xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_variable.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/NamedTensorUtils.h>
2 #include <c10/core/DeviceType.h>
3 #include <c10/core/impl/GPUTrace.h>
4 #include <c10/core/impl/HermeticPyObjectTLS.h>
5 #include <c10/core/impl/PythonDispatcherTLS.h>
6 #include <c10/util/irange.h>
7 #include <pybind11/pytypes.h>
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/DynamicTypes.h>
10 #include <torch/csrc/Exceptions.h>
11 #include <torch/csrc/PyInterpreter.h>
12 #include <torch/csrc/Size.h>
13 #include <torch/csrc/THP.h>
14 #include <torch/csrc/Types.h>
15 #include <torch/csrc/autograd/autograd.h>
16 #include <torch/csrc/autograd/edge.h>
17 #include <torch/csrc/autograd/function.h>
18 #include <torch/csrc/autograd/python_cpp_function.h>
19 #include <torch/csrc/autograd/python_hook.h>
20 #include <torch/csrc/autograd/python_variable_indexing.h>
21 #include <torch/csrc/autograd/utils/error_messages.h>
22 #include <torch/csrc/autograd/utils/wrap_outputs.h>
23 #include <torch/csrc/autograd/variable.h>
24 #include <torch/csrc/jit/frontend/tracer.h>
25 #include <torch/csrc/jit/python/pybind_utils.h>
26 #include <torch/csrc/tensor/python_tensor.h>
27 #include <torch/csrc/utils/pybind.h>
28 #include <torch/csrc/utils/pycfunction_helpers.h>
29 #include <torch/csrc/utils/pyobject_preservation.h>
30 #include <torch/csrc/utils/python_arg_parser.h>
31 #include <torch/csrc/utils/python_dispatch.h>
32 #include <torch/csrc/utils/python_strings.h>
33 #include <torch/csrc/utils/tensor_new.h>
34 #include <torch/csrc/utils/tensor_numpy.h>
35 
36 #include <torch/csrc/utils/torch_dispatch_mode.h>
37 
38 #include <ATen/ATen.h>
39 
40 #include <c10/core/SymIntArrayRef.h>
41 #include <structmember.h>
42 #include <cstdint>
43 #include <memory>
44 #include <utility>
45 #include <vector>
46 
47 using namespace at;
48 using namespace torch;
49 using namespace torch::autograd;
50 
parseIValuesToPyArgsKwargs(const c10::OperatorHandle & op,const std::vector<c10::IValue> & arguments)51 std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
52     const c10::OperatorHandle& op,
53     const std::vector<c10::IValue>& arguments) {
54   TORCH_CHECK(
55       PyGILState_Check(),
56       "GIL must be held before you call parseIValuesToPyArgsKwargs");
57   const auto& schema = op.schema();
58   py::dict kwargs;
59   // About all the pointers:
60   //
61   // f(int x, int y = 0, *, int z = 0)
62   //                                  ^- arguments.size()
63   //                        ^- kwarg_only_start
64   //          ^- positional_default_start
65   //   ^- 0
66 
67   // Find the split point between kwarg-only and regular.  Since most functions
68   // don't have kwarg-only arguments, it is more efficient to scan from the
69   // right (but ideally, this would just be precomputed in FunctionSchema
70   // itself).  (NB: minus one in the loop is because we're testing if the
71   // *next* argument is kwarg-only before we advance the starting index)
72   int64_t kwarg_only_start = static_cast<int64_t>(arguments.size());
73   for (; kwarg_only_start > 0; kwarg_only_start--) {
74     const auto& arg = schema.arguments()[kwarg_only_start - 1];
75     if (!arg.kwarg_only()) {
76       break;
77     }
78   }
79 
80   // Find the first positional argument that isn't defaulted
81   auto is_default = [&](size_t idx) -> bool {
82     const auto& arg = schema.arguments()[idx];
83     if (!arg.default_value().has_value()) {
84       return false;
85     }
86     const auto& default_ivalue = *arg.default_value();
87     const auto& ivalue = arguments[idx];
88     if (default_ivalue != ivalue) {
89       return false;
90     }
91     return true;
92   };
93 
94   int64_t positional_default_start = kwarg_only_start;
95   for (; positional_default_start > 0; positional_default_start--) {
96     if (!is_default(positional_default_start - 1)) {
97       break;
98     }
99   }
100 
101   auto args =
102       py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
103 
104   auto schemaAwareToPyObject = [&](size_t idx) -> py::object {
105     const auto& arg = schema.arguments()[idx];
106     auto match = [&](c10::TypeKind kind) {
107       const auto& t = arg.real_type();
108       if (t->kind() == kind)
109         return true;
110       if (auto opt_t = t->cast<c10::OptionalType>()) {
111         if (opt_t->getElementType()->kind() == kind)
112           return true;
113       }
114       return false;
115     };
116     if (arguments[idx].isNone()) {
117       return py::none();
118     } else if (match(c10::ScalarTypeType::Kind)) {
119       auto* obj =
120           getTHPDtype(static_cast<c10::ScalarType>(arguments[idx].toInt()));
121       return py::reinterpret_borrow<py::object>(
122           reinterpret_cast<PyObject*>(obj));
123     } else if (match(c10::LayoutType::Kind)) {
124       auto* obj =
125           getTHPLayout(static_cast<c10::Layout>(arguments[idx].toInt()));
126       return py::reinterpret_borrow<py::object>(
127           reinterpret_cast<PyObject*>(obj));
128     } else if (match(c10::MemoryFormatType::Kind)) {
129       return py::cast(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
130     } else {
131       return torch::jit::toPyObject(arguments[idx]);
132     }
133   };
134 
135   // Populate positional arguments
136   for (const auto idx : c10::irange(positional_default_start)) {
137     PyTuple_SET_ITEM(
138         args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr());
139   }
140 
141   // Populate keyword arguments
142   for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
143     // But don't populate default keyword arguments
144     if (is_default(idx))
145       continue;
146     const auto& arg = schema.arguments()[idx];
147     kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx);
148   }
149   return std::make_pair(std::move(args), std::move(kwargs));
150 }
151 
pushPyOutToStack(const c10::OperatorHandle & op,torch::jit::Stack * stack,py::object out,const char * msg)152 void pushPyOutToStack(
153     const c10::OperatorHandle& op,
154     torch::jit::Stack* stack,
155     py::object out,
156     const char* msg) {
157   TORCH_CHECK(
158       PyGILState_Check(), "GIL must be held before you call pushPyOutToStack");
159   auto schema_returns = op.schema().returns();
160   const auto num_returns = schema_returns.size();
161   if (num_returns == 0) {
162     // Check that we got a None return from Python. Anything else is an error.
163     TORCH_CHECK(
164         out.is_none(),
165         "Expected ",
166         msg,
167         " for ",
168         op.operator_name(),
169         " to return None but it returned something else instead.");
170   } else if (num_returns == 1) {
171     torch::jit::push(
172         stack, torch::jit::toIValue(out.ptr(), schema_returns[0].real_type()));
173   } else {
174     auto outs = py::cast<py::sequence>(out);
175     for (const auto idx : c10::irange(outs.size())) {
176       torch::jit::push(
177           stack,
178           torch::jit::toIValue(
179               outs[idx].ptr(), schema_returns[idx].real_type()));
180     }
181   }
182 }
183 
184 namespace {
185 
parseSizesStridesPolicyArgument(c10::string_view arg)186 c10::TensorImpl::SizesStridesPolicy parseSizesStridesPolicyArgument(
187     c10::string_view arg) {
188   if (arg == "strides") {
189     return c10::TensorImpl::SizesStridesPolicy::CustomStrides;
190   }
191 
192   if (arg == "sizes") {
193     return c10::TensorImpl::SizesStridesPolicy::CustomSizes;
194   }
195 
196   TORCH_CHECK_VALUE(
197       false,
198       "Unknown sizes_strides_policy: ",
199       arg,
200       "; expected 'strides' or 'sizes'");
201 }
202 } // anonymous namespace
203 
204 PyObject* THPVariableClass = nullptr;
205 
206 PyObject* ParameterClass = nullptr;
207 
208 static PyObject* THPVariable_NewWithVar(
209     PyTypeObject* type,
210     Variable _var,
211     c10::impl::PyInterpreterStatus status,
212     bool allow_preexisting_pyobj = false);
213 
214 // clang-tidy gets confused by static const
215 static const char* VOLATILE_WARNING =
216     "volatile was removed and now has no effect. Use "
217     "`with torch.no_grad():` instead.";
218 
check_has_torch_dispatch(PyObject * obj)219 static bool check_has_torch_dispatch(PyObject* obj) {
220   PyTypeObject* tp = Py_TYPE(obj);
221   if (THPVariable_CheckTypeExact(tp)) {
222     return false;
223   }
224   py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
225   return (
226       attr.ptr() != nullptr &&
227       attr.ptr() != torch::disabled_torch_dispatch_impl());
228 }
229 
230 // NOLINTNEXTLINE(*-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
231 static PyObject* device_to_py_class_[static_cast<size_t>(
232     c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
233 
registerPythonTensorClass(const std::string & device,PyObject * python_tensor_class)234 void registerPythonTensorClass(
235     const std::string& device,
236     PyObject* python_tensor_class) {
237   c10::Device dev(device);
238 
239   TORCH_CHECK(
240       dev.type() == kXLA, "Only the python class for XLA can be overriden");
241   if (device_to_py_class_[static_cast<size_t>(dev.type())] != nullptr) {
242     TORCH_WARN(
243         "Overriding a previously registered python class for ", dev.str());
244   }
245 
246   device_to_py_class_[static_cast<size_t>(dev.type())] = python_tensor_class;
247 }
248 
getPythonTensorClass(c10::Device d)249 static PyObject* getPythonTensorClass(c10::Device d) {
250   return device_to_py_class_[static_cast<size_t>(d.type())];
251 }
252 
activateGPUTrace()253 void activateGPUTrace() {
254   c10::impl::GPUTrace::set_trace(getPyInterpreter());
255 }
256 
257 // TODO: Make this take Variable by const reference
THPVariable_Wrap(at::TensorBase var)258 PyObject* THPVariable_Wrap(at::TensorBase var) {
259   if (!var.defined()) {
260     Py_RETURN_NONE;
261   }
262 
263   if (c10::impl::HermeticPyObjectTLS::get_state()) {
264     return THPVariable_NewWithVar(
265         (PyTypeObject*)THPVariableClass,
266         std::move(var),
267         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
268   }
269 
270   std::optional<PyObject*> mb_obj =
271       var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
272           getPyInterpreter(), /*ignore_hermetic_tls=*/false);
273   c10::impl::PyInterpreterStatus status{};
274   if (mb_obj.has_value()) {
275     auto obj = *mb_obj;
276     if (obj) {
277       if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
278         // C++ owns the Python object; this implies there weren't any other
279         // owning references to the Python object.  Since we're making the
280         // object "live" again on Python side, let's flip back the ownership
281         // (Python owns C++) as it would now be unsound to deallocate the C++
282         // object if all C++ references go to zero
283         var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
284         reinterpret_cast<THPVariable*>(obj)->cdata =
285             MaybeOwned<Variable>::owned(std::move(var));
286         // NB: incref is not necessary, because we are "stealing" the previous
287         // ownership from the Variable to return it here for the wrap
288         return obj;
289       }
290       Py_INCREF(obj);
291       return obj;
292     }
293     // TODO: a better invariant is that if we tagged, we MUST have a valid
294     // PyObject.  That's PyObject preservation
295     // (https://github.com/pytorch/pytorch/pull/56017).  Prior to this PR
296     // being a thing, the PyObject field will get cleared when all references
297     // to the Python object are removed.
298     status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
299   } else {
300     // Assumption: if a Tensor has been shared across threads, this induces
301     // a refcount bump.  Therefore, if the use count 1, we are the sole thread
302     // with access to this tensor and no race is possible.
303     if (var.use_count() <= 1) {
304       status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
305     } else {
306       status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
307     }
308   }
309 
310   if (C10_LIKELY(var.device().type() != c10::kXLA)) {
311     return THPVariable_NewWithVar(
312         (PyTypeObject*)THPVariableClass, std::move(var), status);
313   }
314 
315   if (auto clazz = getPythonTensorClass(var.device())) {
316     return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status);
317   }
318 
319   return THPVariable_NewWithVar(
320       (PyTypeObject*)THPVariableClass, std::move(var), status);
321 }
322 
isResurrectable(THPVariable * self)323 bool isResurrectable(THPVariable* self) {
324   // We want to divide this check into 2 cases.
325 
326   // 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is
327   // true). You might think that in this case, it is impossible for tp_clear to
328   // be called: surely the C++ reference to the PyObject is keeping it live? And
329   // you'd be right! In fact, when C++ owns the PyObject, we have an invariant
330   // that the refcount on the PyObject should be precisely one (because if you
331   // take out another reference to the PyObject, we're supposed to flip the
332   // ownership pointer back). In reality, you can violate this invariant
333   // temporarily with weak references, so we don't test for it in asserts.
334 
335   // 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is
336   // false). In this case, tp_clear can get called if the PyObject is referenced
337   // from a dead cycle, and nowhere else. But if resurrection did not occur,
338   // then the reference to C++ from the PyObject must be the ONLY reference to
339   // the C++ object.
340   if (self->cdata.unsafeIsBorrowed()) {
341     return false;
342   }
343   auto const& tensor = THPVariable_Unpack(self);
344   if (!tensor.defined() || tensor.use_count() <= 1) {
345     return false;
346   }
347   // Check if this is hermetic. If it is, no resurrection.
348   if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
349           getPyInterpreter(), /*ignore_hermetic_tls=*/false) !=
350       std::make_optional((PyObject*)self)) {
351     return false;
352   }
353   return true;
354 }
355 
356 // returns true if successfully rezzed; if so, cancel the
357 // rest of deallocation
THPVariable_tryResurrect(THPVariable * self)358 static bool THPVariable_tryResurrect(THPVariable* self) {
359   const auto& tensor = THPVariable_Unpack(self);
360 
361   if (!isResurrectable(self)) {
362     return false;
363   }
364 
365   // At this point, we are definitely going to resurrect the tensor. So, the
366   // tensor better be defined :)
367   TORCH_INTERNAL_ASSERT(tensor.defined());
368 
369   // There are other C++ owners of the tensor.  Flip ownership
370   // so that C++ owns this Python object, and cancel deallocation.
371   TORCH_INTERNAL_ASSERT(
372       !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
373 
374   c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
375   auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
376       getPyInterpreter(),
377       /*ignore_hermetic_tls=*/false);
378 
379   TORCH_INTERNAL_ASSERT(
380       maybe_pyobj.has_value(),
381       "Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
382 
383   tensor_impl->pyobj_slot()->set_owns_pyobj(true);
384 
385 // Resurrect the Python object.  This is something CPython does
386 // internally occasionally, see
387 // https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259
388 // so we just copy the pattern here.  Note that we don't have to worry
389 // about saving and restoring the refcount (as the quoted code does)
390 // because we actually DO need to reset the refcount to one here, we
391 // can't assume that some other code has taken care of it.
392 // NB: this will overreport _Py_RefTotal but based on inspection of object.c
393 // there is no way to avoid this
394 #ifdef Py_TRACE_REFS
395   _Py_AddToAllObjects(reinterpret_cast<PyObject*>(self), 1);
396 #endif
397   Py_INCREF(self);
398 
399   // Flip THPVariable to be non-owning
400   // (near use-after-free miss here: fresh MaybeOwned is created breaking
401   // reference on Tensor in struct BEFORE we overwrite the old one)
402   TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
403   self->cdata = MaybeOwned<Variable>::borrowed(tensor);
404 
405   // NB: At this point, tensor *could* be dead (e.g., some other C++ thread
406   // decrefed it.)  At this point, it is probably waiting on the GIL to
407   // deallocate the Python object and will kill self, BUT NOT YET.
408 
409   return true;
410 }
411 
THPVariable_clear(THPVariable * self)412 static int THPVariable_clear(THPVariable* self) {
413   // Is it OK for an object to still be live after running
414   // tp_clear? Yes. When Python is breaking reference cycles, it can't assume
415   // that an object will dealloc after it's cleared.  The source code explicitly
416   // handles this case:
417   // https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025
418 
419   // Note that we don't need to actually resurrect here. There are 2 cases:
420   // 1. The PyObject is not part of a reference cycle. In this case, we don't
421   // need to do anything. The GC will move on to try and break the reference
422   // cycle on another object, which will eventually trigger tp_dealloc (and thus
423   // resurrection).
424 
425   // 2. The PyObject is part of a reference cycle. This case should not actually
426   // be possible, due to the logic in our tp_traverse
427   // (THPVariable_subclass_traverse).
428 
429   // In fact, resurrecting here breaks the invariant that "C++ owns Python only
430   // when PyObject's refcount would otherwise be 0". Most immediately, as we're
431   // merely breaking reference cycles here, there can be other references to the
432   // PyObject. *However*, if other objects in the refcycle resurrect, then we
433   // will be in a state where the PyObject has multiple Python references, yet
434   // C++ owns the PyObject.
435 
436   // See https://github.com/pytorch/pytorch/pull/75933 for more discussion.
437   if (isResurrectable((THPVariable*)self)) {
438     return 0;
439   }
440   Py_CLEAR(self->backward_hooks);
441   Py_CLEAR(self->post_accumulate_grad_hooks);
442   const auto& tensor = THPVariable_Unpack(self);
443   if (tensor.defined()) {
444     // Two situations to consider:
445     //    PyObject -owns-> Tensor
446     //        unsafeIsBorrowed() is FALSE.  We're obligated to look through
447     //        Tensor to break references.  Clearing cdata must induce the
448     //        destruction of the C++ Tensor.  If there were other references
449     //        to C++ tensor, the Python object would have been resurrected
450     //        by flipping the ownership.
451     //    Tensor -owns-> PyObject
452     //        unsafeIsBorrowed() is TRUE.  We're deallocating the PyObject
453     //        because Tensor asked us to (it's already destructing).
454 
455     if (!self->cdata.unsafeIsBorrowed() &&
456         tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
457             getPyInterpreter(), /*ignore_hermetic_tls=*/false) ==
458             std::make_optional((PyObject*)self)) {
459       // TODO: empirically, on OS X this assert appears to be untrue
460       // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
461       // distributed/rpc/test_process_group_agent.py
462       //
463       //  libc++abi.dylib: terminating with uncaught exception of type
464       //  c10::Error:
465       //  !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
466       //  ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
467       //  please report a bug to PyTorch. Exception raised from
468       //  THPVariable_clear at
469       //  ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
470       //  first): frame #0: c10::Error::Error(c10::SourceLocation,
471       //  std::__1::basic_string<char, std::__1::char_traits<char>,
472       //  std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame
473       //  #1: c10::detail::torchCheckFail(char const*, char const*, unsigned
474       //  int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2:
475       //  c10::detail::torchInternalAssertFail(char const*, char const*,
476       //  unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
477       //  (0x1141e3f89 in libtorch_python.dylib) frame #3:
478       //  THPVariable_clear(THPVariable*) + 412 (0x1148a547c in
479       //  libtorch_python.dylib) frame #4:
480       //  THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
481       //  libtorch_python.dylib) frame #5: (anonymous
482       //  namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
483       //  _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6:
484       //  c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in
485       //  libc10.dylib) frame #7:
486       //  c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&)
487       //  + 91 (0x11488c11b in libtorch_python.dylib) frame #8:
488       //  THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in
489       //  libtorch_python.dylib) <omitting python frames> frame #47: start + 1
490       //  (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???)
491       // TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
492       if (auto grad_acc =
493               torch::autograd::impl::try_get_grad_accumulator(tensor)) {
494         grad_acc->pre_hooks().clear();
495         grad_acc->tensor_pre_hooks().clear();
496         grad_acc->retains_grad_hooks().clear();
497       }
498     }
499   }
500   TORCH_INTERNAL_ASSERT(!isResurrectable((THPVariable*)self));
501   {
502     // MapAllocator can take significant time to release large tensors;
503     // release the GIL here to avoid impacting main thread perf.
504     pybind11::gil_scoped_release no_gil;
505     self->cdata = MaybeOwned<Variable>();
506   }
507   return 0;
508 }
509 
THPFunction_traverse(THPFunction * self,visitproc visit,void * arg)510 int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
511   TORCH_INTERNAL_ASSERT(
512       false, "Tensor tp_traverse function was not overriden properly");
513   return 0;
514 }
515 
516 PyObject* THPVariable_pynew(
517     PyTypeObject* type,
518     PyObject* args,
519     PyObject* kwargs);
520 
THPVariable_fix_weakref(PyObject * self,PyObject * noargs)521 static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
522   const auto& var = THPVariable_Unpack(self);
523   Py_DECREF(THPVariable_Wrap(var));
524   Py_RETURN_NONE;
525 }
526 
527 // Maps the given python callable over a vector of items, returning a vector
528 // of the same type of items.
529 template <typename T>
map_py_func(const py::function & func,const std::vector<T> & items)530 static std::vector<T> map_py_func(
531     const py::function& func,
532     const std::vector<T>& items) {
533   std::vector<T> new_items;
534   new_items.reserve(items.size());
535   for (auto& item : items) {
536     new_items.emplace_back(py::cast<T>(func(item)));
537   }
538   return new_items;
539 }
540 
541 template <>
map_py_func(const py::function & func,const std::vector<at::Tensor> & items)542 std::vector<at::Tensor> map_py_func(
543     const py::function& func,
544     const std::vector<at::Tensor>& items) {
545   std::vector<at::Tensor> new_items;
546   new_items.reserve(items.size());
547   for (auto& item : items) {
548     auto output = func(item);
549     if (output.is(py::none())) {
550       // treat None value as an undefined tensor
551       new_items.emplace_back();
552     } else {
553       new_items.emplace_back(py::cast<at::Tensor>(output));
554     }
555   }
556   return new_items;
557 }
558 
view_func_impl(PyObject * _self,PyObject * args,PyObject * kwargs,bool check_has_same_meta)559 static PyObject* view_func_impl(
560     PyObject* _self,
561     PyObject* args,
562     PyObject* kwargs,
563     bool check_has_same_meta) {
564   HANDLE_TH_ERRORS
565   const auto& self = THPVariable_Unpack(_self);
566 
567   static PythonArgParser parser({
568       "_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)",
569   });
570   ParsedArgs<3> parsed_args{};
571   auto r = parser.parse(_self, args, kwargs, parsed_args);
572   auto new_base = r.tensor(0);
573   PyObject* symint_visitor_fn = r.pyobject(1);
574   PyObject* tensor_visitor_fn = r.pyobject(2);
575 
576   // Ensure that self is indeed a backward differentiable view
577   // If not, we return an undefined Tensor (None) and let the user handle it.
578   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
579   at::Tensor out;
580   if (diff_view_meta && diff_view_meta->has_bw_view()) {
581     const auto& view_info = diff_view_meta->get_backward_view();
582     // Ensure that the newly provided base is similar to the original base
583     if (!check_has_same_meta ||
584         torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
585       // Do the actual view replay
586       if (view_info.has_view_fn()) {
587         auto& view_func = view_info.view_fn();
588 
589         // Determine new SymInt / tensor state as needed.
590         std::optional<std::vector<c10::SymInt>> new_symints = std::nullopt;
591         if (symint_visitor_fn != Py_None) {
592           new_symints = map_py_func(
593               py::cast<py::function>(symint_visitor_fn),
594               view_func.get_symints());
595         }
596 
597         std::optional<std::vector<at::Tensor>> new_tensors = std::nullopt;
598         if (tensor_visitor_fn != Py_None) {
599           new_tensors = map_py_func(
600               py::cast<py::function>(tensor_visitor_fn),
601               view_func.get_tensors());
602         }
603 
604         // call view func
605         if (new_symints.has_value() || new_tensors.has_value()) {
606           out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base);
607         } else {
608           out = view_func(new_base);
609         }
610       } else {
611         out = new_base.as_strided(
612             self.sizes(), self.strides(), self.storage_offset());
613       }
614     }
615   }
616   return THPVariable_Wrap(std::move(out));
617   END_HANDLE_TH_ERRORS
618 }
619 
THPVariable_view_func(PyObject * self_,PyObject * args,PyObject * kwargs)620 static PyObject* THPVariable_view_func(
621     PyObject* self_,
622     PyObject* args,
623     PyObject* kwargs) {
624   return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true);
625 }
626 
THPVariable_view_func_unsafe(PyObject * self_,PyObject * args,PyObject * kwargs)627 static PyObject* THPVariable_view_func_unsafe(
628     PyObject* self_,
629     PyObject* args,
630     PyObject* kwargs) {
631   return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false);
632 }
633 
rev_view_func_impl(PyObject * self_,PyObject * arg)634 static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
635   HANDLE_TH_ERRORS
636   const auto& self = THPVariable_Unpack(self_);
637   TORCH_CHECK(
638       THPVariable_Check(arg),
639       "_rev_view_func expect a single argument that is a Tensor");
640   const auto& new_view = THPVariable_Unpack(arg);
641 
642   // Ensure that self is indeed a backward differentiable view
643   // If not, we return an undefined Tensor (None) and let the user handle it.
644   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
645   at::Tensor out;
646   if (diff_view_meta && diff_view_meta->has_bw_view()) {
647     const auto& view_info = diff_view_meta->get_backward_view();
648     // Do the actual view replay
649     TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found");
650     out = view_info.rev_view_fn()(new_view);
651   }
652   return THPVariable_Wrap(std::move(out));
653   END_HANDLE_TH_ERRORS
654 }
655 
THPVariable_rev_view_func_unsafe(PyObject * self_,PyObject * arg)656 static PyObject* THPVariable_rev_view_func_unsafe(
657     PyObject* self_,
658     PyObject* arg) {
659   return rev_view_func_impl(self_, arg);
660 }
661 
662 // Instantiates a subclass of self with the same data.
THPVariable_as_subclass(PyObject * _self,PyObject * args,PyObject * kwargs)663 static PyObject* THPVariable_as_subclass(
664     PyObject* _self,
665     PyObject* args,
666     PyObject* kwargs) {
667   HANDLE_TH_ERRORS
668   const auto& self = THPVariable_Unpack(_self);
669   static PythonArgParser parser({
670       "as_subclass(PyObject* cls)",
671   });
672   ParsedArgs<1> parsed_args{};
673   auto r = parser.parse(_self, args, kwargs, parsed_args);
674   PyObject* cls = r.pyobject(0);
675   TORCH_CHECK_TYPE(
676       PyType_Check(cls),
677       "cls must be a type (got ",
678       Py_TYPE(cls)->tp_name,
679       ")");
680   return THPVariable_NewWithVar(
681       (PyTypeObject*)cls,
682       self.alias(),
683       c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
684   END_HANDLE_TH_ERRORS
685 }
686 
THPVariable_make_subclass(PyObject * _ignored,PyObject * args,PyObject * kwargs)687 static PyObject* THPVariable_make_subclass(
688     PyObject* _ignored,
689     PyObject* args,
690     PyObject* kwargs) {
691   HANDLE_TH_ERRORS
692   static PythonArgParser parser({
693       "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, Device? device_for_backend_keys=None)",
694   });
695   ParsedArgs<7> parsed_args{};
696   auto r = parser.parse(args, kwargs, parsed_args);
697   PyObject* cls = r.pyobject(0);
698   TORCH_CHECK_TYPE(
699       PyType_Check(cls),
700       "cls must be a type (got ",
701       Py_TYPE(cls)->tp_name,
702       ")");
703   // guard completely turns off torch dispatch modes, doesn't just pop off the
704   // stack
705   torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
706   c10::impl::DisablePythonDispatcher dpd_g;
707   auto data =
708       r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED)
709   // We set `data`'s `allow_tensor_metadata_change` to true here, because we
710   // want to allow the following use case for backward compatibility:
711   //
712   // ```python
713   // rnn = torch.nn.RNN(100, 100, 2)
714   // # The following calls `torch._cudnn_rnn_flatten_weight(rnn._flat_weights,
715   // ...)`, # which changes storage of `rnn`'s weights in-place
716   // rnn.flatten_parameters()
717   // ```
718   data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
719   data.set_requires_grad(r.toBool(2));
720   const auto sizes_strides_policy = r.stringViewOptional(3);
721   if (sizes_strides_policy.has_value()) {
722     data.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
723         parseSizesStridesPolicyArgument(*sizes_strides_policy));
724   }
725   if (r.toBool(4)) {
726     data.unsafeGetTensorImpl()->set_python_custom_device(true);
727   }
728   if (r.toBool(5)) {
729     data.unsafeGetTensorImpl()->set_python_custom_layout(true);
730   }
731   if (!r.isNone(6)) {
732     data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
733   }
734 
735   return THPVariable_NewWithVar(
736       (PyTypeObject*)cls,
737       data,
738       c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
739   END_HANDLE_TH_ERRORS
740 }
741 
THPVariable_make_wrapper_subclass(PyObject *,PyObject * args,PyObject * kwargs)742 static PyObject* THPVariable_make_wrapper_subclass(
743     PyObject*,
744     PyObject* args,
745     PyObject* kwargs) {
746   HANDLE_TH_ERRORS
747   // NB: pin_memory doesn't actually do anything
748   // TODO: strides variant?
749 
750   // cls: Python subclass type
751   // size, strides, storage_offset, memory_format, dtype: self-explanatory
752   // layout: memory layout, e.g. for types of Nested Tensors or other sparse
753   //         tensors
754   // pin_memory, requires_grad: self-explanatory
755   // dispatch_sizes_strides_policy: string - which sizes/strides we should
756   //                                dispatch to a custom python implementation.
757   // dispatch_device: whether to dispatch to a custom python implementation
758   //                  for device
759   // dispatch_layout: whether to dispatch to a custom python implementation
760   //                  for layout
761   // _extra_dispatch_keys: additional dispatch keys to add to the tensor
762   // storage_size: if provided, skip storage size calculation and just use the
763   //               value provided. One use case is for Nested Tensor, where the
764   //               storage size cannot be calculated from the sizes/strides
765   //               (because they contain a NestedInt).
766   static PythonArgParser parser({
767       "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef? strides=None, "
768       "SymInt? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
769       "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
770       "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, "
771       "DispatchKeySet _extra_dispatch_keys=None, SymInt? storage_size=None)",
772   });
773   ParsedArgs<15> parsed_args{};
774   auto r = parser.parse(args, kwargs, parsed_args);
775   PyObject* cls = r.pyobject(0);
776 
777   TORCH_CHECK_TYPE(
778       PyType_Check(cls),
779       "cls must be a type (got ",
780       Py_TYPE(cls)->tp_name,
781       ")");
782 
783   // This is an important safety check; without it, the default behavior will be
784   // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
785   // key, which will immediately segfault because the data pointer is null.  By
786   // forcing users to define __torch_dispatch__ we ensure this does not happen
787   // TODO: This check is not complete; because the user can disable torch
788   // dispatch and then go again, triggering segfault.  TBH I'm thinking I want
789   // to delete this function entirely
790   py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
791   TORCH_CHECK_TYPE(
792       attr.ptr() != nullptr &&
793           attr.ptr() != torch::disabled_torch_dispatch_impl(),
794       ((PyTypeObject*)cls)->tp_name,
795       " must define __torch_dispatch__");
796 
797   const auto options = TensorOptions()
798                            .dtype(r.scalartype(5))
799                            .device(r.device(7))
800                            .layout(r.layoutOptional(6))
801                            // NB: long standing issue, requires_grad is not
802                            // respected here; you have to set it post facto, see
803                            // https://github.com/pytorch/pytorch/issues/26428
804                            // .requires_grad(r.toBool(7))
805                            .pinned_memory(r.toBool(8));
806 
807   // don't bother releasing GIL here, as we are not allocating any nontrivial
808   // data
809   Tensor tensor;
810 
811   {
812     AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
813     tracer::impl::NoTracerDispatchMode tracer_guard{};
814 
815     auto sym_sizes = r.symintlist(1);
816     auto sym_strides_own = r.symintlistOptional(2);
817     auto sym_strides =
818         static_cast<std::optional<c10::SymIntArrayRef>>(sym_strides_own);
819     auto sym_storage_offset = r.toSymIntOptional(3);
820 
821     c10::SymInt size_bytes;
822     auto dtype_itemsize = static_cast<int64_t>(options.dtype().itemsize());
823     auto storage_size = r.toSymIntOptional(14);
824 
825     if (storage_size.has_value()) {
826       size_bytes = storage_size.value();
827     } else if (sym_strides.has_value()) {
828       size_bytes = at::detail::computeStorageNbytes(
829           sym_sizes,
830           sym_strides.value(),
831           dtype_itemsize,
832           sym_storage_offset.value_or(0));
833     } else {
834       size_bytes = at::detail::computeStorageNbytesContiguous(
835           sym_sizes, dtype_itemsize, sym_storage_offset.value_or(0));
836     }
837 
838     // We use storages **only** to track aliasing of subclasses during tracing.
839     // The actual data pointers are not valid.
840     Storage storage{
841         Storage::use_byte_size_t{},
842         size_bytes,
843         /*allocator=*/c10::GetAllocator(c10::kMeta),
844         /*resizable=*/true};
845     // TODO: constructor should probably accept data pointer
846     storage.set_data_ptr_noswap(at::DataPtr{nullptr, r.device(7)});
847 
848     auto keys = c10::DispatchKeySet({options.computeDispatchKey()});
849     if (auto mb_extra_keys = r.toDispatchKeySetOptional(13)) {
850       keys = keys | *mb_extra_keys;
851     }
852     tensor = at::detail::make_tensor<TensorImpl>(
853         std::move(storage), keys, options.dtype());
854 
855     TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
856 
857     if (sym_strides.has_value()) {
858       tensor_impl->set_sizes_and_strides(
859           sym_sizes, sym_strides.value(), sym_storage_offset);
860     } else {
861       TORCH_CHECK(
862           !sym_storage_offset.has_value(),
863           "setting storage offset without stride not supported");
864       tensor_impl->generic_set_sizes_contiguous(sym_sizes);
865     }
866 
867     const auto sizes_strides_policy = r.stringViewOptional(10);
868     if (sizes_strides_policy.has_value()) {
869       tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
870           parseSizesStridesPolicyArgument(*sizes_strides_policy));
871     }
872   }
873 
874   tensor.set_requires_grad(r.toBool(9));
875 
876   if (r.toBool(11)) {
877     tensor.unsafeGetTensorImpl()->set_python_custom_device(true);
878   }
879   if (r.toBool(12)) {
880     tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
881   }
882 
883   return THPVariable_NewWithVar(
884       (PyTypeObject*)cls,
885       tensor,
886       c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
887   END_HANDLE_TH_ERRORS
888 }
889 
890 using getter = PyObject* (*)(PyObject*, void*);
891 using setter = int (*)(PyObject*, PyObject*, void*);
892 
THPVariable_get_python_dispatch(THPVariable * self,void * unused)893 PyObject* THPVariable_get_python_dispatch(THPVariable* self, void* unused) {
894   HANDLE_TH_ERRORS
895   const auto& var = THPVariable_Unpack(self);
896   return torch::autograd::utils::wrap(
897       var.unsafeGetTensorImpl()->is_python_dispatch());
898   END_HANDLE_TH_ERRORS
899 }
900 
901 // CRTP base class to implement the python bindings for a Tensor property in
902 // PyTorch A class that implements a property is expected to have:
903 // - static constexpr const char* name;
904 //   - This variable should hold the Python name of the property
905 // - static Tensor fn(const Tensor&);
906 //   - This function calls the relevant ATen on the tensor
907 template <typename T>
908 struct GetterBase {
getterGetterBase909   static PyObject* getter(THPVariable* self, void* /*unused*/) {
910     HANDLE_TH_ERRORS
911     if (check_has_torch_function((PyObject*)self)) {
912       return handle_torch_function_getter(self, T::name);
913     }
914     return THPVariable_Wrap(T::fn(THPVariable_Unpack(self)));
915     END_HANDLE_TH_ERRORS
916   }
917 };
918 
919 struct PropertyT : GetterBase<PropertyT> {
920   static constexpr const char* name = "T";
fnPropertyT921   static Tensor fn(const Tensor& t) {
922     return t.numpy_T();
923   }
924 };
925 
926 struct PropertyH : GetterBase<PropertyH> {
927   static constexpr const char* name = "H";
fnPropertyH928   static Tensor fn(const Tensor& t) {
929     return t.matrix_H();
930   }
931 };
932 
933 struct PropertymT : GetterBase<PropertymT> {
934   static constexpr const char* name = "mT";
fnPropertymT935   static Tensor fn(const Tensor& t) {
936     return t.mT();
937   }
938 };
939 
940 struct PropertymH : GetterBase<PropertymH> {
941   static constexpr const char* name = "mH";
fnPropertymH942   static Tensor fn(const Tensor& t) {
943     return t.mH();
944   }
945 };
946 
947 struct PropertyData : GetterBase<PropertyData> {
948   static constexpr const char* name = "data";
fnPropertyData949   static Tensor fn(const Tensor& t) {
950     return t.variable_data();
951   }
952 };
953 
954 struct PropertyGrad : GetterBase<PropertyGrad> {
955   static constexpr const char* name = "grad";
fnPropertyGrad956   static Tensor fn(const Tensor& t) {
957     return t.grad();
958   }
959 };
960 
961 struct PropertyReal : GetterBase<PropertyReal> {
962   static constexpr const char* name = "real";
fnPropertyReal963   static Tensor fn(const Tensor& t) {
964     return at::real(t);
965   }
966 };
967 
968 struct PropertyImag : GetterBase<PropertyImag> {
969   static constexpr const char* name = "imag";
fnPropertyImag970   static Tensor fn(const Tensor& t) {
971     return at::imag(t);
972   }
973 };
974 
THPVariable_get_cdata(THPVariable * self,void * unused)975 PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) {
976   HANDLE_TH_ERRORS
977   if (check_has_torch_function((PyObject*)self)) {
978     return handle_torch_function_getter(self, "_cdata");
979   }
980   const auto& var = THPVariable_Unpack(self);
981   return PyLong_FromVoidPtr(var.unsafeGetTensorImpl());
982   END_HANDLE_TH_ERRORS
983 }
984 
THPVariable_get_version(THPVariable * self,void * unused)985 PyObject* THPVariable_get_version(THPVariable* self, void* unused) {
986   HANDLE_TH_ERRORS
987   if (check_has_torch_function((PyObject*)self)) {
988     return handle_torch_function_getter(self, "_version");
989   }
990   const auto& var = THPVariable_Unpack(self);
991   return PyInt_FromLong(var._version());
992   END_HANDLE_TH_ERRORS
993 }
994 
THPVariable_get_grad_fn(THPVariable * self,void * unused)995 PyObject* THPVariable_get_grad_fn(THPVariable* self, void* unused) {
996   HANDLE_TH_ERRORS
997   if (check_has_torch_function((PyObject*)self)) {
998     return handle_torch_function_getter(self, "grad_fn");
999   }
1000   const auto& var = THPVariable_Unpack(self);
1001   if (!var.grad_fn()) {
1002     Py_RETURN_NONE;
1003   }
1004   return functionToPyObject(var.grad_fn());
1005   END_HANDLE_TH_ERRORS
1006 }
1007 
THPVariable_set_grad_fn(THPVariable * self,PyObject * obj,void * unused)1008 static int THPVariable_set_grad_fn(
1009     THPVariable* self,
1010     PyObject* obj,
1011     void* unused) {
1012   HANDLE_TH_ERRORS
1013   if (check_has_torch_function((PyObject*)self)) {
1014     return handle_torch_function_setter(self, "_grad_fn", obj);
1015   }
1016   TORCH_CHECK(obj, "Deletion of _grad_fn not allowed. Detach tensor instead!");
1017   TORCH_CHECK(obj == Py_None, "_grad_fn can be only set to None");
1018   THPVariable_Unpack(self).detach_();
1019   return 0;
1020   END_HANDLE_TH_ERRORS_RET(-1)
1021 }
1022 
THPVariable_is_leaf(THPVariable * self,void * unused)1023 static PyObject* THPVariable_is_leaf(THPVariable* self, void* unused) {
1024   HANDLE_TH_ERRORS
1025   if (check_has_torch_function((PyObject*)self)) {
1026     return handle_torch_function_getter(self, "is_leaf");
1027   }
1028   return PyBool_FromLong(!THPVariable_Unpack(self).grad_fn());
1029   END_HANDLE_TH_ERRORS
1030 }
1031 
THPVariable_set_data(THPVariable * self,PyObject * data,void * unused)1032 int THPVariable_set_data(THPVariable* self, PyObject* data, void* unused) {
1033   HANDLE_TH_ERRORS
1034   if (check_has_torch_function((PyObject*)self)) {
1035     return handle_torch_function_setter(self, "data", data);
1036   }
1037   TORCH_CHECK(
1038       data, "Deleting tensor data is not allowed. Delete tensor instead!");
1039   TORCH_CHECK_TYPE(
1040       THPVariable_Check(data),
1041       "Variable data has to be a tensor, but got ",
1042       Py_TYPE(data)->tp_name);
1043 
1044   THPVariable_Unpack(self).set_data(THPVariable_Unpack(data));
1045   return 0;
1046   END_HANDLE_TH_ERRORS_RET(-1)
1047 }
1048 
THPVariable_set_grad(THPVariable * self,PyObject * py_grad,void * unused)1049 int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) {
1050   HANDLE_TH_ERRORS
1051   if (check_has_torch_function((PyObject*)self)) {
1052     return handle_torch_function_setter(self, "grad", py_grad);
1053   }
1054   const auto& var = THPVariable_Unpack(self);
1055   if (!py_grad || py_grad == Py_None) {
1056     var.mutable_grad().reset();
1057     return 0;
1058   }
1059 
1060   TORCH_CHECK_TYPE(
1061       THPVariable_Check(py_grad),
1062       "assigned grad expected to be a Tensor or None but got grad of type ",
1063       THPUtils_typename(py_grad));
1064   TORCH_CHECK(
1065       self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
1066 
1067   const auto& grad = THPVariable_Unpack(py_grad);
1068   TORCH_CHECK(
1069       var.dtype() == grad.dtype(),
1070       "attempting to assign a gradient with dtype '",
1071       grad.dtype(),
1072       "' to a tensor with dtype '",
1073       var.dtype(),
1074       "'. Please ensure that the gradient and the tensor have the same dtype");
1075   TORCH_CHECK(
1076       var.device().type() == grad.device().type(),
1077       "attempting to assign a gradient with device type '",
1078       grad.device().type(),
1079       "' to a tensor with device type '",
1080       var.device().type(),
1081       "'. Please ensure that the gradient and the tensor are on the same device");
1082   if (grad.layout() != kSparse) {
1083     TORCH_CHECK(
1084         grad.options().type_equal(var.options()),
1085         "attempting to assign a gradient to a tensor that has data of a different type");
1086   }
1087   TORCH_CHECK(
1088       grad.get_device() == var.get_device(),
1089       "attempting to assign a gradient located on device with index '",
1090       grad.get_device(),
1091       "' to a tensor located on device with index '",
1092       var.get_device(),
1093       "'. Please ensure that the gradient and the tensor are on the same device");
1094   TORCH_CHECK(
1095       grad.sym_sizes().equals(var.sym_sizes()),
1096       "attempting to assign a gradient of size '",
1097       grad.sym_sizes(),
1098       "' to a tensor of size '",
1099       var.sym_sizes(),
1100       "'. Please ensure that the gradient and the tensor are the same size");
1101 
1102   var.mutable_grad() = grad;
1103   return 0;
1104   END_HANDLE_TH_ERRORS_RET(-1)
1105 }
1106 
THPVariable_get_volatile(THPVariable * self,void * unused)1107 PyObject* THPVariable_get_volatile(THPVariable* self, void* unused) {
1108   HANDLE_TH_ERRORS
1109   if (check_has_torch_function((PyObject*)self)) {
1110     return handle_torch_function_getter(self, "volatile");
1111   }
1112   const char* msg = "volatile was removed (Variable.volatile is always False)";
1113   auto r = PyErr_WarnEx(PyExc_UserWarning, msg, 1);
1114   if (r != 0)
1115     throw python_error();
1116   Py_RETURN_FALSE;
1117   END_HANDLE_TH_ERRORS
1118 }
1119 
THPVariable_set_volatile(THPVariable * self,PyObject * obj,void * unused)1120 int THPVariable_set_volatile(THPVariable* self, PyObject* obj, void* unused) {
1121   HANDLE_TH_ERRORS
1122   if (check_has_torch_function((PyObject*)self)) {
1123     return handle_torch_function_setter(self, "volatile", obj);
1124   }
1125   auto r = PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
1126   if (r != 0)
1127     throw python_error();
1128   return 0;
1129   END_HANDLE_TH_ERRORS_RET(-1)
1130 }
1131 
THPVariable_get_output_nr(THPVariable * self,void * unused)1132 PyObject* THPVariable_get_output_nr(THPVariable* self, void* unused) {
1133   HANDLE_TH_ERRORS
1134   if (check_has_torch_function((PyObject*)self)) {
1135     return handle_torch_function_getter(self, "output_nr");
1136   }
1137   const auto output_nr =
1138       static_cast<long>(THPVariable_Unpack(self).output_nr());
1139   return PyInt_FromLong(output_nr);
1140   END_HANDLE_TH_ERRORS
1141 }
1142 
THPVariable_get_requires_grad(THPVariable * self,void * unused)1143 PyObject* THPVariable_get_requires_grad(THPVariable* self, void* unused) {
1144   HANDLE_TH_ERRORS
1145   if (check_has_torch_function((PyObject*)self)) {
1146     return handle_torch_function_getter(self, "requires_grad");
1147   }
1148   if (THPVariable_Unpack(self).requires_grad()) {
1149     Py_RETURN_TRUE;
1150   } else {
1151     Py_RETURN_FALSE;
1152   }
1153   END_HANDLE_TH_ERRORS
1154 }
1155 
THPVariable_retains_grad(THPVariable * self,void * unused)1156 PyObject* THPVariable_retains_grad(THPVariable* self, void* unused) {
1157   HANDLE_TH_ERRORS
1158   if (check_has_torch_function((PyObject*)self)) {
1159     return handle_torch_function_getter(self, "retains_grad");
1160   }
1161   if (THPVariable_Unpack(self).retains_grad()) {
1162     Py_RETURN_TRUE;
1163   } else {
1164     Py_RETURN_FALSE;
1165   }
1166   END_HANDLE_TH_ERRORS
1167 }
1168 
THPVariable_get_ndim(THPVariable * self,void * unused)1169 PyObject* THPVariable_get_ndim(THPVariable* self, void* unused) {
1170   HANDLE_TH_ERRORS
1171   if (check_has_torch_function((PyObject*)self)) {
1172     return handle_torch_function_getter(self, "ndim");
1173   }
1174   return PyInt_FromLong(THPVariable_Unpack(self).dim());
1175   END_HANDLE_TH_ERRORS
1176 }
1177 
THPVariable_get_names(PyObject * self,void * unused)1178 PyObject* THPVariable_get_names(PyObject* self, void* unused) {
1179   HANDLE_TH_ERRORS
1180   if (check_has_torch_function(self)) {
1181     return handle_torch_function_getter((THPVariable*)self, "names");
1182   }
1183   // The long-term plan is to return a list of (python) torch.Dimname.
1184   // However, for now, return a list of string.
1185   const auto& tensor = THPVariable_Unpack(self);
1186   auto size = tensor.dim();
1187   THPObjectPtr tuple(PyTuple_New(size));
1188   if (!tuple)
1189     throw python_error();
1190 
1191   const auto dimnames = tensor.names();
1192   for (const auto i : c10::irange(size)) {
1193     PyObject* str = nullptr;
1194     if (dimnames[i].type() == at::NameType::WILDCARD) {
1195       // PyTuple_SET_ITEM steals a reference to the object. When the tuple is
1196       // deallocated, it'll decrement the refcount on Py_None, which is bad.
1197       // To avoid this, we "create" a new reference to Py_None by increasing
1198       // the refcount.
1199       // Sources:
1200       // - https://docs.python.org/3/c-api/tuple.html#c.PyTuple_SetItem
1201       // -
1202       // https://stackoverflow.com/questions/16400600/how-to-return-a-tuple-containing-a-none-value-from-the-c-api
1203       Py_INCREF(Py_None);
1204       str = Py_None;
1205     } else {
1206       str = THPUtils_packString(dimnames[i].symbol().toUnqualString());
1207       if (!str)
1208         throw python_error();
1209     }
1210     PyTuple_SET_ITEM(tuple.get(), i, str);
1211   }
1212   return tuple.release();
1213   END_HANDLE_TH_ERRORS
1214 }
1215 
THPVariable_set_names(PyObject * self,PyObject * names,void * unused)1216 int THPVariable_set_names(PyObject* self, PyObject* names, void* unused) {
1217   HANDLE_TH_ERRORS
1218   if (check_has_torch_function(self)) {
1219     return handle_torch_function_setter((THPVariable*)self, "names", names);
1220   }
1221   const auto& var = THPVariable_Unpack(self);
1222   if (names == Py_None) {
1223     at::internal_set_names_inplace(var, std::nullopt);
1224   } else {
1225     TORCH_CHECK(
1226         THPUtils_checkDimnameList(names),
1227         "names must either be None or a tuple of dim names");
1228     at::internal_set_names_inplace(var, torch::parseDimnameList(names));
1229   }
1230   return 0;
1231   END_HANDLE_TH_ERRORS_RET(-1)
1232 }
1233 
THPVariable_set_requires_grad(THPVariable * self,PyObject * obj,void * unused)1234 int THPVariable_set_requires_grad(
1235     THPVariable* self,
1236     PyObject* obj,
1237     void* unused) {
1238   HANDLE_TH_ERRORS
1239   if (check_has_torch_function((PyObject*)self)) {
1240     return handle_torch_function_setter(self, "requires_grad", obj);
1241   }
1242   TORCH_CHECK(obj && PyBool_Check(obj), "requires_grad must be a bool");
1243   const auto& var = THPVariable_Unpack(self);
1244   auto requires_grad = (obj == Py_True);
1245   if (!var.is_leaf()) {
1246     THPUtils_setError(
1247         autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
1248     return -1;
1249   }
1250   if (requires_grad &&
1251       !isDifferentiableType(at::typeMetaToScalarType((var.dtype())))) {
1252     THPUtils_setError(
1253         "only Tensors of floating point and complex dtype can require gradients");
1254     return -1;
1255   }
1256   var.set_requires_grad(requires_grad);
1257   return 0;
1258   END_HANDLE_TH_ERRORS_RET(-1)
1259 }
1260 
THPVariable_get_name(THPVariable * self,void * unused)1261 PyObject* THPVariable_get_name(THPVariable* self, void* unused) {
1262   if (check_has_torch_function((PyObject*)self)) {
1263     HANDLE_TH_ERRORS
1264     return handle_torch_function_getter(self, "name");
1265     END_HANDLE_TH_ERRORS
1266   }
1267   const auto& tensor = THPVariable_Unpack(self);
1268   if (tensor.name().empty())
1269     Py_RETURN_NONE;
1270   return THPUtils_packString(tensor.name().c_str());
1271 }
1272 
THPVariable_get_backwards_hooks(THPVariable * self,void * unused)1273 PyObject* THPVariable_get_backwards_hooks(THPVariable* self, void* unused) {
1274   HANDLE_TH_ERRORS
1275   if (check_has_torch_function((PyObject*)self)) {
1276     return handle_torch_function_getter(self, "_backward_hooks");
1277   }
1278   if (self->backward_hooks) {
1279     Py_INCREF(self->backward_hooks);
1280     return self->backward_hooks;
1281   }
1282   Py_RETURN_NONE;
1283   END_HANDLE_TH_ERRORS
1284 }
1285 
THPVariable_set_backwards_hooks(THPVariable * self,PyObject * obj,void * unused)1286 int THPVariable_set_backwards_hooks(
1287     THPVariable* self,
1288     PyObject* obj,
1289     void* unused) {
1290   HANDLE_TH_ERRORS
1291   if (check_has_torch_function((PyObject*)self)) {
1292     return handle_torch_function_setter(self, "_backward_hooks", obj);
1293   }
1294   TORCH_CHECK(obj, "Deletion of _backwards_hooks not allowed!");
1295   if (obj == Py_None) {
1296     obj = nullptr;
1297   }
1298   Py_XINCREF(obj);
1299   Py_XDECREF(self->backward_hooks);
1300   self->backward_hooks = obj;
1301   const auto& tensor = THPVariable_Unpack(self);
1302   torch::autograd::impl::clear_hooks(tensor);
1303   if (obj) {
1304     torch::autograd::impl::add_hook(
1305         tensor, std::make_unique<PyFunctionTensorPreHook>(obj, 0));
1306   }
1307   return 0;
1308   END_HANDLE_TH_ERRORS_RET(-1)
1309 }
1310 
THPVariable_get_post_accumulate_grad_hooks(THPVariable * self,void * unused)1311 PyObject* THPVariable_get_post_accumulate_grad_hooks(
1312     THPVariable* self,
1313     void* unused) {
1314   HANDLE_TH_ERRORS
1315   if (check_has_torch_function((PyObject*)self)) {
1316     return handle_torch_function_getter(self, "_post_accumulate_grad_hooks");
1317   }
1318   if (self->post_accumulate_grad_hooks) {
1319     Py_INCREF(self->post_accumulate_grad_hooks);
1320     return self->post_accumulate_grad_hooks;
1321   }
1322   Py_RETURN_NONE;
1323   END_HANDLE_TH_ERRORS
1324 }
1325 
THPVariable_set_post_accumulate_grad_hooks(THPVariable * self,PyObject * obj,void * unused)1326 int THPVariable_set_post_accumulate_grad_hooks(
1327     THPVariable* self,
1328     PyObject* obj,
1329     void* unused) {
1330   HANDLE_TH_ERRORS
1331   if (check_has_torch_function((PyObject*)self)) {
1332     return handle_torch_function_setter(
1333         self, "_post_accumulate_grad_hooks", obj);
1334   }
1335   TORCH_CHECK(obj, "Deletion of _post_accumulate_grad_hooks not allowed!");
1336   if (obj == Py_None) {
1337     obj = nullptr;
1338   }
1339   Py_XINCREF(obj);
1340   Py_CLEAR(self->post_accumulate_grad_hooks);
1341   self->post_accumulate_grad_hooks = obj;
1342   const auto& tensor = THPVariable_Unpack(self);
1343   if (obj) {
1344     torch::autograd::impl::set_post_acc_grad_hooks(
1345         tensor, std::make_unique<PyFunctionTensorPostAccGradHooks>(obj));
1346   }
1347   return 0;
1348   END_HANDLE_TH_ERRORS_RET(-1)
1349 }
1350 
THPVariable_get_base(THPVariable * self,void * unused)1351 PyObject* THPVariable_get_base(THPVariable* self, void* unused) {
1352   HANDLE_TH_ERRORS
1353   if (check_has_torch_function((PyObject*)self)) {
1354     return handle_torch_function_getter(self, "_base");
1355   }
1356   const auto& tensor = THPVariable_Unpack(self);
1357   if (tensor.is_view()) {
1358     return THPVariable_Wrap(tensor._base());
1359   }
1360   Py_RETURN_NONE;
1361   END_HANDLE_TH_ERRORS
1362 }
1363 
THPVariable_get_shape(THPVariable * self,void * unused)1364 PyObject* THPVariable_get_shape(THPVariable* self, void* unused) {
1365   HANDLE_TH_ERRORS
1366   if (check_has_torch_function((PyObject*)self)) {
1367     return handle_torch_function_getter(self, "shape");
1368   }
1369   return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
1370   END_HANDLE_TH_ERRORS
1371 }
1372 
THPVariable_is_cpu(THPVariable * self,void * unused)1373 PyObject* THPVariable_is_cpu(THPVariable* self, void* unused) {
1374   HANDLE_TH_ERRORS
1375   if (check_has_torch_function((PyObject*)self)) {
1376     return handle_torch_function_getter(self, "is_cpu");
1377   }
1378   auto& self_ = THPVariable_Unpack(self);
1379   return torch::autograd::utils::wrap(self_.is_cpu());
1380   END_HANDLE_TH_ERRORS
1381 }
1382 
THPVariable_is_cuda(THPVariable * self,void * unused)1383 PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) {
1384   HANDLE_TH_ERRORS
1385   if (check_has_torch_function((PyObject*)self)) {
1386     return handle_torch_function_getter(self, "is_cuda");
1387   }
1388   auto& self_ = THPVariable_Unpack(self);
1389   return torch::autograd::utils::wrap(self_.is_cuda());
1390   END_HANDLE_TH_ERRORS
1391 }
1392 
THPVariable_is_mtia(THPVariable * self,void * unused)1393 PyObject* THPVariable_is_mtia(THPVariable* self, void* unused) {
1394   HANDLE_TH_ERRORS
1395   if (check_has_torch_function((PyObject*)self)) {
1396     return handle_torch_function_getter(self, "is_mtia");
1397   }
1398   auto& self_ = THPVariable_Unpack(self);
1399   return torch::autograd::utils::wrap(self_.is_mtia());
1400   END_HANDLE_TH_ERRORS
1401 }
1402 
THPVariable_is_xla(THPVariable * self,void * unused)1403 PyObject* THPVariable_is_xla(THPVariable* self, void* unused) {
1404   HANDLE_TH_ERRORS
1405   if (check_has_torch_function((PyObject*)self)) {
1406     return handle_torch_function_getter(self, "is_xla");
1407   }
1408   auto& self_ = THPVariable_Unpack(self);
1409   return torch::autograd::utils::wrap(self_.is_xla());
1410   END_HANDLE_TH_ERRORS
1411 }
1412 
THPVariable_is_ipu(THPVariable * self,void * unused)1413 PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) {
1414   HANDLE_TH_ERRORS
1415   if (check_has_torch_function((PyObject*)self)) {
1416     return handle_torch_function_getter(self, "is_ipu");
1417   }
1418   auto& self_ = THPVariable_Unpack(self);
1419   return torch::autograd::utils::wrap(self_.is_ipu());
1420   END_HANDLE_TH_ERRORS
1421 }
1422 
THPVariable_is_xpu(THPVariable * self,void * unused)1423 PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) {
1424   HANDLE_TH_ERRORS
1425   if (check_has_torch_function((PyObject*)self)) {
1426     return handle_torch_function_getter(self, "is_xpu");
1427   }
1428   auto& self_ = THPVariable_Unpack(self);
1429   return torch::autograd::utils::wrap(self_.is_xpu());
1430   END_HANDLE_TH_ERRORS
1431 }
1432 
THPVariable_is_sparse(THPVariable * self,void * unused)1433 PyObject* THPVariable_is_sparse(THPVariable* self, void* unused) {
1434   HANDLE_TH_ERRORS
1435   if (check_has_torch_function((PyObject*)self)) {
1436     return handle_torch_function_getter(self, "is_sparse");
1437   }
1438   auto& self_ = THPVariable_Unpack(self);
1439   return torch::autograd::utils::wrap(self_.is_sparse());
1440   END_HANDLE_TH_ERRORS
1441 }
1442 
THPVariable_is_sparse_csr(THPVariable * self,void * unused)1443 PyObject* THPVariable_is_sparse_csr(THPVariable* self, void* unused) {
1444   HANDLE_TH_ERRORS
1445   if (check_has_torch_function((PyObject*)self)) {
1446     return handle_torch_function_getter(self, "is_sparse_csr");
1447   }
1448   auto& self_ = THPVariable_Unpack(self);
1449   return torch::autograd::utils::wrap(self_.is_sparse_csr());
1450   END_HANDLE_TH_ERRORS
1451 }
1452 
THPVariable_is_mkldnn(THPVariable * self,void * unused)1453 PyObject* THPVariable_is_mkldnn(THPVariable* self, void* unused) {
1454   HANDLE_TH_ERRORS
1455   if (check_has_torch_function((PyObject*)self)) {
1456     return handle_torch_function_getter(self, "is_mkldnn");
1457   }
1458   auto& self_ = THPVariable_Unpack(self);
1459   return torch::autograd::utils::wrap(self_.is_mkldnn());
1460   END_HANDLE_TH_ERRORS
1461 }
1462 
THPVariable_is_mps(THPVariable * self,void * unused)1463 PyObject* THPVariable_is_mps(THPVariable* self, void* unused) {
1464   HANDLE_TH_ERRORS
1465   if (check_has_torch_function((PyObject*)self)) {
1466     return handle_torch_function_getter(self, "is_mps");
1467   }
1468   auto& self_ = THPVariable_Unpack(self);
1469   return torch::autograd::utils::wrap(self_.is_mps());
1470   END_HANDLE_TH_ERRORS
1471 }
1472 
THPVariable_is_maia(THPVariable * self,void * unused)1473 PyObject* THPVariable_is_maia(THPVariable* self, void* unused) {
1474   HANDLE_TH_ERRORS
1475   if (check_has_torch_function((PyObject*)self)) {
1476     return handle_torch_function_getter(self, "is_maia");
1477   }
1478   auto& self_ = THPVariable_Unpack(self);
1479   return torch::autograd::utils::wrap(self_.is_maia());
1480   END_HANDLE_TH_ERRORS
1481 }
1482 
THPVariable_is_vulkan(THPVariable * self,void * unused)1483 PyObject* THPVariable_is_vulkan(THPVariable* self, void* unused) {
1484   HANDLE_TH_ERRORS
1485   if (check_has_torch_function((PyObject*)self)) {
1486     return handle_torch_function_getter(self, "is_vulkan");
1487   }
1488   auto& self_ = THPVariable_Unpack(self);
1489   return torch::autograd::utils::wrap(self_.is_vulkan());
1490   END_HANDLE_TH_ERRORS
1491 }
1492 
THPVariable_is_quantized(THPVariable * self,void * unused)1493 PyObject* THPVariable_is_quantized(THPVariable* self, void* unused) {
1494   HANDLE_TH_ERRORS
1495   if (check_has_torch_function((PyObject*)self)) {
1496     return handle_torch_function_getter(self, "is_quantized");
1497   }
1498   auto& self_ = THPVariable_Unpack(self);
1499   return torch::autograd::utils::wrap(self_.is_quantized());
1500   END_HANDLE_TH_ERRORS
1501 }
1502 
THPVariable_is_meta(THPVariable * self,void * unused)1503 PyObject* THPVariable_is_meta(THPVariable* self, void* unused) {
1504   HANDLE_TH_ERRORS
1505   if (check_has_torch_function((PyObject*)self)) {
1506     return handle_torch_function_getter(self, "is_meta");
1507   }
1508   auto& self_ = THPVariable_Unpack(self);
1509   return torch::autograd::utils::wrap(self_.is_meta());
1510   END_HANDLE_TH_ERRORS
1511 }
1512 
THPVariable_is_complex(THPVariable * self,void * unused)1513 PyObject* THPVariable_is_complex(THPVariable* self, void* unused) {
1514   HANDLE_TH_ERRORS
1515   if (check_has_torch_function((PyObject*)self)) {
1516     return handle_torch_function_getter(self, "is_complex");
1517   }
1518   auto& self_ = THPVariable_Unpack(self);
1519   return torch::autograd::utils::wrap(self_.is_complex());
1520   END_HANDLE_TH_ERRORS
1521 }
1522 
THPVariable_is_nested(THPVariable * self,void * unused)1523 PyObject* THPVariable_is_nested(THPVariable* self, void* unused) {
1524   HANDLE_TH_ERRORS
1525   if (check_has_torch_function((PyObject*)self)) {
1526     return handle_torch_function_getter(self, "is_nested");
1527   }
1528   auto& self_ = THPVariable_Unpack(self);
1529   return torch::autograd::utils::wrap(self_.is_nested());
1530   END_HANDLE_TH_ERRORS
1531 }
1532 
THPVariable_has_symbolic_sizes_strides(THPVariable * self,void * unused)1533 PyObject* THPVariable_has_symbolic_sizes_strides(
1534     THPVariable* self,
1535     void* unused) {
1536   HANDLE_TH_ERRORS
1537   auto& self_ = THPVariable_Unpack(self);
1538   return torch::autograd::utils::wrap(
1539       self_.unsafeGetTensorImpl()->has_symbolic_sizes_strides());
1540   END_HANDLE_TH_ERRORS
1541 }
1542 
THPVariable_dtype(THPVariable * self,void * unused)1543 static PyObject* THPVariable_dtype(THPVariable* self, void* unused) {
1544   HANDLE_TH_ERRORS
1545   if (check_has_torch_function((PyObject*)self)) {
1546     return handle_torch_function_getter(self, "dtype");
1547   }
1548   auto& self_ = THPVariable_Unpack(self);
1549   return torch::autograd::utils::wrap(self_.scalar_type());
1550   END_HANDLE_TH_ERRORS
1551 }
1552 
THPVariable_layout(THPVariable * self,void * unused)1553 static PyObject* THPVariable_layout(THPVariable* self, void* unused) {
1554   HANDLE_TH_ERRORS
1555   if (check_has_torch_function((PyObject*)self)) {
1556     return handle_torch_function_getter(self, "layout");
1557   }
1558   auto& self_ = THPVariable_Unpack(self);
1559   return torch::autograd::utils::wrap(self_.layout());
1560   END_HANDLE_TH_ERRORS
1561 }
1562 
THPVariable_device(THPVariable * self,void * unused)1563 static PyObject* THPVariable_device(THPVariable* self, void* unused) {
1564   HANDLE_TH_ERRORS
1565   if (check_has_torch_function((PyObject*)self)) {
1566     return handle_torch_function_getter(self, "device");
1567   }
1568   return THPDevice_New(THPVariable_Unpack(self).device());
1569   END_HANDLE_TH_ERRORS
1570 }
1571 
THPVariable_get_nbytes(THPVariable * self,void * unused)1572 static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
1573   HANDLE_TH_ERRORS
1574   if (check_has_torch_function((PyObject*)self)) {
1575     return handle_torch_function_getter(self, "nbytes");
1576   }
1577   return PyLong_FromSize_t(THPVariable_Unpack(self).nbytes());
1578   END_HANDLE_TH_ERRORS
1579 }
1580 
THPVariable_get_itemsize(THPVariable * self,void * unused)1581 static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
1582   HANDLE_TH_ERRORS
1583   if (check_has_torch_function((PyObject*)self)) {
1584     return handle_torch_function_getter(self, "itemsize");
1585   }
1586   return PyLong_FromSize_t(THPVariable_Unpack(self).itemsize());
1587   END_HANDLE_TH_ERRORS
1588 }
1589 
THPVariable_set_real(PyObject * self,PyObject * real,void * unused)1590 int THPVariable_set_real(PyObject* self, PyObject* real, void* unused) {
1591   HANDLE_TH_ERRORS
1592   auto& self_ = THPVariable_Unpack(self);
1593   auto self_real = at::real(self_);
1594   auto real_ = valueToTensor(self_real.options(), real, self_real.device());
1595   {
1596     pybind11::gil_scoped_release no_gil;
1597     self_real.copy_(real_);
1598     return 0;
1599   }
1600   END_HANDLE_TH_ERRORS_RET(-1)
1601 }
1602 
THPVariable_set_imag(PyObject * self,PyObject * imag,void * unused)1603 int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) {
1604   HANDLE_TH_ERRORS
1605   auto& self_ = THPVariable_Unpack(self);
1606   auto self_imag = at::imag(self_);
1607   auto imag_ = valueToTensor(self_imag.options(), imag, self_imag.device());
1608   {
1609     pybind11::gil_scoped_release no_gil;
1610     self_imag.copy_(imag_);
1611     return 0;
1612   }
1613   END_HANDLE_TH_ERRORS_RET(-1)
1614 }
1615 
THPVariable__use_count(PyObject * self,PyObject * noargs)1616 PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) {
1617   HANDLE_TH_ERRORS
1618   const auto& t = THPVariable_Unpack(self);
1619   return THPUtils_packUInt64(t.use_count());
1620   END_HANDLE_TH_ERRORS
1621 }
1622 
1623 // properties are registered here because we are currently only able to bind
1624 // them manually. TODO: make declarable in native_functions
1625 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1626 static struct PyGetSetDef THPVariable_properties[] = {
1627     {"_python_dispatch",
1628      (getter)THPVariable_get_python_dispatch,
1629      nullptr,
1630      nullptr,
1631      nullptr},
1632     {"T", (getter)PropertyT::getter, nullptr, nullptr, nullptr},
1633     {"H", (getter)PropertyH::getter, nullptr, nullptr, nullptr},
1634     {"mT", (getter)PropertymT::getter, nullptr, nullptr, nullptr},
1635     {"mH", (getter)PropertymH::getter, nullptr, nullptr, nullptr},
1636     {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr},
1637     {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr},
1638     {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr},
1639     {"_grad_fn",
1640      (getter)THPVariable_get_grad_fn,
1641      (setter)THPVariable_set_grad_fn,
1642      nullptr,
1643      nullptr},
1644     {"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr},
1645     {"retains_grad",
1646      (getter)THPVariable_retains_grad,
1647      nullptr,
1648      nullptr,
1649      nullptr},
1650     {"data",
1651      (getter)PropertyData::getter,
1652      (setter)THPVariable_set_data,
1653      nullptr,
1654      nullptr},
1655     {"_grad",
1656      (getter)PropertyGrad::getter,
1657      (setter)THPVariable_set_grad,
1658      nullptr,
1659      nullptr}, // Allows the python class to override .grad
1660     {"grad",
1661      (getter)PropertyGrad::getter,
1662      (setter)THPVariable_set_grad,
1663      nullptr,
1664      nullptr},
1665     {"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr},
1666     {"volatile",
1667      (getter)THPVariable_get_volatile,
1668      (setter)THPVariable_set_volatile,
1669      nullptr,
1670      nullptr},
1671     {"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr},
1672     {"requires_grad",
1673      (getter)THPVariable_get_requires_grad,
1674      (setter)THPVariable_set_requires_grad,
1675      nullptr,
1676      nullptr},
1677     {"_backward_hooks",
1678      (getter)THPVariable_get_backwards_hooks,
1679      (setter)THPVariable_set_backwards_hooks,
1680      nullptr,
1681      nullptr},
1682     {"_post_accumulate_grad_hooks",
1683      (getter)THPVariable_get_post_accumulate_grad_hooks,
1684      (setter)THPVariable_set_post_accumulate_grad_hooks,
1685      nullptr,
1686      nullptr},
1687     {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
1688     {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
1689     {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
1690     {"is_mtia", (getter)THPVariable_is_mtia, nullptr, nullptr, nullptr},
1691     {"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr},
1692     {"is_xla", (getter)THPVariable_is_xla, nullptr, nullptr, nullptr},
1693     {"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
1694     {"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr},
1695     {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
1696     {"is_sparse_csr",
1697      (getter)THPVariable_is_sparse_csr,
1698      nullptr,
1699      nullptr,
1700      nullptr},
1701     {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
1702     {"is_mps", (getter)THPVariable_is_mps, nullptr, nullptr, nullptr},
1703     {"is_maia", (getter)THPVariable_is_maia, nullptr, nullptr, nullptr},
1704     {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
1705     {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
1706     {"is_quantized",
1707      (getter)THPVariable_is_quantized,
1708      nullptr,
1709      nullptr,
1710      nullptr},
1711     {"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
1712     {"is_nested", (getter)THPVariable_is_nested, nullptr, nullptr, nullptr},
1713     {"_has_symbolic_sizes_strides",
1714      (getter)THPVariable_has_symbolic_sizes_strides,
1715      nullptr,
1716      nullptr,
1717      nullptr},
1718     {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr},
1719     {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
1720     {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
1721     {"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
1722     {"nbytes", (getter)THPVariable_get_nbytes, nullptr, nullptr, nullptr},
1723     {"itemsize", (getter)THPVariable_get_itemsize, nullptr, nullptr, nullptr},
1724     {"names",
1725      (getter)THPVariable_get_names,
1726      (setter)THPVariable_set_names,
1727      nullptr,
1728      nullptr},
1729     {"real",
1730      (getter)PropertyReal::getter,
1731      (setter)THPVariable_set_real,
1732      nullptr,
1733      nullptr},
1734     {"imag",
1735      (getter)PropertyImag::getter,
1736      (setter)THPVariable_set_imag,
1737      nullptr,
1738      nullptr},
1739     {nullptr}};
1740 
1741 static PyMappingMethods THPVariable_as_mapping = {
1742     THPVariable_length,
1743     THPVariable_getitem,
1744     THPVariable_setitem,
1745 };
1746 
1747 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1748 static PyMethodDef extra_methods[] = {
1749     {"as_subclass",
1750      castPyCFunctionWithKeywords(THPVariable_as_subclass),
1751      METH_VARARGS | METH_KEYWORDS,
1752      nullptr},
1753     {"_make_subclass",
1754      castPyCFunctionWithKeywords(THPVariable_make_subclass),
1755      METH_STATIC | METH_VARARGS | METH_KEYWORDS,
1756      nullptr},
1757     {"_make_wrapper_subclass",
1758      castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass),
1759      METH_STATIC | METH_VARARGS | METH_KEYWORDS,
1760      nullptr},
1761     {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
1762     {"_view_func",
1763      castPyCFunctionWithKeywords(THPVariable_view_func),
1764      METH_VARARGS | METH_KEYWORDS,
1765      nullptr},
1766     {"_view_func_unsafe",
1767      castPyCFunctionWithKeywords(THPVariable_view_func_unsafe),
1768      METH_VARARGS | METH_KEYWORDS,
1769      nullptr},
1770     {"_rev_view_func_unsafe",
1771      THPVariable_rev_view_func_unsafe,
1772      METH_O,
1773      nullptr},
1774     {"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
1775     {nullptr}};
1776 
1777 struct THPVariableMeta {
1778   PyHeapTypeObject base;
1779 };
1780 
1781 int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs);
1782 
1783 PyTypeObject THPVariableMetaType = {
1784     PyVarObject_HEAD_INIT(
1785         DEFERRED_ADDRESS(&PyType_Type),
1786         0) "torch._C._TensorMeta", /* tp_name */
1787     sizeof(THPVariableMeta), /* tp_basicsize */
1788     0, /* tp_itemsize */
1789     nullptr, /* tp_dealloc */
1790     0, /* tp_vectorcall_offset */
1791     nullptr, /* tp_getattr */
1792     nullptr, /* tp_setattr */
1793     nullptr, /* tp_reserved */
1794     nullptr, /* tp_repr */
1795     nullptr, /* tp_as_number */
1796     nullptr, /* tp_as_sequence */
1797     nullptr, /* tp_as_mapping */
1798     nullptr, /* tp_hash  */
1799     nullptr, /* tp_call */
1800     nullptr, /* tp_str */
1801     nullptr, /* tp_getattro */
1802     nullptr, /* tp_setattro */
1803     nullptr, /* tp_as_buffer */
1804     // NOLINTNEXTLINE(misc-redundant-expression)
1805     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
1806     nullptr, /* tp_doc */
1807     nullptr, /* tp_traverse */
1808     nullptr, /* tp_clear */
1809     nullptr, /* tp_richcompare */
1810     0, /* tp_weaklistoffset */
1811     nullptr, /* tp_iter */
1812     nullptr, /* tp_iternext */
1813     nullptr, /* tp_methods */
1814     nullptr, /* tp_members */
1815     nullptr, /* tp_getset */
1816     DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
1817     nullptr, /* tp_dict */
1818     nullptr, /* tp_descr_get */
1819     nullptr, /* tp_descr_set */
1820     0, /* tp_dictoffset */
1821     THPVariableMetaType_init, /* tp_init */
1822     nullptr, /* tp_alloc */
1823     nullptr, /* tp_new */
1824 };
1825 
1826 PyTypeObject THPVariableType = {
1827     PyVarObject_HEAD_INIT(
1828         &THPVariableMetaType,
1829         0) "torch._C.TensorBase", /* tp_name */
1830     sizeof(THPVariable), /* tp_basicsize */
1831     0, /* tp_itemsize */
1832     // This is unspecified, because it is illegal to create a THPVariableType
1833     // directly.  Subclasses will have their tp_dealloc set appropriately
1834     // by the metaclass
1835     nullptr, /* tp_dealloc */
1836     0, /* tp_vectorcall_offset */
1837     nullptr, /* tp_getattr */
1838     nullptr, /* tp_setattr */
1839     nullptr, /* tp_reserved */
1840     nullptr, /* tp_repr */
1841     nullptr, /* tp_as_number */
1842     nullptr, /* tp_as_sequence */
1843     &THPVariable_as_mapping, /* tp_as_mapping */
1844     nullptr, /* tp_hash  */
1845     nullptr, /* tp_call */
1846     nullptr, /* tp_str */
1847     nullptr, /* tp_getattro */
1848     nullptr, /* tp_setattro */
1849     nullptr, /* tp_as_buffer */
1850     // NOLINTNEXTLINE(misc-redundant-expression)
1851     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
1852         Py_TPFLAGS_HAVE_GC, /* tp_flags */
1853     nullptr, /* tp_doc */
1854     // Also set by metaclass
1855     (traverseproc)THPFunction_traverse, /* tp_traverse */
1856     (inquiry)THPVariable_clear, /* tp_clear */
1857     nullptr, /* tp_richcompare */
1858     0, /* tp_weaklistoffset */
1859     nullptr, /* tp_iter */
1860     nullptr, /* tp_iternext */
1861     nullptr, /* tp_methods */
1862     nullptr, /* tp_members */
1863     THPVariable_properties, /* tp_getset */
1864     nullptr, /* tp_base */
1865     nullptr, /* tp_dict */
1866     nullptr, /* tp_descr_get */
1867     nullptr, /* tp_descr_set */
1868     0, /* tp_dictoffset */
1869     nullptr, /* tp_init */
1870     nullptr, /* tp_alloc */
1871     // Although new is provided here, it is illegal to call this with cls ==
1872     // THPVariableMeta.  Instead, subclass it first and then construct it
1873     THPVariable_pynew, /* tp_new */
1874 };
1875 
THPVariable_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)1876 PyObject* THPVariable_pynew(
1877     PyTypeObject* type,
1878     PyObject* args,
1879     PyObject* kwargs) {
1880   HANDLE_TH_ERRORS
1881   TORCH_CHECK(
1882       type != &THPVariableType,
1883       "Cannot directly construct TensorBase; subclass it and then construct that");
1884   jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
1885   auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
1886   // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
1887   // given a raw pointer that will refcount bump
1888   // NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
1889   // alias(), lift_fresh()) which can return Tensor subclasses.  We allow
1890   // these to be passed on directly.
1891   return THPVariable_NewWithVar(
1892       type,
1893       std::move(tensor),
1894       c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED,
1895       /*allow_preexisting_pyobj=*/true);
1896   END_HANDLE_TH_ERRORS
1897 }
1898 
1899 // NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
1900 // on subclasses.  It's never valid to construct a THPVariable so it's not
1901 // necessary to implement the dealloc for that case
THPVariable_subclass_dealloc(PyObject * self)1902 void THPVariable_subclass_dealloc(PyObject* self) {
1903   if (THPVariable_tryResurrect((THPVariable*)self))
1904     return;
1905 
1906   // This is like a crappy version of subtype_dealloc.
1907   // Unfortunately, we cannot directly delegate to
1908   // subtype_dealloc as it will start walking the parent
1909   // chain *starting with* the type of self, which will cause
1910   // us to go back to our custom dealloc.
1911   //
1912   // We have to replicate the subtype_dealloc logic to ensure
1913   // that finalizers are handled correctly
1914   PyTypeObject* type = Py_TYPE(self);
1915   TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
1916   TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented");
1917 
1918   PyObject_GC_UnTrack(self);
1919   // TODO: consider using trash can
1920 
1921   bool has_finalizer = type->tp_finalize || type->tp_del;
1922 
1923   if (type->tp_finalize) {
1924     PyObject_GC_Track(self);
1925     if (PyObject_CallFinalizerFromDealloc(self) < 0) {
1926       /* Resurrected */
1927       return;
1928     }
1929     PyObject_GC_UnTrack(self);
1930   }
1931 
1932   // base test is unnecessary as THPVariable does not set this
1933   if (type->tp_weaklistoffset) {
1934     PyObject_ClearWeakRefs(self);
1935   }
1936 
1937   if (type->tp_del) {
1938     PyObject_GC_Track(self);
1939     type->tp_del(self);
1940     if (Py_REFCNT(self) > 0) {
1941       /* Resurrected */
1942       return;
1943     }
1944     PyObject_GC_UnTrack(self);
1945   }
1946 
1947   if (has_finalizer) {
1948     /* New weakrefs could be created during the finalizer call.
1949        If this occurs, clear them out without calling their
1950        finalizers since they might rely on part of the object
1951        being finalized that has already been destroyed. */
1952     if (type->tp_weaklistoffset) {
1953       /* Modeled after GET_WEAKREFS_LISTPTR() */
1954       PyWeakReference** list =
1955           (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
1956       while (*list)
1957         _PyWeakref_ClearRef(*list);
1958     }
1959   }
1960 
1961   // Clear all slots until we get to base class THPVariableType
1962   {
1963     PyTypeObject* base = type;
1964     while (base != &THPVariableType) {
1965       if (Py_SIZE(base)) {
1966         clear_slots(base, self);
1967       }
1968       base = base->tp_base;
1969       TORCH_INTERNAL_ASSERT(base);
1970     }
1971   }
1972 
1973   // All Python defined classes have __dict__
1974   if (C10_LIKELY(type->tp_dictoffset)) {
1975     PyObject** dictptr = _PyObject_GetDictPtr(self);
1976     if (dictptr != nullptr) {
1977       PyObject* dict = *dictptr;
1978       if (dict != nullptr) {
1979         Py_DECREF(dict);
1980         *dictptr = nullptr;
1981       }
1982     }
1983   }
1984 
1985   // subtype_dealloc allows for this but we don't
1986   TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
1987 
1988   // Finally clear out the base THPVariable
1989   THPVariable_clear((THPVariable*)self);
1990   ((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
1991   Py_TYPE(self)->tp_free(self);
1992 
1993   // Python defined subclasses should always be on the heap
1994   TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
1995   Py_DECREF(type);
1996 }
1997 
1998 // Creates a new Python object for a Variable.  The status parameter
1999 // specifies what the interpreter tag status on the object is; for
2000 // example, if you ran check_pyobj, the return optional of this object
2001 // tells you if the tensor was already tagged or not so you can pass
2002 // TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where
2003 // var came from and can directly assert that it's DEFINITELY_UNINITIALIZED.
2004 // It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED.
THPVariable_NewWithVar(PyTypeObject * type,Variable _var,c10::impl::PyInterpreterStatus status,bool allow_preexisting_pyobj)2005 static PyObject* THPVariable_NewWithVar(
2006     PyTypeObject* type,
2007     Variable _var,
2008     c10::impl::PyInterpreterStatus status,
2009     bool allow_preexisting_pyobj) {
2010   // Make sure that the reinterpret into a THPVariable* will be valid
2011   TORCH_CHECK(
2012       PyType_IsSubtype(type, &THPVariableType),
2013       "Creating a Tensor subclass from a class ",
2014       "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
2015 
2016   // This function overwrite the Tensor's pyobj field without extra checks
2017   // Make sure it is not set otherwise we would leak memory
2018   auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
2019       getPyInterpreter(), /*ignore_hermetic_tls=*/false);
2020 
2021   // Under some circumstances, we may attempt to create a new Python
2022   // object for a variable that already has a Python object.  The most common
2023   // situation this can occur is if you have a TorchDispatchMode active that
2024   // is returning a subclass from lift_fresh (which is invoked to
2025   // appropriately "wrap" a constant tensor into whatever ambient modes are
2026   // active.)
2027   //
2028   // In general, it is impossible to handle this case compositionally.
2029   // Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
2030   // that is transforming all ops (including the internal lift_fresh call that
2031   // transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
2032   // BTensor, where ATensor and BTensor are completely unrelated subclasses
2033   // and there is no way to compose them.  There is no way to satisfy the user
2034   // request here: in particular, you can't just try to re-invoke the ATensor
2035   // constructor on the returned BTensor, because (1) this could cause an
2036   // infinite loop--we are already in ATensor.__new__ and (2) there isn't any
2037   // guarantee that ATensor.__new__ supports a single element constructor
2038   // anyway.
2039   //
2040   // However, a more common case is a user just called torch.Tensor([1, 2, 3]),
2041   // and a fake tensor mode is active.  Really, all you want is to get back
2042   // a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
2043   // would have returned a fake tensor (concretely, the way this happens
2044   // is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
2045   // turns into a FakeTensor when we call lift_fresh on this real tensor).
2046   // This case is compositional because FakeTensor is a subclass of Tensor, so
2047   // it's valid for us to return it in place of a Tensor.  So this is what we
2048   // do.
2049 
2050   if (mb_obj.has_value() && mb_obj.value()) {
2051     TORCH_CHECK(
2052         allow_preexisting_pyobj,
2053         "Creating a new Tensor subclass ",
2054         type->tp_name,
2055         " but the raw Tensor object is already associated to a python object ",
2056         "of type ",
2057         mb_obj.value()->ob_type->tp_name);
2058     // Even if we allow pre-existing PyObject, we don't allow completely
2059     // ignoring the requested type.  Check that we fulfilled a subtype
2060     // relation here.  In the common case the requested type is Tensor and
2061     // this always succeeds.
2062     PyObject* obj = *mb_obj;
2063     // Check if it's OK to just directly return the Python object without
2064     // allocating a new variable.  We just check that the existing Python
2065     // object is a subclass of the requested type.
2066     PyTypeObject* obj_type = Py_TYPE(obj);
2067     TORCH_CHECK(
2068         obj_type == type || PyType_IsSubtype(obj_type, type),
2069         "Creating a new Tensor subclass ",
2070         type->tp_name,
2071         " but the raw Tensor object is already associated to a python object ",
2072         "of type ",
2073         mb_obj.value()->ob_type->tp_name,
2074         " which is not a subclass of the "
2075         "requested type");
2076     // We may (in fact, we typically will) need to resurrect this
2077     return THPVariable_Wrap(std::move(_var));
2078   }
2079 
2080   PyObject* obj = type->tp_alloc(type, 0);
2081   if (obj) {
2082     auto v = (THPVariable*)obj;
2083     // TODO: named constructor to avoid default initialization
2084     new (&v->cdata) MaybeOwned<Variable>();
2085     if (c10::impl::HermeticPyObjectTLS::get_state()) {
2086       // Do NOT initialize pyobj field on the tensor, you own the C++
2087       v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
2088       TORCH_INTERNAL_ASSERT(
2089           !check_has_torch_dispatch(obj),
2090           "While HermeticPyObject was enabled, we attempted to create a tensor "
2091           "subclass with __torch_dispatch__.  This violates the invariant that "
2092           "operations in HermeticPyObject have equivalent C++ implementations. "
2093           "If your operator registered from Python operator registration isn't "
2094           "doing anything strange, there may be an internal PyTorch bug involving "
2095           "not appropriately disabling TorchDispatchMode before executing "
2096           "Python op registration.");
2097     } else {
2098       // Normal codepath
2099       v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
2100       const auto& var = THPVariable_Unpack(v);
2101       var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
2102           getPyInterpreter(), obj, status);
2103       if (check_has_torch_dispatch(obj)) {
2104         var.unsafeGetTensorImpl()->set_python_dispatch(true);
2105       }
2106     }
2107   }
2108   return obj;
2109 }
2110 
2111 /// NOTE [ PyObject Traversal ]
2112 ///
2113 /// PyObjects that are wrapping c++ objects can lead to non-trivial traverse
2114 /// logic and it can be tricky to know what to traverse and when. This note
2115 /// tries to clarify what is the danger here and a simple algorithm to choose
2116 /// how to write the tp_traverse and tp_clear functions. If you're not already
2117 /// familiar with how the CPython GC works, you should read this in-depth
2118 /// description: https://devguide.python.org/garbage_collector/
2119 ///
2120 /// The complexity for us comes from the fact that some c++ shared_ptr objects
2121 /// own references to python objects and are also owned both by other python
2122 /// objects and c++ objects. This means that to allow the GC to collect all
2123 /// cycles, we need to properly implement the traverse/clear methods that take
2124 /// into account these C++ ownership links.
2125 ///
2126 /// The main danger here comes from the fact that, while all python-related code
2127 /// is thread safe wrt the GC execution (thanks to the GIL), other threads might
2128 /// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
2129 /// going up or down in between the different traverse/clear invocations. The
2130 /// one constraint we add here that is not explicitly mentioned in the GC
2131 /// description above is that for a given GC run (meaning while the GIL is
2132 /// held), the traverse/clear pair should never report different ownership
2133 /// relations: if traverse visited a given PyObject, then the clear within that
2134 /// same GC run must still be the sole owner and clear that PyObject.
2135 ///
2136 /// A more mechanical algorithm to know what to traverse/clear is as follows:
2137 ///   - Any field on this PyObject that contains a strong reference to another
2138 ///   PyObject
2139 ///     must be visited and cleared. An example of that is the "backward_hooks"
2140 ///     field of the THPVariable.
2141 ///   - Any field that contains a C++ object that is uniquely owned by this
2142 ///   PyObject (either
2143 ///     a unique_ptr or a shared_ptr with use_count==1) should have all the
2144 ///     PyObject it owns visited and cleared. An example would be here the
2145 ///     tensor hooks.
2146 ///   - If that uniquely owned C++ object also uniquely owns other C++ objects,
2147 ///   these should be
2148 ///     visited and cleared as well if they contain any PyObject.
2149 ///
2150 /// Caveat: to avoid slow runtime, we limit the depth of this exploration of C++
2151 /// objects in practice and we do not, for example, go through the whole
2152 /// autograd graph, even if it is uniquely owned. This is a known place where
2153 /// users can create noncollectable cycles as described in:
2154 /// https://github.com/pytorch/pytorch/issues/7343
2155 ///
2156 
traverse_slots(PyTypeObject * type,PyObject * self,visitproc visit,void * arg)2157 static int traverse_slots(
2158     PyTypeObject* type,
2159     PyObject* self,
2160     visitproc visit,
2161     void* arg) {
2162   auto n = Py_SIZE(type);
2163   auto mp = type->tp_members;
2164   for (Py_ssize_t i = 0; i < n; i++, mp++) {
2165     if (mp->type == T_OBJECT_EX) {
2166       char* addr = (char*)self + mp->offset;
2167       PyObject* obj = *(PyObject**)addr;
2168       if (obj != nullptr) {
2169         int err = visit(obj, arg);
2170         if (err)
2171           return err;
2172       }
2173     }
2174   }
2175   return 0;
2176 }
2177 
THPVariable_subclass_traverse(PyObject * self,visitproc visit,void * arg)2178 static int THPVariable_subclass_traverse(
2179     PyObject* self,
2180     visitproc visit,
2181     void* arg) {
2182   // If the tensor is eligible to be resurrected, don't traverse it; instead
2183   // treat all of its references as a root (as they WOULD be a root since we
2184   // can treat the inbound C++ references as root owners).
2185   //
2186   // This works because unlike conventional GCs, Python's GC operates in two
2187   // phases: first it uses traverse to discover roots, and then it uses traverse
2188   // to do reachability.  Bypassing traverse during root discovery forces Python
2189   // to treat self as a root for everything it refers to.  For a full
2190   // explanation of the algorithm see
2191   // https://devguide.python.org/garbage_collector/
2192   //
2193   // NB: if we don't hold an owning reference to the underlying Tensor, it is
2194   // possible that the underlying Tensor has already gone dead.  In that case,
2195   // it's not safe to access it.  But it's also safe to traverse, because if
2196   // the underlying Tensor *is* live, then root discovery will determine that
2197   // self is live, and nothing will get GC'ed anyway (resurrection cannot happen
2198   // if the C++ objects owns the PyObject)
2199   THPVariable* var = reinterpret_cast<THPVariable*>(self);
2200   if (isResurrectable(var)) {
2201     return 0;
2202   }
2203 
2204   // Crappy version of subtype_traverse; same deal as
2205   // THPVariable_subclass_dealloc
2206 
2207   PyTypeObject* type = Py_TYPE(self);
2208   // Traverse slots until we get to base class THPVariableType
2209   {
2210     PyTypeObject* base = type;
2211     while (base != &THPVariableType) {
2212       if (Py_SIZE(base)) {
2213         int err = traverse_slots(base, self, visit, arg);
2214         if (err)
2215           return err;
2216       }
2217       base = base->tp_base;
2218       TORCH_INTERNAL_ASSERT(base);
2219     }
2220   }
2221 
2222   // All Python defined classes have __dict__
2223   if (C10_LIKELY(type->tp_dictoffset)) {
2224     PyObject** dictptr = _PyObject_GetDictPtr(self);
2225     if (dictptr && *dictptr)
2226       Py_VISIT(*dictptr);
2227   }
2228 
2229   TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
2230   Py_VISIT(type);
2231 
2232   // Finally traverse THPVariable special stuff
2233   Py_VISIT(var->backward_hooks);
2234   Py_VISIT(var->post_accumulate_grad_hooks);
2235   if (!var->cdata.unsafeIsBorrowed()) {
2236     const auto& tensor = THPVariable_Unpack(var);
2237     if (tensor.defined()) {
2238       // WARNING: The grad_fn traversal logic is very subtle, if you change
2239       // this, be very careful not to re-introduce this bug:
2240       // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
2241 
2242       // We ensure that we follow NOTE [ PyObject Traversal ] he by checking
2243       // that this python object is the sole owner of the underlying Tensor and
2244       // that this Tensor is the sole owner of its grad_fn. In this case, the
2245       // only way to get a new reference to the grad_fn is by using this python
2246       // object, which requires the GIL to be accessed. Note that this is only
2247       // valid as long as user don't share non-owning references across
2248       // different threads (which is crazy and should never be done).
2249       auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
2250       if (tensor.use_count() == 1) {
2251         if (autograd_meta) {
2252           // Do NOT call grad_fn() here as that might trigger a recompute
2253           const auto& grad_fn = autograd_meta->grad_fn_;
2254           if (grad_fn && grad_fn.use_count() == 1) {
2255             // All Node can have a pyobj (stored in "pyobj_")
2256             Py_VISIT(grad_fn->pyobj());
2257             // PyNode are special as they also have an "obj" field
2258             if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
2259               Py_VISIT(py_node_fn->obj);
2260             }
2261           }
2262         }
2263       }
2264       if (autograd_meta) {
2265         for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
2266           if (auto pyhook =
2267                   dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
2268             Py_VISIT(pyhook->dict);
2269           }
2270         }
2271       }
2272     }
2273   }
2274 
2275   return 0;
2276 }
2277 
THPVariableMetaType_init(PyObject * cls,PyObject * args,PyObject * kwargs)2278 int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
2279   if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
2280     return -1;
2281   }
2282   ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
2283   ((PyTypeObject*)cls)->tp_traverse =
2284       (traverseproc)THPVariable_subclass_traverse;
2285 
2286   // Don't do anything for the base Tensor class
2287   if (!THPVariableClass) {
2288     return 0;
2289   }
2290 
2291   // Forbid subclassing _TensorBase directly
2292   py::tuple mro =
2293       py::reinterpret_borrow<py::tuple>(((PyTypeObject*)cls)->tp_mro);
2294   bool is_subclass_of_thpvariable = false;
2295   for (py::handle h : mro) {
2296     if (h.ptr() == THPVariableClass) {
2297       is_subclass_of_thpvariable = true;
2298       break;
2299     }
2300   }
2301   if (!is_subclass_of_thpvariable) {
2302     PyErr_SetString(PyExc_RuntimeError, "Cannot subclass _TensorBase directly");
2303     return -1;
2304   }
2305 
2306   // If the user provided a torch_dispatch implementation, disable
2307   // torch_function.
2308   py::object torch_dispatch_impl = py::reinterpret_steal<py::object>(
2309       PyObject_GetAttrString(cls, "__torch_dispatch__"));
2310   py::object torch_dispatch_default = py::reinterpret_steal<py::object>(
2311       PyObject_GetAttrString(THPVariableClass, "__torch_dispatch__"));
2312   if (torch_dispatch_impl.ptr() != torch_dispatch_default.ptr()) {
2313     py::object torch_function_impl = py::reinterpret_steal<py::object>(
2314         PyObject_GetAttrString(cls, "__torch_function__"));
2315     py::object torch_function_default_bound = py::reinterpret_steal<py::object>(
2316         PyObject_GetAttrString(THPVariableClass, "__torch_function__"));
2317 
2318     // Since our __torch_function__ is a classmethod, we need to "unbound" the
2319     // method to get the raw function
2320     py::object torch_function_default = py::reinterpret_steal<py::object>(
2321         PyObject_GetAttrString(torch_function_default_bound.ptr(), "__func__"));
2322 
2323     // User-defined __torch_function__ might not be a classmethod
2324     if (PyObject_HasAttrString(torch_function_impl.ptr(), "__func__")) {
2325       torch_function_impl = py::reinterpret_steal<py::object>(
2326           PyObject_GetAttrString(torch_function_impl.ptr(), "__func__"));
2327     }
2328     if (torch_function_impl.ptr() == torch_function_default.ptr()) {
2329       PyObject_SetAttrString(
2330           cls, "__torch_function__", torch::disabled_torch_function_impl());
2331     }
2332   }
2333 
2334   return 0;
2335 }
2336 
2337 namespace torch::autograd {
2338 
2339 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
2340 extern PyMethodDef variable_methods[];
2341 extern void initTorchFunctions(PyObject* module);
2342 
initTensorImplConversion(PyObject * module)2343 void initTensorImplConversion(PyObject* module) {
2344   auto m = py::handle(module).cast<py::module>();
2345   m.def("_wrap_tensor_impl", [](void* ptr) {
2346     auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
2347         unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
2348     TORCH_CHECK(p.defined(), "Can't wrap undefined tensor");
2349     auto tensor = at::Tensor::wrap_tensor_impl(std::move(p));
2350     return py::cast(std::move(tensor));
2351   });
2352   // set on the module level to avoid mixing pybind and plain CPython extensions
2353   m.def("_tensor_impl_raw_handle", [](torch::autograd::Variable* t) -> void* {
2354     // We return a raw non-owning pointer here, we rely on surrounding
2355     // code to keep the original tensor alive
2356     return t->getIntrusivePtr().get();
2357   });
2358 }
2359 } // namespace torch::autograd
2360 
THPVariable_initModule(PyObject * module)2361 bool THPVariable_initModule(PyObject* module) {
2362   THPVariableMetaType.tp_base = &PyType_Type;
2363   if (PyType_Ready(&THPVariableMetaType) < 0)
2364     return false;
2365   Py_INCREF(&THPVariableMetaType);
2366   PyModule_AddObject(module, "_TensorMeta", (PyObject*)&THPVariableMetaType);
2367 
2368   static std::vector<PyMethodDef> methods;
2369   THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods);
2370   THPUtils_addPyMethodDefs(methods, extra_methods);
2371   THPVariableType.tp_methods = methods.data();
2372   if (PyType_Ready(&THPVariableType) < 0)
2373     return false;
2374   Py_INCREF(&THPVariableType);
2375   PyModule_AddObject(module, "TensorBase", (PyObject*)&THPVariableType);
2376   PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType);
2377   torch::autograd::initTorchFunctions(module);
2378   torch::autograd::initTensorImplConversion(module);
2379   torch::utils::validate_numpy_for_dlpack_deleter_bug();
2380   return true;
2381 }
2382