1 #pragma once 2 3 #include <c10/util/hash.h> 4 #include <c10/util/irange.h> 5 #include <torch/csrc/autograd/variable.h> 6 #include <torch/csrc/jit/python/pybind.h> 7 8 #include <ATen/ATen.h> 9 #include <functional> 10 #include <tuple> 11 #include <vector> 12 13 namespace torch::jit::python { 14 15 struct IODescriptor { 16 struct VariableMetadata { VariableMetadataIODescriptor::VariableMetadata17 VariableMetadata(const autograd::Variable& var) 18 : sizes(var.sizes().vec()), 19 type(var.scalar_type()), 20 device(var.device()), 21 requires_grad(var.requires_grad()) {} 22 23 bool operator==(const VariableMetadata& o) const { 24 return std::tie(device, requires_grad, type, sizes) == 25 std::tie(o.device, o.requires_grad, o.type, o.sizes); 26 } 27 hashIODescriptor::VariableMetadata28 static size_t hash(const VariableMetadata& m) { 29 return c10::get_hash(m.sizes, m.device, m.requires_grad, m.type); 30 } 31 32 std::vector<int64_t> sizes; 33 at::ScalarType type; 34 at::Device device; 35 bool requires_grad; 36 }; 37 38 bool operator==(const IODescriptor& o) const { 39 return std::tie(structure, metadata, grad_enabled) == 40 std::tie(o.structure, o.metadata, o.grad_enabled); 41 } 42 hashIODescriptor43 static size_t hash(const IODescriptor& o) { 44 return c10::get_hash(o.structure, o.metadata, o.grad_enabled); 45 } 46 extendIODescriptor47 void extend(const autograd::variable_list& list) { 48 metadata.reserve(metadata.size() + list.size()); 49 for (auto& var : list) 50 metadata.emplace_back(var); 51 } 52 53 // Description of argument structure. Variables are replaced with 54 // different characters, depending on their flags, beginnings and 55 // ends of tuples and lists are denoted by a pair of parenthesis 56 // of their corresponding kind. They should always be paired. 57 // Example desc: (vv[v(v)v]) 58 // NOTE: if extend() was ever called then metadata.size() can be 59 // different than the number of 'v's in structure. 60 std::string structure; 61 std::vector<std::string> strings; 62 std::vector<VariableMetadata> metadata; 63 bool grad_enabled = false; 64 }; 65 66 static inline std::ostream& operator<<( 67 std::ostream& out, 68 const IODescriptor::VariableMetadata& meta) { 69 at::Device meta_device = meta.device; 70 auto& t = at::getDeprecatedTypeProperties( 71 meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type); 72 out << t << "(requires_grad=" << meta.requires_grad; 73 if (meta_device.is_cuda()) { 74 out << ", device=" << meta_device.index(); 75 } 76 out << ") {"; 77 for (const auto i : c10::irange(meta.sizes.size())) { 78 if (i > 0) 79 out << ", "; 80 out << meta.sizes[i]; 81 } 82 out << "}"; 83 return out; 84 } 85 86 static inline std::ostream& operator<<( 87 std::ostream& out, 88 const IODescriptor& desc) { 89 out << desc.structure << "\n"; 90 out << " with grad_enabled=" << desc.grad_enabled << "\n"; 91 for (const auto i : c10::irange(desc.metadata.size())) { 92 out << " with v" << i << " having type " << desc.metadata[i] << "\n"; 93 } 94 return out; 95 } 96 97 struct ParsedArgs { 98 // Flat vector of Variables found in arguments 99 autograd::variable_list vars; 100 // Metadata describing nesting of objects received from Python and 101 // metadata of vars and whether grad is enabled. 102 IODescriptor desc; 103 extendParsedArgs104 void extend(const autograd::variable_list& list) { 105 if (list.empty()) 106 return; 107 vars.reserve(vars.size() + list.size()); 108 for (auto& var : list) 109 vars.emplace_back(var); 110 desc.extend(list); 111 } 112 }; 113 114 ParsedArgs flatten(py::handle obj); 115 PyObject* unflatten( 116 at::ArrayRef<autograd::Variable> vars, 117 const IODescriptor& structure); 118 119 } // namespace torch::jit::python 120