1# Generates C++ autograd functions for the derivatives of ATen operations 2# 3# This writes two files: 4# Functions.h/cpp: subclasses of autograd::Node 5# python_functions.h/cpp: Python bindings for the above classes 6# 7 8from __future__ import annotations 9 10from typing import Sequence 11 12from torchgen.api.autograd import ( 13 Derivative, 14 DifferentiabilityInfo, 15 SavedAttribute, 16 uses_retain_variables, 17 uses_single_grad, 18) 19from torchgen.api.types import ( 20 ArrayRefCType, 21 BaseCppType, 22 BaseCType, 23 Binding, 24 boolT, 25 doubleT, 26 intArrayRefT, 27 iTensorListRefT, 28 ListCType, 29 longT, 30 MutRefCType, 31 OptionalCType, 32 optionalIntArrayRefT, 33 optionalSymIntArrayRefT, 34 scalarT, 35 stringT, 36 symIntArrayRefT, 37 SymIntT, 38 TENSOR_LIST_LIKE_CTYPES, 39 tensorListT, 40 tensorT, 41 VectorCType, 42) 43from torchgen.code_template import CodeTemplate 44from torchgen.model import Argument, FunctionSchema 45from torchgen.utils import FileManager 46 47from .gen_inplace_or_view_type import VIEW_FUNCTIONS 48 49 50FUNCTION_DECLARATION = CodeTemplate( 51 """\ 52#ifdef _WIN32 53struct ${op} : public ${superclass} { 54 TORCH_API ${op}() = default; 55#else 56struct TORCH_API ${op} : public ${superclass} { 57#endif 58 using ${superclass}::${superclass}; 59 variable_list apply(variable_list&& grads) override; 60 std::string name() const override { return "${op}"; } 61 void release_variables() override { 62 ${thread_lock} 63 ${release_variables} 64 } 65 ${will_release_variables} 66 void compiled_args(CompiledNodeArgs& args) override; 67 variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; 68 ${saved_variables} 69 ${saved_list_sizes} 70}; 71""" 72) 73 74WILL_RELEASE_VARIABLES = CodeTemplate( 75 """\ 76bool retain_variables = true; 77void will_release_variables() override { 78 retain_variables = false; 79} 80""" 81) 82 83FUNCTION_DEFINITION = CodeTemplate( 84 """\ 85variable_list ${op}::apply(variable_list&& grads) { 86 ${thread_lock} 87 ${asserts} 88 IndexRangeGenerator gen; 89 ${compute_index_ranges} 90 variable_list grad_inputs(gen.size()); 91 ${body} 92 return grad_inputs; 93} 94void ${op}::compiled_args(CompiledNodeArgs& args) { 95 ${compiled_args} 96} 97variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { 98 ${apply_with_saved_before} 99 variable_list result = apply(variable_list(grads)); 100 ${apply_with_saved_after} 101 return result; 102} 103""" 104) 105 106GRAD_INPUT_MASK = CodeTemplate( 107 """\ 108 auto grad_input_mask = std::array<bool, ${n}>{ 109 ${masks} 110 };\ 111""" 112) 113 114DERIVATIVE_SINGLE = CodeTemplate( 115 """\ 116if (task_should_compute_output({ ${name}_ix })) { 117 auto grad_result = ${derivative}; 118 copy_range(grad_inputs, ${name}_ix, grad_result); 119} 120""" 121) 122 123# note(crcrpar): `self` argument and other optional positional argument 124# of foreach functions are basically a list of n `Tensor`s thus iterating over 125# `grads` in order to utilize and apply the existing derivative definitions 126# to each `Tensor`(s) of `self`, and the others. 127DERIVATIVE_SINGLE_FOREACH = CodeTemplate( 128 """\ 129if (task_should_compute_output({ ${name}_ix })) { 130 std::vector<Tensor> grad_result; 131 grad_result.reserve(grads.size()); 132 for (const auto & i : c10::irange(grads.size())) { 133 if (grads[i].defined()) { 134 grad_result.emplace_back(${derivative}); 135 } else { 136 grad_result.emplace_back(Tensor()); 137 } 138 } 139 copy_range(grad_inputs, ${name}_ix, grad_result); 140} 141""" 142) 143 144DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( 145 """\ 146 if (task_should_compute_output({ ${name}_ix })) { 147 copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); 148 } 149""" 150) 151 152DERIVATIVE_MULTI = CodeTemplate( 153 """\ 154if (task_should_compute_output({ ${idx_ranges} })) { 155 ${grad_input_mask} 156 auto grad_result = ${derivative}; 157 ${copy_ranges} 158} 159""" 160) 161 162# Generates python bindings 163# 164# This generates the definitions for: 165# (1) The PyTypeObject for each backward grad_fn subclassing Node 166# (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs) 167# We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs 168# Each PyGetSetDef has a function ptr to a getter, also defined here (3). 169# (3) Getters for each of grad_fn's saved inputs and outputs. 170# 171PY_FUNCTION_DEFINITION = CodeTemplate( 172 """\ 173static PyTypeObject ${op}Class; 174addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties); 175""" 176) 177 178PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate( 179 """\ 180${all_getter_definitions} 181 182static struct PyGetSetDef ${op}_properties[] = { 183 THP_FUNCTION_DEFAULT_PROPERTIES, 184 ${all_getsetdef_structs} 185 {nullptr} /* sentinel */ 186}; 187 188""" 189) 190 191PY_GETSETDEF_STRUCT = CodeTemplate( 192 """\ 193{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""" 194) 195 196PY_RAW_GETSETDEF_STRUCT = CodeTemplate( 197 """\ 198{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}""" 199) 200 201# Getter templates 202GETTER_DEFINITION = CodeTemplate( 203 """\ 204PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { 205 HANDLE_TH_ERRORS 206 auto prop = static_cast<${op}*>(self->cdata.get())->${name}; 207 ${body} 208 END_HANDLE_TH_ERRORS 209} 210""" 211) 212 213GETTER_DEFINITION_SAVEDVAR = CodeTemplate( 214 """\ 215PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { 216 HANDLE_TH_ERRORS 217 const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; 218 ${body} 219 END_HANDLE_TH_ERRORS 220} 221""" 222) 223 224GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate( 225 """\ 226PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { 227 HANDLE_TH_ERRORS 228 const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; 229 ${body} 230 END_HANDLE_TH_ERRORS 231} 232""" 233) 234 235GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate( 236 """\ 237PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { 238 HANDLE_TH_ERRORS 239 const auto *node = static_cast<${op}*>(self->cdata.get()); 240 const auto& prop = node->${name}_; 241 if (node->${name}_released_) { 242 PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); 243 return nullptr; 244 } 245 ${body} 246 END_HANDLE_TH_ERRORS 247} 248""" 249) 250 251GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate( 252 """\ 253PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { 254 HANDLE_TH_ERRORS 255 const auto *node = static_cast<${op}*>(self->cdata.get()); 256 const auto& prop = node->${name}_; 257 if (node->${name}_released_) { 258 PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); 259 return nullptr; 260 } 261 ${body} 262 END_HANDLE_TH_ERRORS 263} 264""" 265) 266 267GETTER_DEFINITION_OPT = CodeTemplate( 268 """\ 269PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { 270 HANDLE_TH_ERRORS 271 auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; 272 if (!opt_prop.has_value()) { 273 Py_RETURN_NONE; 274 } 275 auto prop = opt_prop.value(); 276 ${body} 277 END_HANDLE_TH_ERRORS 278} 279""" 280) 281 282GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate( 283 """\ 284PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { 285 HANDLE_TH_ERRORS 286 auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; 287 if (!opt_prop.list.has_value()) { 288 Py_RETURN_NONE; 289 } 290 auto prop = opt_prop.list.value(); 291 ${body} 292 END_HANDLE_TH_ERRORS 293} 294""" 295) 296 297# Getter body 298GETTER_BODY_SAVEDVAR = """\ 299return THPVariable_Wrap(prop.unpack(self->cdata)); 300""" 301 302GETTER_BODY_RAW_SAVEDVAR = """\ 303pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference); 304return obj.release().ptr(); 305""" 306 307GETTER_BODY_VEC_SAVEDVAR = """\ 308PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); 309for (auto i: c10::irange(prop.size())) { 310 PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata))); 311} 312return tup; 313""" 314 315GETTER_BODY_RAW_VEC_SAVEDVAR = """\ 316PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); 317for (auto i : c10::irange(prop.size())) { 318 pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference); 319 PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr()); 320} 321return tup; 322""" 323 324GETTER_BODY_ARRAYREF_LONG = """\ 325PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); 326for (auto i : c10::irange(prop.size())) { 327 PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i])); 328} 329return tup; 330""" 331 332GETTER_BODY_ARRAYREF_SYMINT = """\ 333PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); 334for (auto i : c10::irange(prop.size())) { 335 auto si = prop[i]; 336 if (auto m = si.maybe_as_int()) { 337 PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m)); 338 } else { 339 auto py_symint = py::cast(si).release().ptr(); 340 PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); 341 } 342} 343return tup; 344""" 345 346GETTER_BODY_ARRAYREF_DOUBLE = """\ 347PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); 348for (auto i : c10::irange(prop.size())) { 349 PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i])); 350} 351return tup; 352""" 353 354GETTER_BODY_INT64_T = """\ 355return PyLong_FromUnsignedLong((int64_t) prop); 356""" 357 358GETTER_BODY_SYMINT = """\ 359if (auto m = prop.maybe_as_int()) { 360 return PyLong_FromUnsignedLong(*m); 361} else { 362 return py::cast(prop).release().ptr(); 363} 364""" 365 366GETTER_BODY_DOUBLE = """\ 367return PyFloat_FromDouble((double) prop); 368""" 369 370GETTER_BODY_BOOL = """\ 371if (prop) { 372 Py_RETURN_TRUE; 373} else { 374 Py_RETURN_FALSE; 375} 376""" 377 378GETTER_BODY_STRING = """\ 379return PyUnicode_FromStringAndSize(prop.data(), prop.size()); 380""" 381 382GETTER_BODY_SCALAR = """\ 383if (prop.isComplex()) { 384 auto cprop = prop.to<c10::complex<double>>(); 385 return PyComplex_FromDoubles(cprop.real(), cprop.imag()); 386} else if (prop.isFloatingPoint()) { 387 return PyFloat_FromDouble(prop.to<double>()); 388} else if (prop.isIntegral(/*includeBool=*/false)) { 389 return PyLong_FromLong(prop.to<int64_t>()); 390} else if (prop.isBoolean()) { 391 if (prop.to<bool>()) { 392 Py_RETURN_TRUE; 393 } else { 394 Py_RETURN_FALSE; 395 } 396} else { 397 PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); 398 return nullptr; 399} 400""" 401 402 403GETTER_BODY_VEC_SCALAR = """\ 404PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); 405for (auto i: c10::irange(prop.size())) { 406 if (prop[i].isComplex()) { 407 auto cprop = prop[i].to<c10::complex<double>>(); 408 PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag())); 409 } else if (prop[i].isFloatingPoint()) { 410 auto double_prop = prop[i].to<double>(); 411 PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop)); 412 } else if (prop[i].isIntegral(/*includeBool=*/false)) { 413 auto long_prop = prop[i].to<int64_t>(); 414 PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop)); 415 } else if (prop[i].isBoolean()) { 416 if (prop[i].to<bool>()) { 417 PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True); 418 } else { 419 PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False); 420 } 421 } else { 422 PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); 423 return nullptr; 424 } 425} 426return tup; 427""" 428 429 430MISC_GETTER_DEFS = { 431 OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), 432 OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT), 433 BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), 434 OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), 435 BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), 436 BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), 437 OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), 438} 439 440# These functions have backwards which cannot be traced, and so must have 441# their backward functions traced opaquely. 442# VIEW_FUNCTIONS are not traceable because they use as_strided, which 443# has an untraceable backwards, see 444# https://github.com/pytorch/pytorch/issues/4250 445# TODO: This is probably not exhaustive, but it's a start 446UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS 447 448 449def get_infos_with_derivatives_list( 450 differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] 451) -> list[DifferentiabilityInfo]: 452 diff_info_list = [ 453 info 454 for diffinfo_dict in differentiability_infos.values() 455 for info in diffinfo_dict.values() 456 ] 457 458 return list(filter(lambda info: info.args_with_derivatives, diff_info_list)) 459 460 461def gen_autograd_functions_lib( 462 out: str, 463 differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], 464 template_path: str, 465) -> None: 466 """Functions.h and Functions.cpp body 467 468 These contain the auto-generated subclasses of torch::autograd::Node 469 for each every differentiable torch function. 470 """ 471 472 # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here 473 # infos with the diff dispatchkeys but the same name will still be in the same shard. 474 infos = get_infos_with_derivatives_list(differentiability_infos) 475 declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos] 476 definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos] 477 478 file_basename = "Functions" 479 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 480 for suffix in [".h", ".cpp"]: 481 fname = file_basename + suffix 482 fm.write_with_template( 483 fname, 484 fname, 485 lambda: { 486 "generated_comment": "@" 487 + f"generated from {fm.template_dir_for_comments()}/" 488 + fname, 489 "autograd_function_declarations": declarations, 490 "autograd_function_definitions": definitions, 491 }, 492 ) 493 494 495def gen_autograd_functions_python( 496 out: str, 497 differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], 498 template_path: str, 499) -> None: 500 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 501 num_shards = 5 502 fm.write( 503 "python_functions.h", 504 lambda: { 505 "generated_comment": "@" 506 + f"generated from {fm.template_dir_for_comments()}/python_functions.h", 507 "shard_forward_declare": [ 508 f"void initialize_autogenerated_functions_{i}(PyObject* module);" 509 for i in range(num_shards) 510 ], 511 "shard_call": [ 512 f"initialize_autogenerated_functions_{i}(module);" 513 for i in range(num_shards) 514 ], 515 }, 516 ) 517 518 # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here 519 # infos with the diff dispatchkeys but the same name will still be in the same shard. 520 infos = get_infos_with_derivatives_list(differentiability_infos) 521 fm.write_sharded( 522 "python_functions.cpp", 523 infos, 524 key_fn=lambda info: info.name, 525 base_env={ 526 "generated_comment": "@" 527 + f"generated from {fm.template_dir_for_comments()}/python_functions.cpp", 528 }, 529 env_callable=lambda info: { 530 "py_function_initializers": [ 531 process_function(info, PY_FUNCTION_DEFINITION) 532 ], 533 "py_function_props_and_getters": [ 534 process_function(info, PY_FUNCTION_PROPS_AND_GETTERS) 535 ], 536 }, 537 num_shards=num_shards, 538 sharded_keys={"py_function_initializers", "py_function_props_and_getters"}, 539 ) 540 541 542def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: 543 saved_variables: list[str] = [] 544 release_variables: list[str] = [] 545 saved_list_sizes: list[str] = [] 546 unpack: list[str] = [] 547 asserts: list[str] = [] 548 compute_index_ranges: list[str] = [] 549 getter_definitions: list[str] = [] 550 py_getsetdef_structs: list[str] = [] 551 compiled_args: list[str] = [] 552 apply_with_saved_before: list[str] = [] 553 apply_with_saved_after: list[str] = [] 554 555 for arg in info.args_with_derivatives: 556 if arg.type in TENSOR_LIST_LIKE_CTYPES: 557 size = f"{arg.name}_size_" 558 saved_list_sizes.append(f"size_t {arg.name}_size_;") 559 else: 560 size = "1" 561 compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});") 562 563 def save_var(var: SavedAttribute, is_output: bool) -> None: 564 name = var.nctype.name 565 type = var.nctype.type 566 should_append_getsetdef = True 567 should_append_raw_getsetdef = False 568 visit_name = name 569 uses_cpp_saved_variable_cls = False 570 571 if ( 572 type == BaseCType(tensorT) 573 or type == OptionalCType(BaseCType(tensorT)) 574 or type == MutRefCType(OptionalCType(BaseCType(tensorT))) 575 or (type == BaseCType(scalarT) and is_output) 576 ): 577 uses_cpp_saved_variable_cls = True 578 saved_variables.append(f"SavedVariable {name}_;") 579 release_variables.append(f"{name}_.reset_data();") 580 ptr = "shared_from_this()" if is_output else "" 581 unpack.append(f"auto {name} = {name}_.unpack({ptr});") 582 getter_definitions.append( 583 GETTER_DEFINITION_SAVEDVAR.substitute( 584 op=info.op, name=name, body=GETTER_BODY_SAVEDVAR 585 ) 586 ) 587 getter_definitions.append( 588 GETTER_DEFINITION_RAW_SAVEDVAR.substitute( 589 op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR 590 ) 591 ) 592 should_append_raw_getsetdef = True 593 visit_name = f"{name}_" 594 elif ( 595 type == BaseCType(tensorListT) 596 or type == BaseCType(iTensorListRefT) 597 or type == VectorCType(BaseCType(tensorT)) 598 ): 599 # note(crcrpar): [nuanced return type of out-of-place foreach functions] 600 # When an out-of-place foreach function whose return signature is `Tensor[]` 601 # spells out its backward definitions in `derivatives.yaml`, and some of them depend on 602 # `result`, `result`'s type is interpreted and treated as `std::vector<Tensor>`. 603 # An out-of-place foreach whose backwards rely on their output doesn't suffer from this 604 # difference if the definitions are codegen'ed. 605 # This special case is needed for `_foreach_pow.List` and `_foreach_pow.ScalarAndTensor` 606 # as of https://github.com/pytorch/pytorch/pull/105504. 607 if type == VectorCType(BaseCType(tensorT)): 608 assert ( 609 info.func.func.name.name.base.startswith("_foreach") and is_output 610 ) 611 uses_cpp_saved_variable_cls = True 612 saved_variables.append(f"std::vector<SavedVariable> {name}_;") 613 saved_variables.append(f"bool {name}_released_ = false;") 614 # Just clear() is sufficient, we don't need to loop and clear each variable. 615 # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. 616 release_variables.append(f"{name}_.clear();") 617 release_variables.append(f"{name}_released_ = true;") 618 ptr = "shared_from_this()" if is_output else "nullptr" 619 unpack.append(f"auto {name} = unpack_list({name}_, {ptr});") 620 asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") 621 getter_definitions.append( 622 GETTER_DEFINITION_VEC_SAVEDVAR.substitute( 623 op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR 624 ) 625 ) 626 getter_definitions.append( 627 GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( 628 op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR 629 ) 630 ) 631 should_append_raw_getsetdef = True 632 visit_name = f"{name}_" 633 elif type == ListCType(OptionalCType(BaseCType(tensorT))): 634 uses_cpp_saved_variable_cls = True 635 saved_variables.append(f"std::vector<SavedVariable> {name}_;") 636 saved_variables.append(f"bool {name}_released_ = false;") 637 # Just clear() is sufficient, we don't need to loop and clear each variable. 638 # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. 639 release_variables.append(f"{name}_.clear();") 640 release_variables.append(f"{name}_released_ = true;") 641 unpack.append(f"auto {name} = unpack_opt_list({name}_);") 642 asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") 643 getter_definitions.append( 644 GETTER_DEFINITION_VEC_SAVEDVAR.substitute( 645 op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR 646 ) 647 ) 648 getter_definitions.append( 649 GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( 650 op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR 651 ) 652 ) 653 should_append_raw_getsetdef = True 654 visit_name = f"{name}_" 655 elif type == BaseCType(intArrayRefT): 656 saved_variables.append(f"std::vector<int64_t> {name};") 657 getter_definitions.append( 658 GETTER_DEFINITION.substitute( 659 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG 660 ) 661 ) 662 elif type == BaseCType(symIntArrayRefT): 663 saved_variables.append(f"std::vector<c10::SymInt> {name};") 664 getter_definitions.append( 665 GETTER_DEFINITION.substitute( 666 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT 667 ) 668 ) 669 elif type == BaseCType(optionalIntArrayRefT): 670 saved_variables.append(f"c10::OptionalArray<int64_t> {name};") 671 getter_definitions.append( 672 GETTER_DEFINITION_OPT_ARRAYREF.substitute( 673 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG 674 ) 675 ) 676 elif type == BaseCType(optionalSymIntArrayRefT): 677 saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};") 678 getter_definitions.append( 679 GETTER_DEFINITION_OPT_ARRAYREF.substitute( 680 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT 681 ) 682 ) 683 elif type == OptionalCType(BaseCType(intArrayRefT)): 684 saved_variables.append(f"c10::OptionalArray<int64_t> {name};") 685 getter_definitions.append( 686 GETTER_DEFINITION_OPT_ARRAYREF.substitute( 687 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG 688 ) 689 ) 690 elif type == OptionalCType(BaseCType(symIntArrayRefT)): 691 saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};") 692 getter_definitions.append( 693 GETTER_DEFINITION_OPT_ARRAYREF.substitute( 694 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT 695 ) 696 ) 697 elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): 698 saved_variables.append(f"c10::OptionalArray<double> {name};") 699 getter_definitions.append( 700 GETTER_DEFINITION_OPT_ARRAYREF.substitute( 701 op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE 702 ) 703 ) 704 elif type == BaseCType(longT): 705 saved_variables.append(f"{type.cpp_type()} {name} = 0;") 706 getter_definitions.append( 707 GETTER_DEFINITION.substitute( 708 op=info.op, name=name, body=GETTER_BODY_INT64_T 709 ) 710 ) 711 elif type == BaseCType(SymIntT): 712 saved_variables.append(f"c10::SymInt {name};") 713 getter_definitions.append( 714 GETTER_DEFINITION.substitute( 715 op=info.op, name=name, body=GETTER_BODY_SYMINT 716 ) 717 ) 718 elif type == BaseCType(stringT): 719 saved_variables.append(f"std::string {name};") 720 getter_definitions.append( 721 GETTER_DEFINITION.substitute( 722 op=info.op, name=name, body=GETTER_BODY_STRING 723 ) 724 ) 725 elif type == OptionalCType(BaseCType(stringT)): 726 saved_variables.append(f"std::optional<std::string> {name};") 727 getter_definitions.append( 728 GETTER_DEFINITION_OPT.substitute( 729 op=info.op, name=name, body=GETTER_BODY_STRING 730 ) 731 ) 732 elif type == ArrayRefCType( 733 elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) 734 ): 735 saved_variables.append(f"std::vector<at::Scalar> {name};") 736 saved_variables.append(f"bool {name}_released_ = false;") 737 # Just clear() is sufficient, we don't need to loop and clear each variable. 738 # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. 739 release_variables.append(f"{name}.clear();") 740 # release_variables.append(f"{name}_released_ = true;") 741 # unpack.append(f"auto {name} = unpack_list({name}_);") 742 # asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") 743 getter_definitions.append( 744 CodeTemplate( 745 """\ 746PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { 747 HANDLE_TH_ERRORS 748 const auto *node = static_cast<${op}*>(self->cdata.get()); 749 const auto& prop = node->${name}; 750 if (node->${name}_released_) { 751 PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); 752 return nullptr; 753 } 754 ${body} 755 END_HANDLE_TH_ERRORS 756} 757 """ 758 ).substitute( 759 op=info.op, 760 name=name, 761 body=GETTER_BODY_VEC_SCALAR, 762 ) 763 ) 764 else: 765 # Check for indicators that you're putting a non-owning reference 766 # into the saved variable field. If this is spuriously firing, 767 # edit this field. Otherwise, you probably need to add a case 768 # above. 769 assert ( 770 "ref" not in type.cpp_type().lower() 771 and "view" not in type.cpp_type().lower() 772 and "*" not in type.cpp_type() 773 and "&" not in type.cpp_type() 774 ), f"{type.cpp_type()} looks like it contains a non-owning reference" 775 saved_variables.append(f"{type.cpp_type()} {name};") 776 777 if type in MISC_GETTER_DEFS: 778 getter_def, body = MISC_GETTER_DEFS[type] 779 getter_definitions.append( 780 getter_def.substitute(op=info.op, name=name, body=body) 781 ) 782 else: 783 # Types we don't expose python bindings to yet: 784 # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, 785 # std::vector<std::vector<int64_t>>, std::vector<at::ScalarType> 786 should_append_getsetdef = False 787 788 if should_append_getsetdef: 789 py_getsetdef_structs.append( 790 PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name) 791 ) 792 if should_append_raw_getsetdef: 793 py_getsetdef_structs.append( 794 PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name) 795 ) 796 797 if uses_cpp_saved_variable_cls: 798 compiled_args.append( 799 f"args.collect({visit_name}, {'true' if is_output else 'false'});" 800 ) 801 else: 802 compiled_args.append(f"args.collect({visit_name});") 803 apply_with_saved_before.append(f"saved.before({visit_name});") 804 apply_with_saved_after.append(f"saved.after({visit_name});") 805 806 for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)): 807 save_var(var, is_output=False) 808 for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)): 809 save_var(var, is_output=True) 810 811 # lock the mutex when we release variables and in Node::apply to protect thread safety 812 # see Note [Thread Safety on Autograd Node] 813 if len(release_variables) > 0: 814 thread_lock = "std::lock_guard<std::mutex> lock(mutex_);" 815 else: 816 thread_lock = "" 817 818 if uses_retain_variables(info): 819 will_release_variables = WILL_RELEASE_VARIABLES.substitute() 820 else: 821 will_release_variables = "" 822 823 body: list[str] = [] 824 825 if uses_single_grad(info): 826 body.append("const auto& grad = grads[0];") 827 else: 828 # Generate aliases for gradients named for returned values. 829 body.extend( 830 f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];" 831 for name in sorted(info.used_named_gradients) 832 ) 833 834 def emit_derivative( 835 derivative: Derivative, 836 args_with_derivatives: Sequence[Binding], 837 ) -> tuple[bool, str]: 838 formula = derivative.formula 839 var_names = derivative.var_names 840 if len(var_names) == 1: 841 checks_any_grad_defined = False 842 if "not_implemented" not in formula: 843 matching_args = [ 844 arg for arg in args_with_derivatives if arg.name == var_names[0] 845 ] 846 if len(matching_args) == 1: 847 # We can add undefined grad support if the input variable is a Tensor 848 arg = matching_args[0] 849 if isinstance(arg.argument, Argument) and str( 850 arg.argument.type 851 ) in ("Tensor", "Tensor?"): 852 formula = "any_grad_defined ? (" + formula + ") : Tensor()" 853 checks_any_grad_defined = True 854 if info.name.startswith("_foreach_"): 855 derivative_template = DERIVATIVE_SINGLE_FOREACH 856 else: 857 derivative_template = DERIVATIVE_SINGLE 858 return ( 859 checks_any_grad_defined, 860 derivative_template.substitute(name=var_names[0], derivative=formula), 861 ) 862 else: 863 if "grad_input_mask" in formula: 864 masks = [ 865 f"task_should_compute_output({{ {n}_ix }})," for n in var_names 866 ] 867 grad_input_mask = GRAD_INPUT_MASK.substitute( 868 masks=masks, n=len(var_names) 869 ) 870 else: 871 grad_input_mask = "" 872 idx_ranges = ", ".join(f"{n}_ix" for n in var_names) 873 copy_ranges: list[str] = [] 874 for i, n in enumerate(var_names): 875 copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i)) 876 return False, DERIVATIVE_MULTI.substitute( 877 idx_ranges=idx_ranges, 878 copy_ranges=copy_ranges, 879 derivative=formula, 880 grad_input_mask=grad_input_mask, 881 ) 882 883 body.extend(unpack) 884 need_any_grad_defined_var = False 885 for derivative in info.derivatives: 886 checks_any_grad_defined, derivative_text = emit_derivative( 887 derivative, info.args_with_derivatives 888 ) 889 body.append(derivative_text) 890 need_any_grad_defined_var |= checks_any_grad_defined 891 # Since single-output derivative formulas need to check if grads are 892 # defined, only perform the check once, before all the formulas 893 if need_any_grad_defined_var: 894 body.insert( 895 -len(info.derivatives), 896 "bool any_grad_defined = any_variable_defined(grads);", 897 ) 898 899 if info.name in UNTRACEABLE_FUNCTIONS: 900 superclass = "Node" 901 else: 902 superclass = "TraceableFunction" 903 904 all_getsetdef_structs = ( 905 ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else "" 906 ) 907 all_getter_definitions = "\n".join(getter_definitions) 908 909 return template.substitute( 910 op=info.op, 911 compute_index_ranges=compute_index_ranges, 912 saved_variables=saved_variables, 913 release_variables=release_variables, 914 saved_list_sizes=saved_list_sizes, 915 asserts=asserts, 916 thread_lock=thread_lock, 917 will_release_variables=will_release_variables, 918 body=body, 919 superclass=superclass, 920 all_getter_definitions=all_getter_definitions, 921 all_getsetdef_structs=all_getsetdef_structs, 922 compiled_args=compiled_args, 923 apply_with_saved_before=apply_with_saved_before, 924 apply_with_saved_after=apply_with_saved_after, 925 ) 926