1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import Any 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 6*da0073e9SAndroid Build Coastguard Worker BaseCppType, 7*da0073e9SAndroid Build Coastguard Worker BaseCType, 8*da0073e9SAndroid Build Coastguard Worker boolT, 9*da0073e9SAndroid Build Coastguard Worker CType, 10*da0073e9SAndroid Build Coastguard Worker deviceT, 11*da0073e9SAndroid Build Coastguard Worker doubleT, 12*da0073e9SAndroid Build Coastguard Worker generatorT, 13*da0073e9SAndroid Build Coastguard Worker layoutT, 14*da0073e9SAndroid Build Coastguard Worker ListCType, 15*da0073e9SAndroid Build Coastguard Worker longT, 16*da0073e9SAndroid Build Coastguard Worker memoryFormatT, 17*da0073e9SAndroid Build Coastguard Worker NamedCType, 18*da0073e9SAndroid Build Coastguard Worker OptionalCType, 19*da0073e9SAndroid Build Coastguard Worker scalarT, 20*da0073e9SAndroid Build Coastguard Worker scalarTypeT, 21*da0073e9SAndroid Build Coastguard Worker stringT, 22*da0073e9SAndroid Build Coastguard Worker SymIntT, 23*da0073e9SAndroid Build Coastguard Worker VectorCType, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 26*da0073e9SAndroid Build Coastguard Worker Argument, 27*da0073e9SAndroid Build Coastguard Worker BaseTy, 28*da0073e9SAndroid Build Coastguard Worker BaseType, 29*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 30*da0073e9SAndroid Build Coastguard Worker ListType, 31*da0073e9SAndroid Build Coastguard Worker OperatorName, 32*da0073e9SAndroid Build Coastguard Worker OptionalType, 33*da0073e9SAndroid Build Coastguard Worker Return, 34*da0073e9SAndroid Build Coastguard Worker TensorOptionsArguments, 35*da0073e9SAndroid Build Coastguard Worker Type, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker_valueT: BaseCppType | None = None 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker# A ValueT is an IR type which represents the computation of a Tensor. In other 43*da0073e9SAndroid Build Coastguard Worker# words, a PyTorch user will do operations on lazy tensors, and each output lazy 44*da0073e9SAndroid Build Coastguard Worker# tensor internally tracks a ValueT representing the IR node that would have 45*da0073e9SAndroid Build Coastguard Worker# actually produced the value of this tensor for real. 46*da0073e9SAndroid Build Coastguard Worker# 47*da0073e9SAndroid Build Coastguard Worker# This is configurable because different lazy tensor backends (LTC vs XLA) will 48*da0073e9SAndroid Build Coastguard Worker# have different IR representations. (Though, arguably, after unification they 49*da0073e9SAndroid Build Coastguard Worker# shouldn't!) 50*da0073e9SAndroid Build Coastguard Workerdef getValueT() -> BaseCppType: 51*da0073e9SAndroid Build Coastguard Worker global _valueT 52*da0073e9SAndroid Build Coastguard Worker if not _valueT: 53*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 54*da0073e9SAndroid Build Coastguard Worker "The value type needs to be set with setValueT() in run_gen_lazy_tensor()" 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker return _valueT 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerdef setValueT(val: BaseCppType) -> None: 61*da0073e9SAndroid Build Coastguard Worker global _valueT 62*da0073e9SAndroid Build Coastguard Worker _valueT = val 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, 66*da0073e9SAndroid Build Coastguard Worker# making it easier to represent special properties of an arg. 67*da0073e9SAndroid Build Coastguard WorkertensorListValueT = BaseCppType("torch::lazy", "Value") 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerdef process_ir_type( 71*da0073e9SAndroid Build Coastguard Worker typ: Type, properties: LazyIrProperties, *, symint: bool 72*da0073e9SAndroid Build Coastguard Worker) -> BaseCType | VectorCType | OptionalCType | ListCType: 73*da0073e9SAndroid Build Coastguard Worker """ 74*da0073e9SAndroid Build Coastguard Worker This function takes a type from NativeFunctions and converts it for use with 75*da0073e9SAndroid Build Coastguard Worker lazy tensor codegen. 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker Type conversion for lazy currently consists of 78*da0073e9SAndroid Build Coastguard Worker (1) changing at::Tensors into lazy::Values 79*da0073e9SAndroid Build Coastguard Worker (2) wrapping everything in a BaseCType 80*da0073e9SAndroid Build Coastguard Worker (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) 83*da0073e9SAndroid Build Coastguard Worker There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker This is incomplete- there are assertions in places that it's expected to need to add 86*da0073e9SAndroid Build Coastguard Worker more types as the codegen is used with more operators. 87*da0073e9SAndroid Build Coastguard Worker """ 88*da0073e9SAndroid Build Coastguard Worker if isinstance(typ, BaseType): 89*da0073e9SAndroid Build Coastguard Worker if typ.name == BaseTy.Tensor: 90*da0073e9SAndroid Build Coastguard Worker return BaseCType(getValueT()) 91*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.Scalar: 92*da0073e9SAndroid Build Coastguard Worker if properties.TreatScalarsAsConstants: 93*da0073e9SAndroid Build Coastguard Worker return BaseCType(scalarT) 94*da0073e9SAndroid Build Coastguard Worker # at::scalar has special handling, 95*da0073e9SAndroid Build Coastguard Worker # and is wrapped in an lazy::Value just like at::tensor 96*da0073e9SAndroid Build Coastguard Worker return BaseCType(getValueT()) 97*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.ScalarType: 98*da0073e9SAndroid Build Coastguard Worker return BaseCType(scalarTypeT) 99*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.int: 100*da0073e9SAndroid Build Coastguard Worker return BaseCType(longT) 101*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.SymInt: 102*da0073e9SAndroid Build Coastguard Worker if symint: 103*da0073e9SAndroid Build Coastguard Worker return BaseCType(getValueT()) 104*da0073e9SAndroid Build Coastguard Worker else: 105*da0073e9SAndroid Build Coastguard Worker return BaseCType(longT) 106*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.bool: 107*da0073e9SAndroid Build Coastguard Worker return BaseCType(boolT) 108*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.float: 109*da0073e9SAndroid Build Coastguard Worker return BaseCType(doubleT) 110*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.str: 111*da0073e9SAndroid Build Coastguard Worker return BaseCType(stringT) 112*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.Device: 113*da0073e9SAndroid Build Coastguard Worker return BaseCType(deviceT) 114*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.Generator: 115*da0073e9SAndroid Build Coastguard Worker return BaseCType(generatorT) 116*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.Layout: 117*da0073e9SAndroid Build Coastguard Worker return BaseCType(layoutT) 118*da0073e9SAndroid Build Coastguard Worker elif typ.name == BaseTy.MemoryFormat: 119*da0073e9SAndroid Build Coastguard Worker return BaseCType(memoryFormatT) 120*da0073e9SAndroid Build Coastguard Worker else: 121*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"TODO add support for type {repr(typ)}") 122*da0073e9SAndroid Build Coastguard Worker elif isinstance(typ, OptionalType): 123*da0073e9SAndroid Build Coastguard Worker return OptionalCType(process_ir_type(typ.elem, properties, symint=symint)) 124*da0073e9SAndroid Build Coastguard Worker elif isinstance(typ, ListType): 125*da0073e9SAndroid Build Coastguard Worker if str(typ.elem) == "Tensor?": 126*da0073e9SAndroid Build Coastguard Worker # TODO(whc) is this actually correct? or should it use a Vector like above 127*da0073e9SAndroid Build Coastguard Worker return ListCType(OptionalCType(BaseCType(getValueT()))) 128*da0073e9SAndroid Build Coastguard Worker elif str(typ.elem) == "Tensor": 129*da0073e9SAndroid Build Coastguard Worker # this is a TensorList which comes in from GetTensorList as a Value 130*da0073e9SAndroid Build Coastguard Worker return BaseCType(tensorListValueT) 131*da0073e9SAndroid Build Coastguard Worker elif typ.elem == BaseType(BaseTy.SymInt): 132*da0073e9SAndroid Build Coastguard Worker # TODO: return a value type. The problem here is analogous to 133*da0073e9SAndroid Build Coastguard Worker # the problem with tensorListValueT: if you have SymInt[] you 134*da0073e9SAndroid Build Coastguard Worker # cannot conveniently save the list of Value directly, as nodes 135*da0073e9SAndroid Build Coastguard Worker # expect to save values as a vector for ALL arguments. So you 136*da0073e9SAndroid Build Coastguard Worker # need a separate IR node that represents all of the size nodes 137*da0073e9SAndroid Build Coastguard Worker # assembled into a list. I'm not an LTC dev so I don't want to 138*da0073e9SAndroid Build Coastguard Worker # figure it out right now. Y'all figure it out... 139*da0073e9SAndroid Build Coastguard Worker return VectorCType(BaseCType(longT)) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker else: 142*da0073e9SAndroid Build Coastguard Worker return VectorCType(process_ir_type(typ.elem, properties, symint=symint)) 143*da0073e9SAndroid Build Coastguard Worker else: 144*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized type {repr(typ)}") 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker# TODO: Determining this based off of CType is bad; this should be computed 148*da0073e9SAndroid Build Coastguard Worker# from Type directly; then the same logic as process_ir_type can be used 149*da0073e9SAndroid Build Coastguard Worker# 150*da0073e9SAndroid Build Coastguard Worker# Invariant: passed typ should be an *owning* CType (e.g., we will report 151*da0073e9SAndroid Build Coastguard Worker# that ArrayRef<Value> is NOT a value type) 152*da0073e9SAndroid Build Coastguard Workerdef isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool: 153*da0073e9SAndroid Build Coastguard Worker """ 154*da0073e9SAndroid Build Coastguard Worker Given a type, determine if it is a Value-like type. This is equivalent to 155*da0073e9SAndroid Build Coastguard Worker being Tensor-like, but assumes the type has already been transformed. 156*da0073e9SAndroid Build Coastguard Worker """ 157*da0073e9SAndroid Build Coastguard Worker if isinstance(typ, BaseCType): 158*da0073e9SAndroid Build Coastguard Worker # I am regretting my naming conventions, but now we are wrapping at::scalar in 159*da0073e9SAndroid Build Coastguard Worker # lazy value, while preserving other 'scalar' types as scalars in the IR 160*da0073e9SAndroid Build Coastguard Worker treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants 161*da0073e9SAndroid Build Coastguard Worker return ( 162*da0073e9SAndroid Build Coastguard Worker typ.type == getValueT() 163*da0073e9SAndroid Build Coastguard Worker or (typ.type == scalarT and not treat_scalars_as_constants) 164*da0073e9SAndroid Build Coastguard Worker or typ.type == SymIntT 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker elif typ == VectorCType(BaseCType(SymIntT)): 167*da0073e9SAndroid Build Coastguard Worker # TODO: report True for this 168*da0073e9SAndroid Build Coastguard Worker return False 169*da0073e9SAndroid Build Coastguard Worker elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): 170*da0073e9SAndroid Build Coastguard Worker return isValueType(typ.elem, properties) 171*da0073e9SAndroid Build Coastguard Worker return False 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Workerdef isSymIntType(typ: Type) -> bool: 175*da0073e9SAndroid Build Coastguard Worker return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Workerdef isWrappedScalarType(typ: Type) -> bool: 179*da0073e9SAndroid Build Coastguard Worker """ 180*da0073e9SAndroid Build Coastguard Worker Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. 181*da0073e9SAndroid Build Coastguard Worker Since we literally change the type from scalarT to valueT, information is lost. 182*da0073e9SAndroid Build Coastguard Worker This function helps build a list of wrapped scalars to save that information 183*da0073e9SAndroid Build Coastguard Worker """ 184*da0073e9SAndroid Build Coastguard Worker if isinstance(typ, BaseType): 185*da0073e9SAndroid Build Coastguard Worker # I am regretting my naming conventions, but now we are wrapping at::scalar in 186*da0073e9SAndroid Build Coastguard Worker # lazy value, while preserving other 'scalar' types as scalars in the IR 187*da0073e9SAndroid Build Coastguard Worker return typ.name == BaseTy.Scalar 188*da0073e9SAndroid Build Coastguard Worker elif isinstance(typ, (OptionalType, ListType)): 189*da0073e9SAndroid Build Coastguard Worker return isWrappedScalarType(typ.elem) 190*da0073e9SAndroid Build Coastguard Worker return False 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker# TODO: dedupe with Type.is_generator_like 194*da0073e9SAndroid Build Coastguard Workerdef isGeneratorType(typ: Type) -> bool: 195*da0073e9SAndroid Build Coastguard Worker if isinstance(typ, BaseType): 196*da0073e9SAndroid Build Coastguard Worker return typ.name == BaseTy.Generator 197*da0073e9SAndroid Build Coastguard Worker elif isinstance(typ, (OptionalType)): 198*da0073e9SAndroid Build Coastguard Worker return isGeneratorType(typ.elem) 199*da0073e9SAndroid Build Coastguard Worker return False 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker# This class caches a few derived properties computed from an Argument 203*da0073e9SAndroid Build Coastguard Worker# and LazyIrProperties 204*da0073e9SAndroid Build Coastguard Workerclass LazyArgument: 205*da0073e9SAndroid Build Coastguard Worker name: str 206*da0073e9SAndroid Build Coastguard Worker orig_type: Type 207*da0073e9SAndroid Build Coastguard Worker lazy_type_: CType | None 208*da0073e9SAndroid Build Coastguard Worker is_wrapped_scalar: bool 209*da0073e9SAndroid Build Coastguard Worker is_generator: bool 210*da0073e9SAndroid Build Coastguard Worker # TODO: this is lies, it is false for symint list 211*da0073e9SAndroid Build Coastguard Worker is_symint_or_list: bool 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker # Whether or not we are treating this as symint or not 214*da0073e9SAndroid Build Coastguard Worker symint: bool 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker # true if this argument is or contains a lazy IR value 217*da0073e9SAndroid Build Coastguard Worker is_lazy_value: bool 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker def __init__( 220*da0073e9SAndroid Build Coastguard Worker self, arg: Argument, properties: LazyIrProperties, *, symint: bool 221*da0073e9SAndroid Build Coastguard Worker ) -> None: 222*da0073e9SAndroid Build Coastguard Worker self.name = arg.name 223*da0073e9SAndroid Build Coastguard Worker self.orig_type = arg.type 224*da0073e9SAndroid Build Coastguard Worker self.symint = symint 225*da0073e9SAndroid Build Coastguard Worker self.is_optional = isinstance(arg.type, OptionalType) 226*da0073e9SAndroid Build Coastguard Worker self.is_generator = isGeneratorType(arg.type) 227*da0073e9SAndroid Build Coastguard Worker self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint) 228*da0073e9SAndroid Build Coastguard Worker self.is_wrapped_scalar = isWrappedScalarType(arg.type) 229*da0073e9SAndroid Build Coastguard Worker self.is_symint_or_list = symint and ( 230*da0073e9SAndroid Build Coastguard Worker isSymIntType(arg.type) 231*da0073e9SAndroid Build Coastguard Worker or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem)) 232*da0073e9SAndroid Build Coastguard Worker # TODO: lists of symints are not currently treated as value types 233*da0073e9SAndroid Build Coastguard Worker # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem)) 234*da0073e9SAndroid Build Coastguard Worker ) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker self.is_lazy_value = isValueType(self.lazy_type, properties) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker @property 239*da0073e9SAndroid Build Coastguard Worker def lazy_type(self) -> CType: 240*da0073e9SAndroid Build Coastguard Worker assert ( 241*da0073e9SAndroid Build Coastguard Worker self.lazy_type_ is not None 242*da0073e9SAndroid Build Coastguard Worker ), f"Attempted to access lazy_type for invalid argument {self.name}" 243*da0073e9SAndroid Build Coastguard Worker return self.lazy_type_ 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Workerclass LazyIrProperties: 247*da0073e9SAndroid Build Coastguard Worker """Collection of properties for an IR node 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker The property groups are listed below. Each group is mutually 250*da0073e9SAndroid Build Coastguard Worker exclusive, meaning that only one property from each group can be True 251*da0073e9SAndroid Build Coastguard Worker at any one time. The properties can be accessed as if they were normal 252*da0073e9SAndroid Build Coastguard Worker attributes. The mutual exclusivity is automatically handled. 253*da0073e9SAndroid Build Coastguard Worker """ 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker Properties: tuple[tuple[str, ...], ...] = ( 256*da0073e9SAndroid Build Coastguard Worker ( 257*da0073e9SAndroid Build Coastguard Worker "ShapePrecompute", # Assume shape has been precomputed 258*da0073e9SAndroid Build Coastguard Worker "ShapeCompute", # Need to compute the shape on construction 259*da0073e9SAndroid Build Coastguard Worker "ShapeCache", # Utilize the shape cache to defer computation 260*da0073e9SAndroid Build Coastguard Worker ), 261*da0073e9SAndroid Build Coastguard Worker ( 262*da0073e9SAndroid Build Coastguard Worker "Lower", # Codegen full lower function 263*da0073e9SAndroid Build Coastguard Worker "LowerDeclOnly", # Codegen only lower function declaration 264*da0073e9SAndroid Build Coastguard Worker ), 265*da0073e9SAndroid Build Coastguard Worker ( 266*da0073e9SAndroid Build Coastguard Worker "CanBeReused", # Codegen full reuse function 267*da0073e9SAndroid Build Coastguard Worker "CanBeReusedDeclOnly", # Codegen only reuse function declaration 268*da0073e9SAndroid Build Coastguard Worker ), 269*da0073e9SAndroid Build Coastguard Worker ( 270*da0073e9SAndroid Build Coastguard Worker "CreateFn", # Codegen full create function 271*da0073e9SAndroid Build Coastguard Worker "CreateFnDeclOnly", # Codegen only create function declaration 272*da0073e9SAndroid Build Coastguard Worker ), 273*da0073e9SAndroid Build Coastguard Worker ( 274*da0073e9SAndroid Build Coastguard Worker "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values 275*da0073e9SAndroid Build Coastguard Worker ), 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker def __init__(self, *default_properties: str) -> None: 279*da0073e9SAndroid Build Coastguard Worker properties: dict[tuple[str, ...], str | None] = dict.fromkeys( 280*da0073e9SAndroid Build Coastguard Worker LazyIrProperties.Properties 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker self.__dict__["properties"] = properties 283*da0073e9SAndroid Build Coastguard Worker for p in default_properties: 284*da0073e9SAndroid Build Coastguard Worker setattr(self, p, True) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, key: str) -> Any: 287*da0073e9SAndroid Build Coastguard Worker properties = self.__dict__["properties"] 288*da0073e9SAndroid Build Coastguard Worker for values in LazyIrProperties.Properties: 289*da0073e9SAndroid Build Coastguard Worker if key in values: 290*da0073e9SAndroid Build Coastguard Worker return properties[values] == key 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker return self.__getattribute__(key) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, key: str, value: Any) -> Any: 295*da0073e9SAndroid Build Coastguard Worker properties = self.__dict__["properties"] 296*da0073e9SAndroid Build Coastguard Worker for values in LazyIrProperties.Properties: 297*da0073e9SAndroid Build Coastguard Worker if key in values: 298*da0073e9SAndroid Build Coastguard Worker properties[values] = key if value else None 299*da0073e9SAndroid Build Coastguard Worker return value 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker raise KeyError(f"Invalid property: {key}") 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. 305*da0073e9SAndroid Build Coastguard Worker# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), 306*da0073e9SAndroid Build Coastguard Worker# but carries type information from a native FunctionSchema modified for use with IR nodes, 307*da0073e9SAndroid Build Coastguard Worker# and preserving original argument names. 308*da0073e9SAndroid Build Coastguard Worker# 309*da0073e9SAndroid Build Coastguard Worker# TODO: This is not idiomatic with how other torchgen APIs transform on schema. 310*da0073e9SAndroid Build Coastguard Workerclass LazyIrSchema: 311*da0073e9SAndroid Build Coastguard Worker # The name of the operator this function schema describes. 312*da0073e9SAndroid Build Coastguard Worker name: OperatorName 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker positional_args: tuple[LazyArgument, ...] 315*da0073e9SAndroid Build Coastguard Worker keyword_args: tuple[LazyArgument, ...] 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker # TODO: Need to handle collisions with argument names at some point 318*da0073e9SAndroid Build Coastguard Worker returns: tuple[Return, ...] 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker # if this schema has a Generator arg, list its orig ctype/name but don't 321*da0073e9SAndroid Build Coastguard Worker # build a LazyArgument since lazy IR doesn't support it 322*da0073e9SAndroid Build Coastguard Worker generator_arg: NamedCType | None = None 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker # original function schema 325*da0073e9SAndroid Build Coastguard Worker func: FunctionSchema 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker # Whether or not we are code-genning for SymInt or not 328*da0073e9SAndroid Build Coastguard Worker symint: bool 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker properties: LazyIrProperties = LazyIrProperties( 331*da0073e9SAndroid Build Coastguard Worker # default properties 332*da0073e9SAndroid Build Coastguard Worker "ShapePrecompute", 333*da0073e9SAndroid Build Coastguard Worker "Lower", 334*da0073e9SAndroid Build Coastguard Worker "CanBeReused", 335*da0073e9SAndroid Build Coastguard Worker ) 336*da0073e9SAndroid Build Coastguard Worker opkind: str | None = None 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker def __init__( 339*da0073e9SAndroid Build Coastguard Worker self, 340*da0073e9SAndroid Build Coastguard Worker func: FunctionSchema, 341*da0073e9SAndroid Build Coastguard Worker properties: LazyIrProperties | None = None, 342*da0073e9SAndroid Build Coastguard Worker *, 343*da0073e9SAndroid Build Coastguard Worker symint: bool, 344*da0073e9SAndroid Build Coastguard Worker ) -> None: 345*da0073e9SAndroid Build Coastguard Worker if properties: 346*da0073e9SAndroid Build Coastguard Worker self.properties = properties 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker self.func = func 349*da0073e9SAndroid Build Coastguard Worker self.symint = symint 350*da0073e9SAndroid Build Coastguard Worker positional_args: list[LazyArgument] = [] 351*da0073e9SAndroid Build Coastguard Worker for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: 352*da0073e9SAndroid Build Coastguard Worker if arg_field == "self_arg" and func.arguments.self_arg is not None: 353*da0073e9SAndroid Build Coastguard Worker arg = func.arguments.self_arg.argument 354*da0073e9SAndroid Build Coastguard Worker positional_args.append( 355*da0073e9SAndroid Build Coastguard Worker LazyArgument(arg, self.properties, symint=symint) 356*da0073e9SAndroid Build Coastguard Worker ) 357*da0073e9SAndroid Build Coastguard Worker elif getattr(func.arguments, arg_field) is not None: 358*da0073e9SAndroid Build Coastguard Worker positional_args.extend( 359*da0073e9SAndroid Build Coastguard Worker LazyArgument(arg, self.properties, symint=symint) 360*da0073e9SAndroid Build Coastguard Worker for arg in getattr(func.arguments, arg_field) 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker self.positional_args = tuple(positional_args) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker keyword_args: list[LazyArgument] = [] 365*da0073e9SAndroid Build Coastguard Worker for arg_field in [ 366*da0073e9SAndroid Build Coastguard Worker "pre_tensor_options_kwarg_only", 367*da0073e9SAndroid Build Coastguard Worker "tensor_options", 368*da0073e9SAndroid Build Coastguard Worker "post_tensor_options_kwarg_only", 369*da0073e9SAndroid Build Coastguard Worker "out", 370*da0073e9SAndroid Build Coastguard Worker ]: 371*da0073e9SAndroid Build Coastguard Worker curr_args = getattr(func.arguments, arg_field) 372*da0073e9SAndroid Build Coastguard Worker if curr_args is not None: 373*da0073e9SAndroid Build Coastguard Worker if isinstance(curr_args, TensorOptionsArguments): 374*da0073e9SAndroid Build Coastguard Worker curr_args = curr_args.all() 375*da0073e9SAndroid Build Coastguard Worker for arg in curr_args: 376*da0073e9SAndroid Build Coastguard Worker if isGeneratorType(arg.type): 377*da0073e9SAndroid Build Coastguard Worker assert ( 378*da0073e9SAndroid Build Coastguard Worker self.generator_arg is None 379*da0073e9SAndroid Build Coastguard Worker ), "We expect there is only one generator arg" 380*da0073e9SAndroid Build Coastguard Worker self.generator_arg = NamedCType( 381*da0073e9SAndroid Build Coastguard Worker arg.name, arg.type # type:ignore[arg-type] 382*da0073e9SAndroid Build Coastguard Worker ) 383*da0073e9SAndroid Build Coastguard Worker keyword_args.extend( 384*da0073e9SAndroid Build Coastguard Worker LazyArgument(arg, self.properties, symint=symint) 385*da0073e9SAndroid Build Coastguard Worker for arg in curr_args 386*da0073e9SAndroid Build Coastguard Worker ) 387*da0073e9SAndroid Build Coastguard Worker self.keyword_args = tuple(keyword_args) 388*da0073e9SAndroid Build Coastguard Worker self.name = func.name 389*da0073e9SAndroid Build Coastguard Worker self.returns = func.returns 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker @property 392*da0073e9SAndroid Build Coastguard Worker def node_name(self) -> str: 393*da0073e9SAndroid Build Coastguard Worker """ 394*da0073e9SAndroid Build Coastguard Worker Return camel-case version of op in node. 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker Note: This function also appends any `overload_name` in the operation. 397*da0073e9SAndroid Build Coastguard Worker For example, if the op is `bitwise_and.Tensor`, the returned name 398*da0073e9SAndroid Build Coastguard Worker will be `BitwiseAndTensor`. 399*da0073e9SAndroid Build Coastguard Worker """ 400*da0073e9SAndroid Build Coastguard Worker op_name = f"{self.name.name}_{self.name.overload_name}".lower() 401*da0073e9SAndroid Build Coastguard Worker return "".join(word.capitalize() or "" for word in op_name.split("_")) 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker @property 404*da0073e9SAndroid Build Coastguard Worker def aten_name(self) -> str: 405*da0073e9SAndroid Build Coastguard Worker return str(self.name.name) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker @property 408*da0073e9SAndroid Build Coastguard Worker def base_name(self) -> str: 409*da0073e9SAndroid Build Coastguard Worker return f"{self.name.name.base}" 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker def filtered_args( 412*da0073e9SAndroid Build Coastguard Worker self, 413*da0073e9SAndroid Build Coastguard Worker positional: bool = True, 414*da0073e9SAndroid Build Coastguard Worker keyword: bool = True, 415*da0073e9SAndroid Build Coastguard Worker values: bool = True, 416*da0073e9SAndroid Build Coastguard Worker scalars: bool = True, 417*da0073e9SAndroid Build Coastguard Worker generator: bool = True, 418*da0073e9SAndroid Build Coastguard Worker ) -> list[LazyArgument]: 419*da0073e9SAndroid Build Coastguard Worker # This function maintains the sorted order of arguments but provides different filtered views. 420*da0073e9SAndroid Build Coastguard Worker # Some parts of the code care about kwargs vs args (TS lowerings), 421*da0073e9SAndroid Build Coastguard Worker # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. 422*da0073e9SAndroid Build Coastguard Worker # Generators are special cased, as they are needed for fallback/shape-inference but not supported 423*da0073e9SAndroid Build Coastguard Worker # in TS lowerings and therefore also omitted from lazy IR. 424*da0073e9SAndroid Build Coastguard Worker args: list[LazyArgument] = [] 425*da0073e9SAndroid Build Coastguard Worker if positional: 426*da0073e9SAndroid Build Coastguard Worker args.extend(self.positional_args) 427*da0073e9SAndroid Build Coastguard Worker if keyword: 428*da0073e9SAndroid Build Coastguard Worker args.extend(self.keyword_args) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker if values and scalars and generator: 431*da0073e9SAndroid Build Coastguard Worker return args 432*da0073e9SAndroid Build Coastguard Worker elif values and scalars: 433*da0073e9SAndroid Build Coastguard Worker return [a for a in args if not a.is_generator] 434*da0073e9SAndroid Build Coastguard Worker elif values: 435*da0073e9SAndroid Build Coastguard Worker return [a for a in args if a.is_lazy_value] 436*da0073e9SAndroid Build Coastguard Worker elif scalars: 437*da0073e9SAndroid Build Coastguard Worker return [ 438*da0073e9SAndroid Build Coastguard Worker a 439*da0073e9SAndroid Build Coastguard Worker for a in args 440*da0073e9SAndroid Build Coastguard Worker if not a.is_lazy_value and (generator or not a.is_generator) 441*da0073e9SAndroid Build Coastguard Worker ] 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker return [] 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker @property 446*da0073e9SAndroid Build Coastguard Worker def positional_values(self) -> list[LazyArgument]: 447*da0073e9SAndroid Build Coastguard Worker return self.filtered_args( 448*da0073e9SAndroid Build Coastguard Worker positional=True, keyword=False, values=True, scalars=False 449*da0073e9SAndroid Build Coastguard Worker ) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker @property 452*da0073e9SAndroid Build Coastguard Worker def positional_scalars(self) -> list[LazyArgument]: 453*da0073e9SAndroid Build Coastguard Worker return self.filtered_args( 454*da0073e9SAndroid Build Coastguard Worker positional=True, keyword=False, values=False, scalars=True 455*da0073e9SAndroid Build Coastguard Worker ) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker @property 458*da0073e9SAndroid Build Coastguard Worker def keyword_values(self) -> list[LazyArgument]: 459*da0073e9SAndroid Build Coastguard Worker return self.filtered_args( 460*da0073e9SAndroid Build Coastguard Worker positional=False, keyword=True, values=True, scalars=False 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker @property 464*da0073e9SAndroid Build Coastguard Worker def keyword_scalars(self) -> list[LazyArgument]: 465*da0073e9SAndroid Build Coastguard Worker return self.filtered_args( 466*da0073e9SAndroid Build Coastguard Worker positional=False, keyword=True, values=False, scalars=True 467*da0073e9SAndroid Build Coastguard Worker ) 468