1from __future__ import annotations 2 3from typing import Sequence 4 5from torchgen import local 6from torchgen.api import cpp 7from torchgen.api.types import ( 8 ArgName, 9 BaseCType, 10 Binding, 11 boolT, 12 ConstRefCType, 13 CType, 14 deviceT, 15 layoutT, 16 ListCType, 17 MutRefCType, 18 NamedCType, 19 OptionalCType, 20 scalarT, 21 scalarTypeT, 22 tensorT, 23) 24from torchgen.model import ( 25 Argument, 26 FunctionSchema, 27 Return, 28 SelfArgument, 29 TensorOptionsArguments, 30 Type, 31) 32from torchgen.utils import assert_never 33 34 35# This file describes the translation of JIT schema to the native functions API. 36# This looks a lot like the C++ API (which makes historical sense, because the 37# idea was you wrote native functions to implement functions in the C++ API), 38# but over time we have evolved the C++ API without actually changing our 39# native:: kernels. The intention is to make native API and dispatcher API 40# line up as closely as possible, since this results in the least overhead 41# (no translation is needed from dispatcher API to native API). 42# 43# NB: this is symint aware, you will get the non-SymInt variant for some 44# dispatch entries and SymInt for others. 45 46 47def name(func: FunctionSchema) -> str: 48 name = str(func.name.name) 49 # TODO: delete this! 50 if func.is_out_fn(): 51 name += "_out" 52 if func.name.overload_name: 53 name += f"_{func.name.overload_name}" 54 return name 55 56 57def argumenttype_type( 58 t: Type, *, mutable: bool, binds: ArgName, symint: bool 59) -> NamedCType: 60 if str(t) == "Tensor?": 61 tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) 62 if mutable and not local.use_const_ref_for_mutable_tensors(): 63 return NamedCType(binds, MutRefCType(tensor_type)) 64 else: 65 return NamedCType(binds, ConstRefCType(tensor_type)) 66 elif str(t) == "Tensor?[]": 67 return NamedCType( 68 binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) 69 ) 70 elif str(t) == "Scalar": 71 return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 72 elif str(t) == "Scalar?": 73 return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) 74 return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) 75 76 77def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: 78 return cpp.returns_type(rs, symint=symint) 79 80 81def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: 82 return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) 83 84 85def argument( 86 a: Argument | SelfArgument | TensorOptionsArguments, 87 *, 88 is_out: bool, 89 symint: bool, 90) -> list[Binding]: 91 # Ideally, we NEVER default native functions. However, there are a number 92 # of functions that call native:: directly and rely on the defaulting 93 # existing. So for BC, we generate defaults for non-out variants (but not 94 # for out variants, where it is impossible to generate an appropriate 95 # default) 96 should_default = not is_out 97 if isinstance(a, Argument): 98 default: str | None = None 99 if should_default and a.default is not None: 100 default = cpp.default_expr(a.default, a.type, symint=symint) 101 return [ 102 Binding( 103 nctype=argument_type(a, binds=a.name, symint=symint), 104 name=a.name, 105 default=default, 106 argument=a, 107 ) 108 ] 109 elif isinstance(a, SelfArgument): 110 # Erase SelfArgument from the distinction 111 return argument(a.argument, is_out=is_out, symint=symint) 112 elif isinstance(a, TensorOptionsArguments): 113 default = None 114 if should_default: 115 default = "{}" 116 # TODO: Not sure why the arguments assigned here are for 117 # TensorOptionsArguments and not the constituent pieces. It seems 118 # to matter 119 return [ 120 Binding( 121 nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), 122 name="dtype", 123 default=default, 124 argument=a, 125 ), 126 Binding( 127 nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), 128 name="layout", 129 default=default, 130 argument=a, 131 ), 132 Binding( 133 nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), 134 name="device", 135 default=default, 136 argument=a, 137 ), 138 Binding( 139 nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), 140 name="pin_memory", 141 default=default, 142 argument=a, 143 ), 144 ] 145 else: 146 assert_never(a) 147 148 149def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: 150 args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 151 args.extend(func.arguments.non_out) 152 args.extend(func.arguments.out) 153 return [ 154 r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) 155 ] 156