1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 5*da0073e9SAndroid Build Coastguard Worker ArgName, 6*da0073e9SAndroid Build Coastguard Worker ArrayRefCType, 7*da0073e9SAndroid Build Coastguard Worker BaseCType, 8*da0073e9SAndroid Build Coastguard Worker Binding, 9*da0073e9SAndroid Build Coastguard Worker ConstRefCType, 10*da0073e9SAndroid Build Coastguard Worker dimnameListT, 11*da0073e9SAndroid Build Coastguard Worker intArrayRefT, 12*da0073e9SAndroid Build Coastguard Worker iOptTensorListRefT, 13*da0073e9SAndroid Build Coastguard Worker iTensorListRefT, 14*da0073e9SAndroid Build Coastguard Worker NamedCType, 15*da0073e9SAndroid Build Coastguard Worker OptionalCType, 16*da0073e9SAndroid Build Coastguard Worker optionalIntArrayRefT, 17*da0073e9SAndroid Build Coastguard Worker optionalScalarRefT, 18*da0073e9SAndroid Build Coastguard Worker optionalTensorRefT, 19*da0073e9SAndroid Build Coastguard Worker scalarT, 20*da0073e9SAndroid Build Coastguard Worker tensorT, 21*da0073e9SAndroid Build Coastguard Worker) 22*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 23*da0073e9SAndroid Build Coastguard Worker Argument, 24*da0073e9SAndroid Build Coastguard Worker BaseTy, 25*da0073e9SAndroid Build Coastguard Worker BaseType, 26*da0073e9SAndroid Build Coastguard Worker ListType, 27*da0073e9SAndroid Build Coastguard Worker NativeFunctionsGroup, 28*da0073e9SAndroid Build Coastguard Worker OptionalType, 29*da0073e9SAndroid Build Coastguard Worker SelfArgument, 30*da0073e9SAndroid Build Coastguard Worker TensorOptionsArguments, 31*da0073e9SAndroid Build Coastguard Worker Type, 32*da0073e9SAndroid Build Coastguard Worker) 33*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the structured functions API. 37*da0073e9SAndroid Build Coastguard Worker# This is similar to native API, but a number of historical problems with native 38*da0073e9SAndroid Build Coastguard Worker# API have been fixed. 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker# Translation of types occurring in JIT arguments to a C++ argument type. 42*da0073e9SAndroid Build Coastguard Worker# NB: For now, mutable doesn't do anything; but it could if we make 43*da0073e9SAndroid Build Coastguard Worker# some more nominal types 44*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: 45*da0073e9SAndroid Build Coastguard Worker # If it's a value type, do the value type translation 46*da0073e9SAndroid Build Coastguard Worker # NB: structured kernels ALWAYS have symint off, since they involve actual 47*da0073e9SAndroid Build Coastguard Worker # kernels that require real ints. The one exception is the 48*da0073e9SAndroid Build Coastguard Worker # CompositeExplicitAutograd and the meta function (which could 49*da0073e9SAndroid Build Coastguard Worker # hypothetically be SymInt), but for simplicity we plan for these to just 50*da0073e9SAndroid Build Coastguard Worker # be handled in Python 51*da0073e9SAndroid Build Coastguard Worker r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) 52*da0073e9SAndroid Build Coastguard Worker if r is not None: 53*da0073e9SAndroid Build Coastguard Worker return r 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 56*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.Tensor: 57*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) 58*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Scalar: 59*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 60*da0073e9SAndroid Build Coastguard Worker else: 61*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"base type should have been value type {t}") 62*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 63*da0073e9SAndroid Build Coastguard Worker if t.elem == BaseType(BaseTy.Tensor): 64*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(optionalTensorRefT)) 65*da0073e9SAndroid Build Coastguard Worker elif t.elem == BaseType(BaseTy.Scalar): 66*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(optionalScalarRefT)) 67*da0073e9SAndroid Build Coastguard Worker elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": 68*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(optionalIntArrayRefT)) 69*da0073e9SAndroid Build Coastguard Worker elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) 70*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, OptionalCType(elem.type)) 71*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 72*da0073e9SAndroid Build Coastguard Worker if t.elem == BaseType(BaseTy.Tensor): 73*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) 74*da0073e9SAndroid Build Coastguard Worker elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): 75*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(iOptTensorListRefT)) 76*da0073e9SAndroid Build Coastguard Worker # TODO: delete these special cases; see torchgen.api.cpp--these 77*da0073e9SAndroid Build Coastguard Worker # must be changed in tandem, but there are problems; see 78*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/pull/51485 79*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "int": 80*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(intArrayRefT)) 81*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Dimname": 82*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(dimnameListT)) 83*da0073e9SAndroid Build Coastguard Worker elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) 84*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ArrayRefCType(elem.type)) 85*da0073e9SAndroid Build Coastguard Worker else: 86*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized type {repr(t)}") 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerdef argument_type(a: Argument, *, binds: ArgName) -> NamedCType: 90*da0073e9SAndroid Build Coastguard Worker return argumenttype_type(a.type, mutable=a.is_write, binds=binds) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker# returns_type intentionally omitted, because structured kernels never "return"; 94*da0073e9SAndroid Build Coastguard Worker# instead, they always indirectly report their outputs (in the case of a meta 95*da0073e9SAndroid Build Coastguard Worker# function, by calling set_output; in the case of an impl function, by writing 96*da0073e9SAndroid Build Coastguard Worker# directly into the provided out argument). 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker# Structured kernels are never defaulted 100*da0073e9SAndroid Build Coastguard Workerdef argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: 101*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Argument): 102*da0073e9SAndroid Build Coastguard Worker return [ 103*da0073e9SAndroid Build Coastguard Worker Binding( 104*da0073e9SAndroid Build Coastguard Worker nctype=argument_type(a, binds=a.name), 105*da0073e9SAndroid Build Coastguard Worker name=a.name, 106*da0073e9SAndroid Build Coastguard Worker default=None, 107*da0073e9SAndroid Build Coastguard Worker argument=a, 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker ] 110*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, SelfArgument): 111*da0073e9SAndroid Build Coastguard Worker return argument(a.argument) 112*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, TensorOptionsArguments): 113*da0073e9SAndroid Build Coastguard Worker raise AssertionError("structured kernels don't support TensorOptions yet") 114*da0073e9SAndroid Build Coastguard Worker else: 115*da0073e9SAndroid Build Coastguard Worker assert_never(a) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Workerdef impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: 119*da0073e9SAndroid Build Coastguard Worker args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker if g.out.precomputed: 122*da0073e9SAndroid Build Coastguard Worker # A list of parameters for the impl function with 123*da0073e9SAndroid Build Coastguard Worker # certain parameters replaced with precomputed counterparts 124*da0073e9SAndroid Build Coastguard Worker # as specified in native_functions.yaml. 125*da0073e9SAndroid Build Coastguard Worker non_out_args_replaced: list[ 126*da0073e9SAndroid Build Coastguard Worker Argument | TensorOptionsArguments | SelfArgument 127*da0073e9SAndroid Build Coastguard Worker ] = [] 128*da0073e9SAndroid Build Coastguard Worker for a in g.out.func.arguments.non_out: 129*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Argument) and a.name in g.out.precomputed.replace: 130*da0073e9SAndroid Build Coastguard Worker # If a is in precompute.replace, append the parameters 131*da0073e9SAndroid Build Coastguard Worker # that should replace it onto non_out_args_replaced. 132*da0073e9SAndroid Build Coastguard Worker non_out_args_replaced.extend(g.out.precomputed.replace[a.name]) 133*da0073e9SAndroid Build Coastguard Worker else: 134*da0073e9SAndroid Build Coastguard Worker # If not, push a as it is. 135*da0073e9SAndroid Build Coastguard Worker non_out_args_replaced.append(a) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker args.extend(non_out_args_replaced) 138*da0073e9SAndroid Build Coastguard Worker # g.out.precomputed.add is the list of parameters that are added 139*da0073e9SAndroid Build Coastguard Worker # without replacement after the non out args and just before the out args 140*da0073e9SAndroid Build Coastguard Worker args.extend(g.out.precomputed.add) 141*da0073e9SAndroid Build Coastguard Worker else: 142*da0073e9SAndroid Build Coastguard Worker args.extend(g.out.func.arguments.non_out) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker args.extend(g.out.func.arguments.out) 145*da0073e9SAndroid Build Coastguard Worker return [r for arg in args for r in argument(arg)] 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Workerdef meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: 149*da0073e9SAndroid Build Coastguard Worker args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 150*da0073e9SAndroid Build Coastguard Worker args.extend(g.functional.func.arguments.non_out) 151*da0073e9SAndroid Build Coastguard Worker return [r for arg in args for r in argument(arg)] 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Workerdef out_arguments(g: NativeFunctionsGroup) -> list[Binding]: 155*da0073e9SAndroid Build Coastguard Worker args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 156*da0073e9SAndroid Build Coastguard Worker args.extend(g.out.func.arguments.out) 157*da0073e9SAndroid Build Coastguard Worker return [r for arg in args for r in argument(arg)] 158