1 #include <ATen/PythonTorchFunctionTLS.h>
2 #include <c10/core/SafePyObject.h>
3 #include <c10/core/impl/PyInterpreter.h>
4 #define PY_SSIZE_T_CLEAN
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/SparseCsrTensorUtils.h>
7 #include <c10/util/flat_hash_map.h>
8 #include <torch/csrc/autograd/grad_mode.h>
9 #include <torch/csrc/autograd/utils/wrap_outputs.h>
10 #include <torch/csrc/dynamo/guards.h>
11 #include <torch/csrc/inductor/inductor_ops.h>
12 #include <torch/csrc/utils/disable_torch_function.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14 #include <torch/csrc/utils/python_compat.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_symnode.h>
17 #include <torch/csrc/utils/pythoncapi_compat.h>
18 #include <torch/extension.h>
19
20 #ifdef USE_CUDA
21 #include <ATen/cuda/EmptyTensor.h>
22 #endif
23
24 #ifdef USE_XPU
25 #include <ATen/xpu/EmptyTensor.h>
26 #endif
27
28 #include <sstream>
29 #include <utility>
30
31 // For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
32 // underlying tuple and access the item. Before Python 3.12 version, the
33 // datastructure is in tupleobject.c file -
34 // https://github.com/python/cpython/blob/9afc6d102d16080535325f645849cd84eb04d57d/Objects/tupleobject.c#L1058-L1062
35 // To handle this, we manually copy the struct here and manually cast it to this
36 // new struct. From 3.12, the struct is included in the header file.
37 #if IS_PYTHON_3_12_PLUS
38
39 #define Py_BUILD_CORE
40 // Bring _PyTupleIterObject from the header file
41 #include <internal/pycore_tuple.h>
42 #undef Py_BUILD_CORE
43
44 #else
45
46 // Manually create _PyTupleIterObject struct
47 typedef struct {
48 PyObject_HEAD Py_ssize_t it_index;
49 PyTupleObject* it_seq; /* Set to NULL when iterator is exhausted */
50 } _PyTupleIterObject;
51
52 #endif // IS_PYTHON_3_12_PLUS
53
54 namespace torch::dynamo {
55
56 // Macro to skip addition of duplicate guards like EQUALS_MATCH
57 #define SKIP_IF_GUARD_ALREADY_PRESENT(name) \
58 if (self.is_leaf_guard_present(name)) { \
59 return; \
60 } \
61 self.insert_leaf_guard(name);
62
TensorCheck(const LocalState & state,PyTypeObject * pt,const at::Tensor & v,std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)63 TensorCheck::TensorCheck(
64 const LocalState& state,
65 PyTypeObject* pt,
66 const at::Tensor& v,
67 std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
68 std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
69 : pytype(pt),
70 dispatch_key_(state.apply(v.key_set()).raw_repr()),
71 dtype_(v.dtype().toScalarType()),
72 device_index_(v.device().index()),
73 requires_grad_(v.requires_grad()),
74 sizes_(std::move(dynamic_dims_sizes)),
75 strides_(std::move(dynamic_dims_strides)),
76 dim_(static_cast<int64_t>(sizes_.size())) {
77 // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
78 // we just treat this as optional?
79 }
80
TensorCheck(const LocalState & state,PyTypeObject * pt,c10::DispatchKeySet dispatch_key_set,at::ScalarType dtype,at::DeviceIndex device_index,bool requires_grad,std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)81 TensorCheck::TensorCheck(
82 const LocalState& state,
83 PyTypeObject* pt,
84 c10::DispatchKeySet dispatch_key_set,
85 at::ScalarType dtype,
86 at::DeviceIndex device_index,
87 bool requires_grad,
88 std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
89 std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
90 : pytype(pt),
91 dispatch_key_(state.apply(dispatch_key_set).raw_repr()),
92 dtype_(dtype),
93 device_index_(device_index),
94 requires_grad_(requires_grad),
95 sizes_(std::move(dynamic_dims_sizes)),
96 strides_(std::move(dynamic_dims_strides)),
97 dim_(static_cast<int64_t>(sizes_.size())) {}
98
99 // See note in guards.py [Note - On Export Tensor Guards]
100 // Logic parallel to here must be maintained in python
check(const LocalState & state,const at::Tensor & v)101 bool TensorCheck::check(const LocalState& state, const at::Tensor& v) {
102 // In terms of a sparse_csr tensor, it does not support strides informatio
103 c10::SymIntArrayRef sym_strides(std::vector<SymInt>(v.ndimension(), -1));
104 bool does_not_support_stride = v.layout() == c10::kSparseCsr ||
105 v.layout() == c10::kSparseCsc || v.layout() == c10::kSparseBsc ||
106 v.layout() == c10::kSparseBsr;
107 if (!does_not_support_stride) {
108 sym_strides = v.sym_strides();
109 }
110
111 return check(
112 state,
113 v.key_set(),
114 v.dtype().toScalarType(),
115 v.device(),
116 v.sym_sizes(),
117 sym_strides,
118 v.requires_grad());
119 }
120
check(const LocalState & state,const c10::DispatchKeySet & dispatch_key_set,const at::ScalarType & dtype,const c10::Device & device,const c10::SymIntArrayRef & sym_sizes,const c10::SymIntArrayRef & sym_strides,const bool & requires_grad)121 bool TensorCheck::check(
122 const LocalState& state,
123 const c10::DispatchKeySet& dispatch_key_set,
124 const at::ScalarType& dtype,
125 const c10::Device& device,
126 const c10::SymIntArrayRef& sym_sizes,
127 const c10::SymIntArrayRef& sym_strides,
128 const bool& requires_grad) {
129 if (dispatch_key_ != state.apply(dispatch_key_set).raw_repr() ||
130 dtype_ != dtype || device_index_ != device.index() ||
131 requires_grad_ != requires_grad) {
132 return false;
133 }
134
135 auto ndim = sym_sizes.size();
136 if (ndim != static_cast<size_t>(dim_)) {
137 return false;
138 }
139
140 const auto& sizes = sym_sizes;
141 const auto& strides = sym_strides;
142 for (auto i : c10::irange(ndim)) {
143 auto known_size = sizes_[i];
144 auto known_stride = strides_[i];
145 if (known_size.has_value()) {
146 if (known_size.value() != sizes[i]) {
147 return false;
148 }
149 }
150 if (known_stride.has_value()) {
151 if (known_stride.value() != strides[i]) {
152 return false;
153 }
154 }
155 }
156 return true;
157 }
158
check_verbose(const LocalState & state,const at::Tensor & v,const std::string & tensor_name)159 std::string TensorCheck::check_verbose(
160 const LocalState& state,
161 const at::Tensor& v,
162 const std::string& tensor_name) {
163 std::stringstream fail_reason;
164 fail_reason << "tensor '" << tensor_name << "' ";
165 if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
166 // return fmt::format("tensor dispatch key mismatch. expected {}, actual
167 // {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
168 fail_reason << "dispatch key set mismatch. expected "
169 << c10::DispatchKeySet(c10::DispatchKeySet::RAW, dispatch_key_)
170 << ", actual " << state.apply(v.key_set());
171 return fail_reason.str();
172 } else if (dtype_ != v.dtype().toScalarType()) {
173 // return fmt::format("tensor dtype mismatch. expected {}, actual {}",
174 // dtype_, v.dtype().toScalarType());
175 fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
176 << v.dtype().toScalarType();
177 return fail_reason.str();
178 } else if (device_index_ != v.device().index()) {
179 fail_reason << "Tensor device index mismatch. Expected device index to be "
180 << device_index_ << ", actual " << v.device().index();
181 return fail_reason.str();
182 } else if (requires_grad_ != v.requires_grad()) {
183 // return fmt::format("tensor requires_grad mismatch. expected {}",
184 // requires_grad_);
185 fail_reason << "requires_grad mismatch. expected requires_grad="
186 << requires_grad_;
187 return fail_reason.str();
188 }
189 auto ndim = v.ndimension();
190 if (ndim != dim_) {
191 // return fmt::format("tensor rank mismatch. expected {}, actual {}",
192 // sizes_.size(), ndim);
193 fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual "
194 << ndim;
195 return fail_reason.str();
196 }
197 const auto& sizes = v.sym_sizes();
198 for (auto i : c10::irange(ndim)) {
199 auto known_size = sizes_[i];
200 if (known_size.has_value() && (known_size.value() != sizes[i])) {
201 fail_reason << "size mismatch at index " << i << ". expected "
202 << known_size.value() << ", actual " << sizes[i];
203 return fail_reason.str();
204 }
205 }
206 const bool supports_stride =
207 !v.is_sparse() && !at::sparse_csr::is_sparse_compressed(v);
208 if (supports_stride) {
209 const auto& strides = v.sym_strides();
210 for (auto i : c10::irange(ndim)) {
211 auto known_stride = strides_[i];
212 if (known_stride.has_value() && known_stride.value() != strides[i]) {
213 fail_reason << "stride mismatch at index " << i << ". expected "
214 << known_stride.value() << ", actual " << strides[i];
215 return fail_reason.str();
216 }
217 }
218 }
219 return "";
220 }
221
222 namespace {
223
224 typedef std::vector<TensorCheck> ChecksList;
225
226 typedef struct {
227 PyObject_HEAD;
228 ChecksList* checks;
229 } TensorGuards;
230
TensorGuards_dealloc(TensorGuards * self)231 static void TensorGuards_dealloc(TensorGuards* self) {
232 if (self->checks != nullptr) {
233 delete self->checks;
234 self->checks = nullptr;
235 }
236 Py_TYPE(self)->tp_free((PyObject*)self);
237 }
238
TensorGuards_new(PyTypeObject * type,PyObject * args,PyObject * kwds)239 static PyObject* TensorGuards_new(
240 PyTypeObject* type,
241 PyObject* args,
242 PyObject* kwds) {
243 TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
244 if (self != nullptr) {
245 self->checks = new ChecksList();
246 }
247 return (PyObject*)self;
248 }
249
wrapIntegersInOptional(const c10::SymIntArrayRef & intArray)250 static std::vector<std::optional<c10::SymInt>> wrapIntegersInOptional(
251 const c10::SymIntArrayRef& intArray) {
252 std::vector<std::optional<c10::SymInt>> optVec(intArray.size());
253 std::transform(
254 intArray.begin(),
255 intArray.end(),
256 optVec.begin(),
257 [](const c10::SymInt& value) { return std::make_optional(value); });
258 return optVec;
259 }
260
pyListToVecOptInt(PyObject * pyList)261 static std::vector<std::optional<c10::SymInt>> pyListToVecOptInt(
262 PyObject* pyList) {
263 std::vector<std::optional<c10::SymInt>> vec;
264 Py_ssize_t size = PyList_Size(pyList);
265 for (Py_ssize_t i = 0; i < size; i++) {
266 PyObject* item = PyList_GetItem(pyList, i);
267 auto handle = py::handle(item);
268 if (item == Py_None) {
269 vec.emplace_back(std::nullopt);
270 } else if (torch::is_symint(handle)) {
271 vec.emplace_back(py::cast<c10::SymInt>(handle));
272 } else {
273 int64_t value = PyLong_AsLongLong(item);
274 if (value == -1 && PyErr_Occurred()) {
275 PyErr_SetString(
276 PyExc_TypeError,
277 "Size or stride list item is not a valid integer.");
278 TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
279 }
280 vec.emplace_back(c10::SymInt(value));
281 }
282 }
283 return vec;
284 }
285
get_dynamic_dims(PyObject * dynamic_dims_py)286 static std::vector<std::vector<std::optional<c10::SymInt>>> get_dynamic_dims(
287 PyObject* dynamic_dims_py) {
288 std::vector<std::vector<std::optional<c10::SymInt>>> per_tensor_dynamic_dims;
289 if (dynamic_dims_py != Py_None) {
290 Py_ssize_t size = PyList_Size(dynamic_dims_py);
291 for (Py_ssize_t i = 0; i < size; i++) {
292 PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
293 std::vector<std::optional<c10::SymInt>> vec = pyListToVecOptInt(py_list);
294 per_tensor_dynamic_dims.push_back(std::move(vec));
295 }
296 }
297 return per_tensor_dynamic_dims;
298 }
299
TensorGuards_init(TensorGuards * self,PyObject * args,PyObject * kwds)300 static int TensorGuards_init(
301 TensorGuards* self,
302 PyObject* args,
303 PyObject* kwds) {
304 if (!PyTuple_CheckExact(args)) {
305 PyErr_SetString(PyExc_TypeError, "expected tuple()");
306 return -1;
307 }
308 // Top level structure is List[List[Union[int, None]]]
309 PyObject* dynamic_dims_sizes_py =
310 PyDict_GetItemString(kwds, "dynamic_dims_sizes");
311 if (dynamic_dims_sizes_py == nullptr) {
312 PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
313 return -1;
314 }
315 PyObject* dynamic_dims_strides_py =
316 PyDict_GetItemString(kwds, "dynamic_dims_strides");
317 if (dynamic_dims_strides_py == nullptr) {
318 PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
319 return -1;
320 }
321
322 // dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
323 // an optimization to avoid invoking .size()/.stride() in python needlessly
324 std::vector<std::vector<std::optional<c10::SymInt>>>
325 per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
326 std::vector<std::vector<std::optional<c10::SymInt>>>
327 per_tensor_dynamic_dims_strides =
328 get_dynamic_dims(dynamic_dims_strides_py);
329
330 auto& checks = *self->checks;
331 auto len = PyTuple_GET_SIZE(args);
332 checks.reserve(len);
333 LocalState state;
334
335 for (auto i : c10::irange(len)) {
336 PyObject* item = PyTuple_GET_ITEM(args, i);
337 if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
338 PyErr_SetString(PyExc_TypeError, "expected Tensor()");
339 return -1;
340 }
341 auto tensor = THPVariable_Unpack(item);
342 std::vector<std::optional<c10::SymInt>> tensor_dims_size =
343 per_tensor_dynamic_dims_sizes.empty()
344 ? wrapIntegersInOptional(tensor.sym_sizes())
345 : per_tensor_dynamic_dims_sizes[i];
346 std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
347 per_tensor_dynamic_dims_strides.empty()
348 ? wrapIntegersInOptional(tensor.sym_strides())
349 : per_tensor_dynamic_dims_strides[i];
350
351 checks.emplace_back(
352 state,
353 Py_TYPE(item),
354 std::move(tensor),
355 std::move(tensor_dims_size),
356 std::move(tensor_dims_stride));
357 }
358 return 0;
359 }
360
TensorGuards_check(TensorGuards * self,PyObject * args,PyObject * kwargs)361 PyObject* TensorGuards_check(
362 TensorGuards* self,
363 PyObject* args,
364 PyObject* kwargs) {
365 if (!PyTuple_CheckExact(args)) {
366 PyErr_SetString(PyExc_TypeError, "expected tuple()");
367 return nullptr;
368 }
369 auto& checks = *self->checks;
370 auto len = PyTuple_GET_SIZE(args);
371
372 // kwargs is just ignored here
373
374 if (static_cast<decltype(len)>(checks.size()) != len) {
375 PyErr_SetString(PyExc_TypeError, "wrong length");
376 return nullptr;
377 }
378
379 LocalState state;
380 // Note - all the tensors that make it to guards must be unique. Dynamo
381 // builder handles guarding for positive aliases (X is Y). However, we do not
382 // create guards for negative alias (X is not Y) as that is an N^2
383 // relationship. Instead, we rely on the uniqueness upstream to verify, at
384 // check_fn time (this function).
385 ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
386 for (auto i : c10::irange(len)) {
387 PyObject* item = PyTuple_GET_ITEM(args, i);
388
389 if (Py_TYPE(item) != checks[i].pytype) {
390 Py_RETURN_FALSE;
391 }
392 auto insertion = unique_tensors.insert({item, nullptr});
393 if (!insertion.second) {
394 // Violates uniqueness
395 Py_RETURN_FALSE;
396 }
397 if (!checks[i].check(state, THPVariable_Unpack(item))) {
398 Py_RETURN_FALSE;
399 }
400 }
401
402 Py_RETURN_TRUE;
403 }
404
TensorGuards_check_verbose(TensorGuards * self,PyObject * args,PyObject * kwargs)405 PyObject* TensorGuards_check_verbose(
406 TensorGuards* self,
407 PyObject* args,
408 PyObject* kwargs) {
409 if (!PyTuple_CheckExact(args)) {
410 PyErr_SetString(PyExc_TypeError, "expected tuple()");
411 return nullptr;
412 }
413 auto& checks = *self->checks;
414 auto len = PyTuple_GET_SIZE(args);
415
416 if (static_cast<decltype(len)>(checks.size()) != len) {
417 PyErr_SetString(PyExc_TypeError, "wrong length");
418 return nullptr;
419 }
420
421 PyObject* tensor_check_names_py =
422 PyDict_GetItemString(kwargs, "tensor_check_names");
423 if (tensor_check_names_py == nullptr) {
424 PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
425 return nullptr;
426 }
427
428 if (!PyList_Check(tensor_check_names_py)) {
429 PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
430 return nullptr;
431 }
432
433 auto names_size = PyList_Size(tensor_check_names_py);
434 if (names_size != static_cast<decltype(names_size)>(checks.size())) {
435 PyErr_SetString(
436 PyExc_TypeError,
437 "tensor_check_names should be the same size as # tensors");
438 return nullptr;
439 }
440
441 std::vector<std::string> tensor_check_names;
442 tensor_check_names.reserve(names_size);
443 for (auto i : c10::irange(names_size)) {
444 PyObject* value = PyList_GetItem(tensor_check_names_py, i);
445 if (!PyUnicode_Check(value)) {
446 PyErr_SetString(
447 PyExc_TypeError, "tensor_check_names must only contain strings");
448 return nullptr;
449 }
450 tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
451 }
452
453 LocalState state;
454 ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
455 for (auto i : c10::irange(len)) {
456 PyObject* item = PyTuple_GET_ITEM(args, i);
457 if (Py_TYPE(item) != checks[i].pytype) {
458 std::stringstream fail_reason;
459 PyObject* type_str = PyObject_Str(PyObject_Type(item));
460 fail_reason << "expected type of '" << tensor_check_names[i]
461 << "' to be a tensor type, ";
462 if (!type_str) {
463 fail_reason << "but found a different type";
464 } else {
465 fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
466 }
467 return Py_BuildValue("s", fail_reason.str().c_str());
468 }
469
470 auto insertion = unique_tensors.insert({item, nullptr});
471 if (!insertion.second) {
472 std::stringstream fail_reason;
473 fail_reason << "Duplicate tensor found where not expected! ";
474 fail_reason << tensor_check_names[i]
475 << "should not alias to anything, but is aliased";
476 return Py_BuildValue("s", fail_reason.str().c_str());
477 }
478 std::string fail_reason = checks[i].check_verbose(
479 state, THPVariable_Unpack(item), tensor_check_names[i]);
480 if (fail_reason.length() > 0) {
481 return Py_BuildValue("s", fail_reason.c_str());
482 }
483 }
484
485 Py_RETURN_TRUE;
486 }
487
488 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
489 static PyMethodDef TensorGuards_methods[] = {
490 {"check",
491 (PyCFunction)(void*)TensorGuards_check,
492 METH_VARARGS | METH_KEYWORDS,
493 ""},
494 {"check_verbose",
495 (PyCFunction)(void*)TensorGuards_check_verbose,
496 METH_VARARGS | METH_KEYWORDS,
497 "verbose fail reasons for failed checks"},
498 {nullptr} /* Sentinel */
499 };
500
501 static PyTypeObject TensorGuardsType = {PyVarObject_HEAD_INIT(nullptr, 0)};
502
503 // TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is
504 // merged.
505 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
506 struct GlobalStateGuard {
507 PyObject_HEAD;
508
inittorch::dynamo::__anon296b09360211::GlobalStateGuard509 inline void init() {
510 auto& ctx = at::globalContext();
511 _grad_mode = at::GradMode::is_enabled();
512 // The below two flags disambiguate
513 // if torch function disabled state is
514 // 1) enabled, 2) all disabled, 3) subclasses disabled
515 // we guard on the stack separately
516 _torch_function = torch::torch_function_enabled();
517 _torch_function_all_disabled = at::impl::torch_function_all_disabled();
518 _deterministic_algorithms = ctx.deterministicAlgorithms();
519 _deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
520 _allow_tf32 = ctx.allowTF32CuBLAS();
521 _allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
522 _allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
523 _num_threads = at::get_num_threads();
524 _default_dtype = at::get_default_dtype();
525 }
526
checktorch::dynamo::__anon296b09360211::GlobalStateGuard527 inline bool check() const {
528 auto& ctx = at::globalContext();
529 return (_grad_mode == at::GradMode::is_enabled() &&
530 _torch_function == torch::torch_function_enabled() &&
531 _torch_function_all_disabled ==
532 at::impl::torch_function_all_disabled() &&
533 _deterministic_algorithms == ctx.deterministicAlgorithms() &&
534 _deterministic_algorithms_warn_only ==
535 ctx.deterministicAlgorithmsWarnOnly() &&
536 _allow_tf32 == ctx.allowTF32CuBLAS() &&
537 _allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
538 _allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
539 _num_threads == at::get_num_threads()) &&
540 _default_dtype == at::get_default_dtype();
541 }
542
reasontorch::dynamo::__anon296b09360211::GlobalStateGuard543 inline std::string reason() const {
544 std::ostringstream os;
545 auto& ctx = at::globalContext();
546 if (_grad_mode != at::GradMode::is_enabled())
547 os << "grad_mode ";
548 if (_torch_function != torch::torch_function_enabled())
549 os << "torch_function ";
550 if (_deterministic_algorithms != ctx.deterministicAlgorithms())
551 os << "deterministic_algorithms ";
552 if (_deterministic_algorithms_warn_only !=
553 ctx.deterministicAlgorithmsWarnOnly())
554 os << "deterministic_algorithms_warn_only ";
555 if (_allow_tf32 != ctx.allowTF32CuBLAS())
556 os << "allow_tf32 ";
557 if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
558 os << "allow_fp16_reduce ";
559 if (_allow_bf16_reduce != ctx.allowBF16ReductionCuBLAS())
560 os << "allow_bf16_reduce ";
561 if (_num_threads != at::get_num_threads())
562 os << "num_threads ";
563 if (_default_dtype != at::get_default_dtype())
564 os << "default_dtype ";
565 return os.str();
566 }
567
568 bool _grad_mode;
569 bool _torch_function;
570 bool _torch_function_all_disabled;
571 bool _deterministic_algorithms;
572 bool _deterministic_algorithms_warn_only;
573 bool _allow_tf32;
574 bool _allow_fp16_reduce;
575 bool _allow_bf16_reduce;
576 int _num_threads;
577 caffe2::TypeMeta _default_dtype;
578 // TODO(jansel): we should guard on more state as inductor starts using it
579 };
580
GlobalStateGuard_init(GlobalStateGuard * self,PyObject * args,PyObject * kwargs)581 int GlobalStateGuard_init(
582 GlobalStateGuard* self,
583 PyObject* args,
584 PyObject* kwargs) {
585 self->init();
586 return 0;
587 }
588
GlobalStateGuard_check(GlobalStateGuard * self,PyObject * args,PyObject * kwargs)589 PyObject* GlobalStateGuard_check(
590 GlobalStateGuard* self,
591 PyObject* args,
592 PyObject* kwargs) {
593 if (self->check()) {
594 Py_RETURN_TRUE;
595 } else {
596 Py_RETURN_FALSE;
597 }
598 }
599
GlobalStateGuard_reason(GlobalStateGuard * self,PyObject * args,PyObject * kwargs)600 PyObject* GlobalStateGuard_reason(
601 GlobalStateGuard* self,
602 PyObject* args,
603 PyObject* kwargs) {
604 return PyUnicode_FromString(self->reason().c_str());
605 }
606
607 // NOLINTNEXTLINE(*array*)
608 static PyMethodDef GlobalStateGuard_methods[] = {
609 {"check",
610 (PyCFunction)(void*)GlobalStateGuard_check,
611 METH_NOARGS,
612 "Return true if global state was the same as at creation time"},
613 {"reason",
614 (PyCFunction)(void*)GlobalStateGuard_reason,
615 METH_NOARGS,
616 "Return string reason for guard check failing"},
617 {nullptr}};
618 static PyTypeObject GlobalStateGuardType = {PyVarObject_HEAD_INIT(nullptr, 0)};
619
check_type_id(PyObject * dummy,PyObject * args)620 static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
621 // faster `lambda obj, expected: id(type(obj)) == expected`
622 PyObject* obj = nullptr;
623 unsigned long long expected = 0;
624 if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
625 return nullptr;
626 }
627 // NOLINTNEXTLINE(performance-no-int-to-ptr)
628 if (Py_TYPE(obj) == (void*)expected) {
629 Py_RETURN_TRUE;
630 } else {
631 Py_RETURN_FALSE;
632 }
633 }
634
check_obj_id(PyObject * dummy,PyObject * args)635 static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
636 // faster `lambda obj, expected: id(obj) == expected`
637 PyObject* obj = nullptr;
638 unsigned long long expected = 0;
639 if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
640 return nullptr;
641 }
642 // NOLINTNEXTLINE(performance-no-int-to-ptr)
643 if (obj == (void*)expected) {
644 Py_RETURN_TRUE;
645 } else {
646 Py_RETURN_FALSE;
647 }
648 }
649
650 #if IS_PYTHON_3_12_PLUS
651
652 static std::unordered_map<PyObject*, uint64_t> dict_version_map;
653 static int dict_version_watcher_id;
654 static uint64_t global_dict_version_id = 0;
dict_version_watch_callback(PyDict_WatchEvent event,PyObject * dict,PyObject * key,PyObject * new_value)655 static int dict_version_watch_callback(
656 PyDict_WatchEvent event,
657 PyObject* dict,
658 PyObject* key,
659 PyObject* new_value) noexcept {
660 if (event == PyDict_EVENT_DEALLOCATED) {
661 dict_version_map.erase(dict);
662 } else if (event != PyDict_EVENT_CLONED) {
663 dict_version_map[dict] = global_dict_version_id++;
664 }
665 return 0;
666 }
667
668 #endif
669
get_dict_version_unchecked(PyObject * dict)670 static uint64_t get_dict_version_unchecked(PyObject* dict) {
671 #if IS_PYTHON_3_12_PLUS
672
673 if (PyDict_Watch(dict_version_watcher_id, dict)) {
674 throw std::runtime_error("failed to add version watcher to dict!");
675 }
676 if (!dict_version_map.count(dict)) {
677 dict_version_map[dict] = global_dict_version_id++;
678 }
679 return dict_version_map[dict];
680
681 #else
682
683 return ((PyDictObject*)dict)->ma_version_tag;
684
685 #endif
686 }
687
dict_version(PyObject * dummy,PyObject * args)688 static PyObject* dict_version(PyObject* dummy, PyObject* args) {
689 // Retrieves the version of a dictionary.
690 PyObject* obj = nullptr;
691 if (!PyArg_ParseTuple(args, "O", &obj)) {
692 return nullptr;
693 }
694 if (!PyDict_Check(obj)) {
695 return nullptr;
696 }
697 return THPUtils_packUInt64(get_dict_version_unchecked(obj));
698 }
699
assert_size_stride(PyObject * dummy,PyObject * args)700 static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
701 /*
702 Assert that a given tensor has a given size/stride, but ignore strides
703 of size==1 dimensions. Implemented in C++ as this is on the hot path.
704 */
705 PyObject* item = nullptr;
706 PyObject* size = nullptr;
707 PyObject* stride = nullptr;
708 if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
709 return nullptr;
710 }
711 if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
712 PyErr_SetString(PyExc_TypeError, "expected Tensor()");
713 return nullptr;
714 }
715 if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
716 PyErr_SetString(PyExc_TypeError, "expected tuple()");
717 return nullptr;
718 }
719 at::Tensor tensor = THPVariable_Unpack(item);
720 int64_t ndim = tensor.ndimension();
721 if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
722 PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
723 return nullptr;
724 }
725 std::stringstream msg;
726 int num_errors = 0;
727 for (auto i : c10::irange(ndim)) {
728 int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
729 int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i));
730 int64_t actual_size = tensor.size(i);
731 int64_t actual_stride = tensor.stride(i);
732 if (want_size != actual_size ||
733 // ignore stride differences when size is 1
734 (want_stride != actual_stride && actual_size > 1)) {
735 if (num_errors > 0)
736 msg << "; ";
737 msg << "expected size " << actual_size << "==" << want_size << ", stride "
738 << actual_stride << "==" << want_stride << " at dim=" << i;
739 num_errors++;
740 }
741 }
742
743 if (num_errors) {
744 PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
745 return nullptr;
746 }
747
748 Py_RETURN_TRUE;
749 }
750
751 template <typename T>
unwrap_size_tuple(PyObject * obj,T & output)752 inline static void unwrap_size_tuple(PyObject* obj, T& output) {
753 TORCH_CHECK(PyTuple_CheckExact(obj));
754 size_t len = PyTuple_GET_SIZE(obj);
755 output.reserve(len);
756 for (size_t i = 0; i < len; ++i) {
757 auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(obj, i));
758 TORCH_CHECK(result >= 0);
759 output.emplace_back(result);
760 }
761 }
762
763 template <typename T>
_parse_empty_strided_args(PyObject * args,T & sizes,T & strides,at::ScalarType & dtype)764 inline static void _parse_empty_strided_args(
765 PyObject* args,
766 T& sizes,
767 T& strides,
768 at::ScalarType& dtype) {
769 TORCH_CHECK(PyTuple_CheckExact(args));
770 TORCH_CHECK(PyTuple_GET_SIZE(args) == 3);
771 // note PyTuple_GET_ITEM returns a borrowed ref, so no need for refcounts
772 unwrap_size_tuple(PyTuple_GET_ITEM(args, 0), sizes);
773 unwrap_size_tuple(PyTuple_GET_ITEM(args, 1), strides);
774 PyObject* py_dtype = PyTuple_GET_ITEM(args, 2);
775 TORCH_CHECK(THPDtype_Check(py_dtype));
776 dtype = reinterpret_cast<THPDtype*>(py_dtype)->scalar_type;
777 }
778
_empty_strided_device(PyObject * dummy,PyObject * args,c10::DeviceType device_type)779 inline static PyObject* _empty_strided_device(
780 PyObject* dummy,
781 PyObject* args,
782 c10::DeviceType device_type) {
783 HANDLE_TH_ERRORS;
784 at::SmallVector<int64_t, 8> sizes;
785 at::SmallVector<int64_t, 8> strides;
786 at::ScalarType dtype{at::ScalarType::Undefined};
787 _parse_empty_strided_args(args, sizes, strides, dtype);
788 if (device_type == c10::DeviceType::CPU) {
789 return THPVariable_Wrap(
790 at::detail::empty_strided_cpu(sizes, strides, dtype));
791 }
792 #ifdef USE_CUDA
793 else if (device_type == c10::DeviceType::CUDA) {
794 return THPVariable_Wrap(at::detail::empty_strided_cuda(
795 sizes, strides, dtype, c10::DeviceType::CUDA));
796 }
797 #endif
798 #ifdef USE_XPU
799 else if (device_type == c10::DeviceType::XPU) {
800 return THPVariable_Wrap(at::detail::empty_strided_xpu(
801 sizes, strides, dtype, c10::DeviceType::XPU));
802 }
803 #endif
804 else {
805 TORCH_CHECK(
806 false, "PyTorch compiled without support for the specified device.");
807 }
808
809 END_HANDLE_TH_ERRORS;
810 }
811
_empty_strided_cpu(PyObject * dummy,PyObject * args)812 static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) {
813 // at::empty_strided is surprising slow. This is a lower-overhead
814 // version that saves ~2us on every allocation.
815 return _empty_strided_device(dummy, args, c10::DeviceType::CPU);
816 }
817
_empty_strided_cuda(PyObject * dummy,PyObject * args)818 static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) {
819 // at::empty_strided is surprising slow. This is lower-overhead.
820 return _empty_strided_device(dummy, args, c10::DeviceType::CUDA);
821 }
822
_empty_strided_xpu(PyObject * dummy,PyObject * args)823 static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
824 // at::empty_strided is surprising slow. This is lower-overhead.
825 return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
826 }
827
_reinterpret_tensor(PyObject * dummy,PyObject * args)828 static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
829 HANDLE_TH_ERRORS;
830 static PythonArgParser parser(
831 {"_reinterpret_tensor(Tensor base, IntArrayRef sizes, IntArrayRef strides, int64_t offset_increment=0)"},
832 /*traceable=*/true);
833
834 ParsedArgs<4> parsed_args;
835 auto r = parser.parse(args, /*kwargs=*/nullptr, parsed_args);
836
837 Tensor self = r.tensor(0);
838 auto sizes = r.intlist(1);
839 auto strides = r.intlist(2);
840 auto offset_increment = r.toInt64(3);
841
842 auto res = torch::inductor::_reinterpret_tensor(
843 self, sizes, strides, offset_increment);
844 return torch::autograd::utils::wrap(res);
845
846 END_HANDLE_TH_ERRORS;
847 }
848
849 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
850 static PyMethodDef _methods[] = {
851 {"check_type_id", check_type_id, METH_VARARGS, nullptr},
852 {"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
853 {"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
854 {"dict_version", dict_version, METH_VARARGS, nullptr},
855 {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
856 {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
857 {"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
858 {"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
859 {nullptr, nullptr, 0, nullptr}};
860
861 static struct PyModuleDef _module = {
862 PyModuleDef_HEAD_INIT,
863 "torch._C._dynamo.guards",
864 "Module containing checks on tensors",
865 -1,
866 _methods};
867
get_exception_message()868 std::string get_exception_message() {
869 PyObject *ptype = nullptr, *pvalue = nullptr, *ptraceback = nullptr;
870 PyErr_Fetch(&ptype, &pvalue, &ptraceback);
871
872 PyObject* exc_message_pyobj = PyObject_Str(pvalue);
873 const char* exc_message = PyUnicode_AsUTF8(exc_message_pyobj);
874
875 Py_DECREF(exc_message_pyobj);
876 Py_XDECREF(ptype);
877 Py_XDECREF(pvalue);
878 Py_XDECREF(ptraceback);
879 return std::string(exc_message);
880 }
881
is_immutable_object(py::handle example_value)882 bool is_immutable_object(py::handle example_value) {
883 if (PyTuple_Check(example_value.ptr())) {
884 // Check that each element is immutable
885 for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
886 if (!is_immutable_object(
887 py::handle(PyTuple_GetItem(example_value.ptr(), i)))) {
888 return false;
889 }
890 }
891 return true;
892 }
893 return PyLong_Check(example_value.ptr()) ||
894 PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
895 PyUnicode_Check(example_value.ptr()) ||
896 THPVariable_Check(example_value.ptr());
897 }
898
is_parameter(py::handle tensor)899 bool is_parameter(py::handle tensor) {
900 py::object parameter = py::module::import("torch.nn").attr("Parameter");
901 return py::isinstance(tensor, parameter);
902 }
903
904 /**
905 * Stores relevant guard debug information, e.g., failure str for a LeafGuard
906 * failure. The data structure is also accessible in Python.
907 */
908
909 class GuardDebugInfo {
910 public:
GuardDebugInfo(bool result,py::list verbose_code_parts,int num_guards_executed)911 GuardDebugInfo(
912 bool result,
913 py::list verbose_code_parts,
914 int num_guards_executed)
915 : result(result),
916 verbose_code_parts(std::move(verbose_code_parts)),
917 num_guards_executed(num_guards_executed) {}
918
919 // This constructor is used when guard succeeds.
GuardDebugInfo(bool result,int num_guards_executed)920 GuardDebugInfo(bool result, int num_guards_executed)
921 : result(result), num_guards_executed(num_guards_executed) {}
922
GuardDebugInfo(bool result,const std::string & failed_reason,int num_guards_executed)923 GuardDebugInfo(
924 bool result,
925 const std::string& failed_reason,
926 int num_guards_executed)
927 : GuardDebugInfo(result, num_guards_executed) {
928 verbose_code_parts.append(failed_reason);
929 }
930
to_string()931 std::string to_string() {
932 std::stringstream ss;
933 ss << "GuardDebugInfo(\n"
934 << "result=" << result << ",\n"
935 << "verbose_code_parts=" << verbose_code_parts << ",\n"
936 << "num_guards_executed=" << num_guards_executed << ")\n";
937 return ss.str();
938 }
939
940 // Whether the guard passed or failed.
941 bool result;
942
943 // This is a list of verbose_code_parts for the failed guard. When there are
944 // more than one verbose_code_parts, then recompilation reasoning infra on the
945 // Python side can iterate over this list and eval each string to pinpoint the
946 // exact code part that failed.
947 py::list verbose_code_parts;
948
949 // Total number of executed guards so far. This is helpful in debugging if
950 // shuffling is working.
951 int num_guards_executed;
952 };
953
954 class GuardManager;
955 class RootGuardManager;
956 class DictGuardManager;
957
958 /**
959 * Base class for the leaf guard in the GuardManager hierarchy.
960 */
961 class LeafGuard {
962 public:
963 // Most guards do not need root guard manager.
LeafGuard(py::object verbose_code_parts)964 LeafGuard(py::object verbose_code_parts)
965 : _verbose_code_parts(std::move(verbose_code_parts)) {}
966
967 // Guards like TENSOR_MATCH require root_guard_manager to access local_state
968 // shared across all leaf guards.
LeafGuard(RootGuardManager * root_guard_manager,py::object verbose_code_parts)969 LeafGuard(RootGuardManager* root_guard_manager, py::object verbose_code_parts)
970 : _root_guard_manager(root_guard_manager),
971 _verbose_code_parts(std::move(verbose_code_parts)) {}
972
973 // check function could be called from python. This is useful for debugging
974 // purpose.
check(py::handle value)975 bool check(py::handle value) {
976 return check_nopybind(value.ptr());
977 }
978
check_verbose(py::handle value)979 GuardDebugInfo check_verbose(py::handle value) {
980 return check_verbose_nopybind(value.ptr());
981 }
982
check_verbose_nopybind(PyObject * value)983 virtual GuardDebugInfo check_verbose_nopybind(
984 PyObject* value) { // borrowed ref
985 bool result = check_nopybind(value);
986 if (!result) {
987 return GuardDebugInfo(result, _verbose_code_parts, 0);
988 }
989 return GuardDebugInfo(true, 0);
990 }
991
verbose_code_parts()992 py::list verbose_code_parts() {
993 return _verbose_code_parts;
994 }
995
996 // This is on the hot path and avoids any refcounting code from pybind. This
997 // is not exposed to Python and can only be called from C++.
998 virtual bool check_nopybind(PyObject* value) = 0;
999 virtual ~LeafGuard() = default;
1000
1001 protected:
1002 // RootGuardManager has state that is common across all guards like
1003 // LocalState.
1004 RootGuardManager* _root_guard_manager{nullptr};
1005
1006 private:
1007 // This is set while constructing the leaf guard. This is used for identifying
1008 // the cause of recompilation.
1009 py::list _verbose_code_parts;
1010 };
1011
1012 /**
1013 * Represents a leaf guard that accepts the python guard check function. We
1014 * would like to have most of the guards in C++ (to avoid a Python function
1015 * call). But, it will take some time to reach that goal. Also, there might be
1016 * cases where its too tedious to write an equivalent C++ guard.
1017 *
1018 * LAMBDA_GUARD allows us to gradually move to C++. We can start from all
1019 * guards of type PythonLambaGuard and incrementally move expensive guards to
1020 * C++.
1021 */
1022 class LAMBDA_GUARD : public LeafGuard {
1023 public:
LAMBDA_GUARD(py::object guard_check_fn,py::object verbose_code_parts)1024 LAMBDA_GUARD(py::object guard_check_fn, py::object verbose_code_parts)
1025 : LeafGuard(std::move(verbose_code_parts)) {
1026 if (py::isinstance<py::function>(guard_check_fn)) {
1027 _guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
1028 } else {
1029 throw py::type_error("LAMBDA_GUARD expects (callable, str)");
1030 }
1031 }
1032
1033 // Runs the lambda function with the current f_locals value.
check_nopybind(PyObject * value)1034 bool check_nopybind(PyObject* value) override { // borrowed ref
1035 PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
1036 if (x == nullptr) {
1037 // An exception is caught in the lambda function.
1038 PyErr_Clear();
1039 return false;
1040 }
1041 bool result = PyObject_IsTrue(x);
1042 Py_DECREF(x);
1043 return result;
1044 }
1045
check_verbose_nopybind(PyObject * value)1046 GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
1047 PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
1048 if (x == nullptr) {
1049 // An exception is caught in the lambda function.
1050 std::string exc_message = get_exception_message();
1051 PyErr_Clear();
1052 return GuardDebugInfo(false, exc_message, 0);
1053 }
1054 bool result = PyObject_IsTrue(x);
1055 Py_DECREF(x);
1056 if (result) {
1057 return GuardDebugInfo(true, 0);
1058 }
1059 return GuardDebugInfo(false, verbose_code_parts(), 0);
1060 }
1061
1062 private:
1063 // The user provided lambda function for check_fn.
1064 py::function _guard_check_fn;
1065 };
1066
1067 class TYPE_MATCH : public LeafGuard {
1068 public:
1069 // type_id = id(type(obj))
TYPE_MATCH(py::object type_id,py::object verbose_code_parts)1070 TYPE_MATCH(py::object type_id, py::object verbose_code_parts)
1071 : LeafGuard(std::move(verbose_code_parts)),
1072 _expected(py::cast<intptr_t>(std::move(type_id))) {}
1073
check_nopybind(PyObject * value)1074 bool check_nopybind(PyObject* value) override { // borrowed ref
1075 // NOLINTNEXTLINE(performance-no-int-to-ptr)
1076 return Py_TYPE(value) == (void*)_expected;
1077 }
1078
1079 private:
1080 // id of the type of the original object.
1081 intptr_t _expected;
1082 };
1083
1084 class ID_MATCH : public LeafGuard {
1085 public:
1086 // obj_id = id(obj)
ID_MATCH(py::object obj_id,py::object verbose_code_parts)1087 ID_MATCH(py::object obj_id, py::object verbose_code_parts)
1088 : LeafGuard(std::move(verbose_code_parts)),
1089 _expected(py::cast<intptr_t>(std::move(obj_id))) {}
1090
check_nopybind(PyObject * value)1091 bool check_nopybind(PyObject* value) override { // borrowed ref
1092 // NOLINTNEXTLINE(performance-no-int-to-ptr)
1093 return value == (void*)_expected;
1094 }
1095
1096 private:
1097 // id of the original object.
1098 intptr_t _expected;
1099 };
1100
1101 class EQUALS_MATCH : public LeafGuard {
1102 public:
EQUALS_MATCH(py::object value,py::object verbose_code_parts)1103 EQUALS_MATCH(py::object value, py::object verbose_code_parts)
1104 : LeafGuard(std::move(verbose_code_parts)),
1105 _value(value),
1106 _value_type(Py_TYPE(value.ptr())) {}
1107
check_nopybind(PyObject * value)1108 bool check_nopybind(PyObject* value) override { // borrowed ref
1109 // Fast path - pointer equality check. Pointer equality checks are ok
1110 // because objects guarded with EQUALS_MATCH are immutable.
1111 if (value != _value.ptr()) {
1112 // Check type
1113 if (Py_TYPE(value) != _value_type) {
1114 return false;
1115 }
1116 int result = PyObject_RichCompareBool(value, _value.ptr(), Py_EQ);
1117 // Check for exception
1118 if (result == -1) {
1119 PyErr_Clear();
1120 return false;
1121 }
1122 return result;
1123 }
1124 return true;
1125 }
1126
1127 private:
1128 // value to compare against. This is py::object so that we hold on to the
1129 // original value and prevent garbage collection. We run EQUALS_MATCH only on
1130 // selected objects which do not have high memory footprint, so holding on to
1131 // these objects is ok.
1132 py::object _value;
1133
1134 // Type of the value
1135 PyTypeObject* _value_type;
1136 };
1137
1138 class TUPLE_ITERATOR_LEN : public LeafGuard {
1139 public:
TUPLE_ITERATOR_LEN(py::object length,py::object type_id,py::object verbose_code_parts)1140 TUPLE_ITERATOR_LEN(
1141 py::object length,
1142 py::object type_id,
1143 py::object verbose_code_parts)
1144 : LeafGuard(std::move(verbose_code_parts)),
1145 _length(py::cast<Py_ssize_t>(std::move(length))),
1146 _type_id(py::cast<intptr_t>(std::move(type_id))) {}
1147
check_nopybind(PyObject * value)1148 bool check_nopybind(PyObject* value) override { // borrowed ref
1149 // Do a type match first.
1150 // NOLINTNEXTLINE(performance-no-int-to-ptr)
1151 if (Py_TYPE(value) != (void*)_type_id) {
1152 return false;
1153 }
1154 _PyTupleIterObject* it = (_PyTupleIterObject*)value;
1155 Py_ssize_t length = 0;
1156 if (it->it_seq)
1157 length = PyTuple_GET_SIZE(it->it_seq) - it->it_index;
1158 return length == _length;
1159 }
1160
1161 private:
1162 // Length of the guarded list
1163 Py_ssize_t _length;
1164 intptr_t _type_id;
1165 };
1166
1167 class LENGTH_CHECK : public LeafGuard {
1168 public:
LENGTH_CHECK(py::object value,py::object verbose_code_parts)1169 LENGTH_CHECK(py::object value, py::object verbose_code_parts)
1170 : LeafGuard(std::move(verbose_code_parts)),
1171 _length(py::cast<Py_ssize_t>(std::move(value))) {}
1172
check_nopybind(PyObject * value)1173 bool check_nopybind(PyObject* value) override { // borrowed ref
1174 // PySequence_Length returns -1 if the object is not a sequence. So, we
1175 // don't have to test for PySequence_Check.
1176 return PySequence_Length(value) == _length;
1177 }
1178
1179 private:
1180 // Length of the guarded list
1181 Py_ssize_t _length;
1182 };
1183
1184 class DICT_LENGTH : public LeafGuard {
1185 public:
DICT_LENGTH(py::object value,py::object verbose_code_parts)1186 DICT_LENGTH(py::object value, py::object verbose_code_parts)
1187 : LeafGuard(std::move(verbose_code_parts)),
1188 _length(py::cast<Py_ssize_t>(std::move(value))) {}
1189
check_nopybind(PyObject * value)1190 bool check_nopybind(PyObject* value) override { // borrowed ref
1191 return PyDict_Check(value) && PyDict_Size(value) == _length;
1192 }
1193
1194 private:
1195 // Length of the guarded dict
1196 Py_ssize_t _length;
1197 };
1198
1199 class NOT_NONE : public LeafGuard {
1200 public:
NOT_NONE(py::object verbose_code_parts)1201 NOT_NONE(py::object verbose_code_parts)
1202 : LeafGuard(std::move(verbose_code_parts)) {}
1203
check_nopybind(PyObject * value)1204 bool check_nopybind(PyObject* value) override { // borrowed ref
1205 return value != Py_None;
1206 }
1207 };
1208
1209 class DEFAULT_DEVICE : public LeafGuard {
1210 public:
DEFAULT_DEVICE(py::object verbose_code_parts)1211 DEFAULT_DEVICE(py::object verbose_code_parts)
1212 : LeafGuard(std::move(verbose_code_parts)) {
1213 py::handle device_module = py::module::import("torch.utils._device");
1214 // Save the dict using py::object
1215 _utils_device_dict = device_module.attr("__dict__");
1216 _device = _utils_device_dict["CURRENT_DEVICE"];
1217 }
1218
check_nopybind(PyObject * value)1219 bool check_nopybind(PyObject* value) override { // borrowed ref
1220 // Create a static interned string. Interned string is faster than creating
1221 // a new string every time. Even though its a new reference, we don't dec
1222 // ref it. Interned strings are used for things like variable names and are
1223 // leaked by design.
1224 static PyObject* current_device_str =
1225 PyUnicode_InternFromString("CURRENT_DEVICE");
1226 PyObject* device = PyDict_GetItem(
1227 _utils_device_dict.ptr(), current_device_str); // borrowed ref
1228 if (device != _device.ptr()) {
1229 int result = PyObject_RichCompareBool(device, _device.ptr(), Py_EQ);
1230 if (result == -1) {
1231 PyErr_Clear();
1232 return false;
1233 }
1234 return result;
1235 }
1236 return true;
1237 }
1238
1239 private:
1240 // Save the current device and the module dict during the guard construction.
1241 py::object _utils_device_dict;
1242 py::object _device;
1243 };
1244
1245 class GLOBAL_STATE : public LeafGuard {
1246 public:
GLOBAL_STATE(py::object verbose_code_parts)1247 GLOBAL_STATE(py::object verbose_code_parts)
1248 : LeafGuard(std::move(verbose_code_parts)) {
1249 _guard = std::make_unique<GlobalStateGuard>();
1250 _guard->init();
1251 }
1252
check_nopybind(PyObject * value)1253 bool check_nopybind(PyObject* value) override { // borrowed ref
1254 // Ignore value arg, this is just to satisfy the interface.
1255 return _guard->check();
1256 }
1257
check_verbose_nopybind(PyObject * value)1258 GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
1259 if (!_guard->check()) {
1260 return GuardDebugInfo(
1261 false, "GLOBAL_STATE changed: " + _guard->reason(), 0);
1262 }
1263 return GuardDebugInfo(true, 1);
1264 }
1265
1266 private:
1267 std::unique_ptr<GlobalStateGuard> _guard;
1268 };
1269
1270 class DATA_PTR_MATCH : public LeafGuard {
1271 public:
DATA_PTR_MATCH(py::object tensor,py::object verbose_code_parts)1272 DATA_PTR_MATCH(py::object tensor, py::object verbose_code_parts)
1273 : LeafGuard(std::move(verbose_code_parts)) {
1274 PyObject* value = tensor.ptr();
1275 if (!THPVariable_CheckExact(value) && !THPVariable_Check(value)) {
1276 throw std::runtime_error("DATA_PTR_MATCH guard requires a tensor");
1277 }
1278 _data_ptr = THPVariable_Unpack(value).data_ptr();
1279 }
1280
check_nopybind(PyObject * value)1281 bool check_nopybind(PyObject* value) override { // borrowed ref
1282 if (!THPVariable_CheckExact(value) && !THPVariable_Check(value)) {
1283 return false;
1284 }
1285 void* data_ptr = THPVariable_Unpack(value).data_ptr();
1286 return data_ptr == _data_ptr;
1287 }
1288
1289 private:
1290 // Original tensor data pointer.
1291 void* _data_ptr;
1292 };
1293
1294 // Checks that an attr is absent in the object. We don't need the opposite
1295 // HASATTR guard because we can just rely on GetAttrGuardAccessor to act as
1296 // HASATTR guard.
1297 class NO_HASATTR : public LeafGuard {
1298 public:
NO_HASATTR(py::object attr_name,py::object verbose_code_parts)1299 NO_HASATTR(py::object attr_name, py::object verbose_code_parts)
1300 : LeafGuard(std::move(verbose_code_parts)),
1301 _attr_name(std::move(attr_name)) {}
1302
check_nopybind(PyObject * value)1303 bool check_nopybind(PyObject* value) override { // borrowed ref
1304 return PyObject_HasAttr(value, _attr_name.ptr()) == 0;
1305 }
1306
1307 private:
1308 py::object _attr_name;
1309 };
1310
1311 // Checks that dict contains or does not contain a key. This happens for
1312 // PythonSysModulesVariable tracker.
1313 // TODO(janimesh) - Check if we can use DictGuardManager. The downside could be
1314 // large number of keys for sys module, so DICT_CONTAINS might still end up
1315 // being faster.
1316 class DICT_CONTAINS : public LeafGuard {
1317 public:
DICT_CONTAINS(bool contains,py::object key,py::object verbose_code_parts)1318 DICT_CONTAINS(bool contains, py::object key, py::object verbose_code_parts)
1319 : LeafGuard(std::move(verbose_code_parts)),
1320 _contains(contains ? 1 : 0),
1321 _key(std::move(key)) {}
1322
check_nopybind(PyObject * value)1323 bool check_nopybind(PyObject* value) override { // borrowed ref
1324 int result = PyDict_Contains(value, _key.ptr());
1325 if (result == -1) {
1326 PyErr_Clear();
1327 return false;
1328 }
1329 return result == _contains;
1330 }
1331
1332 private:
1333 int _contains;
1334 py::object _key;
1335 };
1336
1337 /**
1338 * Relational guards compare more than one value. We implement Relational
1339 * guards by capturing some state in the guard object. For example for tensor
1340 * aliasing guards - tensor X is not tensor Y - we construct one leaf guard
1341 * and and install it at as a leaf of two guard managers (one for X and
1342 * another for Y). Therefore, this guard is run twice. In the first
1343 * invocation, it saves the first value (state) and returns True. In the
1344 * second invocation, it compares the saved value with the new value and
1345 * returns True if they do not alias.
1346 *
1347 * We have to be careful about resetting in case the other guards fail and we
1348 * have some state in the relational guard. This is done by virtual method
1349 * reset_state(). This is called by the RootGuardManager before it exits.
1350 *
1351 */
1352 class RelationalGuard : public LeafGuard {
1353 public:
RelationalGuard(py::object verbose_code_parts)1354 RelationalGuard(py::object verbose_code_parts)
1355 : LeafGuard(std::move(verbose_code_parts)) {}
1356
1357 // reset the relational guard state on guard failure. This is called by the
1358 // guard manager.
1359 virtual void reset_state() = 0;
1360 };
1361
1362 /**
1363 * Checks that object x is object y.
1364 */
1365 class OBJECT_ALIASING : public RelationalGuard {
1366 public:
OBJECT_ALIASING(py::object verbose_code_parts)1367 OBJECT_ALIASING(py::object verbose_code_parts)
1368 : RelationalGuard(std::move(verbose_code_parts)) {}
1369
check_nopybind(PyObject * value)1370 bool check_nopybind(PyObject* value) override { // borrowed ref
1371 if (_is_first_call) {
1372 _first_tensor = value;
1373 _is_first_call = false;
1374 return true;
1375 }
1376 return _first_tensor == value;
1377 }
1378
reset_state()1379 void reset_state() final {
1380 _is_first_call = true;
1381 }
1382
1383 private:
1384 bool _is_first_call{true};
1385 PyObject* _first_tensor{nullptr};
1386 };
1387
1388 /**
1389 * Checks that none of the tensors alias.
1390 */
1391 class NO_TENSOR_ALIASING : public RelationalGuard {
1392 public:
NO_TENSOR_ALIASING(const py::list & tensor_names,py::object verbose_code_parts)1393 NO_TENSOR_ALIASING(
1394 const py::list& tensor_names,
1395 py::object verbose_code_parts)
1396 : RelationalGuard(std::move(verbose_code_parts)),
1397 _tensor_names(tensor_names) {
1398 _unique_tensors.reserve(tensor_names.size());
1399 }
1400
check_nopybind(PyObject * value)1401 bool check_nopybind(PyObject* value) override { // borrowed ref
1402 // Typically we don't have to increment the ref count here because the
1403 // tensors are held in f_locals. But there is a special case for
1404 // `from_numpy` source. `from_numpy` converts integers and such into tensors
1405 // and these tensors are ephemeral. If we don't incref, those tensors can be
1406 // garbage collected, and the next time from_numpy can reuse the memory
1407 // address. Therefore, we incref here. They are decref'd in reset_state.
1408 Py_INCREF(value);
1409 auto insertion = _unique_tensors.insert({value, nullptr});
1410 if (!insertion.second) {
1411 // No need to clear _unique_tensors, reset_state will do
1412 // it.
1413 return false;
1414 }
1415 return true;
1416 }
1417
check_verbose_nopybind(PyObject * value)1418 GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
1419 bool result = check_nopybind(value);
1420
1421 if (!result) {
1422 return GuardDebugInfo(
1423 false, "Duplicate tensor found where not expected!", 0);
1424 }
1425 return GuardDebugInfo(true, 1);
1426 }
1427
reset_state()1428 void reset_state() final {
1429 for (auto item : _unique_tensors) {
1430 Py_DECREF(item.first);
1431 }
1432 _unique_tensors.clear();
1433 }
1434
1435 private:
1436 py::list _tensor_names;
1437 ska::flat_hash_map<PyObject*, std::nullptr_t> _unique_tensors;
1438 };
1439
1440 class DYNAMIC_INDICES : public LeafGuard {
1441 // C++ equivalent of
1442 // code.append(
1443 // f"(({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices}))
1444 // if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" #
1445 // noqa: B950
1446 // )
1447 public:
DYNAMIC_INDICES(py::set dynamic_indices,py::object verbose_code_parts)1448 DYNAMIC_INDICES(py::set dynamic_indices, py::object verbose_code_parts)
1449 : LeafGuard(std::move(verbose_code_parts)),
1450 _dynamic_indices(std::move(dynamic_indices)) {}
1451
check_nopybind(PyObject * value)1452 bool check_nopybind(PyObject* value) override { // borrowed ref
1453 // Make an interned string
1454 static PyObject* dynamic_indices_str =
1455 PyUnicode_InternFromString("_dynamo_dynamic_indices");
1456 PyObject* indices = PyObject_GetAttr(value, dynamic_indices_str); // new ref
1457 if (indices == nullptr) {
1458 // Attr absent. Clear exception.
1459 PyErr_Clear();
1460 // This is true deliberately. If hasattr fails, we return true.
1461 return true;
1462 }
1463
1464 static PyObject* issubset_str = PyUnicode_InternFromString("issubset");
1465 PyObject* call_result = PyObject_CallMethodOneArg(
1466 indices, issubset_str, _dynamic_indices.ptr()); // new ref
1467 bool result = PyObject_IsTrue(call_result);
1468 Py_DECREF(call_result);
1469 Py_DECREF(indices);
1470 return result;
1471 }
1472
1473 private:
1474 py::set _dynamic_indices;
1475 };
1476
1477 class DICT_VERSION : public LeafGuard {
1478 public:
DICT_VERSION(py::object value,py::object verbose_code_parts)1479 DICT_VERSION(py::object value, py::object verbose_code_parts)
1480 : LeafGuard(std::move(verbose_code_parts)) {
1481 if (!PyDict_Check(value.ptr())) {
1482 throw py::type_error("DICT_VERSION expects a dict");
1483 }
1484 _tag = get_dict_version_unchecked(value.ptr());
1485 }
check_nopybind(PyObject * value)1486 bool check_nopybind(PyObject* value) override { // borrowed ref
1487 return PyDict_Check(value) && get_dict_version_unchecked(value) == _tag;
1488 }
1489
1490 // Saved dict version.
1491 uint64_t _tag;
1492 };
1493
1494 // GuardManager can be a pointer to DictGuardManager, but at this point the
1495 // compiler does not know that DictGuardManager is a derived class of
1496 // GuardManager (no way to define inheritance relationships in forward
1497 // declarations), so we forward declare a factory function and define it when
1498 // both DictGuardManager and GuardManager are fully defined.
1499 std::unique_ptr<GuardManager> make_guard_manager(
1500 RootGuardManager* root,
1501 std::string source,
1502 py::handle example_value,
1503 py::handle guard_manager_enum);
1504
1505 /**
1506 * Base class representing a pair of accessor and the associated guard
1507 * manager. The accessor defines how to access the child value from the
1508 * py::object given to the parent check function.
1509 *
1510 * GuardAccessors can be considered equivalent to name() method of Source
1511 * objects in guards.py. In python, name() method returns a str which we can
1512 * then eval in f_locals and f_globals to retrieve the actual py object.
1513 * GuardAccessor serves the same purpose. The minor difference is that
1514 * GuardManager is a tree structure, so a GuardAccessor just has to retrieve
1515 * the value in the next level in this tree and pass it to the child
1516 * GuardAccessor.
1517 *
1518 * GuardAccessor also owns the GuardManager associated with the retrieved
1519 * value from the GuardAccessor.
1520 */
1521 class GuardAccessor {
1522 public:
GuardAccessor(RootGuardManager * root,py::object accessor_key,std::string source,py::handle example_value,py::handle guard_manager_enum)1523 GuardAccessor(
1524 RootGuardManager* root,
1525 py::object accessor_key,
1526 std::string source,
1527 py::handle example_value,
1528 py::handle guard_manager_enum)
1529 : _guard_manager(make_guard_manager(
1530 root,
1531 source,
1532 example_value,
1533 guard_manager_enum)),
1534 _accessor_key(std::move(accessor_key)),
1535 _source(std::move(source)) {}
1536
1537 // Return by reference as GuardAccessor owns the GuardManager.
get_guard_manager()1538 std::unique_ptr<GuardManager>& get_guard_manager() {
1539 return _guard_manager;
1540 }
1541
matches_key(const py::handle & key) const1542 bool matches_key(const py::handle& key) const {
1543 return _accessor_key.equal(key);
1544 }
1545
get_source()1546 std::string get_source() {
1547 return _source;
1548 }
1549
1550 // matches_dict_tag is used by the DictGetItemGuardAccessor to skip the guard
1551 // subtree on immutable dict getitems.
1552 virtual bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) = 0;
1553 virtual GuardDebugInfo check_verbose_nopybind(PyObject* obj) = 0;
1554 virtual std::string repr() const = 0;
1555
1556 virtual ~GuardAccessor() = default;
1557
1558 protected:
1559 // Guard manager corresponding to the retrieved value from the
1560 // GuardAccessor.
1561 std::unique_ptr<GuardManager> _guard_manager;
1562 // accessor key could be py::str for getattr, getitem or py::function for
1563 // lambda accessor. It is a py::object because we need to keep these accessor
1564 // keys alive.
1565 py::object _accessor_key;
1566
1567 // A string that can be eval'd on f_locals or f_globals to access the variable
1568 // value. Only used for debugging.
1569 std::string _source;
1570 };
1571
1572 /**
1573 * GuardManager encapsulates all the guards related to a particular
1574 * py::object. It is a tree structure and consists of 1) Leaf guards - Guards
1575 * that are run on the user given object 2) Accessors - Guard accessors (like
1576 * getattr, getitem) to access the next value in the tree hierarchy. Accessor
1577 * object also holds the child GuardManager.
1578 *
1579 * Lets look at an example to understand how it works.
1580 * class Pair:
1581 * int x = 1;
1582 * int y = 2;
1583 *
1584 * At compile time
1585 * >> guard_mananger = GuardManager()
1586 * >> guard_mananger.x.add_lambda_guard(
1587 * lambda x: isinstance(x, Pair),
1588 * lambda x: f"expected Pair, found {type(x)}"
1589 * )
1590 * >> guard_mananger.x.add_lambda_guard(lambda x: x == 1, lambda x: f"found
1591 * {x}, expected 1")
1592 * >> guard_mananger.y.add_lambda_guard(lambda x: x == 2, lambda x: f"found
1593 * {x}, expected 2")
1594 *
1595 * At runtime
1596 * >> guard_mananger.check(Pair())
1597 *
1598 * At compile time we build the tree structure. When we do `guard_manager.x`,
1599 * it creates an AttrGuardAccessorNode, initializes a child guard manager with
1600 * this accessor node, and adds it as a child. When we do
1601 * `guard_manager.x.add_lambda_guard`, we call add_lambda_guard on the newly
1602 * created guard manager and register a new leaf guard on it.
1603 *
1604 * At runtime, the accessor node has an important function of providing a way
1605 * to access the value for the child guard. In the above example,
1606 * guard_manager.x adds an AttrGuardAccessorNode with attr_name x. When check
1607 * function is called, parent GuardManager calls getattr(value, "x") on its
1608 * value passed to the check function to call the check function of the child
1609 * guard manager.
1610 *
1611 * Performace optimization for fail fast - An optimization for runtime here is
1612 * to sort the execution of child guards depending on the failure count. This
1613 * ensures that we run the guards that are more prone to fail statistically
1614 * first. This can improve the cache lookup time when we have multiple cache
1615 * entries.
1616 */
1617
1618 class GuardManager {
1619 public:
1620 GuardManager() = delete;
GuardManager(RootGuardManager * root,std::string source)1621 GuardManager(RootGuardManager* root, std::string source)
1622 : _root(root), _source(std::move(source)), _is_dict(false) {}
1623
GuardManager(RootGuardManager * root,std::string source,py::handle example_value)1624 GuardManager(
1625 RootGuardManager* root,
1626 std::string source,
1627 py::handle example_value)
1628 : _root(root),
1629 _source(std::move(source)),
1630 _is_dict(py::isinstance<py::dict>(example_value)) {
1631 if (_is_dict) {
1632 _dict_tag = get_dict_version_unchecked(example_value.ptr());
1633 }
1634 }
1635
1636 GuardManager(const GuardManager& m) = delete;
1637 GuardManager& operator=(const GuardManager&) = delete;
1638 virtual ~GuardManager() = default;
1639
get_root()1640 RootGuardManager* get_root() {
1641 return _root;
1642 }
1643
get_source()1644 std::string get_source() {
1645 return _source;
1646 }
1647
add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard)1648 virtual void add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) {
1649 _leaf_guards.emplace_back(std::move(leaf_guard));
1650 }
1651
1652 /**
1653 * Adds a new guard manager with appropriate Accessor. If the accessor is
1654 * already present, we just return the guard manager.
1655 */
1656 template <typename GuardAccessorT>
get_child_manager(py::object accessor_key,std::string source,py::handle example_value,py::handle guard_manager_enum)1657 GuardManager* get_child_manager(
1658 py::object accessor_key,
1659 std::string source,
1660 py::handle example_value,
1661 py::handle guard_manager_enum) {
1662 // accessor_key type depends on the GuardAccessorT
1663 // for example for GetAttrGuardAccessor - py::str name
1664
1665 // Return the manager if the guard accessor exists
1666 for (const auto& accessor : _accessors) {
1667 if (accessor->matches_key(accessor_key)) {
1668 return accessor->get_guard_manager().get();
1669 }
1670 }
1671
1672 // Construct a new guard accessor
1673 _accessors.emplace_back(std::make_unique<GuardAccessorT>(
1674 _root,
1675 std::move(accessor_key),
1676 source,
1677 example_value,
1678 guard_manager_enum));
1679 return _accessors.back()->get_guard_manager().get();
1680 }
1681
1682 // Runs the leaf guards check and then child managers check function.
1683 //
1684 // NB: There is some code DUPLICATION between this and check_verbose
1685 // function. This is intentional. check function is in the hot path and is
1686 // kept very simple. The purpose of check_verbose function is to get guard
1687 // failure reasoning to understand recompilations. check_verbose function
1688 // does not change the state of the guard, e.g., it does not shuffle the
1689 // guards and does not change the fail count. For simplicity, we duplicate
1690 // the code here.
check_nopybind(PyObject * value)1691 virtual bool check_nopybind(PyObject* value) { // borrowed ref
1692 // Iterate over leaf guards
1693 for (const auto& guard : _leaf_guards) {
1694 if (!guard->check_nopybind(value)) { // early exit
1695 _fail_count += 1;
1696 // no need of sorting, just return.
1697 return false;
1698 }
1699 }
1700
1701 bool matches_dict_tag = false;
1702 uint64_t new_tag = 0;
1703 if (_is_dict) {
1704 // Check if the dict tag matches. If it does, propagate to the child
1705 // accessors. This will pass to the child manager via
1706 // DictGetItemGuardManager.
1707 new_tag = get_dict_version_unchecked(value);
1708 matches_dict_tag = new_tag == _dict_tag;
1709 }
1710
1711 // Iterate over accessors.
1712 bool result = true;
1713 bool failed_on_first = true;
1714 for (const auto& accessor : _accessors) {
1715 if (!accessor->check_nopybind(value, matches_dict_tag)) { // early exit
1716 _fail_count += 1;
1717 result = false;
1718 // need to sort, so break the loop.
1719 break;
1720 }
1721 failed_on_first = false;
1722 }
1723
1724 // failed_on_first is just an optimization to avoid sorting if we are
1725 // failing on the first accessor itself. This is helpful when we have
1726 // already sorted the guards once, and dont need to sort again.
1727 if (!result && !failed_on_first) {
1728 // Inplace sort the child guards by fail count. This moves the guard
1729 // with higher fail count earlier in the queue, and enables fail fast
1730 // for the next check_verbose.
1731
1732 // An alternate implementation was to use priority queue directly on
1733 // _accessors, but it was rejected because of the complexity of
1734 // popping and creating a new pq on each run_guards. Moreover, this sort
1735 // is happening on the unhappy path when check_verbose guard
1736 // fails. So, its probably ok.
1737 std::sort(
1738 _accessors.begin(),
1739 _accessors.end(),
1740 [](const std::unique_ptr<GuardAccessor>& a,
1741 const std::unique_ptr<GuardAccessor>& b) {
1742 return a->get_guard_manager()->fail_count() >
1743 b->get_guard_manager()->fail_count();
1744 });
1745 }
1746
1747 if (_is_dict && result) {
1748 // If result is true, reset the _dict_tag. This is useful if there is a
1749 // mutation on the dict but it does not change the attr values (like
1750 // swapping).
1751 _dict_tag = new_tag;
1752 }
1753 return result;
1754 }
1755
1756 // This function has some code duplication with function check. This is
1757 // deliberate to keep check function simple and fast.
check_verbose_nopybind(PyObject * value)1758 virtual GuardDebugInfo check_verbose_nopybind(
1759 PyObject* value) { // borrowed ref
1760 int num_guards_executed = 0;
1761 // Iterate over leaf guards
1762 for (const auto& guard : _leaf_guards) {
1763 const GuardDebugInfo& debug_info = guard->check_verbose_nopybind(value);
1764 num_guards_executed++;
1765 if (!debug_info.result) {
1766 return GuardDebugInfo(
1767 false, debug_info.verbose_code_parts, num_guards_executed);
1768 }
1769 }
1770
1771 // Iterate over accessors
1772 for (const auto& accessor : _accessors) {
1773 const GuardDebugInfo& debug_info =
1774 accessor->check_verbose_nopybind(value);
1775 num_guards_executed += debug_info.num_guards_executed;
1776 if (!debug_info.result) {
1777 return GuardDebugInfo(
1778 false, debug_info.verbose_code_parts, num_guards_executed);
1779 }
1780 }
1781
1782 return GuardDebugInfo(true, num_guards_executed);
1783 }
1784
fail_count() const1785 int64_t fail_count() const {
1786 return _fail_count;
1787 }
1788
1789 // DEBUG function - Returning raw pointers because we can't return unique_ptr
1790 // and pybind does not accept a unique_ptr reference return type.
get_accessors() const1791 virtual std::vector<GuardAccessor*> get_accessors() const {
1792 std::vector<GuardAccessor*> ret;
1793 ret.reserve(_accessors.size());
1794 for (const auto& accessor : _accessors) {
1795 ret.emplace_back(accessor.get());
1796 }
1797 return ret;
1798 }
1799
1800 // DEBUG function - Returning raw pointers because we can't return unique_ptr
1801 // and pybind does not accept a unique_ptr reference return type.
get_child_managers()1802 virtual std::vector<GuardManager*> get_child_managers() {
1803 std::vector<GuardManager*> ret;
1804 ret.reserve(_accessors.size());
1805 for (const auto& accessor : _accessors) {
1806 ret.emplace_back(accessor->get_guard_manager().get());
1807 }
1808 return ret;
1809 }
1810
1811 // DEBUG function - Returning raw pointers because we can't return unique_ptr
1812 // and pybind does not accept a unique_ptr reference return type.
get_leaf_guards() const1813 std::vector<LeafGuard*> get_leaf_guards() const {
1814 std::vector<LeafGuard*> ret;
1815 ret.reserve(_leaf_guards.size());
1816 for (const auto& guard : _leaf_guards) {
1817 ret.push_back(guard.get());
1818 }
1819 return ret;
1820 }
1821
is_leaf_guard_present(const std::string & guard_name)1822 bool is_leaf_guard_present(const std::string& guard_name) {
1823 return _inserted_leaf_guards.find(guard_name) !=
1824 _inserted_leaf_guards.end();
1825 }
1826
insert_leaf_guard(const std::string & guard_name)1827 void insert_leaf_guard(const std::string& guard_name) {
1828 _inserted_leaf_guards.insert(guard_name);
1829 }
1830
add_permitted_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard)1831 void add_permitted_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) {
1832 // Selectively called for permitted guards. This is used by DictGuardManager
1833 // which overrides the add_leaf_guard manager to throw runtime error.
1834 GuardManager::add_leaf_guard(std::move(leaf_guard));
1835 }
1836
1837 protected:
1838 // Keeps a count of how many times this guard manager check function returns
1839 // False. This is used for sorting optimization.
1840 int64_t _fail_count{0};
1841
1842 private:
1843 // Root of the guard manager, this is the used to install the relational
1844 // guard resetters.
1845 RootGuardManager* _root;
1846
1847 // A string that can be used to eval on f_locals or f_globals to get the
1848 // value. This is used only to pass on debugging information.
1849 std::string _source;
1850
1851 // A map of which leaf guards are inserted. This is to prevent duplicate
1852 // guards like TYPE_MATCH.
1853 std::unordered_set<std::string> _inserted_leaf_guards;
1854
1855 // Leaf guards are the terminal guards on this object, e.g, type check on a
1856 // list. These guards have to be run before any children are run.
1857 //
1858 // These leaf guards are not shufflable. In almost all cases, these guards
1859 // will have an order, e,g., type(x) is int guard and x == 5 guard. We also
1860 // expect very few leaf guards per GuardManager node.
1861 //
1862 // NB: Why are leaf guards shared ptr? This is primarily to enable relational
1863 // guards like `tensor X is not tensor Y`. These guards require multiple
1864 // values. We handle it by creating one guard object that holds state and this
1865 // guard is installed in many guard managers, hence a shared ptr.
1866 std::vector<std::shared_ptr<LeafGuard>> _leaf_guards;
1867
1868 // GuardAccessors nodes to access the child guards. These guards are
1869 // shufflable. On a guard failure, they are sorted based on their fail count
1870 // to enable fail fast for the next check.
1871 std::vector<std::unique_ptr<GuardAccessor>> _accessors;
1872
1873 bool _is_dict;
1874 uint64_t _dict_tag{0};
1875 };
1876
1877 /**
1878 * RootGuardManager is the root of the guard tree. This is primarily
1879 * constructed to hold the relational guard pointers so that we can reset the
1880 * state of those guards on guard failure. All the other important
1881 * implementation is in GuardManager class.
1882 */
1883
1884 class RootGuardManager : public GuardManager {
1885 public:
1886 // This is the root node, set its _root member to nullptr
RootGuardManager()1887 RootGuardManager() : GuardManager(this, "L") {}
1888
1889 // Adds the relational guard resetter
add_relational_guard_resetter(std::shared_ptr<RelationalGuard> relational_guard)1890 void add_relational_guard_resetter(
1891 std::shared_ptr<RelationalGuard> relational_guard) {
1892 _relational_guard_resetters.emplace_back(std::move(relational_guard));
1893 }
1894
1895 // Python visible API to check guard function.
check(py::handle value)1896 bool check(py::handle value) {
1897 return check_nopybind(value.ptr());
1898 }
1899
1900 // Python visible API to check_verbose guard function.
check_verbose(py::handle value)1901 GuardDebugInfo check_verbose(py::handle value) {
1902 return check_verbose_nopybind(value.ptr());
1903 }
1904
1905 // Fast check function.
check_nopybind(PyObject * value)1906 bool check_nopybind(PyObject* value) override { // borrowed ref
1907 // Check [Note on GIL interaction with mutex lock] for details on why we
1908 // need mutex and its interactions wth GIL.
1909 PyThreadState* _save = nullptr;
1910 Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting
1911 std::lock_guard<std::mutex> lock_guard(_lock);
1912 Py_BLOCK_THREADS; // ; is added to avoid clang-formatting
1913
1914 // Get the local state. This will be used for TENSOR_MATCH guards.
1915 if (_init_local_state) {
1916 LocalState state;
1917 _local_state = state;
1918 }
1919
1920 if (!GuardManager::check_nopybind(value)) {
1921 _reset_relational_guard_state();
1922 return false;
1923 }
1924
1925 // Iterate over epilogue leaf guards.
1926 for (const auto& guard : _epilogue_lambda_guards) {
1927 if (!guard->check_nopybind(value)) { // early exit
1928 _reset_relational_guard_state();
1929 return false;
1930 }
1931 }
1932 _reset_relational_guard_state();
1933 return true;
1934 }
1935
1936 // Fast check_verbose function.
check_verbose_nopybind(PyObject * value)1937 GuardDebugInfo check_verbose_nopybind(
1938 PyObject* value) override { // borrowed ref
1939 // Check [Note on GIL interaction with mutex lock] for details on why we
1940 // need mutex and its interactions wth GIL.
1941 PyThreadState* _save = nullptr;
1942 Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting
1943 std::lock_guard<std::mutex> lock_guard(_lock);
1944 Py_BLOCK_THREADS; // ; is added to avoid clang-formatting
1945
1946 // Get the local state. This will be used for TENSOR_MATCH guards.
1947 if (_init_local_state) {
1948 LocalState state;
1949 _local_state = state;
1950 }
1951
1952 GuardDebugInfo debug_info = GuardManager::check_verbose_nopybind(value);
1953 if (!debug_info.result) {
1954 _reset_relational_guard_state();
1955 return debug_info;
1956 }
1957
1958 int num_guards_executed = debug_info.num_guards_executed;
1959
1960 // Iterate over epilogue leaf guards
1961 for (const auto& guard : _epilogue_lambda_guards) {
1962 const GuardDebugInfo& tmp_debug_info =
1963 guard->check_verbose_nopybind(value);
1964 num_guards_executed++;
1965 if (!tmp_debug_info.result) {
1966 _reset_relational_guard_state();
1967 return GuardDebugInfo(
1968 false, tmp_debug_info.verbose_code_parts, num_guards_executed);
1969 }
1970 }
1971 _reset_relational_guard_state();
1972 return GuardDebugInfo(true, num_guards_executed);
1973 }
1974
add_epilogue_lambda_guard(std::unique_ptr<LeafGuard> leaf_guard)1975 void add_epilogue_lambda_guard(std::unique_ptr<LeafGuard> leaf_guard) {
1976 _epilogue_lambda_guards.emplace_back(std::move(leaf_guard));
1977 }
1978
set_init_local_state_flag()1979 void set_init_local_state_flag() {
1980 _init_local_state = true;
1981 }
1982
1983 // DEBUG function - Returning raw pointers because we can't return unique_ptr
1984 // and pybind does not accept a unique_ptr reference return type.
get_epilogue_lambda_guards() const1985 std::vector<LeafGuard*> get_epilogue_lambda_guards() const {
1986 std::vector<LeafGuard*> ret;
1987 ret.reserve(_epilogue_lambda_guards.size());
1988 for (const auto& guard : _epilogue_lambda_guards) {
1989 ret.push_back(guard.get());
1990 }
1991 return ret;
1992 }
1993
1994 private:
1995 // Reset the state of all the relational guards on failure.
_reset_relational_guard_state()1996 void _reset_relational_guard_state() {
1997 for (auto& guard : _relational_guard_resetters) {
1998 guard->reset_state();
1999 }
2000 }
2001
2002 public:
2003 // Local state for TENSOR_MATCH guards.
2004 LocalState _local_state;
2005
2006 private:
2007 // All the relational guards under this guard mananger. We only use these
2008 // when the guard evaluates to False. This ensures that guard state is reset
2009 // on guard failure so that next invocation is clean.
2010 std::vector<std::shared_ptr<RelationalGuard>> _relational_guard_resetters;
2011
2012 // These guards are lambda guards, i.e., the guards that lack C++
2013 // implementation. For simplicity, we add these guards at the root. They
2014 // MUST be run after all other guard managers have finished to ensure that
2015 // the epilogue guards do not step on some nonexistent getattr or getitem.
2016 std::vector<std::unique_ptr<LeafGuard>> _epilogue_lambda_guards;
2017
2018 // [Note on GIL interaction with mutex lock]
2019 // We use std::mutex to prevent multiple threads from running
2020 // check/check_verbose simultaneously. This is to prevent race condition due
2021 // to state changes in RelationalGuard.
2022 //
2023 // However, we also need to be careful about GIL interaction with mutex. There
2024 // is a chance of deadlock
2025 //
2026 // Thread 1: has GIL, waiting for lock
2027 // Thread 2: has lock, waiting for GIL
2028 //
2029 // This can happen when Thread 2 earlier acquired the mutex lock, starting
2030 // running the critical section of check function and then called some python
2031 // function (like LAMBDA_GUARD) and reached Cpython codebase that checks if it
2032 // should release the GIL (typically happens after every few bytecode
2033 // instructions). Thread 2 here can decide to release the GIL. Thread 1 can
2034 // acquire GIL and reach the mutex, where it will wait forever.
2035 //
2036 // To avoid this, each thread releases the GIL before acquiring the mutex and
2037 // then acquires the GIL again after acquiring the mutex lock by using
2038 // Py_BLOCK_THREADS and Py_UNBLOCK_THREADS. This avoids the deadlock.
2039 std::mutex _lock;
2040
2041 // We init LocalState only when this flag it set. This flag is set during
2042 // TENSOR_MATCH guard init.
2043 bool _init_local_state = false;
2044 };
2045
2046 /*
2047 * Dicts are common in python code. Therefore, we handle guards for dicts
2048 * differently and use PyDict_* APIs which are faster than PyObject_* APIs
2049 * because of no ref count increments/decrements.
2050 *
2051 * DictGuardManager relies on the order of dict.keys(). It keeps track of the
2052 * indices of dict.keys() to access the key, value pair.
2053 */
2054 typedef std::pair<std::unique_ptr<GuardManager>, std::unique_ptr<GuardManager>>
2055 KeyValueManager;
2056 class DictGuardManager : public GuardManager {
2057 public:
DictGuardManager(RootGuardManager * root,std::string source,py::handle example_value)2058 DictGuardManager(
2059 RootGuardManager* root,
2060 std::string source,
2061 py::handle example_value)
2062 : GuardManager(root, std::move(source)),
2063 _size(PyDict_Size(example_value.ptr())),
2064 _expected_type(Py_TYPE(example_value.ptr())),
2065 _is_exact_dict_type(PyDict_CheckExact(example_value.ptr())) {}
2066
get_key_manager(py::object key_index,std::string source,py::handle example_value,py::handle guard_manager_enum)2067 GuardManager* get_key_manager(
2068 py::object key_index,
2069 std::string source,
2070 py::handle example_value,
2071 py::handle guard_manager_enum) {
2072 KeyValueManager& key_value_manager =
2073 _get_index_manager(std::move(key_index));
2074 if (!key_value_manager.first) {
2075 key_value_manager.first = make_guard_manager(
2076 this->get_root(),
2077 std::move(source),
2078 example_value,
2079 guard_manager_enum);
2080 };
2081 return key_value_manager.first.get();
2082 }
2083
get_value_manager(py::object key_index,std::string source,py::handle example_value,py::handle guard_manager_enum)2084 GuardManager* get_value_manager(
2085 py::object key_index,
2086 std::string source,
2087 py::handle example_value,
2088 py::handle guard_manager_enum) {
2089 KeyValueManager& key_value_manager =
2090 _get_index_manager(std::move(key_index));
2091 if (!key_value_manager.second) {
2092 key_value_manager.second = make_guard_manager(
2093 this->get_root(),
2094 std::move(source),
2095 example_value,
2096 guard_manager_enum);
2097 };
2098 return key_value_manager.second.get();
2099 }
2100
check_nopybind(PyObject * obj)2101 bool check_nopybind(PyObject* obj) override { // borrowed ref
2102 // TODO(janimesh) - Implement a fast-path using dict versions.
2103
2104 if (Py_TYPE(obj) != _expected_type) {
2105 _fail_count += 1;
2106 return false;
2107 }
2108
2109 if (PyDict_Size(obj) != _size) {
2110 _fail_count += 1;
2111 return false;
2112 }
2113
2114 // Early return
2115 if (_size == 0) {
2116 return true;
2117 }
2118
2119 // Invokes the base class's check_nopybind method. We permit a limited set
2120 // of leaf guards and accessors within the DictGuardManager framework.
2121 // Integrating certain guards or accessors directly within the
2122 // DictGuardManager can be challenging. For instance, `type(dict_object)` as
2123 // an accessor is permissible, which otherwise would be hard to integrate
2124 // directly into DictGuardManager. Similarly, incorporating guards such as
2125 // DICT_CONTAINS and DICT_VERSION as leaf guards offers a simpler solution
2126 // than embedding these functionalities within the DictGuardManager itself.
2127 if (!GuardManager::check_nopybind(obj)) {
2128 _fail_count += 1;
2129 // No need to shuffle the child guards, just return.
2130 return false;
2131 }
2132
2133 PyObject *key = nullptr, *value = nullptr;
2134 Py_ssize_t pos = 0;
2135
2136 // Points to an element in the _indices vector.
2137 size_t index_pointer = 0;
2138 // Points to the key index in the dict
2139 Py_ssize_t dict_pointer = 0;
2140
2141 while (index_pointer < _indices.size() &&
2142 PyDict_Next(obj, &pos, &key, &value)) {
2143 // Skip if dict_pointer is not a saved index.
2144 if (dict_pointer == _indices[index_pointer]) {
2145 index_pointer += 1;
2146 KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2147 std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2148 if (key_manager && !key_manager->check_nopybind(key)) {
2149 return false;
2150 }
2151 std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2152 if (value_manager && !value_manager->check_nopybind(value)) {
2153 return false;
2154 }
2155 }
2156 dict_pointer += 1;
2157 }
2158 return true;
2159 }
2160
check_verbose_nopybind(PyObject * obj)2161 GuardDebugInfo check_verbose_nopybind(
2162 PyObject* obj) override { // borrowed ref
2163 if (Py_TYPE(obj) != _expected_type) {
2164 return GuardDebugInfo(false, "TYPE_MISMATCH(" + get_source() + ")", 0);
2165 }
2166
2167 if (PyDict_Size(obj) != _size) {
2168 return GuardDebugInfo(
2169 false, "len(" + get_source() + ") != " + std::to_string(_size), 0);
2170 }
2171
2172 // Early return
2173 if (_size == 0) {
2174 return GuardDebugInfo(true, 0);
2175 }
2176
2177 // Invokes the base class's check_nopybind method. We permit a limited set
2178 // of leaf guards and accessors within the DictGuardManager framework.
2179 // Integrating certain guards or accessors directly within the
2180 // DictGuardManager can be challenging. For instance, `type(dict_object)` as
2181 // an accessor is permissible, which otherwise would be hard to integrate
2182 // directly into DictGuardManager. Similarly, incorporating guards such as
2183 // DICT_CONTAINS and DICT_VERSION as leaf guards offers a simpler solution
2184 // than embedding these functionalities within the DictGuardManager itself.
2185 GuardDebugInfo debug_info = GuardManager::check_verbose_nopybind(obj);
2186 if (!debug_info.result) {
2187 return debug_info;
2188 }
2189
2190 PyObject *key = nullptr, *value = nullptr;
2191 Py_ssize_t pos = 0;
2192
2193 // Points to an element in the _indices vector.
2194 size_t index_pointer = 0;
2195 Py_ssize_t dict_pointer = 0;
2196
2197 int num_guards_executed = 0;
2198 while (index_pointer < _indices.size() &&
2199 PyDict_Next(obj, &pos, &key, &value)) {
2200 // Skip if pos is not a saved index.
2201 if (dict_pointer == _indices[index_pointer]) {
2202 index_pointer += 1;
2203 KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2204 std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2205 if (key_manager) {
2206 GuardDebugInfo debug_info = key_manager->check_verbose_nopybind(key);
2207 num_guards_executed += debug_info.num_guards_executed;
2208 if (!debug_info.result) {
2209 return GuardDebugInfo(
2210 false, debug_info.verbose_code_parts, num_guards_executed);
2211 }
2212 }
2213 std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2214 if (value_manager) {
2215 GuardDebugInfo debug_info =
2216 value_manager->check_verbose_nopybind(value);
2217 num_guards_executed += debug_info.num_guards_executed;
2218 if (!debug_info.result) {
2219 return GuardDebugInfo(
2220 false, debug_info.verbose_code_parts, num_guards_executed);
2221 }
2222 }
2223 }
2224 dict_pointer += 1;
2225 }
2226 return GuardDebugInfo(true, num_guards_executed);
2227 }
2228
skip_adding_guard(const py::object & a,const py::object & b)2229 void skip_adding_guard(const py::object& a, const py::object& b) {
2230 // The `add_leaf_guard` method in `DictGuardManager` is overridden to block
2231 // the addition of leaf guards. However, this is too strict. Python side of
2232 // guard management frequently adds TYPE_MATCH and DICT_LENGTH on
2233 // DictGuardManager. We could refactor Python side to never call these
2234 // guards on dict objects, but that results in messy code. Instead, we just
2235 // override these two guards to not go through add_leaf_guard code path and
2236 // skip adding guards. This makes the python side easy.
2237 }
2238
fail_on_get_child_manager(const py::object & a,const std::string & source,const py::object & b)2239 void fail_on_get_child_manager(
2240 const py::object& a,
2241 const std::string& source,
2242 const py::object& b) {
2243 throw std::runtime_error("Can not add an accessor to DictGuardManager");
2244 }
2245
add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard)2246 void add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) override {
2247 // If you are calling this, you probably want to go through a key, value
2248 // child manager and then add a leaf guard on them. DictGuardManager already
2249 // has TYPE_MATCH and LENGTH_CHECK built in.
2250 throw std::runtime_error("DictGuardManager does not support a leaf_guard");
2251 }
2252
2253 // Debug helper - Returning raw pointers because we can't return unique_ptr
2254 // and pybind does not accept a unique_ptr reference return type.
2255 std::unordered_map<Py_ssize_t, std::pair<GuardManager*, GuardManager*>>
get_key_value_managers()2256 get_key_value_managers() {
2257 std::unordered_map<Py_ssize_t, std::pair<GuardManager*, GuardManager*>> ret;
2258 for (auto index : _indices) {
2259 ret[index] = std::make_pair(
2260 _key_value_managers[index].first.get(),
2261 _key_value_managers[index].second.get());
2262 }
2263 return ret;
2264 }
2265
is_exact_dict_type()2266 bool is_exact_dict_type() {
2267 return _is_exact_dict_type;
2268 }
2269
2270 private:
2271 /**
2272 * Adds a new KeyDictGuardAccessor. If the accessor is already present, we
2273 * just return the guard manager.
2274 */
_get_index_manager(py::object key_index)2275 KeyValueManager& _get_index_manager(py::object key_index) {
2276 // Check if the accessor is already present.
2277 Py_ssize_t index = py::cast<Py_ssize_t>(std::move(key_index));
2278 auto it = _key_value_managers.find(index);
2279 if (it != _key_value_managers.end()) {
2280 return it->second;
2281 }
2282 _indices.push_back(index);
2283 // Always keep the _indices array sorted
2284 std::sort(_indices.begin(), _indices.end());
2285 _key_value_managers[index] = std::make_pair(nullptr, nullptr);
2286 return _key_value_managers[index];
2287 }
2288
2289 protected: // also used by DictSubclassGuardManager
2290 Py_ssize_t _size;
2291 // DictGuardManager supports both exact dict type and non-exact dict type.
2292 // Therefore, we have to compare the type to early exit.
2293 PyTypeObject* _expected_type;
2294 bool _is_exact_dict_type; // Useful to check getattr_manager validity.
2295 std::vector<Py_ssize_t> _indices;
2296 std::unordered_map<Py_ssize_t, KeyValueManager> _key_value_managers;
2297 };
2298
2299 /**
2300 * The DictSubclassGuardManager is designed to work with dict subclasses,
2301 * specifically focusing on OrderedDicts. Standard dictionaries leverage the
2302 * PyDict_Next function to iterate over keys, values, and items. OrderedDicts,
2303 * on the other hand, rely on an additional linked list structure to maintain
2304 * keys order. Although PyDict_Next and OrderedDict generally yield the same
2305 * order, discrepancies arise when using OrderedDict's move_to_end method (used
2306 * in Pytorch hooks). `move_to_end` method only updates the linked list, leaving
2307 * PyDict_Next unaffected. Therefore, to accurately capture key ordering in such
2308 * cases, DictSubclassGuardManager directly invoke the .keys() method.
2309 */
2310
2311 class DictSubclassGuardManager : public DictGuardManager {
2312 public:
DictSubclassGuardManager(RootGuardManager * root,std::string source,py::handle example_value)2313 DictSubclassGuardManager(
2314 RootGuardManager* root,
2315 std::string source,
2316 py::handle example_value)
2317 : DictGuardManager(root, std::move(source), example_value) {}
2318
2319 public:
check_nopybind(PyObject * obj)2320 bool check_nopybind(PyObject* obj) override { // borrowed ref
2321 // TODO(janimesh) - Implement a fast-path using dict versions.
2322
2323 if (Py_TYPE(obj) != _expected_type) {
2324 _fail_count += 1;
2325 return false;
2326 }
2327
2328 if (PyDict_Size(obj) != _size) {
2329 _fail_count += 1;
2330 return false;
2331 }
2332
2333 // Early return
2334 if (_size == 0) {
2335 return true;
2336 }
2337
2338 if (!GuardManager::check_nopybind(obj)) { // NOLINT
2339 _fail_count += 1;
2340 // No need to shuffle the child guards, just return.
2341 return false;
2342 }
2343
2344 // Points to an element in the _indices vector.
2345 size_t index_pointer = 0;
2346 // Points to the key index in the dict
2347 Py_ssize_t dict_pointer = 0;
2348
2349 // Use iter(dict.keys()) to iterate over the keys
2350 py::object keys =
2351 py::handle(obj).attr("keys")(); // py::object handles the references
2352 PyObject* iterator = PyObject_GetIter(keys.ptr()); // new reference
2353 PyObject* key = nullptr;
2354
2355 while (index_pointer < _indices.size() &&
2356 (key = PyIter_Next(iterator))) { // new reference
2357 if (dict_pointer == _indices[index_pointer]) {
2358 KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2359 std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2360 if (key_manager && !key_manager->check_nopybind(key)) {
2361 Py_DECREF(key);
2362 Py_DECREF(iterator);
2363 return false;
2364 }
2365
2366 PyObject* value = PyDict_GetItem(obj, key); // borrowed ref
2367 std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2368 if (value_manager && !value_manager->check_nopybind(value)) {
2369 Py_DECREF(key);
2370 Py_DECREF(iterator);
2371 return false;
2372 }
2373
2374 index_pointer++;
2375 }
2376 dict_pointer++;
2377 Py_DECREF(key);
2378 }
2379
2380 Py_DECREF(iterator);
2381 return true;
2382 }
2383
check_verbose_nopybind(PyObject * obj)2384 GuardDebugInfo check_verbose_nopybind(
2385 PyObject* obj) override { // borrowed ref
2386 if (Py_TYPE(obj) != _expected_type) {
2387 return GuardDebugInfo(false, "TYPE_MISMATCH(" + get_source() + ")", 0);
2388 }
2389
2390 if (PyDict_Size(obj) != _size) {
2391 return GuardDebugInfo(
2392 false, "len(" + get_source() + ") != " + std::to_string(_size), 0);
2393 }
2394
2395 // Early return
2396 if (_size == 0) {
2397 return GuardDebugInfo(true, 0);
2398 }
2399
2400 GuardDebugInfo debug_info =
2401 GuardManager::check_verbose_nopybind(obj); // NOLINT
2402 if (!debug_info.result) {
2403 return debug_info;
2404 }
2405
2406 // Points to an element in the _indices vector.
2407 size_t index_pointer = 0;
2408 // Points to the key index in the dict
2409 Py_ssize_t dict_pointer = 0;
2410
2411 int num_guards_executed = 0;
2412
2413 // Use iter(dict.keys()) to iterate over the keys
2414 py::object keys =
2415 py::handle(obj).attr("keys")(); // py::object handles the references
2416 PyObject* iterator = PyObject_GetIter(keys.ptr()); // new reference
2417 PyObject* key = nullptr;
2418
2419 while (index_pointer < _indices.size() &&
2420 (key = PyIter_Next(iterator))) { // new reference
2421 if (dict_pointer == _indices[index_pointer]) {
2422 KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2423 std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2424 if (key_manager) {
2425 GuardDebugInfo debug_info = key_manager->check_verbose_nopybind(key);
2426 num_guards_executed += debug_info.num_guards_executed;
2427 if (!debug_info.result) {
2428 Py_DECREF(key);
2429 Py_DECREF(iterator);
2430 return GuardDebugInfo(
2431 false, debug_info.verbose_code_parts, num_guards_executed);
2432 }
2433 }
2434
2435 PyObject* value = PyDict_GetItem(obj, key); // borrowed ref
2436 std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2437 if (value_manager) {
2438 GuardDebugInfo debug_info =
2439 value_manager->check_verbose_nopybind(value);
2440 num_guards_executed += debug_info.num_guards_executed;
2441 if (!debug_info.result) {
2442 Py_DECREF(key);
2443 Py_DECREF(iterator);
2444 return GuardDebugInfo(
2445 false, debug_info.verbose_code_parts, num_guards_executed);
2446 }
2447 }
2448 index_pointer++;
2449 }
2450 Py_DECREF(key);
2451 dict_pointer++;
2452 }
2453
2454 Py_DECREF(iterator);
2455 return GuardDebugInfo(true, num_guards_executed);
2456 }
2457 };
2458
make_guard_manager(RootGuardManager * root,std::string source,py::handle example_value,py::handle guard_manager_enum)2459 std::unique_ptr<GuardManager> make_guard_manager(
2460 RootGuardManager* root,
2461 std::string source,
2462 py::handle example_value,
2463 py::handle guard_manager_enum) {
2464 static py::object guard_manager_enum_class =
2465 py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
2466 static py::object base_guard_manager_enum =
2467 guard_manager_enum_class.attr("GUARD_MANAGER");
2468 static py::object dict_guard_manager_enum =
2469 guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
2470 static py::object dict_subclass_guard_manager_enum =
2471 guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
2472 if (py::isinstance<py::dict>(example_value)) {
2473 // The purpose of having both DictGuardManager and DictSubclassGuardManager
2474 // is to handle the variability in how dictionaries and their subclasses
2475 // manage key ordering.
2476
2477 // While inserting dictionary guards (check guards.py), we rely on the
2478 // list(d.keys()) ordering. Therefore, the cpp guard equivalent must have
2479 // the same keys ordering. For standard dictionaries, .keys() API internally
2480 // uses PyDict_Next. So, DictGuardManager directly uses PyDict_Next to
2481 // speedup the key fetches.
2482
2483 // But PyDict_Next might not give correct ordering for subclasses of dict.
2484 // For example, OrderedDict override the .keys() API without changing the
2485 // underlying datastructure. This leads to different keys ordering than the
2486 // one given by PyDict_Next. We use DictSubclassGuardManager to account for
2487 // this discrepancy. DictSubclassGuardManager directly calls the .keys() API
2488 // to accurately capture key ordering. This approach is less efficient than
2489 // using PyDict_Next (handled by DictGuardManager), but it ensures
2490 // correctness.
2491
2492 // Since regular dicts are more common than subclasses of dicts with
2493 // overridden keys method, we still optimize for the common case with
2494 // DictGuardManager by relying on PyDict_Next.
2495
2496 if (guard_manager_enum.is(base_guard_manager_enum)) {
2497 // For dicts that don't need to guard on keys, we can just rely on the
2498 // base GuardManager.
2499 return std::make_unique<GuardManager>(
2500 root, std::move(source), example_value);
2501 } else if (guard_manager_enum.is(dict_guard_manager_enum)) {
2502 return std::make_unique<DictGuardManager>(
2503 root, std::move(source), example_value);
2504 } else if (guard_manager_enum.is(dict_subclass_guard_manager_enum))
2505 return std::make_unique<DictSubclassGuardManager>(
2506 root, std::move(source), example_value);
2507 else {
2508 throw py::type_error("Invalid guard manager enum");
2509 }
2510 }
2511 return std::make_unique<GuardManager>(root, std::move(source));
2512 }
2513
2514 class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
2515 public:
TORCH_FUNCTION_MODE_STACK(const py::list & initial_stack,const py::list & ignored_types,py::object verbose_code_parts)2516 TORCH_FUNCTION_MODE_STACK(
2517 const py::list& initial_stack,
2518 const py::list& ignored_types,
2519 py::object verbose_code_parts)
2520 : LeafGuard(std::move(verbose_code_parts)),
2521 _ref_stack(),
2522 _ignored_types() {
2523 Py_ssize_t len = PyList_Size(initial_stack.ptr());
2524 for (Py_ssize_t idx = 0; idx < len; idx++) {
2525 PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
2526 this->_ref_stack.push_back(Py_TYPE(mode));
2527 }
2528
2529 len = PyList_Size(ignored_types.ptr());
2530 for (Py_ssize_t idx = 0; idx < len; idx++) {
2531 PyObject* type_obj =
2532 PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
2533 if (PyType_Check(type_obj) == 0) {
2534 PyErr_SetString(
2535 PyExc_TypeError, "ignored_types should contain a list of types");
2536 return;
2537 }
2538 PyTypeObject* type = (PyTypeObject*)type_obj;
2539 this->_ignored_types.insert(type);
2540 }
2541 }
2542
check_nopybind(PyObject * value)2543 bool check_nopybind(PyObject* value) override {
2544 // Ignore value arg, only used to satisfy the interface
2545 size_t ref_ind = 0;
2546 int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
2547 const size_t ref_stack_size = this->_ref_stack.size();
2548
2549 for (int64_t idx = 0; idx < len; idx++) {
2550 std::shared_ptr<c10::SafePyObject> mode =
2551 at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
2552
2553 PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
2554 // skip ignored types
2555 if (this->_ignored_types.count(mode_type) > 0) {
2556 continue;
2557 }
2558 // if we already have more non-ignored modes than the ref stack
2559 // or if the mode doesn't match at the current index, return false
2560 else if (
2561 (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) ||
2562 mode_type != _ref_stack[ref_ind]) {
2563 return false;
2564 }
2565 ref_ind++;
2566 }
2567
2568 return ref_ind == this->_ref_stack.size();
2569 }
2570
2571 private:
2572 std::vector<PyTypeObject*> _ref_stack;
2573 std::set<PyTypeObject*> _ignored_types;
2574 };
2575
2576 class TENSOR_MATCH : public LeafGuard {
2577 public:
TENSOR_MATCH(RootGuardManager * root_guard_manager,py::object value,py::object dynamic_dims_sizes_py,py::object dynamic_dims_strides_py,py::object tensor_name,py::object verbose_code_parts)2578 TENSOR_MATCH(
2579 RootGuardManager* root_guard_manager,
2580 py::object value,
2581 py::object dynamic_dims_sizes_py,
2582 py::object dynamic_dims_strides_py,
2583 py::object tensor_name,
2584 py::object verbose_code_parts)
2585 : LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
2586 _tensor_name(py::cast<py::str>(std::move(tensor_name))) {
2587 root_guard_manager->set_init_local_state_flag();
2588 PyObject* item = value.ptr();
2589 if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
2590 PyErr_SetString(PyExc_TypeError, "expected Tensor()");
2591 return;
2592 }
2593 auto tensor = THPVariable_Unpack(item);
2594
2595 std::vector<std::optional<c10::SymInt>> tensor_dims_size =
2596 pyListToVecOptInt(dynamic_dims_sizes_py.ptr());
2597 std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
2598 pyListToVecOptInt(dynamic_dims_strides_py.ptr());
2599
2600 tensor_dims_size = tensor_dims_size.empty()
2601 ? wrapIntegersInOptional(tensor.sym_sizes())
2602 : tensor_dims_size;
2603 tensor_dims_stride = tensor_dims_stride.empty()
2604 ? wrapIntegersInOptional(tensor.sym_strides())
2605 : tensor_dims_stride;
2606 LocalState state;
2607 _tensor_check = std::make_unique<TensorCheck>(
2608 state,
2609 Py_TYPE(item),
2610 std::move(tensor),
2611 std::move(tensor_dims_size),
2612 std::move(tensor_dims_stride));
2613 }
2614
check_nopybind(PyObject * value)2615 bool check_nopybind(PyObject* value) override { // borrowed ref
2616 if (Py_TYPE(value) != _tensor_check->pytype) {
2617 return false;
2618 }
2619 return _tensor_check->check(
2620 _root_guard_manager->_local_state, THPVariable_Unpack(value));
2621 }
2622
check_verbose_nopybind(PyObject * value)2623 GuardDebugInfo check_verbose_nopybind(
2624 PyObject* value) override { // borrowed ref
2625
2626 if (Py_TYPE(value) != _tensor_check->pytype) {
2627 std::stringstream fail_reason;
2628 PyObject* type_str = PyObject_Str(PyObject_Type(value));
2629 fail_reason << "expected type of '" << _tensor_name
2630 << "' to be a tensor type, ";
2631 if (!type_str) {
2632 fail_reason << "but found a different type";
2633 } else {
2634 fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
2635 }
2636 return GuardDebugInfo(false, fail_reason.str(), 0);
2637 }
2638
2639 std::string fail_reason = _tensor_check->check_verbose(
2640 _root_guard_manager->_local_state,
2641 THPVariable_Unpack(value),
2642 _tensor_name);
2643
2644 if (!fail_reason.empty()) {
2645 if (is_parameter(py::handle(value))) {
2646 fail_reason += ". Guard failed on a parameter, consider using ";
2647 fail_reason +=
2648 "torch._dynamo.config.force_parameter_static_shapes = False ";
2649 fail_reason += "to allow dynamism on parameters.";
2650 }
2651 return GuardDebugInfo(false, fail_reason, 0);
2652 }
2653 return GuardDebugInfo(true, 1);
2654 }
2655
2656 private:
2657 std::string _tensor_name;
2658 std::unique_ptr<TensorCheck> _tensor_check;
2659 };
2660
2661 /**
2662 * Represents __getattr__ acccessor.
2663 */
2664 class GetAttrGuardAccessor : public GuardAccessor {
2665 public:
GetAttrGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)2666 GetAttrGuardAccessor(
2667 RootGuardManager* root,
2668 py::str name,
2669 std::string source,
2670 py::handle example_value,
2671 py::handle guard_manager_enum)
2672 : GuardAccessor(
2673 root,
2674 name,
2675 std::move(source),
2676 example_value,
2677 guard_manager_enum),
2678 _attr_name(name.ptr()) {}
2679
2680 // NB: Intentional duplication between check_nopybind and
2681 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2682 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2683 override { // borrowed ref
2684 PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
2685 if (x == nullptr) {
2686 // Attribute absent, clear the exception and return false.
2687 PyErr_Clear();
2688 return false;
2689 }
2690 bool result = _guard_manager->check_nopybind(x);
2691 Py_DECREF(x);
2692 return result;
2693 }
2694
check_verbose_nopybind(PyObject * obj)2695 GuardDebugInfo check_verbose_nopybind(
2696 PyObject* obj) override { // borrowed ref
2697 PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
2698 if (x == nullptr) {
2699 // Attribute absent, clear the exception and return false.
2700 PyErr_Clear();
2701 return GuardDebugInfo(
2702 false, "getattr failed on source " + get_source(), 0);
2703 }
2704 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2705 Py_DECREF(x);
2706 return result;
2707 }
2708
repr() const2709 std::string repr() const override {
2710 // Helpful when priting GuardManager tree structure.
2711 return "GetAttrGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
2712 ")";
2713 }
2714
2715 private:
2716 // no need of py::object here because the attr_name is already passed on to
2717 // the base class as accessor_key which is a py::object.
2718 PyObject* _attr_name;
2719 };
2720
2721 /**
2722 * Represents x.__dict__ acccessor.
2723 */
2724 class GetGenericDictGuardAccessor : public GuardAccessor {
2725 public:
GetGenericDictGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)2726 GetGenericDictGuardAccessor(
2727 RootGuardManager* root,
2728 py::str name,
2729 std::string source,
2730 py::handle example_value,
2731 py::handle guard_manager_enum)
2732 : GuardAccessor(
2733 root,
2734 std::move(name),
2735 std::move(source),
2736 example_value,
2737 guard_manager_enum) {}
2738
2739 // NB: Intentional duplication between check_nopybind and
2740 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2741 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2742 override { // borrowed ref
2743 PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref
2744 if (x == nullptr) {
2745 // Attribute absent, clear the exception and return false.
2746 PyErr_Clear();
2747 return false;
2748 }
2749 bool result = _guard_manager->check_nopybind(x);
2750 Py_DECREF(x);
2751 return result;
2752 }
2753
check_verbose_nopybind(PyObject * obj)2754 GuardDebugInfo check_verbose_nopybind(
2755 PyObject* obj) override { // borrowed ref
2756 PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref
2757 if (x == nullptr) {
2758 // Attribute absent, clear the exception and return false.
2759 PyErr_Clear();
2760 return GuardDebugInfo(
2761 false, "getattr failed on source " + get_source(), 0);
2762 }
2763 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2764 Py_DECREF(x);
2765 return result;
2766 }
2767
repr() const2768 std::string repr() const override {
2769 // Helpful when priting GuardManager tree structure.
2770 return "GetGenericDictGuardAccessor";
2771 }
2772 };
2773
2774 /**
2775 * Represents __getitem__ acccessor.
2776 */
2777 class GetItemGuardAccessor : public GuardAccessor {
2778 public:
GetItemGuardAccessor(RootGuardManager * root,py::object name,std::string source,py::handle example_value,py::handle guard_manager_enum)2779 GetItemGuardAccessor(
2780 RootGuardManager* root,
2781 py::object name,
2782 std::string source,
2783 py::handle example_value,
2784 py::handle guard_manager_enum)
2785 : GuardAccessor(
2786 root,
2787 name,
2788 std::move(source),
2789 example_value,
2790 guard_manager_enum),
2791 _attr_name(name.ptr()) {}
2792
2793 // NB: Intentional duplication between check_nopybind and
2794 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2795 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2796 override { // borrowed ref
2797 PyObject* x = PyObject_GetItem(obj, _attr_name); // new ref
2798 if (x == nullptr) {
2799 PyErr_Clear();
2800 return false;
2801 }
2802 bool result = _guard_manager->check_nopybind(x);
2803 Py_DECREF(x);
2804 return result;
2805 }
2806
check_verbose_nopybind(PyObject * obj)2807 GuardDebugInfo check_verbose_nopybind(
2808 PyObject* obj) override { // borrowed ref
2809 PyObject* x = PyObject_GetItem(obj, _attr_name); // new ref
2810 if (x == nullptr) {
2811 PyErr_Clear();
2812 return GuardDebugInfo(
2813 false, std::string("KeyError on ") + get_source(), 0);
2814 }
2815 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2816 Py_DECREF(x);
2817 return result;
2818 }
2819
repr() const2820 std::string repr() const override {
2821 return "GetItemGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
2822 ")";
2823 }
2824
2825 private:
2826 // no need of py::object here because the attr_name is already passed on to
2827 // the base class as accessor_key which is a py::object.
2828 PyObject* _attr_name;
2829 };
2830
2831 /**
2832 * Represents dict[name] acccessor. This is ONLY used for f_locals because its a
2833 * dict, and DictGuardManager does not support sorting. We differentiate it from
2834 * GetItemGuardAccessor because PyDict_GetItem should be fasten the
2835 * PyObject_GetItem.
2836 */
2837 class DictGetItemGuardAccessor : public GuardAccessor {
2838 public:
DictGetItemGuardAccessor(RootGuardManager * root,py::object key,std::string source,py::handle example_value,py::handle guard_manager_enum)2839 DictGetItemGuardAccessor(
2840 RootGuardManager* root,
2841 py::object key,
2842 std::string source,
2843 py::handle example_value,
2844 py::handle guard_manager_enum)
2845 : GuardAccessor(
2846 root,
2847 key,
2848 std::move(source),
2849 example_value,
2850 guard_manager_enum),
2851 _key(key.ptr()),
2852 _is_immutable_object(is_immutable_object(example_value)) {}
2853
2854 // NB: Intentional duplication between check_nopybind and
2855 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2856 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2857 override { // borrowed ref
2858 if (matches_dict_tag && _is_immutable_object) {
2859 // immutable object and dict tag matches, we can skip the guard subtree.
2860 return true;
2861 }
2862 PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref
2863 if (x == nullptr) {
2864 PyErr_Clear();
2865 return false;
2866 }
2867 bool result = _guard_manager->check_nopybind(x);
2868 return result;
2869 }
2870
check_verbose_nopybind(PyObject * obj)2871 GuardDebugInfo check_verbose_nopybind(
2872 PyObject* obj) override { // borrowed ref
2873 PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref
2874 if (x == nullptr) {
2875 PyErr_Clear();
2876 return GuardDebugInfo(
2877 false, std::string("KeyError on ") + get_source(), 0);
2878 }
2879 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2880 return result;
2881 }
2882
repr() const2883 std::string repr() const override {
2884 return "DictGetItemGuardAccessor(" + py::str(_key).cast<std::string>() +
2885 ")";
2886 }
2887
2888 private:
2889 PyObject* _key;
2890
2891 // If immutable object and dict tag matches, we can skip the guard subtree and
2892 // return true.
2893 bool _is_immutable_object;
2894 };
2895
2896 /**
2897 * Represents list[index] accessor. It is faster than generic
2898 * GetItemGuardAccessor.
2899 */
2900 class ListGetItemGuardAccessor : public GuardAccessor {
2901 public:
ListGetItemGuardAccessor(RootGuardManager * root,const py::object & index,std::string source,py::handle example_value,py::handle guard_manager_enum)2902 ListGetItemGuardAccessor(
2903 RootGuardManager* root,
2904 const py::object& index,
2905 std::string source,
2906 py::handle example_value,
2907 py::handle guard_manager_enum)
2908 : GuardAccessor(
2909 root,
2910 index,
2911 std::move(source),
2912 example_value,
2913 guard_manager_enum),
2914 _index(py::cast<Py_ssize_t>(index)) {}
2915
2916 // NB: Intentional duplication between check_nopybind and
2917 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2918 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2919 override { // borrowed ref
2920 PyObject* x = PyList_GetItem(obj, _index); // borrowed ref
2921 if (x == nullptr) {
2922 PyErr_Clear();
2923 return false;
2924 }
2925 bool result = _guard_manager->check_nopybind(x);
2926 return result;
2927 }
2928
check_verbose_nopybind(PyObject * obj)2929 GuardDebugInfo check_verbose_nopybind(
2930 PyObject* obj) override { // borrowed ref
2931 PyObject* x = PyList_GetItem(obj, _index); // borrowed ref
2932 if (x == nullptr) {
2933 PyErr_Clear();
2934 return GuardDebugInfo(
2935 false, std::string("IndexError on ") + get_source(), 0);
2936 }
2937 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2938 return result;
2939 }
2940
repr() const2941 std::string repr() const override {
2942 return "ListGetItemGuardAccessor(" + std::to_string(_index) + ")";
2943 }
2944
2945 private:
2946 Py_ssize_t _index;
2947 };
2948
2949 /**
2950 * Represents tuple[index] accessor. It is faster than generic
2951 * GetItemGuardAccessor.
2952 */
2953 class TupleGetItemGuardAccessor : public GuardAccessor {
2954 public:
TupleGetItemGuardAccessor(RootGuardManager * root,const py::object & index,std::string source,py::handle example_value,py::handle guard_manager_enum)2955 TupleGetItemGuardAccessor(
2956 RootGuardManager* root,
2957 const py::object& index,
2958 std::string source,
2959 py::handle example_value,
2960 py::handle guard_manager_enum)
2961 : GuardAccessor(
2962 root,
2963 index,
2964 std::move(source),
2965 example_value,
2966 guard_manager_enum),
2967 _index(py::cast<Py_ssize_t>(index)) {}
2968
2969 // NB: Intentional duplication between check_nopybind and
2970 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2971 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2972 override { // borrowed ref
2973 PyObject* x = PyTuple_GetItem(obj, _index); // borrowed ref
2974 if (x == nullptr) {
2975 PyErr_Clear();
2976 return false;
2977 }
2978 bool result = _guard_manager->check_nopybind(x);
2979 return result;
2980 }
2981
check_verbose_nopybind(PyObject * obj)2982 GuardDebugInfo check_verbose_nopybind(
2983 PyObject* obj) override { // borrowed ref
2984 PyObject* x = PyTuple_GetItem(obj, _index); // borrowed ref
2985 if (x == nullptr) {
2986 PyErr_Clear();
2987 return GuardDebugInfo(
2988 false, std::string("IndexError on ") + get_source(), 0);
2989 }
2990 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2991 return result;
2992 }
2993
repr() const2994 std::string repr() const override {
2995 return "TupleGetItemGuardAccessor(" + std::to_string(_index) + ")";
2996 }
2997
2998 private:
2999 Py_ssize_t _index;
3000 };
3001
3002 /**
3003 * Represents tensor.grad acccessor.
3004 */
3005 class GradGuardAccessor : public GuardAccessor {
3006 public:
GradGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)3007 GradGuardAccessor(
3008 RootGuardManager* root,
3009 py::str name,
3010 std::string source,
3011 py::handle example_value,
3012 py::handle guard_manager_enum)
3013 : GuardAccessor(
3014 root,
3015 std::move(name),
3016 std::move(source),
3017 example_value,
3018 guard_manager_enum) {}
3019
3020 // NB: Intentional duplication between check_nopybind and
3021 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3022 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3023 override { // borrowed ref
3024 // check that its a tensor
3025 if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
3026 return false;
3027 }
3028 PyObject* grad =
3029 THPVariable_Wrap(THPVariable_Unpack(obj).grad()); // New reference
3030 bool result = _guard_manager->check_nopybind(grad);
3031 // For undefined tensor, THPVariable_Wrap returns Py_RETURN_NONE. So, no
3032 // need of Py_XDECREF.
3033 Py_DECREF(grad);
3034 return result;
3035 }
3036
check_verbose_nopybind(PyObject * obj)3037 GuardDebugInfo check_verbose_nopybind(
3038 PyObject* obj) override { // borrowed ref
3039 // check that its a tensor
3040 if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
3041 return GuardDebugInfo(
3042 false, "not a tensor - grad field is accessed " + get_source(), 0);
3043 }
3044 PyObject* grad =
3045 THPVariable_Wrap(THPVariable_Unpack(obj).grad()); // New reference
3046 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(grad);
3047 // For undefined tensor, THPVariable_Wrap returns Py_RETURN_NONE. So, no
3048 // need of Py_XDECREF.
3049 Py_DECREF(grad);
3050 return result;
3051 }
3052
repr() const3053 std::string repr() const override {
3054 // Helpful when priting GuardManager tree structure.
3055 return "GradGuardAccessor(grad)";
3056 }
3057 };
3058
3059 /**
3060 * Represents func.__defaults__ accessor.
3061 */
3062 class FuncDefaultsGuardAccessor : public GuardAccessor {
3063 public:
FuncDefaultsGuardAccessor(RootGuardManager * root,py::object name,std::string source,py::handle example_value,py::handle guard_manager_enum)3064 FuncDefaultsGuardAccessor(
3065 RootGuardManager* root,
3066 py::object name,
3067 std::string source,
3068 py::handle example_value,
3069 py::handle guard_manager_enum)
3070 : GuardAccessor(
3071 root,
3072 std::move(name),
3073 std::move(source),
3074 example_value,
3075 guard_manager_enum) {}
3076
3077 // NB: Intentional duplication between check_nopybind and
3078 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3079 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3080 override { // borrowed ref
3081 PyObject* func = obj;
3082 if (PyMethod_Check(obj)) {
3083 func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3084 } else if (PyInstanceMethod_Check(obj)) {
3085 func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3086 }
3087 PyObject* x = PyFunction_GetDefaults(func); // borrowed ref
3088 if (x == nullptr) {
3089 PyErr_Clear();
3090 return false;
3091 }
3092 return _guard_manager->check_nopybind(x);
3093 }
3094
check_verbose_nopybind(PyObject * obj)3095 GuardDebugInfo check_verbose_nopybind(
3096 PyObject* obj) override { // borrowed ref
3097 PyObject* func = obj;
3098 if (PyMethod_Check(obj)) {
3099 func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3100 } else if (PyInstanceMethod_Check(obj)) {
3101 func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3102 }
3103 PyObject* x = PyFunction_GetDefaults(func);
3104 if (x == nullptr) {
3105 PyErr_Clear();
3106 return GuardDebugInfo(
3107 false,
3108 std::string(repr() + ": Not a function on ") + get_source(),
3109 0);
3110 }
3111
3112 return _guard_manager->check_verbose_nopybind(x);
3113 }
3114
repr() const3115 std::string repr() const override {
3116 return "FuncDefaultsGuardAccessor";
3117 }
3118 };
3119
3120 /**
3121 * Represents func.__kwdefaults__ accessor.
3122 */
3123 class FuncKwDefaultsGuardAccessor : public GuardAccessor {
3124 public:
FuncKwDefaultsGuardAccessor(RootGuardManager * root,py::object name,std::string source,py::handle example_value,py::handle guard_manager_enum)3125 FuncKwDefaultsGuardAccessor(
3126 RootGuardManager* root,
3127 py::object name,
3128 std::string source,
3129 py::handle example_value,
3130 py::handle guard_manager_enum)
3131 : GuardAccessor(
3132 root,
3133 std::move(name),
3134 std::move(source),
3135 example_value,
3136 guard_manager_enum) {}
3137
3138 // NB: Intentional duplication between check_nopybind and
3139 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3140 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3141 override { // borrowed ref
3142 PyObject* func = obj;
3143 if (PyMethod_Check(obj)) {
3144 func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3145 } else if (PyInstanceMethod_Check(obj)) {
3146 func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3147 }
3148 PyObject* x = PyFunction_GetKwDefaults(func); // borrowed ref
3149 if (x == nullptr) {
3150 PyErr_Clear();
3151 return false;
3152 }
3153 return _guard_manager->check_nopybind(x);
3154 }
3155
check_verbose_nopybind(PyObject * obj)3156 GuardDebugInfo check_verbose_nopybind(
3157 PyObject* obj) override { // borrowed ref
3158 PyObject* func = obj;
3159 if (PyMethod_Check(obj)) {
3160 func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3161 } else if (PyInstanceMethod_Check(obj)) {
3162 func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3163 }
3164 PyObject* x = PyFunction_GetKwDefaults(func);
3165 if (x == nullptr) {
3166 PyErr_Clear();
3167 return GuardDebugInfo(
3168 false,
3169 std::string(repr() + ": Not a function on ") + get_source(),
3170 0);
3171 }
3172
3173 return _guard_manager->check_verbose_nopybind(x);
3174 }
3175
repr() const3176 std::string repr() const override {
3177 return "FuncKwDefaultsGuardAccessor";
3178 }
3179 };
3180
3181 /**
3182 * Represents f_globals acccessor. This sits as a child accessor of the
3183 * RootGuardManager.
3184 */
3185 class GlobalsGuardAccessor : public GuardAccessor {
3186 public:
GlobalsGuardAccessor(RootGuardManager * root,py::dict globals_dict,std::string source,py::handle example_value,py::handle guard_manager_enum)3187 GlobalsGuardAccessor(
3188 RootGuardManager* root,
3189 py::dict globals_dict,
3190 std::string source,
3191 py::handle example_value,
3192 py::handle guard_manager_enum)
3193 : GuardAccessor(
3194 root,
3195 globals_dict,
3196 std::move(source),
3197 example_value,
3198 guard_manager_enum),
3199 _globals_dict(globals_dict.ptr()) {}
3200
3201 // NB: Intentional duplication between check_nopybind and
3202 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3203 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3204 override { // borrowed ref
3205 // Ignore the obj arg. This is required to satisfy the function signature.
3206 // Just pass on the globals dict to the child manager.
3207 return _guard_manager->check_nopybind(_globals_dict);
3208 }
3209
check_verbose_nopybind(PyObject * obj)3210 GuardDebugInfo check_verbose_nopybind(
3211 PyObject* obj) override { // borrowed ref
3212 // Ignore the obj arg. This is required to satisfy the function signature.
3213 // Just pass on the globals dict to the child manager.
3214 return _guard_manager->check_verbose_nopybind(_globals_dict);
3215 }
3216
repr() const3217 std::string repr() const override {
3218 return "GlobalsGuardAccessor";
3219 }
3220
3221 private:
3222 // no need of py::object here because the globals_dict is already passed on to
3223 // the base class as accessor_key which is a py::object.
3224 PyObject* _globals_dict;
3225 };
3226
3227 /**
3228 * Represent type(...) accessor.
3229 */
3230 class TypeGuardAccessor : public GuardAccessor {
3231 public:
3232 // name = __type_accessor__, a unique string used as attribute name.
TypeGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)3233 TypeGuardAccessor(
3234 RootGuardManager* root,
3235 py::str name,
3236 std::string source,
3237 py::handle example_value,
3238 py::handle guard_manager_enum)
3239 : GuardAccessor(
3240 root,
3241 std::move(name),
3242 std::move(source),
3243 example_value,
3244 guard_manager_enum) {}
3245
3246 // NB: Intentional duplication between check_nopybind and
3247 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3248 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3249 override { // borrowed ref
3250 PyObject* x = (PyObject*)Py_TYPE(obj); // borrowed ref
3251 return _guard_manager->check_nopybind(x);
3252 }
3253
check_verbose_nopybind(PyObject * obj)3254 GuardDebugInfo check_verbose_nopybind(
3255 PyObject* obj) override { // borrowed ref
3256 PyObject* x = (PyObject*)Py_TYPE(obj); // borrowed ref
3257 return _guard_manager->check_verbose_nopybind(x);
3258 }
3259
repr() const3260 std::string repr() const override {
3261 return "TypeGuardAccessor";
3262 }
3263 };
3264
3265 /**
3266 * Getitem tuple_iterator accessor.
3267 */
3268 class TupleIteratorGetItemAccessor : public GuardAccessor {
3269 public:
TupleIteratorGetItemAccessor(RootGuardManager * root,py::object index,std::string source,py::handle example_value,py::handle guard_manager_enum)3270 TupleIteratorGetItemAccessor(
3271 RootGuardManager* root,
3272 py::object index,
3273 std::string source,
3274 py::handle example_value,
3275 py::handle guard_manager_enum)
3276 : GuardAccessor(
3277 root,
3278 index,
3279 std::move(source),
3280 example_value,
3281 guard_manager_enum),
3282 _index(py::cast<Py_ssize_t>(std::move(index))) {}
3283
3284 // NB: Intentional duplication between check_nopybind and
3285 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3286 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3287 override { // borrowed ref
3288 _PyTupleIterObject* it = (_PyTupleIterObject*)obj;
3289 PyObject* x =
3290 PyTuple_GET_ITEM(it->it_seq, it->it_index + _index); // borrowed ref
3291 if (x == nullptr) {
3292 // Out of range.
3293 PyErr_Clear();
3294 return false;
3295 }
3296 bool result = _guard_manager->check_nopybind(x);
3297 return result;
3298 }
3299
check_verbose_nopybind(PyObject * obj)3300 GuardDebugInfo check_verbose_nopybind(
3301 PyObject* obj) override { // borrowed ref
3302 _PyTupleIterObject* it = (_PyTupleIterObject*)obj;
3303 PyObject* x =
3304 PyTuple_GET_ITEM(it->it_seq, it->it_index + _index); // borrowed ref
3305 if (x == nullptr) {
3306 // Out of range.
3307 PyErr_Clear();
3308 return GuardDebugInfo(false, std::string("IndexError ") + repr(), 0);
3309 }
3310 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
3311 return result;
3312 }
3313
repr() const3314 std::string repr() const override {
3315 return "TupleIteratorGetItemAccessor(" + std::to_string(_index) + ")";
3316 }
3317
3318 private:
3319 Py_ssize_t _index;
3320 };
3321
3322 /**
3323 * GlobalWeakRef accessor. Dynamo can insert a weakref object into the frame
3324 * globals. This accessor reads the globals and then calls the weakref object
3325 * to get the underlying object. This is a child of GlobalsGuardAccessor.
3326 * Therefore, we will get the globals dict while caling check_nopybind.
3327 */
3328 class GlobalWeakRefGuardAccessor : public GuardAccessor {
3329 public:
GlobalWeakRefGuardAccessor(RootGuardManager * root,py::object global_name,std::string source,py::handle example_value,py::handle guard_manager_enum)3330 GlobalWeakRefGuardAccessor(
3331 RootGuardManager* root,
3332 py::object global_name,
3333 std::string source,
3334 py::handle example_value,
3335 py::handle guard_manager_enum)
3336 : GuardAccessor(
3337 root,
3338 global_name,
3339 std::move(source),
3340 example_value,
3341 guard_manager_enum),
3342 _global_name(global_name.ptr()) {}
3343
3344 // NB: Intentional duplication between check_nopybind and
3345 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3346 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3347 override { // borrowed ref
3348 // obj is globals dict because GlobalWeakRefGuardAccessor has to be a
3349 // child of GlobalsGuardAccessor.
3350 PyObject* weakref = PyDict_GetItem(obj, _global_name); // borrowed ref
3351 if (weakref == nullptr) {
3352 // The weakref is not in the globals dict.
3353 PyErr_Clear();
3354 return false;
3355 }
3356
3357 if (!PyWeakref_Check(weakref)) {
3358 return false;
3359 }
3360
3361 PyObject* x = PyWeakref_GetObject(weakref); // borrowed ref
3362 return _guard_manager->check_nopybind(x);
3363 }
3364
check_verbose_nopybind(PyObject * obj)3365 GuardDebugInfo check_verbose_nopybind(
3366 PyObject* obj) override { // borrowed ref
3367 // obj is globals dict because GlobalWeakRefGuardAccessor has to be a
3368 // child of GlobalsGuardAccessor.
3369 PyObject* weakref = PyDict_GetItem(obj, _global_name); // borrowed ref
3370 if (weakref == nullptr) {
3371 // The weakref is not in the globals dict.
3372 PyErr_Clear();
3373 return GuardDebugInfo(
3374 false, std::string("KeyError on ") + get_source(), 0);
3375 }
3376
3377 if (!PyWeakref_Check(weakref)) {
3378 return GuardDebugInfo(
3379 false, std::string("Not a weakref ") + get_source(), 0);
3380 }
3381
3382 PyObject* x = PyWeakref_GetObject(weakref); // borrowed ref
3383 return _guard_manager->check_verbose_nopybind(x);
3384 }
3385
repr() const3386 std::string repr() const override {
3387 return "GlobalWeakRefGuardAccessor(" +
3388 py::str(_global_name).cast<std::string>() + ")";
3389 }
3390
3391 private:
3392 PyObject* _global_name;
3393 };
3394
3395 /**
3396 * Implements weakref call - x_weak()
3397 */
3398 class WeakRefCallGuardAccessor : public GuardAccessor {
3399 public:
WeakRefCallGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)3400 WeakRefCallGuardAccessor(
3401 RootGuardManager* root,
3402 py::str name,
3403 std::string source,
3404 py::handle example_value,
3405 py::handle guard_manager_enum)
3406 : GuardAccessor(
3407 root,
3408 std::move(name),
3409 std::move(source),
3410 example_value,
3411 guard_manager_enum) {}
3412
3413 // NB: Intentional duplication between check_nopybind and
3414 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3415 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3416 override { // borrowed ref
3417 if (!PyWeakref_Check(obj)) {
3418 return false;
3419 }
3420
3421 PyObject* x = PyWeakref_GetObject(obj); // borrowed ref
3422 return _guard_manager->check_nopybind(x);
3423 }
3424
check_verbose_nopybind(PyObject * obj)3425 GuardDebugInfo check_verbose_nopybind(
3426 PyObject* obj) override { // borrowed ref
3427 if (!PyWeakref_Check(obj)) {
3428 return GuardDebugInfo(
3429 false, std::string("Not a weakref obj ") + get_source(), 0);
3430 }
3431
3432 PyObject* x = PyWeakref_GetObject(obj); // borrowed ref
3433 return _guard_manager->check_verbose_nopybind(x);
3434 }
3435
repr() const3436 std::string repr() const override {
3437 return "WeakRefCallGuardAccessor()";
3438 }
3439 };
3440
3441 /**
3442 * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to
3443 * supply accessor as a python function. This is useful for from_numpy source.
3444 */
3445 class PythonLambdaGuardAccessor : public GuardAccessor {
3446 public:
PythonLambdaGuardAccessor(RootGuardManager * root,py::function accessor_fn,std::string source,py::handle example_value,py::handle guard_manager_enum)3447 PythonLambdaGuardAccessor(
3448 RootGuardManager* root,
3449 py::function accessor_fn,
3450 std::string source,
3451 py::handle example_value,
3452 py::handle guard_manager_enum)
3453 : GuardAccessor(
3454 root,
3455 accessor_fn,
3456 std::move(source),
3457 example_value,
3458 guard_manager_enum),
3459 _accessor_fn(std::move(accessor_fn)) {}
3460
3461 // NB: Intentional duplication between check_nopybind and
3462 // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3463 bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3464 override { // borrowed ref
3465 PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
3466 if (x == nullptr) {
3467 // The accessor function failed.
3468 PyErr_Clear();
3469 return false;
3470 }
3471 bool result = _guard_manager->check_nopybind(x);
3472 Py_DECREF(x);
3473 return result;
3474 }
3475
check_verbose_nopybind(PyObject * obj)3476 GuardDebugInfo check_verbose_nopybind(
3477 PyObject* obj) override { // borrowed ref
3478 PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
3479 if (x == nullptr) {
3480 // The accessor function failed.
3481 std::string exc_message = get_exception_message();
3482 PyErr_Clear();
3483 return GuardDebugInfo(false, exc_message, 0);
3484 }
3485 GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
3486 Py_DECREF(x);
3487 return result;
3488 }
3489
repr() const3490 std::string repr() const override {
3491 return "PythonLambdaGuardAccessor";
3492 }
3493
3494 private:
3495 py::object _accessor_fn;
3496 };
3497
install_object_aliasing_guard(GuardManager * x,GuardManager * y,py::object verbose_code_parts)3498 void install_object_aliasing_guard(
3499 GuardManager* x,
3500 GuardManager* y,
3501 py::object verbose_code_parts) {
3502 // Adds tensor X is tensor Y guard. This is a an example of relational guard.
3503 // There is one guard object that is shared between two guard managers.
3504 std::shared_ptr<RelationalGuard> guard =
3505 std::make_shared<OBJECT_ALIASING>(std::move(verbose_code_parts));
3506
3507 // Register the resetter on the toor guard mananger, so that it can reset
3508 // the newly added relational guard when the guard eval fails.
3509 x->get_root()->add_relational_guard_resetter(guard);
3510
3511 // In case the guard is a DictGuardManager, OBJECT_ALIASING guard is a
3512 // permitted guard.
3513 x->add_permitted_leaf_guard(guard);
3514 y->add_permitted_leaf_guard(guard);
3515 }
3516
install_no_tensor_aliasing_guard(const py::list & guard_managers,const py::list & tensor_names,py::object verbose_code_parts)3517 void install_no_tensor_aliasing_guard(
3518 const py::list& guard_managers,
3519 const py::list& tensor_names,
3520 py::object verbose_code_parts) {
3521 // Adds a guard that checks none of tensors alias. This is a an example of
3522 // relational guard. There is one guard object that is shared between multiple
3523 // guard managers.
3524 std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>(
3525 tensor_names, std::move(verbose_code_parts));
3526
3527 // Register the resetter on the toor guard mananger, so that it can reset
3528 // the newly added relational guard when the guard eval fails.
3529 py::cast<GuardManager*>(guard_managers[0])
3530 ->get_root()
3531 ->add_relational_guard_resetter(guard);
3532 for (const auto& guard_manager : guard_managers) {
3533 py::cast<GuardManager*>(guard_manager)->add_leaf_guard(guard);
3534 }
3535 }
3536
3537 } // namespace
3538
_torchinductor_pyobject_tensor_data_ptr(PyObject * obj)3539 static void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
3540 if (C10_UNLIKELY(
3541 obj == nullptr ||
3542 (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)))) {
3543 throw std::runtime_error(
3544 "_torchinductor_pyobject_tensor_data_ptr: non-tensor input");
3545 }
3546 return THPVariable_Unpack(obj).data_ptr();
3547 }
3548
convert_to_root_guard_manager(py::object root)3549 void* convert_to_root_guard_manager(py::object root) {
3550 RootGuardManager* root_mgr = std::move(root).cast<RootGuardManager*>();
3551 return (void*)root_mgr;
3552 }
3553
run_root_guard_manager(void * root,PyObject * f_locals)3554 bool run_root_guard_manager(void* root, PyObject* f_locals) {
3555 return ((RootGuardManager*)root)->check_nopybind(f_locals);
3556 }
3557
torch_c_dynamo_guards_init()3558 PyObject* torch_c_dynamo_guards_init() {
3559 // initialize TensorGuardsType
3560 TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
3561 TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
3562 TensorGuardsType.tp_itemsize = 0;
3563 TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc;
3564 TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT;
3565 TensorGuardsType.tp_doc = "Check properties of a torch.Tensor";
3566 TensorGuardsType.tp_methods = TensorGuards_methods;
3567 TensorGuardsType.tp_init = (initproc)TensorGuards_init;
3568 TensorGuardsType.tp_new = TensorGuards_new;
3569
3570 if (PyType_Ready(&TensorGuardsType) < 0)
3571 return nullptr;
3572
3573 GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
3574 GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
3575 GlobalStateGuardType.tp_itemsize = 0;
3576 GlobalStateGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
3577 GlobalStateGuardType.tp_doc = "Guard on PyTorch global flags such as no_grad";
3578 GlobalStateGuardType.tp_methods = GlobalStateGuard_methods;
3579 GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init;
3580 GlobalStateGuardType.tp_new = PyType_GenericNew;
3581
3582 if (PyType_Ready(&GlobalStateGuardType) < 0)
3583 return nullptr;
3584
3585 auto m = PyModule_Create(&_module);
3586 if (m == nullptr)
3587 return nullptr;
3588
3589 Py_INCREF(&TensorGuardsType);
3590 if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
3591 Py_DECREF(&TensorGuardsType);
3592 Py_DECREF(m);
3593 return nullptr;
3594 }
3595
3596 Py_INCREF(&GlobalStateGuardType);
3597 if (PyModule_AddObject(
3598 m, "GlobalStateGuard", (PyObject*)&GlobalStateGuardType) < 0) {
3599 Py_DECREF(&GlobalStateGuardType);
3600 Py_DECREF(m);
3601 return nullptr;
3602 }
3603
3604 // We expose the address of _torchinductor_pyobject_tensor_data_ptr in order
3605 // to allow manual linking in our generated TorchInductor Python bindings.
3606 // While regular linking works in most cases, it does not work properly in
3607 // fbcode due to janky build setup there.
3608 if (PyModule_AddObject(
3609 m,
3610 "_torchinductor_pyobject_tensor_data_ptr",
3611 PyLong_FromVoidPtr(reinterpret_cast<void*>(
3612 &_torchinductor_pyobject_tensor_data_ptr))) < 0) {
3613 return nullptr;
3614 }
3615
3616 auto py_m = py::handle(m).cast<py::module>();
3617 py::class_<GuardDebugInfo, std::unique_ptr<GuardDebugInfo>>(
3618 py_m, "GuardDebugInfo")
3619 .def(py::init<bool, py::list, int>())
3620 .def("__str__", &GuardDebugInfo::to_string)
3621 .def_readonly("result", &GuardDebugInfo::result)
3622 .def_readonly("verbose_code_parts", &GuardDebugInfo::verbose_code_parts)
3623 .def_readonly(
3624 "num_guards_executed", &GuardDebugInfo::num_guards_executed);
3625
3626 // Leaf Guards
3627 py::class_<LeafGuard, std::shared_ptr<LeafGuard>>(py_m, "LeafGuard")
3628 .def("verbose_code_parts", &LeafGuard::verbose_code_parts);
3629 py::class_<LAMBDA_GUARD, LeafGuard, std::shared_ptr<LAMBDA_GUARD>>(
3630 py_m, "LAMBDA_GUARD")
3631 .def(py::init<py::function, py::list>())
3632 .def("__call__", &LAMBDA_GUARD::check);
3633 py::class_<TYPE_MATCH, LeafGuard, std::shared_ptr<TYPE_MATCH>>(
3634 py_m, "TYPE_MATCH")
3635 .def(py::init<py::object, py::list>())
3636 .def("__call__", &TYPE_MATCH::check);
3637 py::class_<ID_MATCH, LeafGuard, std::shared_ptr<ID_MATCH>>(py_m, "ID_MATCH")
3638 .def(py::init<py::object, py::list>())
3639 .def("__call__", &ID_MATCH::check);
3640 py::class_<EQUALS_MATCH, LeafGuard, std::shared_ptr<EQUALS_MATCH>>(
3641 py_m, "EQUALS_MATCH")
3642 .def(py::init<py::object, py::list>())
3643 .def("__call__", &EQUALS_MATCH::check);
3644 py::class_<LENGTH_CHECK, LeafGuard, std::shared_ptr<LENGTH_CHECK>>(
3645 py_m, "LENGTH_CHECK")
3646 .def(py::init<py::object, py::list>())
3647 .def("__call__", &LENGTH_CHECK::check);
3648 py::class_<DICT_LENGTH, LeafGuard, std::shared_ptr<DICT_LENGTH>>(
3649 py_m, "DICT_LENGTH")
3650 .def(py::init<py::object, py::list>())
3651 .def("__call__", &DICT_LENGTH::check);
3652 py::class_<DEFAULT_DEVICE, LeafGuard, std::shared_ptr<DEFAULT_DEVICE>>(
3653 py_m, "DEFAULT_DEVICE")
3654 .def(py::init<py::list>())
3655 .def("__call__", &DEFAULT_DEVICE::check);
3656 py::class_<NOT_NONE, LeafGuard, std::shared_ptr<NOT_NONE>>(py_m, "NOT_NONE")
3657 .def(py::init<py::list>())
3658 .def("__call__", &NOT_NONE::check);
3659 py::class_<
3660 TUPLE_ITERATOR_LEN,
3661 LeafGuard,
3662 std::shared_ptr<TUPLE_ITERATOR_LEN>>(py_m, "TUPLE_ITERATOR_LEN")
3663 .def(py::init<py::object, py::object, py::list>())
3664 .def("__call__", &TUPLE_ITERATOR_LEN::check);
3665 py::class_<GLOBAL_STATE, LeafGuard, std::shared_ptr<GLOBAL_STATE>>(
3666 py_m, "GLOBAL_STATE")
3667 .def(py::init<py::list>())
3668 .def("check_verbose", &GLOBAL_STATE::check_verbose)
3669 .def("__call__", &GLOBAL_STATE::check);
3670 py::class_<
3671 TORCH_FUNCTION_MODE_STACK,
3672 LeafGuard,
3673 std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
3674 py_m, "TORCH_FUNCTION_MODE_STACK")
3675 .def(py::init<py::list, py::list, py::list>())
3676 .def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
3677 py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
3678 py_m, "DATA_PTR_MATCH")
3679 .def(py::init<py::object, py::list>())
3680 .def("__call__", &DATA_PTR_MATCH::check);
3681 py::class_<NO_HASATTR, LeafGuard, std::shared_ptr<NO_HASATTR>>(
3682 py_m, "NO_HASATTR")
3683 .def(py::init<py::object, py::list>())
3684 .def("__call__", &NO_HASATTR::check);
3685 py::class_<DICT_CONTAINS, LeafGuard, std::shared_ptr<DICT_CONTAINS>>(
3686 py_m, "DICT_CONTAINS")
3687 .def(py::init<bool, py::object, py::list>())
3688 .def("__call__", &DICT_CONTAINS::check);
3689 py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
3690 py_m, "DYNAMIC_INDICES")
3691 .def(py::init<py::set, py::list>())
3692 .def("__call__", &DYNAMIC_INDICES::check);
3693 py::class_<DICT_VERSION, LeafGuard, std::shared_ptr<DICT_VERSION>>(
3694 py_m, "DICT_VERSION")
3695 .def(py::init<py::object, py::list>())
3696 .def("__call__", &DICT_VERSION::check);
3697 py::class_<TENSOR_MATCH, LeafGuard, std::shared_ptr<TENSOR_MATCH>>(
3698 py_m, "TENSOR_MATCH")
3699 .def(py::init<
3700 RootGuardManager*,
3701 py::object,
3702 py::object,
3703 py::object,
3704 py::str,
3705 py::list>())
3706 .def("__call__", &TENSOR_MATCH::check);
3707 // NOLINTNEXTLINE(bugprone-unused-raii)
3708 py::class_<OBJECT_ALIASING, LeafGuard, std::shared_ptr<OBJECT_ALIASING>>(
3709 py_m, "OBJECT_ALIASING");
3710 // NOLINTNEXTLINE(bugprone-unused-raii)
3711 py::class_<
3712 NO_TENSOR_ALIASING,
3713 LeafGuard,
3714 std::shared_ptr<NO_TENSOR_ALIASING>>(py_m, "NO_TENSOR_ALIASING");
3715
3716 // Guard Accessors - These are present so that we can iterate over the
3717 // GuardManager hierarchy. We intentionally do not provide even an init
3718 // function on these, because these should be constructed from within C++.
3719 py::class_<GuardAccessor, std::unique_ptr<GuardAccessor>>(
3720 py_m, "GuardAccessor")
3721 .def("repr", &GuardAccessor::repr);
3722 // NOLINTNEXTLINE(bugprone-unused-raii)
3723 py::class_<
3724 GetAttrGuardAccessor,
3725 GuardAccessor,
3726 std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor");
3727 // NOLINTNEXTLINE(bugprone-unused-raii)
3728 py::class_<
3729 GetGenericDictGuardAccessor,
3730 GuardAccessor,
3731 std::unique_ptr<GetGenericDictGuardAccessor>>(
3732 py_m, "GetGenericDictGuardAccessor");
3733 // NOLINTNEXTLINE(bugprone-unused-raii)
3734 py::class_<
3735 GetItemGuardAccessor,
3736 GuardAccessor,
3737 std::unique_ptr<GetItemGuardAccessor>>(py_m, "GetItemGuardAccessor");
3738 // NOLINTNEXTLINE(bugprone-unused-raii)
3739 py::class_<
3740 DictGetItemGuardAccessor,
3741 GuardAccessor,
3742 std::unique_ptr<DictGetItemGuardAccessor>>(
3743 py_m, "DictGetItemGuardAccessor");
3744 // NOLINTNEXTLINE(bugprone-unused-raii)
3745 py::class_<
3746 ListGetItemGuardAccessor,
3747 GuardAccessor,
3748 std::unique_ptr<ListGetItemGuardAccessor>>(
3749 py_m, "ListGetItemGuardAccessor");
3750 // NOLINTNEXTLINE(bugprone-unused-raii)
3751 py::class_<
3752 TupleGetItemGuardAccessor,
3753 GuardAccessor,
3754 std::unique_ptr<TupleGetItemGuardAccessor>>(
3755 py_m, "TupleGetItemGuardAccessor");
3756 // NOLINTNEXTLINE(bugprone-unused-raii)
3757 py::class_<
3758 FuncDefaultsGuardAccessor,
3759 GuardAccessor,
3760 std::unique_ptr<FuncDefaultsGuardAccessor>>(
3761 py_m, "FuncDefaultsGuardAccessor");
3762 // NOLINTNEXTLINE(bugprone-unused-raii)
3763 py::class_<
3764 FuncKwDefaultsGuardAccessor,
3765 GuardAccessor,
3766 std::unique_ptr<FuncKwDefaultsGuardAccessor>>(
3767 py_m, "FuncKwDefaultsGuardAccessor");
3768 // NOLINTNEXTLINE(bugprone-unused-raii)
3769 py::class_<
3770 GlobalsGuardAccessor,
3771 GuardAccessor,
3772 std::unique_ptr<GlobalsGuardAccessor>>(py_m, "GlobalsGuardAccessor");
3773 // NOLINTNEXTLINE(bugprone-unused-raii)
3774 py::class_<
3775 TypeGuardAccessor,
3776 GuardAccessor,
3777 std::unique_ptr<TypeGuardAccessor>>(py_m, "TypeGuardAccessor");
3778 // NOLINTNEXTLINE(bugprone-unused-raii)
3779 py::class_<
3780 WeakRefCallGuardAccessor,
3781 GuardAccessor,
3782 std::unique_ptr<WeakRefCallGuardAccessor>>(
3783 py_m, "WeakRefCallGuardAccessor");
3784 // NOLINTNEXTLINE(bugprone-unused-raii)
3785 py::class_<
3786 TupleIteratorGetItemAccessor,
3787 GuardAccessor,
3788 std::unique_ptr<TupleIteratorGetItemAccessor>>(
3789 py_m, "TupleIteratorGetItemAccessor");
3790 // NOLINTNEXTLINE(bugprone-unused-raii)
3791 py::class_<
3792 GlobalWeakRefGuardAccessor,
3793 GuardAccessor,
3794 std::unique_ptr<GlobalWeakRefGuardAccessor>>(
3795 py_m, "GlobalWeakRefGuardAccessor");
3796
3797 // Guard Manager - No constructor in python, python should use
3798 // RootGuardManager.
3799 py::class_<GuardManager, std::unique_ptr<GuardManager>>(py_m, "GuardManager")
3800 // return by reference because GuardManager has the ownership of accessors
3801 .def("get_source", &GuardManager::get_source)
3802 .def(
3803 "get_accessors",
3804 &GuardManager::get_accessors,
3805 py::return_value_policy::reference)
3806 // return by reference because GuardManager has the ownership of child
3807 // managers
3808 .def(
3809 "get_child_managers",
3810 &GuardManager::get_child_managers,
3811 py::return_value_policy::reference)
3812 // return by reference because GuardManager has the ownership of leaf
3813 // guards
3814 .def(
3815 "get_leaf_guards",
3816 &GuardManager::get_leaf_guards,
3817 py::return_value_policy::reference)
3818 .def(
3819 "add_lambda_guard",
3820 [](GuardManager& self,
3821 py::object lambda,
3822 py::object verbose_code_parts) -> void {
3823 self.add_leaf_guard(std::make_shared<LAMBDA_GUARD>(
3824 std::move(lambda), std::move(verbose_code_parts)));
3825 })
3826 .def(
3827 "add_type_match_guard",
3828 [](GuardManager& self,
3829 py::object value,
3830 py::object verbose_code_parts) -> void {
3831 SKIP_IF_GUARD_ALREADY_PRESENT("TYPE_MATCH");
3832 self.add_leaf_guard(std::make_shared<TYPE_MATCH>(
3833 std::move(value), std::move(verbose_code_parts)));
3834 })
3835 .def(
3836 "add_id_match_guard",
3837 [](GuardManager& self,
3838 py::object value,
3839 py::object verbose_code_parts) -> void {
3840 SKIP_IF_GUARD_ALREADY_PRESENT("ID_MATCH");
3841 self.add_leaf_guard(std::make_shared<ID_MATCH>(
3842 std::move(value), std::move(verbose_code_parts)));
3843 })
3844 .def(
3845 "add_equals_match_guard",
3846 [](GuardManager& self,
3847 py::object value,
3848 py::object verbose_code_parts) -> void {
3849 SKIP_IF_GUARD_ALREADY_PRESENT("EQUALS_MATCH");
3850 self.add_leaf_guard(std::make_shared<EQUALS_MATCH>(
3851 std::move(value), std::move(verbose_code_parts)));
3852 })
3853 .def(
3854 "add_length_check_guard",
3855 [](GuardManager& self,
3856 py::object value,
3857 py::object verbose_code_parts) -> void {
3858 SKIP_IF_GUARD_ALREADY_PRESENT("LENGTH_CHECK");
3859 self.add_leaf_guard(std::make_shared<LENGTH_CHECK>(
3860 std::move(value), std::move(verbose_code_parts)));
3861 })
3862 .def(
3863 "add_dict_length_check_guard",
3864 [](GuardManager& self,
3865 py::object value,
3866 py::object verbose_code_parts) -> void {
3867 SKIP_IF_GUARD_ALREADY_PRESENT("DICT_LENGTH");
3868 self.add_leaf_guard(std::make_shared<DICT_LENGTH>(
3869 std::move(value), std::move(verbose_code_parts)));
3870 })
3871 .def(
3872 "add_tuple_iterator_length_guard",
3873 [](GuardManager& self,
3874 py::object length,
3875 py::object type_id,
3876 py::object verbose_code_parts) -> void {
3877 SKIP_IF_GUARD_ALREADY_PRESENT("TUPLE_ITERATOR_LEN");
3878 self.add_leaf_guard(std::make_shared<TUPLE_ITERATOR_LEN>(
3879 std::move(length),
3880 std::move(type_id),
3881 std::move(verbose_code_parts)));
3882 })
3883 .def(
3884 "add_default_device_guard",
3885 [](GuardManager& self, py::object verbose_code_parts) -> void {
3886 self.add_leaf_guard(std::make_shared<DEFAULT_DEVICE>(
3887 std::move(verbose_code_parts)));
3888 })
3889 .def(
3890 "add_not_none_guard",
3891 [](GuardManager& self, py::object verbose_code_parts) -> void {
3892 SKIP_IF_GUARD_ALREADY_PRESENT("NOT_NONE");
3893 self.add_leaf_guard(
3894 std::make_shared<NOT_NONE>(std::move(verbose_code_parts)));
3895 })
3896 .def(
3897 "add_global_state_guard",
3898 [](GuardManager& self, py::object verbose_code_parts) -> void {
3899 self.add_leaf_guard(
3900 std::make_shared<GLOBAL_STATE>(std::move(verbose_code_parts)));
3901 })
3902 .def(
3903 "add_torch_function_mode_stack_guard",
3904 [](GuardManager& self,
3905 const py::list& initial_stack,
3906 const py::list& ignored_types,
3907 py::object verbose_code_parts) -> void {
3908 self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
3909 initial_stack, ignored_types, std::move(verbose_code_parts)));
3910 })
3911 .def(
3912 "add_data_ptr_guard",
3913 [](GuardManager& self,
3914 py::object data_ptr,
3915 py::object verbose_code_parts) -> void {
3916 SKIP_IF_GUARD_ALREADY_PRESENT("DATA_PTR_MATCH");
3917 self.add_leaf_guard(std::make_shared<DATA_PTR_MATCH>(
3918 std::move(data_ptr), std::move(verbose_code_parts)));
3919 })
3920 .def(
3921 "add_no_hasattr_guard",
3922 [](GuardManager& self,
3923 py::object attr_name,
3924 py::object verbose_code_parts) -> void {
3925 self.add_leaf_guard(std::make_shared<NO_HASATTR>(
3926 std::move(attr_name), std::move(verbose_code_parts)));
3927 })
3928 .def(
3929 "add_dict_contains_guard",
3930 [](GuardManager& self,
3931 bool contains,
3932 py::object key,
3933 py::object verbose_code_parts) -> void {
3934 self.add_leaf_guard(std::make_shared<DICT_CONTAINS>(
3935 contains, std::move(key), std::move(verbose_code_parts)));
3936 })
3937 .def(
3938 "add_dynamic_indices_guard",
3939 [](GuardManager& self,
3940 py::set value,
3941 py::object verbose_code_parts) -> void {
3942 self.add_leaf_guard(std::make_shared<DYNAMIC_INDICES>(
3943 std::move(value), std::move(verbose_code_parts)));
3944 })
3945 .def(
3946 "add_dict_version_guard",
3947 [](GuardManager& self,
3948 py::object value,
3949 py::object verbose_code_parts) -> void {
3950 self.add_leaf_guard(std::make_shared<DICT_VERSION>(
3951 std::move(value), std::move(verbose_code_parts)));
3952 })
3953 .def(
3954 "add_tensor_match_guard",
3955 [](GuardManager& self,
3956 py::object value,
3957 py::object sizes,
3958 py::object strides,
3959 py::object tensor_name,
3960 py::object verbose_code_parts) -> void {
3961 SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH");
3962 self.add_leaf_guard(std::make_shared<TENSOR_MATCH>(
3963 self.get_root(),
3964 std::move(value),
3965 std::move(sizes),
3966 std::move(strides),
3967 std::move(tensor_name),
3968 std::move(verbose_code_parts)));
3969 })
3970
3971 // return by reference because GuardManager has the ownership of accessors
3972 // and guard managers
3973 .def(
3974 "getitem_manager",
3975 &GuardManager::get_child_manager<GetItemGuardAccessor>,
3976 py::arg("key"),
3977 py::arg("source"),
3978 py::arg("example_value"),
3979 py::arg("guard_manager_enum"),
3980 py::return_value_policy::reference)
3981 // return by reference because GuardManager has the ownership of accessors
3982 // and guard managers
3983 .def(
3984 "dict_getitem_manager",
3985 &GuardManager::get_child_manager<DictGetItemGuardAccessor>,
3986 py::arg("key"),
3987 py::arg("source"),
3988 py::arg("example_value"),
3989 py::arg("guard_manager_enum"),
3990 py::return_value_policy::reference)
3991 // return by reference because GuardManager has the ownership of accessors
3992 // and guard managers
3993 .def(
3994 "list_getitem_manager",
3995 &GuardManager::get_child_manager<ListGetItemGuardAccessor>,
3996 py::arg("key"),
3997 py::arg("source"),
3998 py::arg("example_value"),
3999 py::arg("guard_manager_enum"),
4000 py::return_value_policy::reference)
4001 // return by reference because GuardManager has the ownership of accessors
4002 // and guard managers
4003 .def(
4004 "tuple_getitem_manager",
4005 &GuardManager::get_child_manager<TupleGetItemGuardAccessor>,
4006 py::arg("key"),
4007 py::arg("source"),
4008 py::arg("example_value"),
4009 py::arg("guard_manager_enum"),
4010 py::return_value_policy::reference)
4011 // return by reference because GuardManager has the ownership of accessors
4012 // and guard managers
4013 .def(
4014 "func_defaults_manager",
4015 [](GuardManager& self,
4016 std::string source,
4017 py::object example_value,
4018 py::handle guard_manager_enum) -> GuardManager* {
4019 // A unique key is used to save as the accessor key.
4020 py::str unique_key("__defaults_accessor__");
4021 return self.get_child_manager<FuncDefaultsGuardAccessor>(
4022 std::move(unique_key),
4023 std::move(source),
4024 std::move(example_value),
4025 guard_manager_enum);
4026 },
4027 py::arg("source"),
4028 py::arg("example_value"),
4029 py::arg("guard_manager_enum"),
4030 py::return_value_policy::reference)
4031
4032 // return by reference because GuardManager has the ownership of accessors
4033 // and guard managers
4034 .def(
4035 "func_kwdefaults_manager",
4036 [](GuardManager& self,
4037 std::string source,
4038 py::object example_value,
4039 py::handle guard_manager_enum) -> GuardManager* {
4040 // A unique key is used to save as the accessor key.
4041 py::str unique_key("__kwdefaults_accessor__");
4042 return self.get_child_manager<FuncKwDefaultsGuardAccessor>(
4043 std::move(unique_key),
4044 std::move(source),
4045 std::move(example_value),
4046 guard_manager_enum);
4047 },
4048 py::arg("source"),
4049 py::arg("example_value"),
4050 py::arg("guard_manager_enum"),
4051 py::return_value_policy::reference)
4052 // return by reference because GuardManager has the ownership of accessors
4053 // and guard managers
4054 .def(
4055 "globals_dict_manager",
4056 &GuardManager::get_child_manager<GlobalsGuardAccessor>,
4057 py::arg("f_globals"),
4058 py::arg("source"),
4059 py::arg("example_value"),
4060 py::arg("guard_manager_enum"),
4061 py::return_value_policy::reference)
4062 // return by reference because GuardManager has the ownership of accessors
4063 // and guard managers
4064 .def(
4065 "type_manager",
4066 [](GuardManager& self,
4067 std::string source,
4068 py::handle example_value,
4069 py::handle guard_manager_enum) -> GuardManager* {
4070 // A unique key is used to save as the accessor key.
4071 py::str unique_key("__type_accessor__");
4072 return self.get_child_manager<TypeGuardAccessor>(
4073 std::move(unique_key),
4074 std::move(source),
4075 example_value,
4076 guard_manager_enum);
4077 },
4078 py::arg("source"),
4079 py::arg("example_value"),
4080 py::arg("guard_manager_enum"),
4081 py::return_value_policy::reference)
4082 // return by reference because GuardManager has the ownership of accessors
4083 // and guard managers
4084 .def(
4085 "weakref_call_manager",
4086 [](GuardManager& self,
4087 std::string source,
4088 py::handle example_value,
4089 py::handle guard_manager_enum) -> GuardManager* {
4090 // A unique key is used to save as the accessor key.
4091 py::str unique_key("__weakref_call_accessor__");
4092 return self.get_child_manager<WeakRefCallGuardAccessor>(
4093 std::move(unique_key),
4094 std::move(source),
4095 example_value,
4096 guard_manager_enum);
4097 },
4098 py::arg("source"),
4099 py::arg("example_value"),
4100 py::arg("guard_manager_enum"),
4101 py::return_value_policy::reference)
4102 // return by reference because GuardManager has the ownership of accessors
4103 // and guard managers
4104 .def(
4105 "tuple_iterator_getitem_manager",
4106 &GuardManager::get_child_manager<TupleIteratorGetItemAccessor>,
4107 py::arg("index"),
4108 py::arg("source"),
4109 py::arg("example_value"),
4110 py::arg("guard_manager_enum"),
4111 py::return_value_policy::reference)
4112 // return by reference because GuardManager has the ownership of accessors
4113 // and guard managers
4114 .def(
4115 "global_weakref_manager",
4116 &GuardManager::get_child_manager<GlobalWeakRefGuardAccessor>,
4117 py::arg("global_name"),
4118 py::arg("source"),
4119 py::arg("example_value"),
4120 py::arg("guard_manager_enum"),
4121 py::return_value_policy::reference)
4122 // return by reference because GuardManager has the ownership of accessors
4123 // and guard managers
4124 .def(
4125 "lambda_manager",
4126 &GuardManager::get_child_manager<PythonLambdaGuardAccessor>,
4127 py::arg("python_lambda"),
4128 py::arg("source"),
4129 py::arg("example_value"),
4130 py::arg("guard_manager_enum"),
4131 py::return_value_policy::reference)
4132 // return by reference because GuardManager has the ownership of accessors
4133 // and guard managers
4134 .def(
4135 "grad_manager",
4136 [](GuardManager& self,
4137 std::string source,
4138 py::handle example_value,
4139 py::handle guard_manager_enum) -> GuardManager* {
4140 // A unique key is used to save as the accessor key.
4141 py::str unique_key("__grad_accessor__");
4142 return self.get_child_manager<GradGuardAccessor>(
4143 std::move(unique_key),
4144 std::move(source),
4145 example_value,
4146 guard_manager_enum);
4147 },
4148 py::arg("source"),
4149 py::arg("example_value"),
4150 py::arg("guard_manager_enum"),
4151 py::return_value_policy::reference)
4152 // return by reference because GuardManager has the ownership of accessors
4153 // and guard managers
4154 .def(
4155 "get_generic_dict_manager",
4156 [](GuardManager& self,
4157 std::string source,
4158 py::handle example_value,
4159 py::handle guard_manager_enum) -> GuardManager* {
4160 // A unique key is used to save as the accessor key.
4161 py::str unique_key("__generic_dict_accessor__");
4162 return self.get_child_manager<GetGenericDictGuardAccessor>(
4163 std::move(unique_key),
4164 std::move(source),
4165 example_value,
4166 guard_manager_enum);
4167 },
4168 py::arg("source"),
4169 py::arg("example_value"),
4170 py::arg("guard_manager_enum"),
4171 py::return_value_policy::reference)
4172 // return by reference because C++ GuardManager has the ownership of
4173 // accessors and guard managers
4174 .def(
4175 "getattr_manager",
4176 &GuardManager::get_child_manager<GetAttrGuardAccessor>,
4177 py::arg("attr"),
4178 py::arg("source"),
4179 py::arg("example_value"),
4180 py::arg("guard_manager_enum"),
4181 py::return_value_policy::reference);
4182
4183 // Root Guard Manager
4184 py::class_<RootGuardManager, GuardManager, std::unique_ptr<RootGuardManager>>(
4185 py_m, "RootGuardManager")
4186 .def(py::init<>())
4187 .def("check", &RootGuardManager::check)
4188 .def("check_verbose", &RootGuardManager::check_verbose)
4189 // return by reference because GuardManager has the ownership of leaf
4190 // guards
4191 .def(
4192 "get_epilogue_lambda_guards",
4193 &RootGuardManager::get_epilogue_lambda_guards,
4194 py::return_value_policy::reference)
4195 .def(
4196 "add_epilogue_lambda_guard",
4197 [](RootGuardManager& self,
4198 py::object lambda,
4199 py::object verbose_code_parts) -> void {
4200 self.add_epilogue_lambda_guard(std::make_unique<LAMBDA_GUARD>(
4201 std::move(lambda), std::move(verbose_code_parts)));
4202 });
4203
4204 // Dict Guard Manager
4205 py::class_<DictGuardManager, GuardManager, std::unique_ptr<DictGuardManager>>(
4206 py_m, "DictGuardManager")
4207 // return by reference because GuardManager has the ownership of accessors
4208 // and guard managers
4209 .def(
4210 "get_key_manager",
4211 [](DictGuardManager& self,
4212 py::object index,
4213 std::string source,
4214 py::handle example_value,
4215 py::handle guard_manager_enum) -> GuardManager* {
4216 return self.get_key_manager(
4217 std::move(index),
4218 std::move(source),
4219 example_value,
4220 guard_manager_enum);
4221 },
4222 py::arg("index"),
4223 py::arg("source"),
4224 py::arg("example_value"),
4225 py::arg("guard_manager_enum"),
4226 py::return_value_policy::reference)
4227 // return by reference because GuardManager has the ownership of accessors
4228 // and guard managers
4229 .def(
4230 "get_value_manager",
4231 [](DictGuardManager& self,
4232 py::object index,
4233 std::string source,
4234 py::handle example_value,
4235 py::handle guard_manager_enum) -> GuardManager* {
4236 return self.get_value_manager(
4237 std::move(index),
4238 std::move(source),
4239 example_value,
4240 guard_manager_enum);
4241 },
4242 py::arg("index"),
4243 py::arg("source"),
4244 py::arg("example_value"),
4245 py::arg("guard_manager_enum"),
4246 py::return_value_policy::reference)
4247 // return by reference because GuardManager has the ownership of leaf
4248 // guards
4249 .def(
4250 "get_key_value_managers",
4251 &DictGuardManager::get_key_value_managers,
4252 py::return_value_policy::reference)
4253 // Skipped leaf guards
4254 .def("add_type_match_guard", &DictGuardManager::skip_adding_guard)
4255 .def("add_dict_length_check_guard", &DictGuardManager::skip_adding_guard)
4256 // Permitted leaf guards
4257 .def(
4258 "add_dict_contains_guard",
4259 [](DictGuardManager& self,
4260 bool contains,
4261 py::object key,
4262 py::object verbose_code_parts) -> void {
4263 self.add_permitted_leaf_guard(std::make_shared<DICT_CONTAINS>(
4264 contains, std::move(key), std::move(verbose_code_parts)));
4265 })
4266 .def(
4267 "add_dict_version_guard",
4268 [](DictGuardManager& self,
4269 py::object value,
4270 py::object verbose_code_parts) -> void {
4271 // DICT_VERSION is used in a very narrow context today to guard on
4272 // pytree SUPPPORTED_NODES. We can remove this once we have tags in
4273 // DictGuardManager.
4274 self.add_permitted_leaf_guard(std::make_shared<DICT_VERSION>(
4275 std::move(value), std::move(verbose_code_parts)));
4276 })
4277 // Not permitted accesssors
4278 .def("lambda_manager", &DictGuardManager::fail_on_get_child_manager)
4279 .def("getitem_manager", &DictGuardManager::fail_on_get_child_manager)
4280 .def("dict_getitem_manager", &DictGuardManager::fail_on_get_child_manager)
4281 .def("globals_dict_manager", &DictGuardManager::fail_on_get_child_manager)
4282 .def(
4283 "tuple_iterator_getitem_manager",
4284 &DictGuardManager::fail_on_get_child_manager)
4285 .def(
4286 "global_weakref_manager",
4287 &DictGuardManager::fail_on_get_child_manager)
4288 .def("lambda_manager", &DictGuardManager::fail_on_get_child_manager)
4289 // Permitted accessors (and also type_manager)
4290 // return by reference because GuardManager has the ownership of accessors
4291 // and guard managers
4292 .def(
4293 "getattr_manager",
4294 [](DictGuardManager& self,
4295 py::object attr_name,
4296 std::string source,
4297 py::handle example_value,
4298 py::handle guard_manager_enum) -> GuardManager* {
4299 if (self.is_exact_dict_type()) {
4300 throw std::runtime_error(
4301 "getattr_manager on a DictGuardManager is supported only for dict subclasses");
4302 }
4303 return self.get_child_manager<GetAttrGuardAccessor>(
4304 std::move(attr_name),
4305 std::move(source),
4306 example_value,
4307 guard_manager_enum);
4308 },
4309 py::arg("attr"),
4310 py::arg("source"),
4311 py::arg("example_value"),
4312 py::arg("guard_manager_enum"),
4313 py::return_value_policy::reference);
4314
4315 // Dict Guard Manager
4316 py::class_< // NOLINT
4317 DictSubclassGuardManager,
4318 DictGuardManager,
4319 std::unique_ptr<DictSubclassGuardManager>>(
4320 py_m, "DictSubclassGuardManager") // NOLINT
4321 .def(
4322 "add_no_hasattr_guard",
4323 [](DictSubclassGuardManager& self,
4324 py::object attr_name,
4325 py::object verbose_code_parts) -> void {
4326 self.add_permitted_leaf_guard(std::make_shared<NO_HASATTR>(
4327 std::move(attr_name), std::move(verbose_code_parts)));
4328 });
4329
4330 py_m.def("install_object_aliasing_guard", install_object_aliasing_guard);
4331 py_m.def(
4332 "install_no_tensor_aliasing_guard", install_no_tensor_aliasing_guard);
4333
4334 // initialize dict_version_map watcher for 3.12
4335 #if IS_PYTHON_3_12_PLUS
4336
4337 dict_version_watcher_id = PyDict_AddWatcher(dict_version_watch_callback);
4338 if (dict_version_watcher_id == -1) {
4339 throw std::runtime_error("Failed to install dict_version_watch_callback");
4340 }
4341
4342 #endif
4343
4344 return m;
4345 }
4346
4347 } // namespace torch::dynamo
4348