xref: /aosp_15_r20/external/pytorch/torchgen/api/structured.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
5*da0073e9SAndroid Build Coastguard Worker    ArgName,
6*da0073e9SAndroid Build Coastguard Worker    ArrayRefCType,
7*da0073e9SAndroid Build Coastguard Worker    BaseCType,
8*da0073e9SAndroid Build Coastguard Worker    Binding,
9*da0073e9SAndroid Build Coastguard Worker    ConstRefCType,
10*da0073e9SAndroid Build Coastguard Worker    dimnameListT,
11*da0073e9SAndroid Build Coastguard Worker    intArrayRefT,
12*da0073e9SAndroid Build Coastguard Worker    iOptTensorListRefT,
13*da0073e9SAndroid Build Coastguard Worker    iTensorListRefT,
14*da0073e9SAndroid Build Coastguard Worker    NamedCType,
15*da0073e9SAndroid Build Coastguard Worker    OptionalCType,
16*da0073e9SAndroid Build Coastguard Worker    optionalIntArrayRefT,
17*da0073e9SAndroid Build Coastguard Worker    optionalScalarRefT,
18*da0073e9SAndroid Build Coastguard Worker    optionalTensorRefT,
19*da0073e9SAndroid Build Coastguard Worker    scalarT,
20*da0073e9SAndroid Build Coastguard Worker    tensorT,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
23*da0073e9SAndroid Build Coastguard Worker    Argument,
24*da0073e9SAndroid Build Coastguard Worker    BaseTy,
25*da0073e9SAndroid Build Coastguard Worker    BaseType,
26*da0073e9SAndroid Build Coastguard Worker    ListType,
27*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsGroup,
28*da0073e9SAndroid Build Coastguard Worker    OptionalType,
29*da0073e9SAndroid Build Coastguard Worker    SelfArgument,
30*da0073e9SAndroid Build Coastguard Worker    TensorOptionsArguments,
31*da0073e9SAndroid Build Coastguard Worker    Type,
32*da0073e9SAndroid Build Coastguard Worker)
33*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the structured functions API.
37*da0073e9SAndroid Build Coastguard Worker# This is similar to native API, but a number of historical problems with native
38*da0073e9SAndroid Build Coastguard Worker# API have been fixed.
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker# Translation of types occurring in JIT arguments to a C++ argument type.
42*da0073e9SAndroid Build Coastguard Worker# NB: For now, mutable doesn't do anything; but it could if we make
43*da0073e9SAndroid Build Coastguard Worker# some more nominal types
44*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
45*da0073e9SAndroid Build Coastguard Worker    # If it's a value type, do the value type translation
46*da0073e9SAndroid Build Coastguard Worker    # NB: structured kernels ALWAYS have symint off, since they involve actual
47*da0073e9SAndroid Build Coastguard Worker    # kernels that require real ints.  The one exception is the
48*da0073e9SAndroid Build Coastguard Worker    # CompositeExplicitAutograd and the meta function (which could
49*da0073e9SAndroid Build Coastguard Worker    # hypothetically be SymInt), but for simplicity we plan for these to just
50*da0073e9SAndroid Build Coastguard Worker    # be handled in Python
51*da0073e9SAndroid Build Coastguard Worker    r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
52*da0073e9SAndroid Build Coastguard Worker    if r is not None:
53*da0073e9SAndroid Build Coastguard Worker        return r
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
56*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.Tensor:
57*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
58*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Scalar:
59*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
60*da0073e9SAndroid Build Coastguard Worker        else:
61*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(f"base type should have been value type {t}")
62*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
63*da0073e9SAndroid Build Coastguard Worker        if t.elem == BaseType(BaseTy.Tensor):
64*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(optionalTensorRefT))
65*da0073e9SAndroid Build Coastguard Worker        elif t.elem == BaseType(BaseTy.Scalar):
66*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(optionalScalarRefT))
67*da0073e9SAndroid Build Coastguard Worker        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
68*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
69*da0073e9SAndroid Build Coastguard Worker        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
70*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, OptionalCType(elem.type))
71*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
72*da0073e9SAndroid Build Coastguard Worker        if t.elem == BaseType(BaseTy.Tensor):
73*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
74*da0073e9SAndroid Build Coastguard Worker        elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
75*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(iOptTensorListRefT))
76*da0073e9SAndroid Build Coastguard Worker        # TODO: delete these special cases; see torchgen.api.cpp--these
77*da0073e9SAndroid Build Coastguard Worker        # must be changed in tandem, but there are problems; see
78*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/pull/51485
79*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "int":
80*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(intArrayRefT))
81*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Dimname":
82*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(dimnameListT))
83*da0073e9SAndroid Build Coastguard Worker        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
84*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, ArrayRefCType(elem.type))
85*da0073e9SAndroid Build Coastguard Worker    else:
86*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"unrecognized type {repr(t)}")
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerdef argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
90*da0073e9SAndroid Build Coastguard Worker    return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker# returns_type intentionally omitted, because structured kernels never "return";
94*da0073e9SAndroid Build Coastguard Worker# instead, they always indirectly report their outputs (in the case of a meta
95*da0073e9SAndroid Build Coastguard Worker# function, by calling set_output; in the case of an impl function, by writing
96*da0073e9SAndroid Build Coastguard Worker# directly into the provided out argument).
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker# Structured kernels are never defaulted
100*da0073e9SAndroid Build Coastguard Workerdef argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
101*da0073e9SAndroid Build Coastguard Worker    if isinstance(a, Argument):
102*da0073e9SAndroid Build Coastguard Worker        return [
103*da0073e9SAndroid Build Coastguard Worker            Binding(
104*da0073e9SAndroid Build Coastguard Worker                nctype=argument_type(a, binds=a.name),
105*da0073e9SAndroid Build Coastguard Worker                name=a.name,
106*da0073e9SAndroid Build Coastguard Worker                default=None,
107*da0073e9SAndroid Build Coastguard Worker                argument=a,
108*da0073e9SAndroid Build Coastguard Worker            )
109*da0073e9SAndroid Build Coastguard Worker        ]
110*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, SelfArgument):
111*da0073e9SAndroid Build Coastguard Worker        return argument(a.argument)
112*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, TensorOptionsArguments):
113*da0073e9SAndroid Build Coastguard Worker        raise AssertionError("structured kernels don't support TensorOptions yet")
114*da0073e9SAndroid Build Coastguard Worker    else:
115*da0073e9SAndroid Build Coastguard Worker        assert_never(a)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerdef impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
119*da0073e9SAndroid Build Coastguard Worker    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    if g.out.precomputed:
122*da0073e9SAndroid Build Coastguard Worker        # A list of parameters for the impl function with
123*da0073e9SAndroid Build Coastguard Worker        # certain parameters replaced with precomputed counterparts
124*da0073e9SAndroid Build Coastguard Worker        # as specified in native_functions.yaml.
125*da0073e9SAndroid Build Coastguard Worker        non_out_args_replaced: list[
126*da0073e9SAndroid Build Coastguard Worker            Argument | TensorOptionsArguments | SelfArgument
127*da0073e9SAndroid Build Coastguard Worker        ] = []
128*da0073e9SAndroid Build Coastguard Worker        for a in g.out.func.arguments.non_out:
129*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
130*da0073e9SAndroid Build Coastguard Worker                # If a is in precompute.replace, append the parameters
131*da0073e9SAndroid Build Coastguard Worker                # that should replace it onto non_out_args_replaced.
132*da0073e9SAndroid Build Coastguard Worker                non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
133*da0073e9SAndroid Build Coastguard Worker            else:
134*da0073e9SAndroid Build Coastguard Worker                # If not, push a as it is.
135*da0073e9SAndroid Build Coastguard Worker                non_out_args_replaced.append(a)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker        args.extend(non_out_args_replaced)
138*da0073e9SAndroid Build Coastguard Worker        # g.out.precomputed.add is the list of parameters that are added
139*da0073e9SAndroid Build Coastguard Worker        # without replacement after the non out args and just before the out args
140*da0073e9SAndroid Build Coastguard Worker        args.extend(g.out.precomputed.add)
141*da0073e9SAndroid Build Coastguard Worker    else:
142*da0073e9SAndroid Build Coastguard Worker        args.extend(g.out.func.arguments.non_out)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    args.extend(g.out.func.arguments.out)
145*da0073e9SAndroid Build Coastguard Worker    return [r for arg in args for r in argument(arg)]
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Workerdef meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
149*da0073e9SAndroid Build Coastguard Worker    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
150*da0073e9SAndroid Build Coastguard Worker    args.extend(g.functional.func.arguments.non_out)
151*da0073e9SAndroid Build Coastguard Worker    return [r for arg in args for r in argument(arg)]
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Workerdef out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
155*da0073e9SAndroid Build Coastguard Worker    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
156*da0073e9SAndroid Build Coastguard Worker    args.extend(g.out.func.arguments.out)
157*da0073e9SAndroid Build Coastguard Worker    return [r for arg in args for r in argument(arg)]
158