xref: /aosp_15_r20/external/pytorch/torchgen/api/lazy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Any
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
6*da0073e9SAndroid Build Coastguard Worker    BaseCppType,
7*da0073e9SAndroid Build Coastguard Worker    BaseCType,
8*da0073e9SAndroid Build Coastguard Worker    boolT,
9*da0073e9SAndroid Build Coastguard Worker    CType,
10*da0073e9SAndroid Build Coastguard Worker    deviceT,
11*da0073e9SAndroid Build Coastguard Worker    doubleT,
12*da0073e9SAndroid Build Coastguard Worker    generatorT,
13*da0073e9SAndroid Build Coastguard Worker    layoutT,
14*da0073e9SAndroid Build Coastguard Worker    ListCType,
15*da0073e9SAndroid Build Coastguard Worker    longT,
16*da0073e9SAndroid Build Coastguard Worker    memoryFormatT,
17*da0073e9SAndroid Build Coastguard Worker    NamedCType,
18*da0073e9SAndroid Build Coastguard Worker    OptionalCType,
19*da0073e9SAndroid Build Coastguard Worker    scalarT,
20*da0073e9SAndroid Build Coastguard Worker    scalarTypeT,
21*da0073e9SAndroid Build Coastguard Worker    stringT,
22*da0073e9SAndroid Build Coastguard Worker    SymIntT,
23*da0073e9SAndroid Build Coastguard Worker    VectorCType,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
26*da0073e9SAndroid Build Coastguard Worker    Argument,
27*da0073e9SAndroid Build Coastguard Worker    BaseTy,
28*da0073e9SAndroid Build Coastguard Worker    BaseType,
29*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
30*da0073e9SAndroid Build Coastguard Worker    ListType,
31*da0073e9SAndroid Build Coastguard Worker    OperatorName,
32*da0073e9SAndroid Build Coastguard Worker    OptionalType,
33*da0073e9SAndroid Build Coastguard Worker    Return,
34*da0073e9SAndroid Build Coastguard Worker    TensorOptionsArguments,
35*da0073e9SAndroid Build Coastguard Worker    Type,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker_valueT: BaseCppType | None = None
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker# A ValueT is an IR type which represents the computation of a Tensor.  In other
43*da0073e9SAndroid Build Coastguard Worker# words, a PyTorch user will do operations on lazy tensors, and each output lazy
44*da0073e9SAndroid Build Coastguard Worker# tensor internally tracks a ValueT representing the IR node that would have
45*da0073e9SAndroid Build Coastguard Worker# actually produced the value of this tensor for real.
46*da0073e9SAndroid Build Coastguard Worker#
47*da0073e9SAndroid Build Coastguard Worker# This is configurable because different lazy tensor backends (LTC vs XLA) will
48*da0073e9SAndroid Build Coastguard Worker# have different IR representations.  (Though, arguably, after unification they
49*da0073e9SAndroid Build Coastguard Worker# shouldn't!)
50*da0073e9SAndroid Build Coastguard Workerdef getValueT() -> BaseCppType:
51*da0073e9SAndroid Build Coastguard Worker    global _valueT
52*da0073e9SAndroid Build Coastguard Worker    if not _valueT:
53*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
54*da0073e9SAndroid Build Coastguard Worker            "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
55*da0073e9SAndroid Build Coastguard Worker        )
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    return _valueT
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerdef setValueT(val: BaseCppType) -> None:
61*da0073e9SAndroid Build Coastguard Worker    global _valueT
62*da0073e9SAndroid Build Coastguard Worker    _valueT = val
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
66*da0073e9SAndroid Build Coastguard Worker# making it easier to represent special properties of an arg.
67*da0073e9SAndroid Build Coastguard WorkertensorListValueT = BaseCppType("torch::lazy", "Value")
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerdef process_ir_type(
71*da0073e9SAndroid Build Coastguard Worker    typ: Type, properties: LazyIrProperties, *, symint: bool
72*da0073e9SAndroid Build Coastguard Worker) -> BaseCType | VectorCType | OptionalCType | ListCType:
73*da0073e9SAndroid Build Coastguard Worker    """
74*da0073e9SAndroid Build Coastguard Worker    This function takes a type from NativeFunctions and converts it for use with
75*da0073e9SAndroid Build Coastguard Worker    lazy tensor codegen.
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    Type conversion for lazy currently consists of
78*da0073e9SAndroid Build Coastguard Worker     (1) changing at::Tensors into lazy::Values
79*da0073e9SAndroid Build Coastguard Worker     (2) wrapping everything in a BaseCType
80*da0073e9SAndroid Build Coastguard Worker     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
83*da0073e9SAndroid Build Coastguard Worker    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    This is incomplete- there are assertions in places that it's expected to need to add
86*da0073e9SAndroid Build Coastguard Worker    more types as the codegen is used with more operators.
87*da0073e9SAndroid Build Coastguard Worker    """
88*da0073e9SAndroid Build Coastguard Worker    if isinstance(typ, BaseType):
89*da0073e9SAndroid Build Coastguard Worker        if typ.name == BaseTy.Tensor:
90*da0073e9SAndroid Build Coastguard Worker            return BaseCType(getValueT())
91*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.Scalar:
92*da0073e9SAndroid Build Coastguard Worker            if properties.TreatScalarsAsConstants:
93*da0073e9SAndroid Build Coastguard Worker                return BaseCType(scalarT)
94*da0073e9SAndroid Build Coastguard Worker            # at::scalar has special handling,
95*da0073e9SAndroid Build Coastguard Worker            # and is wrapped in an lazy::Value just like at::tensor
96*da0073e9SAndroid Build Coastguard Worker            return BaseCType(getValueT())
97*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.ScalarType:
98*da0073e9SAndroid Build Coastguard Worker            return BaseCType(scalarTypeT)
99*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.int:
100*da0073e9SAndroid Build Coastguard Worker            return BaseCType(longT)
101*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.SymInt:
102*da0073e9SAndroid Build Coastguard Worker            if symint:
103*da0073e9SAndroid Build Coastguard Worker                return BaseCType(getValueT())
104*da0073e9SAndroid Build Coastguard Worker            else:
105*da0073e9SAndroid Build Coastguard Worker                return BaseCType(longT)
106*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.bool:
107*da0073e9SAndroid Build Coastguard Worker            return BaseCType(boolT)
108*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.float:
109*da0073e9SAndroid Build Coastguard Worker            return BaseCType(doubleT)
110*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.str:
111*da0073e9SAndroid Build Coastguard Worker            return BaseCType(stringT)
112*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.Device:
113*da0073e9SAndroid Build Coastguard Worker            return BaseCType(deviceT)
114*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.Generator:
115*da0073e9SAndroid Build Coastguard Worker            return BaseCType(generatorT)
116*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.Layout:
117*da0073e9SAndroid Build Coastguard Worker            return BaseCType(layoutT)
118*da0073e9SAndroid Build Coastguard Worker        elif typ.name == BaseTy.MemoryFormat:
119*da0073e9SAndroid Build Coastguard Worker            return BaseCType(memoryFormatT)
120*da0073e9SAndroid Build Coastguard Worker        else:
121*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(f"TODO add support for type {repr(typ)}")
122*da0073e9SAndroid Build Coastguard Worker    elif isinstance(typ, OptionalType):
123*da0073e9SAndroid Build Coastguard Worker        return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
124*da0073e9SAndroid Build Coastguard Worker    elif isinstance(typ, ListType):
125*da0073e9SAndroid Build Coastguard Worker        if str(typ.elem) == "Tensor?":
126*da0073e9SAndroid Build Coastguard Worker            # TODO(whc) is this actually correct? or should it use a Vector like above
127*da0073e9SAndroid Build Coastguard Worker            return ListCType(OptionalCType(BaseCType(getValueT())))
128*da0073e9SAndroid Build Coastguard Worker        elif str(typ.elem) == "Tensor":
129*da0073e9SAndroid Build Coastguard Worker            # this is a TensorList which comes in from GetTensorList as a Value
130*da0073e9SAndroid Build Coastguard Worker            return BaseCType(tensorListValueT)
131*da0073e9SAndroid Build Coastguard Worker        elif typ.elem == BaseType(BaseTy.SymInt):
132*da0073e9SAndroid Build Coastguard Worker            # TODO: return a value type.  The problem here is analogous to
133*da0073e9SAndroid Build Coastguard Worker            # the problem with tensorListValueT: if you have SymInt[] you
134*da0073e9SAndroid Build Coastguard Worker            # cannot conveniently save the list of Value directly, as nodes
135*da0073e9SAndroid Build Coastguard Worker            # expect to save values as a vector for ALL arguments.  So you
136*da0073e9SAndroid Build Coastguard Worker            # need a separate IR node that represents all of the size nodes
137*da0073e9SAndroid Build Coastguard Worker            # assembled into a list.  I'm not an LTC dev so I don't want to
138*da0073e9SAndroid Build Coastguard Worker            # figure it out right now.  Y'all figure it out...
139*da0073e9SAndroid Build Coastguard Worker            return VectorCType(BaseCType(longT))
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        else:
142*da0073e9SAndroid Build Coastguard Worker            return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
143*da0073e9SAndroid Build Coastguard Worker    else:
144*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"unrecognized type {repr(typ)}")
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker# TODO: Determining this based off of CType is bad; this should be computed
148*da0073e9SAndroid Build Coastguard Worker# from Type directly; then the same logic as process_ir_type can be used
149*da0073e9SAndroid Build Coastguard Worker#
150*da0073e9SAndroid Build Coastguard Worker# Invariant: passed typ should be an *owning* CType (e.g., we will report
151*da0073e9SAndroid Build Coastguard Worker# that ArrayRef<Value> is NOT a value type)
152*da0073e9SAndroid Build Coastguard Workerdef isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
153*da0073e9SAndroid Build Coastguard Worker    """
154*da0073e9SAndroid Build Coastguard Worker    Given a type, determine if it is a Value-like type.  This is equivalent to
155*da0073e9SAndroid Build Coastguard Worker    being Tensor-like, but assumes the type has already been transformed.
156*da0073e9SAndroid Build Coastguard Worker    """
157*da0073e9SAndroid Build Coastguard Worker    if isinstance(typ, BaseCType):
158*da0073e9SAndroid Build Coastguard Worker        # I am regretting my naming conventions, but now we are wrapping at::scalar in
159*da0073e9SAndroid Build Coastguard Worker        # lazy value, while preserving other 'scalar' types as scalars in the IR
160*da0073e9SAndroid Build Coastguard Worker        treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
161*da0073e9SAndroid Build Coastguard Worker        return (
162*da0073e9SAndroid Build Coastguard Worker            typ.type == getValueT()
163*da0073e9SAndroid Build Coastguard Worker            or (typ.type == scalarT and not treat_scalars_as_constants)
164*da0073e9SAndroid Build Coastguard Worker            or typ.type == SymIntT
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker    elif typ == VectorCType(BaseCType(SymIntT)):
167*da0073e9SAndroid Build Coastguard Worker        # TODO: report True for this
168*da0073e9SAndroid Build Coastguard Worker        return False
169*da0073e9SAndroid Build Coastguard Worker    elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
170*da0073e9SAndroid Build Coastguard Worker        return isValueType(typ.elem, properties)
171*da0073e9SAndroid Build Coastguard Worker    return False
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Workerdef isSymIntType(typ: Type) -> bool:
175*da0073e9SAndroid Build Coastguard Worker    return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Workerdef isWrappedScalarType(typ: Type) -> bool:
179*da0073e9SAndroid Build Coastguard Worker    """
180*da0073e9SAndroid Build Coastguard Worker    Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
181*da0073e9SAndroid Build Coastguard Worker    Since we literally change the type from scalarT to valueT, information is lost.
182*da0073e9SAndroid Build Coastguard Worker    This function helps build a list of wrapped scalars to save that information
183*da0073e9SAndroid Build Coastguard Worker    """
184*da0073e9SAndroid Build Coastguard Worker    if isinstance(typ, BaseType):
185*da0073e9SAndroid Build Coastguard Worker        # I am regretting my naming conventions, but now we are wrapping at::scalar in
186*da0073e9SAndroid Build Coastguard Worker        # lazy value, while preserving other 'scalar' types as scalars in the IR
187*da0073e9SAndroid Build Coastguard Worker        return typ.name == BaseTy.Scalar
188*da0073e9SAndroid Build Coastguard Worker    elif isinstance(typ, (OptionalType, ListType)):
189*da0073e9SAndroid Build Coastguard Worker        return isWrappedScalarType(typ.elem)
190*da0073e9SAndroid Build Coastguard Worker    return False
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker# TODO: dedupe with Type.is_generator_like
194*da0073e9SAndroid Build Coastguard Workerdef isGeneratorType(typ: Type) -> bool:
195*da0073e9SAndroid Build Coastguard Worker    if isinstance(typ, BaseType):
196*da0073e9SAndroid Build Coastguard Worker        return typ.name == BaseTy.Generator
197*da0073e9SAndroid Build Coastguard Worker    elif isinstance(typ, (OptionalType)):
198*da0073e9SAndroid Build Coastguard Worker        return isGeneratorType(typ.elem)
199*da0073e9SAndroid Build Coastguard Worker    return False
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker# This class caches a few derived properties computed from an Argument
203*da0073e9SAndroid Build Coastguard Worker# and LazyIrProperties
204*da0073e9SAndroid Build Coastguard Workerclass LazyArgument:
205*da0073e9SAndroid Build Coastguard Worker    name: str
206*da0073e9SAndroid Build Coastguard Worker    orig_type: Type
207*da0073e9SAndroid Build Coastguard Worker    lazy_type_: CType | None
208*da0073e9SAndroid Build Coastguard Worker    is_wrapped_scalar: bool
209*da0073e9SAndroid Build Coastguard Worker    is_generator: bool
210*da0073e9SAndroid Build Coastguard Worker    # TODO: this is lies, it is false for symint list
211*da0073e9SAndroid Build Coastguard Worker    is_symint_or_list: bool
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    # Whether or not we are treating this as symint or not
214*da0073e9SAndroid Build Coastguard Worker    symint: bool
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    # true if this argument is or contains a lazy IR value
217*da0073e9SAndroid Build Coastguard Worker    is_lazy_value: bool
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    def __init__(
220*da0073e9SAndroid Build Coastguard Worker        self, arg: Argument, properties: LazyIrProperties, *, symint: bool
221*da0073e9SAndroid Build Coastguard Worker    ) -> None:
222*da0073e9SAndroid Build Coastguard Worker        self.name = arg.name
223*da0073e9SAndroid Build Coastguard Worker        self.orig_type = arg.type
224*da0073e9SAndroid Build Coastguard Worker        self.symint = symint
225*da0073e9SAndroid Build Coastguard Worker        self.is_optional = isinstance(arg.type, OptionalType)
226*da0073e9SAndroid Build Coastguard Worker        self.is_generator = isGeneratorType(arg.type)
227*da0073e9SAndroid Build Coastguard Worker        self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
228*da0073e9SAndroid Build Coastguard Worker        self.is_wrapped_scalar = isWrappedScalarType(arg.type)
229*da0073e9SAndroid Build Coastguard Worker        self.is_symint_or_list = symint and (
230*da0073e9SAndroid Build Coastguard Worker            isSymIntType(arg.type)
231*da0073e9SAndroid Build Coastguard Worker            or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
232*da0073e9SAndroid Build Coastguard Worker            # TODO: lists of symints are not currently treated as value types
233*da0073e9SAndroid Build Coastguard Worker            # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
234*da0073e9SAndroid Build Coastguard Worker        )
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        self.is_lazy_value = isValueType(self.lazy_type, properties)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    @property
239*da0073e9SAndroid Build Coastguard Worker    def lazy_type(self) -> CType:
240*da0073e9SAndroid Build Coastguard Worker        assert (
241*da0073e9SAndroid Build Coastguard Worker            self.lazy_type_ is not None
242*da0073e9SAndroid Build Coastguard Worker        ), f"Attempted to access lazy_type for invalid argument {self.name}"
243*da0073e9SAndroid Build Coastguard Worker        return self.lazy_type_
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Workerclass LazyIrProperties:
247*da0073e9SAndroid Build Coastguard Worker    """Collection of properties for an IR node
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    The property groups are listed below. Each group is mutually
250*da0073e9SAndroid Build Coastguard Worker    exclusive, meaning that only one property from each group can be True
251*da0073e9SAndroid Build Coastguard Worker    at any one time. The properties can be accessed as if they were normal
252*da0073e9SAndroid Build Coastguard Worker    attributes. The mutual exclusivity is automatically handled.
253*da0073e9SAndroid Build Coastguard Worker    """
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    Properties: tuple[tuple[str, ...], ...] = (
256*da0073e9SAndroid Build Coastguard Worker        (
257*da0073e9SAndroid Build Coastguard Worker            "ShapePrecompute",  # Assume shape has been precomputed
258*da0073e9SAndroid Build Coastguard Worker            "ShapeCompute",  # Need to compute the shape on construction
259*da0073e9SAndroid Build Coastguard Worker            "ShapeCache",  # Utilize the shape cache to defer computation
260*da0073e9SAndroid Build Coastguard Worker        ),
261*da0073e9SAndroid Build Coastguard Worker        (
262*da0073e9SAndroid Build Coastguard Worker            "Lower",  # Codegen full lower function
263*da0073e9SAndroid Build Coastguard Worker            "LowerDeclOnly",  # Codegen only lower function declaration
264*da0073e9SAndroid Build Coastguard Worker        ),
265*da0073e9SAndroid Build Coastguard Worker        (
266*da0073e9SAndroid Build Coastguard Worker            "CanBeReused",  # Codegen full reuse function
267*da0073e9SAndroid Build Coastguard Worker            "CanBeReusedDeclOnly",  # Codegen only reuse function declaration
268*da0073e9SAndroid Build Coastguard Worker        ),
269*da0073e9SAndroid Build Coastguard Worker        (
270*da0073e9SAndroid Build Coastguard Worker            "CreateFn",  # Codegen full create function
271*da0073e9SAndroid Build Coastguard Worker            "CreateFnDeclOnly",  # Codegen only create function declaration
272*da0073e9SAndroid Build Coastguard Worker        ),
273*da0073e9SAndroid Build Coastguard Worker        (
274*da0073e9SAndroid Build Coastguard Worker            "TreatScalarsAsConstants",  # Treat Scalars as constants instead of handling like values
275*da0073e9SAndroid Build Coastguard Worker        ),
276*da0073e9SAndroid Build Coastguard Worker    )
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *default_properties: str) -> None:
279*da0073e9SAndroid Build Coastguard Worker        properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
280*da0073e9SAndroid Build Coastguard Worker            LazyIrProperties.Properties
281*da0073e9SAndroid Build Coastguard Worker        )
282*da0073e9SAndroid Build Coastguard Worker        self.__dict__["properties"] = properties
283*da0073e9SAndroid Build Coastguard Worker        for p in default_properties:
284*da0073e9SAndroid Build Coastguard Worker            setattr(self, p, True)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, key: str) -> Any:
287*da0073e9SAndroid Build Coastguard Worker        properties = self.__dict__["properties"]
288*da0073e9SAndroid Build Coastguard Worker        for values in LazyIrProperties.Properties:
289*da0073e9SAndroid Build Coastguard Worker            if key in values:
290*da0073e9SAndroid Build Coastguard Worker                return properties[values] == key
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        return self.__getattribute__(key)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def __setattr__(self, key: str, value: Any) -> Any:
295*da0073e9SAndroid Build Coastguard Worker        properties = self.__dict__["properties"]
296*da0073e9SAndroid Build Coastguard Worker        for values in LazyIrProperties.Properties:
297*da0073e9SAndroid Build Coastguard Worker            if key in values:
298*da0073e9SAndroid Build Coastguard Worker                properties[values] = key if value else None
299*da0073e9SAndroid Build Coastguard Worker                return value
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        raise KeyError(f"Invalid property: {key}")
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
305*da0073e9SAndroid Build Coastguard Worker# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
306*da0073e9SAndroid Build Coastguard Worker# but carries type information from a native FunctionSchema modified for use with IR nodes,
307*da0073e9SAndroid Build Coastguard Worker# and preserving original argument names.
308*da0073e9SAndroid Build Coastguard Worker#
309*da0073e9SAndroid Build Coastguard Worker# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
310*da0073e9SAndroid Build Coastguard Workerclass LazyIrSchema:
311*da0073e9SAndroid Build Coastguard Worker    # The name of the operator this function schema describes.
312*da0073e9SAndroid Build Coastguard Worker    name: OperatorName
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    positional_args: tuple[LazyArgument, ...]
315*da0073e9SAndroid Build Coastguard Worker    keyword_args: tuple[LazyArgument, ...]
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    # TODO: Need to handle collisions with argument names at some point
318*da0073e9SAndroid Build Coastguard Worker    returns: tuple[Return, ...]
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    # if this schema has a Generator arg, list its orig ctype/name but don't
321*da0073e9SAndroid Build Coastguard Worker    # build a LazyArgument since lazy IR doesn't support it
322*da0073e9SAndroid Build Coastguard Worker    generator_arg: NamedCType | None = None
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    # original function schema
325*da0073e9SAndroid Build Coastguard Worker    func: FunctionSchema
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    # Whether or not we are code-genning for SymInt or not
328*da0073e9SAndroid Build Coastguard Worker    symint: bool
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker    properties: LazyIrProperties = LazyIrProperties(
331*da0073e9SAndroid Build Coastguard Worker        # default properties
332*da0073e9SAndroid Build Coastguard Worker        "ShapePrecompute",
333*da0073e9SAndroid Build Coastguard Worker        "Lower",
334*da0073e9SAndroid Build Coastguard Worker        "CanBeReused",
335*da0073e9SAndroid Build Coastguard Worker    )
336*da0073e9SAndroid Build Coastguard Worker    opkind: str | None = None
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    def __init__(
339*da0073e9SAndroid Build Coastguard Worker        self,
340*da0073e9SAndroid Build Coastguard Worker        func: FunctionSchema,
341*da0073e9SAndroid Build Coastguard Worker        properties: LazyIrProperties | None = None,
342*da0073e9SAndroid Build Coastguard Worker        *,
343*da0073e9SAndroid Build Coastguard Worker        symint: bool,
344*da0073e9SAndroid Build Coastguard Worker    ) -> None:
345*da0073e9SAndroid Build Coastguard Worker        if properties:
346*da0073e9SAndroid Build Coastguard Worker            self.properties = properties
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker        self.func = func
349*da0073e9SAndroid Build Coastguard Worker        self.symint = symint
350*da0073e9SAndroid Build Coastguard Worker        positional_args: list[LazyArgument] = []
351*da0073e9SAndroid Build Coastguard Worker        for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
352*da0073e9SAndroid Build Coastguard Worker            if arg_field == "self_arg" and func.arguments.self_arg is not None:
353*da0073e9SAndroid Build Coastguard Worker                arg = func.arguments.self_arg.argument
354*da0073e9SAndroid Build Coastguard Worker                positional_args.append(
355*da0073e9SAndroid Build Coastguard Worker                    LazyArgument(arg, self.properties, symint=symint)
356*da0073e9SAndroid Build Coastguard Worker                )
357*da0073e9SAndroid Build Coastguard Worker            elif getattr(func.arguments, arg_field) is not None:
358*da0073e9SAndroid Build Coastguard Worker                positional_args.extend(
359*da0073e9SAndroid Build Coastguard Worker                    LazyArgument(arg, self.properties, symint=symint)
360*da0073e9SAndroid Build Coastguard Worker                    for arg in getattr(func.arguments, arg_field)
361*da0073e9SAndroid Build Coastguard Worker                )
362*da0073e9SAndroid Build Coastguard Worker        self.positional_args = tuple(positional_args)
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        keyword_args: list[LazyArgument] = []
365*da0073e9SAndroid Build Coastguard Worker        for arg_field in [
366*da0073e9SAndroid Build Coastguard Worker            "pre_tensor_options_kwarg_only",
367*da0073e9SAndroid Build Coastguard Worker            "tensor_options",
368*da0073e9SAndroid Build Coastguard Worker            "post_tensor_options_kwarg_only",
369*da0073e9SAndroid Build Coastguard Worker            "out",
370*da0073e9SAndroid Build Coastguard Worker        ]:
371*da0073e9SAndroid Build Coastguard Worker            curr_args = getattr(func.arguments, arg_field)
372*da0073e9SAndroid Build Coastguard Worker            if curr_args is not None:
373*da0073e9SAndroid Build Coastguard Worker                if isinstance(curr_args, TensorOptionsArguments):
374*da0073e9SAndroid Build Coastguard Worker                    curr_args = curr_args.all()
375*da0073e9SAndroid Build Coastguard Worker                for arg in curr_args:
376*da0073e9SAndroid Build Coastguard Worker                    if isGeneratorType(arg.type):
377*da0073e9SAndroid Build Coastguard Worker                        assert (
378*da0073e9SAndroid Build Coastguard Worker                            self.generator_arg is None
379*da0073e9SAndroid Build Coastguard Worker                        ), "We expect there is only one generator arg"
380*da0073e9SAndroid Build Coastguard Worker                        self.generator_arg = NamedCType(
381*da0073e9SAndroid Build Coastguard Worker                            arg.name, arg.type  # type:ignore[arg-type]
382*da0073e9SAndroid Build Coastguard Worker                        )
383*da0073e9SAndroid Build Coastguard Worker                keyword_args.extend(
384*da0073e9SAndroid Build Coastguard Worker                    LazyArgument(arg, self.properties, symint=symint)
385*da0073e9SAndroid Build Coastguard Worker                    for arg in curr_args
386*da0073e9SAndroid Build Coastguard Worker                )
387*da0073e9SAndroid Build Coastguard Worker        self.keyword_args = tuple(keyword_args)
388*da0073e9SAndroid Build Coastguard Worker        self.name = func.name
389*da0073e9SAndroid Build Coastguard Worker        self.returns = func.returns
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    @property
392*da0073e9SAndroid Build Coastguard Worker    def node_name(self) -> str:
393*da0073e9SAndroid Build Coastguard Worker        """
394*da0073e9SAndroid Build Coastguard Worker        Return camel-case version of op in node.
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        Note: This function also appends any `overload_name` in the operation.
397*da0073e9SAndroid Build Coastguard Worker        For example, if the op is `bitwise_and.Tensor`, the returned name
398*da0073e9SAndroid Build Coastguard Worker        will be `BitwiseAndTensor`.
399*da0073e9SAndroid Build Coastguard Worker        """
400*da0073e9SAndroid Build Coastguard Worker        op_name = f"{self.name.name}_{self.name.overload_name}".lower()
401*da0073e9SAndroid Build Coastguard Worker        return "".join(word.capitalize() or "" for word in op_name.split("_"))
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker    @property
404*da0073e9SAndroid Build Coastguard Worker    def aten_name(self) -> str:
405*da0073e9SAndroid Build Coastguard Worker        return str(self.name.name)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    @property
408*da0073e9SAndroid Build Coastguard Worker    def base_name(self) -> str:
409*da0073e9SAndroid Build Coastguard Worker        return f"{self.name.name.base}"
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker    def filtered_args(
412*da0073e9SAndroid Build Coastguard Worker        self,
413*da0073e9SAndroid Build Coastguard Worker        positional: bool = True,
414*da0073e9SAndroid Build Coastguard Worker        keyword: bool = True,
415*da0073e9SAndroid Build Coastguard Worker        values: bool = True,
416*da0073e9SAndroid Build Coastguard Worker        scalars: bool = True,
417*da0073e9SAndroid Build Coastguard Worker        generator: bool = True,
418*da0073e9SAndroid Build Coastguard Worker    ) -> list[LazyArgument]:
419*da0073e9SAndroid Build Coastguard Worker        # This function maintains the sorted order of arguments but provides different filtered views.
420*da0073e9SAndroid Build Coastguard Worker        # Some parts of the code care about kwargs vs args (TS lowerings),
421*da0073e9SAndroid Build Coastguard Worker        # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
422*da0073e9SAndroid Build Coastguard Worker        # Generators are special cased, as they are needed for fallback/shape-inference but not supported
423*da0073e9SAndroid Build Coastguard Worker        # in TS lowerings and therefore also omitted from lazy IR.
424*da0073e9SAndroid Build Coastguard Worker        args: list[LazyArgument] = []
425*da0073e9SAndroid Build Coastguard Worker        if positional:
426*da0073e9SAndroid Build Coastguard Worker            args.extend(self.positional_args)
427*da0073e9SAndroid Build Coastguard Worker        if keyword:
428*da0073e9SAndroid Build Coastguard Worker            args.extend(self.keyword_args)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        if values and scalars and generator:
431*da0073e9SAndroid Build Coastguard Worker            return args
432*da0073e9SAndroid Build Coastguard Worker        elif values and scalars:
433*da0073e9SAndroid Build Coastguard Worker            return [a for a in args if not a.is_generator]
434*da0073e9SAndroid Build Coastguard Worker        elif values:
435*da0073e9SAndroid Build Coastguard Worker            return [a for a in args if a.is_lazy_value]
436*da0073e9SAndroid Build Coastguard Worker        elif scalars:
437*da0073e9SAndroid Build Coastguard Worker            return [
438*da0073e9SAndroid Build Coastguard Worker                a
439*da0073e9SAndroid Build Coastguard Worker                for a in args
440*da0073e9SAndroid Build Coastguard Worker                if not a.is_lazy_value and (generator or not a.is_generator)
441*da0073e9SAndroid Build Coastguard Worker            ]
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        return []
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    @property
446*da0073e9SAndroid Build Coastguard Worker    def positional_values(self) -> list[LazyArgument]:
447*da0073e9SAndroid Build Coastguard Worker        return self.filtered_args(
448*da0073e9SAndroid Build Coastguard Worker            positional=True, keyword=False, values=True, scalars=False
449*da0073e9SAndroid Build Coastguard Worker        )
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    @property
452*da0073e9SAndroid Build Coastguard Worker    def positional_scalars(self) -> list[LazyArgument]:
453*da0073e9SAndroid Build Coastguard Worker        return self.filtered_args(
454*da0073e9SAndroid Build Coastguard Worker            positional=True, keyword=False, values=False, scalars=True
455*da0073e9SAndroid Build Coastguard Worker        )
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker    @property
458*da0073e9SAndroid Build Coastguard Worker    def keyword_values(self) -> list[LazyArgument]:
459*da0073e9SAndroid Build Coastguard Worker        return self.filtered_args(
460*da0073e9SAndroid Build Coastguard Worker            positional=False, keyword=True, values=True, scalars=False
461*da0073e9SAndroid Build Coastguard Worker        )
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    @property
464*da0073e9SAndroid Build Coastguard Worker    def keyword_scalars(self) -> list[LazyArgument]:
465*da0073e9SAndroid Build Coastguard Worker        return self.filtered_args(
466*da0073e9SAndroid Build Coastguard Worker            positional=False, keyword=True, values=False, scalars=True
467*da0073e9SAndroid Build Coastguard Worker        )
468