xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_arg_flatten.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/hash.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/autograd/variable.h>
6 #include <torch/csrc/jit/python/pybind.h>
7 
8 #include <ATen/ATen.h>
9 #include <functional>
10 #include <tuple>
11 #include <vector>
12 
13 namespace torch::jit::python {
14 
15 struct IODescriptor {
16   struct VariableMetadata {
VariableMetadataIODescriptor::VariableMetadata17     VariableMetadata(const autograd::Variable& var)
18         : sizes(var.sizes().vec()),
19           type(var.scalar_type()),
20           device(var.device()),
21           requires_grad(var.requires_grad()) {}
22 
23     bool operator==(const VariableMetadata& o) const {
24       return std::tie(device, requires_grad, type, sizes) ==
25           std::tie(o.device, o.requires_grad, o.type, o.sizes);
26     }
27 
hashIODescriptor::VariableMetadata28     static size_t hash(const VariableMetadata& m) {
29       return c10::get_hash(m.sizes, m.device, m.requires_grad, m.type);
30     }
31 
32     std::vector<int64_t> sizes;
33     at::ScalarType type;
34     at::Device device;
35     bool requires_grad;
36   };
37 
38   bool operator==(const IODescriptor& o) const {
39     return std::tie(structure, metadata, grad_enabled) ==
40         std::tie(o.structure, o.metadata, o.grad_enabled);
41   }
42 
hashIODescriptor43   static size_t hash(const IODescriptor& o) {
44     return c10::get_hash(o.structure, o.metadata, o.grad_enabled);
45   }
46 
extendIODescriptor47   void extend(const autograd::variable_list& list) {
48     metadata.reserve(metadata.size() + list.size());
49     for (auto& var : list)
50       metadata.emplace_back(var);
51   }
52 
53   // Description of argument structure. Variables are replaced with
54   // different characters, depending on their flags, beginnings and
55   // ends of tuples and lists are denoted by a pair of parenthesis
56   // of their corresponding kind. They should always be paired.
57   // Example desc: (vv[v(v)v])
58   // NOTE: if extend() was ever called then metadata.size() can be
59   // different than the number of 'v's in structure.
60   std::string structure;
61   std::vector<std::string> strings;
62   std::vector<VariableMetadata> metadata;
63   bool grad_enabled = false;
64 };
65 
66 static inline std::ostream& operator<<(
67     std::ostream& out,
68     const IODescriptor::VariableMetadata& meta) {
69   at::Device meta_device = meta.device;
70   auto& t = at::getDeprecatedTypeProperties(
71       meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
72   out << t << "(requires_grad=" << meta.requires_grad;
73   if (meta_device.is_cuda()) {
74     out << ", device=" << meta_device.index();
75   }
76   out << ") {";
77   for (const auto i : c10::irange(meta.sizes.size())) {
78     if (i > 0)
79       out << ", ";
80     out << meta.sizes[i];
81   }
82   out << "}";
83   return out;
84 }
85 
86 static inline std::ostream& operator<<(
87     std::ostream& out,
88     const IODescriptor& desc) {
89   out << desc.structure << "\n";
90   out << "  with grad_enabled=" << desc.grad_enabled << "\n";
91   for (const auto i : c10::irange(desc.metadata.size())) {
92     out << "  with v" << i << " having type " << desc.metadata[i] << "\n";
93   }
94   return out;
95 }
96 
97 struct ParsedArgs {
98   // Flat vector of Variables found in arguments
99   autograd::variable_list vars;
100   // Metadata describing nesting of objects received from Python and
101   // metadata of vars and whether grad is enabled.
102   IODescriptor desc;
103 
extendParsedArgs104   void extend(const autograd::variable_list& list) {
105     if (list.empty())
106       return;
107     vars.reserve(vars.size() + list.size());
108     for (auto& var : list)
109       vars.emplace_back(var);
110     desc.extend(list);
111   }
112 };
113 
114 ParsedArgs flatten(py::handle obj);
115 PyObject* unflatten(
116     at::ArrayRef<autograd::Variable> vars,
117     const IODescriptor& structure);
118 
119 } // namespace torch::jit::python
120