1 #include <ATen/NamedTensorUtils.h>
2 #include <c10/core/DeviceType.h>
3 #include <c10/core/impl/GPUTrace.h>
4 #include <c10/core/impl/HermeticPyObjectTLS.h>
5 #include <c10/core/impl/PythonDispatcherTLS.h>
6 #include <c10/util/irange.h>
7 #include <pybind11/pytypes.h>
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/DynamicTypes.h>
10 #include <torch/csrc/Exceptions.h>
11 #include <torch/csrc/PyInterpreter.h>
12 #include <torch/csrc/Size.h>
13 #include <torch/csrc/THP.h>
14 #include <torch/csrc/Types.h>
15 #include <torch/csrc/autograd/autograd.h>
16 #include <torch/csrc/autograd/edge.h>
17 #include <torch/csrc/autograd/function.h>
18 #include <torch/csrc/autograd/python_cpp_function.h>
19 #include <torch/csrc/autograd/python_hook.h>
20 #include <torch/csrc/autograd/python_variable_indexing.h>
21 #include <torch/csrc/autograd/utils/error_messages.h>
22 #include <torch/csrc/autograd/utils/wrap_outputs.h>
23 #include <torch/csrc/autograd/variable.h>
24 #include <torch/csrc/jit/frontend/tracer.h>
25 #include <torch/csrc/jit/python/pybind_utils.h>
26 #include <torch/csrc/tensor/python_tensor.h>
27 #include <torch/csrc/utils/pybind.h>
28 #include <torch/csrc/utils/pycfunction_helpers.h>
29 #include <torch/csrc/utils/pyobject_preservation.h>
30 #include <torch/csrc/utils/python_arg_parser.h>
31 #include <torch/csrc/utils/python_dispatch.h>
32 #include <torch/csrc/utils/python_strings.h>
33 #include <torch/csrc/utils/tensor_new.h>
34 #include <torch/csrc/utils/tensor_numpy.h>
35
36 #include <torch/csrc/utils/torch_dispatch_mode.h>
37
38 #include <ATen/ATen.h>
39
40 #include <c10/core/SymIntArrayRef.h>
41 #include <structmember.h>
42 #include <cstdint>
43 #include <memory>
44 #include <utility>
45 #include <vector>
46
47 using namespace at;
48 using namespace torch;
49 using namespace torch::autograd;
50
parseIValuesToPyArgsKwargs(const c10::OperatorHandle & op,const std::vector<c10::IValue> & arguments)51 std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
52 const c10::OperatorHandle& op,
53 const std::vector<c10::IValue>& arguments) {
54 TORCH_CHECK(
55 PyGILState_Check(),
56 "GIL must be held before you call parseIValuesToPyArgsKwargs");
57 const auto& schema = op.schema();
58 py::dict kwargs;
59 // About all the pointers:
60 //
61 // f(int x, int y = 0, *, int z = 0)
62 // ^- arguments.size()
63 // ^- kwarg_only_start
64 // ^- positional_default_start
65 // ^- 0
66
67 // Find the split point between kwarg-only and regular. Since most functions
68 // don't have kwarg-only arguments, it is more efficient to scan from the
69 // right (but ideally, this would just be precomputed in FunctionSchema
70 // itself). (NB: minus one in the loop is because we're testing if the
71 // *next* argument is kwarg-only before we advance the starting index)
72 int64_t kwarg_only_start = static_cast<int64_t>(arguments.size());
73 for (; kwarg_only_start > 0; kwarg_only_start--) {
74 const auto& arg = schema.arguments()[kwarg_only_start - 1];
75 if (!arg.kwarg_only()) {
76 break;
77 }
78 }
79
80 // Find the first positional argument that isn't defaulted
81 auto is_default = [&](size_t idx) -> bool {
82 const auto& arg = schema.arguments()[idx];
83 if (!arg.default_value().has_value()) {
84 return false;
85 }
86 const auto& default_ivalue = *arg.default_value();
87 const auto& ivalue = arguments[idx];
88 if (default_ivalue != ivalue) {
89 return false;
90 }
91 return true;
92 };
93
94 int64_t positional_default_start = kwarg_only_start;
95 for (; positional_default_start > 0; positional_default_start--) {
96 if (!is_default(positional_default_start - 1)) {
97 break;
98 }
99 }
100
101 auto args =
102 py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
103
104 auto schemaAwareToPyObject = [&](size_t idx) -> py::object {
105 const auto& arg = schema.arguments()[idx];
106 auto match = [&](c10::TypeKind kind) {
107 const auto& t = arg.real_type();
108 if (t->kind() == kind)
109 return true;
110 if (auto opt_t = t->cast<c10::OptionalType>()) {
111 if (opt_t->getElementType()->kind() == kind)
112 return true;
113 }
114 return false;
115 };
116 if (arguments[idx].isNone()) {
117 return py::none();
118 } else if (match(c10::ScalarTypeType::Kind)) {
119 auto* obj =
120 getTHPDtype(static_cast<c10::ScalarType>(arguments[idx].toInt()));
121 return py::reinterpret_borrow<py::object>(
122 reinterpret_cast<PyObject*>(obj));
123 } else if (match(c10::LayoutType::Kind)) {
124 auto* obj =
125 getTHPLayout(static_cast<c10::Layout>(arguments[idx].toInt()));
126 return py::reinterpret_borrow<py::object>(
127 reinterpret_cast<PyObject*>(obj));
128 } else if (match(c10::MemoryFormatType::Kind)) {
129 return py::cast(static_cast<c10::MemoryFormat>(arguments[idx].toInt()));
130 } else {
131 return torch::jit::toPyObject(arguments[idx]);
132 }
133 };
134
135 // Populate positional arguments
136 for (const auto idx : c10::irange(positional_default_start)) {
137 PyTuple_SET_ITEM(
138 args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr());
139 }
140
141 // Populate keyword arguments
142 for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) {
143 // But don't populate default keyword arguments
144 if (is_default(idx))
145 continue;
146 const auto& arg = schema.arguments()[idx];
147 kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx);
148 }
149 return std::make_pair(std::move(args), std::move(kwargs));
150 }
151
pushPyOutToStack(const c10::OperatorHandle & op,torch::jit::Stack * stack,py::object out,const char * msg)152 void pushPyOutToStack(
153 const c10::OperatorHandle& op,
154 torch::jit::Stack* stack,
155 py::object out,
156 const char* msg) {
157 TORCH_CHECK(
158 PyGILState_Check(), "GIL must be held before you call pushPyOutToStack");
159 auto schema_returns = op.schema().returns();
160 const auto num_returns = schema_returns.size();
161 if (num_returns == 0) {
162 // Check that we got a None return from Python. Anything else is an error.
163 TORCH_CHECK(
164 out.is_none(),
165 "Expected ",
166 msg,
167 " for ",
168 op.operator_name(),
169 " to return None but it returned something else instead.");
170 } else if (num_returns == 1) {
171 torch::jit::push(
172 stack, torch::jit::toIValue(out.ptr(), schema_returns[0].real_type()));
173 } else {
174 auto outs = py::cast<py::sequence>(out);
175 for (const auto idx : c10::irange(outs.size())) {
176 torch::jit::push(
177 stack,
178 torch::jit::toIValue(
179 outs[idx].ptr(), schema_returns[idx].real_type()));
180 }
181 }
182 }
183
184 namespace {
185
parseSizesStridesPolicyArgument(c10::string_view arg)186 c10::TensorImpl::SizesStridesPolicy parseSizesStridesPolicyArgument(
187 c10::string_view arg) {
188 if (arg == "strides") {
189 return c10::TensorImpl::SizesStridesPolicy::CustomStrides;
190 }
191
192 if (arg == "sizes") {
193 return c10::TensorImpl::SizesStridesPolicy::CustomSizes;
194 }
195
196 TORCH_CHECK_VALUE(
197 false,
198 "Unknown sizes_strides_policy: ",
199 arg,
200 "; expected 'strides' or 'sizes'");
201 }
202 } // anonymous namespace
203
204 PyObject* THPVariableClass = nullptr;
205
206 PyObject* ParameterClass = nullptr;
207
208 static PyObject* THPVariable_NewWithVar(
209 PyTypeObject* type,
210 Variable _var,
211 c10::impl::PyInterpreterStatus status,
212 bool allow_preexisting_pyobj = false);
213
214 // clang-tidy gets confused by static const
215 static const char* VOLATILE_WARNING =
216 "volatile was removed and now has no effect. Use "
217 "`with torch.no_grad():` instead.";
218
check_has_torch_dispatch(PyObject * obj)219 static bool check_has_torch_dispatch(PyObject* obj) {
220 PyTypeObject* tp = Py_TYPE(obj);
221 if (THPVariable_CheckTypeExact(tp)) {
222 return false;
223 }
224 py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
225 return (
226 attr.ptr() != nullptr &&
227 attr.ptr() != torch::disabled_torch_dispatch_impl());
228 }
229
230 // NOLINTNEXTLINE(*-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
231 static PyObject* device_to_py_class_[static_cast<size_t>(
232 c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
233
registerPythonTensorClass(const std::string & device,PyObject * python_tensor_class)234 void registerPythonTensorClass(
235 const std::string& device,
236 PyObject* python_tensor_class) {
237 c10::Device dev(device);
238
239 TORCH_CHECK(
240 dev.type() == kXLA, "Only the python class for XLA can be overriden");
241 if (device_to_py_class_[static_cast<size_t>(dev.type())] != nullptr) {
242 TORCH_WARN(
243 "Overriding a previously registered python class for ", dev.str());
244 }
245
246 device_to_py_class_[static_cast<size_t>(dev.type())] = python_tensor_class;
247 }
248
getPythonTensorClass(c10::Device d)249 static PyObject* getPythonTensorClass(c10::Device d) {
250 return device_to_py_class_[static_cast<size_t>(d.type())];
251 }
252
activateGPUTrace()253 void activateGPUTrace() {
254 c10::impl::GPUTrace::set_trace(getPyInterpreter());
255 }
256
257 // TODO: Make this take Variable by const reference
THPVariable_Wrap(at::TensorBase var)258 PyObject* THPVariable_Wrap(at::TensorBase var) {
259 if (!var.defined()) {
260 Py_RETURN_NONE;
261 }
262
263 if (c10::impl::HermeticPyObjectTLS::get_state()) {
264 return THPVariable_NewWithVar(
265 (PyTypeObject*)THPVariableClass,
266 std::move(var),
267 c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
268 }
269
270 std::optional<PyObject*> mb_obj =
271 var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
272 getPyInterpreter(), /*ignore_hermetic_tls=*/false);
273 c10::impl::PyInterpreterStatus status{};
274 if (mb_obj.has_value()) {
275 auto obj = *mb_obj;
276 if (obj) {
277 if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
278 // C++ owns the Python object; this implies there weren't any other
279 // owning references to the Python object. Since we're making the
280 // object "live" again on Python side, let's flip back the ownership
281 // (Python owns C++) as it would now be unsound to deallocate the C++
282 // object if all C++ references go to zero
283 var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
284 reinterpret_cast<THPVariable*>(obj)->cdata =
285 MaybeOwned<Variable>::owned(std::move(var));
286 // NB: incref is not necessary, because we are "stealing" the previous
287 // ownership from the Variable to return it here for the wrap
288 return obj;
289 }
290 Py_INCREF(obj);
291 return obj;
292 }
293 // TODO: a better invariant is that if we tagged, we MUST have a valid
294 // PyObject. That's PyObject preservation
295 // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR
296 // being a thing, the PyObject field will get cleared when all references
297 // to the Python object are removed.
298 status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
299 } else {
300 // Assumption: if a Tensor has been shared across threads, this induces
301 // a refcount bump. Therefore, if the use count 1, we are the sole thread
302 // with access to this tensor and no race is possible.
303 if (var.use_count() <= 1) {
304 status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
305 } else {
306 status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
307 }
308 }
309
310 if (C10_LIKELY(var.device().type() != c10::kXLA)) {
311 return THPVariable_NewWithVar(
312 (PyTypeObject*)THPVariableClass, std::move(var), status);
313 }
314
315 if (auto clazz = getPythonTensorClass(var.device())) {
316 return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status);
317 }
318
319 return THPVariable_NewWithVar(
320 (PyTypeObject*)THPVariableClass, std::move(var), status);
321 }
322
isResurrectable(THPVariable * self)323 bool isResurrectable(THPVariable* self) {
324 // We want to divide this check into 2 cases.
325
326 // 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is
327 // true). You might think that in this case, it is impossible for tp_clear to
328 // be called: surely the C++ reference to the PyObject is keeping it live? And
329 // you'd be right! In fact, when C++ owns the PyObject, we have an invariant
330 // that the refcount on the PyObject should be precisely one (because if you
331 // take out another reference to the PyObject, we're supposed to flip the
332 // ownership pointer back). In reality, you can violate this invariant
333 // temporarily with weak references, so we don't test for it in asserts.
334
335 // 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is
336 // false). In this case, tp_clear can get called if the PyObject is referenced
337 // from a dead cycle, and nowhere else. But if resurrection did not occur,
338 // then the reference to C++ from the PyObject must be the ONLY reference to
339 // the C++ object.
340 if (self->cdata.unsafeIsBorrowed()) {
341 return false;
342 }
343 auto const& tensor = THPVariable_Unpack(self);
344 if (!tensor.defined() || tensor.use_count() <= 1) {
345 return false;
346 }
347 // Check if this is hermetic. If it is, no resurrection.
348 if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
349 getPyInterpreter(), /*ignore_hermetic_tls=*/false) !=
350 std::make_optional((PyObject*)self)) {
351 return false;
352 }
353 return true;
354 }
355
356 // returns true if successfully rezzed; if so, cancel the
357 // rest of deallocation
THPVariable_tryResurrect(THPVariable * self)358 static bool THPVariable_tryResurrect(THPVariable* self) {
359 const auto& tensor = THPVariable_Unpack(self);
360
361 if (!isResurrectable(self)) {
362 return false;
363 }
364
365 // At this point, we are definitely going to resurrect the tensor. So, the
366 // tensor better be defined :)
367 TORCH_INTERNAL_ASSERT(tensor.defined());
368
369 // There are other C++ owners of the tensor. Flip ownership
370 // so that C++ owns this Python object, and cancel deallocation.
371 TORCH_INTERNAL_ASSERT(
372 !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
373
374 c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
375 auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
376 getPyInterpreter(),
377 /*ignore_hermetic_tls=*/false);
378
379 TORCH_INTERNAL_ASSERT(
380 maybe_pyobj.has_value(),
381 "Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
382
383 tensor_impl->pyobj_slot()->set_owns_pyobj(true);
384
385 // Resurrect the Python object. This is something CPython does
386 // internally occasionally, see
387 // https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259
388 // so we just copy the pattern here. Note that we don't have to worry
389 // about saving and restoring the refcount (as the quoted code does)
390 // because we actually DO need to reset the refcount to one here, we
391 // can't assume that some other code has taken care of it.
392 // NB: this will overreport _Py_RefTotal but based on inspection of object.c
393 // there is no way to avoid this
394 #ifdef Py_TRACE_REFS
395 _Py_AddToAllObjects(reinterpret_cast<PyObject*>(self), 1);
396 #endif
397 Py_INCREF(self);
398
399 // Flip THPVariable to be non-owning
400 // (near use-after-free miss here: fresh MaybeOwned is created breaking
401 // reference on Tensor in struct BEFORE we overwrite the old one)
402 TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
403 self->cdata = MaybeOwned<Variable>::borrowed(tensor);
404
405 // NB: At this point, tensor *could* be dead (e.g., some other C++ thread
406 // decrefed it.) At this point, it is probably waiting on the GIL to
407 // deallocate the Python object and will kill self, BUT NOT YET.
408
409 return true;
410 }
411
THPVariable_clear(THPVariable * self)412 static int THPVariable_clear(THPVariable* self) {
413 // Is it OK for an object to still be live after running
414 // tp_clear? Yes. When Python is breaking reference cycles, it can't assume
415 // that an object will dealloc after it's cleared. The source code explicitly
416 // handles this case:
417 // https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025
418
419 // Note that we don't need to actually resurrect here. There are 2 cases:
420 // 1. The PyObject is not part of a reference cycle. In this case, we don't
421 // need to do anything. The GC will move on to try and break the reference
422 // cycle on another object, which will eventually trigger tp_dealloc (and thus
423 // resurrection).
424
425 // 2. The PyObject is part of a reference cycle. This case should not actually
426 // be possible, due to the logic in our tp_traverse
427 // (THPVariable_subclass_traverse).
428
429 // In fact, resurrecting here breaks the invariant that "C++ owns Python only
430 // when PyObject's refcount would otherwise be 0". Most immediately, as we're
431 // merely breaking reference cycles here, there can be other references to the
432 // PyObject. *However*, if other objects in the refcycle resurrect, then we
433 // will be in a state where the PyObject has multiple Python references, yet
434 // C++ owns the PyObject.
435
436 // See https://github.com/pytorch/pytorch/pull/75933 for more discussion.
437 if (isResurrectable((THPVariable*)self)) {
438 return 0;
439 }
440 Py_CLEAR(self->backward_hooks);
441 Py_CLEAR(self->post_accumulate_grad_hooks);
442 const auto& tensor = THPVariable_Unpack(self);
443 if (tensor.defined()) {
444 // Two situations to consider:
445 // PyObject -owns-> Tensor
446 // unsafeIsBorrowed() is FALSE. We're obligated to look through
447 // Tensor to break references. Clearing cdata must induce the
448 // destruction of the C++ Tensor. If there were other references
449 // to C++ tensor, the Python object would have been resurrected
450 // by flipping the ownership.
451 // Tensor -owns-> PyObject
452 // unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
453 // because Tensor asked us to (it's already destructing).
454
455 if (!self->cdata.unsafeIsBorrowed() &&
456 tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
457 getPyInterpreter(), /*ignore_hermetic_tls=*/false) ==
458 std::make_optional((PyObject*)self)) {
459 // TODO: empirically, on OS X this assert appears to be untrue
460 // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
461 // distributed/rpc/test_process_group_agent.py
462 //
463 // libc++abi.dylib: terminating with uncaught exception of type
464 // c10::Error:
465 // !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
466 // ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
467 // please report a bug to PyTorch. Exception raised from
468 // THPVariable_clear at
469 // ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
470 // first): frame #0: c10::Error::Error(c10::SourceLocation,
471 // std::__1::basic_string<char, std::__1::char_traits<char>,
472 // std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame
473 // #1: c10::detail::torchCheckFail(char const*, char const*, unsigned
474 // int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2:
475 // c10::detail::torchInternalAssertFail(char const*, char const*,
476 // unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
477 // (0x1141e3f89 in libtorch_python.dylib) frame #3:
478 // THPVariable_clear(THPVariable*) + 412 (0x1148a547c in
479 // libtorch_python.dylib) frame #4:
480 // THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
481 // libtorch_python.dylib) frame #5: (anonymous
482 // namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
483 // _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6:
484 // c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in
485 // libc10.dylib) frame #7:
486 // c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&)
487 // + 91 (0x11488c11b in libtorch_python.dylib) frame #8:
488 // THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in
489 // libtorch_python.dylib) <omitting python frames> frame #47: start + 1
490 // (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???)
491 // TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
492 if (auto grad_acc =
493 torch::autograd::impl::try_get_grad_accumulator(tensor)) {
494 grad_acc->pre_hooks().clear();
495 grad_acc->tensor_pre_hooks().clear();
496 grad_acc->retains_grad_hooks().clear();
497 }
498 }
499 }
500 TORCH_INTERNAL_ASSERT(!isResurrectable((THPVariable*)self));
501 {
502 // MapAllocator can take significant time to release large tensors;
503 // release the GIL here to avoid impacting main thread perf.
504 pybind11::gil_scoped_release no_gil;
505 self->cdata = MaybeOwned<Variable>();
506 }
507 return 0;
508 }
509
THPFunction_traverse(THPFunction * self,visitproc visit,void * arg)510 int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
511 TORCH_INTERNAL_ASSERT(
512 false, "Tensor tp_traverse function was not overriden properly");
513 return 0;
514 }
515
516 PyObject* THPVariable_pynew(
517 PyTypeObject* type,
518 PyObject* args,
519 PyObject* kwargs);
520
THPVariable_fix_weakref(PyObject * self,PyObject * noargs)521 static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
522 const auto& var = THPVariable_Unpack(self);
523 Py_DECREF(THPVariable_Wrap(var));
524 Py_RETURN_NONE;
525 }
526
527 // Maps the given python callable over a vector of items, returning a vector
528 // of the same type of items.
529 template <typename T>
map_py_func(const py::function & func,const std::vector<T> & items)530 static std::vector<T> map_py_func(
531 const py::function& func,
532 const std::vector<T>& items) {
533 std::vector<T> new_items;
534 new_items.reserve(items.size());
535 for (auto& item : items) {
536 new_items.emplace_back(py::cast<T>(func(item)));
537 }
538 return new_items;
539 }
540
541 template <>
map_py_func(const py::function & func,const std::vector<at::Tensor> & items)542 std::vector<at::Tensor> map_py_func(
543 const py::function& func,
544 const std::vector<at::Tensor>& items) {
545 std::vector<at::Tensor> new_items;
546 new_items.reserve(items.size());
547 for (auto& item : items) {
548 auto output = func(item);
549 if (output.is(py::none())) {
550 // treat None value as an undefined tensor
551 new_items.emplace_back();
552 } else {
553 new_items.emplace_back(py::cast<at::Tensor>(output));
554 }
555 }
556 return new_items;
557 }
558
view_func_impl(PyObject * _self,PyObject * args,PyObject * kwargs,bool check_has_same_meta)559 static PyObject* view_func_impl(
560 PyObject* _self,
561 PyObject* args,
562 PyObject* kwargs,
563 bool check_has_same_meta) {
564 HANDLE_TH_ERRORS
565 const auto& self = THPVariable_Unpack(_self);
566
567 static PythonArgParser parser({
568 "_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)",
569 });
570 ParsedArgs<3> parsed_args{};
571 auto r = parser.parse(_self, args, kwargs, parsed_args);
572 auto new_base = r.tensor(0);
573 PyObject* symint_visitor_fn = r.pyobject(1);
574 PyObject* tensor_visitor_fn = r.pyobject(2);
575
576 // Ensure that self is indeed a backward differentiable view
577 // If not, we return an undefined Tensor (None) and let the user handle it.
578 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
579 at::Tensor out;
580 if (diff_view_meta && diff_view_meta->has_bw_view()) {
581 const auto& view_info = diff_view_meta->get_backward_view();
582 // Ensure that the newly provided base is similar to the original base
583 if (!check_has_same_meta ||
584 torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
585 // Do the actual view replay
586 if (view_info.has_view_fn()) {
587 auto& view_func = view_info.view_fn();
588
589 // Determine new SymInt / tensor state as needed.
590 std::optional<std::vector<c10::SymInt>> new_symints = std::nullopt;
591 if (symint_visitor_fn != Py_None) {
592 new_symints = map_py_func(
593 py::cast<py::function>(symint_visitor_fn),
594 view_func.get_symints());
595 }
596
597 std::optional<std::vector<at::Tensor>> new_tensors = std::nullopt;
598 if (tensor_visitor_fn != Py_None) {
599 new_tensors = map_py_func(
600 py::cast<py::function>(tensor_visitor_fn),
601 view_func.get_tensors());
602 }
603
604 // call view func
605 if (new_symints.has_value() || new_tensors.has_value()) {
606 out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base);
607 } else {
608 out = view_func(new_base);
609 }
610 } else {
611 out = new_base.as_strided(
612 self.sizes(), self.strides(), self.storage_offset());
613 }
614 }
615 }
616 return THPVariable_Wrap(std::move(out));
617 END_HANDLE_TH_ERRORS
618 }
619
THPVariable_view_func(PyObject * self_,PyObject * args,PyObject * kwargs)620 static PyObject* THPVariable_view_func(
621 PyObject* self_,
622 PyObject* args,
623 PyObject* kwargs) {
624 return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true);
625 }
626
THPVariable_view_func_unsafe(PyObject * self_,PyObject * args,PyObject * kwargs)627 static PyObject* THPVariable_view_func_unsafe(
628 PyObject* self_,
629 PyObject* args,
630 PyObject* kwargs) {
631 return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false);
632 }
633
rev_view_func_impl(PyObject * self_,PyObject * arg)634 static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
635 HANDLE_TH_ERRORS
636 const auto& self = THPVariable_Unpack(self_);
637 TORCH_CHECK(
638 THPVariable_Check(arg),
639 "_rev_view_func expect a single argument that is a Tensor");
640 const auto& new_view = THPVariable_Unpack(arg);
641
642 // Ensure that self is indeed a backward differentiable view
643 // If not, we return an undefined Tensor (None) and let the user handle it.
644 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
645 at::Tensor out;
646 if (diff_view_meta && diff_view_meta->has_bw_view()) {
647 const auto& view_info = diff_view_meta->get_backward_view();
648 // Do the actual view replay
649 TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found");
650 out = view_info.rev_view_fn()(new_view);
651 }
652 return THPVariable_Wrap(std::move(out));
653 END_HANDLE_TH_ERRORS
654 }
655
THPVariable_rev_view_func_unsafe(PyObject * self_,PyObject * arg)656 static PyObject* THPVariable_rev_view_func_unsafe(
657 PyObject* self_,
658 PyObject* arg) {
659 return rev_view_func_impl(self_, arg);
660 }
661
662 // Instantiates a subclass of self with the same data.
THPVariable_as_subclass(PyObject * _self,PyObject * args,PyObject * kwargs)663 static PyObject* THPVariable_as_subclass(
664 PyObject* _self,
665 PyObject* args,
666 PyObject* kwargs) {
667 HANDLE_TH_ERRORS
668 const auto& self = THPVariable_Unpack(_self);
669 static PythonArgParser parser({
670 "as_subclass(PyObject* cls)",
671 });
672 ParsedArgs<1> parsed_args{};
673 auto r = parser.parse(_self, args, kwargs, parsed_args);
674 PyObject* cls = r.pyobject(0);
675 TORCH_CHECK_TYPE(
676 PyType_Check(cls),
677 "cls must be a type (got ",
678 Py_TYPE(cls)->tp_name,
679 ")");
680 return THPVariable_NewWithVar(
681 (PyTypeObject*)cls,
682 self.alias(),
683 c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
684 END_HANDLE_TH_ERRORS
685 }
686
THPVariable_make_subclass(PyObject * _ignored,PyObject * args,PyObject * kwargs)687 static PyObject* THPVariable_make_subclass(
688 PyObject* _ignored,
689 PyObject* args,
690 PyObject* kwargs) {
691 HANDLE_TH_ERRORS
692 static PythonArgParser parser({
693 "_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, Device? device_for_backend_keys=None)",
694 });
695 ParsedArgs<7> parsed_args{};
696 auto r = parser.parse(args, kwargs, parsed_args);
697 PyObject* cls = r.pyobject(0);
698 TORCH_CHECK_TYPE(
699 PyType_Check(cls),
700 "cls must be a type (got ",
701 Py_TYPE(cls)->tp_name,
702 ")");
703 // guard completely turns off torch dispatch modes, doesn't just pop off the
704 // stack
705 torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
706 c10::impl::DisablePythonDispatcher dpd_g;
707 auto data =
708 r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED)
709 // We set `data`'s `allow_tensor_metadata_change` to true here, because we
710 // want to allow the following use case for backward compatibility:
711 //
712 // ```python
713 // rnn = torch.nn.RNN(100, 100, 2)
714 // # The following calls `torch._cudnn_rnn_flatten_weight(rnn._flat_weights,
715 // ...)`, # which changes storage of `rnn`'s weights in-place
716 // rnn.flatten_parameters()
717 // ```
718 data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
719 data.set_requires_grad(r.toBool(2));
720 const auto sizes_strides_policy = r.stringViewOptional(3);
721 if (sizes_strides_policy.has_value()) {
722 data.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
723 parseSizesStridesPolicyArgument(*sizes_strides_policy));
724 }
725 if (r.toBool(4)) {
726 data.unsafeGetTensorImpl()->set_python_custom_device(true);
727 }
728 if (r.toBool(5)) {
729 data.unsafeGetTensorImpl()->set_python_custom_layout(true);
730 }
731 if (!r.isNone(6)) {
732 data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
733 }
734
735 return THPVariable_NewWithVar(
736 (PyTypeObject*)cls,
737 data,
738 c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
739 END_HANDLE_TH_ERRORS
740 }
741
THPVariable_make_wrapper_subclass(PyObject *,PyObject * args,PyObject * kwargs)742 static PyObject* THPVariable_make_wrapper_subclass(
743 PyObject*,
744 PyObject* args,
745 PyObject* kwargs) {
746 HANDLE_TH_ERRORS
747 // NB: pin_memory doesn't actually do anything
748 // TODO: strides variant?
749
750 // cls: Python subclass type
751 // size, strides, storage_offset, memory_format, dtype: self-explanatory
752 // layout: memory layout, e.g. for types of Nested Tensors or other sparse
753 // tensors
754 // pin_memory, requires_grad: self-explanatory
755 // dispatch_sizes_strides_policy: string - which sizes/strides we should
756 // dispatch to a custom python implementation.
757 // dispatch_device: whether to dispatch to a custom python implementation
758 // for device
759 // dispatch_layout: whether to dispatch to a custom python implementation
760 // for layout
761 // _extra_dispatch_keys: additional dispatch keys to add to the tensor
762 // storage_size: if provided, skip storage size calculation and just use the
763 // value provided. One use case is for Nested Tensor, where the
764 // storage size cannot be calculated from the sizes/strides
765 // (because they contain a NestedInt).
766 static PythonArgParser parser({
767 "_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef? strides=None, "
768 "SymInt? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
769 "Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
770 "c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, "
771 "DispatchKeySet _extra_dispatch_keys=None, SymInt? storage_size=None)",
772 });
773 ParsedArgs<15> parsed_args{};
774 auto r = parser.parse(args, kwargs, parsed_args);
775 PyObject* cls = r.pyobject(0);
776
777 TORCH_CHECK_TYPE(
778 PyType_Check(cls),
779 "cls must be a type (got ",
780 Py_TYPE(cls)->tp_name,
781 ")");
782
783 // This is an important safety check; without it, the default behavior will be
784 // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
785 // key, which will immediately segfault because the data pointer is null. By
786 // forcing users to define __torch_dispatch__ we ensure this does not happen
787 // TODO: This check is not complete; because the user can disable torch
788 // dispatch and then go again, triggering segfault. TBH I'm thinking I want
789 // to delete this function entirely
790 py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
791 TORCH_CHECK_TYPE(
792 attr.ptr() != nullptr &&
793 attr.ptr() != torch::disabled_torch_dispatch_impl(),
794 ((PyTypeObject*)cls)->tp_name,
795 " must define __torch_dispatch__");
796
797 const auto options = TensorOptions()
798 .dtype(r.scalartype(5))
799 .device(r.device(7))
800 .layout(r.layoutOptional(6))
801 // NB: long standing issue, requires_grad is not
802 // respected here; you have to set it post facto, see
803 // https://github.com/pytorch/pytorch/issues/26428
804 // .requires_grad(r.toBool(7))
805 .pinned_memory(r.toBool(8));
806
807 // don't bother releasing GIL here, as we are not allocating any nontrivial
808 // data
809 Tensor tensor;
810
811 {
812 AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
813 tracer::impl::NoTracerDispatchMode tracer_guard{};
814
815 auto sym_sizes = r.symintlist(1);
816 auto sym_strides_own = r.symintlistOptional(2);
817 auto sym_strides =
818 static_cast<std::optional<c10::SymIntArrayRef>>(sym_strides_own);
819 auto sym_storage_offset = r.toSymIntOptional(3);
820
821 c10::SymInt size_bytes;
822 auto dtype_itemsize = static_cast<int64_t>(options.dtype().itemsize());
823 auto storage_size = r.toSymIntOptional(14);
824
825 if (storage_size.has_value()) {
826 size_bytes = storage_size.value();
827 } else if (sym_strides.has_value()) {
828 size_bytes = at::detail::computeStorageNbytes(
829 sym_sizes,
830 sym_strides.value(),
831 dtype_itemsize,
832 sym_storage_offset.value_or(0));
833 } else {
834 size_bytes = at::detail::computeStorageNbytesContiguous(
835 sym_sizes, dtype_itemsize, sym_storage_offset.value_or(0));
836 }
837
838 // We use storages **only** to track aliasing of subclasses during tracing.
839 // The actual data pointers are not valid.
840 Storage storage{
841 Storage::use_byte_size_t{},
842 size_bytes,
843 /*allocator=*/c10::GetAllocator(c10::kMeta),
844 /*resizable=*/true};
845 // TODO: constructor should probably accept data pointer
846 storage.set_data_ptr_noswap(at::DataPtr{nullptr, r.device(7)});
847
848 auto keys = c10::DispatchKeySet({options.computeDispatchKey()});
849 if (auto mb_extra_keys = r.toDispatchKeySetOptional(13)) {
850 keys = keys | *mb_extra_keys;
851 }
852 tensor = at::detail::make_tensor<TensorImpl>(
853 std::move(storage), keys, options.dtype());
854
855 TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
856
857 if (sym_strides.has_value()) {
858 tensor_impl->set_sizes_and_strides(
859 sym_sizes, sym_strides.value(), sym_storage_offset);
860 } else {
861 TORCH_CHECK(
862 !sym_storage_offset.has_value(),
863 "setting storage offset without stride not supported");
864 tensor_impl->generic_set_sizes_contiguous(sym_sizes);
865 }
866
867 const auto sizes_strides_policy = r.stringViewOptional(10);
868 if (sizes_strides_policy.has_value()) {
869 tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
870 parseSizesStridesPolicyArgument(*sizes_strides_policy));
871 }
872 }
873
874 tensor.set_requires_grad(r.toBool(9));
875
876 if (r.toBool(11)) {
877 tensor.unsafeGetTensorImpl()->set_python_custom_device(true);
878 }
879 if (r.toBool(12)) {
880 tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
881 }
882
883 return THPVariable_NewWithVar(
884 (PyTypeObject*)cls,
885 tensor,
886 c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
887 END_HANDLE_TH_ERRORS
888 }
889
890 using getter = PyObject* (*)(PyObject*, void*);
891 using setter = int (*)(PyObject*, PyObject*, void*);
892
THPVariable_get_python_dispatch(THPVariable * self,void * unused)893 PyObject* THPVariable_get_python_dispatch(THPVariable* self, void* unused) {
894 HANDLE_TH_ERRORS
895 const auto& var = THPVariable_Unpack(self);
896 return torch::autograd::utils::wrap(
897 var.unsafeGetTensorImpl()->is_python_dispatch());
898 END_HANDLE_TH_ERRORS
899 }
900
901 // CRTP base class to implement the python bindings for a Tensor property in
902 // PyTorch A class that implements a property is expected to have:
903 // - static constexpr const char* name;
904 // - This variable should hold the Python name of the property
905 // - static Tensor fn(const Tensor&);
906 // - This function calls the relevant ATen on the tensor
907 template <typename T>
908 struct GetterBase {
getterGetterBase909 static PyObject* getter(THPVariable* self, void* /*unused*/) {
910 HANDLE_TH_ERRORS
911 if (check_has_torch_function((PyObject*)self)) {
912 return handle_torch_function_getter(self, T::name);
913 }
914 return THPVariable_Wrap(T::fn(THPVariable_Unpack(self)));
915 END_HANDLE_TH_ERRORS
916 }
917 };
918
919 struct PropertyT : GetterBase<PropertyT> {
920 static constexpr const char* name = "T";
fnPropertyT921 static Tensor fn(const Tensor& t) {
922 return t.numpy_T();
923 }
924 };
925
926 struct PropertyH : GetterBase<PropertyH> {
927 static constexpr const char* name = "H";
fnPropertyH928 static Tensor fn(const Tensor& t) {
929 return t.matrix_H();
930 }
931 };
932
933 struct PropertymT : GetterBase<PropertymT> {
934 static constexpr const char* name = "mT";
fnPropertymT935 static Tensor fn(const Tensor& t) {
936 return t.mT();
937 }
938 };
939
940 struct PropertymH : GetterBase<PropertymH> {
941 static constexpr const char* name = "mH";
fnPropertymH942 static Tensor fn(const Tensor& t) {
943 return t.mH();
944 }
945 };
946
947 struct PropertyData : GetterBase<PropertyData> {
948 static constexpr const char* name = "data";
fnPropertyData949 static Tensor fn(const Tensor& t) {
950 return t.variable_data();
951 }
952 };
953
954 struct PropertyGrad : GetterBase<PropertyGrad> {
955 static constexpr const char* name = "grad";
fnPropertyGrad956 static Tensor fn(const Tensor& t) {
957 return t.grad();
958 }
959 };
960
961 struct PropertyReal : GetterBase<PropertyReal> {
962 static constexpr const char* name = "real";
fnPropertyReal963 static Tensor fn(const Tensor& t) {
964 return at::real(t);
965 }
966 };
967
968 struct PropertyImag : GetterBase<PropertyImag> {
969 static constexpr const char* name = "imag";
fnPropertyImag970 static Tensor fn(const Tensor& t) {
971 return at::imag(t);
972 }
973 };
974
THPVariable_get_cdata(THPVariable * self,void * unused)975 PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) {
976 HANDLE_TH_ERRORS
977 if (check_has_torch_function((PyObject*)self)) {
978 return handle_torch_function_getter(self, "_cdata");
979 }
980 const auto& var = THPVariable_Unpack(self);
981 return PyLong_FromVoidPtr(var.unsafeGetTensorImpl());
982 END_HANDLE_TH_ERRORS
983 }
984
THPVariable_get_version(THPVariable * self,void * unused)985 PyObject* THPVariable_get_version(THPVariable* self, void* unused) {
986 HANDLE_TH_ERRORS
987 if (check_has_torch_function((PyObject*)self)) {
988 return handle_torch_function_getter(self, "_version");
989 }
990 const auto& var = THPVariable_Unpack(self);
991 return PyInt_FromLong(var._version());
992 END_HANDLE_TH_ERRORS
993 }
994
THPVariable_get_grad_fn(THPVariable * self,void * unused)995 PyObject* THPVariable_get_grad_fn(THPVariable* self, void* unused) {
996 HANDLE_TH_ERRORS
997 if (check_has_torch_function((PyObject*)self)) {
998 return handle_torch_function_getter(self, "grad_fn");
999 }
1000 const auto& var = THPVariable_Unpack(self);
1001 if (!var.grad_fn()) {
1002 Py_RETURN_NONE;
1003 }
1004 return functionToPyObject(var.grad_fn());
1005 END_HANDLE_TH_ERRORS
1006 }
1007
THPVariable_set_grad_fn(THPVariable * self,PyObject * obj,void * unused)1008 static int THPVariable_set_grad_fn(
1009 THPVariable* self,
1010 PyObject* obj,
1011 void* unused) {
1012 HANDLE_TH_ERRORS
1013 if (check_has_torch_function((PyObject*)self)) {
1014 return handle_torch_function_setter(self, "_grad_fn", obj);
1015 }
1016 TORCH_CHECK(obj, "Deletion of _grad_fn not allowed. Detach tensor instead!");
1017 TORCH_CHECK(obj == Py_None, "_grad_fn can be only set to None");
1018 THPVariable_Unpack(self).detach_();
1019 return 0;
1020 END_HANDLE_TH_ERRORS_RET(-1)
1021 }
1022
THPVariable_is_leaf(THPVariable * self,void * unused)1023 static PyObject* THPVariable_is_leaf(THPVariable* self, void* unused) {
1024 HANDLE_TH_ERRORS
1025 if (check_has_torch_function((PyObject*)self)) {
1026 return handle_torch_function_getter(self, "is_leaf");
1027 }
1028 return PyBool_FromLong(!THPVariable_Unpack(self).grad_fn());
1029 END_HANDLE_TH_ERRORS
1030 }
1031
THPVariable_set_data(THPVariable * self,PyObject * data,void * unused)1032 int THPVariable_set_data(THPVariable* self, PyObject* data, void* unused) {
1033 HANDLE_TH_ERRORS
1034 if (check_has_torch_function((PyObject*)self)) {
1035 return handle_torch_function_setter(self, "data", data);
1036 }
1037 TORCH_CHECK(
1038 data, "Deleting tensor data is not allowed. Delete tensor instead!");
1039 TORCH_CHECK_TYPE(
1040 THPVariable_Check(data),
1041 "Variable data has to be a tensor, but got ",
1042 Py_TYPE(data)->tp_name);
1043
1044 THPVariable_Unpack(self).set_data(THPVariable_Unpack(data));
1045 return 0;
1046 END_HANDLE_TH_ERRORS_RET(-1)
1047 }
1048
THPVariable_set_grad(THPVariable * self,PyObject * py_grad,void * unused)1049 int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) {
1050 HANDLE_TH_ERRORS
1051 if (check_has_torch_function((PyObject*)self)) {
1052 return handle_torch_function_setter(self, "grad", py_grad);
1053 }
1054 const auto& var = THPVariable_Unpack(self);
1055 if (!py_grad || py_grad == Py_None) {
1056 var.mutable_grad().reset();
1057 return 0;
1058 }
1059
1060 TORCH_CHECK_TYPE(
1061 THPVariable_Check(py_grad),
1062 "assigned grad expected to be a Tensor or None but got grad of type ",
1063 THPUtils_typename(py_grad));
1064 TORCH_CHECK(
1065 self != (THPVariable*)py_grad, "can't assign Variable as its own grad");
1066
1067 const auto& grad = THPVariable_Unpack(py_grad);
1068 TORCH_CHECK(
1069 var.dtype() == grad.dtype(),
1070 "attempting to assign a gradient with dtype '",
1071 grad.dtype(),
1072 "' to a tensor with dtype '",
1073 var.dtype(),
1074 "'. Please ensure that the gradient and the tensor have the same dtype");
1075 TORCH_CHECK(
1076 var.device().type() == grad.device().type(),
1077 "attempting to assign a gradient with device type '",
1078 grad.device().type(),
1079 "' to a tensor with device type '",
1080 var.device().type(),
1081 "'. Please ensure that the gradient and the tensor are on the same device");
1082 if (grad.layout() != kSparse) {
1083 TORCH_CHECK(
1084 grad.options().type_equal(var.options()),
1085 "attempting to assign a gradient to a tensor that has data of a different type");
1086 }
1087 TORCH_CHECK(
1088 grad.get_device() == var.get_device(),
1089 "attempting to assign a gradient located on device with index '",
1090 grad.get_device(),
1091 "' to a tensor located on device with index '",
1092 var.get_device(),
1093 "'. Please ensure that the gradient and the tensor are on the same device");
1094 TORCH_CHECK(
1095 grad.sym_sizes().equals(var.sym_sizes()),
1096 "attempting to assign a gradient of size '",
1097 grad.sym_sizes(),
1098 "' to a tensor of size '",
1099 var.sym_sizes(),
1100 "'. Please ensure that the gradient and the tensor are the same size");
1101
1102 var.mutable_grad() = grad;
1103 return 0;
1104 END_HANDLE_TH_ERRORS_RET(-1)
1105 }
1106
THPVariable_get_volatile(THPVariable * self,void * unused)1107 PyObject* THPVariable_get_volatile(THPVariable* self, void* unused) {
1108 HANDLE_TH_ERRORS
1109 if (check_has_torch_function((PyObject*)self)) {
1110 return handle_torch_function_getter(self, "volatile");
1111 }
1112 const char* msg = "volatile was removed (Variable.volatile is always False)";
1113 auto r = PyErr_WarnEx(PyExc_UserWarning, msg, 1);
1114 if (r != 0)
1115 throw python_error();
1116 Py_RETURN_FALSE;
1117 END_HANDLE_TH_ERRORS
1118 }
1119
THPVariable_set_volatile(THPVariable * self,PyObject * obj,void * unused)1120 int THPVariable_set_volatile(THPVariable* self, PyObject* obj, void* unused) {
1121 HANDLE_TH_ERRORS
1122 if (check_has_torch_function((PyObject*)self)) {
1123 return handle_torch_function_setter(self, "volatile", obj);
1124 }
1125 auto r = PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1);
1126 if (r != 0)
1127 throw python_error();
1128 return 0;
1129 END_HANDLE_TH_ERRORS_RET(-1)
1130 }
1131
THPVariable_get_output_nr(THPVariable * self,void * unused)1132 PyObject* THPVariable_get_output_nr(THPVariable* self, void* unused) {
1133 HANDLE_TH_ERRORS
1134 if (check_has_torch_function((PyObject*)self)) {
1135 return handle_torch_function_getter(self, "output_nr");
1136 }
1137 const auto output_nr =
1138 static_cast<long>(THPVariable_Unpack(self).output_nr());
1139 return PyInt_FromLong(output_nr);
1140 END_HANDLE_TH_ERRORS
1141 }
1142
THPVariable_get_requires_grad(THPVariable * self,void * unused)1143 PyObject* THPVariable_get_requires_grad(THPVariable* self, void* unused) {
1144 HANDLE_TH_ERRORS
1145 if (check_has_torch_function((PyObject*)self)) {
1146 return handle_torch_function_getter(self, "requires_grad");
1147 }
1148 if (THPVariable_Unpack(self).requires_grad()) {
1149 Py_RETURN_TRUE;
1150 } else {
1151 Py_RETURN_FALSE;
1152 }
1153 END_HANDLE_TH_ERRORS
1154 }
1155
THPVariable_retains_grad(THPVariable * self,void * unused)1156 PyObject* THPVariable_retains_grad(THPVariable* self, void* unused) {
1157 HANDLE_TH_ERRORS
1158 if (check_has_torch_function((PyObject*)self)) {
1159 return handle_torch_function_getter(self, "retains_grad");
1160 }
1161 if (THPVariable_Unpack(self).retains_grad()) {
1162 Py_RETURN_TRUE;
1163 } else {
1164 Py_RETURN_FALSE;
1165 }
1166 END_HANDLE_TH_ERRORS
1167 }
1168
THPVariable_get_ndim(THPVariable * self,void * unused)1169 PyObject* THPVariable_get_ndim(THPVariable* self, void* unused) {
1170 HANDLE_TH_ERRORS
1171 if (check_has_torch_function((PyObject*)self)) {
1172 return handle_torch_function_getter(self, "ndim");
1173 }
1174 return PyInt_FromLong(THPVariable_Unpack(self).dim());
1175 END_HANDLE_TH_ERRORS
1176 }
1177
THPVariable_get_names(PyObject * self,void * unused)1178 PyObject* THPVariable_get_names(PyObject* self, void* unused) {
1179 HANDLE_TH_ERRORS
1180 if (check_has_torch_function(self)) {
1181 return handle_torch_function_getter((THPVariable*)self, "names");
1182 }
1183 // The long-term plan is to return a list of (python) torch.Dimname.
1184 // However, for now, return a list of string.
1185 const auto& tensor = THPVariable_Unpack(self);
1186 auto size = tensor.dim();
1187 THPObjectPtr tuple(PyTuple_New(size));
1188 if (!tuple)
1189 throw python_error();
1190
1191 const auto dimnames = tensor.names();
1192 for (const auto i : c10::irange(size)) {
1193 PyObject* str = nullptr;
1194 if (dimnames[i].type() == at::NameType::WILDCARD) {
1195 // PyTuple_SET_ITEM steals a reference to the object. When the tuple is
1196 // deallocated, it'll decrement the refcount on Py_None, which is bad.
1197 // To avoid this, we "create" a new reference to Py_None by increasing
1198 // the refcount.
1199 // Sources:
1200 // - https://docs.python.org/3/c-api/tuple.html#c.PyTuple_SetItem
1201 // -
1202 // https://stackoverflow.com/questions/16400600/how-to-return-a-tuple-containing-a-none-value-from-the-c-api
1203 Py_INCREF(Py_None);
1204 str = Py_None;
1205 } else {
1206 str = THPUtils_packString(dimnames[i].symbol().toUnqualString());
1207 if (!str)
1208 throw python_error();
1209 }
1210 PyTuple_SET_ITEM(tuple.get(), i, str);
1211 }
1212 return tuple.release();
1213 END_HANDLE_TH_ERRORS
1214 }
1215
THPVariable_set_names(PyObject * self,PyObject * names,void * unused)1216 int THPVariable_set_names(PyObject* self, PyObject* names, void* unused) {
1217 HANDLE_TH_ERRORS
1218 if (check_has_torch_function(self)) {
1219 return handle_torch_function_setter((THPVariable*)self, "names", names);
1220 }
1221 const auto& var = THPVariable_Unpack(self);
1222 if (names == Py_None) {
1223 at::internal_set_names_inplace(var, std::nullopt);
1224 } else {
1225 TORCH_CHECK(
1226 THPUtils_checkDimnameList(names),
1227 "names must either be None or a tuple of dim names");
1228 at::internal_set_names_inplace(var, torch::parseDimnameList(names));
1229 }
1230 return 0;
1231 END_HANDLE_TH_ERRORS_RET(-1)
1232 }
1233
THPVariable_set_requires_grad(THPVariable * self,PyObject * obj,void * unused)1234 int THPVariable_set_requires_grad(
1235 THPVariable* self,
1236 PyObject* obj,
1237 void* unused) {
1238 HANDLE_TH_ERRORS
1239 if (check_has_torch_function((PyObject*)self)) {
1240 return handle_torch_function_setter(self, "requires_grad", obj);
1241 }
1242 TORCH_CHECK(obj && PyBool_Check(obj), "requires_grad must be a bool");
1243 const auto& var = THPVariable_Unpack(self);
1244 auto requires_grad = (obj == Py_True);
1245 if (!var.is_leaf()) {
1246 THPUtils_setError(
1247 autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str());
1248 return -1;
1249 }
1250 if (requires_grad &&
1251 !isDifferentiableType(at::typeMetaToScalarType((var.dtype())))) {
1252 THPUtils_setError(
1253 "only Tensors of floating point and complex dtype can require gradients");
1254 return -1;
1255 }
1256 var.set_requires_grad(requires_grad);
1257 return 0;
1258 END_HANDLE_TH_ERRORS_RET(-1)
1259 }
1260
THPVariable_get_name(THPVariable * self,void * unused)1261 PyObject* THPVariable_get_name(THPVariable* self, void* unused) {
1262 if (check_has_torch_function((PyObject*)self)) {
1263 HANDLE_TH_ERRORS
1264 return handle_torch_function_getter(self, "name");
1265 END_HANDLE_TH_ERRORS
1266 }
1267 const auto& tensor = THPVariable_Unpack(self);
1268 if (tensor.name().empty())
1269 Py_RETURN_NONE;
1270 return THPUtils_packString(tensor.name().c_str());
1271 }
1272
THPVariable_get_backwards_hooks(THPVariable * self,void * unused)1273 PyObject* THPVariable_get_backwards_hooks(THPVariable* self, void* unused) {
1274 HANDLE_TH_ERRORS
1275 if (check_has_torch_function((PyObject*)self)) {
1276 return handle_torch_function_getter(self, "_backward_hooks");
1277 }
1278 if (self->backward_hooks) {
1279 Py_INCREF(self->backward_hooks);
1280 return self->backward_hooks;
1281 }
1282 Py_RETURN_NONE;
1283 END_HANDLE_TH_ERRORS
1284 }
1285
THPVariable_set_backwards_hooks(THPVariable * self,PyObject * obj,void * unused)1286 int THPVariable_set_backwards_hooks(
1287 THPVariable* self,
1288 PyObject* obj,
1289 void* unused) {
1290 HANDLE_TH_ERRORS
1291 if (check_has_torch_function((PyObject*)self)) {
1292 return handle_torch_function_setter(self, "_backward_hooks", obj);
1293 }
1294 TORCH_CHECK(obj, "Deletion of _backwards_hooks not allowed!");
1295 if (obj == Py_None) {
1296 obj = nullptr;
1297 }
1298 Py_XINCREF(obj);
1299 Py_XDECREF(self->backward_hooks);
1300 self->backward_hooks = obj;
1301 const auto& tensor = THPVariable_Unpack(self);
1302 torch::autograd::impl::clear_hooks(tensor);
1303 if (obj) {
1304 torch::autograd::impl::add_hook(
1305 tensor, std::make_unique<PyFunctionTensorPreHook>(obj, 0));
1306 }
1307 return 0;
1308 END_HANDLE_TH_ERRORS_RET(-1)
1309 }
1310
THPVariable_get_post_accumulate_grad_hooks(THPVariable * self,void * unused)1311 PyObject* THPVariable_get_post_accumulate_grad_hooks(
1312 THPVariable* self,
1313 void* unused) {
1314 HANDLE_TH_ERRORS
1315 if (check_has_torch_function((PyObject*)self)) {
1316 return handle_torch_function_getter(self, "_post_accumulate_grad_hooks");
1317 }
1318 if (self->post_accumulate_grad_hooks) {
1319 Py_INCREF(self->post_accumulate_grad_hooks);
1320 return self->post_accumulate_grad_hooks;
1321 }
1322 Py_RETURN_NONE;
1323 END_HANDLE_TH_ERRORS
1324 }
1325
THPVariable_set_post_accumulate_grad_hooks(THPVariable * self,PyObject * obj,void * unused)1326 int THPVariable_set_post_accumulate_grad_hooks(
1327 THPVariable* self,
1328 PyObject* obj,
1329 void* unused) {
1330 HANDLE_TH_ERRORS
1331 if (check_has_torch_function((PyObject*)self)) {
1332 return handle_torch_function_setter(
1333 self, "_post_accumulate_grad_hooks", obj);
1334 }
1335 TORCH_CHECK(obj, "Deletion of _post_accumulate_grad_hooks not allowed!");
1336 if (obj == Py_None) {
1337 obj = nullptr;
1338 }
1339 Py_XINCREF(obj);
1340 Py_CLEAR(self->post_accumulate_grad_hooks);
1341 self->post_accumulate_grad_hooks = obj;
1342 const auto& tensor = THPVariable_Unpack(self);
1343 if (obj) {
1344 torch::autograd::impl::set_post_acc_grad_hooks(
1345 tensor, std::make_unique<PyFunctionTensorPostAccGradHooks>(obj));
1346 }
1347 return 0;
1348 END_HANDLE_TH_ERRORS_RET(-1)
1349 }
1350
THPVariable_get_base(THPVariable * self,void * unused)1351 PyObject* THPVariable_get_base(THPVariable* self, void* unused) {
1352 HANDLE_TH_ERRORS
1353 if (check_has_torch_function((PyObject*)self)) {
1354 return handle_torch_function_getter(self, "_base");
1355 }
1356 const auto& tensor = THPVariable_Unpack(self);
1357 if (tensor.is_view()) {
1358 return THPVariable_Wrap(tensor._base());
1359 }
1360 Py_RETURN_NONE;
1361 END_HANDLE_TH_ERRORS
1362 }
1363
THPVariable_get_shape(THPVariable * self,void * unused)1364 PyObject* THPVariable_get_shape(THPVariable* self, void* unused) {
1365 HANDLE_TH_ERRORS
1366 if (check_has_torch_function((PyObject*)self)) {
1367 return handle_torch_function_getter(self, "shape");
1368 }
1369 return THPSize_NewFromSymSizes(THPVariable_Unpack(self));
1370 END_HANDLE_TH_ERRORS
1371 }
1372
THPVariable_is_cpu(THPVariable * self,void * unused)1373 PyObject* THPVariable_is_cpu(THPVariable* self, void* unused) {
1374 HANDLE_TH_ERRORS
1375 if (check_has_torch_function((PyObject*)self)) {
1376 return handle_torch_function_getter(self, "is_cpu");
1377 }
1378 auto& self_ = THPVariable_Unpack(self);
1379 return torch::autograd::utils::wrap(self_.is_cpu());
1380 END_HANDLE_TH_ERRORS
1381 }
1382
THPVariable_is_cuda(THPVariable * self,void * unused)1383 PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) {
1384 HANDLE_TH_ERRORS
1385 if (check_has_torch_function((PyObject*)self)) {
1386 return handle_torch_function_getter(self, "is_cuda");
1387 }
1388 auto& self_ = THPVariable_Unpack(self);
1389 return torch::autograd::utils::wrap(self_.is_cuda());
1390 END_HANDLE_TH_ERRORS
1391 }
1392
THPVariable_is_mtia(THPVariable * self,void * unused)1393 PyObject* THPVariable_is_mtia(THPVariable* self, void* unused) {
1394 HANDLE_TH_ERRORS
1395 if (check_has_torch_function((PyObject*)self)) {
1396 return handle_torch_function_getter(self, "is_mtia");
1397 }
1398 auto& self_ = THPVariable_Unpack(self);
1399 return torch::autograd::utils::wrap(self_.is_mtia());
1400 END_HANDLE_TH_ERRORS
1401 }
1402
THPVariable_is_xla(THPVariable * self,void * unused)1403 PyObject* THPVariable_is_xla(THPVariable* self, void* unused) {
1404 HANDLE_TH_ERRORS
1405 if (check_has_torch_function((PyObject*)self)) {
1406 return handle_torch_function_getter(self, "is_xla");
1407 }
1408 auto& self_ = THPVariable_Unpack(self);
1409 return torch::autograd::utils::wrap(self_.is_xla());
1410 END_HANDLE_TH_ERRORS
1411 }
1412
THPVariable_is_ipu(THPVariable * self,void * unused)1413 PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) {
1414 HANDLE_TH_ERRORS
1415 if (check_has_torch_function((PyObject*)self)) {
1416 return handle_torch_function_getter(self, "is_ipu");
1417 }
1418 auto& self_ = THPVariable_Unpack(self);
1419 return torch::autograd::utils::wrap(self_.is_ipu());
1420 END_HANDLE_TH_ERRORS
1421 }
1422
THPVariable_is_xpu(THPVariable * self,void * unused)1423 PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) {
1424 HANDLE_TH_ERRORS
1425 if (check_has_torch_function((PyObject*)self)) {
1426 return handle_torch_function_getter(self, "is_xpu");
1427 }
1428 auto& self_ = THPVariable_Unpack(self);
1429 return torch::autograd::utils::wrap(self_.is_xpu());
1430 END_HANDLE_TH_ERRORS
1431 }
1432
THPVariable_is_sparse(THPVariable * self,void * unused)1433 PyObject* THPVariable_is_sparse(THPVariable* self, void* unused) {
1434 HANDLE_TH_ERRORS
1435 if (check_has_torch_function((PyObject*)self)) {
1436 return handle_torch_function_getter(self, "is_sparse");
1437 }
1438 auto& self_ = THPVariable_Unpack(self);
1439 return torch::autograd::utils::wrap(self_.is_sparse());
1440 END_HANDLE_TH_ERRORS
1441 }
1442
THPVariable_is_sparse_csr(THPVariable * self,void * unused)1443 PyObject* THPVariable_is_sparse_csr(THPVariable* self, void* unused) {
1444 HANDLE_TH_ERRORS
1445 if (check_has_torch_function((PyObject*)self)) {
1446 return handle_torch_function_getter(self, "is_sparse_csr");
1447 }
1448 auto& self_ = THPVariable_Unpack(self);
1449 return torch::autograd::utils::wrap(self_.is_sparse_csr());
1450 END_HANDLE_TH_ERRORS
1451 }
1452
THPVariable_is_mkldnn(THPVariable * self,void * unused)1453 PyObject* THPVariable_is_mkldnn(THPVariable* self, void* unused) {
1454 HANDLE_TH_ERRORS
1455 if (check_has_torch_function((PyObject*)self)) {
1456 return handle_torch_function_getter(self, "is_mkldnn");
1457 }
1458 auto& self_ = THPVariable_Unpack(self);
1459 return torch::autograd::utils::wrap(self_.is_mkldnn());
1460 END_HANDLE_TH_ERRORS
1461 }
1462
THPVariable_is_mps(THPVariable * self,void * unused)1463 PyObject* THPVariable_is_mps(THPVariable* self, void* unused) {
1464 HANDLE_TH_ERRORS
1465 if (check_has_torch_function((PyObject*)self)) {
1466 return handle_torch_function_getter(self, "is_mps");
1467 }
1468 auto& self_ = THPVariable_Unpack(self);
1469 return torch::autograd::utils::wrap(self_.is_mps());
1470 END_HANDLE_TH_ERRORS
1471 }
1472
THPVariable_is_maia(THPVariable * self,void * unused)1473 PyObject* THPVariable_is_maia(THPVariable* self, void* unused) {
1474 HANDLE_TH_ERRORS
1475 if (check_has_torch_function((PyObject*)self)) {
1476 return handle_torch_function_getter(self, "is_maia");
1477 }
1478 auto& self_ = THPVariable_Unpack(self);
1479 return torch::autograd::utils::wrap(self_.is_maia());
1480 END_HANDLE_TH_ERRORS
1481 }
1482
THPVariable_is_vulkan(THPVariable * self,void * unused)1483 PyObject* THPVariable_is_vulkan(THPVariable* self, void* unused) {
1484 HANDLE_TH_ERRORS
1485 if (check_has_torch_function((PyObject*)self)) {
1486 return handle_torch_function_getter(self, "is_vulkan");
1487 }
1488 auto& self_ = THPVariable_Unpack(self);
1489 return torch::autograd::utils::wrap(self_.is_vulkan());
1490 END_HANDLE_TH_ERRORS
1491 }
1492
THPVariable_is_quantized(THPVariable * self,void * unused)1493 PyObject* THPVariable_is_quantized(THPVariable* self, void* unused) {
1494 HANDLE_TH_ERRORS
1495 if (check_has_torch_function((PyObject*)self)) {
1496 return handle_torch_function_getter(self, "is_quantized");
1497 }
1498 auto& self_ = THPVariable_Unpack(self);
1499 return torch::autograd::utils::wrap(self_.is_quantized());
1500 END_HANDLE_TH_ERRORS
1501 }
1502
THPVariable_is_meta(THPVariable * self,void * unused)1503 PyObject* THPVariable_is_meta(THPVariable* self, void* unused) {
1504 HANDLE_TH_ERRORS
1505 if (check_has_torch_function((PyObject*)self)) {
1506 return handle_torch_function_getter(self, "is_meta");
1507 }
1508 auto& self_ = THPVariable_Unpack(self);
1509 return torch::autograd::utils::wrap(self_.is_meta());
1510 END_HANDLE_TH_ERRORS
1511 }
1512
THPVariable_is_complex(THPVariable * self,void * unused)1513 PyObject* THPVariable_is_complex(THPVariable* self, void* unused) {
1514 HANDLE_TH_ERRORS
1515 if (check_has_torch_function((PyObject*)self)) {
1516 return handle_torch_function_getter(self, "is_complex");
1517 }
1518 auto& self_ = THPVariable_Unpack(self);
1519 return torch::autograd::utils::wrap(self_.is_complex());
1520 END_HANDLE_TH_ERRORS
1521 }
1522
THPVariable_is_nested(THPVariable * self,void * unused)1523 PyObject* THPVariable_is_nested(THPVariable* self, void* unused) {
1524 HANDLE_TH_ERRORS
1525 if (check_has_torch_function((PyObject*)self)) {
1526 return handle_torch_function_getter(self, "is_nested");
1527 }
1528 auto& self_ = THPVariable_Unpack(self);
1529 return torch::autograd::utils::wrap(self_.is_nested());
1530 END_HANDLE_TH_ERRORS
1531 }
1532
THPVariable_has_symbolic_sizes_strides(THPVariable * self,void * unused)1533 PyObject* THPVariable_has_symbolic_sizes_strides(
1534 THPVariable* self,
1535 void* unused) {
1536 HANDLE_TH_ERRORS
1537 auto& self_ = THPVariable_Unpack(self);
1538 return torch::autograd::utils::wrap(
1539 self_.unsafeGetTensorImpl()->has_symbolic_sizes_strides());
1540 END_HANDLE_TH_ERRORS
1541 }
1542
THPVariable_dtype(THPVariable * self,void * unused)1543 static PyObject* THPVariable_dtype(THPVariable* self, void* unused) {
1544 HANDLE_TH_ERRORS
1545 if (check_has_torch_function((PyObject*)self)) {
1546 return handle_torch_function_getter(self, "dtype");
1547 }
1548 auto& self_ = THPVariable_Unpack(self);
1549 return torch::autograd::utils::wrap(self_.scalar_type());
1550 END_HANDLE_TH_ERRORS
1551 }
1552
THPVariable_layout(THPVariable * self,void * unused)1553 static PyObject* THPVariable_layout(THPVariable* self, void* unused) {
1554 HANDLE_TH_ERRORS
1555 if (check_has_torch_function((PyObject*)self)) {
1556 return handle_torch_function_getter(self, "layout");
1557 }
1558 auto& self_ = THPVariable_Unpack(self);
1559 return torch::autograd::utils::wrap(self_.layout());
1560 END_HANDLE_TH_ERRORS
1561 }
1562
THPVariable_device(THPVariable * self,void * unused)1563 static PyObject* THPVariable_device(THPVariable* self, void* unused) {
1564 HANDLE_TH_ERRORS
1565 if (check_has_torch_function((PyObject*)self)) {
1566 return handle_torch_function_getter(self, "device");
1567 }
1568 return THPDevice_New(THPVariable_Unpack(self).device());
1569 END_HANDLE_TH_ERRORS
1570 }
1571
THPVariable_get_nbytes(THPVariable * self,void * unused)1572 static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) {
1573 HANDLE_TH_ERRORS
1574 if (check_has_torch_function((PyObject*)self)) {
1575 return handle_torch_function_getter(self, "nbytes");
1576 }
1577 return PyLong_FromSize_t(THPVariable_Unpack(self).nbytes());
1578 END_HANDLE_TH_ERRORS
1579 }
1580
THPVariable_get_itemsize(THPVariable * self,void * unused)1581 static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) {
1582 HANDLE_TH_ERRORS
1583 if (check_has_torch_function((PyObject*)self)) {
1584 return handle_torch_function_getter(self, "itemsize");
1585 }
1586 return PyLong_FromSize_t(THPVariable_Unpack(self).itemsize());
1587 END_HANDLE_TH_ERRORS
1588 }
1589
THPVariable_set_real(PyObject * self,PyObject * real,void * unused)1590 int THPVariable_set_real(PyObject* self, PyObject* real, void* unused) {
1591 HANDLE_TH_ERRORS
1592 auto& self_ = THPVariable_Unpack(self);
1593 auto self_real = at::real(self_);
1594 auto real_ = valueToTensor(self_real.options(), real, self_real.device());
1595 {
1596 pybind11::gil_scoped_release no_gil;
1597 self_real.copy_(real_);
1598 return 0;
1599 }
1600 END_HANDLE_TH_ERRORS_RET(-1)
1601 }
1602
THPVariable_set_imag(PyObject * self,PyObject * imag,void * unused)1603 int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) {
1604 HANDLE_TH_ERRORS
1605 auto& self_ = THPVariable_Unpack(self);
1606 auto self_imag = at::imag(self_);
1607 auto imag_ = valueToTensor(self_imag.options(), imag, self_imag.device());
1608 {
1609 pybind11::gil_scoped_release no_gil;
1610 self_imag.copy_(imag_);
1611 return 0;
1612 }
1613 END_HANDLE_TH_ERRORS_RET(-1)
1614 }
1615
THPVariable__use_count(PyObject * self,PyObject * noargs)1616 PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) {
1617 HANDLE_TH_ERRORS
1618 const auto& t = THPVariable_Unpack(self);
1619 return THPUtils_packUInt64(t.use_count());
1620 END_HANDLE_TH_ERRORS
1621 }
1622
1623 // properties are registered here because we are currently only able to bind
1624 // them manually. TODO: make declarable in native_functions
1625 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1626 static struct PyGetSetDef THPVariable_properties[] = {
1627 {"_python_dispatch",
1628 (getter)THPVariable_get_python_dispatch,
1629 nullptr,
1630 nullptr,
1631 nullptr},
1632 {"T", (getter)PropertyT::getter, nullptr, nullptr, nullptr},
1633 {"H", (getter)PropertyH::getter, nullptr, nullptr, nullptr},
1634 {"mT", (getter)PropertymT::getter, nullptr, nullptr, nullptr},
1635 {"mH", (getter)PropertymH::getter, nullptr, nullptr, nullptr},
1636 {"_cdata", (getter)THPVariable_get_cdata, nullptr, nullptr, nullptr},
1637 {"_version", (getter)THPVariable_get_version, nullptr, nullptr, nullptr},
1638 {"grad_fn", (getter)THPVariable_get_grad_fn, nullptr, nullptr, nullptr},
1639 {"_grad_fn",
1640 (getter)THPVariable_get_grad_fn,
1641 (setter)THPVariable_set_grad_fn,
1642 nullptr,
1643 nullptr},
1644 {"is_leaf", (getter)THPVariable_is_leaf, nullptr, nullptr, nullptr},
1645 {"retains_grad",
1646 (getter)THPVariable_retains_grad,
1647 nullptr,
1648 nullptr,
1649 nullptr},
1650 {"data",
1651 (getter)PropertyData::getter,
1652 (setter)THPVariable_set_data,
1653 nullptr,
1654 nullptr},
1655 {"_grad",
1656 (getter)PropertyGrad::getter,
1657 (setter)THPVariable_set_grad,
1658 nullptr,
1659 nullptr}, // Allows the python class to override .grad
1660 {"grad",
1661 (getter)PropertyGrad::getter,
1662 (setter)THPVariable_set_grad,
1663 nullptr,
1664 nullptr},
1665 {"_base", (getter)THPVariable_get_base, nullptr, nullptr, nullptr},
1666 {"volatile",
1667 (getter)THPVariable_get_volatile,
1668 (setter)THPVariable_set_volatile,
1669 nullptr,
1670 nullptr},
1671 {"output_nr", (getter)THPVariable_get_output_nr, nullptr, nullptr, nullptr},
1672 {"requires_grad",
1673 (getter)THPVariable_get_requires_grad,
1674 (setter)THPVariable_set_requires_grad,
1675 nullptr,
1676 nullptr},
1677 {"_backward_hooks",
1678 (getter)THPVariable_get_backwards_hooks,
1679 (setter)THPVariable_set_backwards_hooks,
1680 nullptr,
1681 nullptr},
1682 {"_post_accumulate_grad_hooks",
1683 (getter)THPVariable_get_post_accumulate_grad_hooks,
1684 (setter)THPVariable_set_post_accumulate_grad_hooks,
1685 nullptr,
1686 nullptr},
1687 {"name", (getter)THPVariable_get_name, nullptr, nullptr, nullptr},
1688 {"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
1689 {"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
1690 {"is_mtia", (getter)THPVariable_is_mtia, nullptr, nullptr, nullptr},
1691 {"is_cpu", (getter)THPVariable_is_cpu, nullptr, nullptr, nullptr},
1692 {"is_xla", (getter)THPVariable_is_xla, nullptr, nullptr, nullptr},
1693 {"is_xpu", (getter)THPVariable_is_xpu, nullptr, nullptr, nullptr},
1694 {"is_ipu", (getter)THPVariable_is_ipu, nullptr, nullptr, nullptr},
1695 {"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
1696 {"is_sparse_csr",
1697 (getter)THPVariable_is_sparse_csr,
1698 nullptr,
1699 nullptr,
1700 nullptr},
1701 {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
1702 {"is_mps", (getter)THPVariable_is_mps, nullptr, nullptr, nullptr},
1703 {"is_maia", (getter)THPVariable_is_maia, nullptr, nullptr, nullptr},
1704 {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
1705 {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
1706 {"is_quantized",
1707 (getter)THPVariable_is_quantized,
1708 nullptr,
1709 nullptr,
1710 nullptr},
1711 {"is_meta", (getter)THPVariable_is_meta, nullptr, nullptr, nullptr},
1712 {"is_nested", (getter)THPVariable_is_nested, nullptr, nullptr, nullptr},
1713 {"_has_symbolic_sizes_strides",
1714 (getter)THPVariable_has_symbolic_sizes_strides,
1715 nullptr,
1716 nullptr,
1717 nullptr},
1718 {"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr},
1719 {"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
1720 {"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
1721 {"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
1722 {"nbytes", (getter)THPVariable_get_nbytes, nullptr, nullptr, nullptr},
1723 {"itemsize", (getter)THPVariable_get_itemsize, nullptr, nullptr, nullptr},
1724 {"names",
1725 (getter)THPVariable_get_names,
1726 (setter)THPVariable_set_names,
1727 nullptr,
1728 nullptr},
1729 {"real",
1730 (getter)PropertyReal::getter,
1731 (setter)THPVariable_set_real,
1732 nullptr,
1733 nullptr},
1734 {"imag",
1735 (getter)PropertyImag::getter,
1736 (setter)THPVariable_set_imag,
1737 nullptr,
1738 nullptr},
1739 {nullptr}};
1740
1741 static PyMappingMethods THPVariable_as_mapping = {
1742 THPVariable_length,
1743 THPVariable_getitem,
1744 THPVariable_setitem,
1745 };
1746
1747 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1748 static PyMethodDef extra_methods[] = {
1749 {"as_subclass",
1750 castPyCFunctionWithKeywords(THPVariable_as_subclass),
1751 METH_VARARGS | METH_KEYWORDS,
1752 nullptr},
1753 {"_make_subclass",
1754 castPyCFunctionWithKeywords(THPVariable_make_subclass),
1755 METH_STATIC | METH_VARARGS | METH_KEYWORDS,
1756 nullptr},
1757 {"_make_wrapper_subclass",
1758 castPyCFunctionWithKeywords(THPVariable_make_wrapper_subclass),
1759 METH_STATIC | METH_VARARGS | METH_KEYWORDS,
1760 nullptr},
1761 {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr},
1762 {"_view_func",
1763 castPyCFunctionWithKeywords(THPVariable_view_func),
1764 METH_VARARGS | METH_KEYWORDS,
1765 nullptr},
1766 {"_view_func_unsafe",
1767 castPyCFunctionWithKeywords(THPVariable_view_func_unsafe),
1768 METH_VARARGS | METH_KEYWORDS,
1769 nullptr},
1770 {"_rev_view_func_unsafe",
1771 THPVariable_rev_view_func_unsafe,
1772 METH_O,
1773 nullptr},
1774 {"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
1775 {nullptr}};
1776
1777 struct THPVariableMeta {
1778 PyHeapTypeObject base;
1779 };
1780
1781 int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs);
1782
1783 PyTypeObject THPVariableMetaType = {
1784 PyVarObject_HEAD_INIT(
1785 DEFERRED_ADDRESS(&PyType_Type),
1786 0) "torch._C._TensorMeta", /* tp_name */
1787 sizeof(THPVariableMeta), /* tp_basicsize */
1788 0, /* tp_itemsize */
1789 nullptr, /* tp_dealloc */
1790 0, /* tp_vectorcall_offset */
1791 nullptr, /* tp_getattr */
1792 nullptr, /* tp_setattr */
1793 nullptr, /* tp_reserved */
1794 nullptr, /* tp_repr */
1795 nullptr, /* tp_as_number */
1796 nullptr, /* tp_as_sequence */
1797 nullptr, /* tp_as_mapping */
1798 nullptr, /* tp_hash */
1799 nullptr, /* tp_call */
1800 nullptr, /* tp_str */
1801 nullptr, /* tp_getattro */
1802 nullptr, /* tp_setattro */
1803 nullptr, /* tp_as_buffer */
1804 // NOLINTNEXTLINE(misc-redundant-expression)
1805 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
1806 nullptr, /* tp_doc */
1807 nullptr, /* tp_traverse */
1808 nullptr, /* tp_clear */
1809 nullptr, /* tp_richcompare */
1810 0, /* tp_weaklistoffset */
1811 nullptr, /* tp_iter */
1812 nullptr, /* tp_iternext */
1813 nullptr, /* tp_methods */
1814 nullptr, /* tp_members */
1815 nullptr, /* tp_getset */
1816 DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
1817 nullptr, /* tp_dict */
1818 nullptr, /* tp_descr_get */
1819 nullptr, /* tp_descr_set */
1820 0, /* tp_dictoffset */
1821 THPVariableMetaType_init, /* tp_init */
1822 nullptr, /* tp_alloc */
1823 nullptr, /* tp_new */
1824 };
1825
1826 PyTypeObject THPVariableType = {
1827 PyVarObject_HEAD_INIT(
1828 &THPVariableMetaType,
1829 0) "torch._C.TensorBase", /* tp_name */
1830 sizeof(THPVariable), /* tp_basicsize */
1831 0, /* tp_itemsize */
1832 // This is unspecified, because it is illegal to create a THPVariableType
1833 // directly. Subclasses will have their tp_dealloc set appropriately
1834 // by the metaclass
1835 nullptr, /* tp_dealloc */
1836 0, /* tp_vectorcall_offset */
1837 nullptr, /* tp_getattr */
1838 nullptr, /* tp_setattr */
1839 nullptr, /* tp_reserved */
1840 nullptr, /* tp_repr */
1841 nullptr, /* tp_as_number */
1842 nullptr, /* tp_as_sequence */
1843 &THPVariable_as_mapping, /* tp_as_mapping */
1844 nullptr, /* tp_hash */
1845 nullptr, /* tp_call */
1846 nullptr, /* tp_str */
1847 nullptr, /* tp_getattro */
1848 nullptr, /* tp_setattro */
1849 nullptr, /* tp_as_buffer */
1850 // NOLINTNEXTLINE(misc-redundant-expression)
1851 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
1852 Py_TPFLAGS_HAVE_GC, /* tp_flags */
1853 nullptr, /* tp_doc */
1854 // Also set by metaclass
1855 (traverseproc)THPFunction_traverse, /* tp_traverse */
1856 (inquiry)THPVariable_clear, /* tp_clear */
1857 nullptr, /* tp_richcompare */
1858 0, /* tp_weaklistoffset */
1859 nullptr, /* tp_iter */
1860 nullptr, /* tp_iternext */
1861 nullptr, /* tp_methods */
1862 nullptr, /* tp_members */
1863 THPVariable_properties, /* tp_getset */
1864 nullptr, /* tp_base */
1865 nullptr, /* tp_dict */
1866 nullptr, /* tp_descr_get */
1867 nullptr, /* tp_descr_set */
1868 0, /* tp_dictoffset */
1869 nullptr, /* tp_init */
1870 nullptr, /* tp_alloc */
1871 // Although new is provided here, it is illegal to call this with cls ==
1872 // THPVariableMeta. Instead, subclass it first and then construct it
1873 THPVariable_pynew, /* tp_new */
1874 };
1875
THPVariable_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)1876 PyObject* THPVariable_pynew(
1877 PyTypeObject* type,
1878 PyObject* args,
1879 PyObject* kwargs) {
1880 HANDLE_TH_ERRORS
1881 TORCH_CHECK(
1882 type != &THPVariableType,
1883 "Cannot directly construct TensorBase; subclass it and then construct that");
1884 jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
1885 auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
1886 // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
1887 // given a raw pointer that will refcount bump
1888 // NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
1889 // alias(), lift_fresh()) which can return Tensor subclasses. We allow
1890 // these to be passed on directly.
1891 return THPVariable_NewWithVar(
1892 type,
1893 std::move(tensor),
1894 c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED,
1895 /*allow_preexisting_pyobj=*/true);
1896 END_HANDLE_TH_ERRORS
1897 }
1898
1899 // NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
1900 // on subclasses. It's never valid to construct a THPVariable so it's not
1901 // necessary to implement the dealloc for that case
THPVariable_subclass_dealloc(PyObject * self)1902 void THPVariable_subclass_dealloc(PyObject* self) {
1903 if (THPVariable_tryResurrect((THPVariable*)self))
1904 return;
1905
1906 // This is like a crappy version of subtype_dealloc.
1907 // Unfortunately, we cannot directly delegate to
1908 // subtype_dealloc as it will start walking the parent
1909 // chain *starting with* the type of self, which will cause
1910 // us to go back to our custom dealloc.
1911 //
1912 // We have to replicate the subtype_dealloc logic to ensure
1913 // that finalizers are handled correctly
1914 PyTypeObject* type = Py_TYPE(self);
1915 TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
1916 TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented");
1917
1918 PyObject_GC_UnTrack(self);
1919 // TODO: consider using trash can
1920
1921 bool has_finalizer = type->tp_finalize || type->tp_del;
1922
1923 if (type->tp_finalize) {
1924 PyObject_GC_Track(self);
1925 if (PyObject_CallFinalizerFromDealloc(self) < 0) {
1926 /* Resurrected */
1927 return;
1928 }
1929 PyObject_GC_UnTrack(self);
1930 }
1931
1932 // base test is unnecessary as THPVariable does not set this
1933 if (type->tp_weaklistoffset) {
1934 PyObject_ClearWeakRefs(self);
1935 }
1936
1937 if (type->tp_del) {
1938 PyObject_GC_Track(self);
1939 type->tp_del(self);
1940 if (Py_REFCNT(self) > 0) {
1941 /* Resurrected */
1942 return;
1943 }
1944 PyObject_GC_UnTrack(self);
1945 }
1946
1947 if (has_finalizer) {
1948 /* New weakrefs could be created during the finalizer call.
1949 If this occurs, clear them out without calling their
1950 finalizers since they might rely on part of the object
1951 being finalized that has already been destroyed. */
1952 if (type->tp_weaklistoffset) {
1953 /* Modeled after GET_WEAKREFS_LISTPTR() */
1954 PyWeakReference** list =
1955 (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
1956 while (*list)
1957 _PyWeakref_ClearRef(*list);
1958 }
1959 }
1960
1961 // Clear all slots until we get to base class THPVariableType
1962 {
1963 PyTypeObject* base = type;
1964 while (base != &THPVariableType) {
1965 if (Py_SIZE(base)) {
1966 clear_slots(base, self);
1967 }
1968 base = base->tp_base;
1969 TORCH_INTERNAL_ASSERT(base);
1970 }
1971 }
1972
1973 // All Python defined classes have __dict__
1974 if (C10_LIKELY(type->tp_dictoffset)) {
1975 PyObject** dictptr = _PyObject_GetDictPtr(self);
1976 if (dictptr != nullptr) {
1977 PyObject* dict = *dictptr;
1978 if (dict != nullptr) {
1979 Py_DECREF(dict);
1980 *dictptr = nullptr;
1981 }
1982 }
1983 }
1984
1985 // subtype_dealloc allows for this but we don't
1986 TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
1987
1988 // Finally clear out the base THPVariable
1989 THPVariable_clear((THPVariable*)self);
1990 ((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
1991 Py_TYPE(self)->tp_free(self);
1992
1993 // Python defined subclasses should always be on the heap
1994 TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
1995 Py_DECREF(type);
1996 }
1997
1998 // Creates a new Python object for a Variable. The status parameter
1999 // specifies what the interpreter tag status on the object is; for
2000 // example, if you ran check_pyobj, the return optional of this object
2001 // tells you if the tensor was already tagged or not so you can pass
2002 // TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where
2003 // var came from and can directly assert that it's DEFINITELY_UNINITIALIZED.
2004 // It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED.
THPVariable_NewWithVar(PyTypeObject * type,Variable _var,c10::impl::PyInterpreterStatus status,bool allow_preexisting_pyobj)2005 static PyObject* THPVariable_NewWithVar(
2006 PyTypeObject* type,
2007 Variable _var,
2008 c10::impl::PyInterpreterStatus status,
2009 bool allow_preexisting_pyobj) {
2010 // Make sure that the reinterpret into a THPVariable* will be valid
2011 TORCH_CHECK(
2012 PyType_IsSubtype(type, &THPVariableType),
2013 "Creating a Tensor subclass from a class ",
2014 "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
2015
2016 // This function overwrite the Tensor's pyobj field without extra checks
2017 // Make sure it is not set otherwise we would leak memory
2018 auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
2019 getPyInterpreter(), /*ignore_hermetic_tls=*/false);
2020
2021 // Under some circumstances, we may attempt to create a new Python
2022 // object for a variable that already has a Python object. The most common
2023 // situation this can occur is if you have a TorchDispatchMode active that
2024 // is returning a subclass from lift_fresh (which is invoked to
2025 // appropriately "wrap" a constant tensor into whatever ambient modes are
2026 // active.)
2027 //
2028 // In general, it is impossible to handle this case compositionally.
2029 // Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
2030 // that is transforming all ops (including the internal lift_fresh call that
2031 // transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
2032 // BTensor, where ATensor and BTensor are completely unrelated subclasses
2033 // and there is no way to compose them. There is no way to satisfy the user
2034 // request here: in particular, you can't just try to re-invoke the ATensor
2035 // constructor on the returned BTensor, because (1) this could cause an
2036 // infinite loop--we are already in ATensor.__new__ and (2) there isn't any
2037 // guarantee that ATensor.__new__ supports a single element constructor
2038 // anyway.
2039 //
2040 // However, a more common case is a user just called torch.Tensor([1, 2, 3]),
2041 // and a fake tensor mode is active. Really, all you want is to get back
2042 // a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
2043 // would have returned a fake tensor (concretely, the way this happens
2044 // is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
2045 // turns into a FakeTensor when we call lift_fresh on this real tensor).
2046 // This case is compositional because FakeTensor is a subclass of Tensor, so
2047 // it's valid for us to return it in place of a Tensor. So this is what we
2048 // do.
2049
2050 if (mb_obj.has_value() && mb_obj.value()) {
2051 TORCH_CHECK(
2052 allow_preexisting_pyobj,
2053 "Creating a new Tensor subclass ",
2054 type->tp_name,
2055 " but the raw Tensor object is already associated to a python object ",
2056 "of type ",
2057 mb_obj.value()->ob_type->tp_name);
2058 // Even if we allow pre-existing PyObject, we don't allow completely
2059 // ignoring the requested type. Check that we fulfilled a subtype
2060 // relation here. In the common case the requested type is Tensor and
2061 // this always succeeds.
2062 PyObject* obj = *mb_obj;
2063 // Check if it's OK to just directly return the Python object without
2064 // allocating a new variable. We just check that the existing Python
2065 // object is a subclass of the requested type.
2066 PyTypeObject* obj_type = Py_TYPE(obj);
2067 TORCH_CHECK(
2068 obj_type == type || PyType_IsSubtype(obj_type, type),
2069 "Creating a new Tensor subclass ",
2070 type->tp_name,
2071 " but the raw Tensor object is already associated to a python object ",
2072 "of type ",
2073 mb_obj.value()->ob_type->tp_name,
2074 " which is not a subclass of the "
2075 "requested type");
2076 // We may (in fact, we typically will) need to resurrect this
2077 return THPVariable_Wrap(std::move(_var));
2078 }
2079
2080 PyObject* obj = type->tp_alloc(type, 0);
2081 if (obj) {
2082 auto v = (THPVariable*)obj;
2083 // TODO: named constructor to avoid default initialization
2084 new (&v->cdata) MaybeOwned<Variable>();
2085 if (c10::impl::HermeticPyObjectTLS::get_state()) {
2086 // Do NOT initialize pyobj field on the tensor, you own the C++
2087 v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
2088 TORCH_INTERNAL_ASSERT(
2089 !check_has_torch_dispatch(obj),
2090 "While HermeticPyObject was enabled, we attempted to create a tensor "
2091 "subclass with __torch_dispatch__. This violates the invariant that "
2092 "operations in HermeticPyObject have equivalent C++ implementations. "
2093 "If your operator registered from Python operator registration isn't "
2094 "doing anything strange, there may be an internal PyTorch bug involving "
2095 "not appropriately disabling TorchDispatchMode before executing "
2096 "Python op registration.");
2097 } else {
2098 // Normal codepath
2099 v->cdata = MaybeOwned<Variable>::owned(std::move(_var));
2100 const auto& var = THPVariable_Unpack(v);
2101 var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
2102 getPyInterpreter(), obj, status);
2103 if (check_has_torch_dispatch(obj)) {
2104 var.unsafeGetTensorImpl()->set_python_dispatch(true);
2105 }
2106 }
2107 }
2108 return obj;
2109 }
2110
2111 /// NOTE [ PyObject Traversal ]
2112 ///
2113 /// PyObjects that are wrapping c++ objects can lead to non-trivial traverse
2114 /// logic and it can be tricky to know what to traverse and when. This note
2115 /// tries to clarify what is the danger here and a simple algorithm to choose
2116 /// how to write the tp_traverse and tp_clear functions. If you're not already
2117 /// familiar with how the CPython GC works, you should read this in-depth
2118 /// description: https://devguide.python.org/garbage_collector/
2119 ///
2120 /// The complexity for us comes from the fact that some c++ shared_ptr objects
2121 /// own references to python objects and are also owned both by other python
2122 /// objects and c++ objects. This means that to allow the GC to collect all
2123 /// cycles, we need to properly implement the traverse/clear methods that take
2124 /// into account these C++ ownership links.
2125 ///
2126 /// The main danger here comes from the fact that, while all python-related code
2127 /// is thread safe wrt the GC execution (thanks to the GIL), other threads might
2128 /// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
2129 /// going up or down in between the different traverse/clear invocations. The
2130 /// one constraint we add here that is not explicitly mentioned in the GC
2131 /// description above is that for a given GC run (meaning while the GIL is
2132 /// held), the traverse/clear pair should never report different ownership
2133 /// relations: if traverse visited a given PyObject, then the clear within that
2134 /// same GC run must still be the sole owner and clear that PyObject.
2135 ///
2136 /// A more mechanical algorithm to know what to traverse/clear is as follows:
2137 /// - Any field on this PyObject that contains a strong reference to another
2138 /// PyObject
2139 /// must be visited and cleared. An example of that is the "backward_hooks"
2140 /// field of the THPVariable.
2141 /// - Any field that contains a C++ object that is uniquely owned by this
2142 /// PyObject (either
2143 /// a unique_ptr or a shared_ptr with use_count==1) should have all the
2144 /// PyObject it owns visited and cleared. An example would be here the
2145 /// tensor hooks.
2146 /// - If that uniquely owned C++ object also uniquely owns other C++ objects,
2147 /// these should be
2148 /// visited and cleared as well if they contain any PyObject.
2149 ///
2150 /// Caveat: to avoid slow runtime, we limit the depth of this exploration of C++
2151 /// objects in practice and we do not, for example, go through the whole
2152 /// autograd graph, even if it is uniquely owned. This is a known place where
2153 /// users can create noncollectable cycles as described in:
2154 /// https://github.com/pytorch/pytorch/issues/7343
2155 ///
2156
traverse_slots(PyTypeObject * type,PyObject * self,visitproc visit,void * arg)2157 static int traverse_slots(
2158 PyTypeObject* type,
2159 PyObject* self,
2160 visitproc visit,
2161 void* arg) {
2162 auto n = Py_SIZE(type);
2163 auto mp = type->tp_members;
2164 for (Py_ssize_t i = 0; i < n; i++, mp++) {
2165 if (mp->type == T_OBJECT_EX) {
2166 char* addr = (char*)self + mp->offset;
2167 PyObject* obj = *(PyObject**)addr;
2168 if (obj != nullptr) {
2169 int err = visit(obj, arg);
2170 if (err)
2171 return err;
2172 }
2173 }
2174 }
2175 return 0;
2176 }
2177
THPVariable_subclass_traverse(PyObject * self,visitproc visit,void * arg)2178 static int THPVariable_subclass_traverse(
2179 PyObject* self,
2180 visitproc visit,
2181 void* arg) {
2182 // If the tensor is eligible to be resurrected, don't traverse it; instead
2183 // treat all of its references as a root (as they WOULD be a root since we
2184 // can treat the inbound C++ references as root owners).
2185 //
2186 // This works because unlike conventional GCs, Python's GC operates in two
2187 // phases: first it uses traverse to discover roots, and then it uses traverse
2188 // to do reachability. Bypassing traverse during root discovery forces Python
2189 // to treat self as a root for everything it refers to. For a full
2190 // explanation of the algorithm see
2191 // https://devguide.python.org/garbage_collector/
2192 //
2193 // NB: if we don't hold an owning reference to the underlying Tensor, it is
2194 // possible that the underlying Tensor has already gone dead. In that case,
2195 // it's not safe to access it. But it's also safe to traverse, because if
2196 // the underlying Tensor *is* live, then root discovery will determine that
2197 // self is live, and nothing will get GC'ed anyway (resurrection cannot happen
2198 // if the C++ objects owns the PyObject)
2199 THPVariable* var = reinterpret_cast<THPVariable*>(self);
2200 if (isResurrectable(var)) {
2201 return 0;
2202 }
2203
2204 // Crappy version of subtype_traverse; same deal as
2205 // THPVariable_subclass_dealloc
2206
2207 PyTypeObject* type = Py_TYPE(self);
2208 // Traverse slots until we get to base class THPVariableType
2209 {
2210 PyTypeObject* base = type;
2211 while (base != &THPVariableType) {
2212 if (Py_SIZE(base)) {
2213 int err = traverse_slots(base, self, visit, arg);
2214 if (err)
2215 return err;
2216 }
2217 base = base->tp_base;
2218 TORCH_INTERNAL_ASSERT(base);
2219 }
2220 }
2221
2222 // All Python defined classes have __dict__
2223 if (C10_LIKELY(type->tp_dictoffset)) {
2224 PyObject** dictptr = _PyObject_GetDictPtr(self);
2225 if (dictptr && *dictptr)
2226 Py_VISIT(*dictptr);
2227 }
2228
2229 TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
2230 Py_VISIT(type);
2231
2232 // Finally traverse THPVariable special stuff
2233 Py_VISIT(var->backward_hooks);
2234 Py_VISIT(var->post_accumulate_grad_hooks);
2235 if (!var->cdata.unsafeIsBorrowed()) {
2236 const auto& tensor = THPVariable_Unpack(var);
2237 if (tensor.defined()) {
2238 // WARNING: The grad_fn traversal logic is very subtle, if you change
2239 // this, be very careful not to re-introduce this bug:
2240 // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
2241
2242 // We ensure that we follow NOTE [ PyObject Traversal ] he by checking
2243 // that this python object is the sole owner of the underlying Tensor and
2244 // that this Tensor is the sole owner of its grad_fn. In this case, the
2245 // only way to get a new reference to the grad_fn is by using this python
2246 // object, which requires the GIL to be accessed. Note that this is only
2247 // valid as long as user don't share non-owning references across
2248 // different threads (which is crazy and should never be done).
2249 auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
2250 if (tensor.use_count() == 1) {
2251 if (autograd_meta) {
2252 // Do NOT call grad_fn() here as that might trigger a recompute
2253 const auto& grad_fn = autograd_meta->grad_fn_;
2254 if (grad_fn && grad_fn.use_count() == 1) {
2255 // All Node can have a pyobj (stored in "pyobj_")
2256 Py_VISIT(grad_fn->pyobj());
2257 // PyNode are special as they also have an "obj" field
2258 if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
2259 Py_VISIT(py_node_fn->obj);
2260 }
2261 }
2262 }
2263 }
2264 if (autograd_meta) {
2265 for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
2266 if (auto pyhook =
2267 dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
2268 Py_VISIT(pyhook->dict);
2269 }
2270 }
2271 }
2272 }
2273 }
2274
2275 return 0;
2276 }
2277
THPVariableMetaType_init(PyObject * cls,PyObject * args,PyObject * kwargs)2278 int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
2279 if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
2280 return -1;
2281 }
2282 ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
2283 ((PyTypeObject*)cls)->tp_traverse =
2284 (traverseproc)THPVariable_subclass_traverse;
2285
2286 // Don't do anything for the base Tensor class
2287 if (!THPVariableClass) {
2288 return 0;
2289 }
2290
2291 // Forbid subclassing _TensorBase directly
2292 py::tuple mro =
2293 py::reinterpret_borrow<py::tuple>(((PyTypeObject*)cls)->tp_mro);
2294 bool is_subclass_of_thpvariable = false;
2295 for (py::handle h : mro) {
2296 if (h.ptr() == THPVariableClass) {
2297 is_subclass_of_thpvariable = true;
2298 break;
2299 }
2300 }
2301 if (!is_subclass_of_thpvariable) {
2302 PyErr_SetString(PyExc_RuntimeError, "Cannot subclass _TensorBase directly");
2303 return -1;
2304 }
2305
2306 // If the user provided a torch_dispatch implementation, disable
2307 // torch_function.
2308 py::object torch_dispatch_impl = py::reinterpret_steal<py::object>(
2309 PyObject_GetAttrString(cls, "__torch_dispatch__"));
2310 py::object torch_dispatch_default = py::reinterpret_steal<py::object>(
2311 PyObject_GetAttrString(THPVariableClass, "__torch_dispatch__"));
2312 if (torch_dispatch_impl.ptr() != torch_dispatch_default.ptr()) {
2313 py::object torch_function_impl = py::reinterpret_steal<py::object>(
2314 PyObject_GetAttrString(cls, "__torch_function__"));
2315 py::object torch_function_default_bound = py::reinterpret_steal<py::object>(
2316 PyObject_GetAttrString(THPVariableClass, "__torch_function__"));
2317
2318 // Since our __torch_function__ is a classmethod, we need to "unbound" the
2319 // method to get the raw function
2320 py::object torch_function_default = py::reinterpret_steal<py::object>(
2321 PyObject_GetAttrString(torch_function_default_bound.ptr(), "__func__"));
2322
2323 // User-defined __torch_function__ might not be a classmethod
2324 if (PyObject_HasAttrString(torch_function_impl.ptr(), "__func__")) {
2325 torch_function_impl = py::reinterpret_steal<py::object>(
2326 PyObject_GetAttrString(torch_function_impl.ptr(), "__func__"));
2327 }
2328 if (torch_function_impl.ptr() == torch_function_default.ptr()) {
2329 PyObject_SetAttrString(
2330 cls, "__torch_function__", torch::disabled_torch_function_impl());
2331 }
2332 }
2333
2334 return 0;
2335 }
2336
2337 namespace torch::autograd {
2338
2339 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
2340 extern PyMethodDef variable_methods[];
2341 extern void initTorchFunctions(PyObject* module);
2342
initTensorImplConversion(PyObject * module)2343 void initTensorImplConversion(PyObject* module) {
2344 auto m = py::handle(module).cast<py::module>();
2345 m.def("_wrap_tensor_impl", [](void* ptr) {
2346 auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
2347 unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
2348 TORCH_CHECK(p.defined(), "Can't wrap undefined tensor");
2349 auto tensor = at::Tensor::wrap_tensor_impl(std::move(p));
2350 return py::cast(std::move(tensor));
2351 });
2352 // set on the module level to avoid mixing pybind and plain CPython extensions
2353 m.def("_tensor_impl_raw_handle", [](torch::autograd::Variable* t) -> void* {
2354 // We return a raw non-owning pointer here, we rely on surrounding
2355 // code to keep the original tensor alive
2356 return t->getIntrusivePtr().get();
2357 });
2358 }
2359 } // namespace torch::autograd
2360
THPVariable_initModule(PyObject * module)2361 bool THPVariable_initModule(PyObject* module) {
2362 THPVariableMetaType.tp_base = &PyType_Type;
2363 if (PyType_Ready(&THPVariableMetaType) < 0)
2364 return false;
2365 Py_INCREF(&THPVariableMetaType);
2366 PyModule_AddObject(module, "_TensorMeta", (PyObject*)&THPVariableMetaType);
2367
2368 static std::vector<PyMethodDef> methods;
2369 THPUtils_addPyMethodDefs(methods, torch::autograd::variable_methods);
2370 THPUtils_addPyMethodDefs(methods, extra_methods);
2371 THPVariableType.tp_methods = methods.data();
2372 if (PyType_Ready(&THPVariableType) < 0)
2373 return false;
2374 Py_INCREF(&THPVariableType);
2375 PyModule_AddObject(module, "TensorBase", (PyObject*)&THPVariableType);
2376 PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType);
2377 torch::autograd::initTorchFunctions(module);
2378 torch::autograd::initTensorImplConversion(module);
2379 torch::utils::validate_numpy_for_dlpack_deleter_bug();
2380 return true;
2381 }
2382