1""" 2To run this file by hand from the root of the PyTorch 3repository, run: 4 5python -m tools.autograd.gen_autograd \ 6 aten/src/ATen/native/native_functions.yaml \ 7 aten/src/ATen/native/tags.yaml \ 8 $OUTPUT_DIR \ 9 tools/autograd 10 11Where $OUTPUT_DIR is where you would like the files to be 12generated. In the full build system, OUTPUT_DIR is 13torch/csrc/autograd/generated/ 14""" 15 16# gen_autograd.py generates C++ autograd functions and Python bindings. 17# 18# It delegates to the following scripts: 19# 20# gen_autograd_functions.py: generates subclasses of torch::autograd::Node 21# gen_variable_type.py: generates VariableType.h which contains all tensor methods 22# gen_python_functions.py: generates Python bindings to THPVariable 23# 24 25from __future__ import annotations 26 27import argparse 28import os 29 30from torchgen.api import cpp 31from torchgen.api.autograd import ( 32 match_differentiability_info, 33 NativeFunctionWithDifferentiabilityInfo, 34) 35from torchgen.gen import parse_native_yaml 36from torchgen.selective_build.selector import SelectiveBuilder 37 38from . import gen_python_functions 39from .gen_autograd_functions import ( 40 gen_autograd_functions_lib, 41 gen_autograd_functions_python, 42) 43from .gen_inplace_or_view_type import gen_inplace_or_view_type 44from .gen_trace_type import gen_trace_type 45from .gen_variable_factories import gen_variable_factories 46from .gen_variable_type import gen_variable_type 47from .gen_view_funcs import gen_view_funcs 48from .load_derivatives import load_derivatives 49 50 51def gen_autograd( 52 native_functions_path: str, 53 tags_path: str, 54 out: str, 55 autograd_dir: str, 56 operator_selector: SelectiveBuilder, 57 disable_autograd: bool = False, 58) -> None: 59 # Parse and load derivatives.yaml 60 differentiability_infos, used_dispatch_keys = load_derivatives( 61 os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path 62 ) 63 64 template_path = os.path.join(autograd_dir, "templates") 65 66 native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions 67 fns = sorted( 68 filter( 69 operator_selector.is_native_function_selected_for_training, native_funcs 70 ), 71 key=lambda f: cpp.name(f.func), 72 ) 73 fns_with_diff_infos: list[ 74 NativeFunctionWithDifferentiabilityInfo 75 ] = match_differentiability_info(fns, differentiability_infos) 76 77 # Generate VariableType.h/cpp 78 if not disable_autograd: 79 gen_variable_type( 80 out, 81 native_functions_path, 82 tags_path, 83 fns_with_diff_infos, 84 template_path, 85 used_dispatch_keys, 86 ) 87 88 gen_inplace_or_view_type( 89 out, native_functions_path, tags_path, fns_with_diff_infos, template_path 90 ) 91 92 # operator filter not applied as tracing sources are excluded in selective build 93 gen_trace_type(out, native_funcs, template_path) 94 # Generate Functions.h/cpp 95 gen_autograd_functions_lib(out, differentiability_infos, template_path) 96 97 # Generate variable_factories.h 98 gen_variable_factories(out, native_functions_path, tags_path, template_path) 99 100 # Generate ViewFuncs.h/cpp 101 gen_view_funcs(out, fns_with_diff_infos, template_path) 102 103 104def gen_autograd_python( 105 native_functions_path: str, 106 tags_path: str, 107 out: str, 108 autograd_dir: str, 109) -> None: 110 differentiability_infos, _ = load_derivatives( 111 os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path 112 ) 113 114 template_path = os.path.join(autograd_dir, "templates") 115 116 # Generate Functions.h/cpp 117 gen_autograd_functions_python(out, differentiability_infos, template_path) 118 119 # Generate Python bindings 120 deprecated_path = os.path.join(autograd_dir, "deprecated.yaml") 121 gen_python_functions.gen( 122 out, native_functions_path, tags_path, deprecated_path, template_path 123 ) 124 125 126def main() -> None: 127 parser = argparse.ArgumentParser(description="Generate autograd C++ files script") 128 parser.add_argument( 129 "native_functions", metavar="NATIVE", help="path to native_functions.yaml" 130 ) 131 parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml") 132 parser.add_argument("out", metavar="OUT", help="path to output directory") 133 parser.add_argument( 134 "autograd", metavar="AUTOGRAD", help="path to autograd directory" 135 ) 136 args = parser.parse_args() 137 gen_autograd( 138 args.native_functions, 139 args.tags, 140 args.out, 141 args.autograd, 142 SelectiveBuilder.get_nop_selector(), 143 ) 144 145 146if __name__ == "__main__": 147 main() 148