xref: /aosp_15_r20/external/pytorch/torchgen/gen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import functools
5import json
6import os
7from collections import defaultdict, namedtuple, OrderedDict
8from dataclasses import dataclass, field
9from pathlib import Path
10from typing import Any, Callable, Literal, Sequence, TypeVar
11
12import yaml
13
14import torchgen.api.dispatcher as dispatcher
15import torchgen.api.meta as meta
16import torchgen.api.native as native
17import torchgen.api.structured as structured
18import torchgen.dest as dest
19from torchgen.aoti.fallback_ops import inductor_fallback_ops
20from torchgen.api import cpp
21from torchgen.api.translate import translate
22from torchgen.api.types import (
23    Binding,
24    CppSignature,
25    CppSignatureGroup,
26    DispatcherSignature,
27    NamedCType,
28    NativeSignature,
29    SpecialArgName,
30)
31from torchgen.context import (
32    method_with_native_function,
33    native_function_manager,
34    with_native_function,
35    with_native_function_and_indices,
36)
37from torchgen.gen_aoti_c_shim import (
38    gen_aoti_c_shim,
39    gen_static_dispatch_backend_call_signature,
40    get_fallback_op_name,
41    get_header_for_aoti,
42)
43from torchgen.gen_functionalization_type import (
44    gen_functionalization_definition,
45    gen_functionalization_registration,
46    gen_functionalization_view_inverse_declaration,
47    GenCompositeViewCopyKernel,
48)
49from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
50from torchgen.model import (
51    Argument,
52    BackendIndex,
53    BackendMetadata,
54    BaseOperatorName,
55    DEFAULT_KERNEL_NAMESPACE,
56    DispatchKey,
57    FRAGMENT_NAMESPACES,
58    FunctionSchema,
59    is_cuda_dispatch_key,
60    is_generic_dispatch_key,
61    is_ufunc_dispatch_key,
62    is_xpu_dispatch_key,
63    Location,
64    NativeFunction,
65    NativeFunctionsGroup,
66    NativeFunctionsViewGroup,
67    OperatorName,
68    OptionalType,
69    SchemaKind,
70    SelfArgument,
71    STRUCTURED_DISPATCH_KEYS,
72    TensorOptionsArguments,
73    Type,
74    Variant,
75    ViewSchemaKind,
76)
77from torchgen.native_function_generation import (
78    add_generated_native_functions,
79    gen_composite_functional_kernel,
80    gen_composite_out_kernel,
81    pre_group_native_functions,
82)
83from torchgen.selective_build.selector import SelectiveBuilder
84from torchgen.utils import (
85    assert_never,
86    concatMap,
87    context,
88    FileManager,
89    make_file_manager,
90    mapMaybe,
91    NamespaceHelper,
92    Target,
93)
94from torchgen.yaml_utils import YamlDumper, YamlLoader
95
96
97T = TypeVar("T")
98
99# Welcome to the ATen code generator v2!  The ATen code generator is
100# responsible for parsing native_functions.yaml and then generating
101# various generated files (e.g., TypeDefault.cpp) based on the operators
102# defined in this file.  This means that the code generator knows how to
103# parse function schema, and then translate this into various C++ types
104# and boilerplate code.
105#
106# Some things to know about this file when you modify it:
107#
108# - This file has STRICT mypy typechecking.  Typecheck it with
109#   `mypy --config mypy-strict.ini` in the root source directory
110#
111# - Most of the heavy lifting lives in external modules:
112#   - 'model' has the data model for native_functions.yaml.  The classes
113#     in those file represent what you see when you look at
114#     a native_functions.yaml
115#   - 'api' has conversions for how to translate JIT schema into
116#     the various C++ APIs that the codegen interacts with.  There
117#     are in fact THREE different C++ APIs: the public C++ API,
118#     the dispatcher API, and the legacy dispatcher API.  See each
119#     of these respective files for more information
120
121# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
122#
123#                         HELPER FUNCTIONS
124#
125# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
126
127
128# A custom loader for YAML to let us also keep track of line numbers
129# of each entry in the YAML file
130class LineLoader(YamlLoader):
131    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
132        mapping = super().construct_mapping(node, deep=deep)  # type: ignore[no-untyped-call]
133        # Add 1 so line numbering starts at 1
134        mapping["__line__"] = node.start_mark.line + 1
135        return mapping
136
137
138# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
139ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
140
141
142_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
143_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
144
145
146def parse_native_yaml_struct(
147    es: object,
148    valid_tags: set[str],
149    ignore_keys: set[DispatchKey] | None = None,
150    path: str = "<stdin>",
151    skip_native_fns_gen: bool = False,
152) -> ParsedYaml:
153    assert isinstance(es, list)
154    rs: list[NativeFunction] = []
155    bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
156    for e in es:
157        assert isinstance(e, dict), f"expected to be dict: {e}"
158        assert isinstance(e.get("__line__"), int), e
159        loc = Location(path, e["__line__"])
160        funcs = e.get("func")
161        assert funcs is not None, f"missed 'func' in {e}"
162        with context(lambda: f"in {loc}:\n  {funcs}"):
163            func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
164            rs.append(func)
165            BackendIndex.grow_index(bs, m)
166    error_check_native_functions(rs)
167    # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
168    indices: dict[DispatchKey, BackendIndex] = defaultdict(
169        lambda: BackendIndex(
170            dispatch_key=DispatchKey.Undefined,
171            use_out_as_primary=True,
172            external=False,
173            device_guard=False,
174            # I'm actually not sure about this; undefined could be hit on
175            # empty TensorList, hypothetically that could have sizes in it
176            index={},
177        )
178    )
179    if not skip_native_fns_gen:
180        add_generated_native_functions(rs, bs)
181    for k, v in bs.items():
182        # All structured in-tree operators are implemented in terms of their out operator.
183        indices[k] = BackendIndex(
184            dispatch_key=k,
185            use_out_as_primary=True,
186            external=False,
187            # Only cuda-like devices in tree require device guards
188            device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
189            index=v,
190        )
191    return ParsedYaml(rs, indices)
192
193
194def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
195    assert isinstance(es, list)
196    rs: set[str] = set()
197    for e in es:
198        assert isinstance(e.get("__line__"), int), e
199        loc = Location(path, e["__line__"])
200        tags = e.get("tag")
201        with context(lambda: f"in {loc}:\n  {tags}"):
202            e_i = e.copy()
203            name = e_i.pop("tag")
204            desc = e_i.pop("desc", "")
205            # ensure that each tag has a non-empty description
206            assert desc != ""
207            rs.add(name)
208    return rs
209
210
211@functools.lru_cache(maxsize=None)
212def parse_tags_yaml(path: str) -> set[str]:
213    global _GLOBAL_PARSE_TAGS_YAML_CACHE
214    if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
215        with open(path) as f:
216            es = yaml.load(f, Loader=LineLoader)
217            _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
218
219    return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
220
221
222def parse_native_yaml(
223    path: str,
224    tags_yaml_path: str,
225    ignore_keys: set[DispatchKey] | None = None,
226    *,
227    skip_native_fns_gen: bool = False,
228    loaded_yaml: object | None = None,
229) -> ParsedYaml:
230    global _GLOBAL_PARSE_NATIVE_YAML_CACHE
231    if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
232        valid_tags = parse_tags_yaml(tags_yaml_path)
233
234        # if a loaded yaml is provided, use that instead of reading from path
235        if loaded_yaml is None:
236            with open(path) as f:
237                es = yaml.load(f, Loader=LineLoader)
238        else:
239            es = loaded_yaml
240
241        _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
242            es,
243            valid_tags,
244            ignore_keys,
245            path=path,
246            skip_native_fns_gen=skip_native_fns_gen,
247        )
248
249    return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
250
251
252# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
253# Assertions here are meant to be performed across NativeFunctions.
254def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
255    func_map: dict[OperatorName, NativeFunction] = {}
256    base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
257    for f in funcs:
258        func_map[f.func.name] = f
259        base_func_map[f.func.name.name].append(f)
260    for f in funcs:
261        if f.structured_delegate is not None:
262            delegate_func = func_map.get(f.structured_delegate)
263            assert delegate_func is not None, (
264                f"{f.func.name} is marked as a structured_delegate pointing to "
265                f"{f.structured_delegate}, but {f.structured_delegate} is missing."
266            )
267            assert delegate_func.structured, (
268                f"{f.func.name} is marked as a structured_delegate pointing to "
269                f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
270                f"Consider adding 'structured=True' to the delegated operator"
271            )
272        # See Note [resize_ in Functionalization]
273        # resize_() is technically an inplace view op (and therefore needs the tag),
274        # but it would be overkill to add a true "view" variant of resize.
275        # Instead, resize_() gets special treatment in functionalization,
276        # and we have a resize() op that is non-aliasing + functional.
277        if (
278            "inplace_view" in f.tags
279            and str(f.func.name) != "resize_"
280            and str(f.func.name) != "resize_as_"
281            and str(f.func.name.name) != "set_"
282        ):
283            base_name = f.func.name.name
284            assert base_name.inplace, (
285                f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
286                "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
287            )
288            out_of_place_base_name = BaseOperatorName(
289                base_name.base, False, base_name.dunder_method
290            )
291            assert len(base_func_map[out_of_place_base_name]) > 0, (
292                f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
293                f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
294            )
295
296
297def cpp_string(s: str) -> str:
298    """Convert a python string into a c++ string literal"""
299    s = s.replace("\\", "\\\\")
300    s = s.replace('"', '\\"')
301    s = s.replace("\a", "\\a")
302    s = s.replace("\b", "\\b")
303    s = s.replace("\f", "\\f")
304    s = s.replace("\n", "\\n")
305    s = s.replace("\v", "\\v")
306    s = s.replace("\t", "\\t")
307    return f'"{s}"'
308
309
310# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
311#
312#                        C++ CODE GENERATION
313#
314# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
315
316# Most functions in this section are curried: they consist of a function
317# that takes some parameters (e.g., what is to be generated) which itself
318# returns a function that actually maps NativeFunction to the code
319# to be generated.  This pattern makes it convenient to use map, concatMap
320# and similar functional combinators.
321
322
323def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
324    if len(backends) == 0:
325        return []
326    else:
327        return [backend.dispatch_key for backend in backends] + [
328            DispatchKey.CompositeImplicitAutograd,
329            DispatchKey.CompositeImplicitAutogradNestedTensor,
330            DispatchKey.CompositeExplicitAutograd,
331            DispatchKey.CompositeExplicitAutogradNonFunctional,
332        ]
333
334
335def get_static_dispatch_backend(
336    f: NativeFunction, backend_index: BackendIndex
337) -> DispatchKey | None:
338    if f.structured_delegate is not None or backend_index.has_kernel(f):
339        # TODO: for ops with structured_delegate it should check the dispatch table of
340        # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
341        # so we always dispatch to the `backend`, but this could be wrong when we
342        # migrate math/default_backend ops to use structured delegate.
343        return backend_index.dispatch_key
344    elif f.has_composite_explicit_autograd_kernel:
345        return DispatchKey.CompositeExplicitAutograd
346    elif f.has_composite_explicit_autograd_non_functional_kernel:
347        return DispatchKey.CompositeExplicitAutogradNonFunctional
348    elif f.has_composite_implicit_autograd_kernel:
349        return DispatchKey.CompositeImplicitAutograd
350    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
351        return DispatchKey.CompositeImplicitAutogradNestedTensor
352    return None
353
354
355def static_dispatch_ops_header(
356    f: NativeFunction, backend_index: list[BackendIndex]
357) -> str | None:
358    if backend_index is None or f.manual_kernel_registration:
359        return None
360
361    output = []
362    for index in backend_index:
363        dispatch_key = get_static_dispatch_backend(f, index)
364        if dispatch_key is not None:
365            output.append(
366                f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
367            )
368    return "\n".join(output)
369
370
371def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
372    return [
373        f"#include <ATen/{dispatch_key}Functions.h>"
374        for dispatch_key in static_dispatch_keys(backends)
375    ]
376
377
378# Translates arguments of `sig` to CppSignature bindings.
379# Note that we have a special case for `memory_format` argument and this case is not covered by
380# tools.codegen.api.translate() yet as its application is limited to static dispatch.
381def translate_args(
382    sig: CppSignature | DispatcherSignature,
383    cpp_sig: CppSignature,
384) -> str:
385    # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
386    def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
387        output_bindings: list[Binding] = []
388        for binding in input_bindings:
389            if binding.name == "memory_format":
390                spl_mem_format_binding = Binding(
391                    nctype=NamedCType(
392                        SpecialArgName.possibly_redundant_memory_format,
393                        binding.nctype.type,
394                    ),
395                    name=binding.name,
396                    default=binding.default,
397                    argument=binding.argument,
398                )
399                output_bindings.append(spl_mem_format_binding)
400            else:
401                output_bindings.append(binding)
402        return output_bindings
403
404    src_bindings = list(sig.arguments())
405    goal_bindings = list(cpp_sig.arguments())
406    # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
407    # get memory_format bindings of dispatcher signature to have the same NCType as well
408    for arg in goal_bindings:
409        if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
410            src_bindings = add_spl_memory_format_binding(src_bindings)
411            break
412    exprs = translate(src_bindings, goal_bindings)
413    return ", ".join(a.expr for a in exprs)
414
415
416def generate_static_dispatch_backend_call(
417    sig: CppSignature | DispatcherSignature,
418    f: NativeFunction,
419    backend_index: BackendIndex,
420) -> str:
421    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
422    name = cpp_sig.name()
423    exprs = translate_args(sig, cpp_sig)
424    backend_metadata = backend_index.get_kernel(f)
425    kernel_ns = (
426        backend_metadata.cpp_namespace
427        if backend_metadata and backend_metadata.cpp_namespace
428        else DEFAULT_KERNEL_NAMESPACE
429    )
430    ns = kernel_ns.replace("::native", "")
431    return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
432
433
434def generate_static_dispatch_fallback_call(
435    sig: CppSignature | DispatcherSignature,
436    f: NativeFunction,
437    backend_indices: list[BackendIndex],
438) -> str:
439    cpp_sigs = CppSignatureGroup.from_native_function(
440        f, method=False, fallback_binding=False
441    )
442    if sig.symint and f.func.has_symint():
443        cpp_sig = cpp_sigs.symint_signature
444    else:
445        cpp_sig = cpp_sigs.signature
446    assert cpp_sig is not None
447    name = cpp_sig.name()
448    exprs = translate_args(sig, cpp_sig)
449    ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
450    if f.has_composite_explicit_autograd_kernel:
451        return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
452    elif f.has_composite_explicit_autograd_non_functional_kernel:
453        return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
454    elif f.has_composite_implicit_autograd_kernel:
455        return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
456    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
457        return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
458    else:
459        return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
460{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
461
462
463def static_dispatch(
464    sig: CppSignature | DispatcherSignature,
465    f: NativeFunction,
466    backend_indices: list[BackendIndex],
467) -> str:
468    """
469    For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
470    backends exsit, fallback to static dispatch by determining dispatch key from inputs.
471    Arguments:
472        sig: A CppSignature or DispatcherSignature for this native function we want to use.
473        f: NativeFunction to generate static dispatch.
474        backend_indices: All available backends.
475    Return:
476        C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
477    """
478    if len(backend_indices) == 0 or f.manual_kernel_registration:
479        return ""
480
481    keys = [
482        b
483        for b in backend_indices
484        if b.has_kernel(f)
485        or (
486            f.structured_delegate is not None
487            and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
488        )
489    ]
490    if len(keys) == 1:
491        return generate_static_dispatch_backend_call(sig, f, keys[0])
492    elif len(keys) == 0:
493        return generate_static_dispatch_fallback_call(sig, f, backend_indices)
494
495    native_tensor_args = [
496        a.name
497        for a in sig.arguments()
498        if isinstance(a.argument, SelfArgument)
499        or isinstance(a.argument, Argument)
500        and a.argument.type.is_tensor_like()
501    ]
502    tensor_args = ", ".join(native_tensor_args)
503    tensor_opts = f.func.arguments.tensor_options
504
505    stmts = []
506    subexprs: list[str] = []
507    if tensor_opts is not None:
508        subexprs.append(
509            "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
510        )
511    if tensor_args != "":
512        subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
513    stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
514    stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
515
516    dispatch_code = []
517    for index in keys:
518        dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
519        dispatch_code.append(
520            f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
521        )
522
523    fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
524    connector = "\n\t\t"
525
526    return f"""
527    {connector.join(stmts)}
528    switch (_dk) {{
529        {connector.join(dispatch_code)}
530        default:
531            {fallback}
532    }}
533    """
534
535
536# Generates RegisterSchema.cpp.  Depending on the selector, either
537# all schemas are registered, or only some are (in the case of
538# selective build)
539@dataclass(frozen=True)
540class RegisterSchema:
541    selector: SelectiveBuilder
542    known_tags: dict[str, int] = field(default_factory=dict)
543
544    @method_with_native_function
545    def __call__(self, f: NativeFunction) -> str | None:
546        if not self.selector.is_native_function_selected(f):
547            return None
548        tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
549        if tags == "{}":
550            return f"m.def({cpp_string(str(f.func))}, {{}});\n"
551        maybe_tags = ""
552        if tags not in self.known_tags:
553            idx = len(self.known_tags)
554            self.known_tags[tags] = idx
555            maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
556        return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
557
558
559# Generates Operators.h and Operators.cpp.
560# These provide macros that, given an operator and overload name, allow users
561# to access an "un-overloaded" function version of the operator. This
562# is useful for extension writers who want to (1) want to decltype the operator
563# and (2) don't want to worry about method-only operators.
564@dataclass(frozen=True)
565class ComputeOperators:
566    target: Literal[Target.DECLARATION, Target.DEFINITION]
567    static_dispatch_backend_indices: list[BackendIndex]
568
569    @method_with_native_function
570    def __call__(self, f: NativeFunction) -> str:
571        sig = DispatcherSignature.from_schema(f.func)
572        name = f.func.name.unambiguous_name()
573
574        if self.target is Target.DECLARATION:
575            # Note [The ATen Operators API]
576            # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
577            # metadata about each operator + entry points into the Dispatcher.
578            # The C++ function, method, and redispatch API's are all implemented as wrappers
579            # into various bits of the structs defined here.
580            #
581            # Important characteristics about the Operators API:
582            # (1) It follows the Dispatcher API.
583            #     This is kind of necessary to avoid overhead.
584            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
585            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
586            # (2) Overload names are disambiguated.
587            #     This is helpful for pytorch extenders who would like to decltype() an aten operator,
588            #     that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
589            # (3) No argument defaulting is allowed.
590            #     This is more of an implementation detail to avoid #include cycles,
591            #     since TensorBody.h (which defines the Tensor class) needs to include this file.
592            # (4) manual_cpp_bindings and faithful names are not included in the API.
593            #     This applies to stuff like __dispatch__is_complex(), and add_outf().
594            #     These aren't "real aten ops", they're just additional functions provided by the C++ API.
595            #     They're implemented as wrappers in Functions.h that call into the actual operators
596            #     defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
597            #     This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
598            return f"""
599struct TORCH_API {name} {{
600  using schema = {sig.type()};
601  using ptr_schema = schema*;
602  // See Note [static constexpr char* members for windows NVCC]
603  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
604  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
605  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
606  static {sig.defn(name="call", is_redispatching_fn=False)};
607  static {sig.defn(name="redispatch", is_redispatching_fn=True)};
608}};"""
609
610        elif self.target is Target.DEFINITION:
611            defns = f"""
612STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
613STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
614STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
615
616// aten::{f.func}
617static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
618  return c10::Dispatcher::singleton()
619      .findSchemaOrThrow({name}::name, {name}::overload_name)
620      .typed<{name}::schema>();
621}}
622"""
623            for is_redispatching_fn in [False, True]:
624                if is_redispatching_fn:
625                    dispatcher_exprs_str = ", ".join(
626                        ["dispatchKeySet"] + [a.name for a in sig.arguments()]
627                    )
628                    method_base = "redispatch"
629                else:
630                    dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
631                    method_base = "call"
632
633                dispatcher_call = method_base
634                method_name = f"{name}::{method_base}"
635
636                fn_body = f"""
637    static auto op = create_{name}_typed_handle();
638    return op.{dispatcher_call}({dispatcher_exprs_str});"""
639
640                if (
641                    not is_redispatching_fn
642                    and len(self.static_dispatch_backend_indices) > 0
643                ):
644                    # call() should go through static dispatch
645                    fn_body = static_dispatch(
646                        sig, f, backend_indices=self.static_dispatch_backend_indices
647                    )
648                defns += f"""
649// aten::{f.func}
650{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
651    {fn_body}
652}}
653"""
654            return defns
655        else:
656            assert_never(self.target)
657
658
659# Generates Functions.h, which provides the functional public C++ API,
660# and the scaffolding to call into the dispatcher from these functions.
661@dataclass(frozen=True)
662class ComputeFunction:
663    @method_with_native_function
664    def __call__(self, f: NativeFunction) -> str | None:
665        sig_group = CppSignatureGroup.from_native_function(
666            f, method=False, fallback_binding=f.manual_cpp_binding
667        )
668        has_symint = f.func.has_symint()
669
670        result = ""
671        for sig in sig_group.signatures():
672            # See Note [The ATen Operators API]
673            target_sig = DispatcherSignature.from_schema(f.func)
674            exprs = translate(sig.arguments(), target_sig.arguments())
675            exprs_str = ", ".join([e.expr for e in exprs])
676
677            if sig.symint:
678                intlike_t = "c10::SymInt"
679            else:
680                intlike_t = "int64_t"
681
682            if Variant.function in f.variants:
683                result += f"""
684// aten::{f.func}
685inline {sig.decl()} {{
686    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
687}}"""
688
689            # The template function can be used from template situations
690            # where you want to switch between the symint or not version
691            # depending on a template argument
692            #
693            # NB: we ALWAYS generate this even for methods.  But we put it in
694            # this header so it can take advantage of per-op headers
695            if has_symint:
696                result += f"""
697namespace symint {{
698  template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>>
699  {sig.decl(suppress_symint_suffix=True)} {{
700    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
701  }}
702}}
703"""
704        return result
705
706
707# Generates TensorBody.h. This file provides the object-oriented (method-based)
708# public C++ API, and the scaffolding to call into the dispatcher from these functions.
709@dataclass(frozen=True)
710class ComputeTensorMethod:
711    target: Literal[Target.DECLARATION, Target.DEFINITION]
712    static_dispatch_backend_indices: list[BackendIndex]
713
714    @method_with_native_function
715    def __call__(self, f: NativeFunction) -> str | None:
716        if Variant.method not in f.variants:
717            return None
718
719        assert not f.func.is_out_fn()
720        assert f.func.arguments.self_arg is not None
721
722        sig_group = CppSignatureGroup.from_native_function(
723            f, method=True, fallback_binding=f.manual_cpp_binding
724        )
725
726        if self.target is Target.DECLARATION:
727            result = ""
728            for sig in sig_group.signatures():
729                result += f"{sig.decl()} const;\n"
730            return result
731
732        if self.target is not Target.DEFINITION:
733            assert_never(self.target)
734
735        result = ""
736
737        for sig in sig_group.signatures():
738            target_sig = DispatcherSignature.from_schema(f.func)
739            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
740            exprs_str = ", ".join([e.expr for e in exprs])
741
742            result += f"""
743// aten::{f.func}
744inline {sig.defn(prefix="Tensor::")} const {{
745    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
746}}
747"""
748
749        return result
750
751
752# Generates RedispatchFunctions.h.
753# This is similar to the C++ API defined in Functions.h, but provides access
754# to the dispatcher's redispatch API.
755@dataclass(frozen=True)
756class ComputeRedispatchFunction:
757    @method_with_native_function
758    def __call__(self, f: NativeFunction) -> str | None:
759        # We unconditionally generate function variants of the redispatch API.
760        # This is mainly because we can namespace functions separately, but not methods,
761        sig_group = CppSignatureGroup.from_native_function(
762            f, method=False, fallback_binding=f.manual_cpp_binding
763        )
764
765        result = ""
766        for sig in sig_group.signatures():
767            target_sig = DispatcherSignature.from_schema(f.func)
768            exprs = translate(sig.arguments(), target_sig.arguments())
769            exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
770
771            result += f"""
772// aten::{f.func}
773inline {sig.decl(is_redispatching_fn=True)} {{
774    return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
775}}
776"""
777
778        return result
779
780
781# Generates ATenOpList.cpp, a runtime accessible list of all aten
782# operators.
783# TODO: This was historically used to help some JIT interop code
784# figure out whether or not to treat aten namespace'd operators
785# one way or another, we should reevaluate if this is actually needed.
786@with_native_function
787def compute_aten_op(f: NativeFunction) -> str:
788    return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
789
790
791# Generates MetaFunctions.h
792def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
793    if not g.structured:
794        return None
795    with native_function_manager(g.out):
796        name = meta.name(g)
797        args = structured.meta_arguments(g)
798        args_str = ", ".join(a.decl() for a in args)
799        parent_class = g.out.structured_inherits
800        if parent_class is None:
801            parent_class = "at::impl::MetaBase"
802        meta_return = "void"
803        precomputed = g.out.precomputed if g.structured else None
804
805        if precomputed:
806            # Generate the template declaration with one bool parameter for each
807            # precomputed element. Each parameter is true if the corresponding (in
808            # terms of position) precomputed element has been set.
809            precomputed_values = [*precomputed.replace.values(), precomputed.add]
810            precomputed_elements = [
811                elem for replace_list in precomputed_values for elem in replace_list
812            ]
813            precomputed_template_parameters = [
814                elem.name.upper() for elem in precomputed_elements
815            ]
816            precomputed_template_params_str = ", ".join(
817                f"bool {param} = false" for param in precomputed_template_parameters
818            )
819            precompute_template_decl = f"template <{precomputed_template_params_str}>"
820
821            # Generate a string containing declarations of all precomputed elements.
822            precomputed_elements_with_cpp_types = [
823                structured.argument_type(elem, binds=elem.name)
824                for elem in precomputed_elements
825            ]
826
827            precomputed_elements_decl = ";\n".join(
828                f"{elem.cpp_type(strip_ref=True)} {elem.name}"
829                for elem in precomputed_elements_with_cpp_types
830            )
831
832            # Generate "setter" methods for each precomputed element. Each method will return
833            # a new instance of precompute_out with the template parameter that corresponds to
834            # the member set by the method to true (to indicate that it has been set).
835            setter_methods = []
836            for i, elem in enumerate(precomputed_elements):
837                # Generate the signature. The return type will be the same
838                # as the type of `this` but with the template parameter
839                # corresponding to the element set by this method set to true.
840                # The assert generated below will ensure that this template
841                # parameter is false on the type of `this`.
842                return_ty_templates = ", ".join(
843                    precomputed_template_parameters[:i]
844                    + ["true"]
845                    + precomputed_template_parameters[i + 1 :]
846                )
847                return_ty = f"precompute_out<{return_ty_templates}>"
848                elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
849                    strip_ref=True
850                )
851                signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
852
853                # Generate an assert which checks that the
854                # template parameter corresponding to the precomputed
855                # element that is set by this method is false on the
856                # class corresponding to the object that `this` points to.
857                # This ensures that each element can be set only once.
858                assert_msg = f'"{elem.name} already set"'
859                assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
860
861                # Generate the new object construction block. All state
862                # except the element that this method sets is copied from the
863                # object that `this` points to. The value for the element that
864                # the method sets is taken from a method parameter.
865                construction_stmts = []
866                construction_stmts.append(f"{return_ty} ret;")
867
868                for j, elem in enumerate(precomputed_elements):
869                    if i == j:
870                        construction_stmts.append(f"ret.{elem.name} = value;")
871                    else:
872                        construction_stmts.append(
873                            f"ret.{elem.name} = this->{elem.name};"
874                        )
875
876                construction_stmts.append("return ret;")
877                construction_block = "\n".join(construction_stmts)
878
879                setter_methods.append(
880                    f"""
881                    {signature} {{
882                        {assert_stmt}
883                        {construction_block}
884                    }}
885                """
886                )
887            setter_methods_decl = "\n".join(setter_methods)
888
889            # Meta should return an instance of the struct containing the precomputed elements.
890            meta_return_template_params = ", ".join(
891                ["true"] * len(precomputed_template_parameters)
892            )
893            # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
894            # type (which has a variable number of template parameters).
895            meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
896            meta_return = "meta_return_ty"
897            precomputed_decl = f"""
898                {precompute_template_decl}
899                struct TORCH_API precompute_out {{
900                    {setter_methods_decl}
901                    {precomputed_elements_decl};
902            }};"""
903        else:
904            meta_return_typedef = ""
905            precomputed_decl = ""
906
907        return f"""\
908struct TORCH_API structured_{name} : public {parent_class} {{
909    {precomputed_decl}
910    {meta_return_typedef}
911    {meta_return} meta({args_str});
912}};
913"""
914
915
916def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
917    name = str(f.func.name.name)
918    if name.endswith("_like") or name.startswith("new_"):
919        return False
920    if f.func.arguments.tensor_options is None:
921        return False
922    return selector.is_native_function_selected(f)
923
924
925# Generates RegisterBackendSelect.cpp, a series of kernels which provide
926# specialized computation of dispatch key for operator signatures which cannot
927# be easily done automatically using templating.
928@dataclass(frozen=True)
929class ComputeBackendSelect:
930    target: Literal[Target.DEFINITION, Target.REGISTRATION]
931
932    # Selector object to determine which operators to generate
933    # registration code for.
934    selector: SelectiveBuilder
935
936    @method_with_native_function
937    def __call__(self, f: NativeFunction) -> str | None:
938        if not needs_backend_select(f, self.selector):
939            return None
940
941        name = native.name(f.func)
942        # BackendSelect can go to Meta, so it must preserve symints
943        native_sig = NativeSignature(f.func, symint=True)
944
945        native_tensor_args = [
946            a
947            for a in native_sig.arguments()
948            if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
949        ]
950
951        dispatcher_sig = DispatcherSignature.from_schema(f.func)
952
953        sig: NativeSignature | DispatcherSignature
954        sig = dispatcher_sig
955        dispatcher_exprs = dispatcher_sig.exprs()
956        dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
957
958        if self.target is Target.DEFINITION:
959            # I don't think there's actually a good reason to generate
960            # these two cases differently
961            # The first case could probably be improved though- it calls computeDispatchKeySet(),
962            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
963            if native_tensor_args:
964                assert f.func.arguments.has_tensor_arg()
965                tensor_args = ", ".join(a.name for a in native_tensor_args)
966                compute_dk = f"""\
967DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
968DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
969DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
970            else:
971                assert not f.func.arguments.has_tensor_arg()
972                compute_dk = (
973                    f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
974                )
975            return f"""\
976// aten::{f.func}
977C10_ALWAYS_INLINE
978{sig.defn(name)} {{
979  {compute_dk}
980  return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
981      _dk, {', '.join(a.expr for a in dispatcher_exprs)});
982}}
983"""
984        elif self.target is Target.REGISTRATION:
985            return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
986        else:
987            assert_never(self.target)
988
989
990# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
991#
992#                       YAML CODE GENERATION
993#
994# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
995
996
997def format_yaml(data: object) -> str:
998    # Ignore alias in Dumper
999    YamlDumper.ignore_aliases = lambda self, data: True  # type: ignore[assignment]
1000
1001    # Support serializing OrderedDict
1002    def dict_representer(dumper: Any, data: Any) -> Any:
1003        return dumper.represent_dict(data.items())
1004
1005    YamlDumper.add_representer(OrderedDict, dict_representer)  # type: ignore[no-untyped-call]
1006    # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
1007    # width=1e9 turns off optional line breaks and improves
1008    # the portability of the outputted yaml.
1009    return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9)  # type: ignore[no-any-return, call-overload]
1010
1011
1012# For some reason, some defaults we write to YAML are written as native
1013# YAML objects, rather than doing them uniformly as strings.  This
1014# function detects those cases and converts them into native Python
1015# objects.
1016def pythonify_default(s: str) -> object:
1017    if s == "true":
1018        return True
1019    elif s == "false":
1020        return False
1021
1022    try:
1023        return int(s)
1024    except ValueError:
1025        try:
1026            return float(s)
1027        except ValueError:
1028            return s
1029
1030
1031# What is a dynamic type?  Over time, the semantic meaning of
1032# dynamic type has degraded to meaninglessness (in the old days,
1033# it captured dtype-ness of types, but that has gone away with
1034# the removal of TH).  These days, it's mostly the same thing as
1035# the C++ API argument type, except that Tensor and Tensor?
1036# arguments simply present as Tensor.
1037#
1038# TODO: Get rid of dynamic_type, after getting tools/autograd
1039# to use the new codegen framework
1040def dynamic_type(t: Type) -> str:
1041    if isinstance(t, OptionalType):
1042        return dynamic_type(t.elem)
1043    # Note we don't use t.is_tensor_like() here because it would
1044    # also include Tensor[]
1045    if str(t) == "Tensor":
1046        return "at::Tensor"
1047    # This is a legacy concept, so never report SymInt
1048    return cpp.argumenttype_type(
1049        t, mutable=False, binds="__placeholder__", symint=False
1050    ).cpp_type()
1051
1052
1053def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
1054    # This is written out explicitly to ensure that Tensor and
1055    # namespace are put into the list in the right order
1056    method_of = ["Type"]
1057    if Variant.method in variants:
1058        method_of.append("Tensor")
1059    if Variant.function in variants:
1060        method_of.append("namespace")
1061    return method_of
1062
1063
1064def compute_returns_yaml(
1065    f: NativeFunction,
1066) -> tuple[list[dict[str, str]], dict[str, str]]:
1067    # Note [name and field_name]
1068    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
1069    # To understand name_to_field_name, we must first talk about this
1070    # schema:
1071    #
1072    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
1073    #
1074    # There is something very odd about this schema: it is an out
1075    # variant of the function (that is to say, it will convert into
1076    # at::lstsq_out() in the C++ API), but the names of the output
1077    # return arguments don't match the keyword argument names of
1078    # the inputs.  It TURNS OUT that in this situation, the historical
1079    # Declarations.yaml we want to output is this (abbreviated to
1080    # only show relevant fields):
1081    #
1082    #   arguments:
1083    #     ...
1084    #   - field_name: solution
1085    #     name: X
1086    #   - field_name: QR
1087    #     name: qr
1088    #     ...
1089    #
1090    #   returns:
1091    #   - field_name: solution
1092    #     name: X
1093    #   - field_name: QR
1094    #     name: qr
1095    #
1096    # The name of the return fields is stored in 'field_name', and the
1097    # name of the arguments is stored in 'name'.  So when we process
1098    # arguments, we need a way to get at the corresponding return.  At
1099    # the moment, this is most conveniently done by constructing a
1100    # mapping from name (the argument concept) to field_name (the
1101    # return concept) while processing return arguments, since we don't
1102    # directly maintain this correspondence in the modeling of function
1103    # schema itself.
1104    #
1105    # See also https://github.com/pytorch/pytorch/issues/43114
1106    name_to_field_name: dict[str, str] = {}
1107
1108    # Compute the returns field of the YAML entry
1109    names = cpp.return_names(f)
1110    returns = []
1111    for i, (r, name) in enumerate(zip(f.func.returns, names)):
1112        ret = {
1113            "dynamic_type": dynamic_type(r.type),
1114            "name": name,
1115            # legacy, report ints
1116            "type": cpp.return_type(r, symint=False).cpp_type(),
1117        }
1118
1119        if r.name:
1120            # See Note [name and field_name]
1121            ret["field_name"] = r.name
1122            if f.func.is_out_fn():
1123                name_to_field_name[f.func.arguments.out[i].name] = r.name
1124
1125        returns.append(ret)
1126
1127    return returns, name_to_field_name
1128
1129
1130# arguments in yaml roughly corresponds to the public C++ API
1131def compute_cpp_argument_yaml(
1132    cpp_a: Binding,
1133    *,
1134    schema_order: bool,
1135    kwarg_only_set: set[str],
1136    out_arg_set: set[str],
1137    name_to_field_name: dict[str, str],
1138) -> object:
1139    if isinstance(cpp_a.argument, TensorOptionsArguments):
1140        arg: dict[str, object] = {
1141            "annotation": None,
1142            "dynamic_type": "at::TensorOptions",
1143            "is_nullable": False,
1144            "name": cpp_a.name,
1145            "type": cpp_a.type,
1146            "kwarg_only": True,
1147        }
1148        if cpp_a.default is not None:
1149            arg["default"] = cpp_a.default
1150        return arg
1151    elif isinstance(cpp_a.argument, SelfArgument):
1152        raise AssertionError
1153    elif isinstance(cpp_a.argument, Argument):
1154        return compute_argument_yaml(
1155            cpp_a.argument,
1156            schema_order=schema_order,
1157            kwarg_only_set=kwarg_only_set,
1158            out_arg_set=out_arg_set,
1159            name_to_field_name=name_to_field_name,
1160        )
1161
1162
1163def compute_argument_yaml(
1164    a: Argument,
1165    *,
1166    schema_order: bool,
1167    kwarg_only_set: set[str],
1168    out_arg_set: set[str],
1169    name_to_field_name: dict[str, str],
1170) -> object:
1171    arg: dict[str, object] = {
1172        "annotation": str(a.annotation) if a.annotation else None,
1173        "dynamic_type": dynamic_type(a.type),
1174        "is_nullable": a.type.is_nullable(),
1175        "name": a.name,
1176        # legacy, report ints
1177        "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
1178    }
1179    if a.default is not None:
1180        arg["default"] = pythonify_default(
1181            cpp.default_expr(a.default, a.type, symint=False)
1182        )
1183    if a.name in kwarg_only_set:
1184        arg["kwarg_only"] = True
1185    if a.name in out_arg_set:
1186        arg["output"] = True
1187        arg["allocate"] = True
1188        # See Note [name and field_name]
1189        if a.name in name_to_field_name:
1190            arg["field_name"] = name_to_field_name[a.name]
1191    # Historically, booleans don't get their size recorded, because it
1192    # is already built into the cpp type (e.g., std::array<bool, 4>)
1193    l = a.type.is_list_like()
1194    if l is not None and l.size is not None and str(l.elem) != "bool":
1195        arg["size"] = l.size
1196    return arg
1197
1198
1199@with_native_function
1200def compute_declaration_yaml(f: NativeFunction) -> object:
1201    returns, name_to_field_name = compute_returns_yaml(f)
1202
1203    # These sets are used to conveniently test if an argument is a
1204    # kwarg-only or out argument
1205    kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
1206    out_arg_set = {a.name for a in f.func.arguments.out}
1207
1208    sig_group = CppSignatureGroup.from_native_function(
1209        f, method=False, fallback_binding=False
1210    )
1211    cpp_args = sig_group.signature.arguments()
1212    arguments = [
1213        compute_cpp_argument_yaml(
1214            cpp_a,
1215            schema_order=False,
1216            kwarg_only_set=kwarg_only_set,
1217            out_arg_set=out_arg_set,
1218            name_to_field_name=name_to_field_name,
1219        )
1220        for cpp_a in cpp_args
1221    ]
1222
1223    schema_order_jit_arguments = list(f.func.schema_order_arguments())
1224
1225    schema_order_arguments = [
1226        compute_argument_yaml(
1227            a,
1228            schema_order=True,
1229            kwarg_only_set=kwarg_only_set,
1230            out_arg_set=out_arg_set,
1231            name_to_field_name=name_to_field_name,
1232        )
1233        for a in schema_order_jit_arguments
1234    ]
1235
1236    cpp_schema_order_types = [
1237        # NB: method here doesn't matter
1238        r.type
1239        for a in schema_order_jit_arguments
1240        for r in cpp.argument(
1241            a,
1242            method=False,
1243            cpp_no_default_args=set(),
1244            faithful=False,
1245            symint=False,
1246            has_tensor_options=False,
1247        )
1248    ]
1249
1250    # legacy, report ints
1251    cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
1252    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
1253
1254    is_factory_method = (
1255        any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
1256        and Variant.method not in f.variants
1257    )
1258
1259    return OrderedDict(
1260        [
1261            ("name", cpp.name(f.func)),
1262            ("operator_name", str(f.func.name.name)),
1263            ("overload_name", str(f.func.name.overload_name)),
1264            ("manual_kernel_registration", f.manual_kernel_registration),
1265            (
1266                "category_override",
1267                f.category_override if f.category_override is not None else "",
1268            ),
1269            ("schema_string", f"aten::{f.func}"),
1270            ("arguments", arguments),
1271            ("schema_order_cpp_signature", schema_order_cpp_signature),
1272            ("schema_order_arguments", schema_order_arguments),
1273            ("method_of", compute_method_of_yaml(f.variants)),
1274            ("mode", "native"),
1275            ("python_module", "" if f.python_module is None else f.python_module),
1276            ("returns", returns),
1277            ("inplace", f.func.name.name.inplace),
1278            ("is_factory_method", is_factory_method),
1279            ("abstract", f.is_abstract),
1280            ("device_guard", f.device_guard),
1281            ("with_gil", False),
1282            ("deprecated", False),
1283            ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
1284        ]
1285    )
1286
1287
1288# See Note [Auto generated composite kernels]
1289def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
1290    return (f.structured or f.structured_delegate is not None) and (
1291        f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
1292    )
1293
1294
1295@with_native_function_and_indices
1296def compute_registration_declarations(
1297    f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
1298) -> str:
1299    name = dispatcher.name(f.func)
1300    returns_type = dispatcher.returns_type(
1301        f.func.returns
1302    ).cpp_type_registration_declarations()
1303    args = dispatcher.arguments(f.func)
1304    args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
1305    comment_data: dict[str, str] = {
1306        "schema": f"aten::{f.func}",
1307        # TODO: What exactly is the semantics of the 'dispatch' field?
1308        "dispatch": str(
1309            {k for k, v in backend_indices.items() if v.has_kernel(f)}
1310            != {DispatchKey.CompositeImplicitAutograd}
1311            and {k for k, v in backend_indices.items() if v.has_kernel(f)}
1312            != {
1313                DispatchKey.CompositeImplicitAutograd,
1314                DispatchKey.CompositeImplicitAutogradNestedTensor,
1315            }
1316        ),
1317        "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
1318    }
1319    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
1320"""
1321
1322
1323# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1324#
1325#                           RUN IT ALL
1326#
1327# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1328
1329
1330def get_custom_build_selector(
1331    provided_op_registration_allowlist: list[str] | None,
1332    op_selection_yaml_path: str | None,
1333) -> SelectiveBuilder:
1334    assert not (
1335        provided_op_registration_allowlist is not None
1336        and op_selection_yaml_path is not None
1337    ), (
1338        "Both provided_op_registration_allowlist and "
1339        + "op_selection_yaml_path can NOT be provided at the "
1340        + "same time."
1341    )
1342
1343    op_registration_allowlist: set[str] | None = None
1344    if provided_op_registration_allowlist is not None:
1345        op_registration_allowlist = set(provided_op_registration_allowlist)
1346
1347    if op_registration_allowlist is not None:
1348        selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
1349            op_registration_allowlist,
1350            True,
1351            False,
1352        )
1353    elif op_selection_yaml_path is not None:
1354        selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
1355    else:
1356        selector = SelectiveBuilder.get_nop_selector()
1357
1358    return selector
1359
1360
1361def get_grouped_by_view_native_functions(
1362    native_functions: Sequence[NativeFunction],
1363) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
1364    def maybe_create_view_group(
1365        d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
1366    ) -> list[NativeFunction | NativeFunctionsViewGroup]:
1367        funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
1368        if ViewSchemaKind.aliasing in d:
1369            view = d.pop(ViewSchemaKind.aliasing)
1370            view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
1371            view_copy = d.pop(SchemaKind.functional, None)
1372
1373            funcs.append(
1374                NativeFunctionsViewGroup(
1375                    view=view,
1376                    view_copy=view_copy,
1377                    view_inplace=view_inplace,
1378                )
1379            )
1380        # Take the remaining functions that weren't part of the view group
1381        # and emit them separately
1382        funcs.extend(d.values())
1383        return funcs
1384
1385    grouped_by_views: dict[
1386        FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
1387    ] = defaultdict(dict)
1388    for f in native_functions:
1389        schema = f.func.view_signature()
1390        view_kind: ViewSchemaKind = f.view_schema_kind
1391        # We need to group up ops relevant to the same "view", consisting of:
1392        # view op (ViewSchemaKind.aliasing)
1393        # view_inplace op (ViewSchemaKind.aliasing_inplace)
1394        # view_copy op (SchemaKind.functional)
1395        if view_kind == ViewSchemaKind.non_aliasing:
1396            kind = f.func.kind()
1397            assert kind not in grouped_by_views[schema]
1398            grouped_by_views[schema][kind] = f
1399        else:
1400            assert (
1401                view_kind not in grouped_by_views[schema]
1402            ), f"{view_kind} already in {grouped_by_views[schema].keys()}"
1403            grouped_by_views[schema][view_kind] = f
1404
1405    return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
1406
1407
1408def get_grouped_native_functions(
1409    native_functions: Sequence[NativeFunction],
1410) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1411    def flatten_pre_group(
1412        d: dict[SchemaKind, NativeFunction]
1413    ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1414        r = NativeFunctionsGroup.from_dict(d)
1415        if r is None:
1416            # Invariant: any NativeFunctions that are code-generated
1417            # should have been grouped into NativeFunctionsGroup objects
1418            assert not any("generated" in f.tags for f in d.values())
1419            return list(d.values())
1420        else:
1421            return [r]
1422
1423    # TODO: how come ValuesView isn't a Sequence lol
1424    pre_grouped_native_functions = pre_group_native_functions(native_functions)
1425    return list(
1426        concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
1427    )
1428
1429
1430def get_ns_grouped_kernels(
1431    *,
1432    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1433    backend_indices: dict[DispatchKey, BackendIndex],
1434    native_function_decl_gen: Callable[
1435        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1436    ] = dest.compute_native_function_declaration,
1437) -> dict[str, list[str]]:
1438    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1439    for f in grouped_native_functions:
1440        native_function_namespaces = set()
1441        dispatch_keys = set()
1442        for dispatch_key, backend_idx in backend_indices.items():
1443            backend_metadata = backend_idx.get_kernel(f)
1444            if backend_metadata:
1445                namespace = backend_metadata.cpp_namespace
1446                dispatch_keys.add(dispatch_key)
1447                native_function_namespaces.add(namespace)
1448            else:
1449                namespace = DEFAULT_KERNEL_NAMESPACE
1450            assert (
1451                len(native_function_namespaces) <= 1
1452            ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
1453            ns_grouped_kernels[namespace].extend(
1454                native_function_decl_gen(f, backend_idx)
1455            )
1456    return ns_grouped_kernels
1457
1458
1459def get_native_function_declarations_from_ns_grouped_kernels(
1460    *,
1461    ns_grouped_kernels: dict[str, list[str]],
1462) -> list[str]:
1463    declarations: list[str] = []
1464    newline = "\n"
1465    for namespace, kernels in ns_grouped_kernels.items():
1466        ns_helper = NamespaceHelper(
1467            namespace_str=namespace,
1468            entity_name="",
1469            max_level=4,
1470        )
1471        # Convert to a set first to remove duplicate kernel names. Backends are
1472        # allowed to repeat kernel names; only generate the declaration once!
1473        ordered_kernels = list(OrderedDict.fromkeys(kernels))
1474        declarations.extend(
1475            f"""
1476{ns_helper.prologue}
1477{newline.join(ordered_kernels)}
1478{ns_helper.epilogue}
1479        """.split(
1480                newline
1481            )
1482        )
1483    return declarations
1484
1485
1486# Return native function declarations grouped by their namespaces.
1487def get_native_function_declarations(
1488    *,
1489    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1490    backend_indices: dict[DispatchKey, BackendIndex],
1491    native_function_decl_gen: Callable[
1492        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1493    ] = dest.compute_native_function_declaration,
1494) -> list[str]:
1495    """
1496    Generate kernel declarations, in `NativeFunction(s).h`.
1497    :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
1498    :param backend_indices: kernel collections grouped by dispatch key.
1499    :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
1500    :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
1501    """
1502
1503    ns_grouped_kernels = get_ns_grouped_kernels(
1504        grouped_native_functions=grouped_native_functions,
1505        backend_indices=backend_indices,
1506        native_function_decl_gen=native_function_decl_gen,
1507    )
1508    return get_native_function_declarations_from_ns_grouped_kernels(
1509        ns_grouped_kernels=ns_grouped_kernels
1510    )
1511
1512
1513def get_kernel_namespace(
1514    *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
1515) -> str:
1516    backend_metadata = backend_idx.get_kernel(f)
1517    assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
1518        f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
1519        f"with dispatch key {backend_idx.dispatch_key}"
1520        f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
1521    )
1522    return (
1523        backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
1524    )
1525
1526
1527# Return native function definitions grouped by dispatch key and custom namespace.
1528# Used in RegisterDispatchKey.cpp and etc.
1529def get_native_function_definitions(
1530    *,
1531    fm: FileManager,
1532    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1533    dispatch_key: DispatchKey,
1534    backend_idx: BackendIndex,
1535    selector: SelectiveBuilder,
1536    rocm: bool,
1537    symint: bool,
1538    skip_dispatcher_op_registration: bool,
1539    gen_dispatch_helpers: bool,
1540) -> list[str]:
1541    definitions: list[str] = []
1542    ns_definitions: dict[str, list[str]] = defaultdict(list)
1543    anonymous_definitions: dict[str, list[str]] = defaultdict(list)
1544    registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
1545    newline = "\n"
1546    ns_gen = dest.RegisterDispatchKey(
1547        backend_idx,
1548        Target.NAMESPACED_DEFINITION,
1549        selector,
1550        rocm=rocm,
1551        symint=symint,
1552        class_method_name=None,
1553        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1554    )
1555    anonymous_gen = dest.RegisterDispatchKey(
1556        backend_idx,
1557        Target.ANONYMOUS_DEFINITION,
1558        selector,
1559        rocm=rocm,
1560        symint=symint,
1561        class_method_name=None,
1562        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1563    )
1564    reg_gen = dest.RegisterDispatchKey(
1565        backend_idx,
1566        Target.REGISTRATION,
1567        selector,
1568        rocm=rocm,
1569        symint=symint,
1570        class_method_name=None,
1571        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1572    )
1573    for f in grouped_native_functions:
1574        kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1575            "::native", ""
1576        )
1577
1578        ns_definitions[kernel_namespace].extend(
1579            ns_gen(f),
1580        )
1581        anonymous_definitions[kernel_namespace].extend(
1582            anonymous_gen(f),
1583        )
1584        namespace = (
1585            f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
1586        )
1587        if namespace not in registrations[kernel_namespace]:
1588            registrations[kernel_namespace] = defaultdict(list)
1589        registrations[kernel_namespace][namespace].extend(
1590            reg_gen(f),
1591        )
1592
1593    for kernel_namespace in ns_definitions:
1594        if len(ns_definitions[kernel_namespace]) == 0:
1595            continue
1596        ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
1597        registration_body = ""
1598        for namespace in registrations[kernel_namespace]:
1599            if not registrations[kernel_namespace][namespace]:
1600                continue
1601            registration_body += f"""
1602TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
1603    {newline.join(registrations[kernel_namespace][namespace])}
1604}};"""
1605        definitions.extend(
1606            fm.substitute_with_template(
1607                "RegisterDispatchDefinitions.ini",
1608                lambda: {
1609                    "ns_prologue": ns_helper.prologue,
1610                    "ns_epilogue": ns_helper.epilogue,
1611                    "dispatch_helpers": dest.gen_registration_helpers(backend_idx)
1612                    if gen_dispatch_helpers
1613                    else [],
1614                    "dispatch_anonymous_definitions": anonymous_definitions[
1615                        kernel_namespace
1616                    ],
1617                    "static_init_dispatch_registrations": ""
1618                    if skip_dispatcher_op_registration
1619                    else registration_body,
1620                    "deferred_dispatch_registrations": "",
1621                    "dispatch_namespace": dispatch_key.lower(),
1622                    "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
1623                },
1624            ).split(newline)
1625        )
1626
1627    return definitions
1628
1629
1630# Return native function declarations grouped by dispatch key and custom namespace.
1631# Used in CPUFunctions_inl.h and etc.
1632def get_namespaced_declaration(
1633    *,
1634    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1635    dispatch_key: DispatchKey,
1636    backend_idx: BackendIndex,
1637    selector: SelectiveBuilder,
1638    rocm: bool,
1639    symint: bool,
1640) -> list[str]:
1641    declarations: list[str] = []
1642    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1643    newline = "\n"
1644    func = dest.RegisterDispatchKey(
1645        backend_idx,
1646        Target.NAMESPACED_DECLARATION,
1647        selector,
1648        rocm=rocm,
1649        class_method_name=None,
1650        skip_dispatcher_op_registration=False,
1651        symint=symint,
1652    )
1653    for f in grouped_native_functions:
1654        namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1655            "native", dispatch_key.lower()
1656        )
1657
1658        ns_grouped_kernels[namespace].extend(
1659            func(f),
1660        )
1661
1662    for namespace, kernels in ns_grouped_kernels.items():
1663        if len(kernels) == 0:
1664            continue
1665        ns_helper = NamespaceHelper(
1666            namespace_str=namespace, entity_name="", max_level=3
1667        )
1668        ordered_kernels = list(OrderedDict.fromkeys(kernels))
1669        declarations.extend(
1670            f"""
1671{ns_helper.prologue}
1672{newline.join(ordered_kernels)}
1673{ns_helper.epilogue}
1674        """.split(
1675                newline
1676            )
1677        )
1678    return declarations
1679
1680
1681# Return native function schema registration code for aten and other namespaces.
1682def get_native_function_schema_registrations(
1683    *,
1684    native_functions: Sequence[NativeFunction],
1685    schema_selector: SelectiveBuilder,
1686) -> tuple[list[str], str]:
1687    ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
1688    for native_function in native_functions:
1689        ns_native_functions[native_function.namespace].append(native_function)
1690    schema_registrations = ""
1691    aten_schema_registrations = []
1692    custom_namespace = None
1693    for namespace, funcs in ns_native_functions.items():
1694        schema_registrations_body = list(
1695            mapMaybe(RegisterSchema(schema_selector), funcs)
1696        )
1697        # NB: we have to separate aten namespace registration from other namespaces,
1698        # because in the template we hardcoded an operator for ATen already.
1699        if namespace == "aten":
1700            aten_schema_registrations = schema_registrations_body
1701        else:
1702            custom_namespace = namespace
1703            tab = "\t"
1704            # if the namespace is predefined, we should use define a library fragment
1705            # instead of a new library
1706            torch_library_macro = (
1707                "TORCH_LIBRARY_FRAGMENT"
1708                if namespace in FRAGMENT_NAMESPACES
1709                else "TORCH_LIBRARY"
1710            )
1711            schema_registrations += f"""
1712{torch_library_macro}({custom_namespace}, m) {{
1713  {tab.join(schema_registrations_body)}
1714}};"""
1715    return (aten_schema_registrations, schema_registrations)
1716
1717
1718def gen_aggregated_headers(
1719    *,
1720    native_functions: Sequence[NativeFunction],
1721    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1722    structured_native_functions: Sequence[NativeFunctionsGroup],
1723    static_dispatch_idx: list[BackendIndex],
1724    selector: SelectiveBuilder,
1725    backend_indices: dict[DispatchKey, BackendIndex],
1726    cpu_fm: FileManager,
1727    cuda_fm: FileManager,
1728    functions_keys: set[DispatchKey],
1729    dispatch_keys: Sequence[DispatchKey],
1730    rocm: bool,
1731) -> None:
1732    # Buck doesn't support dynamic output files, so we aggregate all operator
1733    # headers into a single file
1734    cpu_fm.write(
1735        "NativeMetaFunctions.h",
1736        lambda: {
1737            "NativeMetaFunctions_includes": [],
1738            "NativeMetaFunctions_declarations": list(
1739                mapMaybe(compute_meta_function_declaration, structured_native_functions)
1740            ),
1741        },
1742    )
1743    method_native_functions = [
1744        fn for fn in native_functions if Variant.method in fn.variants
1745    ]
1746    non_method_native_functions = [
1747        fn for fn in native_functions if fn not in method_native_functions
1748    ]
1749    cpu_fm.write(
1750        "MethodOperators.h",
1751        lambda: {
1752            "MethodOperators_includes": [],
1753            "MethodOperators_declarations": list(
1754                mapMaybe(
1755                    ComputeOperators(
1756                        Target.DECLARATION,
1757                        static_dispatch_backend_indices=static_dispatch_idx,
1758                    ),
1759                    method_native_functions,
1760                )
1761            ),
1762        },
1763    )
1764    cpu_fm.write(
1765        "Operators.h",
1766        lambda: {
1767            "Operators_includes": ["#include <ATen/MethodOperators.h>"],
1768            "Operators_declarations": list(
1769                mapMaybe(
1770                    ComputeOperators(
1771                        Target.DECLARATION,
1772                        static_dispatch_backend_indices=static_dispatch_idx,
1773                    ),
1774                    non_method_native_functions,
1775                )
1776            ),
1777        },
1778    )
1779    cpu_fm.write(
1780        "Functions.h",
1781        lambda: {
1782            "static_dispatch_extra_headers": static_dispatch_extra_headers(
1783                static_dispatch_idx
1784            ),
1785            "Functions_includes": ["#include <ATen/Operators.h>"],
1786            "Functions_declarations": list(
1787                mapMaybe(
1788                    ComputeFunction(),
1789                    native_functions,
1790                )
1791            ),
1792        },
1793    )
1794    declarations = get_native_function_declarations(
1795        grouped_native_functions=grouped_native_functions,
1796        backend_indices=backend_indices,
1797    )
1798    cpu_fm.write(
1799        "NativeFunctions.h",
1800        lambda: {
1801            "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
1802            "NativeFunctions_declarations": declarations,
1803        },
1804    )
1805
1806    for dispatch_key in dispatch_keys:
1807        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1808        if dispatch_key in functions_keys:
1809            inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1810
1811            fm.write_with_template(
1812                f"{dispatch_key}Functions.h",
1813                "DispatchKeyFunctions.h",
1814                lambda: {
1815                    "dispatch_key": str(dispatch_key),
1816                    "inline_headers": inl_headers,
1817                },
1818            )
1819            fm.write_with_template(
1820                f"{dispatch_key}Functions_inl.h",
1821                "DispatchKeyFunctions_inl.h",
1822                lambda: {
1823                    "DispatchKeyFunctions_inl_includes": [],
1824                    "dispatch_namespace": dispatch_key.lower(),
1825                    "dispatch_namespaced_declarations": get_namespaced_declaration(
1826                        grouped_native_functions=grouped_native_functions,
1827                        dispatch_key=dispatch_key,
1828                        backend_idx=backend_indices[dispatch_key],
1829                        selector=selector,
1830                        rocm=rocm,
1831                        symint=True,
1832                    ),
1833                },
1834            )
1835
1836        del fm
1837
1838
1839def gen_per_operator_headers(
1840    *,
1841    native_functions: Sequence[NativeFunction],
1842    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1843    static_dispatch_idx: list[BackendIndex],
1844    selector: SelectiveBuilder,
1845    backend_indices: dict[DispatchKey, BackendIndex],
1846    cpu_fm: FileManager,
1847    cuda_fm: FileManager,
1848    ops_fm: FileManager,
1849    functions_keys: set[DispatchKey],
1850    dispatch_keys: Sequence[DispatchKey],
1851    rocm: bool,
1852) -> None:
1853    # For CMake builds, split operator declarations into separate headers in
1854    # the ATen/ops folder to split up header dependencies
1855    functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
1856    for fn in native_functions:
1857        functions_by_root_name[fn.root_name].append(fn)
1858
1859    grouped_functions_by_root_name: dict[
1860        str, list[NativeFunction | NativeFunctionsGroup]
1861    ] = defaultdict(list)
1862    for group in grouped_native_functions:
1863        name = group.root_name
1864        grouped_functions_by_root_name[name].append(group)
1865
1866    for name, functions in functions_by_root_name.items():
1867        ops_fm.write_with_template(
1868            f"{name}_ops.h",
1869            "Operator.h",
1870            lambda: {
1871                "declarations": list(
1872                    mapMaybe(
1873                        ComputeOperators(
1874                            Target.DECLARATION,
1875                            static_dispatch_backend_indices=static_dispatch_idx,
1876                        ),
1877                        functions,
1878                    )
1879                ),
1880            },
1881        )
1882
1883        ops_fm.write_with_template(
1884            f"{name}.h",
1885            "Function.h",
1886            lambda: {
1887                "static_dispatch_ops_headers": list(
1888                    mapMaybe(
1889                        lambda fn: static_dispatch_ops_header(
1890                            fn, backend_index=static_dispatch_idx
1891                        ),
1892                        functions,
1893                    )
1894                ),
1895                "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
1896                "function_definitions": list(
1897                    mapMaybe(
1898                        ComputeFunction(),
1899                        functions,
1900                    )
1901                ),
1902            },
1903        )
1904
1905        grouped_functions = grouped_functions_by_root_name.get(name, [])
1906        structured_functions = [
1907            fn
1908            for fn in grouped_functions
1909            if isinstance(fn, NativeFunctionsGroup) and fn.structured
1910        ]
1911        is_structured = len(structured_functions) > 0
1912
1913        if is_structured:
1914            ops_fm.write_with_template(
1915                f"{name}_meta.h",
1916                "NativeMetaFunction.h",
1917                lambda: {
1918                    "meta_function_declarations": list(
1919                        mapMaybe(
1920                            compute_meta_function_declaration, structured_functions
1921                        )
1922                    ),
1923                },
1924            )
1925        declarations = get_native_function_declarations(
1926            grouped_native_functions=grouped_functions,
1927            backend_indices=backend_indices,
1928            native_function_decl_gen=dest.compute_native_function_declaration,
1929        )
1930        ops_fm.write_with_template(
1931            f"{name}_native.h",
1932            "NativeFunction.h",
1933            lambda: {
1934                "extra_includes": (
1935                    f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
1936                ),
1937                "native_function_declarations": declarations,
1938            },
1939        )
1940
1941    for category, suffix in [
1942        ("Functions", ""),
1943        ("Operators", "_ops"),
1944        ("NativeMetaFunctions", "_meta"),
1945        ("NativeFunctions", "_native"),
1946    ]:
1947        cpu_fm.write(
1948            f"{category}.h",
1949            lambda: {
1950                f"{category}_includes": [
1951                    f"#include <ATen/ops/{name}{suffix}.h>"
1952                    for name in sorted(functions_by_root_name.keys())
1953                ],
1954                f"{category}_declarations": [],
1955            },
1956        )
1957
1958    for dispatch_key in dispatch_keys:
1959        if dispatch_key not in functions_keys:
1960            continue
1961
1962        dispatch_namespace = dispatch_key.lower()
1963        dispatch_names = []
1964
1965        for name, functions in functions_by_root_name.items():
1966            grouped_functions = grouped_functions_by_root_name.get(name, [])
1967            declarations = list(
1968                concatMap(
1969                    dest.RegisterDispatchKey(
1970                        backend_indices[dispatch_key],
1971                        Target.NAMESPACED_DECLARATION,
1972                        selector,
1973                        rocm=rocm,
1974                        symint=True,
1975                        class_method_name=None,
1976                        skip_dispatcher_op_registration=False,
1977                    ),
1978                    grouped_functions,
1979                )
1980            )
1981
1982            if len(declarations) == 0:
1983                continue
1984
1985            dispatch_names.append(name)
1986            ops_fm.write_with_template(
1987                f"{name}_{dispatch_namespace}_dispatch.h",
1988                "DispatchKeyFunction.h",
1989                lambda: {
1990                    "dispatch_namespace": dispatch_namespace,
1991                    "dispatch_namespaced_declarations": declarations,
1992                },
1993            )
1994
1995        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1996        inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1997
1998        fm.write_with_template(
1999            f"{dispatch_key}Functions.h",
2000            "DispatchKeyFunctions.h",
2001            lambda: {
2002                "dispatch_key": str(dispatch_key),
2003                "inline_headers": inl_headers,
2004            },
2005        )
2006        fm.write_with_template(
2007            f"{dispatch_key}Functions_inl.h",
2008            "DispatchKeyFunctions_inl.h",
2009            lambda: {
2010                "dispatch_namespace": dispatch_namespace,
2011                "DispatchKeyFunctions_inl_includes": [
2012                    f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
2013                    for name in sorted(dispatch_names)
2014                ],
2015                "dispatch_namespaced_declarations": [],
2016            },
2017        )
2018        del fm
2019
2020    cpu_fm.write(
2021        "MethodOperators.h",
2022        lambda: {
2023            "MethodOperators_includes": sorted(
2024                f"#include <ATen/ops/{name}_ops.h>"
2025                for name, functions in functions_by_root_name.items()
2026                if any(Variant.method in fn.variants for fn in functions)
2027            ),
2028            "MethodOperators_declarations": [],
2029        },
2030    )
2031
2032
2033def gen_headers(
2034    *,
2035    native_functions: Sequence[NativeFunction],
2036    valid_tags: set[str],
2037    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2038    structured_native_functions: Sequence[NativeFunctionsGroup],
2039    static_dispatch_idx: list[BackendIndex],
2040    selector: SelectiveBuilder,
2041    backend_indices: dict[DispatchKey, BackendIndex],
2042    core_fm: FileManager,
2043    cpu_fm: FileManager,
2044    cuda_fm: FileManager,
2045    ops_fm: FileManager,
2046    dispatch_keys: Sequence[DispatchKey],
2047    functions_keys: set[DispatchKey],
2048    rocm: bool,
2049    per_operator_headers: bool,
2050) -> None:
2051    if per_operator_headers:
2052        gen_per_operator_headers(
2053            native_functions=native_functions,
2054            grouped_native_functions=grouped_native_functions,
2055            static_dispatch_idx=static_dispatch_idx,
2056            selector=selector,
2057            backend_indices=backend_indices,
2058            cpu_fm=cpu_fm,
2059            cuda_fm=cuda_fm,
2060            ops_fm=ops_fm,
2061            dispatch_keys=dispatch_keys,
2062            functions_keys=functions_keys,
2063            rocm=rocm,
2064        )
2065    else:
2066        gen_aggregated_headers(
2067            native_functions=native_functions,
2068            grouped_native_functions=grouped_native_functions,
2069            structured_native_functions=structured_native_functions,
2070            static_dispatch_idx=static_dispatch_idx,
2071            selector=selector,
2072            backend_indices=backend_indices,
2073            cpu_fm=cpu_fm,
2074            cuda_fm=cuda_fm,
2075            dispatch_keys=dispatch_keys,
2076            functions_keys=functions_keys,
2077            rocm=rocm,
2078        )
2079
2080    core_fm.write(
2081        "TensorBody.h",
2082        lambda: {
2083            "tensor_method_declarations": list(
2084                mapMaybe(
2085                    ComputeTensorMethod(
2086                        target=Target.DECLARATION,
2087                        static_dispatch_backend_indices=static_dispatch_idx,
2088                    ),
2089                    native_functions,
2090                )
2091            ),
2092            "tensor_method_definitions": list(
2093                mapMaybe(
2094                    ComputeTensorMethod(
2095                        target=Target.DEFINITION,
2096                        static_dispatch_backend_indices=static_dispatch_idx,
2097                    ),
2098                    native_functions,
2099                )
2100            ),
2101        },
2102    )
2103
2104    cpu_fm.write(
2105        "RedispatchFunctions.h",
2106        lambda: {
2107            "function_redispatch_definitions": list(
2108                mapMaybe(ComputeRedispatchFunction(), native_functions)
2109            ),
2110        },
2111    )
2112
2113    cpu_fm.write(
2114        "RegistrationDeclarations.h",
2115        lambda: {
2116            "registration_declarations": [
2117                compute_registration_declarations(f, backend_indices)
2118                for f in native_functions
2119            ],
2120        },
2121    )
2122
2123    cpu_fm.write(
2124        "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
2125    )
2126
2127    def gen_aten_interned_strings() -> dict[str, str]:
2128        attrs: set[str] = set()  # All function argument names
2129        names = set()  # All ATen function names
2130        for func in native_functions:
2131            names.add(str(func.func.name.name))
2132            # Some operators don't have a functional variant but we still create a
2133            # symbol without the underscore
2134            names.add(func.func.name.name.base)
2135
2136            attrs.update(arg.name for arg in func.func.schema_order_arguments())
2137
2138        # These are keywords in C++, so aren't valid symbol names
2139        # https://en.cppreference.com/w/cpp/language/operator_alternative
2140        names -= {
2141            "and",
2142            "and_eq",
2143            "bitand",
2144            "bitor",
2145            "compl",
2146            "not",
2147            "not_eq",
2148            "or",
2149            "or_eq",
2150            "xor",
2151            "xor_eq",
2152        }
2153
2154        return {
2155            "aten_symbols": " \\\n".join(
2156                [f"_(aten, {name})" for name in sorted(names)]
2157            ),
2158            "attr_symbols": " \\\n".join(
2159                [f"_(attr, {name})" for name in sorted(attrs)]
2160            ),
2161        }
2162
2163    core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
2164
2165    def gen_tags_enum() -> dict[str, str]:
2166        return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
2167
2168    core_fm.write("enum_tag.h", gen_tags_enum)
2169
2170
2171def gen_source_files(
2172    *,
2173    native_functions: Sequence[NativeFunction],
2174    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2175    structured_native_functions: Sequence[NativeFunctionsGroup],
2176    view_groups: Sequence[NativeFunctionsViewGroup],
2177    selector: SelectiveBuilder,
2178    static_dispatch_idx: list[BackendIndex],
2179    backend_indices: dict[DispatchKey, BackendIndex],
2180    aoti_fm: FileManager,
2181    core_fm: FileManager,
2182    cpu_fm: FileManager,
2183    cpu_vec_fm: FileManager,
2184    cuda_fm: FileManager,
2185    dispatch_keys: Sequence[DispatchKey],
2186    functions_keys: set[DispatchKey],
2187    rocm: bool,
2188    force_schema_registration: bool,
2189    per_operator_headers: bool,
2190    skip_dispatcher_op_registration: bool,
2191    update_aoti_c_shim: bool,
2192) -> None:
2193    extra_cuda_headers = """\
2194#include <c10/cuda/CUDAGuard.h>
2195#include <ATen/cuda/ATenCUDAGeneral.h>
2196#include <ATen/cuda/CUDADevice.h>
2197#include <ATen/cuda/CUDAContext.h>"""
2198    if rocm:
2199        extra_cuda_headers = """\
2200#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
2201#include <ATen/hip/ATenHIPGeneral.h>
2202#include <ATen/hip/HIPDevice.h>
2203#include <ATen/hip/HIPContext.h>"""
2204
2205    for dispatch_key in dispatch_keys:
2206        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
2207
2208        if per_operator_headers:
2209
2210            def operator_headers() -> list[str]:
2211                headers = []
2212                for g in grouped_native_functions:
2213                    is_registered = False
2214                    if backend_index.has_kernel(g):
2215                        is_registered = True
2216                    # The above has_kernel test on a group will only test for
2217                    # the existence of out dispatch, because that's how
2218                    # structured kernels work. But sometimes functions can be
2219                    # grouped but not be structured, and then you need to check
2220                    # each individual piece, as they may have manual dispatch
2221                    # entries.
2222                    elif isinstance(g, NativeFunctionsGroup) and any(
2223                        backend_index.has_kernel(fn) for fn in g.functions()
2224                    ):
2225                        is_registered = True
2226                    # TODO: this condition is a bit questionable
2227                    # (It has to do with the fact that structured kernels get generated kernels
2228                    # to the Meta + CompositeExplicitAutogradNonFunctional keys).
2229                    elif g.structured and dispatch_key in (
2230                        DispatchKey.Meta,
2231                        DispatchKey.CompositeExplicitAutogradNonFunctional,
2232                    ):
2233                        is_registered = True
2234                    if not is_registered:
2235                        continue
2236
2237                    headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
2238                    if (
2239                        dispatch_key
2240                        == DispatchKey.CompositeExplicitAutogradNonFunctional
2241                    ):
2242                        headers.append(f"#include <ATen/ops/{g.root_name}.h>")
2243                    if dispatch_key in functions_keys:
2244                        headers.append(
2245                            f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
2246                        )
2247
2248                return sorted(set(headers))
2249
2250        else:
2251
2252            def operator_headers() -> list[str]:
2253                headers = ["#include <ATen/NativeFunctions.h>"]
2254                if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
2255                    headers.append("#include <ATen/Functions.h>")
2256                if dispatch_key in functions_keys:
2257                    headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
2258                return headers
2259
2260        backend_index = backend_indices[dispatch_key]
2261        ns_grouped_native_functions = defaultdict(list)
2262        for grouped_native_function in grouped_native_functions:
2263            namespace = (
2264                grouped_native_function.namespace
2265                if isinstance(grouped_native_function, NativeFunction)
2266                else grouped_native_function.functional.namespace
2267            )
2268            ns_grouped_native_functions[namespace].append(grouped_native_function)
2269
2270        dispatch_namespace = str(dispatch_key).lower()
2271
2272        # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
2273        # compilation will fail when `-Werror=unused-function` flag is set
2274        gen_dispatch_helpers: bool = (
2275            dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
2276        )
2277
2278        dispatch_definitions = get_native_function_definitions(
2279            fm=fm,
2280            grouped_native_functions=grouped_native_functions,
2281            dispatch_key=dispatch_key,
2282            backend_idx=backend_index,
2283            selector=selector,
2284            rocm=rocm,
2285            symint=True,
2286            skip_dispatcher_op_registration=skip_dispatcher_op_registration,
2287            gen_dispatch_helpers=gen_dispatch_helpers,
2288        )
2289        fm.write_with_template(
2290            f"Register{dispatch_key}.cpp",
2291            "RegisterDispatchKey.cpp",
2292            lambda: {
2293                "extra_cuda_headers": extra_cuda_headers
2294                if is_cuda_dispatch_key(dispatch_key)
2295                else "",
2296                "external_backend_headers": "",
2297                "dispatch_headers": dest.gen_registration_headers(
2298                    backend_index, per_operator_headers, rocm
2299                ),
2300                "ops_headers": operator_headers(),
2301                "dispatch_helpers": "",
2302                "dispatch_definitions": dispatch_definitions,
2303            },
2304        )
2305
2306        for g in structured_native_functions:
2307            if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
2308                continue
2309            name = g.functional.func.name.name
2310            if dispatch_key is DispatchKey.CPU:
2311                assert fm is cpu_fm
2312                fm.write_with_template(
2313                    f"UfuncCPU_{name}.cpp",
2314                    "UfuncCPU.cpp",
2315                    lambda: {
2316                        "meta_declaration": compute_meta_function_declaration(g),
2317                        "native_declaration": dest.compute_native_function_declaration(
2318                            g, backend_indices[dispatch_key]
2319                        ),
2320                        "native_definitions": dest.compute_ufunc_cpu(g),
2321                    },
2322                )
2323                cpu_vec_fm.write_with_template(
2324                    f"UfuncCPUKernel_{name}.cpp",
2325                    "UfuncCPUKernel.cpp",
2326                    lambda: {
2327                        "name": name,
2328                        "native_definitions": dest.compute_ufunc_cpu_kernel(g),
2329                    },
2330                )
2331            elif dispatch_key is DispatchKey.CUDA:
2332                cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
2333                if rocm:
2334                    cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
2335                fm.write_with_template(
2336                    f"UfuncCUDA_{name}.cu",
2337                    "UfuncCUDA.cu",
2338                    lambda: {
2339                        "name": name,
2340                        "cuda_headers": cuda_headers,
2341                        "meta_declaration": compute_meta_function_declaration(g),
2342                        "native_declaration": dest.compute_native_function_declaration(
2343                            g, backend_indices[dispatch_key]
2344                        ),
2345                        "native_definitions": dest.compute_ufunc_cuda(g),
2346                    },
2347                )
2348            else:
2349                raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
2350
2351        structured_func_group_dict = {}
2352        for func_group in structured_native_functions:
2353            for func in func_group.functions():
2354                if func.structured_delegate is not None:
2355                    structured_func_group_dict[func.structured_delegate] = func_group
2356                    break
2357
2358        if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
2359            fallbacks = {}
2360            for func in native_functions:
2361                op_name = get_fallback_op_name(func)
2362                if op_name in inductor_fallback_ops:
2363                    fallbacks[op_name] = func
2364            fallback_native_functions = tuple(
2365                value for _, value in sorted(fallbacks.items())
2366            )
2367
2368            # header files were checked in for ABI-compatiblilty checking
2369            header_file_name = f"c_shim_{dispatch_key.lower()}.h"
2370            new_header = gen_aoti_c_shim(
2371                fallback_native_functions,
2372                structured_func_group_dict,
2373                dispatch_key,
2374                backend_indices,
2375                header=True,
2376                includes="",
2377            )
2378            if update_aoti_c_shim:
2379                aoti_fm.write(
2380                    header_file_name,
2381                    lambda: new_header,
2382                )
2383            else:
2384                try:
2385                    with open(
2386                        os.path.join(aoti_fm.install_dir, header_file_name)
2387                    ) as old_file:
2388                        old_header = old_file.read()
2389                        assert (
2390                            old_header == new_header
2391                        ), """
2392
2393WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
2394indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
2395Only in a limited number of situations, this is allowed:
2396
23971. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
2398If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
2399C shim header files.
2400
24012. You added a new default argument to an existing fallback op. This is clearly a BC breaking
2402change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
2403fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
2404number of that fallback op in the newly generated C shim files, and update the cpp wrapper
2405codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
2406
2407                        """
2408                except FileNotFoundError:
2409                    print(
2410                        f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
2411                    )
2412
2413            # cpp files are always generated on-the-fly
2414            def headers_for_aoti() -> str:
2415                headers = []
2416                for func in fallback_native_functions:
2417                    header = get_header_for_aoti(
2418                        func, structured_func_group_dict, dispatch_key, backend_indices
2419                    )
2420                    if header is not None:
2421                        headers.append(header)
2422                return "\n".join(sorted(set(headers)))
2423
2424            extra_headers = (
2425                extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
2426            )
2427
2428            aoti_fm.write(
2429                f"c_shim_{dispatch_key.lower()}.cpp",
2430                lambda: gen_aoti_c_shim(
2431                    fallback_native_functions,
2432                    structured_func_group_dict,
2433                    dispatch_key,
2434                    backend_indices,
2435                    header=False,
2436                    includes=headers_for_aoti() + "\n" + extra_headers,
2437                ),
2438            )
2439
2440        del fm
2441
2442    # BackendSelect is generated specially
2443    def gen_backend_select() -> dict[str, list[str]]:
2444        relevant_fns = [
2445            fn for fn in native_functions if needs_backend_select(fn, selector)
2446        ]
2447        return {
2448            "ops_headers": [
2449                f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
2450            ],
2451            "backend_select_method_definitions": list(
2452                mapMaybe(
2453                    ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
2454                )
2455            ),
2456            "backend_select_function_registrations": list(
2457                mapMaybe(
2458                    ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
2459                )
2460            ),
2461        }
2462
2463    cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
2464
2465    schema_selector = selector
2466    if force_schema_registration:
2467        schema_selector = SelectiveBuilder.get_nop_selector()
2468
2469    (
2470        aten_schema_registrations,
2471        schema_registrations,
2472    ) = get_native_function_schema_registrations(
2473        native_functions=native_functions, schema_selector=schema_selector
2474    )
2475    cpu_fm.write(
2476        "RegisterSchema.cpp",
2477        lambda: {
2478            "aten_schema_registrations": []
2479            if skip_dispatcher_op_registration
2480            else aten_schema_registrations,
2481            "schema_registrations": []
2482            if skip_dispatcher_op_registration
2483            else schema_registrations,
2484        },
2485    )
2486
2487    def key_func(
2488        fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2489    ) -> str:
2490        return fn.root_name
2491
2492    cpu_fm.write_sharded(
2493        "Operators.cpp",
2494        native_functions,
2495        key_fn=key_func,
2496        env_callable=lambda fn: {
2497            "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
2498            "definitions": [
2499                ComputeOperators(
2500                    Target.DEFINITION,
2501                    static_dispatch_backend_indices=static_dispatch_idx,
2502                )(fn)
2503            ],
2504        },
2505        base_env={
2506            "static_dispatch_extra_headers": static_dispatch_extra_headers(
2507                static_dispatch_idx
2508            ),
2509        },
2510        num_shards=5,
2511        sharded_keys={
2512            "operator_headers",
2513            "definitions",
2514            "static_dispatch_extra_headers",
2515        },
2516    )
2517
2518    cpu_fm.write("Functions.cpp", dict)
2519
2520    core_fm.write("TensorMethods.cpp", dict)
2521
2522    core_fm.write(
2523        "ATenOpList.cpp",
2524        lambda: {
2525            "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
2526        },
2527    )
2528
2529    def functionalization_env_callable(
2530        g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2531    ) -> dict[str, list[str]]:
2532        def gen_op_headers(
2533            g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2534        ) -> list[str]:
2535            if isinstance(g, NativeFunctionsViewGroup):
2536                # view ops always get a functionalization kernel
2537                headers = [
2538                    f"#include <ATen/ops/{g.view.root_name}_native.h>",
2539                    f"#include <ATen/ops/{g.view.root_name}_ops.h>",
2540                ]
2541                if g.view_copy is not None:
2542                    headers += [
2543                        f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
2544                        f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
2545                    ]
2546                return headers
2547            elif isinstance(g, NativeFunctionsGroup):
2548                headers = [
2549                    f"#include <ATen/ops/{g.functional.root_name}_native.h>",
2550                    f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
2551                    f"#include <ATen/ops/{g.out.root_name}_native.h>",
2552                    f"#include <ATen/ops/{g.out.root_name}_ops.h>",
2553                ]
2554                if g.inplace is not None:
2555                    headers += [
2556                        f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
2557                        f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
2558                    ]
2559                if g.mutable is not None:
2560                    headers += [
2561                        f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
2562                        f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
2563                    ]
2564                return headers
2565            else:
2566                return [
2567                    f"#include <ATen/ops/{g.root_name}_native.h>",
2568                    f"#include <ATen/ops/{g.root_name}_ops.h>",
2569                ]
2570
2571        return {
2572            "ops_headers": gen_op_headers(g),
2573            "func_definitions": gen_functionalization_definition(
2574                selector,
2575                g,
2576            ),
2577            "func_registrations": gen_functionalization_registration(
2578                selector,
2579                g,
2580                backend_indices[DispatchKey.CompositeImplicitAutograd],
2581            ),
2582        }
2583
2584    all_groups: list[
2585        NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
2586    ] = list(structured_native_functions) + list(
2587        view_groups  # type: ignore[assignment, arg-type, operator]
2588    )
2589    # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
2590    # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
2591    # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
2592    # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
2593    #     Although this could go away long-term if we add a dedicated dispatch key for decompositions.
2594    structured_map: dict[OperatorName, NativeFunction] = {
2595        f.func.name: f
2596        for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
2597    }
2598    view_map: dict[OperatorName, NativeFunction] = {
2599        f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
2600    }
2601    for f in native_functions:
2602        if f.func.name not in structured_map and f.func.name not in view_map:
2603            all_groups.append(f)
2604
2605    cpu_fm.write_sharded(
2606        "RegisterFunctionalization.cpp",
2607        all_groups,
2608        key_fn=key_func,
2609        env_callable=functionalization_env_callable,
2610        num_shards=4,
2611        sharded_keys={
2612            "ops_headers",
2613            "func_definitions",
2614            "func_registrations",
2615            "func_add_back_views_definitions",
2616            "func_add_back_views_registrations",
2617        },
2618    )
2619
2620    cpu_fm.write(
2621        "FunctionalInverses.h",
2622        lambda: {
2623            "view_inverse_declarations": list(
2624                mapMaybe(
2625                    lambda g: gen_functionalization_view_inverse_declaration(
2626                        selector, g
2627                    ),
2628                    view_groups,
2629                )
2630            )
2631        },
2632    )
2633
2634    # Note [view_copy NativeFunctions]
2635    # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
2636    # needs to have a corresponding non-aliasing {view}_copy variant.
2637    # Backends that use functionalization and don't know how to handle aliasing ops
2638    # are expected to implement kernels for these {view}_copy kernels instead.
2639    # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
2640    # so we codegen the following:
2641    # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
2642    #     These are never explicitly invoked by the functionalization pass,
2643    #     but they could theoretically be called from user code (I added these kernels for completeness,
2644    #     since the ops are part of the public API).
2645    # (2) A derivative formula for every {view}_copy operator
2646    #     {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
2647    #     so rather than stamping all of the entries out in derivatives.yaml,
2648    #     we codegen them in.
2649    #     This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
2650    cpu_fm.write(
2651        "CompositeViewCopyKernels.cpp",
2652        lambda: {
2653            "ops_headers": [
2654                "\n".join(
2655                    f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2656                    # NB: this include is important as it ensures we
2657                    # set the visibility on generated view_copy kernels
2658                    # correctly
2659                    f"#include <ATen/ops/{f.root_name}_native.h>"
2660                    for f in (
2661                        [g.view] if g.view_copy is None else [g.view, g.view_copy]
2662                    )
2663                )
2664                for g in view_groups
2665            ]
2666            + [
2667                "\n".join(
2668                    f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2669                    # NB: this include is also important for correct visibility
2670                    f"#include <ATen/ops/{f.root_name}_native.h>"
2671                    for f in [g.inplace, g.mutable, g.functional]
2672                    if f is not None and "generated" not in f.tags
2673                )
2674                for g in structured_native_functions
2675            ],
2676            "CompositeViewCopyKernel_Definitions": list(
2677                mapMaybe(
2678                    GenCompositeViewCopyKernel(
2679                        backend_indices[
2680                            DispatchKey.CompositeExplicitAutogradNonFunctional
2681                        ]
2682                    ),
2683                    view_groups,
2684                )
2685            ),
2686            "GeneratedCompositeFunctional_Definitions": list(
2687                mapMaybe(
2688                    gen_composite_functional_kernel,
2689                    structured_native_functions,
2690                )
2691            ),
2692            "GeneratedCompositeOut_Definitions": list(
2693                mapMaybe(
2694                    gen_composite_out_kernel,
2695                    structured_native_functions,
2696                )
2697            ),
2698        },
2699    )
2700
2701
2702def gen_declarations_yaml(
2703    cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
2704) -> None:
2705    cpu_fm.write(
2706        "Declarations.yaml",
2707        lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
2708    )
2709
2710
2711def get_torchgen_root() -> Path:
2712    """
2713    If you're depending on torchgen out-of-tree, you can use the root to figure
2714    out the path to native_functions.yaml
2715    """
2716    return Path(__file__).parent.resolve()
2717
2718
2719def main() -> None:
2720    parser = argparse.ArgumentParser(description="Generate ATen source files")
2721    parser.add_argument(
2722        "-s",
2723        "--source-path",
2724        help="path to source directory for ATen",
2725        default="aten/src/ATen",
2726    )
2727    parser.add_argument(
2728        "-o",
2729        "--output-dependencies",
2730        help="output a list of dependencies into the given file and exit",
2731    )
2732    parser.add_argument(
2733        "--dry-run",
2734        action="store_true",
2735        help="run without writing any files (still updates outputs)",
2736    )
2737    parser.add_argument(
2738        "--per-operator-headers",
2739        action="store_true",
2740        help="generate separate headers per operator in ATen/ops",
2741    )
2742    parser.add_argument(
2743        "-d",
2744        "--install-dir",
2745        "--install_dir",
2746        help="output directory",
2747        default="build/aten/src/ATen",
2748    )
2749    parser.add_argument(
2750        "--aoti-install-dir",
2751        "--aoti_install_dir",
2752        help="output directory for AOTInductor shim",
2753        default="torch/csrc/inductor/aoti_torch/generated",
2754    )
2755    parser.add_argument(
2756        "--rocm",
2757        action="store_true",
2758        help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
2759    )
2760    parser.add_argument(
2761        "--mps",
2762        action="store_true",
2763        help="Generate MPS registration code when set",
2764    )
2765    # TODO: --op-registration-whitelist will be removed when all call-sites
2766    # for gen.py are moved over to using the operator YAML file for mobile
2767    # custom build.
2768    parser.add_argument(
2769        "--op-registration-whitelist",
2770        "--op_registration_whitelist",
2771        nargs="*",
2772        help="filter op registrations by the whitelist (if set); "
2773        "each item is `namespace`::`operator name` without overload name; "
2774        "e.g.: aten::empty aten::conv2d ...",
2775    )
2776    parser.add_argument(
2777        "--op-selection-yaml-path",
2778        "--op_selection_yaml_path",
2779        help="Provide a path to the operator selection (for custom build) YAML "
2780        "that contains the information about the set of selected operators "
2781        "and their categories (training, ...). Each operator is either a "
2782        "full operator name with overload or just a bare operator name. "
2783        "The operator names also contain the namespace prefix (e.g. aten::)",
2784    )
2785    parser.add_argument(
2786        "--backend-whitelist",
2787        "--backend_whitelist",
2788        nargs="*",
2789        help="filter dispatch backend by the whitelist (if set), "
2790        "e.g.: CPU CUDA QuantizedCPU ...",
2791    )
2792    parser.add_argument(
2793        "--static-dispatch-backend",
2794        "--static_dispatch_backend",
2795        nargs="*",
2796        help="generate static dispatch code for the specific backend (if set)",
2797    )
2798    parser.add_argument(
2799        "--skip-dispatcher-op-registration",
2800        "--skip_dispatcher_op_registration",
2801        action="store_true",
2802        help="Avoid registering operators into the dispatcher.",
2803    )
2804    parser.add_argument(
2805        "--force-schema-registration",
2806        "--force_schema_registration",
2807        action="store_true",
2808        help="force it to generate schema-only registrations for all ops, including"
2809        "those that are not listed on --op-registration-whitelist",
2810    )
2811    parser.add_argument(
2812        "--generate",
2813        type=str,
2814        nargs="*",
2815        choices=["headers", "sources", "declarations_yaml"],
2816        default=["headers", "sources", "declarations_yaml"],
2817        help="Generate only a subset of files",
2818    )
2819    parser.add_argument(
2820        "--update-aoti-c-shim",
2821        action="store_true",
2822        help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
2823        "WARNING: Do not use this unless you are sure what you are doing!!!",
2824    )
2825
2826    options = parser.parse_args()
2827
2828    selector = get_custom_build_selector(
2829        options.op_registration_whitelist,
2830        options.op_selection_yaml_path,
2831    )
2832
2833    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
2834    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
2835
2836    from torchgen.model import dispatch_keys
2837
2838    # TODO: stop generating CUDA kernels for non-CUDA builds
2839    ignore_keys = set()
2840    if not options.mps:
2841        ignore_keys.add(DispatchKey.MPS)
2842
2843        if DispatchKey.MPS in dispatch_keys:
2844            del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
2845
2846    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
2847    valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
2848    native_functions, backend_indices = (
2849        parsed_yaml.native_functions,
2850        parsed_yaml.backend_indices,
2851    )
2852
2853    grouped_native_functions = get_grouped_native_functions(native_functions)
2854
2855    structured_native_functions = [
2856        g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
2857    ]
2858    native_functions_with_view_groups = get_grouped_by_view_native_functions(
2859        native_functions
2860    )
2861    view_groups = [
2862        g
2863        for g in native_functions_with_view_groups
2864        if isinstance(g, NativeFunctionsViewGroup)
2865    ]
2866
2867    # NB: It is mandatory to NOT use os.path.join here, as the install directory
2868    # will eventually be ingested by cmake, which does not respect Windows style
2869    # path slashes.  If you switch this to use os.path.join, you'll get an error
2870    # like:
2871    #
2872    #   Syntax error in cmake code when parsing string
2873    #
2874    #     C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
2875    #
2876    #   Invalid character escape '\c'.
2877    core_install_dir = f"{options.install_dir}/core"
2878    Path(core_install_dir).mkdir(parents=True, exist_ok=True)
2879    ops_install_dir = f"{options.install_dir}/ops"
2880    Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
2881    aoti_install_dir = f"{options.aoti_install_dir}"
2882    Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
2883
2884    core_fm = make_file_manager(options=options, install_dir=core_install_dir)
2885    cpu_fm = make_file_manager(options=options)
2886    cpu_vec_fm = make_file_manager(options=options)
2887    cuda_fm = make_file_manager(options=options)
2888    ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
2889    aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
2890
2891    # Only a limited set of dispatch keys get CPUFunctions.h headers generated
2892    # for them; this is the set
2893    functions_keys = {
2894        DispatchKey.CPU,
2895        DispatchKey.CUDA,
2896        DispatchKey.CompositeImplicitAutograd,
2897        DispatchKey.CompositeImplicitAutogradNestedTensor,
2898        DispatchKey.CompositeExplicitAutograd,
2899        DispatchKey.CompositeExplicitAutogradNonFunctional,
2900        DispatchKey.Meta,
2901    }
2902    if options.mps:
2903        functions_keys.add(DispatchKey.MPS)
2904
2905    if options.backend_whitelist:
2906        dispatch_keys = [
2907            k
2908            for k in dispatch_keys
2909            if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
2910        ]
2911
2912    static_dispatch_idx: list[BackendIndex] = []
2913    if options.static_dispatch_backend:
2914        static_dispatch_idx = [
2915            backend_indices[DispatchKey.parse(key)]
2916            for key in options.static_dispatch_backend
2917        ]
2918        for key in options.static_dispatch_backend:
2919            dp_key = DispatchKey.parse(key)
2920            if dp_key not in functions_keys:
2921                functions_keys.add(dp_key)
2922
2923    if "sources" in options.generate:
2924        gen_source_files(
2925            native_functions=native_functions,
2926            grouped_native_functions=grouped_native_functions,
2927            structured_native_functions=structured_native_functions,
2928            view_groups=view_groups,
2929            selector=selector,
2930            static_dispatch_idx=static_dispatch_idx,
2931            backend_indices=backend_indices,
2932            aoti_fm=aoti_fm,
2933            core_fm=core_fm,
2934            cpu_fm=cpu_fm,
2935            cpu_vec_fm=cpu_vec_fm,
2936            cuda_fm=cuda_fm,
2937            dispatch_keys=dispatch_keys,
2938            functions_keys=functions_keys,
2939            rocm=options.rocm,
2940            force_schema_registration=options.force_schema_registration,
2941            per_operator_headers=options.per_operator_headers,
2942            skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
2943            update_aoti_c_shim=options.update_aoti_c_shim,
2944        )
2945
2946    if "headers" in options.generate:
2947        gen_headers(
2948            native_functions=native_functions,
2949            valid_tags=valid_tags,
2950            grouped_native_functions=grouped_native_functions,
2951            structured_native_functions=structured_native_functions,
2952            static_dispatch_idx=static_dispatch_idx,
2953            selector=selector,
2954            backend_indices=backend_indices,
2955            core_fm=core_fm,
2956            cpu_fm=cpu_fm,
2957            cuda_fm=cuda_fm,
2958            ops_fm=ops_fm,
2959            dispatch_keys=dispatch_keys,
2960            functions_keys=functions_keys,
2961            rocm=options.rocm,
2962            per_operator_headers=options.per_operator_headers,
2963        )
2964
2965    if "declarations_yaml" in options.generate:
2966        gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
2967
2968    if options.output_dependencies:
2969        depfile_path = Path(options.output_dependencies).resolve()
2970        depfile_name = depfile_path.name
2971        depfile_stem = depfile_path.stem
2972
2973        for fm, prefix in [
2974            (cpu_fm, ""),
2975            (cpu_vec_fm, "cpu_vec_"),
2976            (core_fm, "core_"),
2977            (cuda_fm, "cuda_"),
2978            (ops_fm, "ops_"),
2979        ]:
2980            varname = prefix + depfile_stem
2981            path = depfile_path.parent / (prefix + depfile_name)
2982            fm.write_outputs(varname, str(path))
2983
2984
2985if __name__ == "__main__":
2986    main()
2987