xref: /aosp_15_r20/external/pytorch/torchgen/dest/native_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import torchgen.api.meta as meta
4import torchgen.api.structured as structured
5from torchgen.api.types import kernel_signature
6from torchgen.context import with_native_function_and_index
7from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
8from torchgen.utils import mapMaybe
9
10
11@with_native_function_and_index
12def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
13    sig = kernel_signature(f, backend_index)
14    metadata = backend_index.get_kernel(f)
15    if metadata is None:
16        return None
17    if "legacy::" in metadata.kernel:
18        return None
19    else:
20        prefix = "static" if backend_index.external else "TORCH_API"
21        return f"{prefix} {sig.decl(name=metadata.kernel)};"
22
23
24@with_native_function_and_index
25def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
26    meta_name = meta.name(g)
27    out_args = structured.impl_arguments(g)
28    metadata = backend_index.get_kernel(g)
29    if metadata is None:
30        return []
31    prefix = "" if backend_index.external else "TORCH_API "
32    return [
33        f"""\
34struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
35void impl({', '.join(a.decl() for a in out_args)});
36}};
37"""
38    ]
39
40
41# Generates NativeFunctions.h, a list of forward declarations of all
42# actual kernel definitions we keep in aten/src/ATen/native/
43@with_native_function_and_index
44def compute_native_function_declaration(
45    g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
46) -> list[str]:
47    metadata = backend_index.get_kernel(g)
48    if isinstance(g, NativeFunctionsGroup):
49        if metadata is not None and metadata.structured:
50            if backend_index.external:
51                # Structured hasn't been tested with external backends yet.
52                raise AssertionError(
53                    "Structured external backend functions are not implemented yet."
54                )
55            else:
56                return gen_structured(g, backend_index)
57        else:
58            return list(
59                mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
60            )
61    else:
62        x = gen_unstructured(g, backend_index)
63        return [] if x is None else [x]
64