xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2To run this file by hand from the root of the PyTorch
3repository, run:
4
5python -m tools.autograd.gen_autograd \
6       aten/src/ATen/native/native_functions.yaml \
7       aten/src/ATen/native/tags.yaml \
8       $OUTPUT_DIR \
9       tools/autograd
10
11Where $OUTPUT_DIR is where you would like the files to be
12generated.  In the full build system, OUTPUT_DIR is
13torch/csrc/autograd/generated/
14"""
15
16# gen_autograd.py generates C++ autograd functions and Python bindings.
17#
18# It delegates to the following scripts:
19#
20#  gen_autograd_functions.py: generates subclasses of torch::autograd::Node
21#  gen_variable_type.py: generates VariableType.h which contains all tensor methods
22#  gen_python_functions.py: generates Python bindings to THPVariable
23#
24
25from __future__ import annotations
26
27import argparse
28import os
29
30from torchgen.api import cpp
31from torchgen.api.autograd import (
32    match_differentiability_info,
33    NativeFunctionWithDifferentiabilityInfo,
34)
35from torchgen.gen import parse_native_yaml
36from torchgen.selective_build.selector import SelectiveBuilder
37
38from . import gen_python_functions
39from .gen_autograd_functions import (
40    gen_autograd_functions_lib,
41    gen_autograd_functions_python,
42)
43from .gen_inplace_or_view_type import gen_inplace_or_view_type
44from .gen_trace_type import gen_trace_type
45from .gen_variable_factories import gen_variable_factories
46from .gen_variable_type import gen_variable_type
47from .gen_view_funcs import gen_view_funcs
48from .load_derivatives import load_derivatives
49
50
51def gen_autograd(
52    native_functions_path: str,
53    tags_path: str,
54    out: str,
55    autograd_dir: str,
56    operator_selector: SelectiveBuilder,
57    disable_autograd: bool = False,
58) -> None:
59    # Parse and load derivatives.yaml
60    differentiability_infos, used_dispatch_keys = load_derivatives(
61        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
62    )
63
64    template_path = os.path.join(autograd_dir, "templates")
65
66    native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
67    fns = sorted(
68        filter(
69            operator_selector.is_native_function_selected_for_training, native_funcs
70        ),
71        key=lambda f: cpp.name(f.func),
72    )
73    fns_with_diff_infos: list[
74        NativeFunctionWithDifferentiabilityInfo
75    ] = match_differentiability_info(fns, differentiability_infos)
76
77    # Generate VariableType.h/cpp
78    if not disable_autograd:
79        gen_variable_type(
80            out,
81            native_functions_path,
82            tags_path,
83            fns_with_diff_infos,
84            template_path,
85            used_dispatch_keys,
86        )
87
88        gen_inplace_or_view_type(
89            out, native_functions_path, tags_path, fns_with_diff_infos, template_path
90        )
91
92        # operator filter not applied as tracing sources are excluded in selective build
93        gen_trace_type(out, native_funcs, template_path)
94    # Generate Functions.h/cpp
95    gen_autograd_functions_lib(out, differentiability_infos, template_path)
96
97    # Generate variable_factories.h
98    gen_variable_factories(out, native_functions_path, tags_path, template_path)
99
100    # Generate ViewFuncs.h/cpp
101    gen_view_funcs(out, fns_with_diff_infos, template_path)
102
103
104def gen_autograd_python(
105    native_functions_path: str,
106    tags_path: str,
107    out: str,
108    autograd_dir: str,
109) -> None:
110    differentiability_infos, _ = load_derivatives(
111        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
112    )
113
114    template_path = os.path.join(autograd_dir, "templates")
115
116    # Generate Functions.h/cpp
117    gen_autograd_functions_python(out, differentiability_infos, template_path)
118
119    # Generate Python bindings
120    deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
121    gen_python_functions.gen(
122        out, native_functions_path, tags_path, deprecated_path, template_path
123    )
124
125
126def main() -> None:
127    parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
128    parser.add_argument(
129        "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
130    )
131    parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
132    parser.add_argument("out", metavar="OUT", help="path to output directory")
133    parser.add_argument(
134        "autograd", metavar="AUTOGRAD", help="path to autograd directory"
135    )
136    args = parser.parse_args()
137    gen_autograd(
138        args.native_functions,
139        args.tags,
140        args.out,
141        args.autograd,
142        SelectiveBuilder.get_nop_selector(),
143    )
144
145
146if __name__ == "__main__":
147    main()
148