1from __future__ import annotations 2 3from dataclasses import dataclass 4from typing import Callable, Sequence, TYPE_CHECKING 5 6from torchgen.model import ( 7 Argument, 8 BaseTy, 9 BaseType, 10 ListType, 11 NativeFunction, 12 OptionalType, 13 Type, 14) 15 16 17if TYPE_CHECKING: 18 from torchgen.api.types import Binding, CType, NamedCType 19 20 21connector = "\n\t" 22 23 24# Return unboxing function name for a NativeFunction 25def name(f: NativeFunction) -> str: 26 return f.func.name.unambiguous_name() 27 28 29@dataclass(frozen=True) 30class Unboxing: 31 """ 32 Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. 33 A sample generated code: 34 // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 35 void mul_out(EValue** stack) { 36 EValue& self = *stack[0]; 37 EValue& other = *stack[1]; 38 EValue& out = *stack[2]; 39 const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>(); 40 const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>(); 41 torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>(); 42 43 EXECUTORCH_SCOPE_PROF("native_call_mul.out"); 44 torch::executor::mul_outf(self_base, other_base, out_base); 45 46 47 } 48 """ 49 50 # this is a callable that converts a JIT argument, into its C++ type. 51 # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type. 52 argument_type_gen: Callable[ 53 ..., 54 NamedCType, 55 ] 56 57 # Convert all the arguments in a NativeFunction to C++ code 58 def convert_arguments( 59 self, args: Sequence[Binding] 60 ) -> tuple[list[Binding], list[str]]: 61 code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] 62 binding_list = [] 63 for arg in args: 64 # expecting only Argument 65 if not isinstance(arg.argument, Argument): 66 raise Exception( # noqa: TRY002 67 f"Unexpected argument type, expecting `Argument` but got {arg}" 68 ) 69 argument: Argument = arg.argument 70 unboxed_name, _, code, decl = self.argumenttype_evalue_convert( 71 argument.type, argument.name, mutable=argument.is_write 72 ) 73 code_list.extend(decl) 74 code_list.extend(code) 75 binding_list.append(arg.with_name(unboxed_name)) 76 return binding_list, code_list 77 78 def argumenttype_evalue_convert( 79 self, t: Type, arg_name: str, *, mutable: bool = False 80 ) -> tuple[str, CType, list[str], list[str]]: 81 """ 82 Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: 83 (1) the C++ code necessary to unbox the argument 84 (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType 85 :param t: a `Type` of an argument 86 :param arg_name: argument name 87 :param mutable: boolean for whether this argument type is mutable 88 :return: unboxed result 89 """ 90 ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type 91 92 if isinstance(t, BaseType): 93 out_name = f"{arg_name}_base" 94 code, decl = self._gen_code_base_type( 95 arg_name=arg_name, out_name=out_name, ctype=ctype 96 ) 97 elif isinstance(t, OptionalType): 98 out_name = f"{arg_name}_opt_out" 99 code, decl = self._gen_code_optional_type( 100 arg_name=arg_name, out_name=out_name, t=t, ctype=ctype 101 ) 102 elif isinstance(t, ListType): 103 out_name = f"{arg_name}_list_out" 104 code, decl = self._gen_code_list_type( 105 arg_name=arg_name, out_name=out_name, t=t, ctype=ctype 106 ) 107 else: 108 raise Exception( # noqa: TRY002 109 f"Cannot handle type {t}. arg_name: {arg_name}" 110 ) # noqa: TRY002 111 return out_name, ctype, code, decl 112 113 def _gen_code_base_type( 114 self, arg_name: str, out_name: str, ctype: CType 115 ) -> tuple[list[str], list[str]]: 116 return [ 117 f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" 118 ], [] 119 120 def _gen_code_optional_type( 121 self, arg_name: str, out_name: str, t: OptionalType, ctype: CType 122 ) -> tuple[list[str], list[str]]: 123 in_name = f"{arg_name}_opt_in" 124 res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( 125 t.elem, in_name 126 ) 127 return ( 128 f""" 129 auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); 130 """.split( 131 "\n" 132 ), 133 decl, 134 ) 135 136 def _gen_code_list_type( 137 self, arg_name: str, out_name: str, t: ListType, ctype: CType 138 ) -> tuple[list[str], list[str]]: 139 in_name = f"{arg_name}_list_in" 140 elem_name = f"{arg_name}_elem" 141 code = [] 142 res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert( 143 t.elem, elem_name 144 ) 145 146 if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: 147 code.extend( 148 f""" 149 auto {out_name} = {arg_name}.toTensorList(); 150 """.split( 151 "\n" 152 ) 153 ) 154 elif isinstance(t.elem, BaseType) and ( 155 t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt 156 ): 157 code.extend( 158 f""" 159 auto {out_name} = {arg_name}.toIntList(); 160 """.split( 161 "\n" 162 ) 163 ) 164 elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: 165 code.extend( 166 f""" 167 auto {out_name} = {arg_name}.toDoubleList(); 168 """.split( 169 "\n" 170 ) 171 ) 172 elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: 173 # handle list type with size, e.g., bool[4] 174 code.extend( 175 f""" 176#ifdef USE_ATEN_LIB 177std::array<bool, {t.size}> {out_name}; 178auto {in_name} = {arg_name}.toBoolList(); 179size_t _i = 0; 180for (auto {elem_name}: {in_name}) {{ 181 {out_name}[_i++] = {elem_name}; 182}} 183#else 184auto {out_name} = {arg_name}.toBoolList(); 185#endif 186 """.split( 187 "\n" 188 ) 189 ) 190 # pytorch codegen: 191 # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>> 192 elif ( 193 isinstance(t.elem, OptionalType) 194 and isinstance(t.elem.elem, BaseType) 195 and t.elem.elem.name == BaseTy.Tensor 196 ): 197 code.extend( 198 f""" 199#ifdef USE_ATEN_LIB 200auto {in_name} = {arg_name}.toListOptionalTensor(); 201c10::List<::std::optional<at::Tensor>> {out_name}; 202for (auto {elem_name}: {in_name}) {{ 203 {out_name}.push_back({elem_name}); 204}} 205#else 206auto {out_name} = {arg_name}.toListOptionalTensor(); 207#endif 208 """.split( 209 "\n" 210 ) 211 ) 212 else: 213 # use ArrayRef as default. 214 vec_name = arg_name + "_vec" 215 # need to bring vector instantiation out of scope so that ArrayRef has valid data 216 decl.append( 217 f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" 218 ) 219 code.extend( 220 f""" 221 for (EValue {elem_name}: {in_name}) {{ 222 {connector.join(res_code)} 223 {vec_name}.push_back({res_name}); 224 }} 225 {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); 226 """.split( 227 "\n" 228 ) 229 ) 230 return code, decl 231