1from __future__ import annotations 2 3from torchgen.api import cpp 4from torchgen.api.types import Binding, CppSignatureGroup, CType 5from torchgen.model import ( 6 Argument, 7 BaseTy, 8 BaseType, 9 ListType, 10 NativeFunction, 11 OptionalType, 12 Type, 13) 14 15 16# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the 17# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is 18# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the 19# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register 20# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase. 21# 22# Here's an example on how the codegen works: 23# 24# - Function Schema (source of truth) 25# 26# aten::empty.names(int[] size, *, Dimname[]? names, 27# ScalarType? dtype=None, Layout? layout=None, 28# Device? device=None, bool? pin_memory=None, 29# MemoryFormat? memory_format=None) -> Tensor 30# - Argument Conversion 31# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type. 32# - int[] size 33# ```cpp 34# const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList(); 35# 36# std::vector<int64_t> size_vec; 37# for (c10::IValue size_elem: size_list_in) { 38# int64_t size_base = size_elem.to<int64_t>(); 39# size_vec.push_back(size_base); 40# } 41# at::ArrayRef<int64_t> size_list_out(size_vec); 42# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack. 43# Will be passed to unboxed kernel. 44# ``` 45# - Dimname[]? names 46# ```cpp 47# ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>(); 48# ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out; 49# if (names_opt.has_value()) { 50# ~~~~~~~~~~~ <-- Unwrapping optional shell 51# const c10::IValue names_opt_in = names_opt.value(); 52# const c10::List<c10::IValue> names_list_in = names_opt_in.toList(); 53# 54# std::vector<at::Dimname> names_vec; 55# for (c10::IValue names_elem: names_list_in) { 56# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one. 57# at::Dimname names_base = names_elem.to<at::Dimname>(); 58# names_vec.push_back(names_base); 59# } 60# at::ArrayRef<at::Dimname> names_list_out(names_vec); 61# 62# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out); 63# } else { 64# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(); 65# } 66# ``` 67# - ScalarType? dtype (similarly for the rest of the arguments) 68# ```cpp 69# ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>(); 70# ::std::optional<at::ScalarType> dtype_opt_out; 71# if (dtype_opt.has_value()) { 72# const c10::IValue dtype_opt_in = dtype_opt.value(); 73# at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>(); 74# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it 75# directly using ".to<T>()" API. 76# dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base); 77# } else { 78# dtype_opt_out = ::std::optional<at::ScalarType>(); 79# } 80# ``` 81# 82# - Unboxed Kernel Call 83# ```cpp 84# auto result_ = torch::empty( 85# size_list_out, 86# names_opt_out, 87# options, 88# memory_format_opt_out 89# ); 90# ``` 91# 92# - Push Result Back to Stack 93# ```cpp 94# drop(stack, 7); 95# pack(stack, std::move(result_)); 96# ``` 97connector = "\n\t" 98 99 100# Return unboxing function name for a NativeFunction 101def name(f: NativeFunction) -> str: 102 return f.func.name.unambiguous_name() 103 104 105# Convert all the arguments in a NativeFunction to C++ code 106def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: 107 # we need the 'self' argument so method needs to be False 108 args = ( 109 CppSignatureGroup.from_native_function(f, method=False) 110 .most_faithful_signature() 111 .arguments() 112 ) 113 code_list = [ 114 f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));" 115 for i in range(len(args)) 116 ] + [""] 117 binding_list = [] 118 for arg in args: 119 # expecting only Argument 120 if not isinstance(arg.argument, Argument): 121 raise Exception( # noqa: TRY002 122 f"Unexpected argument type, expecting `Argument` but got {arg}" 123 ) 124 argument: Argument = arg.argument 125 unboxed_name, _, code, decl = argumenttype_ivalue_convert( 126 argument.type, 127 argument.name, 128 mutable=argument.is_write, 129 ) 130 code_list.extend(decl) 131 code_list.extend(code) 132 binding_list.append(arg.with_name(unboxed_name)) 133 return binding_list, code_list 134 135 136# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: 137# (1) the C++ code necessary to unbox the argument 138# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType 139def argumenttype_ivalue_convert( 140 t: Type, arg_name: str, *, mutable: bool = False 141) -> tuple[str, CType, list[str], list[str]]: 142 # Unboxing is for mobile, which doesn't care about SymInts 143 ctype = cpp.argumenttype_type( 144 t=t, mutable=mutable, binds=arg_name, symint=False 145 ).type 146 147 if isinstance(t, BaseType): 148 out_name = f"{arg_name}_base" 149 code, decl = _gen_code_base_type( 150 arg_name=arg_name, out_name=out_name, ctype=ctype 151 ) 152 elif isinstance(t, OptionalType): 153 out_name = f"{arg_name}_opt_out" 154 code, decl = _gen_code_optional_type( 155 arg_name=arg_name, 156 out_name=out_name, 157 t=t, 158 ctype=ctype, 159 ) 160 elif isinstance(t, ListType): 161 out_name = f"{arg_name}_list_out" 162 code, decl = _gen_code_list_type( 163 arg_name=arg_name, 164 out_name=out_name, 165 t=t, 166 ctype=ctype, 167 ) 168 else: 169 raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002 170 return out_name, ctype, code, decl 171 172 173def _gen_code_base_type( 174 arg_name: str, out_name: str, ctype: CType 175) -> tuple[list[str], list[str]]: 176 return [ 177 f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" 178 ], [] 179 180 181def _gen_code_optional_type( 182 arg_name: str, out_name: str, t: OptionalType, ctype: CType 183) -> tuple[list[str], list[str]]: 184 in_name = f"{arg_name}_opt_in" 185 res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) 186 return ( 187 f""" 188auto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>(); 189{ctype.cpp_type(strip_ref=True)} {out_name}; 190if ({arg_name}_opt.has_value()) {{ 191 const c10::IValue {in_name} = {arg_name}_opt.value(); 192 {connector.join(res_code)} 193 {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name}); 194}} else {{ 195 {out_name} = {ctype.cpp_type(strip_ref=True)}(); 196}} 197 """.split( 198 "\n" 199 ), 200 decl, 201 ) 202 203 204def _gen_code_list_type( 205 arg_name: str, out_name: str, t: ListType, ctype: CType 206) -> tuple[list[str], list[str]]: 207 in_name = f"{arg_name}_list_in" 208 elem_name = f"{arg_name}_elem" 209 code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"] 210 res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name) 211 # handle list type with size, e.g., bool[4] 212 if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size: 213 code.extend( 214 f""" 215{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name}); 216 """.split( 217 "\n" 218 ) 219 ) 220 # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>> 221 elif isinstance(t.elem, OptionalType): 222 code.extend( 223 f""" 224{ctype.cpp_type(strip_ref=True)} {out_name}; 225for (c10::IValue {elem_name}: {in_name}) {{ 226 {connector.join(res_code)} 227 {out_name}.push_back({res_name}); 228}} 229 """.split( 230 "\n" 231 ) 232 ) 233 else: 234 # use ArrayRef as default. 235 vec_name = arg_name + "_vec" 236 # need to bring vector instantiation out of scope so that ArrayRef has valid data 237 decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};") 238 code.extend( 239 f""" 240for (c10::IValue {elem_name}: {in_name}) {{ 241 {connector.join(res_code)} 242 {vec_name}.push_back({res_name}); 243}} 244{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); 245 """.split( 246 "\n" 247 ) 248 ) 249 return code, decl 250