1 #pragma once
2
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/stack.h>
5 #include <c10/util/hash.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/Export.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <ostream>
11 #include <vector>
12
13 C10_CLANG_DIAGNOSTIC_PUSH()
14 #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
15 C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
16 #endif
17
18 namespace torch::jit {
19
20 // GraphExecutor creates specializations of Graphs for different
21 // dimensionalitities and types of inputs.
22
23 struct ArgumentInfo {
24 friend struct ArgumentSpec;
25 using plain_data_type = uint64_t;
26
definedArgumentInfo27 bool defined() const {
28 return defined_;
29 }
deviceArgumentInfo30 at::Device device() const {
31 return at::Device(DeviceType(dev_type_), device_);
32 }
33 // XXX: It is guaranteed that this will return false when called on non-tensor
34 // arguments
requires_gradArgumentInfo35 bool requires_grad() const {
36 return requires_grad_;
37 }
dimArgumentInfo38 int dim() const {
39 return dim_;
40 }
typeArgumentInfo41 at::ScalarType type() const {
42 return at::ScalarType(type_);
43 }
toTypeArgumentInfo44 TypePtr toType() const {
45 if (!defined())
46 return TensorType::get();
47
48 return TensorType::create(
49 type(), device(), std::optional<size_t>(dim()), requires_grad());
50 }
TypePtrArgumentInfo51 operator TypePtr() const {
52 return toType();
53 }
54
55 private:
56 unsigned defined_ : 1;
57 unsigned requires_grad_ : 1;
58 unsigned : 5;
59 unsigned dim_ : 8;
60 unsigned device_ : 8;
61 unsigned type_ : 8;
62 unsigned dev_type_ : 16;
63 unsigned : 16;
64 };
65
66 static_assert(
67 std::is_standard_layout<ArgumentInfo>::value,
68 "ArgumentInfo is to be a POD struct");
69 static_assert(
70 sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type),
71 "ArgumentInfo is expected to be a 32-bit struct");
72
73 struct ArgumentSpec {
ArgumentSpecArgumentSpec74 ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs)
75 : hash_code(c10::hash_combine(
76 num_flat_tensor_inputs,
77 num_flat_optional_inputs)) {
78 tensor_args.reserve(num_flat_tensor_inputs);
79 optional_presence.reserve(num_flat_optional_inputs);
80 }
81
addOptionalArgumentSpec82 void addOptional(const IValue& input) {
83 bool is_present = !input.isNone();
84 optional_presence.push_back(is_present);
85 hash_code = c10::hash_combine(hash_code, is_present);
86 }
87
addTensorArgumentSpec88 void addTensor(const IValue& input, bool with_grad) {
89 AT_ASSERT(input.isTensor(), "Expected Tensor but found ", input.tagKind());
90 tensor_args.emplace_back();
91 auto& arg = tensor_args.back();
92 // Initialize all fields to 0. This is convenient, because e.g.
93 // requires_grad() can be checked even on tensors AND will make
94 // padding bits all 0s.
95 std::memset(&arg, 0, sizeof(ArgumentInfo));
96
97 // [argspec refcounting] reinterpret the IValue to avoid having to refcount
98 // the Tensor microbenchmarks
99 // https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
100 // show overhead in extra refcounting along this path
101 const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
102 arg.defined_ = t->defined();
103 if (arg.defined_) {
104 arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
105 arg.dim_ = t->dim();
106 at::Device device = t->device();
107 arg.dev_type_ =
108 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
109 static_cast<std::underlying_type<DeviceType>::type>(device.type());
110 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
111 arg.device_ = device.index();
112 arg.type_ = static_cast<unsigned>(t->scalar_type());
113 }
114 combineHash(arg);
115 }
116
combineHashArgumentSpec117 void combineHash(const ArgumentInfo& arg) {
118 ArgumentInfo::plain_data_type arg_data = 0;
119 std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo));
120 hash_code = c10::hash_combine(hash_code, arg_data);
121 }
122
123 // equality is fast: check ninputs, and then check the raw array data,
124 // there are no size/stride indirections
125 // hopefully std::vector<bool> has fast equality
126 bool operator==(const ArgumentSpec& spec) const {
127 if (optional_presence != spec.optional_presence) {
128 return false;
129 }
130 if (tensor_args.size() != spec.tensor_args.size())
131 return false;
132 // NB: we need to break out early when there are no elements, because
133 // passing a nullptr to memcmp is UB.
134 if (tensor_args.empty())
135 return true;
136 return std::memcmp(
137 tensor_args.data(),
138 spec.tensor_args.data(),
139 tensor_args.size() * sizeof(ArgumentInfo)) == 0;
140 }
141 bool operator!=(const ArgumentSpec& spec) const {
142 return !(*this == spec);
143 }
numTensorsArgumentSpec144 size_t numTensors() const {
145 return tensor_args.size();
146 }
tensorAtArgumentSpec147 const ArgumentInfo& tensorAt(size_t i) const {
148 return tensor_args[i];
149 }
numOptionalsArgumentSpec150 size_t numOptionals() const {
151 return optional_presence.size();
152 }
isPresentArgumentSpec153 bool isPresent(size_t i) const {
154 return optional_presence[i];
155 }
hashCodeArgumentSpec156 size_t hashCode() const {
157 return hash_code;
158 }
159
160 private:
161 size_t hash_code; // precomputed on construction
162 std::vector<ArgumentInfo> tensor_args;
163 std::vector<bool> optional_presence;
164 };
165
166 namespace {
167 static constexpr size_t ARG_SPEC_DEPTH_LIMIT = 128;
168 }
169
170 // ArgumentSpecCreator takes an initial graph and comes up with a set
171 // of simple instructions to compute the ArgumentSpec given a set of
172 // input tensors.
173 struct TORCH_API ArgumentSpecCreator {
174 // instructs acts on a stack of a list of input IValues
175 // at the beginning the stack contains a single list of the inputs to the
176 // function the ENTER_ instructs descend into subobjects and push new lists
177 // onto the stack
178 enum Inst : char {
179 ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the
180 // list of its elements onto the stack as a new list
181 ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class
182 LEAVE, // pop the top-most list from the stack
183 SKIP, // consume an element from the top-most list, and discard
184 SPECIALIZE_OPTIONAL_TENSOR, // consume a optional tensor for the top-most
185 // list, and add it to the ArgSpec key being
186 // created
187 SPECIALIZE_TENSOR, // consume a tensor for the top-most
188 // list, and add it to the ArgSpec key being created
189 SPECIALIZE_OPTIONAL,
190 // consume a nontensor optional from the top-most list,
191 // and add it to the ArgSpec key being created
192 };
193 ArgumentSpecCreator(Graph& graph);
194 ArgumentSpec create(bool with_grad, const Stack& stack) const;
195 void specializeTypes(Graph& g, const ArgumentSpec& spec) const;
196 void dump() const;
197 using WrittenSlots = std::unordered_set<std::string>;
198
199 private:
200 void scan(
201 const TypePtr& typ,
202 size_t depth,
203 const WrittenSlots& written_slots);
204 size_t num_inputs_;
205 size_t num_tensors_ = 0;
206 size_t num_optionals_ = 0;
207 std::vector<Inst> instructions_;
208 };
209
210 // CompleteArgumentSpec represents one particular specialization.
211 // It is designed so that it can be created, hashed, and compared quickly
212 // since it is used along the hot-path of the JIT to check if the code
213 // we have created is valid for the given inputs.
214
215 // COmpleteArgumentInfoPOD is only used internally in CompleteArgumentSpec
216 // API users should use ArgumentInfo
217 struct CompleteArgumentInfoPOD {
218 // total size is 64-bit
219 unsigned is_tensor : 8; // all other fields are invalid if this is false
220 unsigned type : 8; // scalar type
221 unsigned defined : 1;
222 unsigned requires_grad : 1;
223 signed device : 14;
224 unsigned dev_type : 16;
225 unsigned
226 total_dims : 16; // all TensorInfoPODs are in CompleteArgumentSpec's
227 // tensor_info() array. total_dims is the total number of
228 // dimensions seen so far in all previous members of
229 // tensor_info(), including this tensor 2*total_dims
230 // becomes the offset into the sizes_strides list for the
231 // _next_ tensor in the tensor_info array for tensor 0,
232 // the offset is always 0
233 };
234
235 static_assert(
236 sizeof(CompleteArgumentInfoPOD) == sizeof(int64_t),
237 "CompleteArgumentInfoPOD must be 64-bit struct for CompleteArgumentSpec encoding to work");
238
239 struct CompleteArgumentInfo;
240
241 struct CompleteArgumentSpec {
CompleteArgumentSpecCompleteArgumentSpec242 CompleteArgumentSpec(bool with_grad, at::ArrayRef<IValue> inputs)
243 : hash_code(0), ninputs(inputs.size()) {
244 int32_t all_dims = 0;
245 const auto num_inputs = inputs.size();
246 for (const auto i : c10::irange(num_inputs)) {
247 if (!inputs[i].isTensor())
248 continue;
249 auto& tensor = inputs[i].toTensor();
250 all_dims += tensor.defined() ? tensor.ndimension() : 0;
251 }
252 // allocate enough room for all TensorPODs and dimensions
253 data.resize(ninputs + all_dims * 2);
254
255 // and reinterpret our data array as these structs
256 auto* pods = reinterpret_cast<CompleteArgumentInfoPOD*>(data.data());
257 int64_t* next_dim = sizes_strides();
258 int32_t total_dims = 0;
259 for (const auto i : c10::irange(num_inputs)) {
260 auto& pod = pods[i];
261 pod.is_tensor = static_cast<uint32_t>(inputs[i].isTensor());
262 if (pod.is_tensor) {
263 at::Tensor t = inputs[i].toTensor();
264 pod.defined = t.defined();
265 if (pod.defined) {
266 pod.type = static_cast<int>(t.scalar_type());
267 at::Device device = t.device();
268 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
269 pod.dev_type = static_cast<std::underlying_type<DeviceType>::type>(
270 device.type());
271 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
272 pod.device = device.index();
273 pod.requires_grad = with_grad && t.requires_grad();
274 total_dims += t.ndimension();
275 auto sizes = t.sizes();
276 std::copy(sizes.begin(), sizes.end(), next_dim);
277 next_dim += sizes.size();
278 auto strides = t.strides();
279 std::copy(strides.begin(), strides.end(), next_dim);
280 next_dim += strides.size();
281 }
282 }
283 // each POD has a running tally of all dimensions including its own
284 TORCH_CHECK(
285 total_dims < std::numeric_limits<uint16_t>::max(),
286 "The number of dims cannot be packed into CompleteArgumentSpec:",
287 total_dims);
288 pod.total_dims = total_dims;
289 }
290 // we precompute the hash_code to minimize the time inside of hash
291 // table operations where we may need to hold a compiler cache lock.
292 hash_code = c10::hash_combine(0, ninputs);
293 for (auto d : data) {
294 hash_code = c10::hash_combine(hash_code, d);
295 }
296 }
297
298 // equality is fast: check ninputs, and then check the raw array data,
299 // there are no size/stride indirections
300 bool operator==(const CompleteArgumentSpec& spec) const {
301 return ninputs == spec.ninputs && data == spec.data;
302 }
303 bool operator!=(const CompleteArgumentSpec& spec) const {
304 return !(*this == spec);
305 }
306 friend struct CompleteArgumentInfo;
307 CompleteArgumentInfo at(size_t i) const;
sizeCompleteArgumentSpec308 size_t size() const {
309 return ninputs;
310 }
hashCodeCompleteArgumentSpec311 size_t hashCode() const {
312 return hash_code;
313 }
314
315 private:
tensor_infoCompleteArgumentSpec316 ArrayRef<CompleteArgumentInfoPOD> tensor_info() const {
317 return ArrayRef<CompleteArgumentInfoPOD>(
318 reinterpret_cast<const CompleteArgumentInfoPOD*>(data.data()), ninputs);
319 }
320 // the start of the sizes_strides information, which comes after the
321 // CompleteArgumentInfoPOD list.
sizes_stridesCompleteArgumentSpec322 const int64_t* sizes_strides() const {
323 return data.data() + ninputs;
324 }
sizes_stridesCompleteArgumentSpec325 int64_t* sizes_strides() {
326 return data.data() + ninputs;
327 }
328 size_t hash_code; // precomputed on construction
329 size_t ninputs;
330 // layout is ninputs of TensorPOD (each 64-bit) followed by their size and
331 // stride info for 3 tensors:
332 // [t0POD][t1POD][t2POD]...
333 // [t0 sizes][t0 strides][t1 sizes][t1 strides][t2 sizes][t2 strides]
334 std::vector<int64_t> data;
335 };
336
337 // public view of compressed CompleteArgumentInfo
338 struct CompleteArgumentInfo {
CompleteArgumentInfoCompleteArgumentInfo339 CompleteArgumentInfo(const CompleteArgumentSpec& spec, const int i)
340 : spec(spec), i(i) {}
isTensorCompleteArgumentInfo341 bool isTensor() const {
342 return pod(i).is_tensor;
343 }
typeCompleteArgumentInfo344 at::ScalarType type() const {
345 return at::ScalarType(pod(i).type);
346 }
definedCompleteArgumentInfo347 bool defined() const {
348 return pod(i).defined;
349 }
requires_gradCompleteArgumentInfo350 bool requires_grad() const {
351 return pod(i).requires_grad;
352 }
deviceCompleteArgumentInfo353 at::Device device() const {
354 return at::Device(
355 DeviceType(pod(i).dev_type),
356 static_cast<c10::DeviceIndex>(pod(i).device));
357 }
ndimensionCompleteArgumentInfo358 int ndimension() const {
359 // See [valid range], it is always valid to ask for offset for (i + 1)
360 return (sizes_strides_offset(i + 1) - sizes_strides_offset(i)) / 2;
361 }
sizesCompleteArgumentInfo362 at::IntArrayRef sizes() const {
363 return at::IntArrayRef(
364 spec.sizes_strides() + sizes_strides_offset(i), ndimension());
365 }
stridesCompleteArgumentInfo366 at::IntArrayRef strides() const {
367 int ndim = ndimension();
368 return at::IntArrayRef(
369 spec.sizes_strides() + sizes_strides_offset(i) + ndim, ndim);
370 }
TypePtrCompleteArgumentInfo371 operator TypePtr() const {
372 if (!defined())
373 return TensorType::get();
374 return TensorType::create(
375 type(),
376 device(),
377 c10::VaryingShape<int64_t>{sizes()},
378 c10::VaryingShape<int64_t>{strides()},
379 requires_grad());
380 }
381
382 private:
383 // offsetinto sizes_strides() array where the sizes start for tensor j
384 // [valid range] valid range is [0, ninputs]
385 // (i.e. you can ask for the offset at ninputs, which would be the offset of
386 // the next tensor if it existed)
sizes_strides_offsetCompleteArgumentInfo387 int sizes_strides_offset(int j) const {
388 if (j == 0)
389 return 0;
390 return 2 * pod(j - 1).total_dims;
391 }
podCompleteArgumentInfo392 const CompleteArgumentInfoPOD& pod(int j) const {
393 return spec.tensor_info().at(j);
394 }
395 const CompleteArgumentSpec& spec;
396 const int i;
397 };
398
399 inline std::ostream& operator<<(std::ostream& out, const ArgumentInfo& info) {
400 if (!info.defined()) {
401 return out << "<undefined>";
402 }
403 out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
404 << ", requires_grad=" << info.requires_grad() << ", dims=" << info.dim()
405 << ")";
406 return out;
407 }
408
409 inline std::ostream& operator<<(std::ostream& out, const ArgumentSpec& spec) {
410 out << "{";
411 for (const auto i : c10::irange(spec.numTensors())) {
412 if (i > 0)
413 out << ", ";
414 out << spec.tensorAt(i);
415 }
416 out << "; ";
417 for (const auto i : c10::irange(spec.numOptionals())) {
418 if (i > 0)
419 out << ", ";
420 out << spec.isPresent(i);
421 }
422 out << "}";
423 return out;
424 }
425
426 inline std::ostream& operator<<(
427 std::ostream& out,
428 const CompleteArgumentInfo& info) {
429 if (!info.defined()) {
430 return out << "<undefined>";
431 }
432 out << "Tensor(device=" << info.device() << ", type=" << toString(info.type())
433 << ", requires_grad=" << info.requires_grad()
434 << ", sizes=" << info.sizes() << ", strides=" << info.strides() << ")";
435 return out;
436 }
437
438 inline std::ostream& operator<<(
439 std::ostream& out,
440 const CompleteArgumentSpec& spec) {
441 out << "{";
442 for (const auto i : c10::irange(spec.size())) {
443 if (i > 0)
444 out << ", ";
445 out << spec.at(i);
446 }
447 out << "}";
448 return out;
449 }
450
at(size_t i)451 inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const {
452 return CompleteArgumentInfo(*this, i);
453 }
454
convertOptional(std::optional<c10::ScalarType> const & from)455 inline std::optional<int8_t> convertOptional(
456 std::optional<c10::ScalarType> const& from) {
457 return (from) ? std::optional<int8_t>(static_cast<int8_t>(*from))
458 : std::optional<int8_t>{};
459 }
460
461 } // namespace torch::jit
462
463 namespace std {
464
465 template <typename T>
466 struct hash<c10::VaryingShape<T>> {
467 size_t operator()(const c10::VaryingShape<T>& vs) const {
468 return c10::get_hash(
469 vs.size(),
470 vs.size() ? vs.sizes().value() : std::vector<std::optional<T>>());
471 }
472 };
473
474 template <>
475 struct hash<c10::TensorType> {
476 size_t operator()(const c10::TensorType& ptt) const {
477 return c10::get_hash<
478 std::optional<int8_t>,
479 c10::VaryingShape<int64_t>,
480 c10::VaryingShape<int64_t>,
481 std::optional<bool>>(
482 torch::jit::convertOptional(ptt.scalarType()),
483 ptt.sizes(),
484 ptt.strides(),
485 ptt.requiresGrad());
486 }
487 };
488
489 template <>
490 struct hash<torch::jit::ArgumentSpec> {
491 size_t operator()(const torch::jit::ArgumentSpec& spec) const {
492 return spec.hashCode();
493 }
494 };
495 template <>
496 struct hash<torch::jit::CompleteArgumentSpec> {
497 size_t operator()(const torch::jit::CompleteArgumentSpec& spec) const {
498 return spec.hashCode();
499 }
500 };
501 } // namespace std
502
503 C10_CLANG_DIAGNOSTIC_POP()
504