1*da0073e9SAndroid Build Coastguard Worker""" 2*da0073e9SAndroid Build Coastguard WorkerTo run this file by hand from the root of the PyTorch 3*da0073e9SAndroid Build Coastguard Workerrepository, run: 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerpython -m tools.autograd.gen_autograd \ 6*da0073e9SAndroid Build Coastguard Worker aten/src/ATen/native/native_functions.yaml \ 7*da0073e9SAndroid Build Coastguard Worker aten/src/ATen/native/tags.yaml \ 8*da0073e9SAndroid Build Coastguard Worker $OUTPUT_DIR \ 9*da0073e9SAndroid Build Coastguard Worker tools/autograd 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard WorkerWhere $OUTPUT_DIR is where you would like the files to be 12*da0073e9SAndroid Build Coastguard Workergenerated. In the full build system, OUTPUT_DIR is 13*da0073e9SAndroid Build Coastguard Workertorch/csrc/autograd/generated/ 14*da0073e9SAndroid Build Coastguard Worker""" 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# gen_autograd.py generates C++ autograd functions and Python bindings. 17*da0073e9SAndroid Build Coastguard Worker# 18*da0073e9SAndroid Build Coastguard Worker# It delegates to the following scripts: 19*da0073e9SAndroid Build Coastguard Worker# 20*da0073e9SAndroid Build Coastguard Worker# gen_autograd_functions.py: generates subclasses of torch::autograd::Node 21*da0073e9SAndroid Build Coastguard Worker# gen_variable_type.py: generates VariableType.h which contains all tensor methods 22*da0073e9SAndroid Build Coastguard Worker# gen_python_functions.py: generates Python bindings to THPVariable 23*da0073e9SAndroid Build Coastguard Worker# 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerimport argparse 28*da0073e9SAndroid Build Coastguard Workerimport os 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 31*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.autograd import ( 32*da0073e9SAndroid Build Coastguard Worker match_differentiability_info, 33*da0073e9SAndroid Build Coastguard Worker NativeFunctionWithDifferentiabilityInfo, 34*da0073e9SAndroid Build Coastguard Worker) 35*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import parse_native_yaml 36*da0073e9SAndroid Build Coastguard Workerfrom torchgen.selective_build.selector import SelectiveBuilder 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerfrom . import gen_python_functions 39*da0073e9SAndroid Build Coastguard Workerfrom .gen_autograd_functions import ( 40*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions_lib, 41*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions_python, 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Workerfrom .gen_inplace_or_view_type import gen_inplace_or_view_type 44*da0073e9SAndroid Build Coastguard Workerfrom .gen_trace_type import gen_trace_type 45*da0073e9SAndroid Build Coastguard Workerfrom .gen_variable_factories import gen_variable_factories 46*da0073e9SAndroid Build Coastguard Workerfrom .gen_variable_type import gen_variable_type 47*da0073e9SAndroid Build Coastguard Workerfrom .gen_view_funcs import gen_view_funcs 48*da0073e9SAndroid Build Coastguard Workerfrom .load_derivatives import load_derivatives 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Workerdef gen_autograd( 52*da0073e9SAndroid Build Coastguard Worker native_functions_path: str, 53*da0073e9SAndroid Build Coastguard Worker tags_path: str, 54*da0073e9SAndroid Build Coastguard Worker out: str, 55*da0073e9SAndroid Build Coastguard Worker autograd_dir: str, 56*da0073e9SAndroid Build Coastguard Worker operator_selector: SelectiveBuilder, 57*da0073e9SAndroid Build Coastguard Worker disable_autograd: bool = False, 58*da0073e9SAndroid Build Coastguard Worker) -> None: 59*da0073e9SAndroid Build Coastguard Worker # Parse and load derivatives.yaml 60*da0073e9SAndroid Build Coastguard Worker differentiability_infos, used_dispatch_keys = load_derivatives( 61*da0073e9SAndroid Build Coastguard Worker os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker template_path = os.path.join(autograd_dir, "templates") 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions 67*da0073e9SAndroid Build Coastguard Worker fns = sorted( 68*da0073e9SAndroid Build Coastguard Worker filter( 69*da0073e9SAndroid Build Coastguard Worker operator_selector.is_native_function_selected_for_training, native_funcs 70*da0073e9SAndroid Build Coastguard Worker ), 71*da0073e9SAndroid Build Coastguard Worker key=lambda f: cpp.name(f.func), 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker fns_with_diff_infos: list[ 74*da0073e9SAndroid Build Coastguard Worker NativeFunctionWithDifferentiabilityInfo 75*da0073e9SAndroid Build Coastguard Worker ] = match_differentiability_info(fns, differentiability_infos) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker # Generate VariableType.h/cpp 78*da0073e9SAndroid Build Coastguard Worker if not disable_autograd: 79*da0073e9SAndroid Build Coastguard Worker gen_variable_type( 80*da0073e9SAndroid Build Coastguard Worker out, 81*da0073e9SAndroid Build Coastguard Worker native_functions_path, 82*da0073e9SAndroid Build Coastguard Worker tags_path, 83*da0073e9SAndroid Build Coastguard Worker fns_with_diff_infos, 84*da0073e9SAndroid Build Coastguard Worker template_path, 85*da0073e9SAndroid Build Coastguard Worker used_dispatch_keys, 86*da0073e9SAndroid Build Coastguard Worker ) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker gen_inplace_or_view_type( 89*da0073e9SAndroid Build Coastguard Worker out, native_functions_path, tags_path, fns_with_diff_infos, template_path 90*da0073e9SAndroid Build Coastguard Worker ) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker # operator filter not applied as tracing sources are excluded in selective build 93*da0073e9SAndroid Build Coastguard Worker gen_trace_type(out, native_funcs, template_path) 94*da0073e9SAndroid Build Coastguard Worker # Generate Functions.h/cpp 95*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions_lib(out, differentiability_infos, template_path) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker # Generate variable_factories.h 98*da0073e9SAndroid Build Coastguard Worker gen_variable_factories(out, native_functions_path, tags_path, template_path) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker # Generate ViewFuncs.h/cpp 101*da0073e9SAndroid Build Coastguard Worker gen_view_funcs(out, fns_with_diff_infos, template_path) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerdef gen_autograd_python( 105*da0073e9SAndroid Build Coastguard Worker native_functions_path: str, 106*da0073e9SAndroid Build Coastguard Worker tags_path: str, 107*da0073e9SAndroid Build Coastguard Worker out: str, 108*da0073e9SAndroid Build Coastguard Worker autograd_dir: str, 109*da0073e9SAndroid Build Coastguard Worker) -> None: 110*da0073e9SAndroid Build Coastguard Worker differentiability_infos, _ = load_derivatives( 111*da0073e9SAndroid Build Coastguard Worker os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker template_path = os.path.join(autograd_dir, "templates") 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # Generate Functions.h/cpp 117*da0073e9SAndroid Build Coastguard Worker gen_autograd_functions_python(out, differentiability_infos, template_path) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker # Generate Python bindings 120*da0073e9SAndroid Build Coastguard Worker deprecated_path = os.path.join(autograd_dir, "deprecated.yaml") 121*da0073e9SAndroid Build Coastguard Worker gen_python_functions.gen( 122*da0073e9SAndroid Build Coastguard Worker out, native_functions_path, tags_path, deprecated_path, template_path 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Workerdef main() -> None: 127*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Generate autograd C++ files script") 128*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 129*da0073e9SAndroid Build Coastguard Worker "native_functions", metavar="NATIVE", help="path to native_functions.yaml" 130*da0073e9SAndroid Build Coastguard Worker ) 131*da0073e9SAndroid Build Coastguard Worker parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml") 132*da0073e9SAndroid Build Coastguard Worker parser.add_argument("out", metavar="OUT", help="path to output directory") 133*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 134*da0073e9SAndroid Build Coastguard Worker "autograd", metavar="AUTOGRAD", help="path to autograd directory" 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 137*da0073e9SAndroid Build Coastguard Worker gen_autograd( 138*da0073e9SAndroid Build Coastguard Worker args.native_functions, 139*da0073e9SAndroid Build Coastguard Worker args.tags, 140*da0073e9SAndroid Build Coastguard Worker args.out, 141*da0073e9SAndroid Build Coastguard Worker args.autograd, 142*da0073e9SAndroid Build Coastguard Worker SelectiveBuilder.get_nop_selector(), 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 147*da0073e9SAndroid Build Coastguard Worker main() 148