xref: /aosp_15_r20/external/pytorch/torchgen/gen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport argparse
4*da0073e9SAndroid Build Coastguard Workerimport functools
5*da0073e9SAndroid Build Coastguard Workerimport json
6*da0073e9SAndroid Build Coastguard Workerimport os
7*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, namedtuple, OrderedDict
8*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field
9*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
10*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Literal, Sequence, TypeVar
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport yaml
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.dispatcher as dispatcher
15*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.meta as meta
16*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.native as native
17*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.structured as structured
18*da0073e9SAndroid Build Coastguard Workerimport torchgen.dest as dest
19*da0073e9SAndroid Build Coastguard Workerfrom torchgen.aoti.fallback_ops import inductor_fallback_ops
20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
21*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.translate import translate
22*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
23*da0073e9SAndroid Build Coastguard Worker    Binding,
24*da0073e9SAndroid Build Coastguard Worker    CppSignature,
25*da0073e9SAndroid Build Coastguard Worker    CppSignatureGroup,
26*da0073e9SAndroid Build Coastguard Worker    DispatcherSignature,
27*da0073e9SAndroid Build Coastguard Worker    NamedCType,
28*da0073e9SAndroid Build Coastguard Worker    NativeSignature,
29*da0073e9SAndroid Build Coastguard Worker    SpecialArgName,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import (
32*da0073e9SAndroid Build Coastguard Worker    method_with_native_function,
33*da0073e9SAndroid Build Coastguard Worker    native_function_manager,
34*da0073e9SAndroid Build Coastguard Worker    with_native_function,
35*da0073e9SAndroid Build Coastguard Worker    with_native_function_and_indices,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen_aoti_c_shim import (
38*da0073e9SAndroid Build Coastguard Worker    gen_aoti_c_shim,
39*da0073e9SAndroid Build Coastguard Worker    gen_static_dispatch_backend_call_signature,
40*da0073e9SAndroid Build Coastguard Worker    get_fallback_op_name,
41*da0073e9SAndroid Build Coastguard Worker    get_header_for_aoti,
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen_functionalization_type import (
44*da0073e9SAndroid Build Coastguard Worker    gen_functionalization_definition,
45*da0073e9SAndroid Build Coastguard Worker    gen_functionalization_registration,
46*da0073e9SAndroid Build Coastguard Worker    gen_functionalization_view_inverse_declaration,
47*da0073e9SAndroid Build Coastguard Worker    GenCompositeViewCopyKernel,
48*da0073e9SAndroid Build Coastguard Worker)
49*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
50*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
51*da0073e9SAndroid Build Coastguard Worker    Argument,
52*da0073e9SAndroid Build Coastguard Worker    BackendIndex,
53*da0073e9SAndroid Build Coastguard Worker    BackendMetadata,
54*da0073e9SAndroid Build Coastguard Worker    BaseOperatorName,
55*da0073e9SAndroid Build Coastguard Worker    DEFAULT_KERNEL_NAMESPACE,
56*da0073e9SAndroid Build Coastguard Worker    DispatchKey,
57*da0073e9SAndroid Build Coastguard Worker    FRAGMENT_NAMESPACES,
58*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
59*da0073e9SAndroid Build Coastguard Worker    is_cuda_dispatch_key,
60*da0073e9SAndroid Build Coastguard Worker    is_generic_dispatch_key,
61*da0073e9SAndroid Build Coastguard Worker    is_ufunc_dispatch_key,
62*da0073e9SAndroid Build Coastguard Worker    is_xpu_dispatch_key,
63*da0073e9SAndroid Build Coastguard Worker    Location,
64*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
65*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsGroup,
66*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsViewGroup,
67*da0073e9SAndroid Build Coastguard Worker    OperatorName,
68*da0073e9SAndroid Build Coastguard Worker    OptionalType,
69*da0073e9SAndroid Build Coastguard Worker    SchemaKind,
70*da0073e9SAndroid Build Coastguard Worker    SelfArgument,
71*da0073e9SAndroid Build Coastguard Worker    STRUCTURED_DISPATCH_KEYS,
72*da0073e9SAndroid Build Coastguard Worker    TensorOptionsArguments,
73*da0073e9SAndroid Build Coastguard Worker    Type,
74*da0073e9SAndroid Build Coastguard Worker    Variant,
75*da0073e9SAndroid Build Coastguard Worker    ViewSchemaKind,
76*da0073e9SAndroid Build Coastguard Worker)
77*da0073e9SAndroid Build Coastguard Workerfrom torchgen.native_function_generation import (
78*da0073e9SAndroid Build Coastguard Worker    add_generated_native_functions,
79*da0073e9SAndroid Build Coastguard Worker    gen_composite_functional_kernel,
80*da0073e9SAndroid Build Coastguard Worker    gen_composite_out_kernel,
81*da0073e9SAndroid Build Coastguard Worker    pre_group_native_functions,
82*da0073e9SAndroid Build Coastguard Worker)
83*da0073e9SAndroid Build Coastguard Workerfrom torchgen.selective_build.selector import SelectiveBuilder
84*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import (
85*da0073e9SAndroid Build Coastguard Worker    assert_never,
86*da0073e9SAndroid Build Coastguard Worker    concatMap,
87*da0073e9SAndroid Build Coastguard Worker    context,
88*da0073e9SAndroid Build Coastguard Worker    FileManager,
89*da0073e9SAndroid Build Coastguard Worker    make_file_manager,
90*da0073e9SAndroid Build Coastguard Worker    mapMaybe,
91*da0073e9SAndroid Build Coastguard Worker    NamespaceHelper,
92*da0073e9SAndroid Build Coastguard Worker    Target,
93*da0073e9SAndroid Build Coastguard Worker)
94*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlDumper, YamlLoader
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T")
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker# Welcome to the ATen code generator v2!  The ATen code generator is
100*da0073e9SAndroid Build Coastguard Worker# responsible for parsing native_functions.yaml and then generating
101*da0073e9SAndroid Build Coastguard Worker# various generated files (e.g., TypeDefault.cpp) based on the operators
102*da0073e9SAndroid Build Coastguard Worker# defined in this file.  This means that the code generator knows how to
103*da0073e9SAndroid Build Coastguard Worker# parse function schema, and then translate this into various C++ types
104*da0073e9SAndroid Build Coastguard Worker# and boilerplate code.
105*da0073e9SAndroid Build Coastguard Worker#
106*da0073e9SAndroid Build Coastguard Worker# Some things to know about this file when you modify it:
107*da0073e9SAndroid Build Coastguard Worker#
108*da0073e9SAndroid Build Coastguard Worker# - This file has STRICT mypy typechecking.  Typecheck it with
109*da0073e9SAndroid Build Coastguard Worker#   `mypy --config mypy-strict.ini` in the root source directory
110*da0073e9SAndroid Build Coastguard Worker#
111*da0073e9SAndroid Build Coastguard Worker# - Most of the heavy lifting lives in external modules:
112*da0073e9SAndroid Build Coastguard Worker#   - 'model' has the data model for native_functions.yaml.  The classes
113*da0073e9SAndroid Build Coastguard Worker#     in those file represent what you see when you look at
114*da0073e9SAndroid Build Coastguard Worker#     a native_functions.yaml
115*da0073e9SAndroid Build Coastguard Worker#   - 'api' has conversions for how to translate JIT schema into
116*da0073e9SAndroid Build Coastguard Worker#     the various C++ APIs that the codegen interacts with.  There
117*da0073e9SAndroid Build Coastguard Worker#     are in fact THREE different C++ APIs: the public C++ API,
118*da0073e9SAndroid Build Coastguard Worker#     the dispatcher API, and the legacy dispatcher API.  See each
119*da0073e9SAndroid Build Coastguard Worker#     of these respective files for more information
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
122*da0073e9SAndroid Build Coastguard Worker#
123*da0073e9SAndroid Build Coastguard Worker#                         HELPER FUNCTIONS
124*da0073e9SAndroid Build Coastguard Worker#
125*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker# A custom loader for YAML to let us also keep track of line numbers
129*da0073e9SAndroid Build Coastguard Worker# of each entry in the YAML file
130*da0073e9SAndroid Build Coastguard Workerclass LineLoader(YamlLoader):
131*da0073e9SAndroid Build Coastguard Worker    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
132*da0073e9SAndroid Build Coastguard Worker        mapping = super().construct_mapping(node, deep=deep)  # type: ignore[no-untyped-call]
133*da0073e9SAndroid Build Coastguard Worker        # Add 1 so line numbering starts at 1
134*da0073e9SAndroid Build Coastguard Worker        mapping["__line__"] = node.start_mark.line + 1
135*da0073e9SAndroid Build Coastguard Worker        return mapping
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
139*da0073e9SAndroid Build Coastguard WorkerParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
143*da0073e9SAndroid Build Coastguard Worker_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Workerdef parse_native_yaml_struct(
147*da0073e9SAndroid Build Coastguard Worker    es: object,
148*da0073e9SAndroid Build Coastguard Worker    valid_tags: set[str],
149*da0073e9SAndroid Build Coastguard Worker    ignore_keys: set[DispatchKey] | None = None,
150*da0073e9SAndroid Build Coastguard Worker    path: str = "<stdin>",
151*da0073e9SAndroid Build Coastguard Worker    skip_native_fns_gen: bool = False,
152*da0073e9SAndroid Build Coastguard Worker) -> ParsedYaml:
153*da0073e9SAndroid Build Coastguard Worker    assert isinstance(es, list)
154*da0073e9SAndroid Build Coastguard Worker    rs: list[NativeFunction] = []
155*da0073e9SAndroid Build Coastguard Worker    bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
156*da0073e9SAndroid Build Coastguard Worker    for e in es:
157*da0073e9SAndroid Build Coastguard Worker        assert isinstance(e, dict), f"expected to be dict: {e}"
158*da0073e9SAndroid Build Coastguard Worker        assert isinstance(e.get("__line__"), int), e
159*da0073e9SAndroid Build Coastguard Worker        loc = Location(path, e["__line__"])
160*da0073e9SAndroid Build Coastguard Worker        funcs = e.get("func")
161*da0073e9SAndroid Build Coastguard Worker        assert funcs is not None, f"missed 'func' in {e}"
162*da0073e9SAndroid Build Coastguard Worker        with context(lambda: f"in {loc}:\n  {funcs}"):
163*da0073e9SAndroid Build Coastguard Worker            func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
164*da0073e9SAndroid Build Coastguard Worker            rs.append(func)
165*da0073e9SAndroid Build Coastguard Worker            BackendIndex.grow_index(bs, m)
166*da0073e9SAndroid Build Coastguard Worker    error_check_native_functions(rs)
167*da0073e9SAndroid Build Coastguard Worker    # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
168*da0073e9SAndroid Build Coastguard Worker    indices: dict[DispatchKey, BackendIndex] = defaultdict(
169*da0073e9SAndroid Build Coastguard Worker        lambda: BackendIndex(
170*da0073e9SAndroid Build Coastguard Worker            dispatch_key=DispatchKey.Undefined,
171*da0073e9SAndroid Build Coastguard Worker            use_out_as_primary=True,
172*da0073e9SAndroid Build Coastguard Worker            external=False,
173*da0073e9SAndroid Build Coastguard Worker            device_guard=False,
174*da0073e9SAndroid Build Coastguard Worker            # I'm actually not sure about this; undefined could be hit on
175*da0073e9SAndroid Build Coastguard Worker            # empty TensorList, hypothetically that could have sizes in it
176*da0073e9SAndroid Build Coastguard Worker            index={},
177*da0073e9SAndroid Build Coastguard Worker        )
178*da0073e9SAndroid Build Coastguard Worker    )
179*da0073e9SAndroid Build Coastguard Worker    if not skip_native_fns_gen:
180*da0073e9SAndroid Build Coastguard Worker        add_generated_native_functions(rs, bs)
181*da0073e9SAndroid Build Coastguard Worker    for k, v in bs.items():
182*da0073e9SAndroid Build Coastguard Worker        # All structured in-tree operators are implemented in terms of their out operator.
183*da0073e9SAndroid Build Coastguard Worker        indices[k] = BackendIndex(
184*da0073e9SAndroid Build Coastguard Worker            dispatch_key=k,
185*da0073e9SAndroid Build Coastguard Worker            use_out_as_primary=True,
186*da0073e9SAndroid Build Coastguard Worker            external=False,
187*da0073e9SAndroid Build Coastguard Worker            # Only cuda-like devices in tree require device guards
188*da0073e9SAndroid Build Coastguard Worker            device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
189*da0073e9SAndroid Build Coastguard Worker            index=v,
190*da0073e9SAndroid Build Coastguard Worker        )
191*da0073e9SAndroid Build Coastguard Worker    return ParsedYaml(rs, indices)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Workerdef parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
195*da0073e9SAndroid Build Coastguard Worker    assert isinstance(es, list)
196*da0073e9SAndroid Build Coastguard Worker    rs: set[str] = set()
197*da0073e9SAndroid Build Coastguard Worker    for e in es:
198*da0073e9SAndroid Build Coastguard Worker        assert isinstance(e.get("__line__"), int), e
199*da0073e9SAndroid Build Coastguard Worker        loc = Location(path, e["__line__"])
200*da0073e9SAndroid Build Coastguard Worker        tags = e.get("tag")
201*da0073e9SAndroid Build Coastguard Worker        with context(lambda: f"in {loc}:\n  {tags}"):
202*da0073e9SAndroid Build Coastguard Worker            e_i = e.copy()
203*da0073e9SAndroid Build Coastguard Worker            name = e_i.pop("tag")
204*da0073e9SAndroid Build Coastguard Worker            desc = e_i.pop("desc", "")
205*da0073e9SAndroid Build Coastguard Worker            # ensure that each tag has a non-empty description
206*da0073e9SAndroid Build Coastguard Worker            assert desc != ""
207*da0073e9SAndroid Build Coastguard Worker            rs.add(name)
208*da0073e9SAndroid Build Coastguard Worker    return rs
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
212*da0073e9SAndroid Build Coastguard Workerdef parse_tags_yaml(path: str) -> set[str]:
213*da0073e9SAndroid Build Coastguard Worker    global _GLOBAL_PARSE_TAGS_YAML_CACHE
214*da0073e9SAndroid Build Coastguard Worker    if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
215*da0073e9SAndroid Build Coastguard Worker        with open(path) as f:
216*da0073e9SAndroid Build Coastguard Worker            es = yaml.load(f, Loader=LineLoader)
217*da0073e9SAndroid Build Coastguard Worker            _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Workerdef parse_native_yaml(
223*da0073e9SAndroid Build Coastguard Worker    path: str,
224*da0073e9SAndroid Build Coastguard Worker    tags_yaml_path: str,
225*da0073e9SAndroid Build Coastguard Worker    ignore_keys: set[DispatchKey] | None = None,
226*da0073e9SAndroid Build Coastguard Worker    *,
227*da0073e9SAndroid Build Coastguard Worker    skip_native_fns_gen: bool = False,
228*da0073e9SAndroid Build Coastguard Worker    loaded_yaml: object | None = None,
229*da0073e9SAndroid Build Coastguard Worker) -> ParsedYaml:
230*da0073e9SAndroid Build Coastguard Worker    global _GLOBAL_PARSE_NATIVE_YAML_CACHE
231*da0073e9SAndroid Build Coastguard Worker    if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
232*da0073e9SAndroid Build Coastguard Worker        valid_tags = parse_tags_yaml(tags_yaml_path)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        # if a loaded yaml is provided, use that instead of reading from path
235*da0073e9SAndroid Build Coastguard Worker        if loaded_yaml is None:
236*da0073e9SAndroid Build Coastguard Worker            with open(path) as f:
237*da0073e9SAndroid Build Coastguard Worker                es = yaml.load(f, Loader=LineLoader)
238*da0073e9SAndroid Build Coastguard Worker        else:
239*da0073e9SAndroid Build Coastguard Worker            es = loaded_yaml
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
242*da0073e9SAndroid Build Coastguard Worker            es,
243*da0073e9SAndroid Build Coastguard Worker            valid_tags,
244*da0073e9SAndroid Build Coastguard Worker            ignore_keys,
245*da0073e9SAndroid Build Coastguard Worker            path=path,
246*da0073e9SAndroid Build Coastguard Worker            skip_native_fns_gen=skip_native_fns_gen,
247*da0073e9SAndroid Build Coastguard Worker        )
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
253*da0073e9SAndroid Build Coastguard Worker# Assertions here are meant to be performed across NativeFunctions.
254*da0073e9SAndroid Build Coastguard Workerdef error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
255*da0073e9SAndroid Build Coastguard Worker    func_map: dict[OperatorName, NativeFunction] = {}
256*da0073e9SAndroid Build Coastguard Worker    base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
257*da0073e9SAndroid Build Coastguard Worker    for f in funcs:
258*da0073e9SAndroid Build Coastguard Worker        func_map[f.func.name] = f
259*da0073e9SAndroid Build Coastguard Worker        base_func_map[f.func.name.name].append(f)
260*da0073e9SAndroid Build Coastguard Worker    for f in funcs:
261*da0073e9SAndroid Build Coastguard Worker        if f.structured_delegate is not None:
262*da0073e9SAndroid Build Coastguard Worker            delegate_func = func_map.get(f.structured_delegate)
263*da0073e9SAndroid Build Coastguard Worker            assert delegate_func is not None, (
264*da0073e9SAndroid Build Coastguard Worker                f"{f.func.name} is marked as a structured_delegate pointing to "
265*da0073e9SAndroid Build Coastguard Worker                f"{f.structured_delegate}, but {f.structured_delegate} is missing."
266*da0073e9SAndroid Build Coastguard Worker            )
267*da0073e9SAndroid Build Coastguard Worker            assert delegate_func.structured, (
268*da0073e9SAndroid Build Coastguard Worker                f"{f.func.name} is marked as a structured_delegate pointing to "
269*da0073e9SAndroid Build Coastguard Worker                f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
270*da0073e9SAndroid Build Coastguard Worker                f"Consider adding 'structured=True' to the delegated operator"
271*da0073e9SAndroid Build Coastguard Worker            )
272*da0073e9SAndroid Build Coastguard Worker        # See Note [resize_ in Functionalization]
273*da0073e9SAndroid Build Coastguard Worker        # resize_() is technically an inplace view op (and therefore needs the tag),
274*da0073e9SAndroid Build Coastguard Worker        # but it would be overkill to add a true "view" variant of resize.
275*da0073e9SAndroid Build Coastguard Worker        # Instead, resize_() gets special treatment in functionalization,
276*da0073e9SAndroid Build Coastguard Worker        # and we have a resize() op that is non-aliasing + functional.
277*da0073e9SAndroid Build Coastguard Worker        if (
278*da0073e9SAndroid Build Coastguard Worker            "inplace_view" in f.tags
279*da0073e9SAndroid Build Coastguard Worker            and str(f.func.name) != "resize_"
280*da0073e9SAndroid Build Coastguard Worker            and str(f.func.name) != "resize_as_"
281*da0073e9SAndroid Build Coastguard Worker            and str(f.func.name.name) != "set_"
282*da0073e9SAndroid Build Coastguard Worker        ):
283*da0073e9SAndroid Build Coastguard Worker            base_name = f.func.name.name
284*da0073e9SAndroid Build Coastguard Worker            assert base_name.inplace, (
285*da0073e9SAndroid Build Coastguard Worker                f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
286*da0073e9SAndroid Build Coastguard Worker                "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
287*da0073e9SAndroid Build Coastguard Worker            )
288*da0073e9SAndroid Build Coastguard Worker            out_of_place_base_name = BaseOperatorName(
289*da0073e9SAndroid Build Coastguard Worker                base_name.base, False, base_name.dunder_method
290*da0073e9SAndroid Build Coastguard Worker            )
291*da0073e9SAndroid Build Coastguard Worker            assert len(base_func_map[out_of_place_base_name]) > 0, (
292*da0073e9SAndroid Build Coastguard Worker                f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
293*da0073e9SAndroid Build Coastguard Worker                f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
294*da0073e9SAndroid Build Coastguard Worker            )
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Workerdef cpp_string(s: str) -> str:
298*da0073e9SAndroid Build Coastguard Worker    """Convert a python string into a c++ string literal"""
299*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\\", "\\\\")
300*da0073e9SAndroid Build Coastguard Worker    s = s.replace('"', '\\"')
301*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\a", "\\a")
302*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\b", "\\b")
303*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\f", "\\f")
304*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\n", "\\n")
305*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\v", "\\v")
306*da0073e9SAndroid Build Coastguard Worker    s = s.replace("\t", "\\t")
307*da0073e9SAndroid Build Coastguard Worker    return f'"{s}"'
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
311*da0073e9SAndroid Build Coastguard Worker#
312*da0073e9SAndroid Build Coastguard Worker#                        C++ CODE GENERATION
313*da0073e9SAndroid Build Coastguard Worker#
314*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker# Most functions in this section are curried: they consist of a function
317*da0073e9SAndroid Build Coastguard Worker# that takes some parameters (e.g., what is to be generated) which itself
318*da0073e9SAndroid Build Coastguard Worker# returns a function that actually maps NativeFunction to the code
319*da0073e9SAndroid Build Coastguard Worker# to be generated.  This pattern makes it convenient to use map, concatMap
320*da0073e9SAndroid Build Coastguard Worker# and similar functional combinators.
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Workerdef static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
324*da0073e9SAndroid Build Coastguard Worker    if len(backends) == 0:
325*da0073e9SAndroid Build Coastguard Worker        return []
326*da0073e9SAndroid Build Coastguard Worker    else:
327*da0073e9SAndroid Build Coastguard Worker        return [backend.dispatch_key for backend in backends] + [
328*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeImplicitAutograd,
329*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeImplicitAutogradNestedTensor,
330*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeExplicitAutograd,
331*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CompositeExplicitAutogradNonFunctional,
332*da0073e9SAndroid Build Coastguard Worker        ]
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Workerdef get_static_dispatch_backend(
336*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction, backend_index: BackendIndex
337*da0073e9SAndroid Build Coastguard Worker) -> DispatchKey | None:
338*da0073e9SAndroid Build Coastguard Worker    if f.structured_delegate is not None or backend_index.has_kernel(f):
339*da0073e9SAndroid Build Coastguard Worker        # TODO: for ops with structured_delegate it should check the dispatch table of
340*da0073e9SAndroid Build Coastguard Worker        # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
341*da0073e9SAndroid Build Coastguard Worker        # so we always dispatch to the `backend`, but this could be wrong when we
342*da0073e9SAndroid Build Coastguard Worker        # migrate math/default_backend ops to use structured delegate.
343*da0073e9SAndroid Build Coastguard Worker        return backend_index.dispatch_key
344*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_explicit_autograd_kernel:
345*da0073e9SAndroid Build Coastguard Worker        return DispatchKey.CompositeExplicitAutograd
346*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_explicit_autograd_non_functional_kernel:
347*da0073e9SAndroid Build Coastguard Worker        return DispatchKey.CompositeExplicitAutogradNonFunctional
348*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_implicit_autograd_kernel:
349*da0073e9SAndroid Build Coastguard Worker        return DispatchKey.CompositeImplicitAutograd
350*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
351*da0073e9SAndroid Build Coastguard Worker        return DispatchKey.CompositeImplicitAutogradNestedTensor
352*da0073e9SAndroid Build Coastguard Worker    return None
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Workerdef static_dispatch_ops_header(
356*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction, backend_index: list[BackendIndex]
357*da0073e9SAndroid Build Coastguard Worker) -> str | None:
358*da0073e9SAndroid Build Coastguard Worker    if backend_index is None or f.manual_kernel_registration:
359*da0073e9SAndroid Build Coastguard Worker        return None
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    output = []
362*da0073e9SAndroid Build Coastguard Worker    for index in backend_index:
363*da0073e9SAndroid Build Coastguard Worker        dispatch_key = get_static_dispatch_backend(f, index)
364*da0073e9SAndroid Build Coastguard Worker        if dispatch_key is not None:
365*da0073e9SAndroid Build Coastguard Worker            output.append(
366*da0073e9SAndroid Build Coastguard Worker                f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
367*da0073e9SAndroid Build Coastguard Worker            )
368*da0073e9SAndroid Build Coastguard Worker    return "\n".join(output)
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Workerdef static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
372*da0073e9SAndroid Build Coastguard Worker    return [
373*da0073e9SAndroid Build Coastguard Worker        f"#include <ATen/{dispatch_key}Functions.h>"
374*da0073e9SAndroid Build Coastguard Worker        for dispatch_key in static_dispatch_keys(backends)
375*da0073e9SAndroid Build Coastguard Worker    ]
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker# Translates arguments of `sig` to CppSignature bindings.
379*da0073e9SAndroid Build Coastguard Worker# Note that we have a special case for `memory_format` argument and this case is not covered by
380*da0073e9SAndroid Build Coastguard Worker# tools.codegen.api.translate() yet as its application is limited to static dispatch.
381*da0073e9SAndroid Build Coastguard Workerdef translate_args(
382*da0073e9SAndroid Build Coastguard Worker    sig: CppSignature | DispatcherSignature,
383*da0073e9SAndroid Build Coastguard Worker    cpp_sig: CppSignature,
384*da0073e9SAndroid Build Coastguard Worker) -> str:
385*da0073e9SAndroid Build Coastguard Worker    # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
386*da0073e9SAndroid Build Coastguard Worker    def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
387*da0073e9SAndroid Build Coastguard Worker        output_bindings: list[Binding] = []
388*da0073e9SAndroid Build Coastguard Worker        for binding in input_bindings:
389*da0073e9SAndroid Build Coastguard Worker            if binding.name == "memory_format":
390*da0073e9SAndroid Build Coastguard Worker                spl_mem_format_binding = Binding(
391*da0073e9SAndroid Build Coastguard Worker                    nctype=NamedCType(
392*da0073e9SAndroid Build Coastguard Worker                        SpecialArgName.possibly_redundant_memory_format,
393*da0073e9SAndroid Build Coastguard Worker                        binding.nctype.type,
394*da0073e9SAndroid Build Coastguard Worker                    ),
395*da0073e9SAndroid Build Coastguard Worker                    name=binding.name,
396*da0073e9SAndroid Build Coastguard Worker                    default=binding.default,
397*da0073e9SAndroid Build Coastguard Worker                    argument=binding.argument,
398*da0073e9SAndroid Build Coastguard Worker                )
399*da0073e9SAndroid Build Coastguard Worker                output_bindings.append(spl_mem_format_binding)
400*da0073e9SAndroid Build Coastguard Worker            else:
401*da0073e9SAndroid Build Coastguard Worker                output_bindings.append(binding)
402*da0073e9SAndroid Build Coastguard Worker        return output_bindings
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    src_bindings = list(sig.arguments())
405*da0073e9SAndroid Build Coastguard Worker    goal_bindings = list(cpp_sig.arguments())
406*da0073e9SAndroid Build Coastguard Worker    # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
407*da0073e9SAndroid Build Coastguard Worker    # get memory_format bindings of dispatcher signature to have the same NCType as well
408*da0073e9SAndroid Build Coastguard Worker    for arg in goal_bindings:
409*da0073e9SAndroid Build Coastguard Worker        if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
410*da0073e9SAndroid Build Coastguard Worker            src_bindings = add_spl_memory_format_binding(src_bindings)
411*da0073e9SAndroid Build Coastguard Worker            break
412*da0073e9SAndroid Build Coastguard Worker    exprs = translate(src_bindings, goal_bindings)
413*da0073e9SAndroid Build Coastguard Worker    return ", ".join(a.expr for a in exprs)
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Workerdef generate_static_dispatch_backend_call(
417*da0073e9SAndroid Build Coastguard Worker    sig: CppSignature | DispatcherSignature,
418*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
419*da0073e9SAndroid Build Coastguard Worker    backend_index: BackendIndex,
420*da0073e9SAndroid Build Coastguard Worker) -> str:
421*da0073e9SAndroid Build Coastguard Worker    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
422*da0073e9SAndroid Build Coastguard Worker    name = cpp_sig.name()
423*da0073e9SAndroid Build Coastguard Worker    exprs = translate_args(sig, cpp_sig)
424*da0073e9SAndroid Build Coastguard Worker    backend_metadata = backend_index.get_kernel(f)
425*da0073e9SAndroid Build Coastguard Worker    kernel_ns = (
426*da0073e9SAndroid Build Coastguard Worker        backend_metadata.cpp_namespace
427*da0073e9SAndroid Build Coastguard Worker        if backend_metadata and backend_metadata.cpp_namespace
428*da0073e9SAndroid Build Coastguard Worker        else DEFAULT_KERNEL_NAMESPACE
429*da0073e9SAndroid Build Coastguard Worker    )
430*da0073e9SAndroid Build Coastguard Worker    ns = kernel_ns.replace("::native", "")
431*da0073e9SAndroid Build Coastguard Worker    return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Workerdef generate_static_dispatch_fallback_call(
435*da0073e9SAndroid Build Coastguard Worker    sig: CppSignature | DispatcherSignature,
436*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
437*da0073e9SAndroid Build Coastguard Worker    backend_indices: list[BackendIndex],
438*da0073e9SAndroid Build Coastguard Worker) -> str:
439*da0073e9SAndroid Build Coastguard Worker    cpp_sigs = CppSignatureGroup.from_native_function(
440*da0073e9SAndroid Build Coastguard Worker        f, method=False, fallback_binding=False
441*da0073e9SAndroid Build Coastguard Worker    )
442*da0073e9SAndroid Build Coastguard Worker    if sig.symint and f.func.has_symint():
443*da0073e9SAndroid Build Coastguard Worker        cpp_sig = cpp_sigs.symint_signature
444*da0073e9SAndroid Build Coastguard Worker    else:
445*da0073e9SAndroid Build Coastguard Worker        cpp_sig = cpp_sigs.signature
446*da0073e9SAndroid Build Coastguard Worker    assert cpp_sig is not None
447*da0073e9SAndroid Build Coastguard Worker    name = cpp_sig.name()
448*da0073e9SAndroid Build Coastguard Worker    exprs = translate_args(sig, cpp_sig)
449*da0073e9SAndroid Build Coastguard Worker    ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
450*da0073e9SAndroid Build Coastguard Worker    if f.has_composite_explicit_autograd_kernel:
451*da0073e9SAndroid Build Coastguard Worker        return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
452*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_explicit_autograd_non_functional_kernel:
453*da0073e9SAndroid Build Coastguard Worker        return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
454*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_implicit_autograd_kernel:
455*da0073e9SAndroid Build Coastguard Worker        return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
456*da0073e9SAndroid Build Coastguard Worker    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
457*da0073e9SAndroid Build Coastguard Worker        return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
458*da0073e9SAndroid Build Coastguard Worker    else:
459*da0073e9SAndroid Build Coastguard Worker        return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
460*da0073e9SAndroid Build Coastguard Worker{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Workerdef static_dispatch(
464*da0073e9SAndroid Build Coastguard Worker    sig: CppSignature | DispatcherSignature,
465*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
466*da0073e9SAndroid Build Coastguard Worker    backend_indices: list[BackendIndex],
467*da0073e9SAndroid Build Coastguard Worker) -> str:
468*da0073e9SAndroid Build Coastguard Worker    """
469*da0073e9SAndroid Build Coastguard Worker    For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
470*da0073e9SAndroid Build Coastguard Worker    backends exsit, fallback to static dispatch by determining dispatch key from inputs.
471*da0073e9SAndroid Build Coastguard Worker    Arguments:
472*da0073e9SAndroid Build Coastguard Worker        sig: A CppSignature or DispatcherSignature for this native function we want to use.
473*da0073e9SAndroid Build Coastguard Worker        f: NativeFunction to generate static dispatch.
474*da0073e9SAndroid Build Coastguard Worker        backend_indices: All available backends.
475*da0073e9SAndroid Build Coastguard Worker    Return:
476*da0073e9SAndroid Build Coastguard Worker        C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
477*da0073e9SAndroid Build Coastguard Worker    """
478*da0073e9SAndroid Build Coastguard Worker    if len(backend_indices) == 0 or f.manual_kernel_registration:
479*da0073e9SAndroid Build Coastguard Worker        return ""
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    keys = [
482*da0073e9SAndroid Build Coastguard Worker        b
483*da0073e9SAndroid Build Coastguard Worker        for b in backend_indices
484*da0073e9SAndroid Build Coastguard Worker        if b.has_kernel(f)
485*da0073e9SAndroid Build Coastguard Worker        or (
486*da0073e9SAndroid Build Coastguard Worker            f.structured_delegate is not None
487*da0073e9SAndroid Build Coastguard Worker            and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
488*da0073e9SAndroid Build Coastguard Worker        )
489*da0073e9SAndroid Build Coastguard Worker    ]
490*da0073e9SAndroid Build Coastguard Worker    if len(keys) == 1:
491*da0073e9SAndroid Build Coastguard Worker        return generate_static_dispatch_backend_call(sig, f, keys[0])
492*da0073e9SAndroid Build Coastguard Worker    elif len(keys) == 0:
493*da0073e9SAndroid Build Coastguard Worker        return generate_static_dispatch_fallback_call(sig, f, backend_indices)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    native_tensor_args = [
496*da0073e9SAndroid Build Coastguard Worker        a.name
497*da0073e9SAndroid Build Coastguard Worker        for a in sig.arguments()
498*da0073e9SAndroid Build Coastguard Worker        if isinstance(a.argument, SelfArgument)
499*da0073e9SAndroid Build Coastguard Worker        or isinstance(a.argument, Argument)
500*da0073e9SAndroid Build Coastguard Worker        and a.argument.type.is_tensor_like()
501*da0073e9SAndroid Build Coastguard Worker    ]
502*da0073e9SAndroid Build Coastguard Worker    tensor_args = ", ".join(native_tensor_args)
503*da0073e9SAndroid Build Coastguard Worker    tensor_opts = f.func.arguments.tensor_options
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    stmts = []
506*da0073e9SAndroid Build Coastguard Worker    subexprs: list[str] = []
507*da0073e9SAndroid Build Coastguard Worker    if tensor_opts is not None:
508*da0073e9SAndroid Build Coastguard Worker        subexprs.append(
509*da0073e9SAndroid Build Coastguard Worker            "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
510*da0073e9SAndroid Build Coastguard Worker        )
511*da0073e9SAndroid Build Coastguard Worker    if tensor_args != "":
512*da0073e9SAndroid Build Coastguard Worker        subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
513*da0073e9SAndroid Build Coastguard Worker    stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
514*da0073e9SAndroid Build Coastguard Worker    stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker    dispatch_code = []
517*da0073e9SAndroid Build Coastguard Worker    for index in keys:
518*da0073e9SAndroid Build Coastguard Worker        dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
519*da0073e9SAndroid Build Coastguard Worker        dispatch_code.append(
520*da0073e9SAndroid Build Coastguard Worker            f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
521*da0073e9SAndroid Build Coastguard Worker        )
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
524*da0073e9SAndroid Build Coastguard Worker    connector = "\n\t\t"
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker    return f"""
527*da0073e9SAndroid Build Coastguard Worker    {connector.join(stmts)}
528*da0073e9SAndroid Build Coastguard Worker    switch (_dk) {{
529*da0073e9SAndroid Build Coastguard Worker        {connector.join(dispatch_code)}
530*da0073e9SAndroid Build Coastguard Worker        default:
531*da0073e9SAndroid Build Coastguard Worker            {fallback}
532*da0073e9SAndroid Build Coastguard Worker    }}
533*da0073e9SAndroid Build Coastguard Worker    """
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker# Generates RegisterSchema.cpp.  Depending on the selector, either
537*da0073e9SAndroid Build Coastguard Worker# all schemas are registered, or only some are (in the case of
538*da0073e9SAndroid Build Coastguard Worker# selective build)
539*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
540*da0073e9SAndroid Build Coastguard Workerclass RegisterSchema:
541*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder
542*da0073e9SAndroid Build Coastguard Worker    known_tags: dict[str, int] = field(default_factory=dict)
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    @method_with_native_function
545*da0073e9SAndroid Build Coastguard Worker    def __call__(self, f: NativeFunction) -> str | None:
546*da0073e9SAndroid Build Coastguard Worker        if not self.selector.is_native_function_selected(f):
547*da0073e9SAndroid Build Coastguard Worker            return None
548*da0073e9SAndroid Build Coastguard Worker        tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
549*da0073e9SAndroid Build Coastguard Worker        if tags == "{}":
550*da0073e9SAndroid Build Coastguard Worker            return f"m.def({cpp_string(str(f.func))}, {{}});\n"
551*da0073e9SAndroid Build Coastguard Worker        maybe_tags = ""
552*da0073e9SAndroid Build Coastguard Worker        if tags not in self.known_tags:
553*da0073e9SAndroid Build Coastguard Worker            idx = len(self.known_tags)
554*da0073e9SAndroid Build Coastguard Worker            self.known_tags[tags] = idx
555*da0073e9SAndroid Build Coastguard Worker            maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
556*da0073e9SAndroid Build Coastguard Worker        return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker# Generates Operators.h and Operators.cpp.
560*da0073e9SAndroid Build Coastguard Worker# These provide macros that, given an operator and overload name, allow users
561*da0073e9SAndroid Build Coastguard Worker# to access an "un-overloaded" function version of the operator. This
562*da0073e9SAndroid Build Coastguard Worker# is useful for extension writers who want to (1) want to decltype the operator
563*da0073e9SAndroid Build Coastguard Worker# and (2) don't want to worry about method-only operators.
564*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
565*da0073e9SAndroid Build Coastguard Workerclass ComputeOperators:
566*da0073e9SAndroid Build Coastguard Worker    target: Literal[Target.DECLARATION, Target.DEFINITION]
567*da0073e9SAndroid Build Coastguard Worker    static_dispatch_backend_indices: list[BackendIndex]
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker    @method_with_native_function
570*da0073e9SAndroid Build Coastguard Worker    def __call__(self, f: NativeFunction) -> str:
571*da0073e9SAndroid Build Coastguard Worker        sig = DispatcherSignature.from_schema(f.func)
572*da0073e9SAndroid Build Coastguard Worker        name = f.func.name.unambiguous_name()
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker        if self.target is Target.DECLARATION:
575*da0073e9SAndroid Build Coastguard Worker            # Note [The ATen Operators API]
576*da0073e9SAndroid Build Coastguard Worker            # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
577*da0073e9SAndroid Build Coastguard Worker            # metadata about each operator + entry points into the Dispatcher.
578*da0073e9SAndroid Build Coastguard Worker            # The C++ function, method, and redispatch API's are all implemented as wrappers
579*da0073e9SAndroid Build Coastguard Worker            # into various bits of the structs defined here.
580*da0073e9SAndroid Build Coastguard Worker            #
581*da0073e9SAndroid Build Coastguard Worker            # Important characteristics about the Operators API:
582*da0073e9SAndroid Build Coastguard Worker            # (1) It follows the Dispatcher API.
583*da0073e9SAndroid Build Coastguard Worker            #     This is kind of necessary to avoid overhead.
584*da0073e9SAndroid Build Coastguard Worker            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
585*da0073e9SAndroid Build Coastguard Worker            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
586*da0073e9SAndroid Build Coastguard Worker            # (2) Overload names are disambiguated.
587*da0073e9SAndroid Build Coastguard Worker            #     This is helpful for pytorch extenders who would like to decltype() an aten operator,
588*da0073e9SAndroid Build Coastguard Worker            #     that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
589*da0073e9SAndroid Build Coastguard Worker            # (3) No argument defaulting is allowed.
590*da0073e9SAndroid Build Coastguard Worker            #     This is more of an implementation detail to avoid #include cycles,
591*da0073e9SAndroid Build Coastguard Worker            #     since TensorBody.h (which defines the Tensor class) needs to include this file.
592*da0073e9SAndroid Build Coastguard Worker            # (4) manual_cpp_bindings and faithful names are not included in the API.
593*da0073e9SAndroid Build Coastguard Worker            #     This applies to stuff like __dispatch__is_complex(), and add_outf().
594*da0073e9SAndroid Build Coastguard Worker            #     These aren't "real aten ops", they're just additional functions provided by the C++ API.
595*da0073e9SAndroid Build Coastguard Worker            #     They're implemented as wrappers in Functions.h that call into the actual operators
596*da0073e9SAndroid Build Coastguard Worker            #     defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
597*da0073e9SAndroid Build Coastguard Worker            #     This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
598*da0073e9SAndroid Build Coastguard Worker            return f"""
599*da0073e9SAndroid Build Coastguard Workerstruct TORCH_API {name} {{
600*da0073e9SAndroid Build Coastguard Worker  using schema = {sig.type()};
601*da0073e9SAndroid Build Coastguard Worker  using ptr_schema = schema*;
602*da0073e9SAndroid Build Coastguard Worker  // See Note [static constexpr char* members for windows NVCC]
603*da0073e9SAndroid Build Coastguard Worker  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
604*da0073e9SAndroid Build Coastguard Worker  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
605*da0073e9SAndroid Build Coastguard Worker  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
606*da0073e9SAndroid Build Coastguard Worker  static {sig.defn(name="call", is_redispatching_fn=False)};
607*da0073e9SAndroid Build Coastguard Worker  static {sig.defn(name="redispatch", is_redispatching_fn=True)};
608*da0073e9SAndroid Build Coastguard Worker}};"""
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        elif self.target is Target.DEFINITION:
611*da0073e9SAndroid Build Coastguard Worker            defns = f"""
612*da0073e9SAndroid Build Coastguard WorkerSTATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
613*da0073e9SAndroid Build Coastguard WorkerSTATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
614*da0073e9SAndroid Build Coastguard WorkerSTATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker// aten::{f.func}
617*da0073e9SAndroid Build Coastguard Workerstatic C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
618*da0073e9SAndroid Build Coastguard Worker  return c10::Dispatcher::singleton()
619*da0073e9SAndroid Build Coastguard Worker      .findSchemaOrThrow({name}::name, {name}::overload_name)
620*da0073e9SAndroid Build Coastguard Worker      .typed<{name}::schema>();
621*da0073e9SAndroid Build Coastguard Worker}}
622*da0073e9SAndroid Build Coastguard Worker"""
623*da0073e9SAndroid Build Coastguard Worker            for is_redispatching_fn in [False, True]:
624*da0073e9SAndroid Build Coastguard Worker                if is_redispatching_fn:
625*da0073e9SAndroid Build Coastguard Worker                    dispatcher_exprs_str = ", ".join(
626*da0073e9SAndroid Build Coastguard Worker                        ["dispatchKeySet"] + [a.name for a in sig.arguments()]
627*da0073e9SAndroid Build Coastguard Worker                    )
628*da0073e9SAndroid Build Coastguard Worker                    method_base = "redispatch"
629*da0073e9SAndroid Build Coastguard Worker                else:
630*da0073e9SAndroid Build Coastguard Worker                    dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
631*da0073e9SAndroid Build Coastguard Worker                    method_base = "call"
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker                dispatcher_call = method_base
634*da0073e9SAndroid Build Coastguard Worker                method_name = f"{name}::{method_base}"
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker                fn_body = f"""
637*da0073e9SAndroid Build Coastguard Worker    static auto op = create_{name}_typed_handle();
638*da0073e9SAndroid Build Coastguard Worker    return op.{dispatcher_call}({dispatcher_exprs_str});"""
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker                if (
641*da0073e9SAndroid Build Coastguard Worker                    not is_redispatching_fn
642*da0073e9SAndroid Build Coastguard Worker                    and len(self.static_dispatch_backend_indices) > 0
643*da0073e9SAndroid Build Coastguard Worker                ):
644*da0073e9SAndroid Build Coastguard Worker                    # call() should go through static dispatch
645*da0073e9SAndroid Build Coastguard Worker                    fn_body = static_dispatch(
646*da0073e9SAndroid Build Coastguard Worker                        sig, f, backend_indices=self.static_dispatch_backend_indices
647*da0073e9SAndroid Build Coastguard Worker                    )
648*da0073e9SAndroid Build Coastguard Worker                defns += f"""
649*da0073e9SAndroid Build Coastguard Worker// aten::{f.func}
650*da0073e9SAndroid Build Coastguard Worker{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
651*da0073e9SAndroid Build Coastguard Worker    {fn_body}
652*da0073e9SAndroid Build Coastguard Worker}}
653*da0073e9SAndroid Build Coastguard Worker"""
654*da0073e9SAndroid Build Coastguard Worker            return defns
655*da0073e9SAndroid Build Coastguard Worker        else:
656*da0073e9SAndroid Build Coastguard Worker            assert_never(self.target)
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker# Generates Functions.h, which provides the functional public C++ API,
660*da0073e9SAndroid Build Coastguard Worker# and the scaffolding to call into the dispatcher from these functions.
661*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
662*da0073e9SAndroid Build Coastguard Workerclass ComputeFunction:
663*da0073e9SAndroid Build Coastguard Worker    @method_with_native_function
664*da0073e9SAndroid Build Coastguard Worker    def __call__(self, f: NativeFunction) -> str | None:
665*da0073e9SAndroid Build Coastguard Worker        sig_group = CppSignatureGroup.from_native_function(
666*da0073e9SAndroid Build Coastguard Worker            f, method=False, fallback_binding=f.manual_cpp_binding
667*da0073e9SAndroid Build Coastguard Worker        )
668*da0073e9SAndroid Build Coastguard Worker        has_symint = f.func.has_symint()
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker        result = ""
671*da0073e9SAndroid Build Coastguard Worker        for sig in sig_group.signatures():
672*da0073e9SAndroid Build Coastguard Worker            # See Note [The ATen Operators API]
673*da0073e9SAndroid Build Coastguard Worker            target_sig = DispatcherSignature.from_schema(f.func)
674*da0073e9SAndroid Build Coastguard Worker            exprs = translate(sig.arguments(), target_sig.arguments())
675*da0073e9SAndroid Build Coastguard Worker            exprs_str = ", ".join([e.expr for e in exprs])
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker            if sig.symint:
678*da0073e9SAndroid Build Coastguard Worker                intlike_t = "c10::SymInt"
679*da0073e9SAndroid Build Coastguard Worker            else:
680*da0073e9SAndroid Build Coastguard Worker                intlike_t = "int64_t"
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker            if Variant.function in f.variants:
683*da0073e9SAndroid Build Coastguard Worker                result += f"""
684*da0073e9SAndroid Build Coastguard Worker// aten::{f.func}
685*da0073e9SAndroid Build Coastguard Workerinline {sig.decl()} {{
686*da0073e9SAndroid Build Coastguard Worker    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
687*da0073e9SAndroid Build Coastguard Worker}}"""
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker            # The template function can be used from template situations
690*da0073e9SAndroid Build Coastguard Worker            # where you want to switch between the symint or not version
691*da0073e9SAndroid Build Coastguard Worker            # depending on a template argument
692*da0073e9SAndroid Build Coastguard Worker            #
693*da0073e9SAndroid Build Coastguard Worker            # NB: we ALWAYS generate this even for methods.  But we put it in
694*da0073e9SAndroid Build Coastguard Worker            # this header so it can take advantage of per-op headers
695*da0073e9SAndroid Build Coastguard Worker            if has_symint:
696*da0073e9SAndroid Build Coastguard Worker                result += f"""
697*da0073e9SAndroid Build Coastguard Workernamespace symint {{
698*da0073e9SAndroid Build Coastguard Worker  template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>>
699*da0073e9SAndroid Build Coastguard Worker  {sig.decl(suppress_symint_suffix=True)} {{
700*da0073e9SAndroid Build Coastguard Worker    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
701*da0073e9SAndroid Build Coastguard Worker  }}
702*da0073e9SAndroid Build Coastguard Worker}}
703*da0073e9SAndroid Build Coastguard Worker"""
704*da0073e9SAndroid Build Coastguard Worker        return result
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker# Generates TensorBody.h. This file provides the object-oriented (method-based)
708*da0073e9SAndroid Build Coastguard Worker# public C++ API, and the scaffolding to call into the dispatcher from these functions.
709*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
710*da0073e9SAndroid Build Coastguard Workerclass ComputeTensorMethod:
711*da0073e9SAndroid Build Coastguard Worker    target: Literal[Target.DECLARATION, Target.DEFINITION]
712*da0073e9SAndroid Build Coastguard Worker    static_dispatch_backend_indices: list[BackendIndex]
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker    @method_with_native_function
715*da0073e9SAndroid Build Coastguard Worker    def __call__(self, f: NativeFunction) -> str | None:
716*da0073e9SAndroid Build Coastguard Worker        if Variant.method not in f.variants:
717*da0073e9SAndroid Build Coastguard Worker            return None
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        assert not f.func.is_out_fn()
720*da0073e9SAndroid Build Coastguard Worker        assert f.func.arguments.self_arg is not None
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Worker        sig_group = CppSignatureGroup.from_native_function(
723*da0073e9SAndroid Build Coastguard Worker            f, method=True, fallback_binding=f.manual_cpp_binding
724*da0073e9SAndroid Build Coastguard Worker        )
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker        if self.target is Target.DECLARATION:
727*da0073e9SAndroid Build Coastguard Worker            result = ""
728*da0073e9SAndroid Build Coastguard Worker            for sig in sig_group.signatures():
729*da0073e9SAndroid Build Coastguard Worker                result += f"{sig.decl()} const;\n"
730*da0073e9SAndroid Build Coastguard Worker            return result
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker        if self.target is not Target.DEFINITION:
733*da0073e9SAndroid Build Coastguard Worker            assert_never(self.target)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker        result = ""
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        for sig in sig_group.signatures():
738*da0073e9SAndroid Build Coastguard Worker            target_sig = DispatcherSignature.from_schema(f.func)
739*da0073e9SAndroid Build Coastguard Worker            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
740*da0073e9SAndroid Build Coastguard Worker            exprs_str = ", ".join([e.expr for e in exprs])
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker            result += f"""
743*da0073e9SAndroid Build Coastguard Worker// aten::{f.func}
744*da0073e9SAndroid Build Coastguard Workerinline {sig.defn(prefix="Tensor::")} const {{
745*da0073e9SAndroid Build Coastguard Worker    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
746*da0073e9SAndroid Build Coastguard Worker}}
747*da0073e9SAndroid Build Coastguard Worker"""
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        return result
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker# Generates RedispatchFunctions.h.
753*da0073e9SAndroid Build Coastguard Worker# This is similar to the C++ API defined in Functions.h, but provides access
754*da0073e9SAndroid Build Coastguard Worker# to the dispatcher's redispatch API.
755*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
756*da0073e9SAndroid Build Coastguard Workerclass ComputeRedispatchFunction:
757*da0073e9SAndroid Build Coastguard Worker    @method_with_native_function
758*da0073e9SAndroid Build Coastguard Worker    def __call__(self, f: NativeFunction) -> str | None:
759*da0073e9SAndroid Build Coastguard Worker        # We unconditionally generate function variants of the redispatch API.
760*da0073e9SAndroid Build Coastguard Worker        # This is mainly because we can namespace functions separately, but not methods,
761*da0073e9SAndroid Build Coastguard Worker        sig_group = CppSignatureGroup.from_native_function(
762*da0073e9SAndroid Build Coastguard Worker            f, method=False, fallback_binding=f.manual_cpp_binding
763*da0073e9SAndroid Build Coastguard Worker        )
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        result = ""
766*da0073e9SAndroid Build Coastguard Worker        for sig in sig_group.signatures():
767*da0073e9SAndroid Build Coastguard Worker            target_sig = DispatcherSignature.from_schema(f.func)
768*da0073e9SAndroid Build Coastguard Worker            exprs = translate(sig.arguments(), target_sig.arguments())
769*da0073e9SAndroid Build Coastguard Worker            exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker            result += f"""
772*da0073e9SAndroid Build Coastguard Worker// aten::{f.func}
773*da0073e9SAndroid Build Coastguard Workerinline {sig.decl(is_redispatching_fn=True)} {{
774*da0073e9SAndroid Build Coastguard Worker    return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
775*da0073e9SAndroid Build Coastguard Worker}}
776*da0073e9SAndroid Build Coastguard Worker"""
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        return result
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker# Generates ATenOpList.cpp, a runtime accessible list of all aten
782*da0073e9SAndroid Build Coastguard Worker# operators.
783*da0073e9SAndroid Build Coastguard Worker# TODO: This was historically used to help some JIT interop code
784*da0073e9SAndroid Build Coastguard Worker# figure out whether or not to treat aten namespace'd operators
785*da0073e9SAndroid Build Coastguard Worker# one way or another, we should reevaluate if this is actually needed.
786*da0073e9SAndroid Build Coastguard Worker@with_native_function
787*da0073e9SAndroid Build Coastguard Workerdef compute_aten_op(f: NativeFunction) -> str:
788*da0073e9SAndroid Build Coastguard Worker    return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker# Generates MetaFunctions.h
792*da0073e9SAndroid Build Coastguard Workerdef compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
793*da0073e9SAndroid Build Coastguard Worker    if not g.structured:
794*da0073e9SAndroid Build Coastguard Worker        return None
795*da0073e9SAndroid Build Coastguard Worker    with native_function_manager(g.out):
796*da0073e9SAndroid Build Coastguard Worker        name = meta.name(g)
797*da0073e9SAndroid Build Coastguard Worker        args = structured.meta_arguments(g)
798*da0073e9SAndroid Build Coastguard Worker        args_str = ", ".join(a.decl() for a in args)
799*da0073e9SAndroid Build Coastguard Worker        parent_class = g.out.structured_inherits
800*da0073e9SAndroid Build Coastguard Worker        if parent_class is None:
801*da0073e9SAndroid Build Coastguard Worker            parent_class = "at::impl::MetaBase"
802*da0073e9SAndroid Build Coastguard Worker        meta_return = "void"
803*da0073e9SAndroid Build Coastguard Worker        precomputed = g.out.precomputed if g.structured else None
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        if precomputed:
806*da0073e9SAndroid Build Coastguard Worker            # Generate the template declaration with one bool parameter for each
807*da0073e9SAndroid Build Coastguard Worker            # precomputed element. Each parameter is true if the corresponding (in
808*da0073e9SAndroid Build Coastguard Worker            # terms of position) precomputed element has been set.
809*da0073e9SAndroid Build Coastguard Worker            precomputed_values = [*precomputed.replace.values(), precomputed.add]
810*da0073e9SAndroid Build Coastguard Worker            precomputed_elements = [
811*da0073e9SAndroid Build Coastguard Worker                elem for replace_list in precomputed_values for elem in replace_list
812*da0073e9SAndroid Build Coastguard Worker            ]
813*da0073e9SAndroid Build Coastguard Worker            precomputed_template_parameters = [
814*da0073e9SAndroid Build Coastguard Worker                elem.name.upper() for elem in precomputed_elements
815*da0073e9SAndroid Build Coastguard Worker            ]
816*da0073e9SAndroid Build Coastguard Worker            precomputed_template_params_str = ", ".join(
817*da0073e9SAndroid Build Coastguard Worker                f"bool {param} = false" for param in precomputed_template_parameters
818*da0073e9SAndroid Build Coastguard Worker            )
819*da0073e9SAndroid Build Coastguard Worker            precompute_template_decl = f"template <{precomputed_template_params_str}>"
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker            # Generate a string containing declarations of all precomputed elements.
822*da0073e9SAndroid Build Coastguard Worker            precomputed_elements_with_cpp_types = [
823*da0073e9SAndroid Build Coastguard Worker                structured.argument_type(elem, binds=elem.name)
824*da0073e9SAndroid Build Coastguard Worker                for elem in precomputed_elements
825*da0073e9SAndroid Build Coastguard Worker            ]
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker            precomputed_elements_decl = ";\n".join(
828*da0073e9SAndroid Build Coastguard Worker                f"{elem.cpp_type(strip_ref=True)} {elem.name}"
829*da0073e9SAndroid Build Coastguard Worker                for elem in precomputed_elements_with_cpp_types
830*da0073e9SAndroid Build Coastguard Worker            )
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker            # Generate "setter" methods for each precomputed element. Each method will return
833*da0073e9SAndroid Build Coastguard Worker            # a new instance of precompute_out with the template parameter that corresponds to
834*da0073e9SAndroid Build Coastguard Worker            # the member set by the method to true (to indicate that it has been set).
835*da0073e9SAndroid Build Coastguard Worker            setter_methods = []
836*da0073e9SAndroid Build Coastguard Worker            for i, elem in enumerate(precomputed_elements):
837*da0073e9SAndroid Build Coastguard Worker                # Generate the signature. The return type will be the same
838*da0073e9SAndroid Build Coastguard Worker                # as the type of `this` but with the template parameter
839*da0073e9SAndroid Build Coastguard Worker                # corresponding to the element set by this method set to true.
840*da0073e9SAndroid Build Coastguard Worker                # The assert generated below will ensure that this template
841*da0073e9SAndroid Build Coastguard Worker                # parameter is false on the type of `this`.
842*da0073e9SAndroid Build Coastguard Worker                return_ty_templates = ", ".join(
843*da0073e9SAndroid Build Coastguard Worker                    precomputed_template_parameters[:i]
844*da0073e9SAndroid Build Coastguard Worker                    + ["true"]
845*da0073e9SAndroid Build Coastguard Worker                    + precomputed_template_parameters[i + 1 :]
846*da0073e9SAndroid Build Coastguard Worker                )
847*da0073e9SAndroid Build Coastguard Worker                return_ty = f"precompute_out<{return_ty_templates}>"
848*da0073e9SAndroid Build Coastguard Worker                elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
849*da0073e9SAndroid Build Coastguard Worker                    strip_ref=True
850*da0073e9SAndroid Build Coastguard Worker                )
851*da0073e9SAndroid Build Coastguard Worker                signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker                # Generate an assert which checks that the
854*da0073e9SAndroid Build Coastguard Worker                # template parameter corresponding to the precomputed
855*da0073e9SAndroid Build Coastguard Worker                # element that is set by this method is false on the
856*da0073e9SAndroid Build Coastguard Worker                # class corresponding to the object that `this` points to.
857*da0073e9SAndroid Build Coastguard Worker                # This ensures that each element can be set only once.
858*da0073e9SAndroid Build Coastguard Worker                assert_msg = f'"{elem.name} already set"'
859*da0073e9SAndroid Build Coastguard Worker                assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker                # Generate the new object construction block. All state
862*da0073e9SAndroid Build Coastguard Worker                # except the element that this method sets is copied from the
863*da0073e9SAndroid Build Coastguard Worker                # object that `this` points to. The value for the element that
864*da0073e9SAndroid Build Coastguard Worker                # the method sets is taken from a method parameter.
865*da0073e9SAndroid Build Coastguard Worker                construction_stmts = []
866*da0073e9SAndroid Build Coastguard Worker                construction_stmts.append(f"{return_ty} ret;")
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker                for j, elem in enumerate(precomputed_elements):
869*da0073e9SAndroid Build Coastguard Worker                    if i == j:
870*da0073e9SAndroid Build Coastguard Worker                        construction_stmts.append(f"ret.{elem.name} = value;")
871*da0073e9SAndroid Build Coastguard Worker                    else:
872*da0073e9SAndroid Build Coastguard Worker                        construction_stmts.append(
873*da0073e9SAndroid Build Coastguard Worker                            f"ret.{elem.name} = this->{elem.name};"
874*da0073e9SAndroid Build Coastguard Worker                        )
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker                construction_stmts.append("return ret;")
877*da0073e9SAndroid Build Coastguard Worker                construction_block = "\n".join(construction_stmts)
878*da0073e9SAndroid Build Coastguard Worker
879*da0073e9SAndroid Build Coastguard Worker                setter_methods.append(
880*da0073e9SAndroid Build Coastguard Worker                    f"""
881*da0073e9SAndroid Build Coastguard Worker                    {signature} {{
882*da0073e9SAndroid Build Coastguard Worker                        {assert_stmt}
883*da0073e9SAndroid Build Coastguard Worker                        {construction_block}
884*da0073e9SAndroid Build Coastguard Worker                    }}
885*da0073e9SAndroid Build Coastguard Worker                """
886*da0073e9SAndroid Build Coastguard Worker                )
887*da0073e9SAndroid Build Coastguard Worker            setter_methods_decl = "\n".join(setter_methods)
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker            # Meta should return an instance of the struct containing the precomputed elements.
890*da0073e9SAndroid Build Coastguard Worker            meta_return_template_params = ", ".join(
891*da0073e9SAndroid Build Coastguard Worker                ["true"] * len(precomputed_template_parameters)
892*da0073e9SAndroid Build Coastguard Worker            )
893*da0073e9SAndroid Build Coastguard Worker            # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
894*da0073e9SAndroid Build Coastguard Worker            # type (which has a variable number of template parameters).
895*da0073e9SAndroid Build Coastguard Worker            meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
896*da0073e9SAndroid Build Coastguard Worker            meta_return = "meta_return_ty"
897*da0073e9SAndroid Build Coastguard Worker            precomputed_decl = f"""
898*da0073e9SAndroid Build Coastguard Worker                {precompute_template_decl}
899*da0073e9SAndroid Build Coastguard Worker                struct TORCH_API precompute_out {{
900*da0073e9SAndroid Build Coastguard Worker                    {setter_methods_decl}
901*da0073e9SAndroid Build Coastguard Worker                    {precomputed_elements_decl};
902*da0073e9SAndroid Build Coastguard Worker            }};"""
903*da0073e9SAndroid Build Coastguard Worker        else:
904*da0073e9SAndroid Build Coastguard Worker            meta_return_typedef = ""
905*da0073e9SAndroid Build Coastguard Worker            precomputed_decl = ""
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        return f"""\
908*da0073e9SAndroid Build Coastguard Workerstruct TORCH_API structured_{name} : public {parent_class} {{
909*da0073e9SAndroid Build Coastguard Worker    {precomputed_decl}
910*da0073e9SAndroid Build Coastguard Worker    {meta_return_typedef}
911*da0073e9SAndroid Build Coastguard Worker    {meta_return} meta({args_str});
912*da0073e9SAndroid Build Coastguard Worker}};
913*da0073e9SAndroid Build Coastguard Worker"""
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Workerdef needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
917*da0073e9SAndroid Build Coastguard Worker    name = str(f.func.name.name)
918*da0073e9SAndroid Build Coastguard Worker    if name.endswith("_like") or name.startswith("new_"):
919*da0073e9SAndroid Build Coastguard Worker        return False
920*da0073e9SAndroid Build Coastguard Worker    if f.func.arguments.tensor_options is None:
921*da0073e9SAndroid Build Coastguard Worker        return False
922*da0073e9SAndroid Build Coastguard Worker    return selector.is_native_function_selected(f)
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker# Generates RegisterBackendSelect.cpp, a series of kernels which provide
926*da0073e9SAndroid Build Coastguard Worker# specialized computation of dispatch key for operator signatures which cannot
927*da0073e9SAndroid Build Coastguard Worker# be easily done automatically using templating.
928*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
929*da0073e9SAndroid Build Coastguard Workerclass ComputeBackendSelect:
930*da0073e9SAndroid Build Coastguard Worker    target: Literal[Target.DEFINITION, Target.REGISTRATION]
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker    # Selector object to determine which operators to generate
933*da0073e9SAndroid Build Coastguard Worker    # registration code for.
934*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker    @method_with_native_function
937*da0073e9SAndroid Build Coastguard Worker    def __call__(self, f: NativeFunction) -> str | None:
938*da0073e9SAndroid Build Coastguard Worker        if not needs_backend_select(f, self.selector):
939*da0073e9SAndroid Build Coastguard Worker            return None
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        name = native.name(f.func)
942*da0073e9SAndroid Build Coastguard Worker        # BackendSelect can go to Meta, so it must preserve symints
943*da0073e9SAndroid Build Coastguard Worker        native_sig = NativeSignature(f.func, symint=True)
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker        native_tensor_args = [
946*da0073e9SAndroid Build Coastguard Worker            a
947*da0073e9SAndroid Build Coastguard Worker            for a in native_sig.arguments()
948*da0073e9SAndroid Build Coastguard Worker            if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
949*da0073e9SAndroid Build Coastguard Worker        ]
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker        dispatcher_sig = DispatcherSignature.from_schema(f.func)
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker        sig: NativeSignature | DispatcherSignature
954*da0073e9SAndroid Build Coastguard Worker        sig = dispatcher_sig
955*da0073e9SAndroid Build Coastguard Worker        dispatcher_exprs = dispatcher_sig.exprs()
956*da0073e9SAndroid Build Coastguard Worker        dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker        if self.target is Target.DEFINITION:
959*da0073e9SAndroid Build Coastguard Worker            # I don't think there's actually a good reason to generate
960*da0073e9SAndroid Build Coastguard Worker            # these two cases differently
961*da0073e9SAndroid Build Coastguard Worker            # The first case could probably be improved though- it calls computeDispatchKeySet(),
962*da0073e9SAndroid Build Coastguard Worker            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
963*da0073e9SAndroid Build Coastguard Worker            if native_tensor_args:
964*da0073e9SAndroid Build Coastguard Worker                assert f.func.arguments.has_tensor_arg()
965*da0073e9SAndroid Build Coastguard Worker                tensor_args = ", ".join(a.name for a in native_tensor_args)
966*da0073e9SAndroid Build Coastguard Worker                compute_dk = f"""\
967*da0073e9SAndroid Build Coastguard WorkerDispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
968*da0073e9SAndroid Build Coastguard WorkerDispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
969*da0073e9SAndroid Build Coastguard WorkerDispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
970*da0073e9SAndroid Build Coastguard Worker            else:
971*da0073e9SAndroid Build Coastguard Worker                assert not f.func.arguments.has_tensor_arg()
972*da0073e9SAndroid Build Coastguard Worker                compute_dk = (
973*da0073e9SAndroid Build Coastguard Worker                    f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
974*da0073e9SAndroid Build Coastguard Worker                )
975*da0073e9SAndroid Build Coastguard Worker            return f"""\
976*da0073e9SAndroid Build Coastguard Worker// aten::{f.func}
977*da0073e9SAndroid Build Coastguard WorkerC10_ALWAYS_INLINE
978*da0073e9SAndroid Build Coastguard Worker{sig.defn(name)} {{
979*da0073e9SAndroid Build Coastguard Worker  {compute_dk}
980*da0073e9SAndroid Build Coastguard Worker  return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
981*da0073e9SAndroid Build Coastguard Worker      _dk, {', '.join(a.expr for a in dispatcher_exprs)});
982*da0073e9SAndroid Build Coastguard Worker}}
983*da0073e9SAndroid Build Coastguard Worker"""
984*da0073e9SAndroid Build Coastguard Worker        elif self.target is Target.REGISTRATION:
985*da0073e9SAndroid Build Coastguard Worker            return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
986*da0073e9SAndroid Build Coastguard Worker        else:
987*da0073e9SAndroid Build Coastguard Worker            assert_never(self.target)
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
991*da0073e9SAndroid Build Coastguard Worker#
992*da0073e9SAndroid Build Coastguard Worker#                       YAML CODE GENERATION
993*da0073e9SAndroid Build Coastguard Worker#
994*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Workerdef format_yaml(data: object) -> str:
998*da0073e9SAndroid Build Coastguard Worker    # Ignore alias in Dumper
999*da0073e9SAndroid Build Coastguard Worker    YamlDumper.ignore_aliases = lambda self, data: True  # type: ignore[assignment]
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker    # Support serializing OrderedDict
1002*da0073e9SAndroid Build Coastguard Worker    def dict_representer(dumper: Any, data: Any) -> Any:
1003*da0073e9SAndroid Build Coastguard Worker        return dumper.represent_dict(data.items())
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker    YamlDumper.add_representer(OrderedDict, dict_representer)  # type: ignore[no-untyped-call]
1006*da0073e9SAndroid Build Coastguard Worker    # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
1007*da0073e9SAndroid Build Coastguard Worker    # width=1e9 turns off optional line breaks and improves
1008*da0073e9SAndroid Build Coastguard Worker    # the portability of the outputted yaml.
1009*da0073e9SAndroid Build Coastguard Worker    return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9)  # type: ignore[no-any-return, call-overload]
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker# For some reason, some defaults we write to YAML are written as native
1013*da0073e9SAndroid Build Coastguard Worker# YAML objects, rather than doing them uniformly as strings.  This
1014*da0073e9SAndroid Build Coastguard Worker# function detects those cases and converts them into native Python
1015*da0073e9SAndroid Build Coastguard Worker# objects.
1016*da0073e9SAndroid Build Coastguard Workerdef pythonify_default(s: str) -> object:
1017*da0073e9SAndroid Build Coastguard Worker    if s == "true":
1018*da0073e9SAndroid Build Coastguard Worker        return True
1019*da0073e9SAndroid Build Coastguard Worker    elif s == "false":
1020*da0073e9SAndroid Build Coastguard Worker        return False
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker    try:
1023*da0073e9SAndroid Build Coastguard Worker        return int(s)
1024*da0073e9SAndroid Build Coastguard Worker    except ValueError:
1025*da0073e9SAndroid Build Coastguard Worker        try:
1026*da0073e9SAndroid Build Coastguard Worker            return float(s)
1027*da0073e9SAndroid Build Coastguard Worker        except ValueError:
1028*da0073e9SAndroid Build Coastguard Worker            return s
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker# What is a dynamic type?  Over time, the semantic meaning of
1032*da0073e9SAndroid Build Coastguard Worker# dynamic type has degraded to meaninglessness (in the old days,
1033*da0073e9SAndroid Build Coastguard Worker# it captured dtype-ness of types, but that has gone away with
1034*da0073e9SAndroid Build Coastguard Worker# the removal of TH).  These days, it's mostly the same thing as
1035*da0073e9SAndroid Build Coastguard Worker# the C++ API argument type, except that Tensor and Tensor?
1036*da0073e9SAndroid Build Coastguard Worker# arguments simply present as Tensor.
1037*da0073e9SAndroid Build Coastguard Worker#
1038*da0073e9SAndroid Build Coastguard Worker# TODO: Get rid of dynamic_type, after getting tools/autograd
1039*da0073e9SAndroid Build Coastguard Worker# to use the new codegen framework
1040*da0073e9SAndroid Build Coastguard Workerdef dynamic_type(t: Type) -> str:
1041*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, OptionalType):
1042*da0073e9SAndroid Build Coastguard Worker        return dynamic_type(t.elem)
1043*da0073e9SAndroid Build Coastguard Worker    # Note we don't use t.is_tensor_like() here because it would
1044*da0073e9SAndroid Build Coastguard Worker    # also include Tensor[]
1045*da0073e9SAndroid Build Coastguard Worker    if str(t) == "Tensor":
1046*da0073e9SAndroid Build Coastguard Worker        return "at::Tensor"
1047*da0073e9SAndroid Build Coastguard Worker    # This is a legacy concept, so never report SymInt
1048*da0073e9SAndroid Build Coastguard Worker    return cpp.argumenttype_type(
1049*da0073e9SAndroid Build Coastguard Worker        t, mutable=False, binds="__placeholder__", symint=False
1050*da0073e9SAndroid Build Coastguard Worker    ).cpp_type()
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Workerdef compute_method_of_yaml(variants: set[Variant]) -> list[str]:
1054*da0073e9SAndroid Build Coastguard Worker    # This is written out explicitly to ensure that Tensor and
1055*da0073e9SAndroid Build Coastguard Worker    # namespace are put into the list in the right order
1056*da0073e9SAndroid Build Coastguard Worker    method_of = ["Type"]
1057*da0073e9SAndroid Build Coastguard Worker    if Variant.method in variants:
1058*da0073e9SAndroid Build Coastguard Worker        method_of.append("Tensor")
1059*da0073e9SAndroid Build Coastguard Worker    if Variant.function in variants:
1060*da0073e9SAndroid Build Coastguard Worker        method_of.append("namespace")
1061*da0073e9SAndroid Build Coastguard Worker    return method_of
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Workerdef compute_returns_yaml(
1065*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
1066*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[dict[str, str]], dict[str, str]]:
1067*da0073e9SAndroid Build Coastguard Worker    # Note [name and field_name]
1068*da0073e9SAndroid Build Coastguard Worker    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
1069*da0073e9SAndroid Build Coastguard Worker    # To understand name_to_field_name, we must first talk about this
1070*da0073e9SAndroid Build Coastguard Worker    # schema:
1071*da0073e9SAndroid Build Coastguard Worker    #
1072*da0073e9SAndroid Build Coastguard Worker    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
1073*da0073e9SAndroid Build Coastguard Worker    #
1074*da0073e9SAndroid Build Coastguard Worker    # There is something very odd about this schema: it is an out
1075*da0073e9SAndroid Build Coastguard Worker    # variant of the function (that is to say, it will convert into
1076*da0073e9SAndroid Build Coastguard Worker    # at::lstsq_out() in the C++ API), but the names of the output
1077*da0073e9SAndroid Build Coastguard Worker    # return arguments don't match the keyword argument names of
1078*da0073e9SAndroid Build Coastguard Worker    # the inputs.  It TURNS OUT that in this situation, the historical
1079*da0073e9SAndroid Build Coastguard Worker    # Declarations.yaml we want to output is this (abbreviated to
1080*da0073e9SAndroid Build Coastguard Worker    # only show relevant fields):
1081*da0073e9SAndroid Build Coastguard Worker    #
1082*da0073e9SAndroid Build Coastguard Worker    #   arguments:
1083*da0073e9SAndroid Build Coastguard Worker    #     ...
1084*da0073e9SAndroid Build Coastguard Worker    #   - field_name: solution
1085*da0073e9SAndroid Build Coastguard Worker    #     name: X
1086*da0073e9SAndroid Build Coastguard Worker    #   - field_name: QR
1087*da0073e9SAndroid Build Coastguard Worker    #     name: qr
1088*da0073e9SAndroid Build Coastguard Worker    #     ...
1089*da0073e9SAndroid Build Coastguard Worker    #
1090*da0073e9SAndroid Build Coastguard Worker    #   returns:
1091*da0073e9SAndroid Build Coastguard Worker    #   - field_name: solution
1092*da0073e9SAndroid Build Coastguard Worker    #     name: X
1093*da0073e9SAndroid Build Coastguard Worker    #   - field_name: QR
1094*da0073e9SAndroid Build Coastguard Worker    #     name: qr
1095*da0073e9SAndroid Build Coastguard Worker    #
1096*da0073e9SAndroid Build Coastguard Worker    # The name of the return fields is stored in 'field_name', and the
1097*da0073e9SAndroid Build Coastguard Worker    # name of the arguments is stored in 'name'.  So when we process
1098*da0073e9SAndroid Build Coastguard Worker    # arguments, we need a way to get at the corresponding return.  At
1099*da0073e9SAndroid Build Coastguard Worker    # the moment, this is most conveniently done by constructing a
1100*da0073e9SAndroid Build Coastguard Worker    # mapping from name (the argument concept) to field_name (the
1101*da0073e9SAndroid Build Coastguard Worker    # return concept) while processing return arguments, since we don't
1102*da0073e9SAndroid Build Coastguard Worker    # directly maintain this correspondence in the modeling of function
1103*da0073e9SAndroid Build Coastguard Worker    # schema itself.
1104*da0073e9SAndroid Build Coastguard Worker    #
1105*da0073e9SAndroid Build Coastguard Worker    # See also https://github.com/pytorch/pytorch/issues/43114
1106*da0073e9SAndroid Build Coastguard Worker    name_to_field_name: dict[str, str] = {}
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker    # Compute the returns field of the YAML entry
1109*da0073e9SAndroid Build Coastguard Worker    names = cpp.return_names(f)
1110*da0073e9SAndroid Build Coastguard Worker    returns = []
1111*da0073e9SAndroid Build Coastguard Worker    for i, (r, name) in enumerate(zip(f.func.returns, names)):
1112*da0073e9SAndroid Build Coastguard Worker        ret = {
1113*da0073e9SAndroid Build Coastguard Worker            "dynamic_type": dynamic_type(r.type),
1114*da0073e9SAndroid Build Coastguard Worker            "name": name,
1115*da0073e9SAndroid Build Coastguard Worker            # legacy, report ints
1116*da0073e9SAndroid Build Coastguard Worker            "type": cpp.return_type(r, symint=False).cpp_type(),
1117*da0073e9SAndroid Build Coastguard Worker        }
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker        if r.name:
1120*da0073e9SAndroid Build Coastguard Worker            # See Note [name and field_name]
1121*da0073e9SAndroid Build Coastguard Worker            ret["field_name"] = r.name
1122*da0073e9SAndroid Build Coastguard Worker            if f.func.is_out_fn():
1123*da0073e9SAndroid Build Coastguard Worker                name_to_field_name[f.func.arguments.out[i].name] = r.name
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker        returns.append(ret)
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    return returns, name_to_field_name
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker# arguments in yaml roughly corresponds to the public C++ API
1131*da0073e9SAndroid Build Coastguard Workerdef compute_cpp_argument_yaml(
1132*da0073e9SAndroid Build Coastguard Worker    cpp_a: Binding,
1133*da0073e9SAndroid Build Coastguard Worker    *,
1134*da0073e9SAndroid Build Coastguard Worker    schema_order: bool,
1135*da0073e9SAndroid Build Coastguard Worker    kwarg_only_set: set[str],
1136*da0073e9SAndroid Build Coastguard Worker    out_arg_set: set[str],
1137*da0073e9SAndroid Build Coastguard Worker    name_to_field_name: dict[str, str],
1138*da0073e9SAndroid Build Coastguard Worker) -> object:
1139*da0073e9SAndroid Build Coastguard Worker    if isinstance(cpp_a.argument, TensorOptionsArguments):
1140*da0073e9SAndroid Build Coastguard Worker        arg: dict[str, object] = {
1141*da0073e9SAndroid Build Coastguard Worker            "annotation": None,
1142*da0073e9SAndroid Build Coastguard Worker            "dynamic_type": "at::TensorOptions",
1143*da0073e9SAndroid Build Coastguard Worker            "is_nullable": False,
1144*da0073e9SAndroid Build Coastguard Worker            "name": cpp_a.name,
1145*da0073e9SAndroid Build Coastguard Worker            "type": cpp_a.type,
1146*da0073e9SAndroid Build Coastguard Worker            "kwarg_only": True,
1147*da0073e9SAndroid Build Coastguard Worker        }
1148*da0073e9SAndroid Build Coastguard Worker        if cpp_a.default is not None:
1149*da0073e9SAndroid Build Coastguard Worker            arg["default"] = cpp_a.default
1150*da0073e9SAndroid Build Coastguard Worker        return arg
1151*da0073e9SAndroid Build Coastguard Worker    elif isinstance(cpp_a.argument, SelfArgument):
1152*da0073e9SAndroid Build Coastguard Worker        raise AssertionError
1153*da0073e9SAndroid Build Coastguard Worker    elif isinstance(cpp_a.argument, Argument):
1154*da0073e9SAndroid Build Coastguard Worker        return compute_argument_yaml(
1155*da0073e9SAndroid Build Coastguard Worker            cpp_a.argument,
1156*da0073e9SAndroid Build Coastguard Worker            schema_order=schema_order,
1157*da0073e9SAndroid Build Coastguard Worker            kwarg_only_set=kwarg_only_set,
1158*da0073e9SAndroid Build Coastguard Worker            out_arg_set=out_arg_set,
1159*da0073e9SAndroid Build Coastguard Worker            name_to_field_name=name_to_field_name,
1160*da0073e9SAndroid Build Coastguard Worker        )
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Workerdef compute_argument_yaml(
1164*da0073e9SAndroid Build Coastguard Worker    a: Argument,
1165*da0073e9SAndroid Build Coastguard Worker    *,
1166*da0073e9SAndroid Build Coastguard Worker    schema_order: bool,
1167*da0073e9SAndroid Build Coastguard Worker    kwarg_only_set: set[str],
1168*da0073e9SAndroid Build Coastguard Worker    out_arg_set: set[str],
1169*da0073e9SAndroid Build Coastguard Worker    name_to_field_name: dict[str, str],
1170*da0073e9SAndroid Build Coastguard Worker) -> object:
1171*da0073e9SAndroid Build Coastguard Worker    arg: dict[str, object] = {
1172*da0073e9SAndroid Build Coastguard Worker        "annotation": str(a.annotation) if a.annotation else None,
1173*da0073e9SAndroid Build Coastguard Worker        "dynamic_type": dynamic_type(a.type),
1174*da0073e9SAndroid Build Coastguard Worker        "is_nullable": a.type.is_nullable(),
1175*da0073e9SAndroid Build Coastguard Worker        "name": a.name,
1176*da0073e9SAndroid Build Coastguard Worker        # legacy, report ints
1177*da0073e9SAndroid Build Coastguard Worker        "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
1178*da0073e9SAndroid Build Coastguard Worker    }
1179*da0073e9SAndroid Build Coastguard Worker    if a.default is not None:
1180*da0073e9SAndroid Build Coastguard Worker        arg["default"] = pythonify_default(
1181*da0073e9SAndroid Build Coastguard Worker            cpp.default_expr(a.default, a.type, symint=False)
1182*da0073e9SAndroid Build Coastguard Worker        )
1183*da0073e9SAndroid Build Coastguard Worker    if a.name in kwarg_only_set:
1184*da0073e9SAndroid Build Coastguard Worker        arg["kwarg_only"] = True
1185*da0073e9SAndroid Build Coastguard Worker    if a.name in out_arg_set:
1186*da0073e9SAndroid Build Coastguard Worker        arg["output"] = True
1187*da0073e9SAndroid Build Coastguard Worker        arg["allocate"] = True
1188*da0073e9SAndroid Build Coastguard Worker        # See Note [name and field_name]
1189*da0073e9SAndroid Build Coastguard Worker        if a.name in name_to_field_name:
1190*da0073e9SAndroid Build Coastguard Worker            arg["field_name"] = name_to_field_name[a.name]
1191*da0073e9SAndroid Build Coastguard Worker    # Historically, booleans don't get their size recorded, because it
1192*da0073e9SAndroid Build Coastguard Worker    # is already built into the cpp type (e.g., std::array<bool, 4>)
1193*da0073e9SAndroid Build Coastguard Worker    l = a.type.is_list_like()
1194*da0073e9SAndroid Build Coastguard Worker    if l is not None and l.size is not None and str(l.elem) != "bool":
1195*da0073e9SAndroid Build Coastguard Worker        arg["size"] = l.size
1196*da0073e9SAndroid Build Coastguard Worker    return arg
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker@with_native_function
1200*da0073e9SAndroid Build Coastguard Workerdef compute_declaration_yaml(f: NativeFunction) -> object:
1201*da0073e9SAndroid Build Coastguard Worker    returns, name_to_field_name = compute_returns_yaml(f)
1202*da0073e9SAndroid Build Coastguard Worker
1203*da0073e9SAndroid Build Coastguard Worker    # These sets are used to conveniently test if an argument is a
1204*da0073e9SAndroid Build Coastguard Worker    # kwarg-only or out argument
1205*da0073e9SAndroid Build Coastguard Worker    kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
1206*da0073e9SAndroid Build Coastguard Worker    out_arg_set = {a.name for a in f.func.arguments.out}
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker    sig_group = CppSignatureGroup.from_native_function(
1209*da0073e9SAndroid Build Coastguard Worker        f, method=False, fallback_binding=False
1210*da0073e9SAndroid Build Coastguard Worker    )
1211*da0073e9SAndroid Build Coastguard Worker    cpp_args = sig_group.signature.arguments()
1212*da0073e9SAndroid Build Coastguard Worker    arguments = [
1213*da0073e9SAndroid Build Coastguard Worker        compute_cpp_argument_yaml(
1214*da0073e9SAndroid Build Coastguard Worker            cpp_a,
1215*da0073e9SAndroid Build Coastguard Worker            schema_order=False,
1216*da0073e9SAndroid Build Coastguard Worker            kwarg_only_set=kwarg_only_set,
1217*da0073e9SAndroid Build Coastguard Worker            out_arg_set=out_arg_set,
1218*da0073e9SAndroid Build Coastguard Worker            name_to_field_name=name_to_field_name,
1219*da0073e9SAndroid Build Coastguard Worker        )
1220*da0073e9SAndroid Build Coastguard Worker        for cpp_a in cpp_args
1221*da0073e9SAndroid Build Coastguard Worker    ]
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker    schema_order_jit_arguments = list(f.func.schema_order_arguments())
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker    schema_order_arguments = [
1226*da0073e9SAndroid Build Coastguard Worker        compute_argument_yaml(
1227*da0073e9SAndroid Build Coastguard Worker            a,
1228*da0073e9SAndroid Build Coastguard Worker            schema_order=True,
1229*da0073e9SAndroid Build Coastguard Worker            kwarg_only_set=kwarg_only_set,
1230*da0073e9SAndroid Build Coastguard Worker            out_arg_set=out_arg_set,
1231*da0073e9SAndroid Build Coastguard Worker            name_to_field_name=name_to_field_name,
1232*da0073e9SAndroid Build Coastguard Worker        )
1233*da0073e9SAndroid Build Coastguard Worker        for a in schema_order_jit_arguments
1234*da0073e9SAndroid Build Coastguard Worker    ]
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker    cpp_schema_order_types = [
1237*da0073e9SAndroid Build Coastguard Worker        # NB: method here doesn't matter
1238*da0073e9SAndroid Build Coastguard Worker        r.type
1239*da0073e9SAndroid Build Coastguard Worker        for a in schema_order_jit_arguments
1240*da0073e9SAndroid Build Coastguard Worker        for r in cpp.argument(
1241*da0073e9SAndroid Build Coastguard Worker            a,
1242*da0073e9SAndroid Build Coastguard Worker            method=False,
1243*da0073e9SAndroid Build Coastguard Worker            cpp_no_default_args=set(),
1244*da0073e9SAndroid Build Coastguard Worker            faithful=False,
1245*da0073e9SAndroid Build Coastguard Worker            symint=False,
1246*da0073e9SAndroid Build Coastguard Worker            has_tensor_options=False,
1247*da0073e9SAndroid Build Coastguard Worker        )
1248*da0073e9SAndroid Build Coastguard Worker    ]
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker    # legacy, report ints
1251*da0073e9SAndroid Build Coastguard Worker    cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
1252*da0073e9SAndroid Build Coastguard Worker    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker    is_factory_method = (
1255*da0073e9SAndroid Build Coastguard Worker        any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
1256*da0073e9SAndroid Build Coastguard Worker        and Variant.method not in f.variants
1257*da0073e9SAndroid Build Coastguard Worker    )
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker    return OrderedDict(
1260*da0073e9SAndroid Build Coastguard Worker        [
1261*da0073e9SAndroid Build Coastguard Worker            ("name", cpp.name(f.func)),
1262*da0073e9SAndroid Build Coastguard Worker            ("operator_name", str(f.func.name.name)),
1263*da0073e9SAndroid Build Coastguard Worker            ("overload_name", str(f.func.name.overload_name)),
1264*da0073e9SAndroid Build Coastguard Worker            ("manual_kernel_registration", f.manual_kernel_registration),
1265*da0073e9SAndroid Build Coastguard Worker            (
1266*da0073e9SAndroid Build Coastguard Worker                "category_override",
1267*da0073e9SAndroid Build Coastguard Worker                f.category_override if f.category_override is not None else "",
1268*da0073e9SAndroid Build Coastguard Worker            ),
1269*da0073e9SAndroid Build Coastguard Worker            ("schema_string", f"aten::{f.func}"),
1270*da0073e9SAndroid Build Coastguard Worker            ("arguments", arguments),
1271*da0073e9SAndroid Build Coastguard Worker            ("schema_order_cpp_signature", schema_order_cpp_signature),
1272*da0073e9SAndroid Build Coastguard Worker            ("schema_order_arguments", schema_order_arguments),
1273*da0073e9SAndroid Build Coastguard Worker            ("method_of", compute_method_of_yaml(f.variants)),
1274*da0073e9SAndroid Build Coastguard Worker            ("mode", "native"),
1275*da0073e9SAndroid Build Coastguard Worker            ("python_module", "" if f.python_module is None else f.python_module),
1276*da0073e9SAndroid Build Coastguard Worker            ("returns", returns),
1277*da0073e9SAndroid Build Coastguard Worker            ("inplace", f.func.name.name.inplace),
1278*da0073e9SAndroid Build Coastguard Worker            ("is_factory_method", is_factory_method),
1279*da0073e9SAndroid Build Coastguard Worker            ("abstract", f.is_abstract),
1280*da0073e9SAndroid Build Coastguard Worker            ("device_guard", f.device_guard),
1281*da0073e9SAndroid Build Coastguard Worker            ("with_gil", False),
1282*da0073e9SAndroid Build Coastguard Worker            ("deprecated", False),
1283*da0073e9SAndroid Build Coastguard Worker            ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
1284*da0073e9SAndroid Build Coastguard Worker        ]
1285*da0073e9SAndroid Build Coastguard Worker    )
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker# See Note [Auto generated composite kernels]
1289*da0073e9SAndroid Build Coastguard Workerdef has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
1290*da0073e9SAndroid Build Coastguard Worker    return (f.structured or f.structured_delegate is not None) and (
1291*da0073e9SAndroid Build Coastguard Worker        f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
1292*da0073e9SAndroid Build Coastguard Worker    )
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker
1295*da0073e9SAndroid Build Coastguard Worker@with_native_function_and_indices
1296*da0073e9SAndroid Build Coastguard Workerdef compute_registration_declarations(
1297*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
1298*da0073e9SAndroid Build Coastguard Worker) -> str:
1299*da0073e9SAndroid Build Coastguard Worker    name = dispatcher.name(f.func)
1300*da0073e9SAndroid Build Coastguard Worker    returns_type = dispatcher.returns_type(
1301*da0073e9SAndroid Build Coastguard Worker        f.func.returns
1302*da0073e9SAndroid Build Coastguard Worker    ).cpp_type_registration_declarations()
1303*da0073e9SAndroid Build Coastguard Worker    args = dispatcher.arguments(f.func)
1304*da0073e9SAndroid Build Coastguard Worker    args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
1305*da0073e9SAndroid Build Coastguard Worker    comment_data: dict[str, str] = {
1306*da0073e9SAndroid Build Coastguard Worker        "schema": f"aten::{f.func}",
1307*da0073e9SAndroid Build Coastguard Worker        # TODO: What exactly is the semantics of the 'dispatch' field?
1308*da0073e9SAndroid Build Coastguard Worker        "dispatch": str(
1309*da0073e9SAndroid Build Coastguard Worker            {k for k, v in backend_indices.items() if v.has_kernel(f)}
1310*da0073e9SAndroid Build Coastguard Worker            != {DispatchKey.CompositeImplicitAutograd}
1311*da0073e9SAndroid Build Coastguard Worker            and {k for k, v in backend_indices.items() if v.has_kernel(f)}
1312*da0073e9SAndroid Build Coastguard Worker            != {
1313*da0073e9SAndroid Build Coastguard Worker                DispatchKey.CompositeImplicitAutograd,
1314*da0073e9SAndroid Build Coastguard Worker                DispatchKey.CompositeImplicitAutogradNestedTensor,
1315*da0073e9SAndroid Build Coastguard Worker            }
1316*da0073e9SAndroid Build Coastguard Worker        ),
1317*da0073e9SAndroid Build Coastguard Worker        "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
1318*da0073e9SAndroid Build Coastguard Worker    }
1319*da0073e9SAndroid Build Coastguard Worker    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
1320*da0073e9SAndroid Build Coastguard Worker"""
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1324*da0073e9SAndroid Build Coastguard Worker#
1325*da0073e9SAndroid Build Coastguard Worker#                           RUN IT ALL
1326*da0073e9SAndroid Build Coastguard Worker#
1327*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Workerdef get_custom_build_selector(
1331*da0073e9SAndroid Build Coastguard Worker    provided_op_registration_allowlist: list[str] | None,
1332*da0073e9SAndroid Build Coastguard Worker    op_selection_yaml_path: str | None,
1333*da0073e9SAndroid Build Coastguard Worker) -> SelectiveBuilder:
1334*da0073e9SAndroid Build Coastguard Worker    assert not (
1335*da0073e9SAndroid Build Coastguard Worker        provided_op_registration_allowlist is not None
1336*da0073e9SAndroid Build Coastguard Worker        and op_selection_yaml_path is not None
1337*da0073e9SAndroid Build Coastguard Worker    ), (
1338*da0073e9SAndroid Build Coastguard Worker        "Both provided_op_registration_allowlist and "
1339*da0073e9SAndroid Build Coastguard Worker        + "op_selection_yaml_path can NOT be provided at the "
1340*da0073e9SAndroid Build Coastguard Worker        + "same time."
1341*da0073e9SAndroid Build Coastguard Worker    )
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker    op_registration_allowlist: set[str] | None = None
1344*da0073e9SAndroid Build Coastguard Worker    if provided_op_registration_allowlist is not None:
1345*da0073e9SAndroid Build Coastguard Worker        op_registration_allowlist = set(provided_op_registration_allowlist)
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker    if op_registration_allowlist is not None:
1348*da0073e9SAndroid Build Coastguard Worker        selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
1349*da0073e9SAndroid Build Coastguard Worker            op_registration_allowlist,
1350*da0073e9SAndroid Build Coastguard Worker            True,
1351*da0073e9SAndroid Build Coastguard Worker            False,
1352*da0073e9SAndroid Build Coastguard Worker        )
1353*da0073e9SAndroid Build Coastguard Worker    elif op_selection_yaml_path is not None:
1354*da0073e9SAndroid Build Coastguard Worker        selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
1355*da0073e9SAndroid Build Coastguard Worker    else:
1356*da0073e9SAndroid Build Coastguard Worker        selector = SelectiveBuilder.get_nop_selector()
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    return selector
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Workerdef get_grouped_by_view_native_functions(
1362*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
1363*da0073e9SAndroid Build Coastguard Worker) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
1364*da0073e9SAndroid Build Coastguard Worker    def maybe_create_view_group(
1365*da0073e9SAndroid Build Coastguard Worker        d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
1366*da0073e9SAndroid Build Coastguard Worker    ) -> list[NativeFunction | NativeFunctionsViewGroup]:
1367*da0073e9SAndroid Build Coastguard Worker        funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
1368*da0073e9SAndroid Build Coastguard Worker        if ViewSchemaKind.aliasing in d:
1369*da0073e9SAndroid Build Coastguard Worker            view = d.pop(ViewSchemaKind.aliasing)
1370*da0073e9SAndroid Build Coastguard Worker            view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
1371*da0073e9SAndroid Build Coastguard Worker            view_copy = d.pop(SchemaKind.functional, None)
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker            funcs.append(
1374*da0073e9SAndroid Build Coastguard Worker                NativeFunctionsViewGroup(
1375*da0073e9SAndroid Build Coastguard Worker                    view=view,
1376*da0073e9SAndroid Build Coastguard Worker                    view_copy=view_copy,
1377*da0073e9SAndroid Build Coastguard Worker                    view_inplace=view_inplace,
1378*da0073e9SAndroid Build Coastguard Worker                )
1379*da0073e9SAndroid Build Coastguard Worker            )
1380*da0073e9SAndroid Build Coastguard Worker        # Take the remaining functions that weren't part of the view group
1381*da0073e9SAndroid Build Coastguard Worker        # and emit them separately
1382*da0073e9SAndroid Build Coastguard Worker        funcs.extend(d.values())
1383*da0073e9SAndroid Build Coastguard Worker        return funcs
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    grouped_by_views: dict[
1386*da0073e9SAndroid Build Coastguard Worker        FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
1387*da0073e9SAndroid Build Coastguard Worker    ] = defaultdict(dict)
1388*da0073e9SAndroid Build Coastguard Worker    for f in native_functions:
1389*da0073e9SAndroid Build Coastguard Worker        schema = f.func.view_signature()
1390*da0073e9SAndroid Build Coastguard Worker        view_kind: ViewSchemaKind = f.view_schema_kind
1391*da0073e9SAndroid Build Coastguard Worker        # We need to group up ops relevant to the same "view", consisting of:
1392*da0073e9SAndroid Build Coastguard Worker        # view op (ViewSchemaKind.aliasing)
1393*da0073e9SAndroid Build Coastguard Worker        # view_inplace op (ViewSchemaKind.aliasing_inplace)
1394*da0073e9SAndroid Build Coastguard Worker        # view_copy op (SchemaKind.functional)
1395*da0073e9SAndroid Build Coastguard Worker        if view_kind == ViewSchemaKind.non_aliasing:
1396*da0073e9SAndroid Build Coastguard Worker            kind = f.func.kind()
1397*da0073e9SAndroid Build Coastguard Worker            assert kind not in grouped_by_views[schema]
1398*da0073e9SAndroid Build Coastguard Worker            grouped_by_views[schema][kind] = f
1399*da0073e9SAndroid Build Coastguard Worker        else:
1400*da0073e9SAndroid Build Coastguard Worker            assert (
1401*da0073e9SAndroid Build Coastguard Worker                view_kind not in grouped_by_views[schema]
1402*da0073e9SAndroid Build Coastguard Worker            ), f"{view_kind} already in {grouped_by_views[schema].keys()}"
1403*da0073e9SAndroid Build Coastguard Worker            grouped_by_views[schema][view_kind] = f
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker    return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Workerdef get_grouped_native_functions(
1409*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
1410*da0073e9SAndroid Build Coastguard Worker) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1411*da0073e9SAndroid Build Coastguard Worker    def flatten_pre_group(
1412*da0073e9SAndroid Build Coastguard Worker        d: dict[SchemaKind, NativeFunction]
1413*da0073e9SAndroid Build Coastguard Worker    ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1414*da0073e9SAndroid Build Coastguard Worker        r = NativeFunctionsGroup.from_dict(d)
1415*da0073e9SAndroid Build Coastguard Worker        if r is None:
1416*da0073e9SAndroid Build Coastguard Worker            # Invariant: any NativeFunctions that are code-generated
1417*da0073e9SAndroid Build Coastguard Worker            # should have been grouped into NativeFunctionsGroup objects
1418*da0073e9SAndroid Build Coastguard Worker            assert not any("generated" in f.tags for f in d.values())
1419*da0073e9SAndroid Build Coastguard Worker            return list(d.values())
1420*da0073e9SAndroid Build Coastguard Worker        else:
1421*da0073e9SAndroid Build Coastguard Worker            return [r]
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker    # TODO: how come ValuesView isn't a Sequence lol
1424*da0073e9SAndroid Build Coastguard Worker    pre_grouped_native_functions = pre_group_native_functions(native_functions)
1425*da0073e9SAndroid Build Coastguard Worker    return list(
1426*da0073e9SAndroid Build Coastguard Worker        concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
1427*da0073e9SAndroid Build Coastguard Worker    )
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Workerdef get_ns_grouped_kernels(
1431*da0073e9SAndroid Build Coastguard Worker    *,
1432*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1433*da0073e9SAndroid Build Coastguard Worker    backend_indices: dict[DispatchKey, BackendIndex],
1434*da0073e9SAndroid Build Coastguard Worker    native_function_decl_gen: Callable[
1435*da0073e9SAndroid Build Coastguard Worker        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1436*da0073e9SAndroid Build Coastguard Worker    ] = dest.compute_native_function_declaration,
1437*da0073e9SAndroid Build Coastguard Worker) -> dict[str, list[str]]:
1438*da0073e9SAndroid Build Coastguard Worker    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1439*da0073e9SAndroid Build Coastguard Worker    for f in grouped_native_functions:
1440*da0073e9SAndroid Build Coastguard Worker        native_function_namespaces = set()
1441*da0073e9SAndroid Build Coastguard Worker        dispatch_keys = set()
1442*da0073e9SAndroid Build Coastguard Worker        for dispatch_key, backend_idx in backend_indices.items():
1443*da0073e9SAndroid Build Coastguard Worker            backend_metadata = backend_idx.get_kernel(f)
1444*da0073e9SAndroid Build Coastguard Worker            if backend_metadata:
1445*da0073e9SAndroid Build Coastguard Worker                namespace = backend_metadata.cpp_namespace
1446*da0073e9SAndroid Build Coastguard Worker                dispatch_keys.add(dispatch_key)
1447*da0073e9SAndroid Build Coastguard Worker                native_function_namespaces.add(namespace)
1448*da0073e9SAndroid Build Coastguard Worker            else:
1449*da0073e9SAndroid Build Coastguard Worker                namespace = DEFAULT_KERNEL_NAMESPACE
1450*da0073e9SAndroid Build Coastguard Worker            assert (
1451*da0073e9SAndroid Build Coastguard Worker                len(native_function_namespaces) <= 1
1452*da0073e9SAndroid Build Coastguard Worker            ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
1453*da0073e9SAndroid Build Coastguard Worker            ns_grouped_kernels[namespace].extend(
1454*da0073e9SAndroid Build Coastguard Worker                native_function_decl_gen(f, backend_idx)
1455*da0073e9SAndroid Build Coastguard Worker            )
1456*da0073e9SAndroid Build Coastguard Worker    return ns_grouped_kernels
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Workerdef get_native_function_declarations_from_ns_grouped_kernels(
1460*da0073e9SAndroid Build Coastguard Worker    *,
1461*da0073e9SAndroid Build Coastguard Worker    ns_grouped_kernels: dict[str, list[str]],
1462*da0073e9SAndroid Build Coastguard Worker) -> list[str]:
1463*da0073e9SAndroid Build Coastguard Worker    declarations: list[str] = []
1464*da0073e9SAndroid Build Coastguard Worker    newline = "\n"
1465*da0073e9SAndroid Build Coastguard Worker    for namespace, kernels in ns_grouped_kernels.items():
1466*da0073e9SAndroid Build Coastguard Worker        ns_helper = NamespaceHelper(
1467*da0073e9SAndroid Build Coastguard Worker            namespace_str=namespace,
1468*da0073e9SAndroid Build Coastguard Worker            entity_name="",
1469*da0073e9SAndroid Build Coastguard Worker            max_level=4,
1470*da0073e9SAndroid Build Coastguard Worker        )
1471*da0073e9SAndroid Build Coastguard Worker        # Convert to a set first to remove duplicate kernel names. Backends are
1472*da0073e9SAndroid Build Coastguard Worker        # allowed to repeat kernel names; only generate the declaration once!
1473*da0073e9SAndroid Build Coastguard Worker        ordered_kernels = list(OrderedDict.fromkeys(kernels))
1474*da0073e9SAndroid Build Coastguard Worker        declarations.extend(
1475*da0073e9SAndroid Build Coastguard Worker            f"""
1476*da0073e9SAndroid Build Coastguard Worker{ns_helper.prologue}
1477*da0073e9SAndroid Build Coastguard Worker{newline.join(ordered_kernels)}
1478*da0073e9SAndroid Build Coastguard Worker{ns_helper.epilogue}
1479*da0073e9SAndroid Build Coastguard Worker        """.split(
1480*da0073e9SAndroid Build Coastguard Worker                newline
1481*da0073e9SAndroid Build Coastguard Worker            )
1482*da0073e9SAndroid Build Coastguard Worker        )
1483*da0073e9SAndroid Build Coastguard Worker    return declarations
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker# Return native function declarations grouped by their namespaces.
1487*da0073e9SAndroid Build Coastguard Workerdef get_native_function_declarations(
1488*da0073e9SAndroid Build Coastguard Worker    *,
1489*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1490*da0073e9SAndroid Build Coastguard Worker    backend_indices: dict[DispatchKey, BackendIndex],
1491*da0073e9SAndroid Build Coastguard Worker    native_function_decl_gen: Callable[
1492*da0073e9SAndroid Build Coastguard Worker        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1493*da0073e9SAndroid Build Coastguard Worker    ] = dest.compute_native_function_declaration,
1494*da0073e9SAndroid Build Coastguard Worker) -> list[str]:
1495*da0073e9SAndroid Build Coastguard Worker    """
1496*da0073e9SAndroid Build Coastguard Worker    Generate kernel declarations, in `NativeFunction(s).h`.
1497*da0073e9SAndroid Build Coastguard Worker    :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
1498*da0073e9SAndroid Build Coastguard Worker    :param backend_indices: kernel collections grouped by dispatch key.
1499*da0073e9SAndroid Build Coastguard Worker    :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
1500*da0073e9SAndroid Build Coastguard Worker    :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
1501*da0073e9SAndroid Build Coastguard Worker    """
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker    ns_grouped_kernels = get_ns_grouped_kernels(
1504*da0073e9SAndroid Build Coastguard Worker        grouped_native_functions=grouped_native_functions,
1505*da0073e9SAndroid Build Coastguard Worker        backend_indices=backend_indices,
1506*da0073e9SAndroid Build Coastguard Worker        native_function_decl_gen=native_function_decl_gen,
1507*da0073e9SAndroid Build Coastguard Worker    )
1508*da0073e9SAndroid Build Coastguard Worker    return get_native_function_declarations_from_ns_grouped_kernels(
1509*da0073e9SAndroid Build Coastguard Worker        ns_grouped_kernels=ns_grouped_kernels
1510*da0073e9SAndroid Build Coastguard Worker    )
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Workerdef get_kernel_namespace(
1514*da0073e9SAndroid Build Coastguard Worker    *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
1515*da0073e9SAndroid Build Coastguard Worker) -> str:
1516*da0073e9SAndroid Build Coastguard Worker    backend_metadata = backend_idx.get_kernel(f)
1517*da0073e9SAndroid Build Coastguard Worker    assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
1518*da0073e9SAndroid Build Coastguard Worker        f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
1519*da0073e9SAndroid Build Coastguard Worker        f"with dispatch key {backend_idx.dispatch_key}"
1520*da0073e9SAndroid Build Coastguard Worker        f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
1521*da0073e9SAndroid Build Coastguard Worker    )
1522*da0073e9SAndroid Build Coastguard Worker    return (
1523*da0073e9SAndroid Build Coastguard Worker        backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
1524*da0073e9SAndroid Build Coastguard Worker    )
1525*da0073e9SAndroid Build Coastguard Worker
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker# Return native function definitions grouped by dispatch key and custom namespace.
1528*da0073e9SAndroid Build Coastguard Worker# Used in RegisterDispatchKey.cpp and etc.
1529*da0073e9SAndroid Build Coastguard Workerdef get_native_function_definitions(
1530*da0073e9SAndroid Build Coastguard Worker    *,
1531*da0073e9SAndroid Build Coastguard Worker    fm: FileManager,
1532*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1533*da0073e9SAndroid Build Coastguard Worker    dispatch_key: DispatchKey,
1534*da0073e9SAndroid Build Coastguard Worker    backend_idx: BackendIndex,
1535*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder,
1536*da0073e9SAndroid Build Coastguard Worker    rocm: bool,
1537*da0073e9SAndroid Build Coastguard Worker    symint: bool,
1538*da0073e9SAndroid Build Coastguard Worker    skip_dispatcher_op_registration: bool,
1539*da0073e9SAndroid Build Coastguard Worker    gen_dispatch_helpers: bool,
1540*da0073e9SAndroid Build Coastguard Worker) -> list[str]:
1541*da0073e9SAndroid Build Coastguard Worker    definitions: list[str] = []
1542*da0073e9SAndroid Build Coastguard Worker    ns_definitions: dict[str, list[str]] = defaultdict(list)
1543*da0073e9SAndroid Build Coastguard Worker    anonymous_definitions: dict[str, list[str]] = defaultdict(list)
1544*da0073e9SAndroid Build Coastguard Worker    registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
1545*da0073e9SAndroid Build Coastguard Worker    newline = "\n"
1546*da0073e9SAndroid Build Coastguard Worker    ns_gen = dest.RegisterDispatchKey(
1547*da0073e9SAndroid Build Coastguard Worker        backend_idx,
1548*da0073e9SAndroid Build Coastguard Worker        Target.NAMESPACED_DEFINITION,
1549*da0073e9SAndroid Build Coastguard Worker        selector,
1550*da0073e9SAndroid Build Coastguard Worker        rocm=rocm,
1551*da0073e9SAndroid Build Coastguard Worker        symint=symint,
1552*da0073e9SAndroid Build Coastguard Worker        class_method_name=None,
1553*da0073e9SAndroid Build Coastguard Worker        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1554*da0073e9SAndroid Build Coastguard Worker    )
1555*da0073e9SAndroid Build Coastguard Worker    anonymous_gen = dest.RegisterDispatchKey(
1556*da0073e9SAndroid Build Coastguard Worker        backend_idx,
1557*da0073e9SAndroid Build Coastguard Worker        Target.ANONYMOUS_DEFINITION,
1558*da0073e9SAndroid Build Coastguard Worker        selector,
1559*da0073e9SAndroid Build Coastguard Worker        rocm=rocm,
1560*da0073e9SAndroid Build Coastguard Worker        symint=symint,
1561*da0073e9SAndroid Build Coastguard Worker        class_method_name=None,
1562*da0073e9SAndroid Build Coastguard Worker        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1563*da0073e9SAndroid Build Coastguard Worker    )
1564*da0073e9SAndroid Build Coastguard Worker    reg_gen = dest.RegisterDispatchKey(
1565*da0073e9SAndroid Build Coastguard Worker        backend_idx,
1566*da0073e9SAndroid Build Coastguard Worker        Target.REGISTRATION,
1567*da0073e9SAndroid Build Coastguard Worker        selector,
1568*da0073e9SAndroid Build Coastguard Worker        rocm=rocm,
1569*da0073e9SAndroid Build Coastguard Worker        symint=symint,
1570*da0073e9SAndroid Build Coastguard Worker        class_method_name=None,
1571*da0073e9SAndroid Build Coastguard Worker        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1572*da0073e9SAndroid Build Coastguard Worker    )
1573*da0073e9SAndroid Build Coastguard Worker    for f in grouped_native_functions:
1574*da0073e9SAndroid Build Coastguard Worker        kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1575*da0073e9SAndroid Build Coastguard Worker            "::native", ""
1576*da0073e9SAndroid Build Coastguard Worker        )
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker        ns_definitions[kernel_namespace].extend(
1579*da0073e9SAndroid Build Coastguard Worker            ns_gen(f),
1580*da0073e9SAndroid Build Coastguard Worker        )
1581*da0073e9SAndroid Build Coastguard Worker        anonymous_definitions[kernel_namespace].extend(
1582*da0073e9SAndroid Build Coastguard Worker            anonymous_gen(f),
1583*da0073e9SAndroid Build Coastguard Worker        )
1584*da0073e9SAndroid Build Coastguard Worker        namespace = (
1585*da0073e9SAndroid Build Coastguard Worker            f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
1586*da0073e9SAndroid Build Coastguard Worker        )
1587*da0073e9SAndroid Build Coastguard Worker        if namespace not in registrations[kernel_namespace]:
1588*da0073e9SAndroid Build Coastguard Worker            registrations[kernel_namespace] = defaultdict(list)
1589*da0073e9SAndroid Build Coastguard Worker        registrations[kernel_namespace][namespace].extend(
1590*da0073e9SAndroid Build Coastguard Worker            reg_gen(f),
1591*da0073e9SAndroid Build Coastguard Worker        )
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker    for kernel_namespace in ns_definitions:
1594*da0073e9SAndroid Build Coastguard Worker        if len(ns_definitions[kernel_namespace]) == 0:
1595*da0073e9SAndroid Build Coastguard Worker            continue
1596*da0073e9SAndroid Build Coastguard Worker        ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
1597*da0073e9SAndroid Build Coastguard Worker        registration_body = ""
1598*da0073e9SAndroid Build Coastguard Worker        for namespace in registrations[kernel_namespace]:
1599*da0073e9SAndroid Build Coastguard Worker            if not registrations[kernel_namespace][namespace]:
1600*da0073e9SAndroid Build Coastguard Worker                continue
1601*da0073e9SAndroid Build Coastguard Worker            registration_body += f"""
1602*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
1603*da0073e9SAndroid Build Coastguard Worker    {newline.join(registrations[kernel_namespace][namespace])}
1604*da0073e9SAndroid Build Coastguard Worker}};"""
1605*da0073e9SAndroid Build Coastguard Worker        definitions.extend(
1606*da0073e9SAndroid Build Coastguard Worker            fm.substitute_with_template(
1607*da0073e9SAndroid Build Coastguard Worker                "RegisterDispatchDefinitions.ini",
1608*da0073e9SAndroid Build Coastguard Worker                lambda: {
1609*da0073e9SAndroid Build Coastguard Worker                    "ns_prologue": ns_helper.prologue,
1610*da0073e9SAndroid Build Coastguard Worker                    "ns_epilogue": ns_helper.epilogue,
1611*da0073e9SAndroid Build Coastguard Worker                    "dispatch_helpers": dest.gen_registration_helpers(backend_idx)
1612*da0073e9SAndroid Build Coastguard Worker                    if gen_dispatch_helpers
1613*da0073e9SAndroid Build Coastguard Worker                    else [],
1614*da0073e9SAndroid Build Coastguard Worker                    "dispatch_anonymous_definitions": anonymous_definitions[
1615*da0073e9SAndroid Build Coastguard Worker                        kernel_namespace
1616*da0073e9SAndroid Build Coastguard Worker                    ],
1617*da0073e9SAndroid Build Coastguard Worker                    "static_init_dispatch_registrations": ""
1618*da0073e9SAndroid Build Coastguard Worker                    if skip_dispatcher_op_registration
1619*da0073e9SAndroid Build Coastguard Worker                    else registration_body,
1620*da0073e9SAndroid Build Coastguard Worker                    "deferred_dispatch_registrations": "",
1621*da0073e9SAndroid Build Coastguard Worker                    "dispatch_namespace": dispatch_key.lower(),
1622*da0073e9SAndroid Build Coastguard Worker                    "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
1623*da0073e9SAndroid Build Coastguard Worker                },
1624*da0073e9SAndroid Build Coastguard Worker            ).split(newline)
1625*da0073e9SAndroid Build Coastguard Worker        )
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker    return definitions
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker# Return native function declarations grouped by dispatch key and custom namespace.
1631*da0073e9SAndroid Build Coastguard Worker# Used in CPUFunctions_inl.h and etc.
1632*da0073e9SAndroid Build Coastguard Workerdef get_namespaced_declaration(
1633*da0073e9SAndroid Build Coastguard Worker    *,
1634*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1635*da0073e9SAndroid Build Coastguard Worker    dispatch_key: DispatchKey,
1636*da0073e9SAndroid Build Coastguard Worker    backend_idx: BackendIndex,
1637*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder,
1638*da0073e9SAndroid Build Coastguard Worker    rocm: bool,
1639*da0073e9SAndroid Build Coastguard Worker    symint: bool,
1640*da0073e9SAndroid Build Coastguard Worker) -> list[str]:
1641*da0073e9SAndroid Build Coastguard Worker    declarations: list[str] = []
1642*da0073e9SAndroid Build Coastguard Worker    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1643*da0073e9SAndroid Build Coastguard Worker    newline = "\n"
1644*da0073e9SAndroid Build Coastguard Worker    func = dest.RegisterDispatchKey(
1645*da0073e9SAndroid Build Coastguard Worker        backend_idx,
1646*da0073e9SAndroid Build Coastguard Worker        Target.NAMESPACED_DECLARATION,
1647*da0073e9SAndroid Build Coastguard Worker        selector,
1648*da0073e9SAndroid Build Coastguard Worker        rocm=rocm,
1649*da0073e9SAndroid Build Coastguard Worker        class_method_name=None,
1650*da0073e9SAndroid Build Coastguard Worker        skip_dispatcher_op_registration=False,
1651*da0073e9SAndroid Build Coastguard Worker        symint=symint,
1652*da0073e9SAndroid Build Coastguard Worker    )
1653*da0073e9SAndroid Build Coastguard Worker    for f in grouped_native_functions:
1654*da0073e9SAndroid Build Coastguard Worker        namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1655*da0073e9SAndroid Build Coastguard Worker            "native", dispatch_key.lower()
1656*da0073e9SAndroid Build Coastguard Worker        )
1657*da0073e9SAndroid Build Coastguard Worker
1658*da0073e9SAndroid Build Coastguard Worker        ns_grouped_kernels[namespace].extend(
1659*da0073e9SAndroid Build Coastguard Worker            func(f),
1660*da0073e9SAndroid Build Coastguard Worker        )
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker    for namespace, kernels in ns_grouped_kernels.items():
1663*da0073e9SAndroid Build Coastguard Worker        if len(kernels) == 0:
1664*da0073e9SAndroid Build Coastguard Worker            continue
1665*da0073e9SAndroid Build Coastguard Worker        ns_helper = NamespaceHelper(
1666*da0073e9SAndroid Build Coastguard Worker            namespace_str=namespace, entity_name="", max_level=3
1667*da0073e9SAndroid Build Coastguard Worker        )
1668*da0073e9SAndroid Build Coastguard Worker        ordered_kernels = list(OrderedDict.fromkeys(kernels))
1669*da0073e9SAndroid Build Coastguard Worker        declarations.extend(
1670*da0073e9SAndroid Build Coastguard Worker            f"""
1671*da0073e9SAndroid Build Coastguard Worker{ns_helper.prologue}
1672*da0073e9SAndroid Build Coastguard Worker{newline.join(ordered_kernels)}
1673*da0073e9SAndroid Build Coastguard Worker{ns_helper.epilogue}
1674*da0073e9SAndroid Build Coastguard Worker        """.split(
1675*da0073e9SAndroid Build Coastguard Worker                newline
1676*da0073e9SAndroid Build Coastguard Worker            )
1677*da0073e9SAndroid Build Coastguard Worker        )
1678*da0073e9SAndroid Build Coastguard Worker    return declarations
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker# Return native function schema registration code for aten and other namespaces.
1682*da0073e9SAndroid Build Coastguard Workerdef get_native_function_schema_registrations(
1683*da0073e9SAndroid Build Coastguard Worker    *,
1684*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
1685*da0073e9SAndroid Build Coastguard Worker    schema_selector: SelectiveBuilder,
1686*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], str]:
1687*da0073e9SAndroid Build Coastguard Worker    ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
1688*da0073e9SAndroid Build Coastguard Worker    for native_function in native_functions:
1689*da0073e9SAndroid Build Coastguard Worker        ns_native_functions[native_function.namespace].append(native_function)
1690*da0073e9SAndroid Build Coastguard Worker    schema_registrations = ""
1691*da0073e9SAndroid Build Coastguard Worker    aten_schema_registrations = []
1692*da0073e9SAndroid Build Coastguard Worker    custom_namespace = None
1693*da0073e9SAndroid Build Coastguard Worker    for namespace, funcs in ns_native_functions.items():
1694*da0073e9SAndroid Build Coastguard Worker        schema_registrations_body = list(
1695*da0073e9SAndroid Build Coastguard Worker            mapMaybe(RegisterSchema(schema_selector), funcs)
1696*da0073e9SAndroid Build Coastguard Worker        )
1697*da0073e9SAndroid Build Coastguard Worker        # NB: we have to separate aten namespace registration from other namespaces,
1698*da0073e9SAndroid Build Coastguard Worker        # because in the template we hardcoded an operator for ATen already.
1699*da0073e9SAndroid Build Coastguard Worker        if namespace == "aten":
1700*da0073e9SAndroid Build Coastguard Worker            aten_schema_registrations = schema_registrations_body
1701*da0073e9SAndroid Build Coastguard Worker        else:
1702*da0073e9SAndroid Build Coastguard Worker            custom_namespace = namespace
1703*da0073e9SAndroid Build Coastguard Worker            tab = "\t"
1704*da0073e9SAndroid Build Coastguard Worker            # if the namespace is predefined, we should use define a library fragment
1705*da0073e9SAndroid Build Coastguard Worker            # instead of a new library
1706*da0073e9SAndroid Build Coastguard Worker            torch_library_macro = (
1707*da0073e9SAndroid Build Coastguard Worker                "TORCH_LIBRARY_FRAGMENT"
1708*da0073e9SAndroid Build Coastguard Worker                if namespace in FRAGMENT_NAMESPACES
1709*da0073e9SAndroid Build Coastguard Worker                else "TORCH_LIBRARY"
1710*da0073e9SAndroid Build Coastguard Worker            )
1711*da0073e9SAndroid Build Coastguard Worker            schema_registrations += f"""
1712*da0073e9SAndroid Build Coastguard Worker{torch_library_macro}({custom_namespace}, m) {{
1713*da0073e9SAndroid Build Coastguard Worker  {tab.join(schema_registrations_body)}
1714*da0073e9SAndroid Build Coastguard Worker}};"""
1715*da0073e9SAndroid Build Coastguard Worker    return (aten_schema_registrations, schema_registrations)
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker
1718*da0073e9SAndroid Build Coastguard Workerdef gen_aggregated_headers(
1719*da0073e9SAndroid Build Coastguard Worker    *,
1720*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
1721*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1722*da0073e9SAndroid Build Coastguard Worker    structured_native_functions: Sequence[NativeFunctionsGroup],
1723*da0073e9SAndroid Build Coastguard Worker    static_dispatch_idx: list[BackendIndex],
1724*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder,
1725*da0073e9SAndroid Build Coastguard Worker    backend_indices: dict[DispatchKey, BackendIndex],
1726*da0073e9SAndroid Build Coastguard Worker    cpu_fm: FileManager,
1727*da0073e9SAndroid Build Coastguard Worker    cuda_fm: FileManager,
1728*da0073e9SAndroid Build Coastguard Worker    functions_keys: set[DispatchKey],
1729*da0073e9SAndroid Build Coastguard Worker    dispatch_keys: Sequence[DispatchKey],
1730*da0073e9SAndroid Build Coastguard Worker    rocm: bool,
1731*da0073e9SAndroid Build Coastguard Worker) -> None:
1732*da0073e9SAndroid Build Coastguard Worker    # Buck doesn't support dynamic output files, so we aggregate all operator
1733*da0073e9SAndroid Build Coastguard Worker    # headers into a single file
1734*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
1735*da0073e9SAndroid Build Coastguard Worker        "NativeMetaFunctions.h",
1736*da0073e9SAndroid Build Coastguard Worker        lambda: {
1737*da0073e9SAndroid Build Coastguard Worker            "NativeMetaFunctions_includes": [],
1738*da0073e9SAndroid Build Coastguard Worker            "NativeMetaFunctions_declarations": list(
1739*da0073e9SAndroid Build Coastguard Worker                mapMaybe(compute_meta_function_declaration, structured_native_functions)
1740*da0073e9SAndroid Build Coastguard Worker            ),
1741*da0073e9SAndroid Build Coastguard Worker        },
1742*da0073e9SAndroid Build Coastguard Worker    )
1743*da0073e9SAndroid Build Coastguard Worker    method_native_functions = [
1744*da0073e9SAndroid Build Coastguard Worker        fn for fn in native_functions if Variant.method in fn.variants
1745*da0073e9SAndroid Build Coastguard Worker    ]
1746*da0073e9SAndroid Build Coastguard Worker    non_method_native_functions = [
1747*da0073e9SAndroid Build Coastguard Worker        fn for fn in native_functions if fn not in method_native_functions
1748*da0073e9SAndroid Build Coastguard Worker    ]
1749*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
1750*da0073e9SAndroid Build Coastguard Worker        "MethodOperators.h",
1751*da0073e9SAndroid Build Coastguard Worker        lambda: {
1752*da0073e9SAndroid Build Coastguard Worker            "MethodOperators_includes": [],
1753*da0073e9SAndroid Build Coastguard Worker            "MethodOperators_declarations": list(
1754*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
1755*da0073e9SAndroid Build Coastguard Worker                    ComputeOperators(
1756*da0073e9SAndroid Build Coastguard Worker                        Target.DECLARATION,
1757*da0073e9SAndroid Build Coastguard Worker                        static_dispatch_backend_indices=static_dispatch_idx,
1758*da0073e9SAndroid Build Coastguard Worker                    ),
1759*da0073e9SAndroid Build Coastguard Worker                    method_native_functions,
1760*da0073e9SAndroid Build Coastguard Worker                )
1761*da0073e9SAndroid Build Coastguard Worker            ),
1762*da0073e9SAndroid Build Coastguard Worker        },
1763*da0073e9SAndroid Build Coastguard Worker    )
1764*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
1765*da0073e9SAndroid Build Coastguard Worker        "Operators.h",
1766*da0073e9SAndroid Build Coastguard Worker        lambda: {
1767*da0073e9SAndroid Build Coastguard Worker            "Operators_includes": ["#include <ATen/MethodOperators.h>"],
1768*da0073e9SAndroid Build Coastguard Worker            "Operators_declarations": list(
1769*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
1770*da0073e9SAndroid Build Coastguard Worker                    ComputeOperators(
1771*da0073e9SAndroid Build Coastguard Worker                        Target.DECLARATION,
1772*da0073e9SAndroid Build Coastguard Worker                        static_dispatch_backend_indices=static_dispatch_idx,
1773*da0073e9SAndroid Build Coastguard Worker                    ),
1774*da0073e9SAndroid Build Coastguard Worker                    non_method_native_functions,
1775*da0073e9SAndroid Build Coastguard Worker                )
1776*da0073e9SAndroid Build Coastguard Worker            ),
1777*da0073e9SAndroid Build Coastguard Worker        },
1778*da0073e9SAndroid Build Coastguard Worker    )
1779*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
1780*da0073e9SAndroid Build Coastguard Worker        "Functions.h",
1781*da0073e9SAndroid Build Coastguard Worker        lambda: {
1782*da0073e9SAndroid Build Coastguard Worker            "static_dispatch_extra_headers": static_dispatch_extra_headers(
1783*da0073e9SAndroid Build Coastguard Worker                static_dispatch_idx
1784*da0073e9SAndroid Build Coastguard Worker            ),
1785*da0073e9SAndroid Build Coastguard Worker            "Functions_includes": ["#include <ATen/Operators.h>"],
1786*da0073e9SAndroid Build Coastguard Worker            "Functions_declarations": list(
1787*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
1788*da0073e9SAndroid Build Coastguard Worker                    ComputeFunction(),
1789*da0073e9SAndroid Build Coastguard Worker                    native_functions,
1790*da0073e9SAndroid Build Coastguard Worker                )
1791*da0073e9SAndroid Build Coastguard Worker            ),
1792*da0073e9SAndroid Build Coastguard Worker        },
1793*da0073e9SAndroid Build Coastguard Worker    )
1794*da0073e9SAndroid Build Coastguard Worker    declarations = get_native_function_declarations(
1795*da0073e9SAndroid Build Coastguard Worker        grouped_native_functions=grouped_native_functions,
1796*da0073e9SAndroid Build Coastguard Worker        backend_indices=backend_indices,
1797*da0073e9SAndroid Build Coastguard Worker    )
1798*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
1799*da0073e9SAndroid Build Coastguard Worker        "NativeFunctions.h",
1800*da0073e9SAndroid Build Coastguard Worker        lambda: {
1801*da0073e9SAndroid Build Coastguard Worker            "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
1802*da0073e9SAndroid Build Coastguard Worker            "NativeFunctions_declarations": declarations,
1803*da0073e9SAndroid Build Coastguard Worker        },
1804*da0073e9SAndroid Build Coastguard Worker    )
1805*da0073e9SAndroid Build Coastguard Worker
1806*da0073e9SAndroid Build Coastguard Worker    for dispatch_key in dispatch_keys:
1807*da0073e9SAndroid Build Coastguard Worker        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1808*da0073e9SAndroid Build Coastguard Worker        if dispatch_key in functions_keys:
1809*da0073e9SAndroid Build Coastguard Worker            inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1810*da0073e9SAndroid Build Coastguard Worker
1811*da0073e9SAndroid Build Coastguard Worker            fm.write_with_template(
1812*da0073e9SAndroid Build Coastguard Worker                f"{dispatch_key}Functions.h",
1813*da0073e9SAndroid Build Coastguard Worker                "DispatchKeyFunctions.h",
1814*da0073e9SAndroid Build Coastguard Worker                lambda: {
1815*da0073e9SAndroid Build Coastguard Worker                    "dispatch_key": str(dispatch_key),
1816*da0073e9SAndroid Build Coastguard Worker                    "inline_headers": inl_headers,
1817*da0073e9SAndroid Build Coastguard Worker                },
1818*da0073e9SAndroid Build Coastguard Worker            )
1819*da0073e9SAndroid Build Coastguard Worker            fm.write_with_template(
1820*da0073e9SAndroid Build Coastguard Worker                f"{dispatch_key}Functions_inl.h",
1821*da0073e9SAndroid Build Coastguard Worker                "DispatchKeyFunctions_inl.h",
1822*da0073e9SAndroid Build Coastguard Worker                lambda: {
1823*da0073e9SAndroid Build Coastguard Worker                    "DispatchKeyFunctions_inl_includes": [],
1824*da0073e9SAndroid Build Coastguard Worker                    "dispatch_namespace": dispatch_key.lower(),
1825*da0073e9SAndroid Build Coastguard Worker                    "dispatch_namespaced_declarations": get_namespaced_declaration(
1826*da0073e9SAndroid Build Coastguard Worker                        grouped_native_functions=grouped_native_functions,
1827*da0073e9SAndroid Build Coastguard Worker                        dispatch_key=dispatch_key,
1828*da0073e9SAndroid Build Coastguard Worker                        backend_idx=backend_indices[dispatch_key],
1829*da0073e9SAndroid Build Coastguard Worker                        selector=selector,
1830*da0073e9SAndroid Build Coastguard Worker                        rocm=rocm,
1831*da0073e9SAndroid Build Coastguard Worker                        symint=True,
1832*da0073e9SAndroid Build Coastguard Worker                    ),
1833*da0073e9SAndroid Build Coastguard Worker                },
1834*da0073e9SAndroid Build Coastguard Worker            )
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker        del fm
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Workerdef gen_per_operator_headers(
1840*da0073e9SAndroid Build Coastguard Worker    *,
1841*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
1842*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1843*da0073e9SAndroid Build Coastguard Worker    static_dispatch_idx: list[BackendIndex],
1844*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder,
1845*da0073e9SAndroid Build Coastguard Worker    backend_indices: dict[DispatchKey, BackendIndex],
1846*da0073e9SAndroid Build Coastguard Worker    cpu_fm: FileManager,
1847*da0073e9SAndroid Build Coastguard Worker    cuda_fm: FileManager,
1848*da0073e9SAndroid Build Coastguard Worker    ops_fm: FileManager,
1849*da0073e9SAndroid Build Coastguard Worker    functions_keys: set[DispatchKey],
1850*da0073e9SAndroid Build Coastguard Worker    dispatch_keys: Sequence[DispatchKey],
1851*da0073e9SAndroid Build Coastguard Worker    rocm: bool,
1852*da0073e9SAndroid Build Coastguard Worker) -> None:
1853*da0073e9SAndroid Build Coastguard Worker    # For CMake builds, split operator declarations into separate headers in
1854*da0073e9SAndroid Build Coastguard Worker    # the ATen/ops folder to split up header dependencies
1855*da0073e9SAndroid Build Coastguard Worker    functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
1856*da0073e9SAndroid Build Coastguard Worker    for fn in native_functions:
1857*da0073e9SAndroid Build Coastguard Worker        functions_by_root_name[fn.root_name].append(fn)
1858*da0073e9SAndroid Build Coastguard Worker
1859*da0073e9SAndroid Build Coastguard Worker    grouped_functions_by_root_name: dict[
1860*da0073e9SAndroid Build Coastguard Worker        str, list[NativeFunction | NativeFunctionsGroup]
1861*da0073e9SAndroid Build Coastguard Worker    ] = defaultdict(list)
1862*da0073e9SAndroid Build Coastguard Worker    for group in grouped_native_functions:
1863*da0073e9SAndroid Build Coastguard Worker        name = group.root_name
1864*da0073e9SAndroid Build Coastguard Worker        grouped_functions_by_root_name[name].append(group)
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker    for name, functions in functions_by_root_name.items():
1867*da0073e9SAndroid Build Coastguard Worker        ops_fm.write_with_template(
1868*da0073e9SAndroid Build Coastguard Worker            f"{name}_ops.h",
1869*da0073e9SAndroid Build Coastguard Worker            "Operator.h",
1870*da0073e9SAndroid Build Coastguard Worker            lambda: {
1871*da0073e9SAndroid Build Coastguard Worker                "declarations": list(
1872*da0073e9SAndroid Build Coastguard Worker                    mapMaybe(
1873*da0073e9SAndroid Build Coastguard Worker                        ComputeOperators(
1874*da0073e9SAndroid Build Coastguard Worker                            Target.DECLARATION,
1875*da0073e9SAndroid Build Coastguard Worker                            static_dispatch_backend_indices=static_dispatch_idx,
1876*da0073e9SAndroid Build Coastguard Worker                        ),
1877*da0073e9SAndroid Build Coastguard Worker                        functions,
1878*da0073e9SAndroid Build Coastguard Worker                    )
1879*da0073e9SAndroid Build Coastguard Worker                ),
1880*da0073e9SAndroid Build Coastguard Worker            },
1881*da0073e9SAndroid Build Coastguard Worker        )
1882*da0073e9SAndroid Build Coastguard Worker
1883*da0073e9SAndroid Build Coastguard Worker        ops_fm.write_with_template(
1884*da0073e9SAndroid Build Coastguard Worker            f"{name}.h",
1885*da0073e9SAndroid Build Coastguard Worker            "Function.h",
1886*da0073e9SAndroid Build Coastguard Worker            lambda: {
1887*da0073e9SAndroid Build Coastguard Worker                "static_dispatch_ops_headers": list(
1888*da0073e9SAndroid Build Coastguard Worker                    mapMaybe(
1889*da0073e9SAndroid Build Coastguard Worker                        lambda fn: static_dispatch_ops_header(
1890*da0073e9SAndroid Build Coastguard Worker                            fn, backend_index=static_dispatch_idx
1891*da0073e9SAndroid Build Coastguard Worker                        ),
1892*da0073e9SAndroid Build Coastguard Worker                        functions,
1893*da0073e9SAndroid Build Coastguard Worker                    )
1894*da0073e9SAndroid Build Coastguard Worker                ),
1895*da0073e9SAndroid Build Coastguard Worker                "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
1896*da0073e9SAndroid Build Coastguard Worker                "function_definitions": list(
1897*da0073e9SAndroid Build Coastguard Worker                    mapMaybe(
1898*da0073e9SAndroid Build Coastguard Worker                        ComputeFunction(),
1899*da0073e9SAndroid Build Coastguard Worker                        functions,
1900*da0073e9SAndroid Build Coastguard Worker                    )
1901*da0073e9SAndroid Build Coastguard Worker                ),
1902*da0073e9SAndroid Build Coastguard Worker            },
1903*da0073e9SAndroid Build Coastguard Worker        )
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        grouped_functions = grouped_functions_by_root_name.get(name, [])
1906*da0073e9SAndroid Build Coastguard Worker        structured_functions = [
1907*da0073e9SAndroid Build Coastguard Worker            fn
1908*da0073e9SAndroid Build Coastguard Worker            for fn in grouped_functions
1909*da0073e9SAndroid Build Coastguard Worker            if isinstance(fn, NativeFunctionsGroup) and fn.structured
1910*da0073e9SAndroid Build Coastguard Worker        ]
1911*da0073e9SAndroid Build Coastguard Worker        is_structured = len(structured_functions) > 0
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker        if is_structured:
1914*da0073e9SAndroid Build Coastguard Worker            ops_fm.write_with_template(
1915*da0073e9SAndroid Build Coastguard Worker                f"{name}_meta.h",
1916*da0073e9SAndroid Build Coastguard Worker                "NativeMetaFunction.h",
1917*da0073e9SAndroid Build Coastguard Worker                lambda: {
1918*da0073e9SAndroid Build Coastguard Worker                    "meta_function_declarations": list(
1919*da0073e9SAndroid Build Coastguard Worker                        mapMaybe(
1920*da0073e9SAndroid Build Coastguard Worker                            compute_meta_function_declaration, structured_functions
1921*da0073e9SAndroid Build Coastguard Worker                        )
1922*da0073e9SAndroid Build Coastguard Worker                    ),
1923*da0073e9SAndroid Build Coastguard Worker                },
1924*da0073e9SAndroid Build Coastguard Worker            )
1925*da0073e9SAndroid Build Coastguard Worker        declarations = get_native_function_declarations(
1926*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=grouped_functions,
1927*da0073e9SAndroid Build Coastguard Worker            backend_indices=backend_indices,
1928*da0073e9SAndroid Build Coastguard Worker            native_function_decl_gen=dest.compute_native_function_declaration,
1929*da0073e9SAndroid Build Coastguard Worker        )
1930*da0073e9SAndroid Build Coastguard Worker        ops_fm.write_with_template(
1931*da0073e9SAndroid Build Coastguard Worker            f"{name}_native.h",
1932*da0073e9SAndroid Build Coastguard Worker            "NativeFunction.h",
1933*da0073e9SAndroid Build Coastguard Worker            lambda: {
1934*da0073e9SAndroid Build Coastguard Worker                "extra_includes": (
1935*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
1936*da0073e9SAndroid Build Coastguard Worker                ),
1937*da0073e9SAndroid Build Coastguard Worker                "native_function_declarations": declarations,
1938*da0073e9SAndroid Build Coastguard Worker            },
1939*da0073e9SAndroid Build Coastguard Worker        )
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker    for category, suffix in [
1942*da0073e9SAndroid Build Coastguard Worker        ("Functions", ""),
1943*da0073e9SAndroid Build Coastguard Worker        ("Operators", "_ops"),
1944*da0073e9SAndroid Build Coastguard Worker        ("NativeMetaFunctions", "_meta"),
1945*da0073e9SAndroid Build Coastguard Worker        ("NativeFunctions", "_native"),
1946*da0073e9SAndroid Build Coastguard Worker    ]:
1947*da0073e9SAndroid Build Coastguard Worker        cpu_fm.write(
1948*da0073e9SAndroid Build Coastguard Worker            f"{category}.h",
1949*da0073e9SAndroid Build Coastguard Worker            lambda: {
1950*da0073e9SAndroid Build Coastguard Worker                f"{category}_includes": [
1951*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{name}{suffix}.h>"
1952*da0073e9SAndroid Build Coastguard Worker                    for name in sorted(functions_by_root_name.keys())
1953*da0073e9SAndroid Build Coastguard Worker                ],
1954*da0073e9SAndroid Build Coastguard Worker                f"{category}_declarations": [],
1955*da0073e9SAndroid Build Coastguard Worker            },
1956*da0073e9SAndroid Build Coastguard Worker        )
1957*da0073e9SAndroid Build Coastguard Worker
1958*da0073e9SAndroid Build Coastguard Worker    for dispatch_key in dispatch_keys:
1959*da0073e9SAndroid Build Coastguard Worker        if dispatch_key not in functions_keys:
1960*da0073e9SAndroid Build Coastguard Worker            continue
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker        dispatch_namespace = dispatch_key.lower()
1963*da0073e9SAndroid Build Coastguard Worker        dispatch_names = []
1964*da0073e9SAndroid Build Coastguard Worker
1965*da0073e9SAndroid Build Coastguard Worker        for name, functions in functions_by_root_name.items():
1966*da0073e9SAndroid Build Coastguard Worker            grouped_functions = grouped_functions_by_root_name.get(name, [])
1967*da0073e9SAndroid Build Coastguard Worker            declarations = list(
1968*da0073e9SAndroid Build Coastguard Worker                concatMap(
1969*da0073e9SAndroid Build Coastguard Worker                    dest.RegisterDispatchKey(
1970*da0073e9SAndroid Build Coastguard Worker                        backend_indices[dispatch_key],
1971*da0073e9SAndroid Build Coastguard Worker                        Target.NAMESPACED_DECLARATION,
1972*da0073e9SAndroid Build Coastguard Worker                        selector,
1973*da0073e9SAndroid Build Coastguard Worker                        rocm=rocm,
1974*da0073e9SAndroid Build Coastguard Worker                        symint=True,
1975*da0073e9SAndroid Build Coastguard Worker                        class_method_name=None,
1976*da0073e9SAndroid Build Coastguard Worker                        skip_dispatcher_op_registration=False,
1977*da0073e9SAndroid Build Coastguard Worker                    ),
1978*da0073e9SAndroid Build Coastguard Worker                    grouped_functions,
1979*da0073e9SAndroid Build Coastguard Worker                )
1980*da0073e9SAndroid Build Coastguard Worker            )
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker            if len(declarations) == 0:
1983*da0073e9SAndroid Build Coastguard Worker                continue
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker            dispatch_names.append(name)
1986*da0073e9SAndroid Build Coastguard Worker            ops_fm.write_with_template(
1987*da0073e9SAndroid Build Coastguard Worker                f"{name}_{dispatch_namespace}_dispatch.h",
1988*da0073e9SAndroid Build Coastguard Worker                "DispatchKeyFunction.h",
1989*da0073e9SAndroid Build Coastguard Worker                lambda: {
1990*da0073e9SAndroid Build Coastguard Worker                    "dispatch_namespace": dispatch_namespace,
1991*da0073e9SAndroid Build Coastguard Worker                    "dispatch_namespaced_declarations": declarations,
1992*da0073e9SAndroid Build Coastguard Worker                },
1993*da0073e9SAndroid Build Coastguard Worker            )
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1996*da0073e9SAndroid Build Coastguard Worker        inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1997*da0073e9SAndroid Build Coastguard Worker
1998*da0073e9SAndroid Build Coastguard Worker        fm.write_with_template(
1999*da0073e9SAndroid Build Coastguard Worker            f"{dispatch_key}Functions.h",
2000*da0073e9SAndroid Build Coastguard Worker            "DispatchKeyFunctions.h",
2001*da0073e9SAndroid Build Coastguard Worker            lambda: {
2002*da0073e9SAndroid Build Coastguard Worker                "dispatch_key": str(dispatch_key),
2003*da0073e9SAndroid Build Coastguard Worker                "inline_headers": inl_headers,
2004*da0073e9SAndroid Build Coastguard Worker            },
2005*da0073e9SAndroid Build Coastguard Worker        )
2006*da0073e9SAndroid Build Coastguard Worker        fm.write_with_template(
2007*da0073e9SAndroid Build Coastguard Worker            f"{dispatch_key}Functions_inl.h",
2008*da0073e9SAndroid Build Coastguard Worker            "DispatchKeyFunctions_inl.h",
2009*da0073e9SAndroid Build Coastguard Worker            lambda: {
2010*da0073e9SAndroid Build Coastguard Worker                "dispatch_namespace": dispatch_namespace,
2011*da0073e9SAndroid Build Coastguard Worker                "DispatchKeyFunctions_inl_includes": [
2012*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
2013*da0073e9SAndroid Build Coastguard Worker                    for name in sorted(dispatch_names)
2014*da0073e9SAndroid Build Coastguard Worker                ],
2015*da0073e9SAndroid Build Coastguard Worker                "dispatch_namespaced_declarations": [],
2016*da0073e9SAndroid Build Coastguard Worker            },
2017*da0073e9SAndroid Build Coastguard Worker        )
2018*da0073e9SAndroid Build Coastguard Worker        del fm
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2021*da0073e9SAndroid Build Coastguard Worker        "MethodOperators.h",
2022*da0073e9SAndroid Build Coastguard Worker        lambda: {
2023*da0073e9SAndroid Build Coastguard Worker            "MethodOperators_includes": sorted(
2024*da0073e9SAndroid Build Coastguard Worker                f"#include <ATen/ops/{name}_ops.h>"
2025*da0073e9SAndroid Build Coastguard Worker                for name, functions in functions_by_root_name.items()
2026*da0073e9SAndroid Build Coastguard Worker                if any(Variant.method in fn.variants for fn in functions)
2027*da0073e9SAndroid Build Coastguard Worker            ),
2028*da0073e9SAndroid Build Coastguard Worker            "MethodOperators_declarations": [],
2029*da0073e9SAndroid Build Coastguard Worker        },
2030*da0073e9SAndroid Build Coastguard Worker    )
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Workerdef gen_headers(
2034*da0073e9SAndroid Build Coastguard Worker    *,
2035*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
2036*da0073e9SAndroid Build Coastguard Worker    valid_tags: set[str],
2037*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2038*da0073e9SAndroid Build Coastguard Worker    structured_native_functions: Sequence[NativeFunctionsGroup],
2039*da0073e9SAndroid Build Coastguard Worker    static_dispatch_idx: list[BackendIndex],
2040*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder,
2041*da0073e9SAndroid Build Coastguard Worker    backend_indices: dict[DispatchKey, BackendIndex],
2042*da0073e9SAndroid Build Coastguard Worker    core_fm: FileManager,
2043*da0073e9SAndroid Build Coastguard Worker    cpu_fm: FileManager,
2044*da0073e9SAndroid Build Coastguard Worker    cuda_fm: FileManager,
2045*da0073e9SAndroid Build Coastguard Worker    ops_fm: FileManager,
2046*da0073e9SAndroid Build Coastguard Worker    dispatch_keys: Sequence[DispatchKey],
2047*da0073e9SAndroid Build Coastguard Worker    functions_keys: set[DispatchKey],
2048*da0073e9SAndroid Build Coastguard Worker    rocm: bool,
2049*da0073e9SAndroid Build Coastguard Worker    per_operator_headers: bool,
2050*da0073e9SAndroid Build Coastguard Worker) -> None:
2051*da0073e9SAndroid Build Coastguard Worker    if per_operator_headers:
2052*da0073e9SAndroid Build Coastguard Worker        gen_per_operator_headers(
2053*da0073e9SAndroid Build Coastguard Worker            native_functions=native_functions,
2054*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=grouped_native_functions,
2055*da0073e9SAndroid Build Coastguard Worker            static_dispatch_idx=static_dispatch_idx,
2056*da0073e9SAndroid Build Coastguard Worker            selector=selector,
2057*da0073e9SAndroid Build Coastguard Worker            backend_indices=backend_indices,
2058*da0073e9SAndroid Build Coastguard Worker            cpu_fm=cpu_fm,
2059*da0073e9SAndroid Build Coastguard Worker            cuda_fm=cuda_fm,
2060*da0073e9SAndroid Build Coastguard Worker            ops_fm=ops_fm,
2061*da0073e9SAndroid Build Coastguard Worker            dispatch_keys=dispatch_keys,
2062*da0073e9SAndroid Build Coastguard Worker            functions_keys=functions_keys,
2063*da0073e9SAndroid Build Coastguard Worker            rocm=rocm,
2064*da0073e9SAndroid Build Coastguard Worker        )
2065*da0073e9SAndroid Build Coastguard Worker    else:
2066*da0073e9SAndroid Build Coastguard Worker        gen_aggregated_headers(
2067*da0073e9SAndroid Build Coastguard Worker            native_functions=native_functions,
2068*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=grouped_native_functions,
2069*da0073e9SAndroid Build Coastguard Worker            structured_native_functions=structured_native_functions,
2070*da0073e9SAndroid Build Coastguard Worker            static_dispatch_idx=static_dispatch_idx,
2071*da0073e9SAndroid Build Coastguard Worker            selector=selector,
2072*da0073e9SAndroid Build Coastguard Worker            backend_indices=backend_indices,
2073*da0073e9SAndroid Build Coastguard Worker            cpu_fm=cpu_fm,
2074*da0073e9SAndroid Build Coastguard Worker            cuda_fm=cuda_fm,
2075*da0073e9SAndroid Build Coastguard Worker            dispatch_keys=dispatch_keys,
2076*da0073e9SAndroid Build Coastguard Worker            functions_keys=functions_keys,
2077*da0073e9SAndroid Build Coastguard Worker            rocm=rocm,
2078*da0073e9SAndroid Build Coastguard Worker        )
2079*da0073e9SAndroid Build Coastguard Worker
2080*da0073e9SAndroid Build Coastguard Worker    core_fm.write(
2081*da0073e9SAndroid Build Coastguard Worker        "TensorBody.h",
2082*da0073e9SAndroid Build Coastguard Worker        lambda: {
2083*da0073e9SAndroid Build Coastguard Worker            "tensor_method_declarations": list(
2084*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2085*da0073e9SAndroid Build Coastguard Worker                    ComputeTensorMethod(
2086*da0073e9SAndroid Build Coastguard Worker                        target=Target.DECLARATION,
2087*da0073e9SAndroid Build Coastguard Worker                        static_dispatch_backend_indices=static_dispatch_idx,
2088*da0073e9SAndroid Build Coastguard Worker                    ),
2089*da0073e9SAndroid Build Coastguard Worker                    native_functions,
2090*da0073e9SAndroid Build Coastguard Worker                )
2091*da0073e9SAndroid Build Coastguard Worker            ),
2092*da0073e9SAndroid Build Coastguard Worker            "tensor_method_definitions": list(
2093*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2094*da0073e9SAndroid Build Coastguard Worker                    ComputeTensorMethod(
2095*da0073e9SAndroid Build Coastguard Worker                        target=Target.DEFINITION,
2096*da0073e9SAndroid Build Coastguard Worker                        static_dispatch_backend_indices=static_dispatch_idx,
2097*da0073e9SAndroid Build Coastguard Worker                    ),
2098*da0073e9SAndroid Build Coastguard Worker                    native_functions,
2099*da0073e9SAndroid Build Coastguard Worker                )
2100*da0073e9SAndroid Build Coastguard Worker            ),
2101*da0073e9SAndroid Build Coastguard Worker        },
2102*da0073e9SAndroid Build Coastguard Worker    )
2103*da0073e9SAndroid Build Coastguard Worker
2104*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2105*da0073e9SAndroid Build Coastguard Worker        "RedispatchFunctions.h",
2106*da0073e9SAndroid Build Coastguard Worker        lambda: {
2107*da0073e9SAndroid Build Coastguard Worker            "function_redispatch_definitions": list(
2108*da0073e9SAndroid Build Coastguard Worker                mapMaybe(ComputeRedispatchFunction(), native_functions)
2109*da0073e9SAndroid Build Coastguard Worker            ),
2110*da0073e9SAndroid Build Coastguard Worker        },
2111*da0073e9SAndroid Build Coastguard Worker    )
2112*da0073e9SAndroid Build Coastguard Worker
2113*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2114*da0073e9SAndroid Build Coastguard Worker        "RegistrationDeclarations.h",
2115*da0073e9SAndroid Build Coastguard Worker        lambda: {
2116*da0073e9SAndroid Build Coastguard Worker            "registration_declarations": [
2117*da0073e9SAndroid Build Coastguard Worker                compute_registration_declarations(f, backend_indices)
2118*da0073e9SAndroid Build Coastguard Worker                for f in native_functions
2119*da0073e9SAndroid Build Coastguard Worker            ],
2120*da0073e9SAndroid Build Coastguard Worker        },
2121*da0073e9SAndroid Build Coastguard Worker    )
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2124*da0073e9SAndroid Build Coastguard Worker        "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
2125*da0073e9SAndroid Build Coastguard Worker    )
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker    def gen_aten_interned_strings() -> dict[str, str]:
2128*da0073e9SAndroid Build Coastguard Worker        attrs: set[str] = set()  # All function argument names
2129*da0073e9SAndroid Build Coastguard Worker        names = set()  # All ATen function names
2130*da0073e9SAndroid Build Coastguard Worker        for func in native_functions:
2131*da0073e9SAndroid Build Coastguard Worker            names.add(str(func.func.name.name))
2132*da0073e9SAndroid Build Coastguard Worker            # Some operators don't have a functional variant but we still create a
2133*da0073e9SAndroid Build Coastguard Worker            # symbol without the underscore
2134*da0073e9SAndroid Build Coastguard Worker            names.add(func.func.name.name.base)
2135*da0073e9SAndroid Build Coastguard Worker
2136*da0073e9SAndroid Build Coastguard Worker            attrs.update(arg.name for arg in func.func.schema_order_arguments())
2137*da0073e9SAndroid Build Coastguard Worker
2138*da0073e9SAndroid Build Coastguard Worker        # These are keywords in C++, so aren't valid symbol names
2139*da0073e9SAndroid Build Coastguard Worker        # https://en.cppreference.com/w/cpp/language/operator_alternative
2140*da0073e9SAndroid Build Coastguard Worker        names -= {
2141*da0073e9SAndroid Build Coastguard Worker            "and",
2142*da0073e9SAndroid Build Coastguard Worker            "and_eq",
2143*da0073e9SAndroid Build Coastguard Worker            "bitand",
2144*da0073e9SAndroid Build Coastguard Worker            "bitor",
2145*da0073e9SAndroid Build Coastguard Worker            "compl",
2146*da0073e9SAndroid Build Coastguard Worker            "not",
2147*da0073e9SAndroid Build Coastguard Worker            "not_eq",
2148*da0073e9SAndroid Build Coastguard Worker            "or",
2149*da0073e9SAndroid Build Coastguard Worker            "or_eq",
2150*da0073e9SAndroid Build Coastguard Worker            "xor",
2151*da0073e9SAndroid Build Coastguard Worker            "xor_eq",
2152*da0073e9SAndroid Build Coastguard Worker        }
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker        return {
2155*da0073e9SAndroid Build Coastguard Worker            "aten_symbols": " \\\n".join(
2156*da0073e9SAndroid Build Coastguard Worker                [f"_(aten, {name})" for name in sorted(names)]
2157*da0073e9SAndroid Build Coastguard Worker            ),
2158*da0073e9SAndroid Build Coastguard Worker            "attr_symbols": " \\\n".join(
2159*da0073e9SAndroid Build Coastguard Worker                [f"_(attr, {name})" for name in sorted(attrs)]
2160*da0073e9SAndroid Build Coastguard Worker            ),
2161*da0073e9SAndroid Build Coastguard Worker        }
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker    core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
2164*da0073e9SAndroid Build Coastguard Worker
2165*da0073e9SAndroid Build Coastguard Worker    def gen_tags_enum() -> dict[str, str]:
2166*da0073e9SAndroid Build Coastguard Worker        return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
2167*da0073e9SAndroid Build Coastguard Worker
2168*da0073e9SAndroid Build Coastguard Worker    core_fm.write("enum_tag.h", gen_tags_enum)
2169*da0073e9SAndroid Build Coastguard Worker
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Workerdef gen_source_files(
2172*da0073e9SAndroid Build Coastguard Worker    *,
2173*da0073e9SAndroid Build Coastguard Worker    native_functions: Sequence[NativeFunction],
2174*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2175*da0073e9SAndroid Build Coastguard Worker    structured_native_functions: Sequence[NativeFunctionsGroup],
2176*da0073e9SAndroid Build Coastguard Worker    view_groups: Sequence[NativeFunctionsViewGroup],
2177*da0073e9SAndroid Build Coastguard Worker    selector: SelectiveBuilder,
2178*da0073e9SAndroid Build Coastguard Worker    static_dispatch_idx: list[BackendIndex],
2179*da0073e9SAndroid Build Coastguard Worker    backend_indices: dict[DispatchKey, BackendIndex],
2180*da0073e9SAndroid Build Coastguard Worker    aoti_fm: FileManager,
2181*da0073e9SAndroid Build Coastguard Worker    core_fm: FileManager,
2182*da0073e9SAndroid Build Coastguard Worker    cpu_fm: FileManager,
2183*da0073e9SAndroid Build Coastguard Worker    cpu_vec_fm: FileManager,
2184*da0073e9SAndroid Build Coastguard Worker    cuda_fm: FileManager,
2185*da0073e9SAndroid Build Coastguard Worker    dispatch_keys: Sequence[DispatchKey],
2186*da0073e9SAndroid Build Coastguard Worker    functions_keys: set[DispatchKey],
2187*da0073e9SAndroid Build Coastguard Worker    rocm: bool,
2188*da0073e9SAndroid Build Coastguard Worker    force_schema_registration: bool,
2189*da0073e9SAndroid Build Coastguard Worker    per_operator_headers: bool,
2190*da0073e9SAndroid Build Coastguard Worker    skip_dispatcher_op_registration: bool,
2191*da0073e9SAndroid Build Coastguard Worker    update_aoti_c_shim: bool,
2192*da0073e9SAndroid Build Coastguard Worker) -> None:
2193*da0073e9SAndroid Build Coastguard Worker    extra_cuda_headers = """\
2194*da0073e9SAndroid Build Coastguard Worker#include <c10/cuda/CUDAGuard.h>
2195*da0073e9SAndroid Build Coastguard Worker#include <ATen/cuda/ATenCUDAGeneral.h>
2196*da0073e9SAndroid Build Coastguard Worker#include <ATen/cuda/CUDADevice.h>
2197*da0073e9SAndroid Build Coastguard Worker#include <ATen/cuda/CUDAContext.h>"""
2198*da0073e9SAndroid Build Coastguard Worker    if rocm:
2199*da0073e9SAndroid Build Coastguard Worker        extra_cuda_headers = """\
2200*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
2201*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/ATenHIPGeneral.h>
2202*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/HIPDevice.h>
2203*da0073e9SAndroid Build Coastguard Worker#include <ATen/hip/HIPContext.h>"""
2204*da0073e9SAndroid Build Coastguard Worker
2205*da0073e9SAndroid Build Coastguard Worker    for dispatch_key in dispatch_keys:
2206*da0073e9SAndroid Build Coastguard Worker        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
2207*da0073e9SAndroid Build Coastguard Worker
2208*da0073e9SAndroid Build Coastguard Worker        if per_operator_headers:
2209*da0073e9SAndroid Build Coastguard Worker
2210*da0073e9SAndroid Build Coastguard Worker            def operator_headers() -> list[str]:
2211*da0073e9SAndroid Build Coastguard Worker                headers = []
2212*da0073e9SAndroid Build Coastguard Worker                for g in grouped_native_functions:
2213*da0073e9SAndroid Build Coastguard Worker                    is_registered = False
2214*da0073e9SAndroid Build Coastguard Worker                    if backend_index.has_kernel(g):
2215*da0073e9SAndroid Build Coastguard Worker                        is_registered = True
2216*da0073e9SAndroid Build Coastguard Worker                    # The above has_kernel test on a group will only test for
2217*da0073e9SAndroid Build Coastguard Worker                    # the existence of out dispatch, because that's how
2218*da0073e9SAndroid Build Coastguard Worker                    # structured kernels work. But sometimes functions can be
2219*da0073e9SAndroid Build Coastguard Worker                    # grouped but not be structured, and then you need to check
2220*da0073e9SAndroid Build Coastguard Worker                    # each individual piece, as they may have manual dispatch
2221*da0073e9SAndroid Build Coastguard Worker                    # entries.
2222*da0073e9SAndroid Build Coastguard Worker                    elif isinstance(g, NativeFunctionsGroup) and any(
2223*da0073e9SAndroid Build Coastguard Worker                        backend_index.has_kernel(fn) for fn in g.functions()
2224*da0073e9SAndroid Build Coastguard Worker                    ):
2225*da0073e9SAndroid Build Coastguard Worker                        is_registered = True
2226*da0073e9SAndroid Build Coastguard Worker                    # TODO: this condition is a bit questionable
2227*da0073e9SAndroid Build Coastguard Worker                    # (It has to do with the fact that structured kernels get generated kernels
2228*da0073e9SAndroid Build Coastguard Worker                    # to the Meta + CompositeExplicitAutogradNonFunctional keys).
2229*da0073e9SAndroid Build Coastguard Worker                    elif g.structured and dispatch_key in (
2230*da0073e9SAndroid Build Coastguard Worker                        DispatchKey.Meta,
2231*da0073e9SAndroid Build Coastguard Worker                        DispatchKey.CompositeExplicitAutogradNonFunctional,
2232*da0073e9SAndroid Build Coastguard Worker                    ):
2233*da0073e9SAndroid Build Coastguard Worker                        is_registered = True
2234*da0073e9SAndroid Build Coastguard Worker                    if not is_registered:
2235*da0073e9SAndroid Build Coastguard Worker                        continue
2236*da0073e9SAndroid Build Coastguard Worker
2237*da0073e9SAndroid Build Coastguard Worker                    headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
2238*da0073e9SAndroid Build Coastguard Worker                    if (
2239*da0073e9SAndroid Build Coastguard Worker                        dispatch_key
2240*da0073e9SAndroid Build Coastguard Worker                        == DispatchKey.CompositeExplicitAutogradNonFunctional
2241*da0073e9SAndroid Build Coastguard Worker                    ):
2242*da0073e9SAndroid Build Coastguard Worker                        headers.append(f"#include <ATen/ops/{g.root_name}.h>")
2243*da0073e9SAndroid Build Coastguard Worker                    if dispatch_key in functions_keys:
2244*da0073e9SAndroid Build Coastguard Worker                        headers.append(
2245*da0073e9SAndroid Build Coastguard Worker                            f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
2246*da0073e9SAndroid Build Coastguard Worker                        )
2247*da0073e9SAndroid Build Coastguard Worker
2248*da0073e9SAndroid Build Coastguard Worker                return sorted(set(headers))
2249*da0073e9SAndroid Build Coastguard Worker
2250*da0073e9SAndroid Build Coastguard Worker        else:
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker            def operator_headers() -> list[str]:
2253*da0073e9SAndroid Build Coastguard Worker                headers = ["#include <ATen/NativeFunctions.h>"]
2254*da0073e9SAndroid Build Coastguard Worker                if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
2255*da0073e9SAndroid Build Coastguard Worker                    headers.append("#include <ATen/Functions.h>")
2256*da0073e9SAndroid Build Coastguard Worker                if dispatch_key in functions_keys:
2257*da0073e9SAndroid Build Coastguard Worker                    headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
2258*da0073e9SAndroid Build Coastguard Worker                return headers
2259*da0073e9SAndroid Build Coastguard Worker
2260*da0073e9SAndroid Build Coastguard Worker        backend_index = backend_indices[dispatch_key]
2261*da0073e9SAndroid Build Coastguard Worker        ns_grouped_native_functions = defaultdict(list)
2262*da0073e9SAndroid Build Coastguard Worker        for grouped_native_function in grouped_native_functions:
2263*da0073e9SAndroid Build Coastguard Worker            namespace = (
2264*da0073e9SAndroid Build Coastguard Worker                grouped_native_function.namespace
2265*da0073e9SAndroid Build Coastguard Worker                if isinstance(grouped_native_function, NativeFunction)
2266*da0073e9SAndroid Build Coastguard Worker                else grouped_native_function.functional.namespace
2267*da0073e9SAndroid Build Coastguard Worker            )
2268*da0073e9SAndroid Build Coastguard Worker            ns_grouped_native_functions[namespace].append(grouped_native_function)
2269*da0073e9SAndroid Build Coastguard Worker
2270*da0073e9SAndroid Build Coastguard Worker        dispatch_namespace = str(dispatch_key).lower()
2271*da0073e9SAndroid Build Coastguard Worker
2272*da0073e9SAndroid Build Coastguard Worker        # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
2273*da0073e9SAndroid Build Coastguard Worker        # compilation will fail when `-Werror=unused-function` flag is set
2274*da0073e9SAndroid Build Coastguard Worker        gen_dispatch_helpers: bool = (
2275*da0073e9SAndroid Build Coastguard Worker            dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
2276*da0073e9SAndroid Build Coastguard Worker        )
2277*da0073e9SAndroid Build Coastguard Worker
2278*da0073e9SAndroid Build Coastguard Worker        dispatch_definitions = get_native_function_definitions(
2279*da0073e9SAndroid Build Coastguard Worker            fm=fm,
2280*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=grouped_native_functions,
2281*da0073e9SAndroid Build Coastguard Worker            dispatch_key=dispatch_key,
2282*da0073e9SAndroid Build Coastguard Worker            backend_idx=backend_index,
2283*da0073e9SAndroid Build Coastguard Worker            selector=selector,
2284*da0073e9SAndroid Build Coastguard Worker            rocm=rocm,
2285*da0073e9SAndroid Build Coastguard Worker            symint=True,
2286*da0073e9SAndroid Build Coastguard Worker            skip_dispatcher_op_registration=skip_dispatcher_op_registration,
2287*da0073e9SAndroid Build Coastguard Worker            gen_dispatch_helpers=gen_dispatch_helpers,
2288*da0073e9SAndroid Build Coastguard Worker        )
2289*da0073e9SAndroid Build Coastguard Worker        fm.write_with_template(
2290*da0073e9SAndroid Build Coastguard Worker            f"Register{dispatch_key}.cpp",
2291*da0073e9SAndroid Build Coastguard Worker            "RegisterDispatchKey.cpp",
2292*da0073e9SAndroid Build Coastguard Worker            lambda: {
2293*da0073e9SAndroid Build Coastguard Worker                "extra_cuda_headers": extra_cuda_headers
2294*da0073e9SAndroid Build Coastguard Worker                if is_cuda_dispatch_key(dispatch_key)
2295*da0073e9SAndroid Build Coastguard Worker                else "",
2296*da0073e9SAndroid Build Coastguard Worker                "external_backend_headers": "",
2297*da0073e9SAndroid Build Coastguard Worker                "dispatch_headers": dest.gen_registration_headers(
2298*da0073e9SAndroid Build Coastguard Worker                    backend_index, per_operator_headers, rocm
2299*da0073e9SAndroid Build Coastguard Worker                ),
2300*da0073e9SAndroid Build Coastguard Worker                "ops_headers": operator_headers(),
2301*da0073e9SAndroid Build Coastguard Worker                "dispatch_helpers": "",
2302*da0073e9SAndroid Build Coastguard Worker                "dispatch_definitions": dispatch_definitions,
2303*da0073e9SAndroid Build Coastguard Worker            },
2304*da0073e9SAndroid Build Coastguard Worker        )
2305*da0073e9SAndroid Build Coastguard Worker
2306*da0073e9SAndroid Build Coastguard Worker        for g in structured_native_functions:
2307*da0073e9SAndroid Build Coastguard Worker            if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
2308*da0073e9SAndroid Build Coastguard Worker                continue
2309*da0073e9SAndroid Build Coastguard Worker            name = g.functional.func.name.name
2310*da0073e9SAndroid Build Coastguard Worker            if dispatch_key is DispatchKey.CPU:
2311*da0073e9SAndroid Build Coastguard Worker                assert fm is cpu_fm
2312*da0073e9SAndroid Build Coastguard Worker                fm.write_with_template(
2313*da0073e9SAndroid Build Coastguard Worker                    f"UfuncCPU_{name}.cpp",
2314*da0073e9SAndroid Build Coastguard Worker                    "UfuncCPU.cpp",
2315*da0073e9SAndroid Build Coastguard Worker                    lambda: {
2316*da0073e9SAndroid Build Coastguard Worker                        "meta_declaration": compute_meta_function_declaration(g),
2317*da0073e9SAndroid Build Coastguard Worker                        "native_declaration": dest.compute_native_function_declaration(
2318*da0073e9SAndroid Build Coastguard Worker                            g, backend_indices[dispatch_key]
2319*da0073e9SAndroid Build Coastguard Worker                        ),
2320*da0073e9SAndroid Build Coastguard Worker                        "native_definitions": dest.compute_ufunc_cpu(g),
2321*da0073e9SAndroid Build Coastguard Worker                    },
2322*da0073e9SAndroid Build Coastguard Worker                )
2323*da0073e9SAndroid Build Coastguard Worker                cpu_vec_fm.write_with_template(
2324*da0073e9SAndroid Build Coastguard Worker                    f"UfuncCPUKernel_{name}.cpp",
2325*da0073e9SAndroid Build Coastguard Worker                    "UfuncCPUKernel.cpp",
2326*da0073e9SAndroid Build Coastguard Worker                    lambda: {
2327*da0073e9SAndroid Build Coastguard Worker                        "name": name,
2328*da0073e9SAndroid Build Coastguard Worker                        "native_definitions": dest.compute_ufunc_cpu_kernel(g),
2329*da0073e9SAndroid Build Coastguard Worker                    },
2330*da0073e9SAndroid Build Coastguard Worker                )
2331*da0073e9SAndroid Build Coastguard Worker            elif dispatch_key is DispatchKey.CUDA:
2332*da0073e9SAndroid Build Coastguard Worker                cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
2333*da0073e9SAndroid Build Coastguard Worker                if rocm:
2334*da0073e9SAndroid Build Coastguard Worker                    cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
2335*da0073e9SAndroid Build Coastguard Worker                fm.write_with_template(
2336*da0073e9SAndroid Build Coastguard Worker                    f"UfuncCUDA_{name}.cu",
2337*da0073e9SAndroid Build Coastguard Worker                    "UfuncCUDA.cu",
2338*da0073e9SAndroid Build Coastguard Worker                    lambda: {
2339*da0073e9SAndroid Build Coastguard Worker                        "name": name,
2340*da0073e9SAndroid Build Coastguard Worker                        "cuda_headers": cuda_headers,
2341*da0073e9SAndroid Build Coastguard Worker                        "meta_declaration": compute_meta_function_declaration(g),
2342*da0073e9SAndroid Build Coastguard Worker                        "native_declaration": dest.compute_native_function_declaration(
2343*da0073e9SAndroid Build Coastguard Worker                            g, backend_indices[dispatch_key]
2344*da0073e9SAndroid Build Coastguard Worker                        ),
2345*da0073e9SAndroid Build Coastguard Worker                        "native_definitions": dest.compute_ufunc_cuda(g),
2346*da0073e9SAndroid Build Coastguard Worker                    },
2347*da0073e9SAndroid Build Coastguard Worker                )
2348*da0073e9SAndroid Build Coastguard Worker            else:
2349*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
2350*da0073e9SAndroid Build Coastguard Worker
2351*da0073e9SAndroid Build Coastguard Worker        structured_func_group_dict = {}
2352*da0073e9SAndroid Build Coastguard Worker        for func_group in structured_native_functions:
2353*da0073e9SAndroid Build Coastguard Worker            for func in func_group.functions():
2354*da0073e9SAndroid Build Coastguard Worker                if func.structured_delegate is not None:
2355*da0073e9SAndroid Build Coastguard Worker                    structured_func_group_dict[func.structured_delegate] = func_group
2356*da0073e9SAndroid Build Coastguard Worker                    break
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker        if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
2359*da0073e9SAndroid Build Coastguard Worker            fallbacks = {}
2360*da0073e9SAndroid Build Coastguard Worker            for func in native_functions:
2361*da0073e9SAndroid Build Coastguard Worker                op_name = get_fallback_op_name(func)
2362*da0073e9SAndroid Build Coastguard Worker                if op_name in inductor_fallback_ops:
2363*da0073e9SAndroid Build Coastguard Worker                    fallbacks[op_name] = func
2364*da0073e9SAndroid Build Coastguard Worker            fallback_native_functions = tuple(
2365*da0073e9SAndroid Build Coastguard Worker                value for _, value in sorted(fallbacks.items())
2366*da0073e9SAndroid Build Coastguard Worker            )
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker            # header files were checked in for ABI-compatiblilty checking
2369*da0073e9SAndroid Build Coastguard Worker            header_file_name = f"c_shim_{dispatch_key.lower()}.h"
2370*da0073e9SAndroid Build Coastguard Worker            new_header = gen_aoti_c_shim(
2371*da0073e9SAndroid Build Coastguard Worker                fallback_native_functions,
2372*da0073e9SAndroid Build Coastguard Worker                structured_func_group_dict,
2373*da0073e9SAndroid Build Coastguard Worker                dispatch_key,
2374*da0073e9SAndroid Build Coastguard Worker                backend_indices,
2375*da0073e9SAndroid Build Coastguard Worker                header=True,
2376*da0073e9SAndroid Build Coastguard Worker                includes="",
2377*da0073e9SAndroid Build Coastguard Worker            )
2378*da0073e9SAndroid Build Coastguard Worker            if update_aoti_c_shim:
2379*da0073e9SAndroid Build Coastguard Worker                aoti_fm.write(
2380*da0073e9SAndroid Build Coastguard Worker                    header_file_name,
2381*da0073e9SAndroid Build Coastguard Worker                    lambda: new_header,
2382*da0073e9SAndroid Build Coastguard Worker                )
2383*da0073e9SAndroid Build Coastguard Worker            else:
2384*da0073e9SAndroid Build Coastguard Worker                try:
2385*da0073e9SAndroid Build Coastguard Worker                    with open(
2386*da0073e9SAndroid Build Coastguard Worker                        os.path.join(aoti_fm.install_dir, header_file_name)
2387*da0073e9SAndroid Build Coastguard Worker                    ) as old_file:
2388*da0073e9SAndroid Build Coastguard Worker                        old_header = old_file.read()
2389*da0073e9SAndroid Build Coastguard Worker                        assert (
2390*da0073e9SAndroid Build Coastguard Worker                            old_header == new_header
2391*da0073e9SAndroid Build Coastguard Worker                        ), """
2392*da0073e9SAndroid Build Coastguard Worker
2393*da0073e9SAndroid Build Coastguard WorkerWARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
2394*da0073e9SAndroid Build Coastguard Workerindicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
2395*da0073e9SAndroid Build Coastguard WorkerOnly in a limited number of situations, this is allowed:
2396*da0073e9SAndroid Build Coastguard Worker
2397*da0073e9SAndroid Build Coastguard Worker1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
2398*da0073e9SAndroid Build Coastguard WorkerIf that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
2399*da0073e9SAndroid Build Coastguard WorkerC shim header files.
2400*da0073e9SAndroid Build Coastguard Worker
2401*da0073e9SAndroid Build Coastguard Worker2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
2402*da0073e9SAndroid Build Coastguard Workerchange in the AOTInductor land. In this case, you need to keep a manual copy of that existing
2403*da0073e9SAndroid Build Coastguard Workerfallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
2404*da0073e9SAndroid Build Coastguard Workernumber of that fallback op in the newly generated C shim files, and update the cpp wrapper
2405*da0073e9SAndroid Build Coastguard Workercodegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
2406*da0073e9SAndroid Build Coastguard Worker
2407*da0073e9SAndroid Build Coastguard Worker                        """
2408*da0073e9SAndroid Build Coastguard Worker                except FileNotFoundError:
2409*da0073e9SAndroid Build Coastguard Worker                    print(
2410*da0073e9SAndroid Build Coastguard Worker                        f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
2411*da0073e9SAndroid Build Coastguard Worker                    )
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker            # cpp files are always generated on-the-fly
2414*da0073e9SAndroid Build Coastguard Worker            def headers_for_aoti() -> str:
2415*da0073e9SAndroid Build Coastguard Worker                headers = []
2416*da0073e9SAndroid Build Coastguard Worker                for func in fallback_native_functions:
2417*da0073e9SAndroid Build Coastguard Worker                    header = get_header_for_aoti(
2418*da0073e9SAndroid Build Coastguard Worker                        func, structured_func_group_dict, dispatch_key, backend_indices
2419*da0073e9SAndroid Build Coastguard Worker                    )
2420*da0073e9SAndroid Build Coastguard Worker                    if header is not None:
2421*da0073e9SAndroid Build Coastguard Worker                        headers.append(header)
2422*da0073e9SAndroid Build Coastguard Worker                return "\n".join(sorted(set(headers)))
2423*da0073e9SAndroid Build Coastguard Worker
2424*da0073e9SAndroid Build Coastguard Worker            extra_headers = (
2425*da0073e9SAndroid Build Coastguard Worker                extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
2426*da0073e9SAndroid Build Coastguard Worker            )
2427*da0073e9SAndroid Build Coastguard Worker
2428*da0073e9SAndroid Build Coastguard Worker            aoti_fm.write(
2429*da0073e9SAndroid Build Coastguard Worker                f"c_shim_{dispatch_key.lower()}.cpp",
2430*da0073e9SAndroid Build Coastguard Worker                lambda: gen_aoti_c_shim(
2431*da0073e9SAndroid Build Coastguard Worker                    fallback_native_functions,
2432*da0073e9SAndroid Build Coastguard Worker                    structured_func_group_dict,
2433*da0073e9SAndroid Build Coastguard Worker                    dispatch_key,
2434*da0073e9SAndroid Build Coastguard Worker                    backend_indices,
2435*da0073e9SAndroid Build Coastguard Worker                    header=False,
2436*da0073e9SAndroid Build Coastguard Worker                    includes=headers_for_aoti() + "\n" + extra_headers,
2437*da0073e9SAndroid Build Coastguard Worker                ),
2438*da0073e9SAndroid Build Coastguard Worker            )
2439*da0073e9SAndroid Build Coastguard Worker
2440*da0073e9SAndroid Build Coastguard Worker        del fm
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker    # BackendSelect is generated specially
2443*da0073e9SAndroid Build Coastguard Worker    def gen_backend_select() -> dict[str, list[str]]:
2444*da0073e9SAndroid Build Coastguard Worker        relevant_fns = [
2445*da0073e9SAndroid Build Coastguard Worker            fn for fn in native_functions if needs_backend_select(fn, selector)
2446*da0073e9SAndroid Build Coastguard Worker        ]
2447*da0073e9SAndroid Build Coastguard Worker        return {
2448*da0073e9SAndroid Build Coastguard Worker            "ops_headers": [
2449*da0073e9SAndroid Build Coastguard Worker                f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
2450*da0073e9SAndroid Build Coastguard Worker            ],
2451*da0073e9SAndroid Build Coastguard Worker            "backend_select_method_definitions": list(
2452*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2453*da0073e9SAndroid Build Coastguard Worker                    ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
2454*da0073e9SAndroid Build Coastguard Worker                )
2455*da0073e9SAndroid Build Coastguard Worker            ),
2456*da0073e9SAndroid Build Coastguard Worker            "backend_select_function_registrations": list(
2457*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2458*da0073e9SAndroid Build Coastguard Worker                    ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
2459*da0073e9SAndroid Build Coastguard Worker                )
2460*da0073e9SAndroid Build Coastguard Worker            ),
2461*da0073e9SAndroid Build Coastguard Worker        }
2462*da0073e9SAndroid Build Coastguard Worker
2463*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
2464*da0073e9SAndroid Build Coastguard Worker
2465*da0073e9SAndroid Build Coastguard Worker    schema_selector = selector
2466*da0073e9SAndroid Build Coastguard Worker    if force_schema_registration:
2467*da0073e9SAndroid Build Coastguard Worker        schema_selector = SelectiveBuilder.get_nop_selector()
2468*da0073e9SAndroid Build Coastguard Worker
2469*da0073e9SAndroid Build Coastguard Worker    (
2470*da0073e9SAndroid Build Coastguard Worker        aten_schema_registrations,
2471*da0073e9SAndroid Build Coastguard Worker        schema_registrations,
2472*da0073e9SAndroid Build Coastguard Worker    ) = get_native_function_schema_registrations(
2473*da0073e9SAndroid Build Coastguard Worker        native_functions=native_functions, schema_selector=schema_selector
2474*da0073e9SAndroid Build Coastguard Worker    )
2475*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2476*da0073e9SAndroid Build Coastguard Worker        "RegisterSchema.cpp",
2477*da0073e9SAndroid Build Coastguard Worker        lambda: {
2478*da0073e9SAndroid Build Coastguard Worker            "aten_schema_registrations": []
2479*da0073e9SAndroid Build Coastguard Worker            if skip_dispatcher_op_registration
2480*da0073e9SAndroid Build Coastguard Worker            else aten_schema_registrations,
2481*da0073e9SAndroid Build Coastguard Worker            "schema_registrations": []
2482*da0073e9SAndroid Build Coastguard Worker            if skip_dispatcher_op_registration
2483*da0073e9SAndroid Build Coastguard Worker            else schema_registrations,
2484*da0073e9SAndroid Build Coastguard Worker        },
2485*da0073e9SAndroid Build Coastguard Worker    )
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker    def key_func(
2488*da0073e9SAndroid Build Coastguard Worker        fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2489*da0073e9SAndroid Build Coastguard Worker    ) -> str:
2490*da0073e9SAndroid Build Coastguard Worker        return fn.root_name
2491*da0073e9SAndroid Build Coastguard Worker
2492*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write_sharded(
2493*da0073e9SAndroid Build Coastguard Worker        "Operators.cpp",
2494*da0073e9SAndroid Build Coastguard Worker        native_functions,
2495*da0073e9SAndroid Build Coastguard Worker        key_fn=key_func,
2496*da0073e9SAndroid Build Coastguard Worker        env_callable=lambda fn: {
2497*da0073e9SAndroid Build Coastguard Worker            "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
2498*da0073e9SAndroid Build Coastguard Worker            "definitions": [
2499*da0073e9SAndroid Build Coastguard Worker                ComputeOperators(
2500*da0073e9SAndroid Build Coastguard Worker                    Target.DEFINITION,
2501*da0073e9SAndroid Build Coastguard Worker                    static_dispatch_backend_indices=static_dispatch_idx,
2502*da0073e9SAndroid Build Coastguard Worker                )(fn)
2503*da0073e9SAndroid Build Coastguard Worker            ],
2504*da0073e9SAndroid Build Coastguard Worker        },
2505*da0073e9SAndroid Build Coastguard Worker        base_env={
2506*da0073e9SAndroid Build Coastguard Worker            "static_dispatch_extra_headers": static_dispatch_extra_headers(
2507*da0073e9SAndroid Build Coastguard Worker                static_dispatch_idx
2508*da0073e9SAndroid Build Coastguard Worker            ),
2509*da0073e9SAndroid Build Coastguard Worker        },
2510*da0073e9SAndroid Build Coastguard Worker        num_shards=5,
2511*da0073e9SAndroid Build Coastguard Worker        sharded_keys={
2512*da0073e9SAndroid Build Coastguard Worker            "operator_headers",
2513*da0073e9SAndroid Build Coastguard Worker            "definitions",
2514*da0073e9SAndroid Build Coastguard Worker            "static_dispatch_extra_headers",
2515*da0073e9SAndroid Build Coastguard Worker        },
2516*da0073e9SAndroid Build Coastguard Worker    )
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write("Functions.cpp", dict)
2519*da0073e9SAndroid Build Coastguard Worker
2520*da0073e9SAndroid Build Coastguard Worker    core_fm.write("TensorMethods.cpp", dict)
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker    core_fm.write(
2523*da0073e9SAndroid Build Coastguard Worker        "ATenOpList.cpp",
2524*da0073e9SAndroid Build Coastguard Worker        lambda: {
2525*da0073e9SAndroid Build Coastguard Worker            "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
2526*da0073e9SAndroid Build Coastguard Worker        },
2527*da0073e9SAndroid Build Coastguard Worker    )
2528*da0073e9SAndroid Build Coastguard Worker
2529*da0073e9SAndroid Build Coastguard Worker    def functionalization_env_callable(
2530*da0073e9SAndroid Build Coastguard Worker        g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2531*da0073e9SAndroid Build Coastguard Worker    ) -> dict[str, list[str]]:
2532*da0073e9SAndroid Build Coastguard Worker        def gen_op_headers(
2533*da0073e9SAndroid Build Coastguard Worker            g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2534*da0073e9SAndroid Build Coastguard Worker        ) -> list[str]:
2535*da0073e9SAndroid Build Coastguard Worker            if isinstance(g, NativeFunctionsViewGroup):
2536*da0073e9SAndroid Build Coastguard Worker                # view ops always get a functionalization kernel
2537*da0073e9SAndroid Build Coastguard Worker                headers = [
2538*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.view.root_name}_native.h>",
2539*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.view.root_name}_ops.h>",
2540*da0073e9SAndroid Build Coastguard Worker                ]
2541*da0073e9SAndroid Build Coastguard Worker                if g.view_copy is not None:
2542*da0073e9SAndroid Build Coastguard Worker                    headers += [
2543*da0073e9SAndroid Build Coastguard Worker                        f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
2544*da0073e9SAndroid Build Coastguard Worker                        f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
2545*da0073e9SAndroid Build Coastguard Worker                    ]
2546*da0073e9SAndroid Build Coastguard Worker                return headers
2547*da0073e9SAndroid Build Coastguard Worker            elif isinstance(g, NativeFunctionsGroup):
2548*da0073e9SAndroid Build Coastguard Worker                headers = [
2549*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.functional.root_name}_native.h>",
2550*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
2551*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.out.root_name}_native.h>",
2552*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.out.root_name}_ops.h>",
2553*da0073e9SAndroid Build Coastguard Worker                ]
2554*da0073e9SAndroid Build Coastguard Worker                if g.inplace is not None:
2555*da0073e9SAndroid Build Coastguard Worker                    headers += [
2556*da0073e9SAndroid Build Coastguard Worker                        f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
2557*da0073e9SAndroid Build Coastguard Worker                        f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
2558*da0073e9SAndroid Build Coastguard Worker                    ]
2559*da0073e9SAndroid Build Coastguard Worker                if g.mutable is not None:
2560*da0073e9SAndroid Build Coastguard Worker                    headers += [
2561*da0073e9SAndroid Build Coastguard Worker                        f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
2562*da0073e9SAndroid Build Coastguard Worker                        f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
2563*da0073e9SAndroid Build Coastguard Worker                    ]
2564*da0073e9SAndroid Build Coastguard Worker                return headers
2565*da0073e9SAndroid Build Coastguard Worker            else:
2566*da0073e9SAndroid Build Coastguard Worker                return [
2567*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.root_name}_native.h>",
2568*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{g.root_name}_ops.h>",
2569*da0073e9SAndroid Build Coastguard Worker                ]
2570*da0073e9SAndroid Build Coastguard Worker
2571*da0073e9SAndroid Build Coastguard Worker        return {
2572*da0073e9SAndroid Build Coastguard Worker            "ops_headers": gen_op_headers(g),
2573*da0073e9SAndroid Build Coastguard Worker            "func_definitions": gen_functionalization_definition(
2574*da0073e9SAndroid Build Coastguard Worker                selector,
2575*da0073e9SAndroid Build Coastguard Worker                g,
2576*da0073e9SAndroid Build Coastguard Worker            ),
2577*da0073e9SAndroid Build Coastguard Worker            "func_registrations": gen_functionalization_registration(
2578*da0073e9SAndroid Build Coastguard Worker                selector,
2579*da0073e9SAndroid Build Coastguard Worker                g,
2580*da0073e9SAndroid Build Coastguard Worker                backend_indices[DispatchKey.CompositeImplicitAutograd],
2581*da0073e9SAndroid Build Coastguard Worker            ),
2582*da0073e9SAndroid Build Coastguard Worker        }
2583*da0073e9SAndroid Build Coastguard Worker
2584*da0073e9SAndroid Build Coastguard Worker    all_groups: list[
2585*da0073e9SAndroid Build Coastguard Worker        NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
2586*da0073e9SAndroid Build Coastguard Worker    ] = list(structured_native_functions) + list(
2587*da0073e9SAndroid Build Coastguard Worker        view_groups  # type: ignore[assignment, arg-type, operator]
2588*da0073e9SAndroid Build Coastguard Worker    )
2589*da0073e9SAndroid Build Coastguard Worker    # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
2590*da0073e9SAndroid Build Coastguard Worker    # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
2591*da0073e9SAndroid Build Coastguard Worker    # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
2592*da0073e9SAndroid Build Coastguard Worker    # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
2593*da0073e9SAndroid Build Coastguard Worker    #     Although this could go away long-term if we add a dedicated dispatch key for decompositions.
2594*da0073e9SAndroid Build Coastguard Worker    structured_map: dict[OperatorName, NativeFunction] = {
2595*da0073e9SAndroid Build Coastguard Worker        f.func.name: f
2596*da0073e9SAndroid Build Coastguard Worker        for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
2597*da0073e9SAndroid Build Coastguard Worker    }
2598*da0073e9SAndroid Build Coastguard Worker    view_map: dict[OperatorName, NativeFunction] = {
2599*da0073e9SAndroid Build Coastguard Worker        f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
2600*da0073e9SAndroid Build Coastguard Worker    }
2601*da0073e9SAndroid Build Coastguard Worker    for f in native_functions:
2602*da0073e9SAndroid Build Coastguard Worker        if f.func.name not in structured_map and f.func.name not in view_map:
2603*da0073e9SAndroid Build Coastguard Worker            all_groups.append(f)
2604*da0073e9SAndroid Build Coastguard Worker
2605*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write_sharded(
2606*da0073e9SAndroid Build Coastguard Worker        "RegisterFunctionalization.cpp",
2607*da0073e9SAndroid Build Coastguard Worker        all_groups,
2608*da0073e9SAndroid Build Coastguard Worker        key_fn=key_func,
2609*da0073e9SAndroid Build Coastguard Worker        env_callable=functionalization_env_callable,
2610*da0073e9SAndroid Build Coastguard Worker        num_shards=4,
2611*da0073e9SAndroid Build Coastguard Worker        sharded_keys={
2612*da0073e9SAndroid Build Coastguard Worker            "ops_headers",
2613*da0073e9SAndroid Build Coastguard Worker            "func_definitions",
2614*da0073e9SAndroid Build Coastguard Worker            "func_registrations",
2615*da0073e9SAndroid Build Coastguard Worker            "func_add_back_views_definitions",
2616*da0073e9SAndroid Build Coastguard Worker            "func_add_back_views_registrations",
2617*da0073e9SAndroid Build Coastguard Worker        },
2618*da0073e9SAndroid Build Coastguard Worker    )
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2621*da0073e9SAndroid Build Coastguard Worker        "FunctionalInverses.h",
2622*da0073e9SAndroid Build Coastguard Worker        lambda: {
2623*da0073e9SAndroid Build Coastguard Worker            "view_inverse_declarations": list(
2624*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2625*da0073e9SAndroid Build Coastguard Worker                    lambda g: gen_functionalization_view_inverse_declaration(
2626*da0073e9SAndroid Build Coastguard Worker                        selector, g
2627*da0073e9SAndroid Build Coastguard Worker                    ),
2628*da0073e9SAndroid Build Coastguard Worker                    view_groups,
2629*da0073e9SAndroid Build Coastguard Worker                )
2630*da0073e9SAndroid Build Coastguard Worker            )
2631*da0073e9SAndroid Build Coastguard Worker        },
2632*da0073e9SAndroid Build Coastguard Worker    )
2633*da0073e9SAndroid Build Coastguard Worker
2634*da0073e9SAndroid Build Coastguard Worker    # Note [view_copy NativeFunctions]
2635*da0073e9SAndroid Build Coastguard Worker    # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
2636*da0073e9SAndroid Build Coastguard Worker    # needs to have a corresponding non-aliasing {view}_copy variant.
2637*da0073e9SAndroid Build Coastguard Worker    # Backends that use functionalization and don't know how to handle aliasing ops
2638*da0073e9SAndroid Build Coastguard Worker    # are expected to implement kernels for these {view}_copy kernels instead.
2639*da0073e9SAndroid Build Coastguard Worker    # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
2640*da0073e9SAndroid Build Coastguard Worker    # so we codegen the following:
2641*da0073e9SAndroid Build Coastguard Worker    # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
2642*da0073e9SAndroid Build Coastguard Worker    #     These are never explicitly invoked by the functionalization pass,
2643*da0073e9SAndroid Build Coastguard Worker    #     but they could theoretically be called from user code (I added these kernels for completeness,
2644*da0073e9SAndroid Build Coastguard Worker    #     since the ops are part of the public API).
2645*da0073e9SAndroid Build Coastguard Worker    # (2) A derivative formula for every {view}_copy operator
2646*da0073e9SAndroid Build Coastguard Worker    #     {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
2647*da0073e9SAndroid Build Coastguard Worker    #     so rather than stamping all of the entries out in derivatives.yaml,
2648*da0073e9SAndroid Build Coastguard Worker    #     we codegen them in.
2649*da0073e9SAndroid Build Coastguard Worker    #     This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
2650*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2651*da0073e9SAndroid Build Coastguard Worker        "CompositeViewCopyKernels.cpp",
2652*da0073e9SAndroid Build Coastguard Worker        lambda: {
2653*da0073e9SAndroid Build Coastguard Worker            "ops_headers": [
2654*da0073e9SAndroid Build Coastguard Worker                "\n".join(
2655*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2656*da0073e9SAndroid Build Coastguard Worker                    # NB: this include is important as it ensures we
2657*da0073e9SAndroid Build Coastguard Worker                    # set the visibility on generated view_copy kernels
2658*da0073e9SAndroid Build Coastguard Worker                    # correctly
2659*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{f.root_name}_native.h>"
2660*da0073e9SAndroid Build Coastguard Worker                    for f in (
2661*da0073e9SAndroid Build Coastguard Worker                        [g.view] if g.view_copy is None else [g.view, g.view_copy]
2662*da0073e9SAndroid Build Coastguard Worker                    )
2663*da0073e9SAndroid Build Coastguard Worker                )
2664*da0073e9SAndroid Build Coastguard Worker                for g in view_groups
2665*da0073e9SAndroid Build Coastguard Worker            ]
2666*da0073e9SAndroid Build Coastguard Worker            + [
2667*da0073e9SAndroid Build Coastguard Worker                "\n".join(
2668*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2669*da0073e9SAndroid Build Coastguard Worker                    # NB: this include is also important for correct visibility
2670*da0073e9SAndroid Build Coastguard Worker                    f"#include <ATen/ops/{f.root_name}_native.h>"
2671*da0073e9SAndroid Build Coastguard Worker                    for f in [g.inplace, g.mutable, g.functional]
2672*da0073e9SAndroid Build Coastguard Worker                    if f is not None and "generated" not in f.tags
2673*da0073e9SAndroid Build Coastguard Worker                )
2674*da0073e9SAndroid Build Coastguard Worker                for g in structured_native_functions
2675*da0073e9SAndroid Build Coastguard Worker            ],
2676*da0073e9SAndroid Build Coastguard Worker            "CompositeViewCopyKernel_Definitions": list(
2677*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2678*da0073e9SAndroid Build Coastguard Worker                    GenCompositeViewCopyKernel(
2679*da0073e9SAndroid Build Coastguard Worker                        backend_indices[
2680*da0073e9SAndroid Build Coastguard Worker                            DispatchKey.CompositeExplicitAutogradNonFunctional
2681*da0073e9SAndroid Build Coastguard Worker                        ]
2682*da0073e9SAndroid Build Coastguard Worker                    ),
2683*da0073e9SAndroid Build Coastguard Worker                    view_groups,
2684*da0073e9SAndroid Build Coastguard Worker                )
2685*da0073e9SAndroid Build Coastguard Worker            ),
2686*da0073e9SAndroid Build Coastguard Worker            "GeneratedCompositeFunctional_Definitions": list(
2687*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2688*da0073e9SAndroid Build Coastguard Worker                    gen_composite_functional_kernel,
2689*da0073e9SAndroid Build Coastguard Worker                    structured_native_functions,
2690*da0073e9SAndroid Build Coastguard Worker                )
2691*da0073e9SAndroid Build Coastguard Worker            ),
2692*da0073e9SAndroid Build Coastguard Worker            "GeneratedCompositeOut_Definitions": list(
2693*da0073e9SAndroid Build Coastguard Worker                mapMaybe(
2694*da0073e9SAndroid Build Coastguard Worker                    gen_composite_out_kernel,
2695*da0073e9SAndroid Build Coastguard Worker                    structured_native_functions,
2696*da0073e9SAndroid Build Coastguard Worker                )
2697*da0073e9SAndroid Build Coastguard Worker            ),
2698*da0073e9SAndroid Build Coastguard Worker        },
2699*da0073e9SAndroid Build Coastguard Worker    )
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker
2702*da0073e9SAndroid Build Coastguard Workerdef gen_declarations_yaml(
2703*da0073e9SAndroid Build Coastguard Worker    cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
2704*da0073e9SAndroid Build Coastguard Worker) -> None:
2705*da0073e9SAndroid Build Coastguard Worker    cpu_fm.write(
2706*da0073e9SAndroid Build Coastguard Worker        "Declarations.yaml",
2707*da0073e9SAndroid Build Coastguard Worker        lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
2708*da0073e9SAndroid Build Coastguard Worker    )
2709*da0073e9SAndroid Build Coastguard Worker
2710*da0073e9SAndroid Build Coastguard Worker
2711*da0073e9SAndroid Build Coastguard Workerdef get_torchgen_root() -> Path:
2712*da0073e9SAndroid Build Coastguard Worker    """
2713*da0073e9SAndroid Build Coastguard Worker    If you're depending on torchgen out-of-tree, you can use the root to figure
2714*da0073e9SAndroid Build Coastguard Worker    out the path to native_functions.yaml
2715*da0073e9SAndroid Build Coastguard Worker    """
2716*da0073e9SAndroid Build Coastguard Worker    return Path(__file__).parent.resolve()
2717*da0073e9SAndroid Build Coastguard Worker
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
2720*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Generate ATen source files")
2721*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2722*da0073e9SAndroid Build Coastguard Worker        "-s",
2723*da0073e9SAndroid Build Coastguard Worker        "--source-path",
2724*da0073e9SAndroid Build Coastguard Worker        help="path to source directory for ATen",
2725*da0073e9SAndroid Build Coastguard Worker        default="aten/src/ATen",
2726*da0073e9SAndroid Build Coastguard Worker    )
2727*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2728*da0073e9SAndroid Build Coastguard Worker        "-o",
2729*da0073e9SAndroid Build Coastguard Worker        "--output-dependencies",
2730*da0073e9SAndroid Build Coastguard Worker        help="output a list of dependencies into the given file and exit",
2731*da0073e9SAndroid Build Coastguard Worker    )
2732*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2733*da0073e9SAndroid Build Coastguard Worker        "--dry-run",
2734*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2735*da0073e9SAndroid Build Coastguard Worker        help="run without writing any files (still updates outputs)",
2736*da0073e9SAndroid Build Coastguard Worker    )
2737*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2738*da0073e9SAndroid Build Coastguard Worker        "--per-operator-headers",
2739*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2740*da0073e9SAndroid Build Coastguard Worker        help="generate separate headers per operator in ATen/ops",
2741*da0073e9SAndroid Build Coastguard Worker    )
2742*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2743*da0073e9SAndroid Build Coastguard Worker        "-d",
2744*da0073e9SAndroid Build Coastguard Worker        "--install-dir",
2745*da0073e9SAndroid Build Coastguard Worker        "--install_dir",
2746*da0073e9SAndroid Build Coastguard Worker        help="output directory",
2747*da0073e9SAndroid Build Coastguard Worker        default="build/aten/src/ATen",
2748*da0073e9SAndroid Build Coastguard Worker    )
2749*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2750*da0073e9SAndroid Build Coastguard Worker        "--aoti-install-dir",
2751*da0073e9SAndroid Build Coastguard Worker        "--aoti_install_dir",
2752*da0073e9SAndroid Build Coastguard Worker        help="output directory for AOTInductor shim",
2753*da0073e9SAndroid Build Coastguard Worker        default="torch/csrc/inductor/aoti_torch/generated",
2754*da0073e9SAndroid Build Coastguard Worker    )
2755*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2756*da0073e9SAndroid Build Coastguard Worker        "--rocm",
2757*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2758*da0073e9SAndroid Build Coastguard Worker        help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
2759*da0073e9SAndroid Build Coastguard Worker    )
2760*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2761*da0073e9SAndroid Build Coastguard Worker        "--mps",
2762*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2763*da0073e9SAndroid Build Coastguard Worker        help="Generate MPS registration code when set",
2764*da0073e9SAndroid Build Coastguard Worker    )
2765*da0073e9SAndroid Build Coastguard Worker    # TODO: --op-registration-whitelist will be removed when all call-sites
2766*da0073e9SAndroid Build Coastguard Worker    # for gen.py are moved over to using the operator YAML file for mobile
2767*da0073e9SAndroid Build Coastguard Worker    # custom build.
2768*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2769*da0073e9SAndroid Build Coastguard Worker        "--op-registration-whitelist",
2770*da0073e9SAndroid Build Coastguard Worker        "--op_registration_whitelist",
2771*da0073e9SAndroid Build Coastguard Worker        nargs="*",
2772*da0073e9SAndroid Build Coastguard Worker        help="filter op registrations by the whitelist (if set); "
2773*da0073e9SAndroid Build Coastguard Worker        "each item is `namespace`::`operator name` without overload name; "
2774*da0073e9SAndroid Build Coastguard Worker        "e.g.: aten::empty aten::conv2d ...",
2775*da0073e9SAndroid Build Coastguard Worker    )
2776*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2777*da0073e9SAndroid Build Coastguard Worker        "--op-selection-yaml-path",
2778*da0073e9SAndroid Build Coastguard Worker        "--op_selection_yaml_path",
2779*da0073e9SAndroid Build Coastguard Worker        help="Provide a path to the operator selection (for custom build) YAML "
2780*da0073e9SAndroid Build Coastguard Worker        "that contains the information about the set of selected operators "
2781*da0073e9SAndroid Build Coastguard Worker        "and their categories (training, ...). Each operator is either a "
2782*da0073e9SAndroid Build Coastguard Worker        "full operator name with overload or just a bare operator name. "
2783*da0073e9SAndroid Build Coastguard Worker        "The operator names also contain the namespace prefix (e.g. aten::)",
2784*da0073e9SAndroid Build Coastguard Worker    )
2785*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2786*da0073e9SAndroid Build Coastguard Worker        "--backend-whitelist",
2787*da0073e9SAndroid Build Coastguard Worker        "--backend_whitelist",
2788*da0073e9SAndroid Build Coastguard Worker        nargs="*",
2789*da0073e9SAndroid Build Coastguard Worker        help="filter dispatch backend by the whitelist (if set), "
2790*da0073e9SAndroid Build Coastguard Worker        "e.g.: CPU CUDA QuantizedCPU ...",
2791*da0073e9SAndroid Build Coastguard Worker    )
2792*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2793*da0073e9SAndroid Build Coastguard Worker        "--static-dispatch-backend",
2794*da0073e9SAndroid Build Coastguard Worker        "--static_dispatch_backend",
2795*da0073e9SAndroid Build Coastguard Worker        nargs="*",
2796*da0073e9SAndroid Build Coastguard Worker        help="generate static dispatch code for the specific backend (if set)",
2797*da0073e9SAndroid Build Coastguard Worker    )
2798*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2799*da0073e9SAndroid Build Coastguard Worker        "--skip-dispatcher-op-registration",
2800*da0073e9SAndroid Build Coastguard Worker        "--skip_dispatcher_op_registration",
2801*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2802*da0073e9SAndroid Build Coastguard Worker        help="Avoid registering operators into the dispatcher.",
2803*da0073e9SAndroid Build Coastguard Worker    )
2804*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2805*da0073e9SAndroid Build Coastguard Worker        "--force-schema-registration",
2806*da0073e9SAndroid Build Coastguard Worker        "--force_schema_registration",
2807*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2808*da0073e9SAndroid Build Coastguard Worker        help="force it to generate schema-only registrations for all ops, including"
2809*da0073e9SAndroid Build Coastguard Worker        "those that are not listed on --op-registration-whitelist",
2810*da0073e9SAndroid Build Coastguard Worker    )
2811*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2812*da0073e9SAndroid Build Coastguard Worker        "--generate",
2813*da0073e9SAndroid Build Coastguard Worker        type=str,
2814*da0073e9SAndroid Build Coastguard Worker        nargs="*",
2815*da0073e9SAndroid Build Coastguard Worker        choices=["headers", "sources", "declarations_yaml"],
2816*da0073e9SAndroid Build Coastguard Worker        default=["headers", "sources", "declarations_yaml"],
2817*da0073e9SAndroid Build Coastguard Worker        help="Generate only a subset of files",
2818*da0073e9SAndroid Build Coastguard Worker    )
2819*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
2820*da0073e9SAndroid Build Coastguard Worker        "--update-aoti-c-shim",
2821*da0073e9SAndroid Build Coastguard Worker        action="store_true",
2822*da0073e9SAndroid Build Coastguard Worker        help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
2823*da0073e9SAndroid Build Coastguard Worker        "WARNING: Do not use this unless you are sure what you are doing!!!",
2824*da0073e9SAndroid Build Coastguard Worker    )
2825*da0073e9SAndroid Build Coastguard Worker
2826*da0073e9SAndroid Build Coastguard Worker    options = parser.parse_args()
2827*da0073e9SAndroid Build Coastguard Worker
2828*da0073e9SAndroid Build Coastguard Worker    selector = get_custom_build_selector(
2829*da0073e9SAndroid Build Coastguard Worker        options.op_registration_whitelist,
2830*da0073e9SAndroid Build Coastguard Worker        options.op_selection_yaml_path,
2831*da0073e9SAndroid Build Coastguard Worker    )
2832*da0073e9SAndroid Build Coastguard Worker
2833*da0073e9SAndroid Build Coastguard Worker    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
2834*da0073e9SAndroid Build Coastguard Worker    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Worker    from torchgen.model import dispatch_keys
2837*da0073e9SAndroid Build Coastguard Worker
2838*da0073e9SAndroid Build Coastguard Worker    # TODO: stop generating CUDA kernels for non-CUDA builds
2839*da0073e9SAndroid Build Coastguard Worker    ignore_keys = set()
2840*da0073e9SAndroid Build Coastguard Worker    if not options.mps:
2841*da0073e9SAndroid Build Coastguard Worker        ignore_keys.add(DispatchKey.MPS)
2842*da0073e9SAndroid Build Coastguard Worker
2843*da0073e9SAndroid Build Coastguard Worker        if DispatchKey.MPS in dispatch_keys:
2844*da0073e9SAndroid Build Coastguard Worker            del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
2845*da0073e9SAndroid Build Coastguard Worker
2846*da0073e9SAndroid Build Coastguard Worker    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
2847*da0073e9SAndroid Build Coastguard Worker    valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
2848*da0073e9SAndroid Build Coastguard Worker    native_functions, backend_indices = (
2849*da0073e9SAndroid Build Coastguard Worker        parsed_yaml.native_functions,
2850*da0073e9SAndroid Build Coastguard Worker        parsed_yaml.backend_indices,
2851*da0073e9SAndroid Build Coastguard Worker    )
2852*da0073e9SAndroid Build Coastguard Worker
2853*da0073e9SAndroid Build Coastguard Worker    grouped_native_functions = get_grouped_native_functions(native_functions)
2854*da0073e9SAndroid Build Coastguard Worker
2855*da0073e9SAndroid Build Coastguard Worker    structured_native_functions = [
2856*da0073e9SAndroid Build Coastguard Worker        g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
2857*da0073e9SAndroid Build Coastguard Worker    ]
2858*da0073e9SAndroid Build Coastguard Worker    native_functions_with_view_groups = get_grouped_by_view_native_functions(
2859*da0073e9SAndroid Build Coastguard Worker        native_functions
2860*da0073e9SAndroid Build Coastguard Worker    )
2861*da0073e9SAndroid Build Coastguard Worker    view_groups = [
2862*da0073e9SAndroid Build Coastguard Worker        g
2863*da0073e9SAndroid Build Coastguard Worker        for g in native_functions_with_view_groups
2864*da0073e9SAndroid Build Coastguard Worker        if isinstance(g, NativeFunctionsViewGroup)
2865*da0073e9SAndroid Build Coastguard Worker    ]
2866*da0073e9SAndroid Build Coastguard Worker
2867*da0073e9SAndroid Build Coastguard Worker    # NB: It is mandatory to NOT use os.path.join here, as the install directory
2868*da0073e9SAndroid Build Coastguard Worker    # will eventually be ingested by cmake, which does not respect Windows style
2869*da0073e9SAndroid Build Coastguard Worker    # path slashes.  If you switch this to use os.path.join, you'll get an error
2870*da0073e9SAndroid Build Coastguard Worker    # like:
2871*da0073e9SAndroid Build Coastguard Worker    #
2872*da0073e9SAndroid Build Coastguard Worker    #   Syntax error in cmake code when parsing string
2873*da0073e9SAndroid Build Coastguard Worker    #
2874*da0073e9SAndroid Build Coastguard Worker    #     C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
2875*da0073e9SAndroid Build Coastguard Worker    #
2876*da0073e9SAndroid Build Coastguard Worker    #   Invalid character escape '\c'.
2877*da0073e9SAndroid Build Coastguard Worker    core_install_dir = f"{options.install_dir}/core"
2878*da0073e9SAndroid Build Coastguard Worker    Path(core_install_dir).mkdir(parents=True, exist_ok=True)
2879*da0073e9SAndroid Build Coastguard Worker    ops_install_dir = f"{options.install_dir}/ops"
2880*da0073e9SAndroid Build Coastguard Worker    Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
2881*da0073e9SAndroid Build Coastguard Worker    aoti_install_dir = f"{options.aoti_install_dir}"
2882*da0073e9SAndroid Build Coastguard Worker    Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
2883*da0073e9SAndroid Build Coastguard Worker
2884*da0073e9SAndroid Build Coastguard Worker    core_fm = make_file_manager(options=options, install_dir=core_install_dir)
2885*da0073e9SAndroid Build Coastguard Worker    cpu_fm = make_file_manager(options=options)
2886*da0073e9SAndroid Build Coastguard Worker    cpu_vec_fm = make_file_manager(options=options)
2887*da0073e9SAndroid Build Coastguard Worker    cuda_fm = make_file_manager(options=options)
2888*da0073e9SAndroid Build Coastguard Worker    ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
2889*da0073e9SAndroid Build Coastguard Worker    aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
2890*da0073e9SAndroid Build Coastguard Worker
2891*da0073e9SAndroid Build Coastguard Worker    # Only a limited set of dispatch keys get CPUFunctions.h headers generated
2892*da0073e9SAndroid Build Coastguard Worker    # for them; this is the set
2893*da0073e9SAndroid Build Coastguard Worker    functions_keys = {
2894*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CPU,
2895*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CUDA,
2896*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeImplicitAutograd,
2897*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeImplicitAutogradNestedTensor,
2898*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeExplicitAutograd,
2899*da0073e9SAndroid Build Coastguard Worker        DispatchKey.CompositeExplicitAutogradNonFunctional,
2900*da0073e9SAndroid Build Coastguard Worker        DispatchKey.Meta,
2901*da0073e9SAndroid Build Coastguard Worker    }
2902*da0073e9SAndroid Build Coastguard Worker    if options.mps:
2903*da0073e9SAndroid Build Coastguard Worker        functions_keys.add(DispatchKey.MPS)
2904*da0073e9SAndroid Build Coastguard Worker
2905*da0073e9SAndroid Build Coastguard Worker    if options.backend_whitelist:
2906*da0073e9SAndroid Build Coastguard Worker        dispatch_keys = [
2907*da0073e9SAndroid Build Coastguard Worker            k
2908*da0073e9SAndroid Build Coastguard Worker            for k in dispatch_keys
2909*da0073e9SAndroid Build Coastguard Worker            if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
2910*da0073e9SAndroid Build Coastguard Worker        ]
2911*da0073e9SAndroid Build Coastguard Worker
2912*da0073e9SAndroid Build Coastguard Worker    static_dispatch_idx: list[BackendIndex] = []
2913*da0073e9SAndroid Build Coastguard Worker    if options.static_dispatch_backend:
2914*da0073e9SAndroid Build Coastguard Worker        static_dispatch_idx = [
2915*da0073e9SAndroid Build Coastguard Worker            backend_indices[DispatchKey.parse(key)]
2916*da0073e9SAndroid Build Coastguard Worker            for key in options.static_dispatch_backend
2917*da0073e9SAndroid Build Coastguard Worker        ]
2918*da0073e9SAndroid Build Coastguard Worker        for key in options.static_dispatch_backend:
2919*da0073e9SAndroid Build Coastguard Worker            dp_key = DispatchKey.parse(key)
2920*da0073e9SAndroid Build Coastguard Worker            if dp_key not in functions_keys:
2921*da0073e9SAndroid Build Coastguard Worker                functions_keys.add(dp_key)
2922*da0073e9SAndroid Build Coastguard Worker
2923*da0073e9SAndroid Build Coastguard Worker    if "sources" in options.generate:
2924*da0073e9SAndroid Build Coastguard Worker        gen_source_files(
2925*da0073e9SAndroid Build Coastguard Worker            native_functions=native_functions,
2926*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=grouped_native_functions,
2927*da0073e9SAndroid Build Coastguard Worker            structured_native_functions=structured_native_functions,
2928*da0073e9SAndroid Build Coastguard Worker            view_groups=view_groups,
2929*da0073e9SAndroid Build Coastguard Worker            selector=selector,
2930*da0073e9SAndroid Build Coastguard Worker            static_dispatch_idx=static_dispatch_idx,
2931*da0073e9SAndroid Build Coastguard Worker            backend_indices=backend_indices,
2932*da0073e9SAndroid Build Coastguard Worker            aoti_fm=aoti_fm,
2933*da0073e9SAndroid Build Coastguard Worker            core_fm=core_fm,
2934*da0073e9SAndroid Build Coastguard Worker            cpu_fm=cpu_fm,
2935*da0073e9SAndroid Build Coastguard Worker            cpu_vec_fm=cpu_vec_fm,
2936*da0073e9SAndroid Build Coastguard Worker            cuda_fm=cuda_fm,
2937*da0073e9SAndroid Build Coastguard Worker            dispatch_keys=dispatch_keys,
2938*da0073e9SAndroid Build Coastguard Worker            functions_keys=functions_keys,
2939*da0073e9SAndroid Build Coastguard Worker            rocm=options.rocm,
2940*da0073e9SAndroid Build Coastguard Worker            force_schema_registration=options.force_schema_registration,
2941*da0073e9SAndroid Build Coastguard Worker            per_operator_headers=options.per_operator_headers,
2942*da0073e9SAndroid Build Coastguard Worker            skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
2943*da0073e9SAndroid Build Coastguard Worker            update_aoti_c_shim=options.update_aoti_c_shim,
2944*da0073e9SAndroid Build Coastguard Worker        )
2945*da0073e9SAndroid Build Coastguard Worker
2946*da0073e9SAndroid Build Coastguard Worker    if "headers" in options.generate:
2947*da0073e9SAndroid Build Coastguard Worker        gen_headers(
2948*da0073e9SAndroid Build Coastguard Worker            native_functions=native_functions,
2949*da0073e9SAndroid Build Coastguard Worker            valid_tags=valid_tags,
2950*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=grouped_native_functions,
2951*da0073e9SAndroid Build Coastguard Worker            structured_native_functions=structured_native_functions,
2952*da0073e9SAndroid Build Coastguard Worker            static_dispatch_idx=static_dispatch_idx,
2953*da0073e9SAndroid Build Coastguard Worker            selector=selector,
2954*da0073e9SAndroid Build Coastguard Worker            backend_indices=backend_indices,
2955*da0073e9SAndroid Build Coastguard Worker            core_fm=core_fm,
2956*da0073e9SAndroid Build Coastguard Worker            cpu_fm=cpu_fm,
2957*da0073e9SAndroid Build Coastguard Worker            cuda_fm=cuda_fm,
2958*da0073e9SAndroid Build Coastguard Worker            ops_fm=ops_fm,
2959*da0073e9SAndroid Build Coastguard Worker            dispatch_keys=dispatch_keys,
2960*da0073e9SAndroid Build Coastguard Worker            functions_keys=functions_keys,
2961*da0073e9SAndroid Build Coastguard Worker            rocm=options.rocm,
2962*da0073e9SAndroid Build Coastguard Worker            per_operator_headers=options.per_operator_headers,
2963*da0073e9SAndroid Build Coastguard Worker        )
2964*da0073e9SAndroid Build Coastguard Worker
2965*da0073e9SAndroid Build Coastguard Worker    if "declarations_yaml" in options.generate:
2966*da0073e9SAndroid Build Coastguard Worker        gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
2967*da0073e9SAndroid Build Coastguard Worker
2968*da0073e9SAndroid Build Coastguard Worker    if options.output_dependencies:
2969*da0073e9SAndroid Build Coastguard Worker        depfile_path = Path(options.output_dependencies).resolve()
2970*da0073e9SAndroid Build Coastguard Worker        depfile_name = depfile_path.name
2971*da0073e9SAndroid Build Coastguard Worker        depfile_stem = depfile_path.stem
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker        for fm, prefix in [
2974*da0073e9SAndroid Build Coastguard Worker            (cpu_fm, ""),
2975*da0073e9SAndroid Build Coastguard Worker            (cpu_vec_fm, "cpu_vec_"),
2976*da0073e9SAndroid Build Coastguard Worker            (core_fm, "core_"),
2977*da0073e9SAndroid Build Coastguard Worker            (cuda_fm, "cuda_"),
2978*da0073e9SAndroid Build Coastguard Worker            (ops_fm, "ops_"),
2979*da0073e9SAndroid Build Coastguard Worker        ]:
2980*da0073e9SAndroid Build Coastguard Worker            varname = prefix + depfile_stem
2981*da0073e9SAndroid Build Coastguard Worker            path = depfile_path.parent / (prefix + depfile_name)
2982*da0073e9SAndroid Build Coastguard Worker            fm.write_outputs(varname, str(path))
2983*da0073e9SAndroid Build Coastguard Worker
2984*da0073e9SAndroid Build Coastguard Worker
2985*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2986*da0073e9SAndroid Build Coastguard Worker    main()
2987