xref: /aosp_15_r20/external/pytorch/torchgen/api/native.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom torchgen import local
6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
8*da0073e9SAndroid Build Coastguard Worker    ArgName,
9*da0073e9SAndroid Build Coastguard Worker    BaseCType,
10*da0073e9SAndroid Build Coastguard Worker    Binding,
11*da0073e9SAndroid Build Coastguard Worker    boolT,
12*da0073e9SAndroid Build Coastguard Worker    ConstRefCType,
13*da0073e9SAndroid Build Coastguard Worker    CType,
14*da0073e9SAndroid Build Coastguard Worker    deviceT,
15*da0073e9SAndroid Build Coastguard Worker    layoutT,
16*da0073e9SAndroid Build Coastguard Worker    ListCType,
17*da0073e9SAndroid Build Coastguard Worker    MutRefCType,
18*da0073e9SAndroid Build Coastguard Worker    NamedCType,
19*da0073e9SAndroid Build Coastguard Worker    OptionalCType,
20*da0073e9SAndroid Build Coastguard Worker    scalarT,
21*da0073e9SAndroid Build Coastguard Worker    scalarTypeT,
22*da0073e9SAndroid Build Coastguard Worker    tensorT,
23*da0073e9SAndroid Build Coastguard Worker)
24*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
25*da0073e9SAndroid Build Coastguard Worker    Argument,
26*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
27*da0073e9SAndroid Build Coastguard Worker    Return,
28*da0073e9SAndroid Build Coastguard Worker    SelfArgument,
29*da0073e9SAndroid Build Coastguard Worker    TensorOptionsArguments,
30*da0073e9SAndroid Build Coastguard Worker    Type,
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the native functions API.
36*da0073e9SAndroid Build Coastguard Worker# This looks a lot like the C++ API (which makes historical sense, because the
37*da0073e9SAndroid Build Coastguard Worker# idea was you wrote native functions to implement functions in the C++ API),
38*da0073e9SAndroid Build Coastguard Worker# but over time we have evolved the C++ API without actually changing our
39*da0073e9SAndroid Build Coastguard Worker# native:: kernels.  The intention is to make native API and dispatcher API
40*da0073e9SAndroid Build Coastguard Worker# line up as closely as possible, since this results in the least overhead
41*da0073e9SAndroid Build Coastguard Worker# (no translation is needed from dispatcher API to native API).
42*da0073e9SAndroid Build Coastguard Worker#
43*da0073e9SAndroid Build Coastguard Worker# NB: this is symint aware, you will get the non-SymInt variant for some
44*da0073e9SAndroid Build Coastguard Worker# dispatch entries and SymInt for others.
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef name(func: FunctionSchema) -> str:
48*da0073e9SAndroid Build Coastguard Worker    name = str(func.name.name)
49*da0073e9SAndroid Build Coastguard Worker    # TODO: delete this!
50*da0073e9SAndroid Build Coastguard Worker    if func.is_out_fn():
51*da0073e9SAndroid Build Coastguard Worker        name += "_out"
52*da0073e9SAndroid Build Coastguard Worker    if func.name.overload_name:
53*da0073e9SAndroid Build Coastguard Worker        name += f"_{func.name.overload_name}"
54*da0073e9SAndroid Build Coastguard Worker    return name
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type(
58*da0073e9SAndroid Build Coastguard Worker    t: Type, *, mutable: bool, binds: ArgName, symint: bool
59*da0073e9SAndroid Build Coastguard Worker) -> NamedCType:
60*da0073e9SAndroid Build Coastguard Worker    if str(t) == "Tensor?":
61*da0073e9SAndroid Build Coastguard Worker        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
62*da0073e9SAndroid Build Coastguard Worker        if mutable and not local.use_const_ref_for_mutable_tensors():
63*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, MutRefCType(tensor_type))
64*da0073e9SAndroid Build Coastguard Worker        else:
65*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ConstRefCType(tensor_type))
66*da0073e9SAndroid Build Coastguard Worker    elif str(t) == "Tensor?[]":
67*da0073e9SAndroid Build Coastguard Worker        return NamedCType(
68*da0073e9SAndroid Build Coastguard Worker            binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
69*da0073e9SAndroid Build Coastguard Worker        )
70*da0073e9SAndroid Build Coastguard Worker    elif str(t) == "Scalar":
71*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
72*da0073e9SAndroid Build Coastguard Worker    elif str(t) == "Scalar?":
73*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
74*da0073e9SAndroid Build Coastguard Worker    return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Workerdef returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
78*da0073e9SAndroid Build Coastguard Worker    return cpp.returns_type(rs, symint=symint)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerdef argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
82*da0073e9SAndroid Build Coastguard Worker    return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef argument(
86*da0073e9SAndroid Build Coastguard Worker    a: Argument | SelfArgument | TensorOptionsArguments,
87*da0073e9SAndroid Build Coastguard Worker    *,
88*da0073e9SAndroid Build Coastguard Worker    is_out: bool,
89*da0073e9SAndroid Build Coastguard Worker    symint: bool,
90*da0073e9SAndroid Build Coastguard Worker) -> list[Binding]:
91*da0073e9SAndroid Build Coastguard Worker    # Ideally, we NEVER default native functions.  However, there are a number
92*da0073e9SAndroid Build Coastguard Worker    # of functions that call native:: directly and rely on the defaulting
93*da0073e9SAndroid Build Coastguard Worker    # existing.  So for BC, we generate defaults for non-out variants (but not
94*da0073e9SAndroid Build Coastguard Worker    # for out variants, where it is impossible to generate an appropriate
95*da0073e9SAndroid Build Coastguard Worker    # default)
96*da0073e9SAndroid Build Coastguard Worker    should_default = not is_out
97*da0073e9SAndroid Build Coastguard Worker    if isinstance(a, Argument):
98*da0073e9SAndroid Build Coastguard Worker        default: str | None = None
99*da0073e9SAndroid Build Coastguard Worker        if should_default and a.default is not None:
100*da0073e9SAndroid Build Coastguard Worker            default = cpp.default_expr(a.default, a.type, symint=symint)
101*da0073e9SAndroid Build Coastguard Worker        return [
102*da0073e9SAndroid Build Coastguard Worker            Binding(
103*da0073e9SAndroid Build Coastguard Worker                nctype=argument_type(a, binds=a.name, symint=symint),
104*da0073e9SAndroid Build Coastguard Worker                name=a.name,
105*da0073e9SAndroid Build Coastguard Worker                default=default,
106*da0073e9SAndroid Build Coastguard Worker                argument=a,
107*da0073e9SAndroid Build Coastguard Worker            )
108*da0073e9SAndroid Build Coastguard Worker        ]
109*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, SelfArgument):
110*da0073e9SAndroid Build Coastguard Worker        # Erase SelfArgument from the distinction
111*da0073e9SAndroid Build Coastguard Worker        return argument(a.argument, is_out=is_out, symint=symint)
112*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, TensorOptionsArguments):
113*da0073e9SAndroid Build Coastguard Worker        default = None
114*da0073e9SAndroid Build Coastguard Worker        if should_default:
115*da0073e9SAndroid Build Coastguard Worker            default = "{}"
116*da0073e9SAndroid Build Coastguard Worker        # TODO: Not sure why the arguments assigned here are for
117*da0073e9SAndroid Build Coastguard Worker        # TensorOptionsArguments and not the constituent pieces.  It seems
118*da0073e9SAndroid Build Coastguard Worker        # to matter
119*da0073e9SAndroid Build Coastguard Worker        return [
120*da0073e9SAndroid Build Coastguard Worker            Binding(
121*da0073e9SAndroid Build Coastguard Worker                nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
122*da0073e9SAndroid Build Coastguard Worker                name="dtype",
123*da0073e9SAndroid Build Coastguard Worker                default=default,
124*da0073e9SAndroid Build Coastguard Worker                argument=a,
125*da0073e9SAndroid Build Coastguard Worker            ),
126*da0073e9SAndroid Build Coastguard Worker            Binding(
127*da0073e9SAndroid Build Coastguard Worker                nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
128*da0073e9SAndroid Build Coastguard Worker                name="layout",
129*da0073e9SAndroid Build Coastguard Worker                default=default,
130*da0073e9SAndroid Build Coastguard Worker                argument=a,
131*da0073e9SAndroid Build Coastguard Worker            ),
132*da0073e9SAndroid Build Coastguard Worker            Binding(
133*da0073e9SAndroid Build Coastguard Worker                nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
134*da0073e9SAndroid Build Coastguard Worker                name="device",
135*da0073e9SAndroid Build Coastguard Worker                default=default,
136*da0073e9SAndroid Build Coastguard Worker                argument=a,
137*da0073e9SAndroid Build Coastguard Worker            ),
138*da0073e9SAndroid Build Coastguard Worker            Binding(
139*da0073e9SAndroid Build Coastguard Worker                nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
140*da0073e9SAndroid Build Coastguard Worker                name="pin_memory",
141*da0073e9SAndroid Build Coastguard Worker                default=default,
142*da0073e9SAndroid Build Coastguard Worker                argument=a,
143*da0073e9SAndroid Build Coastguard Worker            ),
144*da0073e9SAndroid Build Coastguard Worker        ]
145*da0073e9SAndroid Build Coastguard Worker    else:
146*da0073e9SAndroid Build Coastguard Worker        assert_never(a)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
150*da0073e9SAndroid Build Coastguard Worker    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
151*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.non_out)
152*da0073e9SAndroid Build Coastguard Worker    args.extend(func.arguments.out)
153*da0073e9SAndroid Build Coastguard Worker    return [
154*da0073e9SAndroid Build Coastguard Worker        r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
155*da0073e9SAndroid Build Coastguard Worker    ]
156