xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/argument_spec.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/stack.h>
5 #include <c10/util/hash.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/Export.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <ostream>
11 #include <vector>
12 
13 C10_CLANG_DIAGNOSTIC_PUSH()
14 #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
15 C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
16 #endif
17 
18 namespace torch::jit {
19 
20 // GraphExecutor creates specializations of Graphs for different
21 // dimensionalitities and types of inputs.
22 
23 struct ArgumentInfo {
24   friend struct ArgumentSpec;
25   using plain_data_type = uint64_t;
26 
definedArgumentInfo27   bool defined() const {
28     return defined_;
29   }
deviceArgumentInfo30   at::Device device() const {
31     return at::Device(DeviceType(dev_type_), device_);
32   }
33   // XXX: It is guaranteed that this will return false when called on non-tensor
34   // arguments
requires_gradArgumentInfo35   bool requires_grad() const {
36     return requires_grad_;
37   }
dimArgumentInfo38   int dim() const {
39     return dim_;
40   }
typeArgumentInfo41   at::ScalarType type() const {
42     return at::ScalarType(type_);
43   }
toTypeArgumentInfo44   TypePtr toType() const {
45     if (!defined())
46       return TensorType::get();
47 
48     return TensorType::create(
49         type(), device(), std::optional<size_t>(dim()), requires_grad());
50   }
TypePtrArgumentInfo51   operator TypePtr() const {
52     return toType();
53   }
54 
55  private:
56   unsigned defined_ : 1;
57   unsigned requires_grad_ : 1;
58   unsigned : 5;
59   unsigned dim_ : 8;
60   unsigned device_ : 8;
61   unsigned type_ : 8;
62   unsigned dev_type_ : 16;
63   unsigned : 16;
64 };
65 
66 static_assert(
67     std::is_standard_layout<ArgumentInfo>::value,
68     "ArgumentInfo is to be a POD struct");
69 static_assert(
70     sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
71     "ArgumentInfo is expected to be a 32-bit struct");
72 
73 struct ArgumentSpec {
ArgumentSpecArgumentSpec74   ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs)
75       : hash_code(c10::hash_combine(
76             num_flat_tensor_inputs,
77             num_flat_optional_inputs)) {
78     tensor_args.reserve(num_flat_tensor_inputs);
79     optional_presence.reserve(num_flat_optional_inputs);
80   }
81 
addOptionalArgumentSpec82   void addOptional(const IValue& input) {
83     bool is_present = !input.isNone();
84     optional_presence.push_back(is_present);
85     hash_code = c10::hash_combine(hash_code, is_present);
86   }
87 
addTensorArgumentSpec88   void addTensor(const IValue& input, bool with_grad) {
89     AT_ASSERT(input.isTensor(), "Expected Tensor but found ", input.tagKind());
90     tensor_args.emplace_back();
91     auto& arg = tensor_args.back();
92     // Initialize all fields to 0. This is convenient, because e.g.
93     // requires_grad() can be checked even on tensors AND will make
94     // padding bits all 0s.
95     std::memset(&arg, 0, sizeof(ArgumentInfo));
96 
97     // [argspec refcounting] reinterpret the IValue to avoid having to refcount
98     // the Tensor microbenchmarks
99     // https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
100     // show overhead in extra refcounting along this path
101     const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
102     arg.defined_ = t->defined();
103     if (arg.defined_) {
104       arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
105       arg.dim_ = t->dim();
106       at::Device device = t->device();
107       arg.dev_type_ =
108           // NOLINTNEXTLINE(bugprone-signed-char-misuse)
109           static_cast<std::underlying_type<DeviceType>::type>(device.type());
110       // NOLINTNEXTLINE(bugprone-signed-char-misuse)
111       arg.device_ = device.index();
112       arg.type_ = static_cast<unsigned>(t->scalar_type());
113     }
114     combineHash(arg);
115   }
116 
combineHashArgumentSpec117   void combineHash(const ArgumentInfo& arg) {
118     ArgumentInfo::plain_data_type arg_data = 0;
119     std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo));
120     hash_code = c10::hash_combine(hash_code, arg_data);
121   }
122 
123   // equality is fast: check ninputs, and then check the raw array data,
124   // there are no size/stride indirections
125   // hopefully std::vector<bool> has fast equality
126   bool operator==(const ArgumentSpec& spec) const {
127     if (optional_presence != spec.optional_presence) {
128       return false;
129     }
130     if (tensor_args.size() != spec.tensor_args.size())
131       return false;
132     // NB: we need to break out early when there are no elements, because
133     // passing a nullptr to memcmp is UB.
134     if (tensor_args.empty())
135       return true;
136     return std::memcmp(
137                tensor_args.data(),
138                spec.tensor_args.data(),
139                tensor_args.size() * sizeof(ArgumentInfo)) == 0;
140   }
141   bool operator!=(const ArgumentSpec& spec) const {
142     return !(*this == spec);
143   }
numTensorsArgumentSpec144   size_t numTensors() const {
145     return tensor_args.size();
146   }
tensorAtArgumentSpec147   const ArgumentInfo& tensorAt(size_t i) const {
148     return tensor_args[i];
149   }
numOptionalsArgumentSpec150   size_t numOptionals() const {
151     return optional_presence.size();
152   }
isPresentArgumentSpec153   bool isPresent(size_t i) const {
154     return optional_presence[i];
155   }
hashCodeArgumentSpec156   size_t hashCode() const {
157     return hash_code;
158   }
159 
160  private:
161   size_t hash_code; // precomputed on construction
162   std::vector<ArgumentInfo> tensor_args;
163   std::vector<bool> optional_presence;
164 };
165 
166 namespace {
167 static constexpr size_t ARG_SPEC_DEPTH_LIMIT = 128;
168 }
169 
170 // ArgumentSpecCreator takes an initial graph and comes up with a set
171 // of simple instructions to compute the ArgumentSpec given a set of
172 // input tensors.
173 struct TORCH_API ArgumentSpecCreator {
174   // instructs acts on a stack of a list of input IValues
175   // at the beginning the stack contains a single list of the inputs to the
176   // function the ENTER_ instructs descend into subobjects and push new lists
177   // onto the stack
178   enum Inst : char {
179     ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the
180                  // list of its elements onto the stack as a new list
181     ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class
182     LEAVE, // pop the top-most list from the stack
183     SKIP, // consume an element from the top-most list, and discard
184     SPECIALIZE_OPTIONAL_TENSOR, // consume a optional tensor for the top-most
185                                 // list, and add it to the ArgSpec key being
186                                 // created
187     SPECIALIZE_TENSOR, // consume a tensor for the top-most
188                        // list, and add it to the ArgSpec key being created
189     SPECIALIZE_OPTIONAL,
190     // consume a nontensor optional from the top-most list,
191     // and add it to the ArgSpec key being created
192   };
193   ArgumentSpecCreator(Graph& graph);
194   ArgumentSpec create(bool with_grad, const Stack& stack) const;
195   void specializeTypes(Graph& g, const ArgumentSpec& spec) const;
196   void dump() const;
197   using WrittenSlots = std::unordered_set<std::string>;
198 
199  private:
200   void scan(
201       const TypePtr& typ,
202       size_t depth,
203       const WrittenSlots& written_slots);
204   size_t num_inputs_;
205   size_t num_tensors_ = 0;
206   size_t num_optionals_ = 0;
207   std::vector<Inst> instructions_;
208 };
209 
210 // CompleteArgumentSpec represents one particular specialization.
211 // It is designed so that it can be created, hashed, and compared quickly
212 // since it is used along the hot-path of the JIT to check if the code
213 // we have created is valid for the given inputs.
214 
215 // COmpleteArgumentInfoPOD is only used internally in CompleteArgumentSpec
216 // API users should use ArgumentInfo
217 struct CompleteArgumentInfoPOD {
218   // total size is 64-bit
219   unsigned is_tensor : 8; // all other fields are invalid if this is false
220   unsigned type : 8; // scalar type
221   unsigned defined : 1;
222   unsigned requires_grad : 1;
223   signed device : 14;
224   unsigned dev_type : 16;
225   unsigned
226       total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's
227                        // tensor_info() array. total_dims is the total number of
228                        // dimensions seen so far in all previous members of
229                        // tensor_info(), including this tensor 2*total_dims
230                        // becomes the offset into the sizes_strides list for the
231                        // _next_ tensor in the tensor_info array for tensor 0,
232                        // the offset is always 0
233 };
234 
235 static_assert(
236     sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
237     "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
238 
239 struct CompleteArgumentInfo;
240 
241 struct CompleteArgumentSpec {
CompleteArgumentSpecCompleteArgumentSpec242   CompleteArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs)
243       : hash_code(0), ninputs(inputs.size()) {
244     int32_t all_dims = 0;
245     const auto num_inputs = inputs.size();
246     for (const auto i : c10::irange(num_inputs)) {
247       if (!inputs[i].isTensor())
248         continue;
249       auto& tensor = inputs[i].toTensor();
250       all_dims += tensor.defined() ? tensor.ndimension() : 0;
251     }
252     // allocate enough room for all TensorPODs and dimensions
253     data.resize(ninputs + all_dims * 2);
254 
255     // and reinterpret our data array as these structs
256     auto* pods = reinterpret_cast<CompleteArgumentInfoPOD*>(data.data());
257     int64_t* next_dim = sizes_strides();
258     int32_t total_dims = 0;
259     for (const auto i : c10::irange(num_inputs)) {
260       auto& pod = pods[i];
261       pod.is_tensor = static_cast<uint32_t>(inputs[i].isTensor());
262       if (pod.is_tensor) {
263         at::Tensor t = inputs[i].toTensor();
264         pod.defined = t.defined();
265         if (pod.defined) {
266           pod.type = static_cast<int>(t.scalar_type());
267           at::Device device = t.device();
268           // NOLINTNEXTLINE(bugprone-signed-char-misuse)
269           pod.dev_type = static_cast<std::underlying_type<DeviceType>::type>(
270               device.type());
271           // NOLINTNEXTLINE(bugprone-signed-char-misuse)
272           pod.device = device.index();
273           pod.requires_grad = with_grad && t.requires_grad();
274           total_dims += t.ndimension();
275           auto sizes = t.sizes();
276           std::copy(sizes.begin(), sizes.end(), next_dim);
277           next_dim += sizes.size();
278           auto strides = t.strides();
279           std::copy(strides.begin(), strides.end(), next_dim);
280           next_dim += strides.size();
281         }
282       }
283       // each POD has a running tally of all dimensions including its own
284       TORCH_CHECK(
285           total_dims < std::numeric_limits<uint16_t>::max(),
286           "The number of dims cannot be packed into CompleteArgumentSpec:",
287           total_dims);
288       pod.total_dims = total_dims;
289     }
290     // we precompute the hash_code to minimize the time inside of hash
291     // table operations where we may need to hold a compiler cache lock.
292     hash_code = c10::hash_combine(0, ninputs);
293     for (auto d : data) {
294       hash_code = c10::hash_combine(hash_code, d);
295     }
296   }
297 
298   // equality is fast: check ninputs, and then check the raw array data,
299   // there are no size/stride indirections
300   bool operator==(const CompleteArgumentSpec& spec) const {
301     return ninputs == spec.ninputs && data == spec.data;
302   }
303   bool operator!=(const CompleteArgumentSpec& spec) const {
304     return !(*this == spec);
305   }
306   friend struct CompleteArgumentInfo;
307   CompleteArgumentInfo at(size_t i) const;
sizeCompleteArgumentSpec308   size_t size() const {
309     return ninputs;
310   }
hashCodeCompleteArgumentSpec311   size_t hashCode() const {
312     return hash_code;
313   }
314 
315  private:
tensor_infoCompleteArgumentSpec316   ArrayRef<CompleteArgumentInfoPOD> tensor_info() const {
317     return ArrayRef<CompleteArgumentInfoPOD>(
318         reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
319   }
320   // the start of the sizes_strides information, which comes after the
321   // CompleteArgumentInfoPOD list.
sizes_stridesCompleteArgumentSpec322   const int64_t* sizes_strides() const {
323     return data.data() + ninputs;
324   }
sizes_stridesCompleteArgumentSpec325   int64_t* sizes_strides() {
326     return data.data() + ninputs;
327   }
328   size_t hash_code; // precomputed on construction
329   size_t ninputs;
330   // layout is ninputs of TensorPOD (each 64-bit) followed by their size and
331   // stride info for 3 tensors:
332   // [t0POD][t1POD][t2POD]...
333   // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
334   std::vector<int64_t> data;
335 };
336 
337 // public view of compressed CompleteArgumentInfo
338 struct CompleteArgumentInfo {
CompleteArgumentInfoCompleteArgumentInfo339   CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i)
340       : spec(spec), i(i) {}
isTensorCompleteArgumentInfo341   bool isTensor() const {
342     return pod(i).is_tensor;
343   }
typeCompleteArgumentInfo344   at::ScalarType type() const {
345     return at::ScalarType(pod(i).type);
346   }
definedCompleteArgumentInfo347   bool defined() const {
348     return pod(i).defined;
349   }
requires_gradCompleteArgumentInfo350   bool requires_grad() const {
351     return pod(i).requires_grad;
352   }
deviceCompleteArgumentInfo353   at::Device device() const {
354     return at::Device(
355         DeviceType(pod(i).dev_type),
356         static_cast<c10::DeviceIndex>(pod(i).device));
357   }
ndimensionCompleteArgumentInfo358   int ndimension() const {
359     // See [valid range], it is always valid to ask for offset for (i + 1)
360     return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2;
361   }
sizesCompleteArgumentInfo362   at::IntArrayRef sizes() const {
363     return at::IntArrayRef(
364         spec.sizes_strides() + sizes_strides_offset(i), ndimension());
365   }
stridesCompleteArgumentInfo366   at::IntArrayRef strides() const {
367     int ndim = ndimension();
368     return at::IntArrayRef(
369         spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
370   }
TypePtrCompleteArgumentInfo371   operator TypePtr() const {
372     if (!defined())
373       return TensorType::get();
374     return TensorType::create(
375         type(),
376         device(),
377         c10::VaryingShape<int64_t>{sizes()},
378         c10::VaryingShape<int64_t>{strides()},
379         requires_grad());
380   }
381 
382  private:
383   // offsetinto sizes_strides() array where the sizes start for tensor j
384   // [valid range] valid range is [0, ninputs]
385   // (i.e. you can ask for the offset at ninputs, which would be the offset of
386   // the next tensor if it existed)
sizes_strides_offsetCompleteArgumentInfo387   int sizes_strides_offset(int j) const {
388     if (j == 0)
389       return 0;
390     return 2 * pod(j - 1).total_dims;
391   }
podCompleteArgumentInfo392   const CompleteArgumentInfoPOD& pod(int j) const {
393     return spec.tensor_info().at(j);
394   }
395   const CompleteArgumentSpec& spec;
396   const int i;
397 };
398 
399 inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) {
400   if (!info.defined()) {
401     return out << "<undefined>";
402   }
403   out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
404       << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim()
405       << ")";
406   return out;
407 }
408 
409 inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) {
410   out << "{";
411   for (const auto i : c10::irange(spec.numTensors())) {
412     if (i > 0)
413       out << ", ";
414     out << spec.tensorAt(i);
415   }
416   out << "; ";
417   for (const auto i : c10::irange(spec.numOptionals())) {
418     if (i > 0)
419       out << ", ";
420     out << spec.isPresent(i);
421   }
422   out << "}";
423   return out;
424 }
425 
426 inline std::ostream& operator<<(
427     std::ostream& out,
428     const CompleteArgumentInfo& info) {
429   if (!info.defined()) {
430     return out << "<undefined>";
431   }
432   out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
433       << ", requires_grad=" << info.requires_grad()
434       << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")";
435   return out;
436 }
437 
438 inline std::ostream& operator<<(
439     std::ostream& out,
440     const CompleteArgumentSpec& spec) {
441   out << "{";
442   for (const auto i : c10::irange(spec.size())) {
443     if (i > 0)
444       out << ", ";
445     out << spec.at(i);
446   }
447   out << "}";
448   return out;
449 }
450 
at(size_t i)451 inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const {
452   return CompleteArgumentInfo(*this, i);
453 }
454 
convertOptional(std::optional<c10::ScalarType> const & from)455 inline std::optional<int8_t> convertOptional(
456     std::optional<c10::ScalarType> const& from) {
457   return (from) ? std::optional<int8_t>(static_cast<int8_t>(*from))
458                 : std::optional<int8_t>{};
459 }
460 
461 } // namespace torch::jit
462 
463 namespace std {
464 
465 template <typename T>
466 struct hash<c10::VaryingShape<T>> {
467   size_t operator()(const c10::VaryingShape<T>& vs) const {
468     return c10::get_hash(
469         vs.size(),
470         vs.size() ? vs.sizes().value() : std::vector<std::optional<T>>());
471   }
472 };
473 
474 template <>
475 struct hash<c10::TensorType> {
476   size_t operator()(const c10::TensorType& ptt) const {
477     return c10::get_hash<
478         std::optional<int8_t>,
479         c10::VaryingShape<int64_t>,
480         c10::VaryingShape<int64_t>,
481         std::optional<bool>>(
482         torch::jit::convertOptional(ptt.scalarType()),
483         ptt.sizes(),
484         ptt.strides(),
485         ptt.requiresGrad());
486   }
487 };
488 
489 template <>
490 struct hash<torch::jit::ArgumentSpec> {
491   size_t operator()(const torch::jit::ArgumentSpec& spec) const {
492     return spec.hashCode();
493   }
494 };
495 template <>
496 struct hash<torch::jit::CompleteArgumentSpec> {
497   size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const {
498     return spec.hashCode();
499   }
500 };
501 } // namespace std
502 
503 C10_CLANG_DIAGNOSTIC_POP()
504