xref: /aosp_15_r20/external/pytorch/torchgen/api/native.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import Sequence
4
5from torchgen import local
6from torchgen.api import cpp
7from torchgen.api.types import (
8    ArgName,
9    BaseCType,
10    Binding,
11    boolT,
12    ConstRefCType,
13    CType,
14    deviceT,
15    layoutT,
16    ListCType,
17    MutRefCType,
18    NamedCType,
19    OptionalCType,
20    scalarT,
21    scalarTypeT,
22    tensorT,
23)
24from torchgen.model import (
25    Argument,
26    FunctionSchema,
27    Return,
28    SelfArgument,
29    TensorOptionsArguments,
30    Type,
31)
32from torchgen.utils import assert_never
33
34
35# This file describes the translation of JIT schema to the native functions API.
36# This looks a lot like the C++ API (which makes historical sense, because the
37# idea was you wrote native functions to implement functions in the C++ API),
38# but over time we have evolved the C++ API without actually changing our
39# native:: kernels.  The intention is to make native API and dispatcher API
40# line up as closely as possible, since this results in the least overhead
41# (no translation is needed from dispatcher API to native API).
42#
43# NB: this is symint aware, you will get the non-SymInt variant for some
44# dispatch entries and SymInt for others.
45
46
47def name(func: FunctionSchema) -> str:
48    name = str(func.name.name)
49    # TODO: delete this!
50    if func.is_out_fn():
51        name += "_out"
52    if func.name.overload_name:
53        name += f"_{func.name.overload_name}"
54    return name
55
56
57def argumenttype_type(
58    t: Type, *, mutable: bool, binds: ArgName, symint: bool
59) -> NamedCType:
60    if str(t) == "Tensor?":
61        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
62        if mutable and not local.use_const_ref_for_mutable_tensors():
63            return NamedCType(binds, MutRefCType(tensor_type))
64        else:
65            return NamedCType(binds, ConstRefCType(tensor_type))
66    elif str(t) == "Tensor?[]":
67        return NamedCType(
68            binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
69        )
70    elif str(t) == "Scalar":
71        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
72    elif str(t) == "Scalar?":
73        return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
74    return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
75
76
77def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
78    return cpp.returns_type(rs, symint=symint)
79
80
81def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
82    return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
83
84
85def argument(
86    a: Argument | SelfArgument | TensorOptionsArguments,
87    *,
88    is_out: bool,
89    symint: bool,
90) -> list[Binding]:
91    # Ideally, we NEVER default native functions.  However, there are a number
92    # of functions that call native:: directly and rely on the defaulting
93    # existing.  So for BC, we generate defaults for non-out variants (but not
94    # for out variants, where it is impossible to generate an appropriate
95    # default)
96    should_default = not is_out
97    if isinstance(a, Argument):
98        default: str | None = None
99        if should_default and a.default is not None:
100            default = cpp.default_expr(a.default, a.type, symint=symint)
101        return [
102            Binding(
103                nctype=argument_type(a, binds=a.name, symint=symint),
104                name=a.name,
105                default=default,
106                argument=a,
107            )
108        ]
109    elif isinstance(a, SelfArgument):
110        # Erase SelfArgument from the distinction
111        return argument(a.argument, is_out=is_out, symint=symint)
112    elif isinstance(a, TensorOptionsArguments):
113        default = None
114        if should_default:
115            default = "{}"
116        # TODO: Not sure why the arguments assigned here are for
117        # TensorOptionsArguments and not the constituent pieces.  It seems
118        # to matter
119        return [
120            Binding(
121                nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
122                name="dtype",
123                default=default,
124                argument=a,
125            ),
126            Binding(
127                nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
128                name="layout",
129                default=default,
130                argument=a,
131            ),
132            Binding(
133                nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
134                name="device",
135                default=default,
136                argument=a,
137            ),
138            Binding(
139                nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
140                name="pin_memory",
141                default=default,
142                argument=a,
143            ),
144        ]
145    else:
146        assert_never(a)
147
148
149def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
150    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
151    args.extend(func.arguments.non_out)
152    args.extend(func.arguments.out)
153    return [
154        r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
155    ]
156