xref: /aosp_15_r20/external/pytorch/tools/setup_helpers/generate_code.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import os
5import sys
6from pathlib import Path
7from typing import Any, cast
8
9import yaml
10
11
12try:
13    # use faster C loader if available
14    from yaml import CSafeLoader as YamlLoader
15except ImportError:
16    from yaml import SafeLoader as YamlLoader  # type: ignore[assignment, misc]
17
18NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml"
19TAGS_PATH = "aten/src/ATen/native/tags.yaml"
20
21
22def generate_code(
23    gen_dir: Path,
24    native_functions_path: str | None = None,
25    tags_path: str | None = None,
26    install_dir: str | None = None,
27    subset: str | None = None,
28    disable_autograd: bool = False,
29    force_schema_registration: bool = False,
30    operator_selector: Any = None,
31) -> None:
32    from tools.autograd.gen_annotated_fn_args import gen_annotated
33    from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
34
35    from torchgen.selective_build.selector import SelectiveBuilder
36
37    # Build ATen based Variable classes
38    if install_dir is None:
39        install_dir = os.fspath(gen_dir / "torch/csrc")
40        python_install_dir = os.fspath(gen_dir / "torch/testing/_internal/generated")
41    else:
42        python_install_dir = install_dir
43    autograd_gen_dir = os.path.join(install_dir, "autograd", "generated")
44    for d in (autograd_gen_dir, python_install_dir):
45        os.makedirs(d, exist_ok=True)
46    autograd_dir = os.fspath(Path(__file__).parent.parent / "autograd")
47
48    if subset == "pybindings" or not subset:
49        gen_autograd_python(
50            native_functions_path or NATIVE_FUNCTIONS_PATH,
51            tags_path or TAGS_PATH,
52            autograd_gen_dir,
53            autograd_dir,
54        )
55
56    if operator_selector is None:
57        operator_selector = SelectiveBuilder.get_nop_selector()
58
59    if subset == "libtorch" or not subset:
60        gen_autograd(
61            native_functions_path or NATIVE_FUNCTIONS_PATH,
62            tags_path or TAGS_PATH,
63            autograd_gen_dir,
64            autograd_dir,
65            disable_autograd=disable_autograd,
66            operator_selector=operator_selector,
67        )
68
69    if subset == "python" or not subset:
70        gen_annotated(
71            native_functions_path or NATIVE_FUNCTIONS_PATH,
72            tags_path or TAGS_PATH,
73            python_install_dir,
74            autograd_dir,
75        )
76
77
78def get_selector_from_legacy_operator_selection_list(
79    selected_op_list_path: str,
80) -> Any:
81    with open(selected_op_list_path) as f:
82        # strip out the overload part
83        # It's only for legacy config - do NOT copy this code!
84        selected_op_list = {
85            opname.split(".", 1)[0] for opname in yaml.load(f, Loader=YamlLoader)
86        }
87
88    # Internal build doesn't use this flag any more. Only used by OSS
89    # build now. Every operator should be considered a root operator
90    # (hence generating unboxing code for it, which is consistent with
91    # the current behavior), and also be considered as used for
92    # training, since OSS doesn't support training on mobile for now.
93    #
94    is_root_operator = True
95    is_used_for_training = True
96
97    from torchgen.selective_build.selector import SelectiveBuilder
98
99    selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
100        selected_op_list,
101        is_root_operator,
102        is_used_for_training,
103    )
104
105    return selector
106
107
108def get_selector(
109    selected_op_list_path: str | None,
110    operators_yaml_path: str | None,
111) -> Any:
112    # cwrap depends on pyyaml, so we can't import it earlier
113    root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
114    sys.path.insert(0, root)
115    from torchgen.selective_build.selector import SelectiveBuilder
116
117    assert not (
118        selected_op_list_path is not None and operators_yaml_path is not None
119    ), (
120        "Expected at most one of selected_op_list_path and "
121        + "operators_yaml_path to be set."
122    )
123
124    if selected_op_list_path is None and operators_yaml_path is None:
125        return SelectiveBuilder.get_nop_selector()
126    elif selected_op_list_path is not None:
127        return get_selector_from_legacy_operator_selection_list(selected_op_list_path)
128    else:
129        return SelectiveBuilder.from_yaml_path(cast(str, operators_yaml_path))
130
131
132def main() -> None:
133    parser = argparse.ArgumentParser(description="Autogenerate code")
134    parser.add_argument("--native-functions-path")
135    parser.add_argument("--tags-path")
136    parser.add_argument(
137        "--gen-dir",
138        type=Path,
139        default=Path("."),
140        help="Root directory where to install files. Defaults to the current working directory.",
141    )
142    parser.add_argument(
143        "--install-dir",
144        "--install_dir",
145        help=(
146            "Deprecated. Use --gen-dir instead. The semantics are different, do not change "
147            "blindly."
148        ),
149    )
150    parser.add_argument(
151        "--subset",
152        help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.',
153    )
154    parser.add_argument(
155        "--disable-autograd",
156        default=False,
157        action="store_true",
158        help="It can skip generating autograd related code when the flag is set",
159    )
160    parser.add_argument(
161        "--selected-op-list-path",
162        help="Path to the YAML file that contains the list of operators to include for custom build.",
163    )
164    parser.add_argument(
165        "--operators-yaml-path",
166        "--operators_yaml_path",
167        help="Path to the model YAML file that contains the list of operators to include for custom build.",
168    )
169    parser.add_argument(
170        "--force-schema-registration",
171        "--force_schema_registration",
172        action="store_true",
173        help="force it to generate schema-only registrations for ops that are not"
174        "listed on --selected-op-list",
175    )
176    parser.add_argument(
177        "--gen-lazy-ts-backend",
178        "--gen_lazy_ts_backend",
179        action="store_true",
180        help="Enable generation of the torch::lazy TorchScript backend",
181    )
182    parser.add_argument(
183        "--per-operator-headers",
184        "--per_operator_headers",
185        action="store_true",
186        help="Build lazy tensor ts backend with per-operator ATen headers, must match how ATen was built",
187    )
188    options = parser.parse_args()
189
190    generate_code(
191        options.gen_dir,
192        options.native_functions_path,
193        options.tags_path,
194        options.install_dir,
195        options.subset,
196        options.disable_autograd,
197        options.force_schema_registration,
198        # options.selected_op_list
199        operator_selector=get_selector(
200            options.selected_op_list_path, options.operators_yaml_path
201        ),
202    )
203
204    if options.gen_lazy_ts_backend:
205        aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
206        ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
207        ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
208        ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
209        install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")
210        lazy_install_dir = os.path.join(install_dir, "lazy/generated")
211        os.makedirs(lazy_install_dir, exist_ok=True)
212
213        assert os.path.isfile(
214            ts_backend_yaml
215        ), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
216        assert os.path.isfile(
217            ts_native_functions
218        ), f"Unable to access {ts_native_functions}"
219        from torchgen.dest.lazy_ir import GenTSLazyIR
220        from torchgen.gen_lazy_tensor import run_gen_lazy_tensor
221
222        run_gen_lazy_tensor(
223            aten_path=aten_path,
224            source_yaml=ts_backend_yaml,
225            backend_name="TorchScript",
226            output_dir=lazy_install_dir,
227            dry_run=False,
228            impl_path=ts_native_functions,
229            node_base="TsNode",
230            node_base_hdr=ts_node_base,
231            build_in_tree=True,
232            lazy_ir_generator=GenTSLazyIR,
233            per_operator_headers=options.per_operator_headers,
234            gen_forced_fallback_code=True,
235        )
236
237
238if __name__ == "__main__":
239    main()
240