xref: /aosp_15_r20/external/pytorch/torchgen/dest/lazy_ir.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import itertools
4from abc import ABC
5from dataclasses import dataclass
6from typing import Any
7
8import torchgen.api.dispatcher as dispatcher
9from torchgen.api.lazy import (
10    getValueT,
11    isValueType,
12    LazyArgument,
13    LazyIrProperties,
14    LazyIrSchema,
15    tensorListValueT,
16)
17from torchgen.api.translate import translate
18from torchgen.api.types import (
19    BaseCType,
20    Binding,
21    deviceT,
22    DispatcherSignature,
23    kernel_signature,
24    NativeSignature,
25    OptionalCType,
26    VectorCType,
27)
28from torchgen.context import method_with_native_function
29from torchgen.dest.lazy_ts_lowering import ts_lowering_body
30from torchgen.model import (
31    Argument,
32    BackendIndex,
33    BackendMetadata,
34    BaseTy,
35    BaseType,
36    FunctionSchema,
37    ListType,
38    NativeFunction,
39    NativeFunctionsGroup,
40)
41
42
43def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
44    """
45    Given a LazyArgument,
46    generate a c++ string for materializing an rvalue of that arg for passing into
47    a lazy Node constructor.
48    """
49
50    # TODO: Matching on CType seems wrong; should be matching on Type
51    if isValueType(arg.lazy_type):
52        if isinstance(arg.lazy_type, BaseCType):
53            if arg.is_wrapped_scalar:
54                return f"node_{arg.name}"
55            elif arg.lazy_type.type is tensorListValueT:
56                return f"lazy_{arg.name}_tensorlist"
57            elif arg.is_symint_or_list:
58                return f"GetSymIntValue({arg.name})"
59            return f"lazy_{arg.name}->GetIrValue()"
60        elif isinstance(arg.lazy_type, OptionalCType):
61            if arg.is_symint_or_list:
62                # TODO: I don't understand when you should put lazy_ in the name
63                # or not
64                return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
65            elif arg.is_wrapped_scalar:
66                return f"node_{arg.name}"
67            return (
68                f"lazy_{arg.name} ? "
69                f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
70                "::std::nullopt"
71            )
72        else:
73            raise AssertionError(
74                f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
75            )
76    else:
77        # NB: this is here because right now we aren't treating SymInt[] as a
78        # value type; when we do this needs to move above
79        # NB: we cannot test arg.lazy_type as we've already specified it is an
80        # int64_t and so we cannot distinguish between SymInt and int64_t
81        if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
82            BaseTy.SymInt
83        ):
84            if arg.symint:
85                return f"GetSymIntArrayRefValue({arg.name})"
86            else:
87                return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
88        elif isinstance(arg.lazy_type, VectorCType) and isinstance(
89            arg.lazy_type.elem, BaseCType
90        ):
91            return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
92        elif (
93            isinstance(arg.lazy_type, OptionalCType)
94            and isinstance(arg.lazy_type.elem, VectorCType)
95            and isinstance(arg.lazy_type.elem.elem, BaseCType)
96        ):
97            return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
98        else:
99            return f"{arg.name}"
100
101
102def node_ctor_inputs(schema: LazyIrSchema) -> str:
103    """
104    Produce a formatted string with the arguments as passed into the constructor of a node class.
105    """
106    node_ctor_values = [
107        node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
108    ]
109    return ", ".join(node_ctor_values)
110
111
112def gen_fallback_code(
113    schema: LazyIrSchema,
114    sig: DispatcherSignature | NativeSignature,
115    overload_name: str,
116) -> str:
117    """
118    Generate code that falls back to eager conditioned on a predicate
119    """
120    dispatcher_sig = DispatcherSignature.from_schema(schema.func)
121    exprs = translate(sig.arguments(), dispatcher_sig.arguments())
122    fallback_args = ",\n                ".join([a.expr for a in exprs])
123    if len(overload_name):
124        aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
125    else:
126        aten_op_str = f"ATEN_OP({schema.aten_name})"
127    return f"""
128        if (force_eager_fallback({aten_symbol(schema)})) {{
129            return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
130                {fallback_args}
131            );
132        }}
133"""
134
135
136def aten_symbol(schema: LazyIrSchema) -> str:
137    missing_interned_strings = {
138        "sigmoid_backward",
139    }
140    if schema.aten_name in missing_interned_strings:
141        return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
142
143    if not schema.aten_name.startswith("at::"):
144        return f"at::aten::{schema.aten_name}"
145    else:
146        return schema.aten_name
147
148
149# converts  all tensor-like arguments to meta tensors. Returns:
150# (1) a string containing all of the logic that does the conversions.
151# (2) a context, to be used by translate(), with all of the relevant bindings.
152def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
153    context: list[Binding] = []
154    unwrapped_tensor_args: list[str] = []
155    for arg in sig.arguments():
156        if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
157            unwrapped_name = f"{arg.name}_meta"
158            unwrapped_tensor_args.append(
159                f"auto {unwrapped_name} = to_meta({arg.name});"
160            )
161            context.append(arg.with_name(unwrapped_name))
162        else:
163            context.append(arg)
164    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
165    return unwrap_tensor_args_str, context
166
167
168@dataclass(frozen=True)
169class GenLazyIR(ABC):
170    backend_index: BackendIndex
171    backend_name: str
172    node_base: str
173    use_lazy_shape: bool
174
175    @method_with_native_function
176    def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
177        func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
178        metadata = self.backend_index.get_kernel(
179            f.functional if isinstance(f, NativeFunctionsGroup) else f
180        )
181        schema = LazyIrSchema(
182            func, symint=metadata is not None and metadata.supports_symint()
183        )
184        return self.gen(schema)
185
186    # there is no lowering functionality generated unless this IR base class is subclassed and
187    # implemented as a backend-specific node
188    def lowering_function(self, schema: LazyIrSchema) -> str:
189        return ""
190
191    def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
192        return ""
193
194    def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
195        return f"""bool CanBeReused({node_ctor_args}) const {{
196    return false;
197    }}"""
198
199    def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
200        value_args = schema.filtered_args(values=True, scalars=False)
201        # backends can customize the way the node base class constructor is called,
202        # as long as all of its arguments can be generated from information available from the schema
203        base_ctor_value_args_list = []
204        for arg in value_args:
205            if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
206                base_ctor_value_args_list.append(f"{arg.name}")
207            elif isinstance(arg.lazy_type, OptionalCType):
208                base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
209            else:
210                raise AssertionError(
211                    f"Unsupported type ({arg.lazy_type}) - add support if necessary"
212                )
213        base_ctor_value_args = ", ".join(base_ctor_value_args_list)
214
215        scalar_args = schema.filtered_args(values=False, scalars=True)
216
217        # Shape construction.
218        # Conditionally build shape depending on specified shape property
219        if schema.properties.ShapePrecompute:
220            shape_ctor_arg = "std::move(shapes),"
221        elif schema.properties.ShapeCompute:
222            shape_args = [a.name for a in value_args]
223            shape_args.extend(a.name for a in scalar_args)
224            shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
225        elif schema.properties.ShapeCache:
226            shape_args = [f"operand({i})" for i in range(len(value_args))]
227            shape_args.extend(a.name for a in scalar_args)
228            shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
229        else:
230            shape_ctor_arg = ""
231
232        scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
233
234        return f"""{self.node_base}(
235              {schema.node_name}::ClassOpKind(),
236              OpList{{{base_ctor_value_args}}},
237              {shape_ctor_arg}
238              /* num_outputs */ {len(schema.returns)},
239              torch::lazy::MHash({scalar_hashes}))"""
240
241    def gen(self, schema: LazyIrSchema) -> list[str]:
242        opkind = schema.opkind or aten_symbol(schema)
243
244        # for now, we just want one IR class decl and soon after also the method defs
245        # and we use the functional version not out/inplace.
246        all_args = schema.filtered_args()
247        scalar_args = schema.filtered_args(values=False, scalars=True)
248
249        ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
250        reuse_ctor_args = ", ".join(ctor_args)
251        if self.use_lazy_shape and schema.properties.ShapePrecompute:
252            ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
253        node_ctor_args = ", ".join(ctor_args)
254
255        scalar_initializers = ",\n        ".join(
256            [
257                # This code is just special casing the mapping from string_view -> strings
258                f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
259                if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
260                else f"{a.name}({a.name})"
261                for a in scalar_args
262            ]
263        )
264        if len(scalar_initializers):
265            scalar_initializers = f",\n        {scalar_initializers}"
266        scalar_decls = "\n  ".join(
267            [
268                f"std::string {a.name};"
269                if a.lazy_type.cpp_type() == "c10::string_view"
270                else f"::std::optional<std::string> {a.name};"
271                if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
272                else f"{a.lazy_type.cpp_type()} {a.name};"
273                for a in scalar_args
274            ]
275        )
276        optional_values = [
277            arg.name
278            for arg in schema.filtered_args(values=True, scalars=False)
279            if isinstance(arg.lazy_type, OptionalCType)
280        ]
281        has_optional_decls = "\n  ".join(
282            [f"bool has_{value}: 1;" for value in optional_values]
283        )
284        has_optional_defs = "\n    ".join(
285            [f"has_{value} = !!{value};" for value in optional_values]
286        )
287        members_to_string = []
288        for arg in scalar_args:
289            if isinstance(arg.lazy_type, OptionalCType):
290                value = f"{arg.name}.value()"
291                if arg.is_generator:
292                    value = '"torch.Generator()"'
293                members_to_string.append(
294                    f"""if ({arg.name}.has_value()) {{
295      ss << ", {arg.name}=" << {value};
296    }} else {{
297      ss << ", {arg.name}=null";
298    }}"""
299                )
300            else:
301                members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
302        members_to_string_str = "\n    ".join(members_to_string)
303
304        return [
305            f"""\
306class {schema.node_name} : public {self.node_base} {{
307 public:
308  static torch::lazy::OpKind ClassOpKind() {{
309    return torch::lazy::OpKind({opkind});
310  }}
311
312  {schema.node_name}({node_ctor_args})
313      : {self.node_base_ctor_call(schema)}{scalar_initializers}
314  {{
315    {has_optional_defs}
316  }}
317
318  std::string ToString() const override {{
319    std::stringstream ss;
320    ss << {self.node_base}::ToString();
321    {members_to_string_str}
322    return ss.str();
323  }}
324
325  {self.create_function(schema, reuse_ctor_args)}
326
327  {self.can_be_reused_function(schema, reuse_ctor_args)}
328
329  {self.lowering_function(schema)}
330
331  {scalar_decls}
332  {has_optional_decls}
333
334}};
335
336""",
337        ]
338
339
340@dataclass(frozen=True)
341class GenTSLazyIR(GenLazyIR):
342    def lowering_function(self, schema: LazyIrSchema) -> str:
343        signature = """
344  torch::lazy::TSOpVector Lower(
345      std::shared_ptr<torch::jit::GraphFunction> function,
346      torch::lazy::TSLoweringContext* loctx) const override"""
347
348        if schema.properties.LowerDeclOnly:
349            return f"{signature};"
350        elif schema.properties.Lower:
351            return f"""{signature} {{
352    {ts_lowering_body(schema)}
353  }}
354            """
355        else:
356            return ""
357
358    def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
359        signature = f"static NodePtr Create({node_ctor_args})"
360        if schema.properties.CreateFnDeclOnly:
361            return f"{signature};"
362        elif not schema.properties.CreateFn:
363            return ""
364        return f"""{signature} {{
365    return ReuseOrMakeNode<{schema.node_name}>(data);
366  }}"""
367
368    def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
369        signature = f"bool CanBeReused({node_ctor_args}) const"
370        if schema.properties.CanBeReusedDeclOnly:
371            return f"{signature};"
372        elif not schema.properties.CanBeReused:
373            return ""
374        value_comparison = []
375        for arg in itertools.chain(schema.positional_values, schema.keyword_values):
376            if isinstance(arg.lazy_type, OptionalCType):
377                value_comparison.append(
378                    f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
379                )
380            else:
381                value_comparison.append(f"operand(i++) == {arg.name}")
382        for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
383            if isinstance(arg.lazy_type, OptionalCType):
384                value_comparison.append(
385                    f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
386                )
387            else:
388                value_comparison.append(f"this->{arg.name} == {arg.name}")
389        value_comparison_str = " &&\n        ".join(value_comparison)
390
391        return f"""{signature} {{
392    size_t i = 0;
393    return ({value_comparison_str});
394  }}"""
395
396
397@dataclass(frozen=True)
398class GenLazyNativeFuncDefinition:
399    class_method_name: str
400    backend_index: BackendIndex
401    tensor_class: str
402    gen_forced_fallback_code: bool
403    backend_namespace: str
404    get_tensorlist: str
405    get_tensor_or_wrap_number: str
406    try_get_tensor: str
407    metrics_counter: str
408    create_tensor: str
409    create_from_first_tensor: bool
410    create_aten_from_ltc_tensor: str
411    tuple_aten_from_ltc_tensors: str
412    lazy_tensor_ptr: str
413    get_device_fn: str
414
415    def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
416        value_args = schema.filtered_args(values=True, scalars=False)
417        # Generates lazy_{name} variables for LazyTensors wrapping input tensors
418        lazy_tensor_decls: list[str] = []
419        for arg in value_args:
420            if arg.is_wrapped_scalar:
421                if isinstance(arg.lazy_type, OptionalCType):
422                    lazy_tensor_decls.append(
423                        f"""auto node_{arg.name} = {arg.name} ?
424                std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
425                    GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
426                ::std::nullopt;"""
427                    )
428                else:
429                    lazy_tensor_decls.append(
430                        f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
431                            GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
432                    )
433            elif arg.is_symint_or_list:
434                continue  # values are extracted in isValueType
435            elif isinstance(arg.lazy_type, BaseCType):
436                if arg.lazy_type.type is tensorListValueT:
437                    lazy_tensor_decls.append(
438                        f"auto lazy_{arg.name}_tensorlist = "
439                        f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
440                    )
441                else:
442                    lazy_tensor_decls.append(
443                        f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
444                        f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
445                    )
446            elif isinstance(arg.lazy_type, OptionalCType):
447                assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
448                # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
449                # until we encounter a real world example.
450                lazy_tensor_decls.append(
451                    f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
452                    f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
453                )
454            else:
455                raise AssertionError(
456                    f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
457                )
458        return ("\n        ").join(lazy_tensor_decls)
459
460    def force_eager_fallback(
461        self,
462        func: NativeFunction,
463        schema: LazyIrSchema,
464        metadata: BackendMetadata,
465        sig: DispatcherSignature | NativeSignature,
466    ) -> str:
467        if self.gen_forced_fallback_code:
468            return gen_fallback_code(
469                schema, sig, overload_name=func.func.name.overload_name
470            )
471        return ""
472
473    def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
474        return f"{self.metrics_counter};"
475
476    def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
477        value_args = schema.filtered_args(values=True, scalars=False)
478        scalar_args = schema.filtered_args(values=False, scalars=True)
479        value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
480        optional_device = OptionalCType(BaseCType(deviceT))
481        optional_devices = [
482            a.name for a in scalar_args if a.lazy_type == optional_device
483        ]
484        assert (
485            len(value_types_names) > 0 or len(optional_devices) > 0
486        ), "Expected at least one Value or Device type"
487        get_device_str = (
488            f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
489        )
490        return f"""auto common_device = {get_device_str};
491        TORCH_INTERNAL_ASSERT(common_device);
492        """
493
494    def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
495        metadata = self.backend_index.get_kernel(func)
496        assert metadata is not None
497        all_args = schema.filtered_args()
498        returns_length = len(schema.returns)
499        # call the meta kernel if it exists, to compute output shape/dtype for our IR
500        # Note [Generated LTC Shape Functions]
501        # LTC uses meta tensors from core to do shape inference when possible, and otherwise
502        # we generate a shape function declaration that needs to be manually implemented.
503        # How do we detect which ops are eligible to use meta tensors?
504        # In general we should be able to use meta tensors not just on structured operators,
505        # but also on composite operators that are implemented in terms of structured kernels.
506        # We don't currently have a way of knowing at codegen time which ops are implemented that way.
507        # This is the case for all view and view_copy operators however, so we're going to
508        # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
509        is_view_copy_op = "view_copy" in func.tags
510        is_structured = func.structured or func.structured_delegate is not None
511        if is_structured or is_view_copy_op:
512            meta_out = """
513std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
514            if returns_length > 1:
515
516                def this_shape(i: int) -> str:
517                    return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
518
519                shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
520                meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
521
522            # Convert tensor args to the meta device and call it.
523            # (We can't pass in the input tensors directly, because they are "functional wrappers".
524            # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
525            # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
526            dispatcher_sig = DispatcherSignature.from_schema(func.func)
527            meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
528            meta_call_args = [
529                e.expr
530                for e in translate(
531                    meta_call_ctx, dispatcher_sig.arguments(), method=False
532                )
533            ]
534            if is_view_copy_op:
535                # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
536                assert func.has_composite_explicit_autograd_non_functional_kernel
537                dispatch_ns = "compositeexplicitautogradnonfunctional"
538            else:
539                dispatch_ns = "meta"
540            aten_name = schema.aten_name
541            # TODO: this is trolling
542            if func.func.has_symint() and metadata.supports_symint():
543                aten_name += "_symint"
544            shape_str = f"""\
545        {meta_conversion_str}
546        auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
547        {meta_out}"""
548        else:
549            shape_sig = ComputeShapeSignature(
550                metadata.kernel, func, symint=metadata.supports_symint()
551            )
552            shape_str = f"""
553            auto shapes = {shape_sig.shape_call};"""
554
555        shape_str += f"""
556            TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
557
558        # Calculating which dimensions are symbolic
559        func_schema_str = "aten::" + str(func.func)
560        shape_str += f"""
561            if(torch::lazy::symbolicShapeEnabled()){{
562                std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
563                const char* schema_str = "{func_schema_str}";
564                applySymbolicShapesOnLT(schema_str, inputs, shapes);
565            }}
566        """
567        return shape_str
568
569    def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
570        node_ctor_input_str = node_ctor_inputs(schema)
571        return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
572        if (!node) {{
573            {self.shape_inference(func, schema)}
574            node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
575            CacheNode(node);
576        }}
577        """
578
579    def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
580        # xla uses an instance method for tensor creation, for the time being
581        if self.create_from_first_tensor:
582            # TODO(whc) remove this if XLA switches to using static method for creation
583            assert (
584                first_tensor_name is not None
585            ), "Requires first tensor to create lazy tensor"
586            return f"{first_tensor_name}.{self.create_tensor}"
587        return f"{self.backend_namespace}::{self.create_tensor}"
588
589    def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
590        returns_length = len(schema.returns)
591        value_args = schema.filtered_args(values=True, scalars=False)
592        value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
593        first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
594        bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
595                {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
596
597        if returns_length > 1:
598            assert (
599                len(value_types_names) > 0
600            ), "Code below assumes there is at least one tensor arg"
601            bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
602        for (int i = 0; i < {returns_length}; i++) {{
603            lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
604        }}
605        auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
606
607        if schema.name.name.inplace or func.func.is_out_fn():
608            assert returns_length == 1, (
609                "We assumed there was no such case where an op is an in-place variant "
610                f"and has tuple outputs, but got tuple of len {returns_length}."
611            )
612            bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
613        auto& result = {first_tensor_name};"""
614
615        bridge_str += """
616        return result;"""
617        return bridge_str
618
619    @method_with_native_function
620    def __call__(self, func: NativeFunction) -> list[str]:
621        sig = kernel_signature(func, self.backend_index)
622        metadata = self.backend_index.get_kernel(func)
623        assert metadata is not None
624        schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
625        return [
626            f"""\
627    {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
628        {self.force_eager_fallback(func, schema, metadata, sig)}
629        {self.metrics(func, schema)}
630        {self.get_device(func, schema)}
631        {self.lazy_tensor_decls(func, schema)}
632        {self.build_ir_node(func, schema)}
633        {self.return_aten_tensor(func, schema)}
634    }}\n
635    """
636        ]
637
638
639class ComputeShapeSignature:
640    """
641    Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
642    """
643
644    def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
645        self.__schema = LazyIrSchema(f.func, symint=symint)
646        self.__dispatch_args = ", ".join(
647            [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
648        )
649        self.__call_args = ", ".join(
650            [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
651        )
652        self.__kernel_name = kernel_name
653
654    def __decl_suffix(self) -> str:
655        return f"{self.__kernel_name}({self.__dispatch_args})"
656
657    def __call_suffix(self) -> str:
658        return f"{self.__kernel_name}({self.__call_args})"
659
660    @property
661    def shape_decl(self) -> str:
662        return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
663
664    @property
665    def shape_call(self) -> str:
666        return f"torch::lazy::compute_shape_{self.__call_suffix()}"
667
668
669@dataclass(frozen=True)
670class GenLazyShapeInferenceDefinition:
671    backend_index: BackendIndex
672    tensor_class: str
673
674    @method_with_native_function
675    def __call__(self, f: NativeFunction) -> list[str]:
676        metadata = self.backend_index.get_kernel(f)
677        assert metadata is not None
678
679        # See Note [Generated LTC Shape Functions]
680        is_view_copy_op = "view_copy" in f.tags
681        is_structured = f.structured or f.structured_delegate is not None
682        if is_structured or is_view_copy_op:
683            return []
684        else:
685            shape_sig = ComputeShapeSignature(
686                metadata.kernel, f, symint=metadata.supports_symint()
687            )
688            return ["\n".join([f"{shape_sig.shape_decl};"])]
689
690
691def generate_non_native_lazy_ir_nodes(
692    non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
693) -> list[str]:
694    """Generate the non-native lazy IR node classes"""
695    nodes = []
696    for op in non_native:
697        # Set default properties for Non-Native IRs
698        properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
699        for p in op.get("properties", []):
700            setattr(properties, p, True)
701
702        # non-native is assumed to want symint bindings if you wrote symint
703        schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
704        schema.opkind = op.get("opkind")
705        nodes.append(gen_lazy_ir.gen(schema)[0])
706
707    return nodes
708