xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_autograd_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Generates C++ autograd functions for the derivatives of ATen operations
2#
3# This writes two files:
4#  Functions.h/cpp: subclasses of autograd::Node
5#  python_functions.h/cpp: Python bindings for the above classes
6#
7
8from __future__ import annotations
9
10from typing import Sequence
11
12from torchgen.api.autograd import (
13    Derivative,
14    DifferentiabilityInfo,
15    SavedAttribute,
16    uses_retain_variables,
17    uses_single_grad,
18)
19from torchgen.api.types import (
20    ArrayRefCType,
21    BaseCppType,
22    BaseCType,
23    Binding,
24    boolT,
25    doubleT,
26    intArrayRefT,
27    iTensorListRefT,
28    ListCType,
29    longT,
30    MutRefCType,
31    OptionalCType,
32    optionalIntArrayRefT,
33    optionalSymIntArrayRefT,
34    scalarT,
35    stringT,
36    symIntArrayRefT,
37    SymIntT,
38    TENSOR_LIST_LIKE_CTYPES,
39    tensorListT,
40    tensorT,
41    VectorCType,
42)
43from torchgen.code_template import CodeTemplate
44from torchgen.model import Argument, FunctionSchema
45from torchgen.utils import FileManager
46
47from .gen_inplace_or_view_type import VIEW_FUNCTIONS
48
49
50FUNCTION_DECLARATION = CodeTemplate(
51    """\
52#ifdef _WIN32
53struct ${op} : public ${superclass} {
54  TORCH_API ${op}() = default;
55#else
56struct TORCH_API ${op} : public ${superclass} {
57#endif
58  using ${superclass}::${superclass};
59  variable_list apply(variable_list&& grads) override;
60  std::string name() const override { return "${op}"; }
61  void release_variables() override {
62    ${thread_lock}
63    ${release_variables}
64  }
65  ${will_release_variables}
66  void compiled_args(CompiledNodeArgs& args) override;
67  variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
68  ${saved_variables}
69  ${saved_list_sizes}
70};
71"""
72)
73
74WILL_RELEASE_VARIABLES = CodeTemplate(
75    """\
76bool retain_variables = true;
77void will_release_variables() override {
78  retain_variables = false;
79}
80"""
81)
82
83FUNCTION_DEFINITION = CodeTemplate(
84    """\
85variable_list ${op}::apply(variable_list&& grads) {
86  ${thread_lock}
87  ${asserts}
88  IndexRangeGenerator gen;
89  ${compute_index_ranges}
90  variable_list grad_inputs(gen.size());
91  ${body}
92  return grad_inputs;
93}
94void ${op}::compiled_args(CompiledNodeArgs& args) {
95    ${compiled_args}
96}
97variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
98    ${apply_with_saved_before}
99    variable_list result = apply(variable_list(grads));
100    ${apply_with_saved_after}
101    return result;
102}
103"""
104)
105
106GRAD_INPUT_MASK = CodeTemplate(
107    """\
108  auto grad_input_mask = std::array<bool, ${n}>{
109    ${masks}
110  };\
111"""
112)
113
114DERIVATIVE_SINGLE = CodeTemplate(
115    """\
116if (task_should_compute_output({ ${name}_ix })) {
117  auto grad_result = ${derivative};
118  copy_range(grad_inputs, ${name}_ix, grad_result);
119}
120"""
121)
122
123# note(crcrpar): `self` argument and other optional positional argument
124# of foreach functions are basically a list of n `Tensor`s thus iterating over
125# `grads` in order to utilize and apply the existing derivative definitions
126# to each `Tensor`(s) of `self`, and the others.
127DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
128    """\
129if (task_should_compute_output({ ${name}_ix })) {
130  std::vector<Tensor> grad_result;
131  grad_result.reserve(grads.size());
132  for (const auto & i : c10::irange(grads.size())) {
133    if (grads[i].defined()) {
134      grad_result.emplace_back(${derivative});
135    } else {
136      grad_result.emplace_back(Tensor());
137    }
138  }
139  copy_range(grad_inputs, ${name}_ix, grad_result);
140}
141"""
142)
143
144DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
145    """\
146  if (task_should_compute_output({ ${name}_ix })) {
147    copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
148  }
149"""
150)
151
152DERIVATIVE_MULTI = CodeTemplate(
153    """\
154if (task_should_compute_output({ ${idx_ranges} })) {
155  ${grad_input_mask}
156  auto grad_result = ${derivative};
157  ${copy_ranges}
158}
159"""
160)
161
162# Generates python bindings
163#
164# This generates the definitions for:
165#   (1) The PyTypeObject for each backward grad_fn subclassing Node
166#   (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs)
167#       We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs
168#       Each PyGetSetDef has a function ptr to a getter, also defined here (3).
169#   (3) Getters for each of grad_fn's saved inputs and outputs.
170#
171PY_FUNCTION_DEFINITION = CodeTemplate(
172    """\
173static PyTypeObject ${op}Class;
174addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties);
175"""
176)
177
178PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate(
179    """\
180${all_getter_definitions}
181
182static struct PyGetSetDef ${op}_properties[] = {
183  THP_FUNCTION_DEFAULT_PROPERTIES,
184  ${all_getsetdef_structs}
185  {nullptr} /* sentinel */
186};
187
188"""
189)
190
191PY_GETSETDEF_STRUCT = CodeTemplate(
192    """\
193{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}"""
194)
195
196PY_RAW_GETSETDEF_STRUCT = CodeTemplate(
197    """\
198{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}"""
199)
200
201# Getter templates
202GETTER_DEFINITION = CodeTemplate(
203    """\
204PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
205  HANDLE_TH_ERRORS
206  auto prop = static_cast<${op}*>(self->cdata.get())->${name};
207  ${body}
208  END_HANDLE_TH_ERRORS
209}
210"""
211)
212
213GETTER_DEFINITION_SAVEDVAR = CodeTemplate(
214    """\
215PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
216  HANDLE_TH_ERRORS
217  const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
218  ${body}
219  END_HANDLE_TH_ERRORS
220}
221"""
222)
223
224GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate(
225    """\
226PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
227  HANDLE_TH_ERRORS
228  const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
229  ${body}
230  END_HANDLE_TH_ERRORS
231}
232"""
233)
234
235GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate(
236    """\
237PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
238  HANDLE_TH_ERRORS
239  const auto *node = static_cast<${op}*>(self->cdata.get());
240  const auto& prop = node->${name}_;
241  if (node->${name}_released_) {
242    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
243    return nullptr;
244  }
245  ${body}
246  END_HANDLE_TH_ERRORS
247}
248"""
249)
250
251GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate(
252    """\
253PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
254  HANDLE_TH_ERRORS
255  const auto *node = static_cast<${op}*>(self->cdata.get());
256  const auto& prop = node->${name}_;
257  if (node->${name}_released_) {
258    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
259    return nullptr;
260  }
261  ${body}
262  END_HANDLE_TH_ERRORS
263}
264"""
265)
266
267GETTER_DEFINITION_OPT = CodeTemplate(
268    """\
269PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
270  HANDLE_TH_ERRORS
271  auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
272  if (!opt_prop.has_value()) {
273    Py_RETURN_NONE;
274  }
275  auto prop = opt_prop.value();
276  ${body}
277  END_HANDLE_TH_ERRORS
278}
279"""
280)
281
282GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate(
283    """\
284PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
285  HANDLE_TH_ERRORS
286  auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
287  if (!opt_prop.list.has_value()) {
288    Py_RETURN_NONE;
289  }
290  auto prop = opt_prop.list.value();
291  ${body}
292  END_HANDLE_TH_ERRORS
293}
294"""
295)
296
297# Getter body
298GETTER_BODY_SAVEDVAR = """\
299return THPVariable_Wrap(prop.unpack(self->cdata));
300"""
301
302GETTER_BODY_RAW_SAVEDVAR = """\
303pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference);
304return obj.release().ptr();
305"""
306
307GETTER_BODY_VEC_SAVEDVAR = """\
308PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
309for (auto i: c10::irange(prop.size())) {
310  PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata)));
311}
312return tup;
313"""
314
315GETTER_BODY_RAW_VEC_SAVEDVAR = """\
316PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
317for (auto i : c10::irange(prop.size())) {
318  pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference);
319  PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr());
320}
321return tup;
322"""
323
324GETTER_BODY_ARRAYREF_LONG = """\
325PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
326for (auto i : c10::irange(prop.size())) {
327  PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i]));
328}
329return tup;
330"""
331
332GETTER_BODY_ARRAYREF_SYMINT = """\
333PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
334for (auto i : c10::irange(prop.size())) {
335    auto si = prop[i];
336    if (auto m = si.maybe_as_int()) {
337      PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m));
338    } else {
339      auto py_symint = py::cast(si).release().ptr();
340      PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
341    }
342}
343return tup;
344"""
345
346GETTER_BODY_ARRAYREF_DOUBLE = """\
347PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
348for (auto i : c10::irange(prop.size())) {
349  PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i]));
350}
351return tup;
352"""
353
354GETTER_BODY_INT64_T = """\
355return PyLong_FromUnsignedLong((int64_t) prop);
356"""
357
358GETTER_BODY_SYMINT = """\
359if (auto m = prop.maybe_as_int()) {
360  return PyLong_FromUnsignedLong(*m);
361} else {
362  return py::cast(prop).release().ptr();
363}
364"""
365
366GETTER_BODY_DOUBLE = """\
367return PyFloat_FromDouble((double) prop);
368"""
369
370GETTER_BODY_BOOL = """\
371if (prop) {
372  Py_RETURN_TRUE;
373} else {
374  Py_RETURN_FALSE;
375}
376"""
377
378GETTER_BODY_STRING = """\
379return PyUnicode_FromStringAndSize(prop.data(), prop.size());
380"""
381
382GETTER_BODY_SCALAR = """\
383if (prop.isComplex()) {
384  auto cprop = prop.to<c10::complex<double>>();
385  return PyComplex_FromDoubles(cprop.real(), cprop.imag());
386} else if (prop.isFloatingPoint()) {
387  return PyFloat_FromDouble(prop.to<double>());
388} else if (prop.isIntegral(/*includeBool=*/false)) {
389  return PyLong_FromLong(prop.to<int64_t>());
390} else if (prop.isBoolean()) {
391  if (prop.to<bool>()) {
392    Py_RETURN_TRUE;
393  } else {
394    Py_RETURN_FALSE;
395  }
396} else {
397  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
398  return nullptr;
399}
400"""
401
402
403GETTER_BODY_VEC_SCALAR = """\
404PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
405for (auto i: c10::irange(prop.size())) {
406  if (prop[i].isComplex()) {
407    auto cprop = prop[i].to<c10::complex<double>>();
408    PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag()));
409  } else if (prop[i].isFloatingPoint()) {
410    auto double_prop = prop[i].to<double>();
411    PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop));
412  } else if (prop[i].isIntegral(/*includeBool=*/false)) {
413    auto long_prop = prop[i].to<int64_t>();
414    PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop));
415  } else if (prop[i].isBoolean()) {
416    if (prop[i].to<bool>()) {
417      PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True);
418    } else {
419      PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False);
420    }
421  } else {
422    PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
423    return nullptr;
424  }
425}
426return tup;
427"""
428
429
430MISC_GETTER_DEFS = {
431    OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
432    OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT),
433    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
434    OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
435    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
436    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
437    OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
438}
439
440# These functions have backwards which cannot be traced, and so must have
441# their backward functions traced opaquely.
442# VIEW_FUNCTIONS are not traceable because they use as_strided, which
443# has an untraceable backwards, see
444# https://github.com/pytorch/pytorch/issues/4250
445# TODO: This is probably not exhaustive, but it's a start
446UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
447
448
449def get_infos_with_derivatives_list(
450    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]]
451) -> list[DifferentiabilityInfo]:
452    diff_info_list = [
453        info
454        for diffinfo_dict in differentiability_infos.values()
455        for info in diffinfo_dict.values()
456    ]
457
458    return list(filter(lambda info: info.args_with_derivatives, diff_info_list))
459
460
461def gen_autograd_functions_lib(
462    out: str,
463    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
464    template_path: str,
465) -> None:
466    """Functions.h and Functions.cpp body
467
468    These contain the auto-generated subclasses of torch::autograd::Node
469    for each every differentiable torch function.
470    """
471
472    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
473    # infos with the diff dispatchkeys but the same name will still be in the same shard.
474    infos = get_infos_with_derivatives_list(differentiability_infos)
475    declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos]
476    definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos]
477
478    file_basename = "Functions"
479    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
480    for suffix in [".h", ".cpp"]:
481        fname = file_basename + suffix
482        fm.write_with_template(
483            fname,
484            fname,
485            lambda: {
486                "generated_comment": "@"
487                + f"generated from {fm.template_dir_for_comments()}/"
488                + fname,
489                "autograd_function_declarations": declarations,
490                "autograd_function_definitions": definitions,
491            },
492        )
493
494
495def gen_autograd_functions_python(
496    out: str,
497    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
498    template_path: str,
499) -> None:
500    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
501    num_shards = 5
502    fm.write(
503        "python_functions.h",
504        lambda: {
505            "generated_comment": "@"
506            + f"generated from {fm.template_dir_for_comments()}/python_functions.h",
507            "shard_forward_declare": [
508                f"void initialize_autogenerated_functions_{i}(PyObject* module);"
509                for i in range(num_shards)
510            ],
511            "shard_call": [
512                f"initialize_autogenerated_functions_{i}(module);"
513                for i in range(num_shards)
514            ],
515        },
516    )
517
518    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
519    # infos with the diff dispatchkeys but the same name will still be in the same shard.
520    infos = get_infos_with_derivatives_list(differentiability_infos)
521    fm.write_sharded(
522        "python_functions.cpp",
523        infos,
524        key_fn=lambda info: info.name,
525        base_env={
526            "generated_comment": "@"
527            + f"generated from {fm.template_dir_for_comments()}/python_functions.cpp",
528        },
529        env_callable=lambda info: {
530            "py_function_initializers": [
531                process_function(info, PY_FUNCTION_DEFINITION)
532            ],
533            "py_function_props_and_getters": [
534                process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
535            ],
536        },
537        num_shards=num_shards,
538        sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
539    )
540
541
542def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
543    saved_variables: list[str] = []
544    release_variables: list[str] = []
545    saved_list_sizes: list[str] = []
546    unpack: list[str] = []
547    asserts: list[str] = []
548    compute_index_ranges: list[str] = []
549    getter_definitions: list[str] = []
550    py_getsetdef_structs: list[str] = []
551    compiled_args: list[str] = []
552    apply_with_saved_before: list[str] = []
553    apply_with_saved_after: list[str] = []
554
555    for arg in info.args_with_derivatives:
556        if arg.type in TENSOR_LIST_LIKE_CTYPES:
557            size = f"{arg.name}_size_"
558            saved_list_sizes.append(f"size_t {arg.name}_size_;")
559        else:
560            size = "1"
561        compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
562
563    def save_var(var: SavedAttribute, is_output: bool) -> None:
564        name = var.nctype.name
565        type = var.nctype.type
566        should_append_getsetdef = True
567        should_append_raw_getsetdef = False
568        visit_name = name
569        uses_cpp_saved_variable_cls = False
570
571        if (
572            type == BaseCType(tensorT)
573            or type == OptionalCType(BaseCType(tensorT))
574            or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
575            or (type == BaseCType(scalarT) and is_output)
576        ):
577            uses_cpp_saved_variable_cls = True
578            saved_variables.append(f"SavedVariable {name}_;")
579            release_variables.append(f"{name}_.reset_data();")
580            ptr = "shared_from_this()" if is_output else ""
581            unpack.append(f"auto {name} = {name}_.unpack({ptr});")
582            getter_definitions.append(
583                GETTER_DEFINITION_SAVEDVAR.substitute(
584                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
585                )
586            )
587            getter_definitions.append(
588                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
589                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
590                )
591            )
592            should_append_raw_getsetdef = True
593            visit_name = f"{name}_"
594        elif (
595            type == BaseCType(tensorListT)
596            or type == BaseCType(iTensorListRefT)
597            or type == VectorCType(BaseCType(tensorT))
598        ):
599            # note(crcrpar): [nuanced return type of out-of-place foreach functions]
600            # When an out-of-place foreach function whose return signature is `Tensor[]`
601            # spells out its backward definitions in `derivatives.yaml`, and some of them depend on
602            # `result`, `result`'s type is interpreted and treated as `std::vector<Tensor>`.
603            # An out-of-place foreach whose backwards rely on their output doesn't suffer from this
604            # difference if the definitions are codegen'ed.
605            # This special case is needed for `_foreach_pow.List` and `_foreach_pow.ScalarAndTensor`
606            # as of https://github.com/pytorch/pytorch/pull/105504.
607            if type == VectorCType(BaseCType(tensorT)):
608                assert (
609                    info.func.func.name.name.base.startswith("_foreach") and is_output
610                )
611            uses_cpp_saved_variable_cls = True
612            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
613            saved_variables.append(f"bool {name}_released_ = false;")
614            # Just clear() is sufficient, we don't need to loop and clear each variable.
615            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
616            release_variables.append(f"{name}_.clear();")
617            release_variables.append(f"{name}_released_ = true;")
618            ptr = "shared_from_this()" if is_output else "nullptr"
619            unpack.append(f"auto {name} = unpack_list({name}_, {ptr});")
620            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
621            getter_definitions.append(
622                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
623                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
624                )
625            )
626            getter_definitions.append(
627                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
628                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
629                )
630            )
631            should_append_raw_getsetdef = True
632            visit_name = f"{name}_"
633        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
634            uses_cpp_saved_variable_cls = True
635            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
636            saved_variables.append(f"bool {name}_released_ = false;")
637            # Just clear() is sufficient, we don't need to loop and clear each variable.
638            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
639            release_variables.append(f"{name}_.clear();")
640            release_variables.append(f"{name}_released_ = true;")
641            unpack.append(f"auto {name} = unpack_opt_list({name}_);")
642            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
643            getter_definitions.append(
644                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
645                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
646                )
647            )
648            getter_definitions.append(
649                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
650                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
651                )
652            )
653            should_append_raw_getsetdef = True
654            visit_name = f"{name}_"
655        elif type == BaseCType(intArrayRefT):
656            saved_variables.append(f"std::vector<int64_t> {name};")
657            getter_definitions.append(
658                GETTER_DEFINITION.substitute(
659                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
660                )
661            )
662        elif type == BaseCType(symIntArrayRefT):
663            saved_variables.append(f"std::vector<c10::SymInt> {name};")
664            getter_definitions.append(
665                GETTER_DEFINITION.substitute(
666                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
667                )
668            )
669        elif type == BaseCType(optionalIntArrayRefT):
670            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
671            getter_definitions.append(
672                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
673                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
674                )
675            )
676        elif type == BaseCType(optionalSymIntArrayRefT):
677            saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};")
678            getter_definitions.append(
679                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
680                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
681                )
682            )
683        elif type == OptionalCType(BaseCType(intArrayRefT)):
684            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
685            getter_definitions.append(
686                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
687                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
688                )
689            )
690        elif type == OptionalCType(BaseCType(symIntArrayRefT)):
691            saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};")
692            getter_definitions.append(
693                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
694                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
695                )
696            )
697        elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
698            saved_variables.append(f"c10::OptionalArray<double> {name};")
699            getter_definitions.append(
700                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
701                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
702                )
703            )
704        elif type == BaseCType(longT):
705            saved_variables.append(f"{type.cpp_type()} {name} = 0;")
706            getter_definitions.append(
707                GETTER_DEFINITION.substitute(
708                    op=info.op, name=name, body=GETTER_BODY_INT64_T
709                )
710            )
711        elif type == BaseCType(SymIntT):
712            saved_variables.append(f"c10::SymInt {name};")
713            getter_definitions.append(
714                GETTER_DEFINITION.substitute(
715                    op=info.op, name=name, body=GETTER_BODY_SYMINT
716                )
717            )
718        elif type == BaseCType(stringT):
719            saved_variables.append(f"std::string {name};")
720            getter_definitions.append(
721                GETTER_DEFINITION.substitute(
722                    op=info.op, name=name, body=GETTER_BODY_STRING
723                )
724            )
725        elif type == OptionalCType(BaseCType(stringT)):
726            saved_variables.append(f"std::optional<std::string> {name};")
727            getter_definitions.append(
728                GETTER_DEFINITION_OPT.substitute(
729                    op=info.op, name=name, body=GETTER_BODY_STRING
730                )
731            )
732        elif type == ArrayRefCType(
733            elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
734        ):
735            saved_variables.append(f"std::vector<at::Scalar> {name};")
736            saved_variables.append(f"bool {name}_released_ = false;")
737            # Just clear() is sufficient, we don't need to loop and clear each variable.
738            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
739            release_variables.append(f"{name}.clear();")
740            # release_variables.append(f"{name}_released_ = true;")
741            # unpack.append(f"auto {name} = unpack_list({name}_);")
742            # asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
743            getter_definitions.append(
744                CodeTemplate(
745                    """\
746PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
747  HANDLE_TH_ERRORS
748  const auto *node = static_cast<${op}*>(self->cdata.get());
749  const auto& prop = node->${name};
750  if (node->${name}_released_) {
751    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
752    return nullptr;
753  }
754  ${body}
755  END_HANDLE_TH_ERRORS
756}
757                            """
758                ).substitute(
759                    op=info.op,
760                    name=name,
761                    body=GETTER_BODY_VEC_SCALAR,
762                )
763            )
764        else:
765            # Check for indicators that you're putting a non-owning reference
766            # into the saved variable field.  If this is spuriously firing,
767            # edit this field.  Otherwise, you probably need to add a case
768            # above.
769            assert (
770                "ref" not in type.cpp_type().lower()
771                and "view" not in type.cpp_type().lower()
772                and "*" not in type.cpp_type()
773                and "&" not in type.cpp_type()
774            ), f"{type.cpp_type()} looks like it contains a non-owning reference"
775            saved_variables.append(f"{type.cpp_type()} {name};")
776
777            if type in MISC_GETTER_DEFS:
778                getter_def, body = MISC_GETTER_DEFS[type]
779                getter_definitions.append(
780                    getter_def.substitute(op=info.op, name=name, body=body)
781                )
782            else:
783                # Types we don't expose python bindings to yet:
784                #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
785                #   std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
786                should_append_getsetdef = False
787
788        if should_append_getsetdef:
789            py_getsetdef_structs.append(
790                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
791            )
792        if should_append_raw_getsetdef:
793            py_getsetdef_structs.append(
794                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
795            )
796
797        if uses_cpp_saved_variable_cls:
798            compiled_args.append(
799                f"args.collect({visit_name}, {'true' if is_output else 'false'});"
800            )
801        else:
802            compiled_args.append(f"args.collect({visit_name});")
803        apply_with_saved_before.append(f"saved.before({visit_name});")
804        apply_with_saved_after.append(f"saved.after({visit_name});")
805
806    for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
807        save_var(var, is_output=False)
808    for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
809        save_var(var, is_output=True)
810
811    # lock the mutex when we release variables and in Node::apply to protect thread safety
812    # see Note [Thread Safety on Autograd Node]
813    if len(release_variables) > 0:
814        thread_lock = "std::lock_guard<std::mutex> lock(mutex_);"
815    else:
816        thread_lock = ""
817
818    if uses_retain_variables(info):
819        will_release_variables = WILL_RELEASE_VARIABLES.substitute()
820    else:
821        will_release_variables = ""
822
823    body: list[str] = []
824
825    if uses_single_grad(info):
826        body.append("const auto& grad = grads[0];")
827    else:
828        # Generate aliases for gradients named for returned values.
829        body.extend(
830            f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
831            for name in sorted(info.used_named_gradients)
832        )
833
834    def emit_derivative(
835        derivative: Derivative,
836        args_with_derivatives: Sequence[Binding],
837    ) -> tuple[bool, str]:
838        formula = derivative.formula
839        var_names = derivative.var_names
840        if len(var_names) == 1:
841            checks_any_grad_defined = False
842            if "not_implemented" not in formula:
843                matching_args = [
844                    arg for arg in args_with_derivatives if arg.name == var_names[0]
845                ]
846                if len(matching_args) == 1:
847                    # We can add undefined grad support if the input variable is a Tensor
848                    arg = matching_args[0]
849                    if isinstance(arg.argument, Argument) and str(
850                        arg.argument.type
851                    ) in ("Tensor", "Tensor?"):
852                        formula = "any_grad_defined ? (" + formula + ") : Tensor()"
853                        checks_any_grad_defined = True
854            if info.name.startswith("_foreach_"):
855                derivative_template = DERIVATIVE_SINGLE_FOREACH
856            else:
857                derivative_template = DERIVATIVE_SINGLE
858            return (
859                checks_any_grad_defined,
860                derivative_template.substitute(name=var_names[0], derivative=formula),
861            )
862        else:
863            if "grad_input_mask" in formula:
864                masks = [
865                    f"task_should_compute_output({{ {n}_ix }})," for n in var_names
866                ]
867                grad_input_mask = GRAD_INPUT_MASK.substitute(
868                    masks=masks, n=len(var_names)
869                )
870            else:
871                grad_input_mask = ""
872            idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
873            copy_ranges: list[str] = []
874            for i, n in enumerate(var_names):
875                copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
876            return False, DERIVATIVE_MULTI.substitute(
877                idx_ranges=idx_ranges,
878                copy_ranges=copy_ranges,
879                derivative=formula,
880                grad_input_mask=grad_input_mask,
881            )
882
883    body.extend(unpack)
884    need_any_grad_defined_var = False
885    for derivative in info.derivatives:
886        checks_any_grad_defined, derivative_text = emit_derivative(
887            derivative, info.args_with_derivatives
888        )
889        body.append(derivative_text)
890        need_any_grad_defined_var |= checks_any_grad_defined
891    # Since single-output derivative formulas need to check if grads are
892    # defined, only perform the check once, before all the formulas
893    if need_any_grad_defined_var:
894        body.insert(
895            -len(info.derivatives),
896            "bool any_grad_defined = any_variable_defined(grads);",
897        )
898
899    if info.name in UNTRACEABLE_FUNCTIONS:
900        superclass = "Node"
901    else:
902        superclass = "TraceableFunction"
903
904    all_getsetdef_structs = (
905        ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
906    )
907    all_getter_definitions = "\n".join(getter_definitions)
908
909    return template.substitute(
910        op=info.op,
911        compute_index_ranges=compute_index_ranges,
912        saved_variables=saved_variables,
913        release_variables=release_variables,
914        saved_list_sizes=saved_list_sizes,
915        asserts=asserts,
916        thread_lock=thread_lock,
917        will_release_variables=will_release_variables,
918        body=body,
919        superclass=superclass,
920        all_getter_definitions=all_getter_definitions,
921        all_getsetdef_structs=all_getsetdef_structs,
922        compiled_args=compiled_args,
923        apply_with_saved_before=apply_with_saved_before,
924        apply_with_saved_after=apply_with_saved_after,
925    )
926