1 #include <torch/csrc/autograd/python_function.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/SequenceNumber.h>
5 #include <c10/util/irange.h>
6 #include <pybind11/pybind11.h>
7 #include <structmember.h>
8 #include <torch/csrc/PyInterpreter.h>
9 #include <torch/csrc/python_headers.h>
10 #include <torch/csrc/utils/pybind.h>
11
12 #include <ATen/FuncTorchTLS.h>
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <torch/csrc/DynamicTypes.h>
15 #include <torch/csrc/Exceptions.h>
16 #include <torch/csrc/THP.h>
17 #include <torch/csrc/autograd/functions/accumulate_grad.h>
18 #include <torch/csrc/autograd/functions/basic_ops.h>
19 #include <torch/csrc/autograd/functions/utils.h>
20 #include <torch/csrc/autograd/grad_mode.h>
21 #include <torch/csrc/autograd/graph_task.h>
22 #include <torch/csrc/autograd/python_anomaly_mode.h>
23 #include <torch/csrc/autograd/python_cpp_function.h>
24 #include <torch/csrc/autograd/python_hook.h>
25 #include <torch/csrc/autograd/saved_variable.h>
26 #include <torch/csrc/autograd/utils/wrap_outputs.h>
27 #include <torch/csrc/dynamo/compiled_autograd.h>
28 #include <torch/csrc/jit/frontend/tracer.h>
29 #include <torch/csrc/jit/ir/ir.h>
30 #include <torch/csrc/jit/python/pybind_utils.h>
31 #include <torch/csrc/jit/python/python_tracer.h>
32 #include <torch/csrc/profiler/api.h>
33 #include <torch/csrc/utils/python_strings.h>
34 #include <torch/csrc/utils/tensor_dtypes.h>
35
36 #include <functional>
37 #include <memory>
38 #include <stdexcept>
39 #include <string>
40 #include <unordered_map>
41 #include <unordered_set>
42 #include <utility>
43 #include <vector>
44
45 using namespace torch;
46 using namespace torch::autograd;
47 using at::Tensor;
48
49 PyObject* THPFunctionClass = nullptr;
50 PyObject* THPGradientEdgeClass = nullptr;
51
52 #define THPFunction_assert(condition, ...) \
53 if (!(condition)) { \
54 THPUtils_setError(__VA_ARGS__); \
55 throw python_error(); \
56 }
57
58 // Anonymous namespace for helpful functions used in this file
59 namespace {
60
61 // TODO: We shouldn't need to call this function because the engine
62 // can already persist the errors for us. This still seems to be
63 // needed for the DistEngine however.
64 //
65 // python test/distributed/rpc/test_tensorpipe_agent.py -k
66 // test_backward_autograd_engine_error
67 //
68 // See Note [ Persisting PyErr state across autograd engine threads ]
throw_python_error()69 void throw_python_error() {
70 python_error err;
71 err.persist();
72 throw std::move(err);
73 }
74
unpack_saved_variables(THPFunction * self,const std::function<PyObject * (const Variable &)> & unpack_fn)75 static PyObject* unpack_saved_variables(
76 THPFunction* self,
77 const std::function<PyObject*(const Variable&)>& unpack_fn) {
78 HANDLE_TH_ERRORS
79 TORCH_CHECK(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
80 auto& saved_variables = self->saved_variables;
81 if (saved_variables.empty())
82 return PyTuple_New(0);
83
84 auto num_saved = saved_variables.size();
85 THPObjectPtr saved(PyTuple_New(static_cast<Py_ssize_t>(num_saved)));
86 if (!saved)
87 return nullptr;
88 auto saved_for = self->cdata.lock();
89 // This is really a true assert, because we've already tested for the
90 // self->has_freed_buffers case at the beginning of this function:
91 // buffers are freed when PyNode dies; if the buffers are not freed,
92 // PyNode must be live. (Note that the buffers could be freed
93 // even though the PyNode is live, but that doesn't matter here
94 // because we will never hit this line of code if the buffers are freed--
95 // and in any case saved_for will be non-NULL.)
96 TORCH_INTERNAL_ASSERT(saved_for);
97 for (const auto i : c10::irange(num_saved)) {
98 auto unpacked_var = saved_variables[i].unpack(saved_for);
99 THPObjectPtr value;
100 if (!unpacked_var.defined()) {
101 Py_INCREF(Py_None);
102 value = Py_None;
103 } else {
104 value = unpack_fn(unpacked_var);
105 }
106 PyTuple_SET_ITEM(saved.get(), i, value.release());
107 }
108 return saved.release();
109 END_HANDLE_TH_ERRORS
110 }
111
to_py_size(const std::vector<c10::SymInt> & size)112 PyObject* to_py_size(const std::vector<c10::SymInt>& size) {
113 c10::SymIntArrayRef sym_sizes(size);
114
115 auto ret = THPObjectPtr(THPSizeType.tp_alloc(
116 &THPSizeType, static_cast<Py_ssize_t>(sym_sizes.size())));
117 if (!ret)
118 throw python_error();
119
120 for (auto i : c10::irange(sym_sizes.size())) {
121 auto symint = sym_sizes[i];
122 if (auto maybe_int = symint.maybe_as_int(); maybe_int.has_value()) {
123 PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*maybe_int));
124 } else {
125 auto py_symint = py::cast(symint).release().ptr();
126 PyTuple_SET_ITEM(ret.get(), i, py_symint);
127 }
128 }
129 return ret.release();
130 }
131
132 } // namespace
133
134 namespace torch::autograd {
135
136 // NOTE: this function is written in a way that assumes it's only called for
137 // backward; it's used by engine.cpp. This is responsible for forwarding a call
138 // from C++'s Node::apply to a Python method "apply".
apply(variable_list && inputs)139 auto PyNode::apply(variable_list&& inputs) -> variable_list {
140 pybind11::gil_scoped_acquire gil;
141 at::OptionalDeviceGuard _device_guard;
142 THPFunction* py_fn = (THPFunction*)obj;
143
144 // Massage a C++ variable_list into a Python arguments tuple
145 THPObjectPtr pyInputs(to_py_args(inputs, &_device_guard));
146
147 THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply"));
148 if (!apply_fn)
149 throw_python_error();
150 THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
151 if (!r)
152 throw_python_error();
153 ensure_tuple(r);
154
155 auto& is_variable_input = py_fn->is_variable_input;
156 auto num_outputs = PyTuple_GET_SIZE(r.get());
157 auto num_forward_inputs = static_cast<Py_ssize_t>(is_variable_input.size());
158 // Returning too many results is ok, but only as long as they're all None.
159 // Truncate the result tuple in that case.
160 if (num_outputs > num_forward_inputs) {
161 bool all_none = true;
162 for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
163 all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None;
164 }
165 if (all_none) {
166 num_outputs = num_forward_inputs;
167 r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
168 if (!r)
169 throw_python_error();
170 }
171 }
172
173 // Now the number of gradients should match
174 if (num_outputs != num_forward_inputs) {
175 std::string msg("function ");
176 msg += name() + " returned an incorrect number of gradients (expected ";
177 msg += std::to_string(num_forward_inputs) + ", got ";
178 msg += std::to_string(num_outputs) + ")";
179 throw std::runtime_error(msg);
180 }
181
182 // Massage the Python results tuple back into a C++ variable_list
183 return to_variable_list(r.get(), is_variable_input);
184 }
185
defer_to_dynamo(variable_list && inputs,std::optional<PyObject * > compiler)186 auto PyNode::defer_to_dynamo(
187 variable_list&& inputs,
188 std::optional<PyObject*> compiler) -> variable_list {
189 pybind11::gil_scoped_acquire gil;
190 at::OptionalDeviceGuard _device_guard;
191 THPFunction* py_fn = (THPFunction*)obj;
192
193 // Massage a C++ variable_list into a Python arguments tuple
194 THPObjectPtr pyInputs(to_py_args(inputs, &_device_guard));
195
196 const auto& is_variable_input = py_fn->is_variable_input;
197 const auto& input_infos = py_fn->input_info;
198 // input_info only contains info from variable inputs and should be a subset
199 TORCH_INTERNAL_ASSERT(is_variable_input.size() >= input_infos.size());
200
201 // The gradients returned in the backwards need to match the number of inputs
202 // to the forward, and their metadata, so we pass the fwdInputs
203 THPObjectPtr fwdInputMetadatas(
204 PyTuple_New(static_cast<Py_ssize_t>(is_variable_input.size())));
205 if (!fwdInputMetadatas)
206 throw python_error();
207
208 int offset = 0;
209 for (const auto i : c10::irange(is_variable_input.size())) {
210 if (!is_variable_input[i]) {
211 // input at i is not a variable, skip index
212 PyTuple_SET_ITEM(fwdInputMetadatas.get(), i, Py_None);
213 offset++;
214 continue;
215 }
216
217 const auto& input_info = input_infos[i - offset];
218
219 PyObject* device(THPDevice_New(input_info.device));
220 if (!device)
221 throw_python_error();
222 // Metadata is a tuple of 4 elements: (layout, device, dtype, size)
223 PyObject* fwdInputMetadata = PyTuple_Pack(
224 4,
225 autograd::utils::wrap(input_info.layout),
226 device,
227 autograd::utils::wrap(input_info.scalar_type),
228 to_py_size(input_info.size));
229 if (!fwdInputMetadata)
230 throw python_error();
231
232 PyTuple_SET_ITEM(fwdInputMetadatas.get(), i, fwdInputMetadata);
233 }
234 THPObjectPtr saved_tensors(unpack_saved_variables(
235 py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
236 TORCH_INTERNAL_ASSERT(
237 _backward_idx.has_value(),
238 "indices should already be set by compiled_args, called before apply_with_saved");
239 TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value());
240 THPObjectPtr r(PyObject_CallMethod(
241 *compiler,
242 "proxy_call_backward",
243 "OOOi",
244 pyInputs.get(),
245 fwdInputMetadatas.get(),
246 saved_tensors.get(),
247 *_backward_idx));
248
249 if (!r)
250 throw_python_error();
251 ensure_tuple(r);
252
253 // Massage the Python results tuple back into a C++ variable_list
254 return to_variable_list(r.get(), is_variable_input);
255 }
256
is_traceable()257 auto PyNode::is_traceable() -> bool {
258 pybind11::gil_scoped_acquire gil;
259 THPObjectPtr forward_class{PyObject_GetAttrString(obj, "_forward_cls")};
260 if (!forward_class)
261 throw_python_error();
262 THPObjectPtr traceable_py_bool{
263 PyObject_GetAttrString(forward_class, "is_traceable")};
264 if (!traceable_py_bool)
265 throw_python_error();
266 return traceable_py_bool == Py_True;
267 }
268
release_variables()269 auto PyNode::release_variables() -> void {
270 // This function is called as part of the Node destructor!
271 // Since this object might be kept alive by C++, it is possible
272 // that the python interpreter is already dead here. In that case
273 // we just leak the saved objects.
274 if (Py_IsInitialized()) {
275 pybind11::gil_scoped_acquire gil;
276 auto f = (THPFunction*)obj;
277 f->saved_variables.clear();
278 f->has_freed_buffers = 1;
279 }
280 }
281
name() const282 auto PyNode::name() const -> std::string {
283 pybind11::gil_scoped_acquire gil;
284 auto f = (THPFunction*)obj;
285 auto name = std::string(Py_TYPE(f)->tp_name);
286 return name;
287 }
288
compiled_autograd_should_lift() const289 auto PyNode::compiled_autograd_should_lift() const -> bool {
290 pybind11::gil_scoped_acquire gil;
291 static PyObject* attr_name =
292 PyUnicode_InternFromString("_compiled_autograd_should_lift");
293 THPObjectPtr should_lift(PyObject_GetAttr(obj, attr_name));
294 return PyObject_IsTrue(should_lift.get()) == 1;
295 }
296
compiled_args(CompiledNodeArgs & args)297 void PyNode::compiled_args(CompiledNodeArgs& args) {
298 static PyObject* method_name =
299 PyUnicode_InternFromString("_compiled_autograd_key");
300 THPObjectPtr pykey(PyObject_CallMethodNoArgs(obj, method_name));
301 if (!pykey)
302 throw_python_error();
303 TORCH_CHECK(
304 PyTuple_CheckExact(pykey.get()),
305 "_compiled_autograd_key should return tuple of ints");
306 auto size = PyTuple_GET_SIZE(pykey.get());
307 TORCH_INTERNAL_ASSERT(size > 0);
308 // first value is unique id managed by AUTOGRAD_FUNCTION_COUNTER
309 auto key = PyLong_AsSsize_t(PyTuple_GET_ITEM(pykey.get(), 0));
310 if (C10_UNLIKELY(key < 0)) {
311 TORCH_CHECK(PyErr_Occurred(), "key must be positive");
312 throw_python_error();
313 }
314 args.collect_size(static_cast<size_t>(key));
315 args.collect_size(static_cast<size_t>(size));
316
317 auto f = (THPFunction*)obj;
318 f->compiled_autograd_symints.clear();
319 f->compiled_autograd_symints.reserve(size - 1);
320 for (const auto i : c10::irange(1, size)) {
321 auto val = PyLong_AsSsize_t(PyTuple_GET_ITEM(pykey.get(), i));
322 if (C10_UNLIKELY(val == -1 && PyErr_Occurred()))
323 throw_python_error();
324 f->compiled_autograd_symints.emplace_back(val);
325 }
326
327 // AotAutograd symints are all dynamic
328 auto prior =
329 args.set_default_dyn_type(torch::dynamo::autograd::SizeInput::DYNAMIC);
330 args.collect(f->compiled_autograd_symints);
331 args.set_default_dyn_type(prior);
332
333 args.collect(f->saved_variables, true); // always unpacked as output in eager
334 args.collect(f->materialize_grads);
335 args.collect(f->is_variable_input);
336 args.collect(f->needs_input_grad);
337 args.collect(f->materialize_non_diff_grads);
338 args.collect(f->output_info);
339 args.collect(f->input_info);
340
341 if (compiled_autograd_should_lift()) {
342 Py_INCREF(obj);
343 _backward_idx =
344 args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
345 }
346
347 PyObject* bw_state = f->compiled_autograd_backward_state;
348 if (args.cond(bw_state != nullptr)) {
349 Py_INCREF(bw_state);
350 _backward_state_idx = args.add_backward_state(
351 c10::SafePyObject(bw_state, getPyInterpreter()));
352 }
353 }
354
apply_with_saved(const variable_list & inputs,SwapSavedVariables & saved)355 variable_list PyNode::apply_with_saved(
356 const variable_list& inputs,
357 SwapSavedVariables& saved) {
358 auto f = (THPFunction*)obj;
359 TORCH_INTERNAL_ASSERT(!f->compiled_autograd_tracing);
360 saved.before(f->compiled_autograd_symints);
361 saved.before(f->saved_variables);
362 saved.before(f->needs_input_grad);
363 saved.before(f->materialize_non_diff_grads);
364 saved.before(f->output_info);
365 saved.before(f->input_info);
366 f->compiled_autograd_tracing = true;
367 variable_list result;
368 if (!compiled_autograd_should_lift()) {
369 if (_backward_state_idx.has_value()) {
370 PyObject* r = PyObject_CallMethod(
371 saved.get_py_compiler(),
372 "bind_backward_state",
373 "i",
374 *_backward_state_idx);
375 if (r == nullptr) {
376 throw python_error();
377 }
378 THPObjectPtr prior(f->compiled_autograd_backward_state);
379 f->compiled_autograd_backward_state = r;
380 result = apply(variable_list(inputs));
381 Py_CLEAR(f->compiled_autograd_backward_state);
382 f->compiled_autograd_backward_state = prior.release();
383 } else {
384 result = apply(variable_list(inputs));
385 }
386 } else {
387 result = defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
388 }
389 f->compiled_autograd_tracing = false;
390 saved.after(f->compiled_autograd_symints);
391 saved.after(f->saved_variables);
392 saved.after(f->needs_input_grad);
393 saved.after(f->materialize_non_diff_grads);
394 saved.after(f->output_info);
395 saved.after(f->input_info);
396 return result;
397 }
398
to_py_args(const variable_list & inputs,at::OptionalDeviceGuard * device_guard)399 PyObject* PyNode::to_py_args(
400 const variable_list& inputs,
401 at::OptionalDeviceGuard* device_guard) {
402 THPFunction* py_fn = (THPFunction*)obj;
403
404 auto zeros_without_gil = [](const VariableInfo& variable,
405 at::OptionalDeviceGuard& dg) {
406 pybind11::gil_scoped_release gil;
407 return variable.zeros(dg);
408 };
409
410 auto num_inputs = inputs.size();
411 PyObject* pyInputs = PyTuple_New(static_cast<Py_ssize_t>(num_inputs));
412 if (!pyInputs)
413 throw_python_error();
414 auto& output_info = py_fn->output_info;
415 for (const auto i : c10::irange(num_inputs)) {
416 PyObject* input = nullptr;
417 if (inputs[i].defined() || !py_fn->materialize_grads ||
418 (input_metadata(i).was_default_constructed() &&
419 !py_fn->materialize_non_diff_grads)) {
420 input = THPVariable_Wrap(inputs[i]);
421 } else {
422 input =
423 THPVariable_Wrap(zeros_without_gil(output_info[i], *device_guard));
424 }
425 if (!input)
426 throw_python_error();
427 PyTuple_SET_ITEM(pyInputs, i, input);
428 }
429
430 return pyInputs;
431 }
432
to_variable_list(const PyObject * outputs,const std::vector<bool> & is_variable_input)433 variable_list PyNode::to_variable_list(
434 const PyObject* outputs,
435 const std::vector<bool>& is_variable_input) {
436 auto num_outputs = PyTuple_GET_SIZE(outputs);
437 variable_list results;
438 results.reserve(num_outputs);
439 for (int i = 0; i != num_outputs; ++i) {
440 PyObject* output = PyTuple_GET_ITEM(outputs, i);
441 bool was_variable = is_variable_input[i];
442 if (!was_variable) {
443 if (output != Py_None) {
444 std::string msg("function ");
445 msg += name() + " returned a gradient different than None at position ";
446 msg += std::to_string(i + 1) +
447 ", but the corresponding forward input was not a Variable";
448 throw std::runtime_error(msg);
449 }
450 continue;
451 }
452 if (output == Py_None) {
453 results.emplace_back();
454 } else {
455 if (!THPVariable_Check(output)) {
456 std::string msg("expected Variable or None (got ");
457 msg += THPUtils_typename(output);
458 msg += ")";
459 throw std::runtime_error(msg);
460 }
461 results.emplace_back(THPVariable_Unpack(output));
462 }
463 }
464
465 return results;
466 }
467
468 } // namespace torch::autograd
469
470 // Traverse and clear are required for supporting Python's GC cycle handling.
THPFunction_traverse(THPFunction * self,visitproc visit,void * arg)471 static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
472 // NB: We should not traverse PyObbject stored on PyNode, since we only hold
473 // as weak reference to the PyNode.
474 Py_VISIT(self->to_save);
475 Py_VISIT(self->non_differentiable);
476 Py_VISIT(self->dirty_tensors);
477 Py_VISIT(self->compiled_autograd_backward_state);
478 Py_VISIT(self->saved_for_forward);
479 return 0;
480 }
481
THPFunction_clear(THPFunction * self)482 static int THPFunction_clear(THPFunction* self) {
483 // Note that the cdata might not be expired yet in the case where this
484 // object is part of a cycle and the GC happens to tp_clear this PyObject
485 // before the other ones that trigger the de-allocation of the cdata
486
487 Py_CLEAR(self->needs_input_grad);
488
489 Py_CLEAR(self->to_save);
490 Py_CLEAR(self->non_differentiable);
491 Py_CLEAR(self->dirty_tensors);
492 Py_CLEAR(self->compiled_autograd_backward_state);
493 Py_CLEAR(self->saved_for_forward);
494
495 self->output_info.clear();
496 self->input_info.clear();
497 self->saved_variables.clear();
498 self->is_variable_input.clear();
499
500 return 0;
501 }
502
THPFunction_dealloc(THPFunction * self)503 static void THPFunction_dealloc(THPFunction* self) {
504 // Why is this guaranteed to be true? Suppose that self->cdata is non-null
505 // (otherwise the condition is trivially true). Then there is a PyNode
506 // which contains an owning reference to this object. But we are only
507 // allowed to clear if all owning references are gone! Contradiction.
508 //
509 // However, note that THPFunction_clear is typically called in the shared_ptr
510 // destructor of PyNode; in that case, per
511 // https://cplusplus.github.io/LWG/lwg-active.html#2751 it's not currently
512 // specified in the standard that this is guaranteed. If you see this
513 // assert triggering in the wild, feel free to comment it out. They're
514 // likely to standardize that you ARE guaranteed to see the weak pointers
515 // as expired in the destructor in the future, so we'll keep this for now.
516 TORCH_INTERNAL_ASSERT(self->cdata.expired());
517
518 PyObject_GC_UnTrack(self);
519 THPFunction_clear(self);
520 self->cdata.~weak_ptr<PyNode>();
521 self->output_info.~vector();
522 self->input_info.~vector();
523 self->saved_variables.~vector();
524 self->is_variable_input.~vector();
525 Py_TYPE(self)->tp_free((PyObject*)self);
526 }
527
THPFunction_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)528 PyObject* THPFunction_new(
529 PyTypeObject* type,
530 PyObject* args,
531 PyObject* kwargs) {
532 PyObject* obj = type->tp_alloc(type, 0);
533 if (!obj)
534 return nullptr;
535 // Python zero-initializes the object memory, so there's no need to initialize
536 // most fields
537 THPFunction* self = (THPFunction*)obj;
538 // Setup the PyNode later; we can't keep it live here
539 new (&self->cdata) std::weak_ptr<PyNode>();
540 new (&self->output_info) std::vector<VariableInfo>();
541 new (&self->input_info) std::vector<VariableInfo>();
542 new (&self->saved_variables) std::vector<SavedVariable>();
543 new (&self->is_variable_input) std::vector<bool>();
544 self->materialize_grads = true;
545 self->materialize_non_diff_grads = true;
546 self->compiled_autograd_tracing = false;
547 return obj;
548 }
549
550 ////////////////////////////////////////////////////////////////////////////////
551 // Forward
552 ////////////////////////////////////////////////////////////////////////////////
553
554 // Bump the counters of all recorded dirty input tensors, adding each of them
555 // into dirty_inputs. Also does some sanity checking.
_mark_dirty(THPFunction * self)556 static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction* self) {
557 // Increase versions of modified tensors
558 std::unordered_set<at::TensorImpl*> dirty_inputs;
559 if (!self->dirty_tensors)
560 return dirty_inputs;
561
562 THPFunction_assert(
563 PyTuple_Check(self->dirty_tensors),
564 "autograd "
565 "internal error: dirty_tensors attribute is expected to be a tuple "
566 "but is ",
567 THPUtils_typename(self->dirty_tensors));
568 Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
569 dirty_inputs.reserve(num_dirty);
570 for (const auto i : c10::irange(num_dirty)) {
571 PyObject* obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
572 THPFunction_assert(
573 THPVariable_Check(obj),
574 "mark_dirty can "
575 "only accept variables, but argument ",
576 i,
577 " is of type ",
578 THPUtils_typename(obj));
579
580 const auto& tensor = THPVariable_Unpack(obj);
581 dirty_inputs.insert(tensor.unsafeGetTensorImpl());
582 torch::autograd::impl::bump_version(tensor);
583 }
584 // We're not going to ever need this so let's remove references now
585 Py_CLEAR(self->dirty_tensors);
586 return dirty_inputs;
587 }
588
589 static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
590 THPFunction* self);
591
592 // Given a Python tuple of raw output tensors (raw_output), set each of
593 // the corresponding entries in a different Python tuple (outputs) with
594 // these tensors wrapped with variables. We save the gradient function (self)
595 // to the variable if the output requires grad.
596 //
597 // There is a considerable amount of complexity to handle if the operation
598 // that produced these output tensors is inplace. A mapping of *input*
599 // tensors to variables (t2var) is used to test if this occurred, and
600 // the set of dirty tensors (dirty_inputs) is used to figure out what to
601 // do in this case. After this method is run, t2var is extended with
602 // mappings for output tensors as well.
_wrap_outputs(const std::shared_ptr<PyNode> & cdata,THPFunction * self,const variable_list & input_vars,PyObject * raw_output,PyObject * outputs,bool is_executable,const std::unordered_set<at::TensorImpl * > & to_save_if_setup_context)603 static void _wrap_outputs(
604 const std::shared_ptr<PyNode>& cdata,
605 THPFunction* self,
606 const variable_list& input_vars,
607 PyObject* raw_output,
608 PyObject* outputs,
609 bool is_executable,
610 const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
611 auto cdata_if_executable = is_executable ? cdata : nullptr;
612 Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
613 if (is_executable) {
614 self->output_info.clear();
615 self->output_info.reserve(num_outputs);
616 }
617
618 auto non_differentiable = _parse_non_differentiable(self);
619 auto dirty_inputs = _mark_dirty(self);
620
621 std::vector<std::optional<Variable>> raw_output_vars;
622 raw_output_vars.reserve(num_outputs);
623 for (const auto i : c10::irange(num_outputs)) {
624 PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
625 // Only process tensors as outputs for autograd purposes.
626 if (THPVariable_Check(obj)) {
627 raw_output_vars.emplace_back(THPVariable_Unpack(obj));
628 } else {
629 raw_output_vars.emplace_back();
630 }
631 }
632
633 _jvp_fn_t jvp_user_function = [self](
634 variable_list inputs,
635 variable_list grad_inputs) {
636 pybind11::gil_scoped_acquire gil;
637
638 // Massage a C++ variable_list into a Python arguments tuple
639 // Making sure to introduce the proper None for non-Tensor inputs
640 auto num_inputs = self->is_variable_input.size();
641 THPObjectPtr pyInputs(PyTuple_New(static_cast<Py_ssize_t>(num_inputs)));
642 if (!pyInputs)
643 throw_python_error();
644 int64_t variable_idx = 0;
645 for (const auto i : c10::irange(num_inputs)) {
646 PyObject* input = nullptr;
647 if (self->is_variable_input[i]) {
648 if (grad_inputs[variable_idx].defined() || !self->materialize_grads ||
649 !isDifferentiableType(inputs[variable_idx].scalar_type())) {
650 input = THPVariable_Wrap(grad_inputs[variable_idx]);
651 } else {
652 input = THPVariable_Wrap(at::zeros_like(inputs[variable_idx]));
653 }
654 if (!input) {
655 throw_python_error();
656 }
657 variable_idx++;
658 } else {
659 Py_INCREF(Py_None);
660 input = Py_None;
661 }
662 PyTuple_SET_ITEM(pyInputs.get(), i, input);
663 }
664
665 THPObjectPtr apply_jvp_fn(
666 PyObject_GetAttrString((PyObject*)self, "apply_jvp"));
667 if (!apply_jvp_fn)
668 throw_python_error();
669 THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get()));
670 if (!r)
671 throw_python_error();
672 ensure_tuple(r);
673
674 // Massage the Python results tuple back into a C++ variable_list
675 // Don't do any check on the number of results here as
676 // it is handled by the caller
677 const int num_outputs = PyTuple_GET_SIZE(r.get());
678 variable_list results;
679 results.reserve(num_outputs);
680 for (const auto i : c10::irange(num_outputs)) {
681 PyObject* output = PyTuple_GET_ITEM(r.get(), i);
682 if (output == Py_None) {
683 results.emplace_back();
684 } else {
685 TORCH_CHECK(
686 THPVariable_Check(output),
687 "expected Variable or None (got ",
688 THPUtils_typename(output),
689 ") for grad output ",
690 i,
691 ".")
692 results.emplace_back(THPVariable_Unpack(output));
693 }
694 }
695
696 return results;
697 };
698
699 auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
700 pybind11::gil_scoped_acquire gil;
701 THPObjectPtr py_x(THPVariable_Wrap(x));
702 THPObjectPtr py_view_as_method(PyObject_GetAttrString(py_x, "view_as"));
703 if (!py_view_as_method)
704 throw python_error();
705 THPObjectPtr args(PyTuple_Pack(1, py_x.get()));
706 if (!args)
707 throw python_error();
708 THPObjectPtr result(PyObject_CallObject(py_view_as_method, args));
709 if (!result)
710 throw python_error();
711 return THPVariable_Unpack(result);
712 };
713
714 // Wrap only the tensor outputs.
715 auto wrapped_outputs = _wrap_outputs(
716 input_vars,
717 non_differentiable,
718 dirty_inputs,
719 raw_output_vars,
720 cdata_if_executable,
721 jvp_user_function,
722 to_save_if_setup_context,
723 view_as_self_fn);
724
725 for (const auto i : c10::irange(num_outputs)) {
726 PyObject* obj = PyTuple_GetItem(raw_output, i);
727 // Keep the non-tensor outputs as is.
728 if (!THPVariable_Check(obj)) {
729 if (is_executable) {
730 self->output_info.emplace_back();
731 }
732 Py_INCREF(obj);
733 PyTuple_SetItem(outputs, i, obj);
734 } else {
735 if (is_executable) {
736 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
737 self->output_info.emplace_back(*wrapped_outputs[i]);
738 }
739 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
740 PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
741 }
742 }
743 }
744
_get_tensors_to_save(THPFunction * self,std::unordered_set<at::TensorImpl * > & to_save_if_setup_context,std::vector<std::optional<at::Tensor>> & tensors_to_save,bool overridden_setup_context,bool is_executable)745 static void _get_tensors_to_save(
746 THPFunction* self,
747 std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
748 std::vector<std::optional<at::Tensor>>& tensors_to_save,
749 bool overridden_setup_context,
750 bool is_executable) {
751 if (self->saved_for_forward && overridden_setup_context) {
752 // We look at saved_for_forward here purely for the purpose of populating
753 // to_save_if_setup_context, the actual saving is not done here.
754 THPFunction_assert(
755 PyTuple_Check(self->saved_for_forward),
756 "autograd internal "
757 "error: saved_for_forward attribute is expected to be a tuple but is ",
758 THPUtils_typename(self->saved_for_forward));
759 Py_ssize_t num_saved_for_forward =
760 PyTuple_GET_SIZE(self->saved_for_forward);
761 for (const auto i : c10::irange(num_saved_for_forward)) {
762 PyObject* obj = PyTuple_GET_ITEM(self->saved_for_forward, i);
763 if (THPVariable_Check(obj)) {
764 const auto& tensor = THPVariable_Unpack(obj);
765 to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
766 }
767 }
768 }
769 if (self->to_save) {
770 THPFunction_assert(
771 PyTuple_Check(self->to_save),
772 "autograd internal "
773 "error: to_save attribute is expected to be a tuple but is ",
774 THPUtils_typename(self->to_save));
775
776 Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
777 for (const auto i : c10::irange(num_saved)) {
778 PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
779 if (obj == Py_None) {
780 tensors_to_save.emplace_back(std::nullopt);
781 continue;
782 } else if (THPVariable_Check(obj)) {
783 const auto& tensor = THPVariable_Unpack(obj);
784 if (overridden_setup_context) {
785 to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
786 }
787 if (is_executable) {
788 tensors_to_save.emplace_back(tensor);
789 }
790 } else {
791 if (is_executable) {
792 // TODO: We should really just ALWAYS throw an error here, but
793 // doing so will break some internal tests. We should fix those.
794 throw torch::TypeError(
795 "save_for_backward can only save variables, but argument %ld is of "
796 "type %s",
797 i,
798 Py_TYPE(obj)->tp_name);
799 }
800 }
801 }
802 }
803 }
804 // Save any variables that requested by to_save
_save_variables(const std::vector<std::optional<at::Tensor>> & tensors_to_save,const std::shared_ptr<PyNode> & cdata_ptr,THPFunction * self)805 static void _save_variables(
806 const std::vector<std::optional<at::Tensor>>& tensors_to_save,
807 const std::shared_ptr<PyNode>& cdata_ptr,
808 THPFunction* self) {
809 if (!self->to_save)
810 return;
811 size_t num_saved = tensors_to_save.size();
812 self->saved_variables.clear();
813 self->saved_variables.reserve(num_saved);
814 for (const auto& opt_tensor : tensors_to_save) {
815 if (!opt_tensor.has_value()) {
816 self->saved_variables.emplace_back();
817 } else {
818 bool is_output = opt_tensor.value().grad_fn().get() == cdata_ptr.get();
819 self->saved_variables.emplace_back(opt_tensor.value(), is_output);
820 }
821 }
822 // Free .to_save
823 Py_CLEAR(self->to_save);
824 }
825
826 // Mark requires_grad = 0 on non-differentiable variables (as per
827 // non_differentiable)
_parse_non_differentiable(THPFunction * self)828 static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
829 THPFunction* self) {
830 std::unordered_set<at::TensorImpl*> set;
831 if (!self->non_differentiable)
832 return set;
833
834 THPFunction_assert(
835 PyTuple_Check(self->non_differentiable),
836 "autograd "
837 "internal error: non_differentiable attribute is expected to be a "
838 "tuple but is ",
839 THPUtils_typename(self->non_differentiable));
840 Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable);
841 set.reserve(num_nondiff);
842 for (const auto i : c10::irange(num_nondiff)) {
843 PyObject* t = PyTuple_GET_ITEM(self->non_differentiable, i);
844 THPFunction_assert(
845 THPVariable_Check(t),
846 "mark_non_differentiable "
847 "only accepts variable arguments, but got ",
848 THPUtils_typename(t));
849 set.insert(THPVariable_Unpack(t).unsafeGetTensorImpl());
850 }
851 Py_CLEAR(self->non_differentiable);
852 return set;
853 }
854
855 struct UnpackedInput {
856 THPObjectPtr input_tuple;
857 variable_list input_vars;
858 // record_function_inputs is for RECORD_FUNCTION only
859 std::vector<c10::IValue> record_function_inputs;
860 };
861
862 struct InputFlags {
863 bool is_executable = false;
864 edge_list next_edges;
865 THPObjectPtr needs_input_grad;
866 std::vector<bool> is_variable_input;
867 };
868
869 template <bool enforce_variables>
unpack_input(PyObject * args)870 std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
871 UnpackedInput unpacked;
872 InputFlags flags;
873
874 auto num_args = PyTuple_GET_SIZE(args);
875 unpacked.input_tuple = PyTuple_New(num_args);
876 flags.needs_input_grad = PyTuple_New(num_args);
877 bool profiler_need_input = torch::autograd::profiler::profilerEnabled() &&
878 torch::autograd::profiler::getProfilerConfig().report_input_shapes;
879
880 for (const auto i : c10::irange(num_args)) {
881 PyObject* arg = PyTuple_GET_ITEM(args, i);
882
883 bool is_variable = THPVariable_Check(arg);
884 flags.is_variable_input.push_back(is_variable);
885 if (!is_variable) {
886 // TODO: remove this code path once Variable and Tensor are merged in
887 // Python
888 if (enforce_variables) {
889 THPUtils_setError(
890 "expected a Tensor argument, but got ", THPUtils_typename(arg));
891 throw python_error();
892 }
893 Py_INCREF(Py_False);
894 PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
895
896 if (profiler_need_input) {
897 // The following conversion from PyObject to IValue is expensive
898 // Only do it if profiler is enabled and needs input shapes
899 auto match = torch::jit::tryToInferPrimitiveType(arg);
900 if (match.success()) {
901 unpacked.record_function_inputs.push_back(
902 torch::jit::toIValue(arg, match.type()));
903 }
904 }
905 } else {
906 const auto& tensor = THPVariable_Unpack(arg);
907 unpacked.input_vars.push_back(tensor);
908 PyObject* needs_grad = tensor.requires_grad() ? Py_True : Py_False;
909 Py_INCREF(needs_grad);
910 PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
911 unpacked.record_function_inputs.emplace_back(tensor);
912 }
913 Py_INCREF(arg);
914 PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
915 }
916
917 flags.is_executable =
918 GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
919 flags.next_edges =
920 (flags.is_executable ? collect_next_edges(unpacked.input_vars)
921 : edge_list());
922 return std::make_pair(std::move(unpacked), std::move(flags));
923 }
924
925 // Given a prim::PythonOp node, _append_subgraph creates a subgraph such that:
926 // (1) It has the same inputs as the prim::PythonOp node
927 // (2) The intermediate nodes used in the PythonOp are cloned and stored in the
928 // subgraph (3) trace_outputs stores the Value* objects, before a new trace
929 // value is assigned by the prim::PythonOp node and helps to eventually route
930 // the outputs of the subgraph correctly This newly created subgraph is then
931 // added to the prim::PythonOp node as a subgraph attribute
_append_subgraph(torch::jit::Node * node,torch::jit::Graph * graph,std::vector<torch::jit::Value * > trace_outputs,bool unpack_output)932 static void _append_subgraph(
933 torch::jit::Node* node,
934 torch::jit::Graph* graph,
935 std::vector<torch::jit::Value*> trace_outputs,
936 bool unpack_output) {
937 using Value = torch::jit::Value;
938 node->g_(
939 torch::jit::attr::Subgraph,
940 std::make_shared<torch::jit::Graph>(graph->current_scope()));
941 auto subgraph = node->g(torch::jit::attr::Subgraph);
942
943 std::unordered_map<Value*, Value*> value_map;
944 auto value_map_func = [&](Value* v) { return value_map.at(v); };
945 for (size_t i = 0; i < node->inputs().size(); ++i) {
946 auto subgraph_input = subgraph->addInput();
947 subgraph_input->copyMetadata(node->inputs().at(i));
948 value_map[node->inputs().at(i)] = subgraph_input;
949 }
950 // Find node position in owning block, all subsequent nodes after are added to
951 // subgraph
952 auto owning_block = node->owningBlock();
953 auto it = std::find(
954 owning_block->nodes().begin(), owning_block->nodes().end(), node);
955 // Skip TupleUnpack node if created
956 if (!unpack_output) {
957 it++;
958 }
959 for (it++; it != owning_block->nodes().end(); ++it) {
960 torch::jit::Node* node = *it;
961 auto* clone_node =
962 subgraph->insertNode(subgraph->createClone(node, value_map_func));
963 for (size_t i = 0; i < node->outputs().size(); ++i) {
964 value_map[node->outputs()[i]] = clone_node->outputs()[i];
965 auto trace_it = std::find(
966 trace_outputs.begin(), trace_outputs.end(), node->outputs()[i]);
967 if (trace_it != trace_outputs.end()) {
968 subgraph->registerOutput(clone_node->outputs()[i]);
969 }
970 }
971 }
972 }
973
_trace_pre_record(PyObject * op_obj,PyObject * input_objects,const variable_list & input_vars)974 static torch::jit::Node* _trace_pre_record(
975 PyObject* op_obj,
976 PyObject* input_objects,
977 const variable_list& input_vars) {
978 if (!jit::tracer::isTracing()) {
979 return nullptr;
980 }
981
982 // Save scalar args and the calling convention
983 auto num_args = PyTuple_GET_SIZE(input_objects);
984 pyobj_list scalar_args;
985 std::string arg_types;
986 arg_types.reserve(num_args);
987 scalar_args.reserve(num_args);
988 for (const auto i : c10::irange(num_args)) {
989 PyObject* arg_object = PyTuple_GET_ITEM(input_objects, i);
990 if (THPVariable_Check(arg_object)) {
991 arg_types.push_back('d');
992 } else {
993 arg_types.push_back('c');
994 Py_INCREF(arg_object);
995 scalar_args.emplace_back(arg_object);
996 }
997 }
998
999 Py_INCREF(op_obj);
1000 auto pyobj = THPObjectPtr(op_obj);
1001 return jit::tracer::preRecordPythonTrace(
1002 std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
1003 }
1004
_trace_post_record(torch::jit::Node * node,PyObject * op_obj,const variable_list & input_vars,PyObject * output_objects,bool is_inplace,bool unpack_output)1005 static void _trace_post_record(
1006 torch::jit::Node* node,
1007 PyObject* op_obj,
1008 const variable_list& input_vars,
1009 PyObject* output_objects,
1010 bool is_inplace,
1011 bool unpack_output) {
1012 if (!jit::tracer::isTracing()) {
1013 return;
1014 }
1015
1016 node->i_(jit::attr::inplace, is_inplace);
1017 if (PyObject* module_name = PyDict_GetItemString(
1018 ((PyTypeObject*)op_obj)->tp_dict, "__module__")) {
1019 if (auto ptr = PyUnicode_AsUTF8(module_name)) {
1020 node->s_(jit::attr::module, std::string(ptr));
1021 }
1022 }
1023
1024 // Isolate C variable ptrs in a vector
1025 int num_outputs = PyTuple_GET_SIZE(output_objects);
1026 auto graph = node->owningGraph();
1027 node->addOutput();
1028 auto old_node = node;
1029 if (!unpack_output) {
1030 std::vector<at::TypePtr> tuple_values(num_outputs, at::TensorType::get());
1031 auto tuple_type = at::TupleType::create(std::move(tuple_values));
1032 // Original type is tuple of tensors "without" element type and shape.
1033 // The missed parts will be added below.
1034 node->output()->setType(std::move(tuple_type));
1035 auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
1036 node = unpacked;
1037 }
1038
1039 std::vector<torch::jit::Value*> trace_outputs;
1040 for (const auto i : c10::irange(num_outputs)) {
1041 PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
1042 if (THPVariable_Check(obj)) {
1043 auto value = node->outputs()[i];
1044 const auto& tensor = THPVariable_Unpack(obj);
1045 if (tensor.defined()) {
1046 value->inferTypeFrom(tensor);
1047 trace_outputs.push_back(jit::tracer::getValueTrace(tensor));
1048 jit::tracer::setValueTrace(tensor, value);
1049 }
1050 }
1051 }
1052 py::object onnx_globals = py::module::import("torch.onnx._globals");
1053 py::bool_ is_in_onnx_export =
1054 py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
1055 py::bool_ is_autograd_inlining_enabled =
1056 py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));
1057
1058 if (py::cast<bool>(is_in_onnx_export) &&
1059 py::cast<bool>(is_autograd_inlining_enabled)) {
1060 _append_subgraph(old_node, graph, std::move(trace_outputs), unpack_output);
1061 }
1062
1063 // If TupleUnpack operator is created, we copy its output type back
1064 // to the original tuple type.
1065 if (!unpack_output) {
1066 std::vector<at::TypePtr> new_tuple_values;
1067 for (const auto i : c10::irange(num_outputs)) {
1068 auto ptr = node->outputs()[i]->type();
1069 new_tuple_values.push_back(ptr);
1070 }
1071 auto tuple_type = at::TupleType::create(std::move(new_tuple_values));
1072 // The i-th tuple element receives a new tensor type with element type and
1073 // shape.
1074 old_node->output()->setType(std::move(tuple_type));
1075 }
1076 }
1077
process_outputs(PyObject * op_obj,const std::shared_ptr<PyNode> & cdata,THPFunction * grad_fn,const UnpackedInput & unpacked,PyObject * inputs,THPObjectPtr && raw_output,bool is_executable,torch::jit::Node * node,bool overridden_setup_context)1078 PyObject* process_outputs(
1079 PyObject* op_obj,
1080 const std::shared_ptr<PyNode>& cdata,
1081 THPFunction* grad_fn,
1082 const UnpackedInput& unpacked,
1083 PyObject* inputs,
1084 THPObjectPtr&& raw_output,
1085 bool is_executable,
1086 torch::jit::Node* node,
1087 bool overridden_setup_context) {
1088 bool unpack_output = ensure_tuple(raw_output);
1089
1090 auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
1091
1092 THPObjectPtr outputs(PyTuple_New(num_outputs));
1093 if (!outputs)
1094 throw python_error();
1095
1096 cdata->clear_input_metadata();
1097
1098 // Record type, device, and size information about inputs
1099 if (is_executable) {
1100 grad_fn->input_info.clear();
1101 grad_fn->input_info.reserve(unpacked.input_vars.size());
1102 for (auto& var : unpacked.input_vars) {
1103 grad_fn->input_info.emplace_back(var);
1104 }
1105 }
1106
1107 std::unordered_set<at::TensorImpl*> to_save_if_setup_context{};
1108 std::vector<std::optional<at::Tensor>> tensors_to_save{};
1109 _get_tensors_to_save(
1110 grad_fn,
1111 to_save_if_setup_context,
1112 tensors_to_save,
1113 overridden_setup_context,
1114 is_executable);
1115
1116 bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
1117 _wrap_outputs(
1118 cdata,
1119 grad_fn,
1120 unpacked.input_vars,
1121 raw_output,
1122 outputs,
1123 is_executable,
1124 to_save_if_setup_context);
1125 _trace_post_record(
1126 node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
1127
1128 // It is important that creating the SavedVariables happen after the output
1129 // wrapping as the outputs must have their grad_fn/fw_grad properly set before
1130 // we save them.
1131 if (is_executable) {
1132 _save_variables(tensors_to_save, cdata, grad_fn);
1133 } else {
1134 // Remove unnecessary attributes
1135 Py_XDECREF(grad_fn->to_save);
1136 grad_fn->to_save = nullptr;
1137 Py_XDECREF(grad_fn->non_differentiable);
1138 grad_fn->non_differentiable = nullptr;
1139 }
1140
1141 Py_XDECREF(grad_fn->saved_for_forward);
1142 grad_fn->saved_for_forward = nullptr;
1143
1144 // Unpack the output, unless .forward() returned a tuple
1145 if (unpack_output) {
1146 PyObject* output = PyTuple_GET_ITEM(outputs.get(), 0);
1147 Py_INCREF(output);
1148 return output;
1149 }
1150
1151 return outputs.release();
1152 }
1153
THPFunction_name(PyObject * self,PyObject * noargs)1154 PyObject* THPFunction_name(PyObject* self, PyObject* noargs) {
1155 HANDLE_TH_ERRORS
1156 auto cdata = ((THPFunction*)self)->cdata.lock();
1157 TORCH_CHECK(
1158 cdata,
1159 "Attribute 'name' is invalid for this instance of _C._FunctionBase. "
1160 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1161 "access pattern that is no longer supported. For examples on how to use new-style "
1162 "autograd functions, see "
1163 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1164 return THPUtils_packString(cdata->name());
1165 END_HANDLE_TH_ERRORS
1166 }
1167
THPFunction_sequence_nr(PyObject * self,PyObject * noargs)1168 PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
1169 HANDLE_TH_ERRORS;
1170 auto cdata = ((THPFunction*)self)->cdata.lock();
1171 return THPUtils_packUInt64(cdata->sequence_nr());
1172 END_HANDLE_TH_ERRORS
1173 }
1174
THPFunction_set_sequence_nr(PyObject * self,PyObject * sequence_nr)1175 PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
1176 HANDLE_TH_ERRORS;
1177 auto cdata = ((THPFunction*)self)->cdata.lock();
1178 cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
1179 Py_RETURN_NONE;
1180 END_HANDLE_TH_ERRORS
1181 }
1182
THPFunction_input_metadata(PyObject * self,void * unused)1183 PyObject* THPFunction_input_metadata(PyObject* self, void* unused) {
1184 HANDLE_TH_ERRORS;
1185 auto cdata = ((THPFunction*)self)->cdata.lock();
1186 const auto num_inputs = cdata->num_inputs();
1187 THPObjectPtr list(PyTuple_New(num_inputs));
1188 if (!list) {
1189 return nullptr;
1190 }
1191 for (size_t i = 0; i < num_inputs; ++i) {
1192 const auto& metadata = cdata->input_metadata(i);
1193 THPObjectPtr item(py::cast(metadata).release().ptr());
1194 if (!item) {
1195 return nullptr;
1196 }
1197 PyTuple_SET_ITEM(list.get(), i, item.release());
1198 }
1199 return list.release();
1200 END_HANDLE_TH_ERRORS
1201 }
1202
THPFunction_maybe_clear_saved_tensors(PyObject * self,PyObject * noargs)1203 PyObject* THPFunction_maybe_clear_saved_tensors(
1204 PyObject* self,
1205 PyObject* noargs) {
1206 HANDLE_TH_ERRORS;
1207 auto cdata = ((THPFunction*)self)->cdata.lock();
1208 if (!get_current_graph_task_keep_graph()) {
1209 cdata->release_variables();
1210 }
1211 Py_RETURN_NONE;
1212 END_HANDLE_TH_ERRORS
1213 }
1214
1215 namespace {
1216
make_ctx_input_tuple(THPFunction * ctx,const UnpackedInput & unpacked_input,int64_t num_args)1217 THPObjectPtr make_ctx_input_tuple(
1218 THPFunction* ctx,
1219 const UnpackedInput& unpacked_input,
1220 int64_t num_args) {
1221 THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
1222 if (!ctx_input_tuple)
1223 return {};
1224 Py_INCREF(ctx);
1225 PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
1226 for (const auto i : c10::irange(num_args)) {
1227 PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
1228 Py_INCREF(arg);
1229 PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
1230 }
1231 return ctx_input_tuple;
1232 }
1233
make_ctx_input_output_tuple(THPFunction * ctx,UnpackedInput & unpacked_input,PyObject * output)1234 THPObjectPtr make_ctx_input_output_tuple(
1235 THPFunction* ctx,
1236 UnpackedInput& unpacked_input,
1237 PyObject* output) {
1238 THPObjectPtr result(PyTuple_New(3));
1239 if (!result)
1240 return {};
1241 Py_INCREF(ctx);
1242 Py_INCREF(unpacked_input.input_tuple.get());
1243 Py_INCREF(output);
1244 PyTuple_SET_ITEM(result.get(), 0, (PyObject*)ctx);
1245 PyTuple_SET_ITEM(result.get(), 1, unpacked_input.input_tuple.get());
1246 PyTuple_SET_ITEM(result.get(), 2, output);
1247 return result;
1248 }
1249
1250 } // namespace
1251
1252 static PyObject* THPFunction_setup_context = nullptr;
1253
get_base_setup_context()1254 static PyObject* get_base_setup_context() {
1255 if (THPFunction_setup_context != nullptr) {
1256 return THPFunction_setup_context;
1257 }
1258
1259 auto module = THPObjectPtr(PyImport_ImportModule("torch.autograd.function"));
1260 if (!module)
1261 return nullptr;
1262
1263 auto function =
1264 THPObjectPtr(PyObject_GetAttrString(module, "_SingleLevelFunction"));
1265 if (!function)
1266 return nullptr;
1267
1268 // setup_context gets "leaked" - we return a new reference and hold onto it
1269 // forever.
1270 auto setup_context = PyObject_GetAttrString(function, "setup_context");
1271 if (!setup_context)
1272 return nullptr;
1273 THPFunction_setup_context = setup_context;
1274 return THPFunction_setup_context;
1275 }
1276
THPFunction_apply(PyObject * cls,PyObject * inputs)1277 PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
1278 HANDLE_TH_ERRORS
1279
1280 // save a local copy of seq_id before it gets incremented
1281 auto seq_id = at::sequence_number::peek();
1282 auto info_pair = unpack_input<false>(inputs);
1283 UnpackedInput& unpacked_input = info_pair.first;
1284 InputFlags& input_info = info_pair.second;
1285
1286 // Call record function after all the inputs have been decoded, but
1287 // before context has been allocated.
1288 RECORD_FUNCTION(
1289 ((PyTypeObject*)cls)->tp_name,
1290 unpacked_input.record_function_inputs,
1291 seq_id);
1292
1293 const auto& functorch_tls = at::functorch::functorchTLSAccessor();
1294 if (functorch_tls) {
1295 // autograd.Function support for functorch is handled in Python.
1296 // If we have gotten here, then either we are dealing with a
1297 // torch.autograd.function._SingleLevelFunction, or something in
1298 // the implementation went wrong.
1299 // The following code is useful for debugging when something goes wrong
1300 // because it'll raise a loud error (instead of being silently incorrect).
1301 functorch_tls->checkSupportsSingleLevelAutogradFunction();
1302 }
1303
1304 THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
1305 if (!backward_cls)
1306 return nullptr;
1307 THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
1308 if (!ctx_obj)
1309 return nullptr;
1310 THPFunction* ctx = (THPFunction*)ctx_obj.get();
1311
1312 auto cdata =
1313 std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
1314 ctx->cdata = cdata;
1315
1316 // Record input nodes if tracing
1317 auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
1318
1319 // Initialize backward function (and ctx)
1320 bool is_executable = input_info.is_executable;
1321 cdata->set_next_edges(std::move(input_info.next_edges));
1322 ctx->needs_input_grad = input_info.needs_input_grad.release();
1323 ctx->is_variable_input = std::move(input_info.is_variable_input);
1324
1325 // autograd.Function may optionally override a setup_context staticmethod.
1326 // In this case, autograd.Function.forward does NOT accept a ctx object.
1327 // Determine if this is the case.
1328 auto cls_setup_context =
1329 THPObjectPtr(PyObject_GetAttrString(cls, "setup_context"));
1330 if (!cls_setup_context) {
1331 return nullptr;
1332 }
1333 auto orig_setup_context = get_base_setup_context();
1334 if (!orig_setup_context) {
1335 return nullptr;
1336 }
1337 auto overridden_setup_context = cls_setup_context.get() != orig_setup_context;
1338
1339 auto num_args = PyTuple_GET_SIZE(inputs);
1340
1341 // Call forward
1342 THPObjectPtr output;
1343 {
1344 AutoGradMode grad_mode(false);
1345 at::AutoFwGradMode fw_grad_mode(false);
1346 THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
1347 if (!forward_fn)
1348 return nullptr;
1349 if (overridden_setup_context) {
1350 // call forward followed by setup_context
1351 output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
1352 if (!output) {
1353 return nullptr;
1354 }
1355 // signature is setup_context(ctx, inputs, output)
1356 auto ctx_input_output_tuple =
1357 make_ctx_input_output_tuple(ctx, unpacked_input, output);
1358 if (!ctx_input_output_tuple) {
1359 return nullptr;
1360 }
1361 THPObjectPtr setup_context_fn(
1362 PyObject_GetAttrString(cls, "setup_context"));
1363 auto result =
1364 PyObject_CallObject(setup_context_fn, ctx_input_output_tuple);
1365 if (!result) {
1366 return nullptr;
1367 }
1368 } else {
1369 // call forward
1370 auto ctx_input_tuple =
1371 make_ctx_input_tuple(ctx, unpacked_input, num_args);
1372 if (!ctx_input_tuple) {
1373 return nullptr;
1374 }
1375 output = PyObject_CallObject(forward_fn, ctx_input_tuple);
1376 }
1377 if (!output)
1378 return nullptr;
1379 }
1380
1381 return process_outputs(
1382 cls,
1383 cdata,
1384 ctx,
1385 unpacked_input,
1386 inputs,
1387 std::move(output),
1388 is_executable,
1389 node,
1390 overridden_setup_context);
1391 END_HANDLE_TH_ERRORS
1392 }
1393
1394 ////////////////////////////////////////////////////////////////////////////////
1395 // Other methods / attributes
1396 ////////////////////////////////////////////////////////////////////////////////
1397
THPFunction__register_hook_dict(PyObject * _self,PyObject * _var)1398 PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) {
1399 HANDLE_TH_ERRORS
1400 TORCH_CHECK(THPVariable_Check(_var), "_register_hook_dict expected a Tensor");
1401 THPVariable* var = reinterpret_cast<THPVariable*>(_var);
1402 const auto& tensor = THPVariable_Unpack(var);
1403 std::unique_ptr<FunctionPreHook> hook(
1404 new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr()));
1405 auto self = (THPFunction*)_self;
1406 auto cdata = self->cdata.lock();
1407 TORCH_CHECK(
1408 cdata,
1409 "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. "
1410 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1411 "access pattern that is no longer supported. For examples on how to use new-style "
1412 "autograd functions, see "
1413 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1414 cdata->add_tensor_pre_hook(std::move(hook));
1415 Py_RETURN_NONE;
1416 END_HANDLE_TH_ERRORS
1417 }
1418
THPFunction_register_hook(PyObject * _self,PyObject * hook)1419 PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) {
1420 HANDLE_TH_ERRORS
1421 auto self = (THPFunction*)_self;
1422 auto cdata = self->cdata.lock();
1423 TORCH_CHECK(
1424 cdata,
1425 "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. "
1426 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1427 "access pattern that is no longer supported. For examples on how to use new-style "
1428 "autograd functions, see "
1429 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1430 return torch::autograd::registerFunctionHook(*cdata, hook);
1431 END_HANDLE_TH_ERRORS
1432 }
1433
THPFunction_register_prehook(PyObject * _self,PyObject * hook)1434 PyObject* THPFunction_register_prehook(PyObject* _self, PyObject* hook) {
1435 HANDLE_TH_ERRORS
1436 auto self = (THPFunction*)_self;
1437 auto cdata = self->cdata.lock();
1438 TORCH_CHECK(
1439 cdata,
1440 "Attribute 'register_prehook' is invalid for this instance of _C._FunctionBase. "
1441 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1442 "access pattern that is no longer supported. For examples on how to use new-style "
1443 "autograd functions, see "
1444 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1445 return torch::autograd::registerFunctionPreHook(*cdata, hook);
1446 END_HANDLE_TH_ERRORS
1447 }
1448
THPFunction_set_materialize_grads(THPFunction * self,PyObject * value,void * unused)1449 int THPFunction_set_materialize_grads(
1450 THPFunction* self,
1451 PyObject* value,
1452 void* unused) {
1453 HANDLE_TH_ERRORS
1454 if (!PyBool_Check(value)) {
1455 THPUtils_invalidArguments(
1456 value, nullptr, "set_materialize_grads", 1, "(bool)");
1457 return -1;
1458 }
1459 self->materialize_grads = (value == Py_True);
1460 return 0;
1461 END_HANDLE_TH_ERRORS_RET(-1)
1462 }
1463
THPFunction_get_materialize_non_diff_grads(THPFunction * self,void * _unused)1464 PyObject* THPFunction_get_materialize_non_diff_grads(
1465 THPFunction* self,
1466 void* _unused) {
1467 HANDLE_TH_ERRORS
1468 if (self->materialize_non_diff_grads) {
1469 Py_RETURN_TRUE;
1470 } else {
1471 Py_RETURN_FALSE;
1472 }
1473 END_HANDLE_TH_ERRORS
1474 }
1475
THPFunction_set_materialize_non_diff_grads(THPFunction * self,PyObject * value,void * unused)1476 int THPFunction_set_materialize_non_diff_grads(
1477 THPFunction* self,
1478 PyObject* value,
1479 void* unused) {
1480 HANDLE_TH_ERRORS
1481 if (!PyBool_Check(value)) {
1482 THPUtils_invalidArguments(
1483 value, nullptr, "set_materialize_non_diff_grads", 1, "(bool)");
1484 return -1;
1485 }
1486 self->materialize_non_diff_grads = (value == Py_True);
1487 return 0;
1488 END_HANDLE_TH_ERRORS_RET(-1)
1489 }
1490
THPFunction_saved_tensors(THPFunction * self,void * _unused)1491 PyObject* THPFunction_saved_tensors(THPFunction* self, void* _unused) {
1492 HANDLE_TH_ERRORS
1493 if (self->saved_for_forward) {
1494 Py_INCREF(self->saved_for_forward);
1495 return self->saved_for_forward;
1496 } else {
1497 return unpack_saved_variables(
1498 self, [](const Variable& var) { return THPVariable_Wrap(var); });
1499 }
1500 END_HANDLE_TH_ERRORS
1501 }
1502
THPFunction_saved_variables(THPFunction * self,void * _unused)1503 PyObject* THPFunction_saved_variables(THPFunction* self, void* _unused) {
1504 HANDLE_TH_ERRORS
1505 auto r = PyErr_WarnEx(
1506 PyExc_DeprecationWarning,
1507 "'saved_variables' is deprecated; use 'saved_tensors'",
1508 0);
1509 if (r != 0)
1510 throw python_error();
1511 return unpack_saved_variables(
1512 self, [](const Variable& var) { return THPVariable_Wrap(var); });
1513 END_HANDLE_TH_ERRORS
1514 }
1515
THPFunction_is_compiled_autograd_tracing(PyObject * self,PyObject * _unused)1516 PyObject* THPFunction_is_compiled_autograd_tracing(
1517 PyObject* self,
1518 PyObject* _unused) {
1519 HANDLE_TH_ERRORS
1520 if (((THPFunction*)self)->compiled_autograd_tracing) {
1521 Py_RETURN_TRUE;
1522 } else {
1523 Py_RETURN_FALSE;
1524 }
1525 END_HANDLE_TH_ERRORS
1526 }
1527
THPFunction_get_compiled_autograd_symints(PyObject * _self,PyObject * _unused)1528 PyObject* THPFunction_get_compiled_autograd_symints(
1529 PyObject* _self,
1530 PyObject* _unused) {
1531 HANDLE_TH_ERRORS
1532 auto self = (THPFunction*)_self;
1533 auto size = self->compiled_autograd_symints.size();
1534 PyObject* result = PyTuple_New(static_cast<Py_ssize_t>(size));
1535 if (!result) {
1536 throw python_error();
1537 }
1538 for (const auto i : c10::irange(size)) {
1539 PyTuple_SET_ITEM(
1540 result,
1541 i,
1542 py::cast(self->compiled_autograd_symints[i]).release().ptr());
1543 }
1544 return result;
1545 END_HANDLE_TH_ERRORS
1546 }
1547
THPFunction_get_compiled_autograd_backward_state(PyObject * _self,void * _unused)1548 PyObject* THPFunction_get_compiled_autograd_backward_state(
1549 PyObject* _self,
1550 void* _unused) {
1551 HANDLE_TH_ERRORS
1552 auto self = (THPFunction*)_self;
1553 PyObject* bw_state = self->compiled_autograd_backward_state;
1554 if (bw_state == nullptr) {
1555 bw_state = Py_None;
1556 }
1557 Py_INCREF(bw_state);
1558 return bw_state;
1559 END_HANDLE_TH_ERRORS
1560 }
1561
THPFunction_set_compiled_autograd_backward_state(PyObject * _self,PyObject * bw_state,void * _unused)1562 int THPFunction_set_compiled_autograd_backward_state(
1563 PyObject* _self,
1564 PyObject* bw_state,
1565 void* _unused) {
1566 HANDLE_TH_ERRORS
1567 auto self = (THPFunction*)_self;
1568 TORCH_INTERNAL_ASSERT(self->compiled_autograd_backward_state == nullptr);
1569 Py_INCREF(bw_state);
1570 self->compiled_autograd_backward_state = bw_state;
1571 return 0;
1572 END_HANDLE_TH_ERRORS_RET(-1)
1573 }
1574
THPFunction_raw_saved_tensors(THPFunction * self,void * _unused)1575 PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) {
1576 HANDLE_TH_ERRORS
1577 // User tries to access saved variables after they have been freed
1578 TORCH_CHECK(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
1579 const auto& saved_variables = self->saved_variables;
1580 if (saved_variables.empty())
1581 return PyTuple_New(0);
1582 size_t num_saved = saved_variables.size();
1583 THPObjectPtr saved(PyTuple_New(static_cast<Py_ssize_t>(num_saved)));
1584 if (!saved) {
1585 return nullptr;
1586 }
1587 for (const auto i : c10::irange(num_saved)) {
1588 py::object obj =
1589 py::cast(saved_variables[i], py::return_value_policy::reference);
1590 PyTuple_SET_ITEM(saved.get(), i, obj.release().ptr());
1591 }
1592 return saved.release();
1593 END_HANDLE_TH_ERRORS
1594 }
1595
THPFunction_next_functions(THPFunction * self,void * _unused)1596 PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) {
1597 HANDLE_TH_ERRORS
1598 auto cdata = self->cdata.lock();
1599 TORCH_CHECK(
1600 cdata,
1601 "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. "
1602 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1603 "access pattern that is no longer supported. For examples on how to use new-style "
1604 "autograd functions, see "
1605 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1606 const auto num_outputs = cdata->num_outputs();
1607 THPObjectPtr result(PyTuple_New(num_outputs));
1608 if (!result)
1609 return nullptr;
1610 for (const auto i : c10::irange(num_outputs)) {
1611 THPObjectPtr fn_tuple(PyTuple_New(2));
1612 if (!fn_tuple)
1613 return nullptr;
1614 const auto& edge = cdata->next_edge(i);
1615 PyObject* fn = functionToPyObject(edge.function);
1616 if (!fn)
1617 return nullptr;
1618 PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
1619 PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr));
1620 PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
1621 }
1622 return result.release();
1623 END_HANDLE_TH_ERRORS
1624 }
1625
THPFunction_metadata(THPFunction * self,void * _unused)1626 PyObject* THPFunction_metadata(THPFunction* self, void* _unused) {
1627 HANDLE_TH_ERRORS
1628 auto cdata = self->cdata.lock();
1629 // The correct way to solve this problem is to stop exposing grad_fn
1630 // of PyFunctions as THPFunction; instead, we should use THPCppFunction
1631 // like everyone else. But this is a BC-breaking change as it would
1632 // mean that you no longer get the property that grad_fn is a subclass
1633 // of the autograd function class that you defined in the custom case,
1634 // so I didn't fix it here.
1635 TORCH_CHECK(
1636 cdata,
1637 "You attempted to access the anomaly metadata of a custom autograd function "
1638 "but the underlying PyNode has already been deallocated. The most likely "
1639 "reason this occurred is because you assigned x.grad_fn to a local variable "
1640 "and then let the original variable get deallocated. Don't do that! If "
1641 "you really have no way of restructuring your code so this is the case, "
1642 "please file an issue reporting that you are affected by this.");
1643 auto metadata = static_cast<PyAnomalyMetadata*>(cdata->metadata())->dict();
1644
1645 Py_INCREF(metadata);
1646 return metadata;
1647 END_HANDLE_TH_ERRORS
1648 }
1649
1650 using getter = PyObject* (*)(PyObject*, void*);
1651 using setter = int (*)(PyObject*, PyObject*, void*);
1652
1653 namespace {
1654
1655 template <PyObject* THPFunction::*ptr>
getObject(PyObject * obj,void * _unused)1656 PyObject* getObject(PyObject* obj, void* _unused) {
1657 auto self = (THPFunction*)obj;
1658 PyObject* value = self->*ptr;
1659 if (!value) {
1660 Py_RETURN_NONE;
1661 }
1662 Py_INCREF(value);
1663 return value;
1664 }
1665
1666 template <PyObject* THPFunction::*ptr>
setObject(PyObject * obj,PyObject * value,void * _unused)1667 int setObject(PyObject* obj, PyObject* value, void* _unused) {
1668 auto self = (THPFunction*)obj;
1669 if (value == Py_None) {
1670 value = nullptr;
1671 }
1672 Py_XDECREF((self->*ptr));
1673 Py_XINCREF(value);
1674 self->*ptr = value;
1675 return 0;
1676 }
1677
1678 template <typename M, M THPFunction::*ptr, PyObject* (*Convert)(long)>
getMember(PyObject * obj,void * _unused)1679 PyObject* getMember(PyObject* obj, void* _unused) {
1680 auto self = (THPFunction*)obj;
1681 return Convert(self->*ptr);
1682 }
1683
1684 template <typename M, M autograd::Node::*ptr, PyObject* (*Convert)(long)>
getImplMember(PyObject * obj,void * _unused)1685 PyObject* getImplMember(PyObject* obj, void* _unused) {
1686 auto self = (THPFunction*)obj;
1687 return Convert(self->cdata.*ptr);
1688 }
1689
getRequiresGrad(PyObject * obj,void * _unused)1690 PyObject* getRequiresGrad(PyObject* obj, void* _unused) {
1691 Py_RETURN_TRUE;
1692 }
1693
1694 } // namespace
1695
1696 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1697 static struct PyGetSetDef THPFunction_properties[] = {
1698 {"saved_tensors",
1699 (getter)THPFunction_saved_tensors,
1700 nullptr,
1701 nullptr,
1702 nullptr},
1703 {"saved_variables",
1704 (getter)THPFunction_saved_variables,
1705 nullptr,
1706 nullptr,
1707 nullptr},
1708 {"_raw_saved_tensors",
1709 (getter)THPFunction_raw_saved_tensors,
1710 nullptr,
1711 nullptr,
1712 nullptr},
1713 {"next_functions",
1714 (getter)THPFunction_next_functions,
1715 nullptr,
1716 nullptr,
1717 nullptr},
1718 {"to_save",
1719 &getObject<&THPFunction::to_save>,
1720 &setObject<&THPFunction::to_save>,
1721 nullptr,
1722 nullptr},
1723 {"non_differentiable",
1724 &getObject<&THPFunction::non_differentiable>,
1725 &setObject<&THPFunction::non_differentiable>,
1726 nullptr,
1727 nullptr},
1728 {"dirty_tensors",
1729 &getObject<&THPFunction::dirty_tensors>,
1730 &setObject<&THPFunction::dirty_tensors>,
1731 nullptr,
1732 nullptr},
1733 {"saved_for_forward",
1734 &getObject<&THPFunction::saved_for_forward>,
1735 &setObject<&THPFunction::saved_for_forward>,
1736 nullptr,
1737 nullptr},
1738 {"needs_input_grad",
1739 &getObject<&THPFunction::needs_input_grad>,
1740 &setObject<&THPFunction::needs_input_grad>,
1741 nullptr,
1742 nullptr},
1743 {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
1744 {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
1745 {"_input_metadata",
1746 (getter)THPFunction_input_metadata,
1747 nullptr,
1748 nullptr,
1749 nullptr},
1750 {"materialize_grads",
1751 nullptr,
1752 (setter)THPFunction_set_materialize_grads,
1753 nullptr,
1754 nullptr},
1755 {"_materialize_non_diff_grads",
1756 (getter)THPFunction_get_materialize_non_diff_grads,
1757 (setter)THPFunction_set_materialize_non_diff_grads,
1758 nullptr,
1759 nullptr},
1760 {"_compiled_autograd_backward_state",
1761 (getter)THPFunction_get_compiled_autograd_backward_state,
1762 (setter)THPFunction_set_compiled_autograd_backward_state,
1763 nullptr,
1764 nullptr},
1765 {nullptr}};
1766
1767 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1768 static struct PyMethodDef THPFunction_methods[] = {
1769 {(char*)"name", THPFunction_name, METH_NOARGS, nullptr},
1770 {(char*)"_sequence_nr", THPFunction_sequence_nr, METH_NOARGS, nullptr},
1771 {(char*)"_set_sequence_nr", THPFunction_set_sequence_nr, METH_O, nullptr},
1772 {(char*)"maybe_clear_saved_tensors",
1773 THPFunction_maybe_clear_saved_tensors,
1774 METH_NOARGS,
1775 nullptr},
1776 {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
1777 {(char*)"_register_hook_dict",
1778 THPFunction__register_hook_dict,
1779 METH_O,
1780 nullptr},
1781 {(char*)"register_hook", THPFunction_register_hook, METH_O, nullptr},
1782 {(char*)"register_prehook", THPFunction_register_prehook, METH_O, nullptr},
1783 {(char*)"_is_compiled_autograd_tracing",
1784 THPFunction_is_compiled_autograd_tracing,
1785 METH_NOARGS,
1786 nullptr},
1787 {(char*)"_get_compiled_autograd_symints",
1788 THPFunction_get_compiled_autograd_symints,
1789 METH_NOARGS,
1790 nullptr},
1791 {nullptr}};
1792
1793 PyTypeObject THPFunctionType = {
1794 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._FunctionBase", /* tp_name */
1795 sizeof(THPFunction), /* tp_basicsize */
1796 0, /* tp_itemsize */
1797 (destructor)THPFunction_dealloc, /* tp_dealloc */
1798 0, /* tp_vectorcall_offset */
1799 nullptr, /* tp_getattr */
1800 nullptr, /* tp_setattr */
1801 nullptr, /* tp_reserved */
1802 nullptr, /* tp_repr */
1803 nullptr, /* tp_as_number */
1804 nullptr, /* tp_as_sequence */
1805 nullptr, /* tp_as_mapping */
1806 nullptr, /* tp_hash */
1807 nullptr, /* tp_call */
1808 nullptr, /* tp_str */
1809 nullptr, /* tp_getattro */
1810 nullptr, /* tp_setattro */
1811 nullptr, /* tp_as_buffer */
1812 // NOLINTNEXTLINE(misc-redundant-expression)
1813 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
1814 Py_TPFLAGS_HAVE_GC, /* tp_flags */
1815 nullptr, /* tp_doc */
1816 (traverseproc)THPFunction_traverse, /* tp_traverse */
1817 (inquiry)THPFunction_clear, /* tp_clear */
1818 nullptr, /* tp_richcompare */
1819 0, /* tp_weaklistoffset */
1820 nullptr, /* tp_iter */
1821 nullptr, /* tp_iternext */
1822 THPFunction_methods, /* tp_methods */
1823 nullptr, /* tp_members */
1824 THPFunction_properties, /* tp_getset */
1825 nullptr, /* tp_base */
1826 nullptr, /* tp_dict */
1827 nullptr, /* tp_descr_get */
1828 nullptr, /* tp_descr_set */
1829 0, /* tp_dictoffset */
1830 nullptr, /* tp_init */
1831 nullptr, /* tp_alloc */
1832 THPFunction_new /* tp_new */
1833 };
1834
THPFunction_initModule(PyObject * module)1835 bool THPFunction_initModule(PyObject* module) {
1836 if (PyType_Ready(&THPFunctionType) < 0)
1837 return false;
1838 Py_INCREF(&THPFunctionType);
1839 PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType);
1840 return true;
1841 }
1842