1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom torchgen import local 6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 8*da0073e9SAndroid Build Coastguard Worker ArgName, 9*da0073e9SAndroid Build Coastguard Worker BaseCType, 10*da0073e9SAndroid Build Coastguard Worker Binding, 11*da0073e9SAndroid Build Coastguard Worker boolT, 12*da0073e9SAndroid Build Coastguard Worker ConstRefCType, 13*da0073e9SAndroid Build Coastguard Worker CType, 14*da0073e9SAndroid Build Coastguard Worker deviceT, 15*da0073e9SAndroid Build Coastguard Worker layoutT, 16*da0073e9SAndroid Build Coastguard Worker ListCType, 17*da0073e9SAndroid Build Coastguard Worker MutRefCType, 18*da0073e9SAndroid Build Coastguard Worker NamedCType, 19*da0073e9SAndroid Build Coastguard Worker OptionalCType, 20*da0073e9SAndroid Build Coastguard Worker scalarT, 21*da0073e9SAndroid Build Coastguard Worker scalarTypeT, 22*da0073e9SAndroid Build Coastguard Worker tensorT, 23*da0073e9SAndroid Build Coastguard Worker) 24*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 25*da0073e9SAndroid Build Coastguard Worker Argument, 26*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 27*da0073e9SAndroid Build Coastguard Worker Return, 28*da0073e9SAndroid Build Coastguard Worker SelfArgument, 29*da0073e9SAndroid Build Coastguard Worker TensorOptionsArguments, 30*da0073e9SAndroid Build Coastguard Worker Type, 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the native functions API. 36*da0073e9SAndroid Build Coastguard Worker# This looks a lot like the C++ API (which makes historical sense, because the 37*da0073e9SAndroid Build Coastguard Worker# idea was you wrote native functions to implement functions in the C++ API), 38*da0073e9SAndroid Build Coastguard Worker# but over time we have evolved the C++ API without actually changing our 39*da0073e9SAndroid Build Coastguard Worker# native:: kernels. The intention is to make native API and dispatcher API 40*da0073e9SAndroid Build Coastguard Worker# line up as closely as possible, since this results in the least overhead 41*da0073e9SAndroid Build Coastguard Worker# (no translation is needed from dispatcher API to native API). 42*da0073e9SAndroid Build Coastguard Worker# 43*da0073e9SAndroid Build Coastguard Worker# NB: this is symint aware, you will get the non-SymInt variant for some 44*da0073e9SAndroid Build Coastguard Worker# dispatch entries and SymInt for others. 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerdef name(func: FunctionSchema) -> str: 48*da0073e9SAndroid Build Coastguard Worker name = str(func.name.name) 49*da0073e9SAndroid Build Coastguard Worker # TODO: delete this! 50*da0073e9SAndroid Build Coastguard Worker if func.is_out_fn(): 51*da0073e9SAndroid Build Coastguard Worker name += "_out" 52*da0073e9SAndroid Build Coastguard Worker if func.name.overload_name: 53*da0073e9SAndroid Build Coastguard Worker name += f"_{func.name.overload_name}" 54*da0073e9SAndroid Build Coastguard Worker return name 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type( 58*da0073e9SAndroid Build Coastguard Worker t: Type, *, mutable: bool, binds: ArgName, symint: bool 59*da0073e9SAndroid Build Coastguard Worker) -> NamedCType: 60*da0073e9SAndroid Build Coastguard Worker if str(t) == "Tensor?": 61*da0073e9SAndroid Build Coastguard Worker tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) 62*da0073e9SAndroid Build Coastguard Worker if mutable and not local.use_const_ref_for_mutable_tensors(): 63*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, MutRefCType(tensor_type)) 64*da0073e9SAndroid Build Coastguard Worker else: 65*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(tensor_type)) 66*da0073e9SAndroid Build Coastguard Worker elif str(t) == "Tensor?[]": 67*da0073e9SAndroid Build Coastguard Worker return NamedCType( 68*da0073e9SAndroid Build Coastguard Worker binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) 69*da0073e9SAndroid Build Coastguard Worker ) 70*da0073e9SAndroid Build Coastguard Worker elif str(t) == "Scalar": 71*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 72*da0073e9SAndroid Build Coastguard Worker elif str(t) == "Scalar?": 73*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) 74*da0073e9SAndroid Build Coastguard Worker return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Workerdef returns_type(rs: Sequence[Return], *, symint: bool) -> CType: 78*da0073e9SAndroid Build Coastguard Worker return cpp.returns_type(rs, symint=symint) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerdef argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: 82*da0073e9SAndroid Build Coastguard Worker return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef argument( 86*da0073e9SAndroid Build Coastguard Worker a: Argument | SelfArgument | TensorOptionsArguments, 87*da0073e9SAndroid Build Coastguard Worker *, 88*da0073e9SAndroid Build Coastguard Worker is_out: bool, 89*da0073e9SAndroid Build Coastguard Worker symint: bool, 90*da0073e9SAndroid Build Coastguard Worker) -> list[Binding]: 91*da0073e9SAndroid Build Coastguard Worker # Ideally, we NEVER default native functions. However, there are a number 92*da0073e9SAndroid Build Coastguard Worker # of functions that call native:: directly and rely on the defaulting 93*da0073e9SAndroid Build Coastguard Worker # existing. So for BC, we generate defaults for non-out variants (but not 94*da0073e9SAndroid Build Coastguard Worker # for out variants, where it is impossible to generate an appropriate 95*da0073e9SAndroid Build Coastguard Worker # default) 96*da0073e9SAndroid Build Coastguard Worker should_default = not is_out 97*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Argument): 98*da0073e9SAndroid Build Coastguard Worker default: str | None = None 99*da0073e9SAndroid Build Coastguard Worker if should_default and a.default is not None: 100*da0073e9SAndroid Build Coastguard Worker default = cpp.default_expr(a.default, a.type, symint=symint) 101*da0073e9SAndroid Build Coastguard Worker return [ 102*da0073e9SAndroid Build Coastguard Worker Binding( 103*da0073e9SAndroid Build Coastguard Worker nctype=argument_type(a, binds=a.name, symint=symint), 104*da0073e9SAndroid Build Coastguard Worker name=a.name, 105*da0073e9SAndroid Build Coastguard Worker default=default, 106*da0073e9SAndroid Build Coastguard Worker argument=a, 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker ] 109*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, SelfArgument): 110*da0073e9SAndroid Build Coastguard Worker # Erase SelfArgument from the distinction 111*da0073e9SAndroid Build Coastguard Worker return argument(a.argument, is_out=is_out, symint=symint) 112*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, TensorOptionsArguments): 113*da0073e9SAndroid Build Coastguard Worker default = None 114*da0073e9SAndroid Build Coastguard Worker if should_default: 115*da0073e9SAndroid Build Coastguard Worker default = "{}" 116*da0073e9SAndroid Build Coastguard Worker # TODO: Not sure why the arguments assigned here are for 117*da0073e9SAndroid Build Coastguard Worker # TensorOptionsArguments and not the constituent pieces. It seems 118*da0073e9SAndroid Build Coastguard Worker # to matter 119*da0073e9SAndroid Build Coastguard Worker return [ 120*da0073e9SAndroid Build Coastguard Worker Binding( 121*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), 122*da0073e9SAndroid Build Coastguard Worker name="dtype", 123*da0073e9SAndroid Build Coastguard Worker default=default, 124*da0073e9SAndroid Build Coastguard Worker argument=a, 125*da0073e9SAndroid Build Coastguard Worker ), 126*da0073e9SAndroid Build Coastguard Worker Binding( 127*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), 128*da0073e9SAndroid Build Coastguard Worker name="layout", 129*da0073e9SAndroid Build Coastguard Worker default=default, 130*da0073e9SAndroid Build Coastguard Worker argument=a, 131*da0073e9SAndroid Build Coastguard Worker ), 132*da0073e9SAndroid Build Coastguard Worker Binding( 133*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), 134*da0073e9SAndroid Build Coastguard Worker name="device", 135*da0073e9SAndroid Build Coastguard Worker default=default, 136*da0073e9SAndroid Build Coastguard Worker argument=a, 137*da0073e9SAndroid Build Coastguard Worker ), 138*da0073e9SAndroid Build Coastguard Worker Binding( 139*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), 140*da0073e9SAndroid Build Coastguard Worker name="pin_memory", 141*da0073e9SAndroid Build Coastguard Worker default=default, 142*da0073e9SAndroid Build Coastguard Worker argument=a, 143*da0073e9SAndroid Build Coastguard Worker ), 144*da0073e9SAndroid Build Coastguard Worker ] 145*da0073e9SAndroid Build Coastguard Worker else: 146*da0073e9SAndroid Build Coastguard Worker assert_never(a) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: 150*da0073e9SAndroid Build Coastguard Worker args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 151*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.non_out) 152*da0073e9SAndroid Build Coastguard Worker args.extend(func.arguments.out) 153*da0073e9SAndroid Build Coastguard Worker return [ 154*da0073e9SAndroid Build Coastguard Worker r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) 155*da0073e9SAndroid Build Coastguard Worker ] 156