xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/python_function.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/SequenceNumber.h>
5 #include <c10/util/irange.h>
6 #include <pybind11/pybind11.h>
7 #include <structmember.h>
8 #include <torch/csrc/PyInterpreter.h>
9 #include <torch/csrc/python_headers.h>
10 #include <torch/csrc/utils/pybind.h>
11 
12 #include <ATen/FuncTorchTLS.h>
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <torch/csrc/DynamicTypes.h>
15 #include <torch/csrc/Exceptions.h>
16 #include <torch/csrc/THP.h>
17 #include <torch/csrc/autograd/functions/accumulate_grad.h>
18 #include <torch/csrc/autograd/functions/basic_ops.h>
19 #include <torch/csrc/autograd/functions/utils.h>
20 #include <torch/csrc/autograd/grad_mode.h>
21 #include <torch/csrc/autograd/graph_task.h>
22 #include <torch/csrc/autograd/python_anomaly_mode.h>
23 #include <torch/csrc/autograd/python_cpp_function.h>
24 #include <torch/csrc/autograd/python_hook.h>
25 #include <torch/csrc/autograd/saved_variable.h>
26 #include <torch/csrc/autograd/utils/wrap_outputs.h>
27 #include <torch/csrc/dynamo/compiled_autograd.h>
28 #include <torch/csrc/jit/frontend/tracer.h>
29 #include <torch/csrc/jit/ir/ir.h>
30 #include <torch/csrc/jit/python/pybind_utils.h>
31 #include <torch/csrc/jit/python/python_tracer.h>
32 #include <torch/csrc/profiler/api.h>
33 #include <torch/csrc/utils/python_strings.h>
34 #include <torch/csrc/utils/tensor_dtypes.h>
35 
36 #include <functional>
37 #include <memory>
38 #include <stdexcept>
39 #include <string>
40 #include <unordered_map>
41 #include <unordered_set>
42 #include <utility>
43 #include <vector>
44 
45 using namespace torch;
46 using namespace torch::autograd;
47 using at::Tensor;
48 
49 PyObject* THPFunctionClass = nullptr;
50 PyObject* THPGradientEdgeClass = nullptr;
51 
52 #define THPFunction_assert(condition, ...) \
53   if (!(condition)) {                      \
54     THPUtils_setError(__VA_ARGS__);        \
55     throw python_error();                  \
56   }
57 
58 // Anonymous namespace for helpful functions used in this file
59 namespace {
60 
61 // TODO: We shouldn't need to call this function because the engine
62 // can already persist the errors for us. This still seems to be
63 // needed for the DistEngine however.
64 //
65 // python test/distributed/rpc/test_tensorpipe_agent.py -k
66 // test_backward_autograd_engine_error
67 //
68 // See Note [ Persisting PyErr state across autograd engine threads ]
throw_python_error()69 void throw_python_error() {
70   python_error err;
71   err.persist();
72   throw std::move(err);
73 }
74 
unpack_saved_variables(THPFunction * self,const std::function<PyObject * (const Variable &)> & unpack_fn)75 static PyObject* unpack_saved_variables(
76     THPFunction* self,
77     const std::function<PyObject*(const Variable&)>& unpack_fn) {
78   HANDLE_TH_ERRORS
79   TORCH_CHECK(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
80   auto& saved_variables = self->saved_variables;
81   if (saved_variables.empty())
82     return PyTuple_New(0);
83 
84   auto num_saved = saved_variables.size();
85   THPObjectPtr saved(PyTuple_New(static_cast<Py_ssize_t>(num_saved)));
86   if (!saved)
87     return nullptr;
88   auto saved_for = self->cdata.lock();
89   // This is really a true assert, because we've already tested for the
90   // self->has_freed_buffers case at the beginning of this function:
91   // buffers are freed when PyNode dies; if the buffers are not freed,
92   // PyNode must be live.  (Note that the buffers could be freed
93   // even though the PyNode is live, but that doesn't matter here
94   // because we will never hit this line of code if the buffers are freed--
95   // and in any case saved_for will be non-NULL.)
96   TORCH_INTERNAL_ASSERT(saved_for);
97   for (const auto i : c10::irange(num_saved)) {
98     auto unpacked_var = saved_variables[i].unpack(saved_for);
99     THPObjectPtr value;
100     if (!unpacked_var.defined()) {
101       Py_INCREF(Py_None);
102       value = Py_None;
103     } else {
104       value = unpack_fn(unpacked_var);
105     }
106     PyTuple_SET_ITEM(saved.get(), i, value.release());
107   }
108   return saved.release();
109   END_HANDLE_TH_ERRORS
110 }
111 
to_py_size(const std::vector<c10::SymInt> & size)112 PyObject* to_py_size(const std::vector<c10::SymInt>& size) {
113   c10::SymIntArrayRef sym_sizes(size);
114 
115   auto ret = THPObjectPtr(THPSizeType.tp_alloc(
116       &THPSizeType, static_cast<Py_ssize_t>(sym_sizes.size())));
117   if (!ret)
118     throw python_error();
119 
120   for (auto i : c10::irange(sym_sizes.size())) {
121     auto symint = sym_sizes[i];
122     if (auto maybe_int = symint.maybe_as_int(); maybe_int.has_value()) {
123       PyTuple_SET_ITEM(ret.get(), i, THPUtils_packInt64(*maybe_int));
124     } else {
125       auto py_symint = py::cast(symint).release().ptr();
126       PyTuple_SET_ITEM(ret.get(), i, py_symint);
127     }
128   }
129   return ret.release();
130 }
131 
132 } // namespace
133 
134 namespace torch::autograd {
135 
136 // NOTE: this function is written in a way that assumes it's only called for
137 // backward; it's used by engine.cpp.  This is responsible for forwarding a call
138 // from C++'s Node::apply to a Python method "apply".
apply(variable_list && inputs)139 auto PyNode::apply(variable_list&& inputs) -> variable_list {
140   pybind11::gil_scoped_acquire gil;
141   at::OptionalDeviceGuard _device_guard;
142   THPFunction* py_fn = (THPFunction*)obj;
143 
144   // Massage a C++ variable_list into a Python arguments tuple
145   THPObjectPtr pyInputs(to_py_args(inputs, &_device_guard));
146 
147   THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply"));
148   if (!apply_fn)
149     throw_python_error();
150   THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
151   if (!r)
152     throw_python_error();
153   ensure_tuple(r);
154 
155   auto& is_variable_input = py_fn->is_variable_input;
156   auto num_outputs = PyTuple_GET_SIZE(r.get());
157   auto num_forward_inputs = static_cast<Py_ssize_t>(is_variable_input.size());
158   // Returning too many results is ok, but only as long as they're all None.
159   // Truncate the result tuple in that case.
160   if (num_outputs > num_forward_inputs) {
161     bool all_none = true;
162     for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
163       all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None;
164     }
165     if (all_none) {
166       num_outputs = num_forward_inputs;
167       r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
168       if (!r)
169         throw_python_error();
170     }
171   }
172 
173   // Now the number of gradients should match
174   if (num_outputs != num_forward_inputs) {
175     std::string msg("function ");
176     msg += name() + " returned an incorrect number of gradients (expected ";
177     msg += std::to_string(num_forward_inputs) + ", got ";
178     msg += std::to_string(num_outputs) + ")";
179     throw std::runtime_error(msg);
180   }
181 
182   // Massage the Python results tuple back into a C++ variable_list
183   return to_variable_list(r.get(), is_variable_input);
184 }
185 
defer_to_dynamo(variable_list && inputs,std::optional<PyObject * > compiler)186 auto PyNode::defer_to_dynamo(
187     variable_list&& inputs,
188     std::optional<PyObject*> compiler) -> variable_list {
189   pybind11::gil_scoped_acquire gil;
190   at::OptionalDeviceGuard _device_guard;
191   THPFunction* py_fn = (THPFunction*)obj;
192 
193   // Massage a C++ variable_list into a Python arguments tuple
194   THPObjectPtr pyInputs(to_py_args(inputs, &_device_guard));
195 
196   const auto& is_variable_input = py_fn->is_variable_input;
197   const auto& input_infos = py_fn->input_info;
198   // input_info only contains info from variable inputs and should be a subset
199   TORCH_INTERNAL_ASSERT(is_variable_input.size() >= input_infos.size());
200 
201   // The gradients returned in the backwards need to match the number of inputs
202   // to the forward, and their metadata, so we pass the fwdInputs
203   THPObjectPtr fwdInputMetadatas(
204       PyTuple_New(static_cast<Py_ssize_t>(is_variable_input.size())));
205   if (!fwdInputMetadatas)
206     throw python_error();
207 
208   int offset = 0;
209   for (const auto i : c10::irange(is_variable_input.size())) {
210     if (!is_variable_input[i]) {
211       // input at i is not a variable, skip index
212       PyTuple_SET_ITEM(fwdInputMetadatas.get(), i, Py_None);
213       offset++;
214       continue;
215     }
216 
217     const auto& input_info = input_infos[i - offset];
218 
219     PyObject* device(THPDevice_New(input_info.device));
220     if (!device)
221       throw_python_error();
222     // Metadata is a tuple of 4 elements: (layout, device, dtype, size)
223     PyObject* fwdInputMetadata = PyTuple_Pack(
224         4,
225         autograd::utils::wrap(input_info.layout),
226         device,
227         autograd::utils::wrap(input_info.scalar_type),
228         to_py_size(input_info.size));
229     if (!fwdInputMetadata)
230       throw python_error();
231 
232     PyTuple_SET_ITEM(fwdInputMetadatas.get(), i, fwdInputMetadata);
233   }
234   THPObjectPtr saved_tensors(unpack_saved_variables(
235       py_fn, [](const Variable& var) { return THPVariable_Wrap(var); }));
236   TORCH_INTERNAL_ASSERT(
237       _backward_idx.has_value(),
238       "indices should already be set by compiled_args, called before apply_with_saved");
239   TORCH_INTERNAL_ASSERT(!_backward_state_idx.has_value());
240   THPObjectPtr r(PyObject_CallMethod(
241       *compiler,
242       "proxy_call_backward",
243       "OOOi",
244       pyInputs.get(),
245       fwdInputMetadatas.get(),
246       saved_tensors.get(),
247       *_backward_idx));
248 
249   if (!r)
250     throw_python_error();
251   ensure_tuple(r);
252 
253   // Massage the Python results tuple back into a C++ variable_list
254   return to_variable_list(r.get(), is_variable_input);
255 }
256 
is_traceable()257 auto PyNode::is_traceable() -> bool {
258   pybind11::gil_scoped_acquire gil;
259   THPObjectPtr forward_class{PyObject_GetAttrString(obj, "_forward_cls")};
260   if (!forward_class)
261     throw_python_error();
262   THPObjectPtr traceable_py_bool{
263       PyObject_GetAttrString(forward_class, "is_traceable")};
264   if (!traceable_py_bool)
265     throw_python_error();
266   return traceable_py_bool == Py_True;
267 }
268 
release_variables()269 auto PyNode::release_variables() -> void {
270   // This function is called as part of the Node destructor!
271   // Since this object might be kept alive by C++, it is possible
272   // that the python interpreter is already dead here. In that case
273   // we just leak the saved objects.
274   if (Py_IsInitialized()) {
275     pybind11::gil_scoped_acquire gil;
276     auto f = (THPFunction*)obj;
277     f->saved_variables.clear();
278     f->has_freed_buffers = 1;
279   }
280 }
281 
name() const282 auto PyNode::name() const -> std::string {
283   pybind11::gil_scoped_acquire gil;
284   auto f = (THPFunction*)obj;
285   auto name = std::string(Py_TYPE(f)->tp_name);
286   return name;
287 }
288 
compiled_autograd_should_lift() const289 auto PyNode::compiled_autograd_should_lift() const -> bool {
290   pybind11::gil_scoped_acquire gil;
291   static PyObject* attr_name =
292       PyUnicode_InternFromString("_compiled_autograd_should_lift");
293   THPObjectPtr should_lift(PyObject_GetAttr(obj, attr_name));
294   return PyObject_IsTrue(should_lift.get()) == 1;
295 }
296 
compiled_args(CompiledNodeArgs & args)297 void PyNode::compiled_args(CompiledNodeArgs& args) {
298   static PyObject* method_name =
299       PyUnicode_InternFromString("_compiled_autograd_key");
300   THPObjectPtr pykey(PyObject_CallMethodNoArgs(obj, method_name));
301   if (!pykey)
302     throw_python_error();
303   TORCH_CHECK(
304       PyTuple_CheckExact(pykey.get()),
305       "_compiled_autograd_key should return tuple of ints");
306   auto size = PyTuple_GET_SIZE(pykey.get());
307   TORCH_INTERNAL_ASSERT(size > 0);
308   // first value is unique id managed by AUTOGRAD_FUNCTION_COUNTER
309   auto key = PyLong_AsSsize_t(PyTuple_GET_ITEM(pykey.get(), 0));
310   if (C10_UNLIKELY(key < 0)) {
311     TORCH_CHECK(PyErr_Occurred(), "key must be positive");
312     throw_python_error();
313   }
314   args.collect_size(static_cast<size_t>(key));
315   args.collect_size(static_cast<size_t>(size));
316 
317   auto f = (THPFunction*)obj;
318   f->compiled_autograd_symints.clear();
319   f->compiled_autograd_symints.reserve(size - 1);
320   for (const auto i : c10::irange(1, size)) {
321     auto val = PyLong_AsSsize_t(PyTuple_GET_ITEM(pykey.get(), i));
322     if (C10_UNLIKELY(val == -1 && PyErr_Occurred()))
323       throw_python_error();
324     f->compiled_autograd_symints.emplace_back(val);
325   }
326 
327   // AotAutograd symints are all dynamic
328   auto prior =
329       args.set_default_dyn_type(torch::dynamo::autograd::SizeInput::DYNAMIC);
330   args.collect(f->compiled_autograd_symints);
331   args.set_default_dyn_type(prior);
332 
333   args.collect(f->saved_variables, true); // always unpacked as output in eager
334   args.collect(f->materialize_grads);
335   args.collect(f->is_variable_input);
336   args.collect(f->needs_input_grad);
337   args.collect(f->materialize_non_diff_grads);
338   args.collect(f->output_info);
339   args.collect(f->input_info);
340 
341   if (compiled_autograd_should_lift()) {
342     Py_INCREF(obj);
343     _backward_idx =
344         args.add_backward(c10::SafePyObject(obj, getPyInterpreter()));
345   }
346 
347   PyObject* bw_state = f->compiled_autograd_backward_state;
348   if (args.cond(bw_state != nullptr)) {
349     Py_INCREF(bw_state);
350     _backward_state_idx = args.add_backward_state(
351         c10::SafePyObject(bw_state, getPyInterpreter()));
352   }
353 }
354 
apply_with_saved(const variable_list & inputs,SwapSavedVariables & saved)355 variable_list PyNode::apply_with_saved(
356     const variable_list& inputs,
357     SwapSavedVariables& saved) {
358   auto f = (THPFunction*)obj;
359   TORCH_INTERNAL_ASSERT(!f->compiled_autograd_tracing);
360   saved.before(f->compiled_autograd_symints);
361   saved.before(f->saved_variables);
362   saved.before(f->needs_input_grad);
363   saved.before(f->materialize_non_diff_grads);
364   saved.before(f->output_info);
365   saved.before(f->input_info);
366   f->compiled_autograd_tracing = true;
367   variable_list result;
368   if (!compiled_autograd_should_lift()) {
369     if (_backward_state_idx.has_value()) {
370       PyObject* r = PyObject_CallMethod(
371           saved.get_py_compiler(),
372           "bind_backward_state",
373           "i",
374           *_backward_state_idx);
375       if (r == nullptr) {
376         throw python_error();
377       }
378       THPObjectPtr prior(f->compiled_autograd_backward_state);
379       f->compiled_autograd_backward_state = r;
380       result = apply(variable_list(inputs));
381       Py_CLEAR(f->compiled_autograd_backward_state);
382       f->compiled_autograd_backward_state = prior.release();
383     } else {
384       result = apply(variable_list(inputs));
385     }
386   } else {
387     result = defer_to_dynamo(variable_list(inputs), saved.get_py_compiler());
388   }
389   f->compiled_autograd_tracing = false;
390   saved.after(f->compiled_autograd_symints);
391   saved.after(f->saved_variables);
392   saved.after(f->needs_input_grad);
393   saved.after(f->materialize_non_diff_grads);
394   saved.after(f->output_info);
395   saved.after(f->input_info);
396   return result;
397 }
398 
to_py_args(const variable_list & inputs,at::OptionalDeviceGuard * device_guard)399 PyObject* PyNode::to_py_args(
400     const variable_list& inputs,
401     at::OptionalDeviceGuard* device_guard) {
402   THPFunction* py_fn = (THPFunction*)obj;
403 
404   auto zeros_without_gil = [](const VariableInfo& variable,
405                               at::OptionalDeviceGuard& dg) {
406     pybind11::gil_scoped_release gil;
407     return variable.zeros(dg);
408   };
409 
410   auto num_inputs = inputs.size();
411   PyObject* pyInputs = PyTuple_New(static_cast<Py_ssize_t>(num_inputs));
412   if (!pyInputs)
413     throw_python_error();
414   auto& output_info = py_fn->output_info;
415   for (const auto i : c10::irange(num_inputs)) {
416     PyObject* input = nullptr;
417     if (inputs[i].defined() || !py_fn->materialize_grads ||
418         (input_metadata(i).was_default_constructed() &&
419          !py_fn->materialize_non_diff_grads)) {
420       input = THPVariable_Wrap(inputs[i]);
421     } else {
422       input =
423           THPVariable_Wrap(zeros_without_gil(output_info[i], *device_guard));
424     }
425     if (!input)
426       throw_python_error();
427     PyTuple_SET_ITEM(pyInputs, i, input);
428   }
429 
430   return pyInputs;
431 }
432 
to_variable_list(const PyObject * outputs,const std::vector<bool> & is_variable_input)433 variable_list PyNode::to_variable_list(
434     const PyObject* outputs,
435     const std::vector<bool>& is_variable_input) {
436   auto num_outputs = PyTuple_GET_SIZE(outputs);
437   variable_list results;
438   results.reserve(num_outputs);
439   for (int i = 0; i != num_outputs; ++i) {
440     PyObject* output = PyTuple_GET_ITEM(outputs, i);
441     bool was_variable = is_variable_input[i];
442     if (!was_variable) {
443       if (output != Py_None) {
444         std::string msg("function ");
445         msg += name() + " returned a gradient different than None at position ";
446         msg += std::to_string(i + 1) +
447             ", but the corresponding forward input was not a Variable";
448         throw std::runtime_error(msg);
449       }
450       continue;
451     }
452     if (output == Py_None) {
453       results.emplace_back();
454     } else {
455       if (!THPVariable_Check(output)) {
456         std::string msg("expected Variable or None (got ");
457         msg += THPUtils_typename(output);
458         msg += ")";
459         throw std::runtime_error(msg);
460       }
461       results.emplace_back(THPVariable_Unpack(output));
462     }
463   }
464 
465   return results;
466 }
467 
468 } // namespace torch::autograd
469 
470 // Traverse and clear are required for supporting Python's GC cycle handling.
THPFunction_traverse(THPFunction * self,visitproc visit,void * arg)471 static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
472   // NB: We should not traverse PyObbject stored on PyNode, since we only hold
473   // as weak reference to the PyNode.
474   Py_VISIT(self->to_save);
475   Py_VISIT(self->non_differentiable);
476   Py_VISIT(self->dirty_tensors);
477   Py_VISIT(self->compiled_autograd_backward_state);
478   Py_VISIT(self->saved_for_forward);
479   return 0;
480 }
481 
THPFunction_clear(THPFunction * self)482 static int THPFunction_clear(THPFunction* self) {
483   // Note that the cdata might not be expired yet in the case where this
484   // object is part of a cycle and the GC happens to tp_clear this PyObject
485   // before the other ones that trigger the de-allocation of the cdata
486 
487   Py_CLEAR(self->needs_input_grad);
488 
489   Py_CLEAR(self->to_save);
490   Py_CLEAR(self->non_differentiable);
491   Py_CLEAR(self->dirty_tensors);
492   Py_CLEAR(self->compiled_autograd_backward_state);
493   Py_CLEAR(self->saved_for_forward);
494 
495   self->output_info.clear();
496   self->input_info.clear();
497   self->saved_variables.clear();
498   self->is_variable_input.clear();
499 
500   return 0;
501 }
502 
THPFunction_dealloc(THPFunction * self)503 static void THPFunction_dealloc(THPFunction* self) {
504   // Why is this guaranteed to be true?  Suppose that self->cdata is non-null
505   // (otherwise the condition is trivially true).  Then there is a PyNode
506   // which contains an owning reference to this object.  But we are only
507   // allowed to clear if all owning references are gone!  Contradiction.
508   //
509   // However, note that THPFunction_clear is typically called in the shared_ptr
510   // destructor of PyNode; in that case, per
511   // https://cplusplus.github.io/LWG/lwg-active.html#2751 it's not currently
512   // specified in the standard that this is guaranteed.  If you see this
513   // assert triggering in the wild, feel free to comment it out.  They're
514   // likely to standardize that you ARE guaranteed to see the weak pointers
515   // as expired in the destructor in the future, so we'll keep this for now.
516   TORCH_INTERNAL_ASSERT(self->cdata.expired());
517 
518   PyObject_GC_UnTrack(self);
519   THPFunction_clear(self);
520   self->cdata.~weak_ptr<PyNode>();
521   self->output_info.~vector();
522   self->input_info.~vector();
523   self->saved_variables.~vector();
524   self->is_variable_input.~vector();
525   Py_TYPE(self)->tp_free((PyObject*)self);
526 }
527 
THPFunction_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)528 PyObject* THPFunction_new(
529     PyTypeObject* type,
530     PyObject* args,
531     PyObject* kwargs) {
532   PyObject* obj = type->tp_alloc(type, 0);
533   if (!obj)
534     return nullptr;
535   // Python zero-initializes the object memory, so there's no need to initialize
536   // most fields
537   THPFunction* self = (THPFunction*)obj;
538   // Setup the PyNode later; we can't keep it live here
539   new (&self->cdata) std::weak_ptr<PyNode>();
540   new (&self->output_info) std::vector<VariableInfo>();
541   new (&self->input_info) std::vector<VariableInfo>();
542   new (&self->saved_variables) std::vector<SavedVariable>();
543   new (&self->is_variable_input) std::vector<bool>();
544   self->materialize_grads = true;
545   self->materialize_non_diff_grads = true;
546   self->compiled_autograd_tracing = false;
547   return obj;
548 }
549 
550 ////////////////////////////////////////////////////////////////////////////////
551 // Forward
552 ////////////////////////////////////////////////////////////////////////////////
553 
554 // Bump the counters of all recorded dirty input tensors, adding each of them
555 // into dirty_inputs.  Also does some sanity checking.
_mark_dirty(THPFunction * self)556 static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction* self) {
557   // Increase versions of modified tensors
558   std::unordered_set<at::TensorImpl*> dirty_inputs;
559   if (!self->dirty_tensors)
560     return dirty_inputs;
561 
562   THPFunction_assert(
563       PyTuple_Check(self->dirty_tensors),
564       "autograd "
565       "internal error: dirty_tensors attribute is expected to be a tuple "
566       "but is ",
567       THPUtils_typename(self->dirty_tensors));
568   Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
569   dirty_inputs.reserve(num_dirty);
570   for (const auto i : c10::irange(num_dirty)) {
571     PyObject* obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
572     THPFunction_assert(
573         THPVariable_Check(obj),
574         "mark_dirty can "
575         "only accept variables, but argument ",
576         i,
577         " is of type ",
578         THPUtils_typename(obj));
579 
580     const auto& tensor = THPVariable_Unpack(obj);
581     dirty_inputs.insert(tensor.unsafeGetTensorImpl());
582     torch::autograd::impl::bump_version(tensor);
583   }
584   // We're not going to ever need this so let's remove references now
585   Py_CLEAR(self->dirty_tensors);
586   return dirty_inputs;
587 }
588 
589 static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
590     THPFunction* self);
591 
592 // Given a Python tuple of raw output tensors (raw_output), set each of
593 // the corresponding entries in a different Python tuple (outputs) with
594 // these tensors wrapped with variables.  We save the gradient function (self)
595 // to the variable if the output requires grad.
596 //
597 // There is a considerable amount of complexity to handle if the operation
598 // that produced these output tensors is inplace.  A mapping of *input*
599 // tensors to variables (t2var) is used to test if this occurred, and
600 // the set of dirty tensors (dirty_inputs) is used to figure out what to
601 // do in this case.  After this method is run, t2var is extended with
602 // mappings for output tensors as well.
_wrap_outputs(const std::shared_ptr<PyNode> & cdata,THPFunction * self,const variable_list & input_vars,PyObject * raw_output,PyObject * outputs,bool is_executable,const std::unordered_set<at::TensorImpl * > & to_save_if_setup_context)603 static void _wrap_outputs(
604     const std::shared_ptr<PyNode>& cdata,
605     THPFunction* self,
606     const variable_list& input_vars,
607     PyObject* raw_output,
608     PyObject* outputs,
609     bool is_executable,
610     const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
611   auto cdata_if_executable = is_executable ? cdata : nullptr;
612   Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
613   if (is_executable) {
614     self->output_info.clear();
615     self->output_info.reserve(num_outputs);
616   }
617 
618   auto non_differentiable = _parse_non_differentiable(self);
619   auto dirty_inputs = _mark_dirty(self);
620 
621   std::vector<std::optional<Variable>> raw_output_vars;
622   raw_output_vars.reserve(num_outputs);
623   for (const auto i : c10::irange(num_outputs)) {
624     PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
625     // Only process tensors as outputs for autograd purposes.
626     if (THPVariable_Check(obj)) {
627       raw_output_vars.emplace_back(THPVariable_Unpack(obj));
628     } else {
629       raw_output_vars.emplace_back();
630     }
631   }
632 
633   _jvp_fn_t jvp_user_function = [self](
634                                     variable_list inputs,
635                                     variable_list grad_inputs) {
636     pybind11::gil_scoped_acquire gil;
637 
638     // Massage a C++ variable_list into a Python arguments tuple
639     // Making sure to introduce the proper None for non-Tensor inputs
640     auto num_inputs = self->is_variable_input.size();
641     THPObjectPtr pyInputs(PyTuple_New(static_cast<Py_ssize_t>(num_inputs)));
642     if (!pyInputs)
643       throw_python_error();
644     int64_t variable_idx = 0;
645     for (const auto i : c10::irange(num_inputs)) {
646       PyObject* input = nullptr;
647       if (self->is_variable_input[i]) {
648         if (grad_inputs[variable_idx].defined() || !self->materialize_grads ||
649             !isDifferentiableType(inputs[variable_idx].scalar_type())) {
650           input = THPVariable_Wrap(grad_inputs[variable_idx]);
651         } else {
652           input = THPVariable_Wrap(at::zeros_like(inputs[variable_idx]));
653         }
654         if (!input) {
655           throw_python_error();
656         }
657         variable_idx++;
658       } else {
659         Py_INCREF(Py_None);
660         input = Py_None;
661       }
662       PyTuple_SET_ITEM(pyInputs.get(), i, input);
663     }
664 
665     THPObjectPtr apply_jvp_fn(
666         PyObject_GetAttrString((PyObject*)self, "apply_jvp"));
667     if (!apply_jvp_fn)
668       throw_python_error();
669     THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get()));
670     if (!r)
671       throw_python_error();
672     ensure_tuple(r);
673 
674     // Massage the Python results tuple back into a C++ variable_list
675     // Don't do any check on the number of results here as
676     // it is handled by the caller
677     const int num_outputs = PyTuple_GET_SIZE(r.get());
678     variable_list results;
679     results.reserve(num_outputs);
680     for (const auto i : c10::irange(num_outputs)) {
681       PyObject* output = PyTuple_GET_ITEM(r.get(), i);
682       if (output == Py_None) {
683         results.emplace_back();
684       } else {
685         TORCH_CHECK(
686             THPVariable_Check(output),
687             "expected Variable or None (got ",
688             THPUtils_typename(output),
689             ") for grad output ",
690             i,
691             ".")
692         results.emplace_back(THPVariable_Unpack(output));
693       }
694     }
695 
696     return results;
697   };
698 
699   auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
700     pybind11::gil_scoped_acquire gil;
701     THPObjectPtr py_x(THPVariable_Wrap(x));
702     THPObjectPtr py_view_as_method(PyObject_GetAttrString(py_x, "view_as"));
703     if (!py_view_as_method)
704       throw python_error();
705     THPObjectPtr args(PyTuple_Pack(1, py_x.get()));
706     if (!args)
707       throw python_error();
708     THPObjectPtr result(PyObject_CallObject(py_view_as_method, args));
709     if (!result)
710       throw python_error();
711     return THPVariable_Unpack(result);
712   };
713 
714   // Wrap only the tensor outputs.
715   auto wrapped_outputs = _wrap_outputs(
716       input_vars,
717       non_differentiable,
718       dirty_inputs,
719       raw_output_vars,
720       cdata_if_executable,
721       jvp_user_function,
722       to_save_if_setup_context,
723       view_as_self_fn);
724 
725   for (const auto i : c10::irange(num_outputs)) {
726     PyObject* obj = PyTuple_GetItem(raw_output, i);
727     // Keep the non-tensor outputs as is.
728     if (!THPVariable_Check(obj)) {
729       if (is_executable) {
730         self->output_info.emplace_back();
731       }
732       Py_INCREF(obj);
733       PyTuple_SetItem(outputs, i, obj);
734     } else {
735       if (is_executable) {
736         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
737         self->output_info.emplace_back(*wrapped_outputs[i]);
738       }
739       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
740       PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
741     }
742   }
743 }
744 
_get_tensors_to_save(THPFunction * self,std::unordered_set<at::TensorImpl * > & to_save_if_setup_context,std::vector<std::optional<at::Tensor>> & tensors_to_save,bool overridden_setup_context,bool is_executable)745 static void _get_tensors_to_save(
746     THPFunction* self,
747     std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
748     std::vector<std::optional<at::Tensor>>& tensors_to_save,
749     bool overridden_setup_context,
750     bool is_executable) {
751   if (self->saved_for_forward && overridden_setup_context) {
752     // We look at saved_for_forward here purely for the purpose of populating
753     // to_save_if_setup_context, the actual saving is not done here.
754     THPFunction_assert(
755         PyTuple_Check(self->saved_for_forward),
756         "autograd internal "
757         "error: saved_for_forward attribute is expected to be a tuple but is ",
758         THPUtils_typename(self->saved_for_forward));
759     Py_ssize_t num_saved_for_forward =
760         PyTuple_GET_SIZE(self->saved_for_forward);
761     for (const auto i : c10::irange(num_saved_for_forward)) {
762       PyObject* obj = PyTuple_GET_ITEM(self->saved_for_forward, i);
763       if (THPVariable_Check(obj)) {
764         const auto& tensor = THPVariable_Unpack(obj);
765         to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
766       }
767     }
768   }
769   if (self->to_save) {
770     THPFunction_assert(
771         PyTuple_Check(self->to_save),
772         "autograd internal "
773         "error: to_save attribute is expected to be a tuple but is ",
774         THPUtils_typename(self->to_save));
775 
776     Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
777     for (const auto i : c10::irange(num_saved)) {
778       PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
779       if (obj == Py_None) {
780         tensors_to_save.emplace_back(std::nullopt);
781         continue;
782       } else if (THPVariable_Check(obj)) {
783         const auto& tensor = THPVariable_Unpack(obj);
784         if (overridden_setup_context) {
785           to_save_if_setup_context.insert(tensor.unsafeGetTensorImpl());
786         }
787         if (is_executable) {
788           tensors_to_save.emplace_back(tensor);
789         }
790       } else {
791         if (is_executable) {
792           // TODO: We should really just ALWAYS throw an error here, but
793           // doing so will break some internal tests. We should fix those.
794           throw torch::TypeError(
795               "save_for_backward can only save variables, but argument %ld is of "
796               "type %s",
797               i,
798               Py_TYPE(obj)->tp_name);
799         }
800       }
801     }
802   }
803 }
804 // Save any variables that requested by to_save
_save_variables(const std::vector<std::optional<at::Tensor>> & tensors_to_save,const std::shared_ptr<PyNode> & cdata_ptr,THPFunction * self)805 static void _save_variables(
806     const std::vector<std::optional<at::Tensor>>& tensors_to_save,
807     const std::shared_ptr<PyNode>& cdata_ptr,
808     THPFunction* self) {
809   if (!self->to_save)
810     return;
811   size_t num_saved = tensors_to_save.size();
812   self->saved_variables.clear();
813   self->saved_variables.reserve(num_saved);
814   for (const auto& opt_tensor : tensors_to_save) {
815     if (!opt_tensor.has_value()) {
816       self->saved_variables.emplace_back();
817     } else {
818       bool is_output = opt_tensor.value().grad_fn().get() == cdata_ptr.get();
819       self->saved_variables.emplace_back(opt_tensor.value(), is_output);
820     }
821   }
822   // Free .to_save
823   Py_CLEAR(self->to_save);
824 }
825 
826 // Mark requires_grad = 0 on non-differentiable variables (as per
827 // non_differentiable)
_parse_non_differentiable(THPFunction * self)828 static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
829     THPFunction* self) {
830   std::unordered_set<at::TensorImpl*> set;
831   if (!self->non_differentiable)
832     return set;
833 
834   THPFunction_assert(
835       PyTuple_Check(self->non_differentiable),
836       "autograd "
837       "internal error: non_differentiable attribute is expected to be a "
838       "tuple but is ",
839       THPUtils_typename(self->non_differentiable));
840   Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable);
841   set.reserve(num_nondiff);
842   for (const auto i : c10::irange(num_nondiff)) {
843     PyObject* t = PyTuple_GET_ITEM(self->non_differentiable, i);
844     THPFunction_assert(
845         THPVariable_Check(t),
846         "mark_non_differentiable "
847         "only accepts variable arguments, but got ",
848         THPUtils_typename(t));
849     set.insert(THPVariable_Unpack(t).unsafeGetTensorImpl());
850   }
851   Py_CLEAR(self->non_differentiable);
852   return set;
853 }
854 
855 struct UnpackedInput {
856   THPObjectPtr input_tuple;
857   variable_list input_vars;
858   // record_function_inputs is for RECORD_FUNCTION only
859   std::vector<c10::IValue> record_function_inputs;
860 };
861 
862 struct InputFlags {
863   bool is_executable = false;
864   edge_list next_edges;
865   THPObjectPtr needs_input_grad;
866   std::vector<bool> is_variable_input;
867 };
868 
869 template <bool enforce_variables>
unpack_input(PyObject * args)870 std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
871   UnpackedInput unpacked;
872   InputFlags flags;
873 
874   auto num_args = PyTuple_GET_SIZE(args);
875   unpacked.input_tuple = PyTuple_New(num_args);
876   flags.needs_input_grad = PyTuple_New(num_args);
877   bool profiler_need_input = torch::autograd::profiler::profilerEnabled() &&
878       torch::autograd::profiler::getProfilerConfig().report_input_shapes;
879 
880   for (const auto i : c10::irange(num_args)) {
881     PyObject* arg = PyTuple_GET_ITEM(args, i);
882 
883     bool is_variable = THPVariable_Check(arg);
884     flags.is_variable_input.push_back(is_variable);
885     if (!is_variable) {
886       // TODO: remove this code path once Variable and Tensor are merged in
887       // Python
888       if (enforce_variables) {
889         THPUtils_setError(
890             "expected a Tensor argument, but got ", THPUtils_typename(arg));
891         throw python_error();
892       }
893       Py_INCREF(Py_False);
894       PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
895 
896       if (profiler_need_input) {
897         // The following conversion from PyObject to IValue is expensive
898         // Only do it if profiler is enabled and needs input shapes
899         auto match = torch::jit::tryToInferPrimitiveType(arg);
900         if (match.success()) {
901           unpacked.record_function_inputs.push_back(
902               torch::jit::toIValue(arg, match.type()));
903         }
904       }
905     } else {
906       const auto& tensor = THPVariable_Unpack(arg);
907       unpacked.input_vars.push_back(tensor);
908       PyObject* needs_grad = tensor.requires_grad() ? Py_True : Py_False;
909       Py_INCREF(needs_grad);
910       PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
911       unpacked.record_function_inputs.emplace_back(tensor);
912     }
913     Py_INCREF(arg);
914     PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
915   }
916 
917   flags.is_executable =
918       GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
919   flags.next_edges =
920       (flags.is_executable ? collect_next_edges(unpacked.input_vars)
921                            : edge_list());
922   return std::make_pair(std::move(unpacked), std::move(flags));
923 }
924 
925 // Given a prim::PythonOp node, _append_subgraph creates a subgraph such that:
926 // (1) It has the same inputs as the prim::PythonOp node
927 // (2) The intermediate nodes used in the PythonOp are cloned and stored in the
928 // subgraph (3) trace_outputs stores the Value* objects, before a new trace
929 // value is assigned by the prim::PythonOp node and helps to eventually route
930 // the outputs of the subgraph correctly This newly created subgraph is then
931 // added to the prim::PythonOp node as a subgraph attribute
_append_subgraph(torch::jit::Node * node,torch::jit::Graph * graph,std::vector<torch::jit::Value * > trace_outputs,bool unpack_output)932 static void _append_subgraph(
933     torch::jit::Node* node,
934     torch::jit::Graph* graph,
935     std::vector<torch::jit::Value*> trace_outputs,
936     bool unpack_output) {
937   using Value = torch::jit::Value;
938   node->g_(
939       torch::jit::attr::Subgraph,
940       std::make_shared<torch::jit::Graph>(graph->current_scope()));
941   auto subgraph = node->g(torch::jit::attr::Subgraph);
942 
943   std::unordered_map<Value*, Value*> value_map;
944   auto value_map_func = [&](Value* v) { return value_map.at(v); };
945   for (size_t i = 0; i < node->inputs().size(); ++i) {
946     auto subgraph_input = subgraph->addInput();
947     subgraph_input->copyMetadata(node->inputs().at(i));
948     value_map[node->inputs().at(i)] = subgraph_input;
949   }
950   // Find node position in owning block, all subsequent nodes after are added to
951   // subgraph
952   auto owning_block = node->owningBlock();
953   auto it = std::find(
954       owning_block->nodes().begin(), owning_block->nodes().end(), node);
955   // Skip TupleUnpack node if created
956   if (!unpack_output) {
957     it++;
958   }
959   for (it++; it != owning_block->nodes().end(); ++it) {
960     torch::jit::Node* node = *it;
961     auto* clone_node =
962         subgraph->insertNode(subgraph->createClone(node, value_map_func));
963     for (size_t i = 0; i < node->outputs().size(); ++i) {
964       value_map[node->outputs()[i]] = clone_node->outputs()[i];
965       auto trace_it = std::find(
966           trace_outputs.begin(), trace_outputs.end(), node->outputs()[i]);
967       if (trace_it != trace_outputs.end()) {
968         subgraph->registerOutput(clone_node->outputs()[i]);
969       }
970     }
971   }
972 }
973 
_trace_pre_record(PyObject * op_obj,PyObject * input_objects,const variable_list & input_vars)974 static torch::jit::Node* _trace_pre_record(
975     PyObject* op_obj,
976     PyObject* input_objects,
977     const variable_list& input_vars) {
978   if (!jit::tracer::isTracing()) {
979     return nullptr;
980   }
981 
982   // Save scalar args and the calling convention
983   auto num_args = PyTuple_GET_SIZE(input_objects);
984   pyobj_list scalar_args;
985   std::string arg_types;
986   arg_types.reserve(num_args);
987   scalar_args.reserve(num_args);
988   for (const auto i : c10::irange(num_args)) {
989     PyObject* arg_object = PyTuple_GET_ITEM(input_objects, i);
990     if (THPVariable_Check(arg_object)) {
991       arg_types.push_back('d');
992     } else {
993       arg_types.push_back('c');
994       Py_INCREF(arg_object);
995       scalar_args.emplace_back(arg_object);
996     }
997   }
998 
999   Py_INCREF(op_obj);
1000   auto pyobj = THPObjectPtr(op_obj);
1001   return jit::tracer::preRecordPythonTrace(
1002       std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
1003 }
1004 
_trace_post_record(torch::jit::Node * node,PyObject * op_obj,const variable_list & input_vars,PyObject * output_objects,bool is_inplace,bool unpack_output)1005 static void _trace_post_record(
1006     torch::jit::Node* node,
1007     PyObject* op_obj,
1008     const variable_list& input_vars,
1009     PyObject* output_objects,
1010     bool is_inplace,
1011     bool unpack_output) {
1012   if (!jit::tracer::isTracing()) {
1013     return;
1014   }
1015 
1016   node->i_(jit::attr::inplace, is_inplace);
1017   if (PyObject* module_name = PyDict_GetItemString(
1018           ((PyTypeObject*)op_obj)->tp_dict, "__module__")) {
1019     if (auto ptr = PyUnicode_AsUTF8(module_name)) {
1020       node->s_(jit::attr::module, std::string(ptr));
1021     }
1022   }
1023 
1024   // Isolate C variable ptrs in a vector
1025   int num_outputs = PyTuple_GET_SIZE(output_objects);
1026   auto graph = node->owningGraph();
1027   node->addOutput();
1028   auto old_node = node;
1029   if (!unpack_output) {
1030     std::vector<at::TypePtr> tuple_values(num_outputs, at::TensorType::get());
1031     auto tuple_type = at::TupleType::create(std::move(tuple_values));
1032     // Original type is tuple of tensors "without" element type and shape.
1033     // The missed parts will be added below.
1034     node->output()->setType(std::move(tuple_type));
1035     auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
1036     node = unpacked;
1037   }
1038 
1039   std::vector<torch::jit::Value*> trace_outputs;
1040   for (const auto i : c10::irange(num_outputs)) {
1041     PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
1042     if (THPVariable_Check(obj)) {
1043       auto value = node->outputs()[i];
1044       const auto& tensor = THPVariable_Unpack(obj);
1045       if (tensor.defined()) {
1046         value->inferTypeFrom(tensor);
1047         trace_outputs.push_back(jit::tracer::getValueTrace(tensor));
1048         jit::tracer::setValueTrace(tensor, value);
1049       }
1050     }
1051   }
1052   py::object onnx_globals = py::module::import("torch.onnx._globals");
1053   py::bool_ is_in_onnx_export =
1054       py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
1055   py::bool_ is_autograd_inlining_enabled =
1056       py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));
1057 
1058   if (py::cast<bool>(is_in_onnx_export) &&
1059       py::cast<bool>(is_autograd_inlining_enabled)) {
1060     _append_subgraph(old_node, graph, std::move(trace_outputs), unpack_output);
1061   }
1062 
1063   // If TupleUnpack operator is created, we copy its output type back
1064   // to the original tuple type.
1065   if (!unpack_output) {
1066     std::vector<at::TypePtr> new_tuple_values;
1067     for (const auto i : c10::irange(num_outputs)) {
1068       auto ptr = node->outputs()[i]->type();
1069       new_tuple_values.push_back(ptr);
1070     }
1071     auto tuple_type = at::TupleType::create(std::move(new_tuple_values));
1072     // The i-th tuple element receives a new tensor type with element type and
1073     // shape.
1074     old_node->output()->setType(std::move(tuple_type));
1075   }
1076 }
1077 
process_outputs(PyObject * op_obj,const std::shared_ptr<PyNode> & cdata,THPFunction * grad_fn,const UnpackedInput & unpacked,PyObject * inputs,THPObjectPtr && raw_output,bool is_executable,torch::jit::Node * node,bool overridden_setup_context)1078 PyObject* process_outputs(
1079     PyObject* op_obj,
1080     const std::shared_ptr<PyNode>& cdata,
1081     THPFunction* grad_fn,
1082     const UnpackedInput& unpacked,
1083     PyObject* inputs,
1084     THPObjectPtr&& raw_output,
1085     bool is_executable,
1086     torch::jit::Node* node,
1087     bool overridden_setup_context) {
1088   bool unpack_output = ensure_tuple(raw_output);
1089 
1090   auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
1091 
1092   THPObjectPtr outputs(PyTuple_New(num_outputs));
1093   if (!outputs)
1094     throw python_error();
1095 
1096   cdata->clear_input_metadata();
1097 
1098   // Record type, device, and size information about inputs
1099   if (is_executable) {
1100     grad_fn->input_info.clear();
1101     grad_fn->input_info.reserve(unpacked.input_vars.size());
1102     for (auto& var : unpacked.input_vars) {
1103       grad_fn->input_info.emplace_back(var);
1104     }
1105   }
1106 
1107   std::unordered_set<at::TensorImpl*> to_save_if_setup_context{};
1108   std::vector<std::optional<at::Tensor>> tensors_to_save{};
1109   _get_tensors_to_save(
1110       grad_fn,
1111       to_save_if_setup_context,
1112       tensors_to_save,
1113       overridden_setup_context,
1114       is_executable);
1115 
1116   bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
1117   _wrap_outputs(
1118       cdata,
1119       grad_fn,
1120       unpacked.input_vars,
1121       raw_output,
1122       outputs,
1123       is_executable,
1124       to_save_if_setup_context);
1125   _trace_post_record(
1126       node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
1127 
1128   // It is important that creating the SavedVariables happen after the output
1129   // wrapping as the outputs must have their grad_fn/fw_grad properly set before
1130   // we save them.
1131   if (is_executable) {
1132     _save_variables(tensors_to_save, cdata, grad_fn);
1133   } else {
1134     // Remove unnecessary attributes
1135     Py_XDECREF(grad_fn->to_save);
1136     grad_fn->to_save = nullptr;
1137     Py_XDECREF(grad_fn->non_differentiable);
1138     grad_fn->non_differentiable = nullptr;
1139   }
1140 
1141   Py_XDECREF(grad_fn->saved_for_forward);
1142   grad_fn->saved_for_forward = nullptr;
1143 
1144   // Unpack the output, unless .forward() returned a tuple
1145   if (unpack_output) {
1146     PyObject* output = PyTuple_GET_ITEM(outputs.get(), 0);
1147     Py_INCREF(output);
1148     return output;
1149   }
1150 
1151   return outputs.release();
1152 }
1153 
THPFunction_name(PyObject * self,PyObject * noargs)1154 PyObject* THPFunction_name(PyObject* self, PyObject* noargs) {
1155   HANDLE_TH_ERRORS
1156   auto cdata = ((THPFunction*)self)->cdata.lock();
1157   TORCH_CHECK(
1158       cdata,
1159       "Attribute 'name' is invalid for this instance of _C._FunctionBase. "
1160       "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1161       "access pattern that is no longer supported. For examples on how to use new-style "
1162       "autograd functions, see "
1163       "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1164   return THPUtils_packString(cdata->name());
1165   END_HANDLE_TH_ERRORS
1166 }
1167 
THPFunction_sequence_nr(PyObject * self,PyObject * noargs)1168 PyObject* THPFunction_sequence_nr(PyObject* self, PyObject* noargs) {
1169   HANDLE_TH_ERRORS;
1170   auto cdata = ((THPFunction*)self)->cdata.lock();
1171   return THPUtils_packUInt64(cdata->sequence_nr());
1172   END_HANDLE_TH_ERRORS
1173 }
1174 
THPFunction_set_sequence_nr(PyObject * self,PyObject * sequence_nr)1175 PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
1176   HANDLE_TH_ERRORS;
1177   auto cdata = ((THPFunction*)self)->cdata.lock();
1178   cdata->set_sequence_nr(THPUtils_unpackUInt64(sequence_nr));
1179   Py_RETURN_NONE;
1180   END_HANDLE_TH_ERRORS
1181 }
1182 
THPFunction_input_metadata(PyObject * self,void * unused)1183 PyObject* THPFunction_input_metadata(PyObject* self, void* unused) {
1184   HANDLE_TH_ERRORS;
1185   auto cdata = ((THPFunction*)self)->cdata.lock();
1186   const auto num_inputs = cdata->num_inputs();
1187   THPObjectPtr list(PyTuple_New(num_inputs));
1188   if (!list) {
1189     return nullptr;
1190   }
1191   for (size_t i = 0; i < num_inputs; ++i) {
1192     const auto& metadata = cdata->input_metadata(i);
1193     THPObjectPtr item(py::cast(metadata).release().ptr());
1194     if (!item) {
1195       return nullptr;
1196     }
1197     PyTuple_SET_ITEM(list.get(), i, item.release());
1198   }
1199   return list.release();
1200   END_HANDLE_TH_ERRORS
1201 }
1202 
THPFunction_maybe_clear_saved_tensors(PyObject * self,PyObject * noargs)1203 PyObject* THPFunction_maybe_clear_saved_tensors(
1204     PyObject* self,
1205     PyObject* noargs) {
1206   HANDLE_TH_ERRORS;
1207   auto cdata = ((THPFunction*)self)->cdata.lock();
1208   if (!get_current_graph_task_keep_graph()) {
1209     cdata->release_variables();
1210   }
1211   Py_RETURN_NONE;
1212   END_HANDLE_TH_ERRORS
1213 }
1214 
1215 namespace {
1216 
make_ctx_input_tuple(THPFunction * ctx,const UnpackedInput & unpacked_input,int64_t num_args)1217 THPObjectPtr make_ctx_input_tuple(
1218     THPFunction* ctx,
1219     const UnpackedInput& unpacked_input,
1220     int64_t num_args) {
1221   THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
1222   if (!ctx_input_tuple)
1223     return {};
1224   Py_INCREF(ctx);
1225   PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
1226   for (const auto i : c10::irange(num_args)) {
1227     PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
1228     Py_INCREF(arg);
1229     PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
1230   }
1231   return ctx_input_tuple;
1232 }
1233 
make_ctx_input_output_tuple(THPFunction * ctx,UnpackedInput & unpacked_input,PyObject * output)1234 THPObjectPtr make_ctx_input_output_tuple(
1235     THPFunction* ctx,
1236     UnpackedInput& unpacked_input,
1237     PyObject* output) {
1238   THPObjectPtr result(PyTuple_New(3));
1239   if (!result)
1240     return {};
1241   Py_INCREF(ctx);
1242   Py_INCREF(unpacked_input.input_tuple.get());
1243   Py_INCREF(output);
1244   PyTuple_SET_ITEM(result.get(), 0, (PyObject*)ctx);
1245   PyTuple_SET_ITEM(result.get(), 1, unpacked_input.input_tuple.get());
1246   PyTuple_SET_ITEM(result.get(), 2, output);
1247   return result;
1248 }
1249 
1250 } // namespace
1251 
1252 static PyObject* THPFunction_setup_context = nullptr;
1253 
get_base_setup_context()1254 static PyObject* get_base_setup_context() {
1255   if (THPFunction_setup_context != nullptr) {
1256     return THPFunction_setup_context;
1257   }
1258 
1259   auto module = THPObjectPtr(PyImport_ImportModule("torch.autograd.function"));
1260   if (!module)
1261     return nullptr;
1262 
1263   auto function =
1264       THPObjectPtr(PyObject_GetAttrString(module, "_SingleLevelFunction"));
1265   if (!function)
1266     return nullptr;
1267 
1268   // setup_context gets "leaked" - we return a new reference and hold onto it
1269   // forever.
1270   auto setup_context = PyObject_GetAttrString(function, "setup_context");
1271   if (!setup_context)
1272     return nullptr;
1273   THPFunction_setup_context = setup_context;
1274   return THPFunction_setup_context;
1275 }
1276 
THPFunction_apply(PyObject * cls,PyObject * inputs)1277 PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
1278   HANDLE_TH_ERRORS
1279 
1280   // save a local copy of seq_id before it gets incremented
1281   auto seq_id = at::sequence_number::peek();
1282   auto info_pair = unpack_input<false>(inputs);
1283   UnpackedInput& unpacked_input = info_pair.first;
1284   InputFlags& input_info = info_pair.second;
1285 
1286   // Call record function after all the inputs have been decoded, but
1287   // before context has been allocated.
1288   RECORD_FUNCTION(
1289       ((PyTypeObject*)cls)->tp_name,
1290       unpacked_input.record_function_inputs,
1291       seq_id);
1292 
1293   const auto& functorch_tls = at::functorch::functorchTLSAccessor();
1294   if (functorch_tls) {
1295     // autograd.Function support for functorch is handled in Python.
1296     // If we have gotten here, then either we are dealing with a
1297     // torch.autograd.function._SingleLevelFunction, or something in
1298     // the implementation went wrong.
1299     // The following code is useful for debugging when something goes wrong
1300     // because it'll raise a loud error (instead of being silently incorrect).
1301     functorch_tls->checkSupportsSingleLevelAutogradFunction();
1302   }
1303 
1304   THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
1305   if (!backward_cls)
1306     return nullptr;
1307   THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
1308   if (!ctx_obj)
1309     return nullptr;
1310   THPFunction* ctx = (THPFunction*)ctx_obj.get();
1311 
1312   auto cdata =
1313       std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
1314   ctx->cdata = cdata;
1315 
1316   // Record input nodes if tracing
1317   auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
1318 
1319   // Initialize backward function (and ctx)
1320   bool is_executable = input_info.is_executable;
1321   cdata->set_next_edges(std::move(input_info.next_edges));
1322   ctx->needs_input_grad = input_info.needs_input_grad.release();
1323   ctx->is_variable_input = std::move(input_info.is_variable_input);
1324 
1325   // autograd.Function may optionally override a setup_context staticmethod.
1326   // In this case, autograd.Function.forward does NOT accept a ctx object.
1327   // Determine if this is the case.
1328   auto cls_setup_context =
1329       THPObjectPtr(PyObject_GetAttrString(cls, "setup_context"));
1330   if (!cls_setup_context) {
1331     return nullptr;
1332   }
1333   auto orig_setup_context = get_base_setup_context();
1334   if (!orig_setup_context) {
1335     return nullptr;
1336   }
1337   auto overridden_setup_context = cls_setup_context.get() != orig_setup_context;
1338 
1339   auto num_args = PyTuple_GET_SIZE(inputs);
1340 
1341   // Call forward
1342   THPObjectPtr output;
1343   {
1344     AutoGradMode grad_mode(false);
1345     at::AutoFwGradMode fw_grad_mode(false);
1346     THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
1347     if (!forward_fn)
1348       return nullptr;
1349     if (overridden_setup_context) {
1350       // call forward followed by setup_context
1351       output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
1352       if (!output) {
1353         return nullptr;
1354       }
1355       // signature is setup_context(ctx, inputs, output)
1356       auto ctx_input_output_tuple =
1357           make_ctx_input_output_tuple(ctx, unpacked_input, output);
1358       if (!ctx_input_output_tuple) {
1359         return nullptr;
1360       }
1361       THPObjectPtr setup_context_fn(
1362           PyObject_GetAttrString(cls, "setup_context"));
1363       auto result =
1364           PyObject_CallObject(setup_context_fn, ctx_input_output_tuple);
1365       if (!result) {
1366         return nullptr;
1367       }
1368     } else {
1369       // call forward
1370       auto ctx_input_tuple =
1371           make_ctx_input_tuple(ctx, unpacked_input, num_args);
1372       if (!ctx_input_tuple) {
1373         return nullptr;
1374       }
1375       output = PyObject_CallObject(forward_fn, ctx_input_tuple);
1376     }
1377     if (!output)
1378       return nullptr;
1379   }
1380 
1381   return process_outputs(
1382       cls,
1383       cdata,
1384       ctx,
1385       unpacked_input,
1386       inputs,
1387       std::move(output),
1388       is_executable,
1389       node,
1390       overridden_setup_context);
1391   END_HANDLE_TH_ERRORS
1392 }
1393 
1394 ////////////////////////////////////////////////////////////////////////////////
1395 // Other methods / attributes
1396 ////////////////////////////////////////////////////////////////////////////////
1397 
THPFunction__register_hook_dict(PyObject * _self,PyObject * _var)1398 PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) {
1399   HANDLE_TH_ERRORS
1400   TORCH_CHECK(THPVariable_Check(_var), "_register_hook_dict expected a Tensor");
1401   THPVariable* var = reinterpret_cast<THPVariable*>(_var);
1402   const auto& tensor = THPVariable_Unpack(var);
1403   std::unique_ptr<FunctionPreHook> hook(
1404       new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr()));
1405   auto self = (THPFunction*)_self;
1406   auto cdata = self->cdata.lock();
1407   TORCH_CHECK(
1408       cdata,
1409       "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. "
1410       "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1411       "access pattern that is no longer supported. For examples on how to use new-style "
1412       "autograd functions, see "
1413       "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1414   cdata->add_tensor_pre_hook(std::move(hook));
1415   Py_RETURN_NONE;
1416   END_HANDLE_TH_ERRORS
1417 }
1418 
THPFunction_register_hook(PyObject * _self,PyObject * hook)1419 PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) {
1420   HANDLE_TH_ERRORS
1421   auto self = (THPFunction*)_self;
1422   auto cdata = self->cdata.lock();
1423   TORCH_CHECK(
1424       cdata,
1425       "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. "
1426       "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1427       "access pattern that is no longer supported. For examples on how to use new-style "
1428       "autograd functions, see "
1429       "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1430   return torch::autograd::registerFunctionHook(*cdata, hook);
1431   END_HANDLE_TH_ERRORS
1432 }
1433 
THPFunction_register_prehook(PyObject * _self,PyObject * hook)1434 PyObject* THPFunction_register_prehook(PyObject* _self, PyObject* hook) {
1435   HANDLE_TH_ERRORS
1436   auto self = (THPFunction*)_self;
1437   auto cdata = self->cdata.lock();
1438   TORCH_CHECK(
1439       cdata,
1440       "Attribute 'register_prehook' is invalid for this instance of _C._FunctionBase. "
1441       "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1442       "access pattern that is no longer supported. For examples on how to use new-style "
1443       "autograd functions, see "
1444       "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1445   return torch::autograd::registerFunctionPreHook(*cdata, hook);
1446   END_HANDLE_TH_ERRORS
1447 }
1448 
THPFunction_set_materialize_grads(THPFunction * self,PyObject * value,void * unused)1449 int THPFunction_set_materialize_grads(
1450     THPFunction* self,
1451     PyObject* value,
1452     void* unused) {
1453   HANDLE_TH_ERRORS
1454   if (!PyBool_Check(value)) {
1455     THPUtils_invalidArguments(
1456         value, nullptr, "set_materialize_grads", 1, "(bool)");
1457     return -1;
1458   }
1459   self->materialize_grads = (value == Py_True);
1460   return 0;
1461   END_HANDLE_TH_ERRORS_RET(-1)
1462 }
1463 
THPFunction_get_materialize_non_diff_grads(THPFunction * self,void * _unused)1464 PyObject* THPFunction_get_materialize_non_diff_grads(
1465     THPFunction* self,
1466     void* _unused) {
1467   HANDLE_TH_ERRORS
1468   if (self->materialize_non_diff_grads) {
1469     Py_RETURN_TRUE;
1470   } else {
1471     Py_RETURN_FALSE;
1472   }
1473   END_HANDLE_TH_ERRORS
1474 }
1475 
THPFunction_set_materialize_non_diff_grads(THPFunction * self,PyObject * value,void * unused)1476 int THPFunction_set_materialize_non_diff_grads(
1477     THPFunction* self,
1478     PyObject* value,
1479     void* unused) {
1480   HANDLE_TH_ERRORS
1481   if (!PyBool_Check(value)) {
1482     THPUtils_invalidArguments(
1483         value, nullptr, "set_materialize_non_diff_grads", 1, "(bool)");
1484     return -1;
1485   }
1486   self->materialize_non_diff_grads = (value == Py_True);
1487   return 0;
1488   END_HANDLE_TH_ERRORS_RET(-1)
1489 }
1490 
THPFunction_saved_tensors(THPFunction * self,void * _unused)1491 PyObject* THPFunction_saved_tensors(THPFunction* self, void* _unused) {
1492   HANDLE_TH_ERRORS
1493   if (self->saved_for_forward) {
1494     Py_INCREF(self->saved_for_forward);
1495     return self->saved_for_forward;
1496   } else {
1497     return unpack_saved_variables(
1498         self, [](const Variable& var) { return THPVariable_Wrap(var); });
1499   }
1500   END_HANDLE_TH_ERRORS
1501 }
1502 
THPFunction_saved_variables(THPFunction * self,void * _unused)1503 PyObject* THPFunction_saved_variables(THPFunction* self, void* _unused) {
1504   HANDLE_TH_ERRORS
1505   auto r = PyErr_WarnEx(
1506       PyExc_DeprecationWarning,
1507       "'saved_variables' is deprecated; use 'saved_tensors'",
1508       0);
1509   if (r != 0)
1510     throw python_error();
1511   return unpack_saved_variables(
1512       self, [](const Variable& var) { return THPVariable_Wrap(var); });
1513   END_HANDLE_TH_ERRORS
1514 }
1515 
THPFunction_is_compiled_autograd_tracing(PyObject * self,PyObject * _unused)1516 PyObject* THPFunction_is_compiled_autograd_tracing(
1517     PyObject* self,
1518     PyObject* _unused) {
1519   HANDLE_TH_ERRORS
1520   if (((THPFunction*)self)->compiled_autograd_tracing) {
1521     Py_RETURN_TRUE;
1522   } else {
1523     Py_RETURN_FALSE;
1524   }
1525   END_HANDLE_TH_ERRORS
1526 }
1527 
THPFunction_get_compiled_autograd_symints(PyObject * _self,PyObject * _unused)1528 PyObject* THPFunction_get_compiled_autograd_symints(
1529     PyObject* _self,
1530     PyObject* _unused) {
1531   HANDLE_TH_ERRORS
1532   auto self = (THPFunction*)_self;
1533   auto size = self->compiled_autograd_symints.size();
1534   PyObject* result = PyTuple_New(static_cast<Py_ssize_t>(size));
1535   if (!result) {
1536     throw python_error();
1537   }
1538   for (const auto i : c10::irange(size)) {
1539     PyTuple_SET_ITEM(
1540         result,
1541         i,
1542         py::cast(self->compiled_autograd_symints[i]).release().ptr());
1543   }
1544   return result;
1545   END_HANDLE_TH_ERRORS
1546 }
1547 
THPFunction_get_compiled_autograd_backward_state(PyObject * _self,void * _unused)1548 PyObject* THPFunction_get_compiled_autograd_backward_state(
1549     PyObject* _self,
1550     void* _unused) {
1551   HANDLE_TH_ERRORS
1552   auto self = (THPFunction*)_self;
1553   PyObject* bw_state = self->compiled_autograd_backward_state;
1554   if (bw_state == nullptr) {
1555     bw_state = Py_None;
1556   }
1557   Py_INCREF(bw_state);
1558   return bw_state;
1559   END_HANDLE_TH_ERRORS
1560 }
1561 
THPFunction_set_compiled_autograd_backward_state(PyObject * _self,PyObject * bw_state,void * _unused)1562 int THPFunction_set_compiled_autograd_backward_state(
1563     PyObject* _self,
1564     PyObject* bw_state,
1565     void* _unused) {
1566   HANDLE_TH_ERRORS
1567   auto self = (THPFunction*)_self;
1568   TORCH_INTERNAL_ASSERT(self->compiled_autograd_backward_state == nullptr);
1569   Py_INCREF(bw_state);
1570   self->compiled_autograd_backward_state = bw_state;
1571   return 0;
1572   END_HANDLE_TH_ERRORS_RET(-1)
1573 }
1574 
THPFunction_raw_saved_tensors(THPFunction * self,void * _unused)1575 PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) {
1576   HANDLE_TH_ERRORS
1577   // User tries to access saved variables after they have been freed
1578   TORCH_CHECK(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
1579   const auto& saved_variables = self->saved_variables;
1580   if (saved_variables.empty())
1581     return PyTuple_New(0);
1582   size_t num_saved = saved_variables.size();
1583   THPObjectPtr saved(PyTuple_New(static_cast<Py_ssize_t>(num_saved)));
1584   if (!saved) {
1585     return nullptr;
1586   }
1587   for (const auto i : c10::irange(num_saved)) {
1588     py::object obj =
1589         py::cast(saved_variables[i], py::return_value_policy::reference);
1590     PyTuple_SET_ITEM(saved.get(), i, obj.release().ptr());
1591   }
1592   return saved.release();
1593   END_HANDLE_TH_ERRORS
1594 }
1595 
THPFunction_next_functions(THPFunction * self,void * _unused)1596 PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) {
1597   HANDLE_TH_ERRORS
1598   auto cdata = self->cdata.lock();
1599   TORCH_CHECK(
1600       cdata,
1601       "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. "
1602       "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1603       "access pattern that is no longer supported. For examples on how to use new-style "
1604       "autograd functions, see "
1605       "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1606   const auto num_outputs = cdata->num_outputs();
1607   THPObjectPtr result(PyTuple_New(num_outputs));
1608   if (!result)
1609     return nullptr;
1610   for (const auto i : c10::irange(num_outputs)) {
1611     THPObjectPtr fn_tuple(PyTuple_New(2));
1612     if (!fn_tuple)
1613       return nullptr;
1614     const auto& edge = cdata->next_edge(i);
1615     PyObject* fn = functionToPyObject(edge.function);
1616     if (!fn)
1617       return nullptr;
1618     PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
1619     PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr));
1620     PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
1621   }
1622   return result.release();
1623   END_HANDLE_TH_ERRORS
1624 }
1625 
THPFunction_metadata(THPFunction * self,void * _unused)1626 PyObject* THPFunction_metadata(THPFunction* self, void* _unused) {
1627   HANDLE_TH_ERRORS
1628   auto cdata = self->cdata.lock();
1629   // The correct way to solve this problem is to stop exposing grad_fn
1630   // of PyFunctions as THPFunction; instead, we should use THPCppFunction
1631   // like everyone else.  But this is a BC-breaking change as it would
1632   // mean that you no longer get the property that grad_fn is a subclass
1633   // of the autograd function class that you defined in the custom case,
1634   // so I didn't fix it here.
1635   TORCH_CHECK(
1636       cdata,
1637       "You attempted to access the anomaly metadata of a custom autograd function "
1638       "but the underlying PyNode has already been deallocated.  The most likely "
1639       "reason this occurred is because you assigned x.grad_fn to a local variable "
1640       "and then let the original variable get deallocated.  Don't do that!  If "
1641       "you really have no way of restructuring your code so this is the case, "
1642       "please file an issue reporting that you are affected by this.");
1643   auto metadata = static_cast<PyAnomalyMetadata*>(cdata->metadata())->dict();
1644 
1645   Py_INCREF(metadata);
1646   return metadata;
1647   END_HANDLE_TH_ERRORS
1648 }
1649 
1650 using getter = PyObject* (*)(PyObject*, void*);
1651 using setter = int (*)(PyObject*, PyObject*, void*);
1652 
1653 namespace {
1654 
1655 template <PyObject* THPFunction::*ptr>
getObject(PyObject * obj,void * _unused)1656 PyObject* getObject(PyObject* obj, void* _unused) {
1657   auto self = (THPFunction*)obj;
1658   PyObject* value = self->*ptr;
1659   if (!value) {
1660     Py_RETURN_NONE;
1661   }
1662   Py_INCREF(value);
1663   return value;
1664 }
1665 
1666 template <PyObject* THPFunction::*ptr>
setObject(PyObject * obj,PyObject * value,void * _unused)1667 int setObject(PyObject* obj, PyObject* value, void* _unused) {
1668   auto self = (THPFunction*)obj;
1669   if (value == Py_None) {
1670     value = nullptr;
1671   }
1672   Py_XDECREF((self->*ptr));
1673   Py_XINCREF(value);
1674   self->*ptr = value;
1675   return 0;
1676 }
1677 
1678 template <typename M, M THPFunction::*ptr, PyObject* (*Convert)(long)>
getMember(PyObject * obj,void * _unused)1679 PyObject* getMember(PyObject* obj, void* _unused) {
1680   auto self = (THPFunction*)obj;
1681   return Convert(self->*ptr);
1682 }
1683 
1684 template <typename M, M autograd::Node::*ptr, PyObject* (*Convert)(long)>
getImplMember(PyObject * obj,void * _unused)1685 PyObject* getImplMember(PyObject* obj, void* _unused) {
1686   auto self = (THPFunction*)obj;
1687   return Convert(self->cdata.*ptr);
1688 }
1689 
getRequiresGrad(PyObject * obj,void * _unused)1690 PyObject* getRequiresGrad(PyObject* obj, void* _unused) {
1691   Py_RETURN_TRUE;
1692 }
1693 
1694 } // namespace
1695 
1696 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1697 static struct PyGetSetDef THPFunction_properties[] = {
1698     {"saved_tensors",
1699      (getter)THPFunction_saved_tensors,
1700      nullptr,
1701      nullptr,
1702      nullptr},
1703     {"saved_variables",
1704      (getter)THPFunction_saved_variables,
1705      nullptr,
1706      nullptr,
1707      nullptr},
1708     {"_raw_saved_tensors",
1709      (getter)THPFunction_raw_saved_tensors,
1710      nullptr,
1711      nullptr,
1712      nullptr},
1713     {"next_functions",
1714      (getter)THPFunction_next_functions,
1715      nullptr,
1716      nullptr,
1717      nullptr},
1718     {"to_save",
1719      &getObject<&THPFunction::to_save>,
1720      &setObject<&THPFunction::to_save>,
1721      nullptr,
1722      nullptr},
1723     {"non_differentiable",
1724      &getObject<&THPFunction::non_differentiable>,
1725      &setObject<&THPFunction::non_differentiable>,
1726      nullptr,
1727      nullptr},
1728     {"dirty_tensors",
1729      &getObject<&THPFunction::dirty_tensors>,
1730      &setObject<&THPFunction::dirty_tensors>,
1731      nullptr,
1732      nullptr},
1733     {"saved_for_forward",
1734      &getObject<&THPFunction::saved_for_forward>,
1735      &setObject<&THPFunction::saved_for_forward>,
1736      nullptr,
1737      nullptr},
1738     {"needs_input_grad",
1739      &getObject<&THPFunction::needs_input_grad>,
1740      &setObject<&THPFunction::needs_input_grad>,
1741      nullptr,
1742      nullptr},
1743     {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
1744     {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
1745     {"_input_metadata",
1746      (getter)THPFunction_input_metadata,
1747      nullptr,
1748      nullptr,
1749      nullptr},
1750     {"materialize_grads",
1751      nullptr,
1752      (setter)THPFunction_set_materialize_grads,
1753      nullptr,
1754      nullptr},
1755     {"_materialize_non_diff_grads",
1756      (getter)THPFunction_get_materialize_non_diff_grads,
1757      (setter)THPFunction_set_materialize_non_diff_grads,
1758      nullptr,
1759      nullptr},
1760     {"_compiled_autograd_backward_state",
1761      (getter)THPFunction_get_compiled_autograd_backward_state,
1762      (setter)THPFunction_set_compiled_autograd_backward_state,
1763      nullptr,
1764      nullptr},
1765     {nullptr}};
1766 
1767 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1768 static struct PyMethodDef THPFunction_methods[] = {
1769     {(char*)"name", THPFunction_name, METH_NOARGS, nullptr},
1770     {(char*)"_sequence_nr", THPFunction_sequence_nr, METH_NOARGS, nullptr},
1771     {(char*)"_set_sequence_nr", THPFunction_set_sequence_nr, METH_O, nullptr},
1772     {(char*)"maybe_clear_saved_tensors",
1773      THPFunction_maybe_clear_saved_tensors,
1774      METH_NOARGS,
1775      nullptr},
1776     {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
1777     {(char*)"_register_hook_dict",
1778      THPFunction__register_hook_dict,
1779      METH_O,
1780      nullptr},
1781     {(char*)"register_hook", THPFunction_register_hook, METH_O, nullptr},
1782     {(char*)"register_prehook", THPFunction_register_prehook, METH_O, nullptr},
1783     {(char*)"_is_compiled_autograd_tracing",
1784      THPFunction_is_compiled_autograd_tracing,
1785      METH_NOARGS,
1786      nullptr},
1787     {(char*)"_get_compiled_autograd_symints",
1788      THPFunction_get_compiled_autograd_symints,
1789      METH_NOARGS,
1790      nullptr},
1791     {nullptr}};
1792 
1793 PyTypeObject THPFunctionType = {
1794     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._FunctionBase", /* tp_name */
1795     sizeof(THPFunction), /* tp_basicsize */
1796     0, /* tp_itemsize */
1797     (destructor)THPFunction_dealloc, /* tp_dealloc */
1798     0, /* tp_vectorcall_offset */
1799     nullptr, /* tp_getattr */
1800     nullptr, /* tp_setattr */
1801     nullptr, /* tp_reserved */
1802     nullptr, /* tp_repr */
1803     nullptr, /* tp_as_number */
1804     nullptr, /* tp_as_sequence */
1805     nullptr, /* tp_as_mapping */
1806     nullptr, /* tp_hash  */
1807     nullptr, /* tp_call */
1808     nullptr, /* tp_str */
1809     nullptr, /* tp_getattro */
1810     nullptr, /* tp_setattro */
1811     nullptr, /* tp_as_buffer */
1812     // NOLINTNEXTLINE(misc-redundant-expression)
1813     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
1814         Py_TPFLAGS_HAVE_GC, /* tp_flags */
1815     nullptr, /* tp_doc */
1816     (traverseproc)THPFunction_traverse, /* tp_traverse */
1817     (inquiry)THPFunction_clear, /* tp_clear */
1818     nullptr, /* tp_richcompare */
1819     0, /* tp_weaklistoffset */
1820     nullptr, /* tp_iter */
1821     nullptr, /* tp_iternext */
1822     THPFunction_methods, /* tp_methods */
1823     nullptr, /* tp_members */
1824     THPFunction_properties, /* tp_getset */
1825     nullptr, /* tp_base */
1826     nullptr, /* tp_dict */
1827     nullptr, /* tp_descr_get */
1828     nullptr, /* tp_descr_set */
1829     0, /* tp_dictoffset */
1830     nullptr, /* tp_init */
1831     nullptr, /* tp_alloc */
1832     THPFunction_new /* tp_new */
1833 };
1834 
THPFunction_initModule(PyObject * module)1835 bool THPFunction_initModule(PyObject* module) {
1836   if (PyType_Ready(&THPFunctionType) < 0)
1837     return false;
1838   Py_INCREF(&THPFunctionType);
1839   PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType);
1840   return true;
1841 }
1842