xref: /aosp_15_r20/external/pytorch/tools/jit/gen_unboxing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
2
3from __future__ import annotations
4
5import argparse
6import os
7import sys
8from dataclasses import dataclass
9from pathlib import Path
10from typing import Literal, Sequence, TYPE_CHECKING
11
12import yaml
13
14from torchgen.api import cpp, unboxing
15from torchgen.api.translate import translate
16from torchgen.api.types import CppSignatureGroup
17from torchgen.api.unboxing import convert_arguments
18from torchgen.context import method_with_native_function
19from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
20from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
21from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
22
23
24if TYPE_CHECKING:
25    from torchgen.selective_build.selector import SelectiveBuilder
26
27
28# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
29@dataclass(frozen=True)
30class ComputeUnboxingFunctions:
31    target: Literal[Target.DECLARATION, Target.DEFINITION]
32    selector: SelectiveBuilder
33
34    @method_with_native_function
35    def __call__(self, f: NativeFunction) -> str:
36        if not self.selector.is_root_operator(f"aten::{f.func.name}"):
37            return ""
38
39        if self.target is Target.DECLARATION:
40            # Note [The ATen Codegen Unboxing API]
41            # Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and
42            # will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp).
43            # The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API.
44            #
45            # Important characteristics about the Codegen Unboxing API:
46            # (1) It follows the OperatorRegistry API.
47            #     This is kind of necessary to avoid overhead.
48            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
49            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
50            # (2) Under the hood it calls C++ API.
51            return f"""
52// aten::{f.func}
53TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack);
54"""
55        else:
56            sig_group = CppSignatureGroup.from_native_function(
57                f, method=(Variant.method in f.variants)
58            )
59            sig = sig_group.most_faithful_signature()
60            # parse arguments into C++ code
61            binding_list, code_list = convert_arguments(f)
62
63            # for each C++ argument, generate the conversion code
64            code_connector = "\n\t"
65            arg_connector = ", "
66            # function call and push back to stack
67            prefix = "self_base." if sig.method else "at::"
68            translated_args = translate(
69                binding_list, sig.arguments(), method=sig.method
70            )
71            args_str = f"{arg_connector.join(e.expr for e in translated_args)}"
72            if len(f.func.returns) == 0:
73                ret_str = ""
74                push_str = ""
75            else:
76                ret_str = "auto result_ = "
77                push_str = """
78    pack(stack, std::move(result_));
79                """
80            return f"""
81// aten::{f.func}
82TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack) {{
83    {code_connector.join(code_list)}
84
85    drop(stack, {len(binding_list)});
86
87    {ret_str}{prefix}{sig.name()}({args_str});
88    {push_str}
89}}
90"""
91
92
93# Generates RegisterCodegenUnboxedKernels.cpp.
94@dataclass(frozen=True)
95class ComputeCodegenUnboxedKernels:
96    selector: SelectiveBuilder
97
98    @method_with_native_function
99    def __call__(self, f: NativeFunction) -> str:
100        if not self.selector.is_root_operator(f"aten::{f.func.name}"):
101            return ""
102        # We unconditionally generate function wrappers,
103        sig_group = CppSignatureGroup.from_native_function(f, method=False)
104
105        sig = sig_group.most_faithful_signature()
106
107        # escape double quote in schema, get rid of extra double quotes
108        schema = cpp_string(str(sig.func))[1:-1]
109
110        # arguments
111        args = sig.arguments()
112        connector = ",\n\t\t"
113        args_code = []
114        for arg in args:
115            # Using method=False faithful C++ API, so we should not see SelfArgument/TensorOptionsArgument
116            assert isinstance(arg.argument, Argument)
117            if not arg.argument.default:
118                arg_cpp = "c10::IValue(::std::nullopt)"
119            else:
120                # The unboxing code uses the faithful C++ API to avoid the overhead
121                # from wrapping/unwrapping TensorOptios.
122                # However, we would look to include default args for schema parsing.
123                # Default args only show up in the nonfaithful C++ API,
124                arg_default = cpp.default_expr(
125                    arg.argument.default, arg.argument.type, symint=False
126                )
127                if arg_default.startswith("{"):
128                    arg_cpp = f"c10::IntArrayRef({arg_default})"
129                else:
130                    arg_cpp = f"c10::IValue({arg_default})"
131            args_code.append(
132                f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
133            )
134
135        returns = f.func.returns
136        returns_code = []
137        for ret in returns:
138            returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
139        return f"""
140// aten::{schema}
141OperatorGenerator(
142    "aten::{f.func.name.name}",
143    "{f.func.name.overload_name}",
144    {{
145        {connector.join(args_code)}
146    }},
147    {{
148        {connector.join(returns_code)}
149    }},
150    [](Stack & stack) {{
151        RECORD_FUNCTION("{sig.name()}", std::vector<c10::IValue>());
152        at::unboxing::{unboxing.name(f)}(stack);
153    }},
154    aliasAnalysisFromSchema()
155),
156"""
157
158
159def gen_unboxing(
160    *,
161    native_functions: Sequence[NativeFunction],
162    cpu_fm: FileManager,
163    selector: SelectiveBuilder,
164) -> None:
165    def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
166        return fn.root_name
167
168    selected_op_num: int = len(selector.operators)
169    # a best practice threshold of operators to enable sharding
170    sharding_threshold: int = 100
171    cpu_fm.write_sharded(
172        "UnboxingFunctions.cpp",
173        native_functions,
174        key_fn=key_func,
175        env_callable=lambda fn: {
176            "definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
177        },
178        num_shards=1 if selected_op_num < sharding_threshold else 5,
179        sharded_keys={"definitions"},
180    )
181    cpu_fm.write(
182        "UnboxingFunctions.h",
183        lambda: {
184            "declarations": list(
185                mapMaybe(
186                    ComputeUnboxingFunctions(Target.DECLARATION, selector),
187                    native_functions,
188                )
189            ),
190        },
191    )
192    cpu_fm.write_sharded(
193        "RegisterCodegenUnboxedKernels.cpp",
194        native_functions,
195        key_fn=key_func,
196        env_callable=lambda fn: {
197            "unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
198        },
199        num_shards=1 if selected_op_num < sharding_threshold else 10,
200        sharded_keys={"unboxed_ops"},
201    )
202
203
204def main(args: list[str]) -> None:
205    parser = argparse.ArgumentParser(description="Generate unboxing source files")
206    parser.add_argument(
207        "-s",
208        "--source-path",
209        help="path to source directory for ATen",
210        default="aten/src/ATen",
211    )
212    parser.add_argument(
213        "-d",
214        "--install-dir",
215        "--install_dir",
216        help="output directory",
217        default="build/aten/src/ATen",
218    )
219    parser.add_argument(
220        "-o",
221        "--output-dependencies",
222        help="output a list of dependencies into the given file and exit",
223    )
224    parser.add_argument(
225        "--dry-run",
226        action="store_true",
227        help="run without writing any files (still updates outputs)",
228    )
229    parser.add_argument(
230        "--op-selection-yaml-path",
231        "--op_selection_yaml_path",
232        help="Provide a path to the operator selection (for custom build) YAML "
233        "that contains the information about the set of selected operators "
234        "and their categories (training, ...). Each operator is either a "
235        "full operator name with overload or just a bare operator name. "
236        "The operator names also contain the namespace prefix (e.g. aten::)",
237    )
238    parser.add_argument(
239        "--op-registration-allowlist",
240        "--op_registration_allowlist",
241        nargs="*",
242        help="filter op registrations by the allowlist (if set); "
243        "each item is `namespace`::`operator name` without overload name; "
244        "e.g.: aten::empty aten::conv2d ...",
245    )
246    parser.add_argument(
247        "--TEST-ONLY-op-registration-allowlist-yaml-path",
248        "--TEST_ONLY_op_registration_allowlist_yaml_path",
249        help="Provide a path to the operator selection (for custom build) YAML "
250        "which contains a list of operators. It is to serve testing purpose and "
251        "each item is `namespace`::`operator name` without overload name; "
252        "e.g.: aten::empty aten::conv2d ...",
253    )
254
255    options = parser.parse_args(args)
256    if options.op_registration_allowlist:
257        op_registration_allowlist = options.op_registration_allowlist
258    elif options.TEST_ONLY_op_registration_allowlist_yaml_path:
259        with open(options.TEST_ONLY_op_registration_allowlist_yaml_path) as f:
260            op_registration_allowlist = yaml.safe_load(f)
261    else:
262        op_registration_allowlist = None
263
264    selector = get_custom_build_selector(
265        op_registration_allowlist,
266        options.op_selection_yaml_path,
267    )
268
269    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
270    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
271    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
272    native_functions, backend_indices = (
273        parsed_yaml.native_functions,
274        parsed_yaml.backend_indices,
275    )
276
277    cpu_fm = make_file_manager(options=options)
278    gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
279
280    if options.output_dependencies:
281        depfile_path = Path(options.output_dependencies).resolve()
282        depfile_name = depfile_path.name
283        depfile_stem = depfile_path.stem
284
285        path = depfile_path.parent / depfile_name
286        cpu_fm.write_outputs(depfile_stem, str(path))
287
288
289if __name__ == "__main__":
290    main(sys.argv[1:])
291