xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""
2*da0073e9SAndroid Build Coastguard WorkerTo run this file by hand from the root of the PyTorch
3*da0073e9SAndroid Build Coastguard Workerrepository, run:
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerpython -m tools.autograd.gen_autograd \
6*da0073e9SAndroid Build Coastguard Worker       aten/src/ATen/native/native_functions.yaml \
7*da0073e9SAndroid Build Coastguard Worker       aten/src/ATen/native/tags.yaml \
8*da0073e9SAndroid Build Coastguard Worker       $OUTPUT_DIR \
9*da0073e9SAndroid Build Coastguard Worker       tools/autograd
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard WorkerWhere $OUTPUT_DIR is where you would like the files to be
12*da0073e9SAndroid Build Coastguard Workergenerated.  In the full build system, OUTPUT_DIR is
13*da0073e9SAndroid Build Coastguard Workertorch/csrc/autograd/generated/
14*da0073e9SAndroid Build Coastguard Worker"""
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# gen_autograd.py generates C++ autograd functions and Python bindings.
17*da0073e9SAndroid Build Coastguard Worker#
18*da0073e9SAndroid Build Coastguard Worker# It delegates to the following scripts:
19*da0073e9SAndroid Build Coastguard Worker#
20*da0073e9SAndroid Build Coastguard Worker#  gen_autograd_functions.py: generates subclasses of torch::autograd::Node
21*da0073e9SAndroid Build Coastguard Worker#  gen_variable_type.py: generates VariableType.h which contains all tensor methods
22*da0073e9SAndroid Build Coastguard Worker#  gen_python_functions.py: generates Python bindings to THPVariable
23*da0073e9SAndroid Build Coastguard Worker#
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerimport argparse
28*da0073e9SAndroid Build Coastguard Workerimport os
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
31*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.autograd import (
32*da0073e9SAndroid Build Coastguard Worker    match_differentiability_info,
33*da0073e9SAndroid Build Coastguard Worker    NativeFunctionWithDifferentiabilityInfo,
34*da0073e9SAndroid Build Coastguard Worker)
35*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import parse_native_yaml
36*da0073e9SAndroid Build Coastguard Workerfrom torchgen.selective_build.selector import SelectiveBuilder
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerfrom . import gen_python_functions
39*da0073e9SAndroid Build Coastguard Workerfrom .gen_autograd_functions import (
40*da0073e9SAndroid Build Coastguard Worker    gen_autograd_functions_lib,
41*da0073e9SAndroid Build Coastguard Worker    gen_autograd_functions_python,
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Workerfrom .gen_inplace_or_view_type import gen_inplace_or_view_type
44*da0073e9SAndroid Build Coastguard Workerfrom .gen_trace_type import gen_trace_type
45*da0073e9SAndroid Build Coastguard Workerfrom .gen_variable_factories import gen_variable_factories
46*da0073e9SAndroid Build Coastguard Workerfrom .gen_variable_type import gen_variable_type
47*da0073e9SAndroid Build Coastguard Workerfrom .gen_view_funcs import gen_view_funcs
48*da0073e9SAndroid Build Coastguard Workerfrom .load_derivatives import load_derivatives
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerdef gen_autograd(
52*da0073e9SAndroid Build Coastguard Worker    native_functions_path: str,
53*da0073e9SAndroid Build Coastguard Worker    tags_path: str,
54*da0073e9SAndroid Build Coastguard Worker    out: str,
55*da0073e9SAndroid Build Coastguard Worker    autograd_dir: str,
56*da0073e9SAndroid Build Coastguard Worker    operator_selector: SelectiveBuilder,
57*da0073e9SAndroid Build Coastguard Worker    disable_autograd: bool = False,
58*da0073e9SAndroid Build Coastguard Worker) -> None:
59*da0073e9SAndroid Build Coastguard Worker    # Parse and load derivatives.yaml
60*da0073e9SAndroid Build Coastguard Worker    differentiability_infos, used_dispatch_keys = load_derivatives(
61*da0073e9SAndroid Build Coastguard Worker        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
62*da0073e9SAndroid Build Coastguard Worker    )
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    template_path = os.path.join(autograd_dir, "templates")
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
67*da0073e9SAndroid Build Coastguard Worker    fns = sorted(
68*da0073e9SAndroid Build Coastguard Worker        filter(
69*da0073e9SAndroid Build Coastguard Worker            operator_selector.is_native_function_selected_for_training, native_funcs
70*da0073e9SAndroid Build Coastguard Worker        ),
71*da0073e9SAndroid Build Coastguard Worker        key=lambda f: cpp.name(f.func),
72*da0073e9SAndroid Build Coastguard Worker    )
73*da0073e9SAndroid Build Coastguard Worker    fns_with_diff_infos: list[
74*da0073e9SAndroid Build Coastguard Worker        NativeFunctionWithDifferentiabilityInfo
75*da0073e9SAndroid Build Coastguard Worker    ] = match_differentiability_info(fns, differentiability_infos)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    # Generate VariableType.h/cpp
78*da0073e9SAndroid Build Coastguard Worker    if not disable_autograd:
79*da0073e9SAndroid Build Coastguard Worker        gen_variable_type(
80*da0073e9SAndroid Build Coastguard Worker            out,
81*da0073e9SAndroid Build Coastguard Worker            native_functions_path,
82*da0073e9SAndroid Build Coastguard Worker            tags_path,
83*da0073e9SAndroid Build Coastguard Worker            fns_with_diff_infos,
84*da0073e9SAndroid Build Coastguard Worker            template_path,
85*da0073e9SAndroid Build Coastguard Worker            used_dispatch_keys,
86*da0073e9SAndroid Build Coastguard Worker        )
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        gen_inplace_or_view_type(
89*da0073e9SAndroid Build Coastguard Worker            out, native_functions_path, tags_path, fns_with_diff_infos, template_path
90*da0073e9SAndroid Build Coastguard Worker        )
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        # operator filter not applied as tracing sources are excluded in selective build
93*da0073e9SAndroid Build Coastguard Worker        gen_trace_type(out, native_funcs, template_path)
94*da0073e9SAndroid Build Coastguard Worker    # Generate Functions.h/cpp
95*da0073e9SAndroid Build Coastguard Worker    gen_autograd_functions_lib(out, differentiability_infos, template_path)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    # Generate variable_factories.h
98*da0073e9SAndroid Build Coastguard Worker    gen_variable_factories(out, native_functions_path, tags_path, template_path)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    # Generate ViewFuncs.h/cpp
101*da0073e9SAndroid Build Coastguard Worker    gen_view_funcs(out, fns_with_diff_infos, template_path)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef gen_autograd_python(
105*da0073e9SAndroid Build Coastguard Worker    native_functions_path: str,
106*da0073e9SAndroid Build Coastguard Worker    tags_path: str,
107*da0073e9SAndroid Build Coastguard Worker    out: str,
108*da0073e9SAndroid Build Coastguard Worker    autograd_dir: str,
109*da0073e9SAndroid Build Coastguard Worker) -> None:
110*da0073e9SAndroid Build Coastguard Worker    differentiability_infos, _ = load_derivatives(
111*da0073e9SAndroid Build Coastguard Worker        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
112*da0073e9SAndroid Build Coastguard Worker    )
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    template_path = os.path.join(autograd_dir, "templates")
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    # Generate Functions.h/cpp
117*da0073e9SAndroid Build Coastguard Worker    gen_autograd_functions_python(out, differentiability_infos, template_path)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    # Generate Python bindings
120*da0073e9SAndroid Build Coastguard Worker    deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
121*da0073e9SAndroid Build Coastguard Worker    gen_python_functions.gen(
122*da0073e9SAndroid Build Coastguard Worker        out, native_functions_path, tags_path, deprecated_path, template_path
123*da0073e9SAndroid Build Coastguard Worker    )
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
127*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
128*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
129*da0073e9SAndroid Build Coastguard Worker        "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
130*da0073e9SAndroid Build Coastguard Worker    )
131*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
132*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("out", metavar="OUT", help="path to output directory")
133*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
134*da0073e9SAndroid Build Coastguard Worker        "autograd", metavar="AUTOGRAD", help="path to autograd directory"
135*da0073e9SAndroid Build Coastguard Worker    )
136*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
137*da0073e9SAndroid Build Coastguard Worker    gen_autograd(
138*da0073e9SAndroid Build Coastguard Worker        args.native_functions,
139*da0073e9SAndroid Build Coastguard Worker        args.tags,
140*da0073e9SAndroid Build Coastguard Worker        args.out,
141*da0073e9SAndroid Build Coastguard Worker        args.autograd,
142*da0073e9SAndroid Build Coastguard Worker        SelectiveBuilder.get_nop_selector(),
143*da0073e9SAndroid Build Coastguard Worker    )
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
147*da0073e9SAndroid Build Coastguard Worker    main()
148