1from __future__ import annotations 2 3from typing import NoReturn, Sequence 4 5from torchgen.api.types import ( 6 ArrayRefCType, 7 BaseCType, 8 Binding, 9 boolT, 10 ConstRefCType, 11 deviceT, 12 Expr, 13 intArrayRefT, 14 iOptTensorListRefT, 15 layoutT, 16 ListCType, 17 longT, 18 memoryFormatT, 19 MutRefCType, 20 NamedCType, 21 opmath_t, 22 OptionalCType, 23 optionalIntArrayRefT, 24 optionalScalarRefT, 25 optionalSymIntArrayRefT, 26 optionalTensorRefT, 27 scalar_t, 28 scalarT, 29 scalarTypeT, 30 SpecialArgName, 31 symIntArrayRefT, 32 SymIntT, 33 tensorOptionsT, 34 tensorT, 35 VectorCType, 36) 37 38 39# This file implements a small program synthesis engine that implements 40# conversions between one API to another. 41# 42# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType 43# represents a C++ type, plus semantic information about what it represents. 44# For example, consider the argument "bool pin_memory"; its normal C++ type is 45# "bool", but its C++ semantic type also keeps track that this represents a 46# "pin_memory"; you can't just use a random other boolean in a context where you 47# need a "pin_memory"! 48# 49# The translator takes a list of needed NamedCTypes, and then figures out how 50# to construct expressions with these NamedCTypes from the given bindings. Many 51# of these expressions are trivial (I need a Tensor other; there's a Tensor 52# other scope); others are more nontrivial and may require packing/unpacking. 53# Some examples of non-trivial action: 54# 55# - Need the "dtype" binding? Well, maybe "dtype" isn't available 56# in the context, instead, "options" is, and you need to extract 57# it from there. (Gather) 58# 59# - Need the "context" binding? Well, maybe "context" isn't available 60# in the context, and you need to construct it from "dtype", "device", 61# etc. (Scatter) 62# 63# - Need the "memory_format" binding? Well, actually, it's available 64# from both "memory_format" and "options", so you had better make sure 65# they are consistent. (Join) 66 67options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) 68 69out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) 70 71longVec_ctype = VectorCType(BaseCType(longT)) 72longSymVec_ctype = VectorCType(BaseCType(SymIntT)) 73optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) 74optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) 75optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) 76 77 78class UnsatError(RuntimeError): 79 pass 80 81 82# Given a set of in-scope bindings and a set of target bindings, synthesize 83# a list of expressions that uses only the in-scope bindings (bindings) that 84# have all of the types of goals. You may want to use this function if 85# you're generating code for a function like: 86# 87# void f({args}) { 88# g({exprs}); // g is a different API 89# } 90# 91# and you need to generate "exprs". 92# 93# Typically, a list of Bindings is convenient to get (you usually call something 94# like arguments() to get them); but technically you only need less information: 95# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for 96# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing 97# something more complicated, e.g., tracking the set of bindings in a context, 98# you may find using these smaller types more convenient. 99def translate( 100 bindings: Sequence[Expr | Binding], 101 goals: Sequence[NamedCType | Binding], 102 *, 103 method: bool = False, 104 allow_expensive_conversions: bool = False, 105) -> list[Expr]: 106 binding_exprs: list[Expr] = [] 107 for b in bindings: 108 if isinstance(b, Binding): 109 binding_exprs.append( 110 Expr( 111 expr=b.name, 112 type=b.nctype, 113 ) 114 ) 115 else: 116 binding_exprs.append(b) 117 118 goal_ctypes: list[NamedCType] = [] 119 for g in goals: 120 if isinstance(g, Binding): 121 goal_ctypes.append(g.nctype) 122 else: 123 goal_ctypes.append(g) 124 125 # Add all the bindings to the context 126 ctx: dict[NamedCType, str] = {} 127 for b in binding_exprs: 128 ctx[b.type] = b.expr 129 130 # While we're at it, do some simple forward inference, looking through 131 # constructors. 132 # 133 # NB: When should you do forward inference versus backward inference? 134 # The general idea: 135 # 136 # - Backward inference WHEN the goal gets smaller 137 # - Forward inference WHEN the hypothesis gets smaller 138 # 139 # This helps ensure termination: backward inference starts with a goal 140 # and tries to make it simpler and simpler until it's trivial; if the 141 # goal can grow in size, we blow up to a really huge goal size. 142 # Similarly, with forward inference we take hypotheses and decompose 143 # them into simpler hypotheses; if hypotheses could expand in size, 144 # we also have potential nontermination. (In the code below, forward 145 # inference is only ever carried out at a single step, but you could 146 # imagine repeated application of forward inference being profitable.) 147 # 148 # A good starting point in the literature for exploring more about proof 149 # search are these lecture notes 150 # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf 151 # 152 # TODO: My kingdom for a pattern matcher 153 # https://www.python.org/dev/peps/pep-0634/ 154 # 155 # TODO: This could get us in recomputation trouble if b.expr is nontrivial. 156 # Fix this by implementing some sort of sharing so that if multiple 157 # goals share the same expression, we only compute it once. This seems 158 # to matter in practice as compiler is often unwilling to CSE nontrivial 159 # expressions like scalar.to<scalar_t>() 160 t = b.type 161 if ( 162 isinstance(t, ConstRefCType) 163 and isinstance(t.elem, OptionalCType) 164 and isinstance(t.elem.elem, BaseCType) 165 and str(t.elem.elem.type) == "at::Tensor" 166 ): 167 ctx[ 168 NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT))) 169 ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" 170 171 if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): 172 ctx[ 173 NamedCType(t.name, BaseCType(optionalTensorRefT)) 174 ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" 175 176 if t.type == ConstRefCType(BaseCType(scalarT)): 177 ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()" 178 179 if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): 180 ctx[ 181 NamedCType(t.name, BaseCType(optionalScalarRefT)) 182 ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" 183 184 if t.type == BaseCType(scalar_t): 185 ctx[ 186 NamedCType(t.name, BaseCType(opmath_t)) 187 ] = f"static_cast<opmath_t>({b.expr})" 188 189 # [Note: IOptTensorListRef] 190 if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): 191 ctx[ 192 NamedCType(t.name, BaseCType(iOptTensorListRefT)) 193 ] = f"at::IOptTensorListRef({b.expr})" 194 195 # Add implicit bindings if the generated code is inside a Tensor method 196 if method: 197 ctx[ 198 NamedCType("self", MutRefCType(BaseCType(tensorT))) 199 ] = "const_cast<Tensor&>(*this)" 200 ctx[ 201 NamedCType("self", ConstRefCType(BaseCType(tensorT))) 202 ] = "const_cast<Tensor&>(*this)" 203 # This is better! Byte-for-byte compat 204 # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" 205 206 def unsat(goal: NamedCType) -> NoReturn: 207 ctx_desc = "\n".join( 208 f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() 209 ) 210 raise UnsatError( 211 f""" 212Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". 213When I failed, the following bindings were available in the context: 214 215{ctx_desc} 216 217This probably means there is a missing rule in the rules of torchgen.api.translate. 218Check this module for more information. 219""" 220 ) 221 222 # A shitty backtracking search implementation. It's shitty because it 223 # does backtracking via stack (bad idea!) and for the most part tries to 224 # avoid backtracking. In particular, if 225 # direct=True, we won't try to do any fancy synthesis, just trivial 226 # conversions (e.g., "T a" is OK for "const T& a"). So all of the 227 # existing rules in this function simply try to solve immediately, 228 # and bail if things don't work out. 229 def solve(goal: NamedCType, *, direct: bool) -> str: 230 def direct_solve(goal: NamedCType) -> str: 231 return solve(goal, direct=True) 232 233 if goal in ctx: 234 # Trivial 235 return ctx[goal] 236 237 # const & is satisfied with mutable & 238 if isinstance(goal.type, ConstRefCType): 239 try: 240 # WARNING: not strictly decreasing; be careful not 241 # to add a direct conversion that goes satisfies 242 # mutable& with const& 243 return solve( 244 NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct 245 ) 246 except UnsatError: 247 pass 248 249 # mutable & is satisfied with value 250 if isinstance(goal.type, MutRefCType): 251 try: 252 return solve(NamedCType(goal.name, goal.type.elem), direct=direct) 253 except UnsatError: 254 pass 255 256 # TODO: These are referentially equal, shouldn't have to do this; 257 # ensuring we don't use type synonym IntArrayRef in codegen would 258 # help 259 if goal.type == ArrayRefCType(BaseCType(longT)): 260 return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) 261 262 if direct: 263 unsat(goal) 264 265 # For now, all of these rules are mutually exclusive. 266 if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): 267 memory_format = direct_solve( 268 NamedCType( 269 SpecialArgName.possibly_redundant_memory_format, 270 OptionalCType(BaseCType(memoryFormatT)), 271 ) 272 ) 273 # No need to join "memory_format" and "options" if the target API takes "options" directly. 274 # Otherwise it will cause the redundant memory_format error. 275 if options_ctype in goal_ctypes: 276 return memory_format 277 try: 278 options = direct_solve(options_ctype) 279 return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" 280 except UnsatError: 281 return memory_format 282 elif goal == NamedCType("options", BaseCType(tensorOptionsT)): 283 dtype = direct_solve( 284 NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) 285 ) 286 pin_memory = direct_solve( 287 NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) 288 ) 289 device = direct_solve( 290 NamedCType("device", OptionalCType(BaseCType(deviceT))) 291 ) 292 layout = direct_solve( 293 NamedCType("layout", OptionalCType(BaseCType(layoutT))) 294 ) 295 return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" 296 297 elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): 298 try: 299 options = direct_solve(options_ctype) 300 return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" 301 except UnsatError: 302 out_tensor = direct_solve(out_tensor_ctype) 303 return f"{out_tensor}.scalar_type()" 304 305 elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): 306 try: 307 options = direct_solve(options_ctype) 308 return f"{options}.layout_opt()" 309 except UnsatError: 310 out_tensor = direct_solve(out_tensor_ctype) 311 return f"{out_tensor}.layout()" 312 313 elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): 314 try: 315 options = direct_solve(options_ctype) 316 return f"{options}.device_opt()" 317 except UnsatError: 318 out_tensor = direct_solve(out_tensor_ctype) 319 return f"{out_tensor}.device()" 320 321 elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): 322 try: 323 options = direct_solve(options_ctype) 324 return f"{options}.pinned_memory_opt()" 325 except UnsatError: 326 # If we're calling a factory op from its out= variant, 327 # We don't actually care about the value of pin_memory. 328 out_tensor = direct_solve(out_tensor_ctype) 329 return "::std::nullopt" 330 331 # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef 332 elif goal.type == BaseCType(intArrayRefT): 333 try: 334 return direct_solve(NamedCType(goal.name, longVec_ctype)) 335 except UnsatError: 336 # We can also go SymIntArrayRef -> IntArrayRef 337 symIntArrayRef_type = direct_solve( 338 NamedCType(goal.name, BaseCType(symIntArrayRefT)) 339 ) 340 return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" 341 elif goal.type == BaseCType(symIntArrayRefT): 342 try: 343 r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) 344 return f"c10::fromIntArrayRefSlow({r})" 345 except UnsatError: 346 return direct_solve(NamedCType(goal.name, longSymVec_ctype)) 347 elif goal.type == BaseCType(SymIntT): 348 return direct_solve(NamedCType(goal.name, BaseCType(longT))) 349 elif goal.type == OptionalCType(BaseCType(SymIntT)): 350 argname = direct_solve( 351 NamedCType(goal.name, OptionalCType(BaseCType(longT))) 352 ) 353 return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" 354 elif goal.type == BaseCType(longT): 355 symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) 356 return f"{symInt_type}.guard_int(__FILE__, __LINE__)" 357 elif goal.type == OptionalCType(BaseCType(longT)): 358 argname = direct_solve( 359 NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) 360 ) 361 return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" 362 elif goal.type == BaseCType(optionalIntArrayRefT): 363 try: 364 return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) 365 except UnsatError: 366 argname = direct_solve( 367 NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) 368 ) 369 return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" 370 elif goal.type == BaseCType(optionalSymIntArrayRefT): 371 # TODO: You might also want to solve this from longSymVec_ctype or 372 # an optional version of it 373 argname = direct_solve( 374 NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) 375 ) 376 return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" 377 elif goal.type == BaseCType(optionalScalarRefT): 378 return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) 379 elif goal.type == BaseCType(optionalTensorRefT): 380 return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) 381 382 # Note [translation from C++ reference to value types] 383 # The below cases are all for when we have an argument with a reference type, 384 # and a corresponding goal with a value type. 385 # These are needed when we populate the inputs to a lambda capture and we need 386 # to guarantee the lifetime of each captured argument. 387 # We guard it with an explicit kwarg because converting to a value type is expensive 388 # (O(n)) to convert from IntArrayRef to vector<int>), 389 # so the caller of translate() should be explicit that they need it. 390 if allow_expensive_conversions: 391 if goal.type == VectorCType(BaseCType(longT)): 392 intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) 393 argname = direct_solve(intArrayRef_ctype) 394 return f"{argname}.vec()" 395 if goal.type == VectorCType(BaseCType(SymIntT)): 396 symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) 397 argname = direct_solve(symIntArrayRef_ctype) 398 return f"{argname}.vec()" 399 elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): 400 optionalIntArrayRef_ctype = NamedCType( 401 goal.name, BaseCType(optionalIntArrayRefT) 402 ) 403 argname = direct_solve(optionalIntArrayRef_ctype) 404 return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" 405 elif goal.type == OptionalCType(BaseCType(scalarT)): 406 optionalScalarRef_ctype = NamedCType( 407 goal.name, BaseCType(optionalScalarRefT) 408 ) 409 argname = direct_solve(optionalScalarRef_ctype) 410 return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" 411 elif goal.type == OptionalCType(BaseCType(scalarT)): 412 optionalTensorRef_ctype = NamedCType( 413 goal.name, BaseCType(optionalTensorRefT) 414 ) 415 argname = direct_solve(optionalTensorRef_ctype) 416 return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" 417 # Technically, we also need to handle cases of C++ containers holding reference types. 418 # But there currently aren't any ops that require lambda capture codegen 419 # With arguments like ::std::vector<IntArrayRef>. 420 # If that changes, we'll have to add the translation here. 421 422 # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. 423 # We could probably generalize this to non-tensor types too. 424 if goal.type == MutRefCType(BaseCType(tensorT)): 425 const_ref_tensor_ctype = NamedCType( 426 goal.name, ConstRefCType(BaseCType(tensorT)) 427 ) 428 argname = direct_solve(const_ref_tensor_ctype) 429 return f"const_cast<Tensor&>({argname})" 430 431 unsat(goal) 432 433 return [Expr(solve(g, direct=False), g) for g in goal_ctypes] 434