1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import NoReturn, Sequence 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 6*da0073e9SAndroid Build Coastguard Worker ArrayRefCType, 7*da0073e9SAndroid Build Coastguard Worker BaseCType, 8*da0073e9SAndroid Build Coastguard Worker Binding, 9*da0073e9SAndroid Build Coastguard Worker boolT, 10*da0073e9SAndroid Build Coastguard Worker ConstRefCType, 11*da0073e9SAndroid Build Coastguard Worker deviceT, 12*da0073e9SAndroid Build Coastguard Worker Expr, 13*da0073e9SAndroid Build Coastguard Worker intArrayRefT, 14*da0073e9SAndroid Build Coastguard Worker iOptTensorListRefT, 15*da0073e9SAndroid Build Coastguard Worker layoutT, 16*da0073e9SAndroid Build Coastguard Worker ListCType, 17*da0073e9SAndroid Build Coastguard Worker longT, 18*da0073e9SAndroid Build Coastguard Worker memoryFormatT, 19*da0073e9SAndroid Build Coastguard Worker MutRefCType, 20*da0073e9SAndroid Build Coastguard Worker NamedCType, 21*da0073e9SAndroid Build Coastguard Worker opmath_t, 22*da0073e9SAndroid Build Coastguard Worker OptionalCType, 23*da0073e9SAndroid Build Coastguard Worker optionalIntArrayRefT, 24*da0073e9SAndroid Build Coastguard Worker optionalScalarRefT, 25*da0073e9SAndroid Build Coastguard Worker optionalSymIntArrayRefT, 26*da0073e9SAndroid Build Coastguard Worker optionalTensorRefT, 27*da0073e9SAndroid Build Coastguard Worker scalar_t, 28*da0073e9SAndroid Build Coastguard Worker scalarT, 29*da0073e9SAndroid Build Coastguard Worker scalarTypeT, 30*da0073e9SAndroid Build Coastguard Worker SpecialArgName, 31*da0073e9SAndroid Build Coastguard Worker symIntArrayRefT, 32*da0073e9SAndroid Build Coastguard Worker SymIntT, 33*da0073e9SAndroid Build Coastguard Worker tensorOptionsT, 34*da0073e9SAndroid Build Coastguard Worker tensorT, 35*da0073e9SAndroid Build Coastguard Worker VectorCType, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker# This file implements a small program synthesis engine that implements 40*da0073e9SAndroid Build Coastguard Worker# conversions between one API to another. 41*da0073e9SAndroid Build Coastguard Worker# 42*da0073e9SAndroid Build Coastguard Worker# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType 43*da0073e9SAndroid Build Coastguard Worker# represents a C++ type, plus semantic information about what it represents. 44*da0073e9SAndroid Build Coastguard Worker# For example, consider the argument "bool pin_memory"; its normal C++ type is 45*da0073e9SAndroid Build Coastguard Worker# "bool", but its C++ semantic type also keeps track that this represents a 46*da0073e9SAndroid Build Coastguard Worker# "pin_memory"; you can't just use a random other boolean in a context where you 47*da0073e9SAndroid Build Coastguard Worker# need a "pin_memory"! 48*da0073e9SAndroid Build Coastguard Worker# 49*da0073e9SAndroid Build Coastguard Worker# The translator takes a list of needed NamedCTypes, and then figures out how 50*da0073e9SAndroid Build Coastguard Worker# to construct expressions with these NamedCTypes from the given bindings. Many 51*da0073e9SAndroid Build Coastguard Worker# of these expressions are trivial (I need a Tensor other; there's a Tensor 52*da0073e9SAndroid Build Coastguard Worker# other scope); others are more nontrivial and may require packing/unpacking. 53*da0073e9SAndroid Build Coastguard Worker# Some examples of non-trivial action: 54*da0073e9SAndroid Build Coastguard Worker# 55*da0073e9SAndroid Build Coastguard Worker# - Need the "dtype" binding? Well, maybe "dtype" isn't available 56*da0073e9SAndroid Build Coastguard Worker# in the context, instead, "options" is, and you need to extract 57*da0073e9SAndroid Build Coastguard Worker# it from there. (Gather) 58*da0073e9SAndroid Build Coastguard Worker# 59*da0073e9SAndroid Build Coastguard Worker# - Need the "context" binding? Well, maybe "context" isn't available 60*da0073e9SAndroid Build Coastguard Worker# in the context, and you need to construct it from "dtype", "device", 61*da0073e9SAndroid Build Coastguard Worker# etc. (Scatter) 62*da0073e9SAndroid Build Coastguard Worker# 63*da0073e9SAndroid Build Coastguard Worker# - Need the "memory_format" binding? Well, actually, it's available 64*da0073e9SAndroid Build Coastguard Worker# from both "memory_format" and "options", so you had better make sure 65*da0073e9SAndroid Build Coastguard Worker# they are consistent. (Join) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workeroptions_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerout_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard WorkerlongVec_ctype = VectorCType(BaseCType(longT)) 72*da0073e9SAndroid Build Coastguard WorkerlongSymVec_ctype = VectorCType(BaseCType(SymIntT)) 73*da0073e9SAndroid Build Coastguard WorkeroptionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) 74*da0073e9SAndroid Build Coastguard WorkeroptionalScalar_ctype = OptionalCType(BaseCType(scalarT)) 75*da0073e9SAndroid Build Coastguard WorkeroptionalTensor_ctype = OptionalCType(BaseCType(tensorT)) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workerclass UnsatError(RuntimeError): 79*da0073e9SAndroid Build Coastguard Worker pass 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker# Given a set of in-scope bindings and a set of target bindings, synthesize 83*da0073e9SAndroid Build Coastguard Worker# a list of expressions that uses only the in-scope bindings (bindings) that 84*da0073e9SAndroid Build Coastguard Worker# have all of the types of goals. You may want to use this function if 85*da0073e9SAndroid Build Coastguard Worker# you're generating code for a function like: 86*da0073e9SAndroid Build Coastguard Worker# 87*da0073e9SAndroid Build Coastguard Worker# void f({args}) { 88*da0073e9SAndroid Build Coastguard Worker# g({exprs}); // g is a different API 89*da0073e9SAndroid Build Coastguard Worker# } 90*da0073e9SAndroid Build Coastguard Worker# 91*da0073e9SAndroid Build Coastguard Worker# and you need to generate "exprs". 92*da0073e9SAndroid Build Coastguard Worker# 93*da0073e9SAndroid Build Coastguard Worker# Typically, a list of Bindings is convenient to get (you usually call something 94*da0073e9SAndroid Build Coastguard Worker# like arguments() to get them); but technically you only need less information: 95*da0073e9SAndroid Build Coastguard Worker# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for 96*da0073e9SAndroid Build Coastguard Worker# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing 97*da0073e9SAndroid Build Coastguard Worker# something more complicated, e.g., tracking the set of bindings in a context, 98*da0073e9SAndroid Build Coastguard Worker# you may find using these smaller types more convenient. 99*da0073e9SAndroid Build Coastguard Workerdef translate( 100*da0073e9SAndroid Build Coastguard Worker bindings: Sequence[Expr | Binding], 101*da0073e9SAndroid Build Coastguard Worker goals: Sequence[NamedCType | Binding], 102*da0073e9SAndroid Build Coastguard Worker *, 103*da0073e9SAndroid Build Coastguard Worker method: bool = False, 104*da0073e9SAndroid Build Coastguard Worker allow_expensive_conversions: bool = False, 105*da0073e9SAndroid Build Coastguard Worker) -> list[Expr]: 106*da0073e9SAndroid Build Coastguard Worker binding_exprs: list[Expr] = [] 107*da0073e9SAndroid Build Coastguard Worker for b in bindings: 108*da0073e9SAndroid Build Coastguard Worker if isinstance(b, Binding): 109*da0073e9SAndroid Build Coastguard Worker binding_exprs.append( 110*da0073e9SAndroid Build Coastguard Worker Expr( 111*da0073e9SAndroid Build Coastguard Worker expr=b.name, 112*da0073e9SAndroid Build Coastguard Worker type=b.nctype, 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker else: 116*da0073e9SAndroid Build Coastguard Worker binding_exprs.append(b) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker goal_ctypes: list[NamedCType] = [] 119*da0073e9SAndroid Build Coastguard Worker for g in goals: 120*da0073e9SAndroid Build Coastguard Worker if isinstance(g, Binding): 121*da0073e9SAndroid Build Coastguard Worker goal_ctypes.append(g.nctype) 122*da0073e9SAndroid Build Coastguard Worker else: 123*da0073e9SAndroid Build Coastguard Worker goal_ctypes.append(g) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # Add all the bindings to the context 126*da0073e9SAndroid Build Coastguard Worker ctx: dict[NamedCType, str] = {} 127*da0073e9SAndroid Build Coastguard Worker for b in binding_exprs: 128*da0073e9SAndroid Build Coastguard Worker ctx[b.type] = b.expr 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker # While we're at it, do some simple forward inference, looking through 131*da0073e9SAndroid Build Coastguard Worker # constructors. 132*da0073e9SAndroid Build Coastguard Worker # 133*da0073e9SAndroid Build Coastguard Worker # NB: When should you do forward inference versus backward inference? 134*da0073e9SAndroid Build Coastguard Worker # The general idea: 135*da0073e9SAndroid Build Coastguard Worker # 136*da0073e9SAndroid Build Coastguard Worker # - Backward inference WHEN the goal gets smaller 137*da0073e9SAndroid Build Coastguard Worker # - Forward inference WHEN the hypothesis gets smaller 138*da0073e9SAndroid Build Coastguard Worker # 139*da0073e9SAndroid Build Coastguard Worker # This helps ensure termination: backward inference starts with a goal 140*da0073e9SAndroid Build Coastguard Worker # and tries to make it simpler and simpler until it's trivial; if the 141*da0073e9SAndroid Build Coastguard Worker # goal can grow in size, we blow up to a really huge goal size. 142*da0073e9SAndroid Build Coastguard Worker # Similarly, with forward inference we take hypotheses and decompose 143*da0073e9SAndroid Build Coastguard Worker # them into simpler hypotheses; if hypotheses could expand in size, 144*da0073e9SAndroid Build Coastguard Worker # we also have potential nontermination. (In the code below, forward 145*da0073e9SAndroid Build Coastguard Worker # inference is only ever carried out at a single step, but you could 146*da0073e9SAndroid Build Coastguard Worker # imagine repeated application of forward inference being profitable.) 147*da0073e9SAndroid Build Coastguard Worker # 148*da0073e9SAndroid Build Coastguard Worker # A good starting point in the literature for exploring more about proof 149*da0073e9SAndroid Build Coastguard Worker # search are these lecture notes 150*da0073e9SAndroid Build Coastguard Worker # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf 151*da0073e9SAndroid Build Coastguard Worker # 152*da0073e9SAndroid Build Coastguard Worker # TODO: My kingdom for a pattern matcher 153*da0073e9SAndroid Build Coastguard Worker # https://www.python.org/dev/peps/pep-0634/ 154*da0073e9SAndroid Build Coastguard Worker # 155*da0073e9SAndroid Build Coastguard Worker # TODO: This could get us in recomputation trouble if b.expr is nontrivial. 156*da0073e9SAndroid Build Coastguard Worker # Fix this by implementing some sort of sharing so that if multiple 157*da0073e9SAndroid Build Coastguard Worker # goals share the same expression, we only compute it once. This seems 158*da0073e9SAndroid Build Coastguard Worker # to matter in practice as compiler is often unwilling to CSE nontrivial 159*da0073e9SAndroid Build Coastguard Worker # expressions like scalar.to<scalar_t>() 160*da0073e9SAndroid Build Coastguard Worker t = b.type 161*da0073e9SAndroid Build Coastguard Worker if ( 162*da0073e9SAndroid Build Coastguard Worker isinstance(t, ConstRefCType) 163*da0073e9SAndroid Build Coastguard Worker and isinstance(t.elem, OptionalCType) 164*da0073e9SAndroid Build Coastguard Worker and isinstance(t.elem.elem, BaseCType) 165*da0073e9SAndroid Build Coastguard Worker and str(t.elem.elem.type) == "at::Tensor" 166*da0073e9SAndroid Build Coastguard Worker ): 167*da0073e9SAndroid Build Coastguard Worker ctx[ 168*da0073e9SAndroid Build Coastguard Worker NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT))) 169*da0073e9SAndroid Build Coastguard Worker ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): 172*da0073e9SAndroid Build Coastguard Worker ctx[ 173*da0073e9SAndroid Build Coastguard Worker NamedCType(t.name, BaseCType(optionalTensorRefT)) 174*da0073e9SAndroid Build Coastguard Worker ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker if t.type == ConstRefCType(BaseCType(scalarT)): 177*da0073e9SAndroid Build Coastguard Worker ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()" 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): 180*da0073e9SAndroid Build Coastguard Worker ctx[ 181*da0073e9SAndroid Build Coastguard Worker NamedCType(t.name, BaseCType(optionalScalarRefT)) 182*da0073e9SAndroid Build Coastguard Worker ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker if t.type == BaseCType(scalar_t): 185*da0073e9SAndroid Build Coastguard Worker ctx[ 186*da0073e9SAndroid Build Coastguard Worker NamedCType(t.name, BaseCType(opmath_t)) 187*da0073e9SAndroid Build Coastguard Worker ] = f"static_cast<opmath_t>({b.expr})" 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker # [Note: IOptTensorListRef] 190*da0073e9SAndroid Build Coastguard Worker if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): 191*da0073e9SAndroid Build Coastguard Worker ctx[ 192*da0073e9SAndroid Build Coastguard Worker NamedCType(t.name, BaseCType(iOptTensorListRefT)) 193*da0073e9SAndroid Build Coastguard Worker ] = f"at::IOptTensorListRef({b.expr})" 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker # Add implicit bindings if the generated code is inside a Tensor method 196*da0073e9SAndroid Build Coastguard Worker if method: 197*da0073e9SAndroid Build Coastguard Worker ctx[ 198*da0073e9SAndroid Build Coastguard Worker NamedCType("self", MutRefCType(BaseCType(tensorT))) 199*da0073e9SAndroid Build Coastguard Worker ] = "const_cast<Tensor&>(*this)" 200*da0073e9SAndroid Build Coastguard Worker ctx[ 201*da0073e9SAndroid Build Coastguard Worker NamedCType("self", ConstRefCType(BaseCType(tensorT))) 202*da0073e9SAndroid Build Coastguard Worker ] = "const_cast<Tensor&>(*this)" 203*da0073e9SAndroid Build Coastguard Worker # This is better! Byte-for-byte compat 204*da0073e9SAndroid Build Coastguard Worker # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker def unsat(goal: NamedCType) -> NoReturn: 207*da0073e9SAndroid Build Coastguard Worker ctx_desc = "\n".join( 208*da0073e9SAndroid Build Coastguard Worker f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() 209*da0073e9SAndroid Build Coastguard Worker ) 210*da0073e9SAndroid Build Coastguard Worker raise UnsatError( 211*da0073e9SAndroid Build Coastguard Worker f""" 212*da0073e9SAndroid Build Coastguard WorkerFailed to synthesize the expression "{goal.cpp_type()} {goal.name}". 213*da0073e9SAndroid Build Coastguard WorkerWhen I failed, the following bindings were available in the context: 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker{ctx_desc} 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard WorkerThis probably means there is a missing rule in the rules of torchgen.api.translate. 218*da0073e9SAndroid Build Coastguard WorkerCheck this module for more information. 219*da0073e9SAndroid Build Coastguard Worker""" 220*da0073e9SAndroid Build Coastguard Worker ) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker # A shitty backtracking search implementation. It's shitty because it 223*da0073e9SAndroid Build Coastguard Worker # does backtracking via stack (bad idea!) and for the most part tries to 224*da0073e9SAndroid Build Coastguard Worker # avoid backtracking. In particular, if 225*da0073e9SAndroid Build Coastguard Worker # direct=True, we won't try to do any fancy synthesis, just trivial 226*da0073e9SAndroid Build Coastguard Worker # conversions (e.g., "T a" is OK for "const T& a"). So all of the 227*da0073e9SAndroid Build Coastguard Worker # existing rules in this function simply try to solve immediately, 228*da0073e9SAndroid Build Coastguard Worker # and bail if things don't work out. 229*da0073e9SAndroid Build Coastguard Worker def solve(goal: NamedCType, *, direct: bool) -> str: 230*da0073e9SAndroid Build Coastguard Worker def direct_solve(goal: NamedCType) -> str: 231*da0073e9SAndroid Build Coastguard Worker return solve(goal, direct=True) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker if goal in ctx: 234*da0073e9SAndroid Build Coastguard Worker # Trivial 235*da0073e9SAndroid Build Coastguard Worker return ctx[goal] 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker # const & is satisfied with mutable & 238*da0073e9SAndroid Build Coastguard Worker if isinstance(goal.type, ConstRefCType): 239*da0073e9SAndroid Build Coastguard Worker try: 240*da0073e9SAndroid Build Coastguard Worker # WARNING: not strictly decreasing; be careful not 241*da0073e9SAndroid Build Coastguard Worker # to add a direct conversion that goes satisfies 242*da0073e9SAndroid Build Coastguard Worker # mutable& with const& 243*da0073e9SAndroid Build Coastguard Worker return solve( 244*da0073e9SAndroid Build Coastguard Worker NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct 245*da0073e9SAndroid Build Coastguard Worker ) 246*da0073e9SAndroid Build Coastguard Worker except UnsatError: 247*da0073e9SAndroid Build Coastguard Worker pass 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker # mutable & is satisfied with value 250*da0073e9SAndroid Build Coastguard Worker if isinstance(goal.type, MutRefCType): 251*da0073e9SAndroid Build Coastguard Worker try: 252*da0073e9SAndroid Build Coastguard Worker return solve(NamedCType(goal.name, goal.type.elem), direct=direct) 253*da0073e9SAndroid Build Coastguard Worker except UnsatError: 254*da0073e9SAndroid Build Coastguard Worker pass 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker # TODO: These are referentially equal, shouldn't have to do this; 257*da0073e9SAndroid Build Coastguard Worker # ensuring we don't use type synonym IntArrayRef in codegen would 258*da0073e9SAndroid Build Coastguard Worker # help 259*da0073e9SAndroid Build Coastguard Worker if goal.type == ArrayRefCType(BaseCType(longT)): 260*da0073e9SAndroid Build Coastguard Worker return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker if direct: 263*da0073e9SAndroid Build Coastguard Worker unsat(goal) 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker # For now, all of these rules are mutually exclusive. 266*da0073e9SAndroid Build Coastguard Worker if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): 267*da0073e9SAndroid Build Coastguard Worker memory_format = direct_solve( 268*da0073e9SAndroid Build Coastguard Worker NamedCType( 269*da0073e9SAndroid Build Coastguard Worker SpecialArgName.possibly_redundant_memory_format, 270*da0073e9SAndroid Build Coastguard Worker OptionalCType(BaseCType(memoryFormatT)), 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker ) 273*da0073e9SAndroid Build Coastguard Worker # No need to join "memory_format" and "options" if the target API takes "options" directly. 274*da0073e9SAndroid Build Coastguard Worker # Otherwise it will cause the redundant memory_format error. 275*da0073e9SAndroid Build Coastguard Worker if options_ctype in goal_ctypes: 276*da0073e9SAndroid Build Coastguard Worker return memory_format 277*da0073e9SAndroid Build Coastguard Worker try: 278*da0073e9SAndroid Build Coastguard Worker options = direct_solve(options_ctype) 279*da0073e9SAndroid Build Coastguard Worker return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" 280*da0073e9SAndroid Build Coastguard Worker except UnsatError: 281*da0073e9SAndroid Build Coastguard Worker return memory_format 282*da0073e9SAndroid Build Coastguard Worker elif goal == NamedCType("options", BaseCType(tensorOptionsT)): 283*da0073e9SAndroid Build Coastguard Worker dtype = direct_solve( 284*da0073e9SAndroid Build Coastguard Worker NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker pin_memory = direct_solve( 287*da0073e9SAndroid Build Coastguard Worker NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker device = direct_solve( 290*da0073e9SAndroid Build Coastguard Worker NamedCType("device", OptionalCType(BaseCType(deviceT))) 291*da0073e9SAndroid Build Coastguard Worker ) 292*da0073e9SAndroid Build Coastguard Worker layout = direct_solve( 293*da0073e9SAndroid Build Coastguard Worker NamedCType("layout", OptionalCType(BaseCType(layoutT))) 294*da0073e9SAndroid Build Coastguard Worker ) 295*da0073e9SAndroid Build Coastguard Worker return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): 298*da0073e9SAndroid Build Coastguard Worker try: 299*da0073e9SAndroid Build Coastguard Worker options = direct_solve(options_ctype) 300*da0073e9SAndroid Build Coastguard Worker return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" 301*da0073e9SAndroid Build Coastguard Worker except UnsatError: 302*da0073e9SAndroid Build Coastguard Worker out_tensor = direct_solve(out_tensor_ctype) 303*da0073e9SAndroid Build Coastguard Worker return f"{out_tensor}.scalar_type()" 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): 306*da0073e9SAndroid Build Coastguard Worker try: 307*da0073e9SAndroid Build Coastguard Worker options = direct_solve(options_ctype) 308*da0073e9SAndroid Build Coastguard Worker return f"{options}.layout_opt()" 309*da0073e9SAndroid Build Coastguard Worker except UnsatError: 310*da0073e9SAndroid Build Coastguard Worker out_tensor = direct_solve(out_tensor_ctype) 311*da0073e9SAndroid Build Coastguard Worker return f"{out_tensor}.layout()" 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): 314*da0073e9SAndroid Build Coastguard Worker try: 315*da0073e9SAndroid Build Coastguard Worker options = direct_solve(options_ctype) 316*da0073e9SAndroid Build Coastguard Worker return f"{options}.device_opt()" 317*da0073e9SAndroid Build Coastguard Worker except UnsatError: 318*da0073e9SAndroid Build Coastguard Worker out_tensor = direct_solve(out_tensor_ctype) 319*da0073e9SAndroid Build Coastguard Worker return f"{out_tensor}.device()" 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): 322*da0073e9SAndroid Build Coastguard Worker try: 323*da0073e9SAndroid Build Coastguard Worker options = direct_solve(options_ctype) 324*da0073e9SAndroid Build Coastguard Worker return f"{options}.pinned_memory_opt()" 325*da0073e9SAndroid Build Coastguard Worker except UnsatError: 326*da0073e9SAndroid Build Coastguard Worker # If we're calling a factory op from its out= variant, 327*da0073e9SAndroid Build Coastguard Worker # We don't actually care about the value of pin_memory. 328*da0073e9SAndroid Build Coastguard Worker out_tensor = direct_solve(out_tensor_ctype) 329*da0073e9SAndroid Build Coastguard Worker return "::std::nullopt" 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef 332*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(intArrayRefT): 333*da0073e9SAndroid Build Coastguard Worker try: 334*da0073e9SAndroid Build Coastguard Worker return direct_solve(NamedCType(goal.name, longVec_ctype)) 335*da0073e9SAndroid Build Coastguard Worker except UnsatError: 336*da0073e9SAndroid Build Coastguard Worker # We can also go SymIntArrayRef -> IntArrayRef 337*da0073e9SAndroid Build Coastguard Worker symIntArrayRef_type = direct_solve( 338*da0073e9SAndroid Build Coastguard Worker NamedCType(goal.name, BaseCType(symIntArrayRefT)) 339*da0073e9SAndroid Build Coastguard Worker ) 340*da0073e9SAndroid Build Coastguard Worker return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" 341*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(symIntArrayRefT): 342*da0073e9SAndroid Build Coastguard Worker try: 343*da0073e9SAndroid Build Coastguard Worker r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) 344*da0073e9SAndroid Build Coastguard Worker return f"c10::fromIntArrayRefSlow({r})" 345*da0073e9SAndroid Build Coastguard Worker except UnsatError: 346*da0073e9SAndroid Build Coastguard Worker return direct_solve(NamedCType(goal.name, longSymVec_ctype)) 347*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(SymIntT): 348*da0073e9SAndroid Build Coastguard Worker return direct_solve(NamedCType(goal.name, BaseCType(longT))) 349*da0073e9SAndroid Build Coastguard Worker elif goal.type == OptionalCType(BaseCType(SymIntT)): 350*da0073e9SAndroid Build Coastguard Worker argname = direct_solve( 351*da0073e9SAndroid Build Coastguard Worker NamedCType(goal.name, OptionalCType(BaseCType(longT))) 352*da0073e9SAndroid Build Coastguard Worker ) 353*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" 354*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(longT): 355*da0073e9SAndroid Build Coastguard Worker symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) 356*da0073e9SAndroid Build Coastguard Worker return f"{symInt_type}.guard_int(__FILE__, __LINE__)" 357*da0073e9SAndroid Build Coastguard Worker elif goal.type == OptionalCType(BaseCType(longT)): 358*da0073e9SAndroid Build Coastguard Worker argname = direct_solve( 359*da0073e9SAndroid Build Coastguard Worker NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) 360*da0073e9SAndroid Build Coastguard Worker ) 361*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" 362*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(optionalIntArrayRefT): 363*da0073e9SAndroid Build Coastguard Worker try: 364*da0073e9SAndroid Build Coastguard Worker return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) 365*da0073e9SAndroid Build Coastguard Worker except UnsatError: 366*da0073e9SAndroid Build Coastguard Worker argname = direct_solve( 367*da0073e9SAndroid Build Coastguard Worker NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" 370*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(optionalSymIntArrayRefT): 371*da0073e9SAndroid Build Coastguard Worker # TODO: You might also want to solve this from longSymVec_ctype or 372*da0073e9SAndroid Build Coastguard Worker # an optional version of it 373*da0073e9SAndroid Build Coastguard Worker argname = direct_solve( 374*da0073e9SAndroid Build Coastguard Worker NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) 375*da0073e9SAndroid Build Coastguard Worker ) 376*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" 377*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(optionalScalarRefT): 378*da0073e9SAndroid Build Coastguard Worker return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) 379*da0073e9SAndroid Build Coastguard Worker elif goal.type == BaseCType(optionalTensorRefT): 380*da0073e9SAndroid Build Coastguard Worker return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker # Note [translation from C++ reference to value types] 383*da0073e9SAndroid Build Coastguard Worker # The below cases are all for when we have an argument with a reference type, 384*da0073e9SAndroid Build Coastguard Worker # and a corresponding goal with a value type. 385*da0073e9SAndroid Build Coastguard Worker # These are needed when we populate the inputs to a lambda capture and we need 386*da0073e9SAndroid Build Coastguard Worker # to guarantee the lifetime of each captured argument. 387*da0073e9SAndroid Build Coastguard Worker # We guard it with an explicit kwarg because converting to a value type is expensive 388*da0073e9SAndroid Build Coastguard Worker # (O(n)) to convert from IntArrayRef to vector<int>), 389*da0073e9SAndroid Build Coastguard Worker # so the caller of translate() should be explicit that they need it. 390*da0073e9SAndroid Build Coastguard Worker if allow_expensive_conversions: 391*da0073e9SAndroid Build Coastguard Worker if goal.type == VectorCType(BaseCType(longT)): 392*da0073e9SAndroid Build Coastguard Worker intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) 393*da0073e9SAndroid Build Coastguard Worker argname = direct_solve(intArrayRef_ctype) 394*da0073e9SAndroid Build Coastguard Worker return f"{argname}.vec()" 395*da0073e9SAndroid Build Coastguard Worker if goal.type == VectorCType(BaseCType(SymIntT)): 396*da0073e9SAndroid Build Coastguard Worker symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) 397*da0073e9SAndroid Build Coastguard Worker argname = direct_solve(symIntArrayRef_ctype) 398*da0073e9SAndroid Build Coastguard Worker return f"{argname}.vec()" 399*da0073e9SAndroid Build Coastguard Worker elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): 400*da0073e9SAndroid Build Coastguard Worker optionalIntArrayRef_ctype = NamedCType( 401*da0073e9SAndroid Build Coastguard Worker goal.name, BaseCType(optionalIntArrayRefT) 402*da0073e9SAndroid Build Coastguard Worker ) 403*da0073e9SAndroid Build Coastguard Worker argname = direct_solve(optionalIntArrayRef_ctype) 404*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" 405*da0073e9SAndroid Build Coastguard Worker elif goal.type == OptionalCType(BaseCType(scalarT)): 406*da0073e9SAndroid Build Coastguard Worker optionalScalarRef_ctype = NamedCType( 407*da0073e9SAndroid Build Coastguard Worker goal.name, BaseCType(optionalScalarRefT) 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker argname = direct_solve(optionalScalarRef_ctype) 410*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" 411*da0073e9SAndroid Build Coastguard Worker elif goal.type == OptionalCType(BaseCType(scalarT)): 412*da0073e9SAndroid Build Coastguard Worker optionalTensorRef_ctype = NamedCType( 413*da0073e9SAndroid Build Coastguard Worker goal.name, BaseCType(optionalTensorRefT) 414*da0073e9SAndroid Build Coastguard Worker ) 415*da0073e9SAndroid Build Coastguard Worker argname = direct_solve(optionalTensorRef_ctype) 416*da0073e9SAndroid Build Coastguard Worker return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" 417*da0073e9SAndroid Build Coastguard Worker # Technically, we also need to handle cases of C++ containers holding reference types. 418*da0073e9SAndroid Build Coastguard Worker # But there currently aren't any ops that require lambda capture codegen 419*da0073e9SAndroid Build Coastguard Worker # With arguments like ::std::vector<IntArrayRef>. 420*da0073e9SAndroid Build Coastguard Worker # If that changes, we'll have to add the translation here. 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. 423*da0073e9SAndroid Build Coastguard Worker # We could probably generalize this to non-tensor types too. 424*da0073e9SAndroid Build Coastguard Worker if goal.type == MutRefCType(BaseCType(tensorT)): 425*da0073e9SAndroid Build Coastguard Worker const_ref_tensor_ctype = NamedCType( 426*da0073e9SAndroid Build Coastguard Worker goal.name, ConstRefCType(BaseCType(tensorT)) 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker argname = direct_solve(const_ref_tensor_ctype) 429*da0073e9SAndroid Build Coastguard Worker return f"const_cast<Tensor&>({argname})" 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker unsat(goal) 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker return [Expr(solve(g, direct=False), g) for g in goal_ctypes] 434