xref: /aosp_15_r20/external/pytorch/torchgen/native_function_generation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from collections import defaultdict
4from typing import Sequence
5
6import torchgen.api.dispatcher as dispatcher
7from torchgen.api.translate import translate
8from torchgen.api.types import Binding, DispatcherSignature, Expr
9from torchgen.context import with_native_function
10from torchgen.model import (
11    Annotation,
12    Argument,
13    BackendIndex,
14    BackendMetadata,
15    BaseOperatorName,
16    BaseTy,
17    BaseType,
18    DEFAULT_KERNEL_NAMESPACE,
19    DeviceCheckType,
20    DispatchKey,
21    FunctionSchema,
22    NativeFunction,
23    NativeFunctionsGroup,
24    OperatorName,
25    Return,
26    SchemaKind,
27    Variant,
28)
29from torchgen.utils import concatMap
30
31
32# See Note: [Out ops with functional variants that don't get grouped properly]
33OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
34    # This has a functional variant, but it's currently marked private.
35    # This function should be marked private as well (*_backward ops aren't exposed to python anyway).
36    "adaptive_avg_pool3d_backward.grad_input",
37    # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
38    # Maybe we can kill this operator in favor of convolution_backward?
39    "_slow_conv2d_backward.grad_input",
40]
41
42
43# See Note: [Mutable ops that cannot get an out variant]
44MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
45    # should be out=?
46    "_cummax_helper",
47    # should be out=?
48    "_cummin_helper",
49]
50
51# All of these operators don't have any tensor like returns
52FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
53    "_assert_async",  # no return
54    "_assert_async.msg",  # no return
55    "_cslt_sparse_mm_search",  # returns an int
56    "_assert_scalar",  # no return
57    "_dimI",  # returns an int
58    "_dimV",  # returns an int
59    "_has_same_storage_numel",  # returns a boolean
60    "_linalg_check_errors",  # no return
61    "_local_scalar_dense",  # returns a Scalar
62    "_nested_tensor_from_mask_left_aligned",  # returns a boolean
63    "_nnz",  # returns an int
64    "_use_cudnn_ctc_loss",  # returns a boolean
65    "_use_cudnn_ctc_loss.Tensor",  # returns a boolean
66    "_validate_compressed_sparse_indices",  # no return
67    "allclose",  # returns a boolean
68    "dense_dim",  # returns an int
69    "equal",  # returns a boolean
70    "is_coalesced",  # returns an boolean
71    "is_pinned",  # returns a boolean
72    "is_same_size",  # returns a boolean
73    "is_set_to",  # returns a boolean
74    "q_per_channel_axis",  # returns an int
75    "q_scale",  # returns a float
76    "q_zero_point",  # returns an int
77    "qscheme",  # returns a QScheme
78    "record_stream",  # no return
79    "sparse_dim",  # returns an int
80    "sym_constrain_range",  # no return
81    "sym_constrain_range_for_size",  # no return
82    "_nested_tensor_storage_offsets",  # returns a vector of ints
83    "_chunk_grad_outputs_efficient_attention",  # returns a bool
84    "_fused_sdp_choice",  # returns an int
85    "_print",  # no return
86    "_sink_tokens",  # no return
87    "_nested_get_ragged_idx",  # returns an int
88]
89
90INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
91    # polygamma and polygamma.out both exist, but have a
92    # pre-self arg (while polygamma_ does not)
93    # We should either fix this schema so it can be grouped properly,
94    # or allow the codegen to generate new functional/out= NativeFunctions for this op
95    # (which would require changing its overload name to prevent overload ambiguity).
96    "polygamma_"
97]
98
99
100# Groups "similar" NativeFunctions together
101# example add.Tensor, add_.Tensor, add.out
102# "similar" NativeFunctions are all expected to have an identical `signature()`,
103# But have differing SchemaKinds.
104def pre_group_native_functions(
105    native_functions: Sequence[NativeFunction],
106) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
107    pre_grouped_native_functions: dict[
108        FunctionSchema, dict[SchemaKind, NativeFunction]
109    ] = defaultdict(dict)
110    for f in native_functions:
111        d = pre_grouped_native_functions[f.func.signature()]
112        assert f.func.kind() not in d
113        d[f.func.kind()] = f
114    return pre_grouped_native_functions
115
116
117# Returns the out variant overload name given a base function overload name
118def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
119    return "out" if not overload_name else f"{overload_name}_out"
120
121
122# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
123# Example before:
124#   _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
125# Example after:
126#   _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
127def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
128    # Generating an out= schema from an inplace schema.
129    assert func.kind() == SchemaKind.inplace
130    assert func.arguments.self_arg is not None
131    # The new out= schema has:
132    # - a new out argument with the same type as "func" (but with a mutable annotation)
133    # - The returns (if any) now alias the out= argument instead of "func"
134    # - an "out" overload name
135    return FunctionSchema(
136        name=func.name.remove_inplace().with_overload(
137            get_expected_out_variant_overload_name(func.name.overload_name)
138        ),
139        arguments=func.arguments.remove_self_annotation().with_out_args(
140            [
141                Argument(
142                    name="out",
143                    type=func.arguments.self_arg.argument.type,
144                    default=None,
145                    annotation=func.arguments.self_arg.argument.annotation,
146                )
147            ]
148        ),
149        returns=func.returns,
150    )
151
152
153# Helper function: given a functional FunctionSchema, generate its corresponding out= variant
154# Example before:
155#   _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
156#       bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
157# Example after:
158#   _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None,
159#       Tensor(a!) out) -> Tensor(a!)
160def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
161    # Generating an out= schema from a functional schema.
162    assert func.kind() == SchemaKind.functional
163
164    new_returns, new_out_args = generate_out_args_from_schema(func)
165    # The new out= schema has:
166    # - one or more new out argument(s) with the same type as returns (but with a mutable annotation)
167    # - The returns now alias the out= arguments
168    # - an "_out" overload name
169    return FunctionSchema(
170        name=func.name.with_overload(
171            get_expected_out_variant_overload_name(func.name.overload_name)
172        ),
173        arguments=func.arguments.signature().with_out_args(
174            new_out_args,
175        ),
176        returns=tuple(new_returns),
177    )
178
179
180# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
181def generate_out_args_from_schema(
182    func: FunctionSchema,
183) -> tuple[list[Return], list[Argument]]:
184    # More of a sanity check - our existing restrictions on schemas should enforce that
185    # mutable schema kinds never return their mutable arguments.
186    assert not any(
187        r.annotation is not None and r.annotation.is_write for r in func.returns
188    )
189
190    tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
191    assert len(tensorlike_rets) > 0
192
193    used_annotations = concatMap(
194        lambda a: [] if a.annotation is None else a.annotation.alias_set,
195        func.arguments.flat_all,
196    )
197    valid_annotations = [
198        x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
199    ]
200
201    all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
202
203    new_out_args: list[Argument] = []
204    # The end result of new_returns is that:
205    # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
206    # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
207    new_returns: list[Return] = []
208    for i, r in enumerate(func.returns):
209        if r.type.is_tensor_like():
210            new_out = Argument(
211                name="out" if len(func.returns) == 1 else f"out{i}",
212                type=r.type,
213                default=None,
214                annotation=Annotation.parse(f"{valid_annotations[i]}!"),
215            )
216            new_out_args.append(new_out)
217            if all_rets_are_tensors:
218                # The convention for out= schemas is that they only return their out arguments
219                # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
220                new_ret = Return(
221                    name=None, type=new_out.type, annotation=new_out.annotation
222                )
223                new_returns.append(new_ret)
224        else:
225            new_returns.append(r)
226    return new_returns, new_out_args
227
228
229# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
230# Example before:
231#   _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
232# Example after:
233#   _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))  # noqa: B950
234def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
235    # Generating an out= schema from a mutable schema.
236    assert func.kind() == SchemaKind.mutable
237    # The new out= schema has:
238    # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
239    #   (if the argument is a tensor then we also return it for method chaining,
240    #   otherwise we return nothing)
241    # - an "out" overload name
242    #
243    # Note that:
244    # (1) This also means that we can *only* generate an out= variant from a mutable schema
245    #     if the mutable schema has at least one tensor-like non-aliasing return.
246    # (2) The generated out= variant still has mutable positional arguments,
247    #     but if necessary we could probably add another out= variant that also
248    #     functionalizes the mutable arguments (a functional_out variant)
249
250    new_returns, new_out_args = generate_out_args_from_schema(func)
251
252    return FunctionSchema(
253        name=func.name.remove_inplace().with_overload(
254            get_expected_out_variant_overload_name(func.name.overload_name)
255        ),
256        arguments=func.arguments.with_out_args(new_out_args),
257        returns=tuple(new_returns),
258    )
259
260
261# This function, given function of one SchemaKind, as well as a target SchemaKind,
262# generates a new NativeFunction with the same properties, but using the target SchemaKind.
263# We only actually generate functions for either functional or out= SchemaKinds.
264# This function returns a tuple, with:
265# - The generated NativeFunction
266# - a dictionary of `BackendIndex` objects, describing which dispatch keys
267#   we will generate kernels for, for the new NativeFunction.
268#   Details are in the function, but we only generate composite kernels (in some cases) today.
269def generate_function(
270    f: NativeFunction, k: SchemaKind
271) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
272    from torchgen.api import cpp
273
274    if k == SchemaKind.functional:
275        assert f.func.kind() != SchemaKind.functional
276        # The new "functional" NativeFunction has:
277        # - any mutable arguments have been converted into (immutable) returns.
278        #   (if a mutable argument was not also a return, it gets converted to one)
279        # - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
280        #   See Note [Overload Ambiguity With Functional Variants]
281        # The default grouping logic in signature() actually already does this,
282        # so we can piggy-back off it (but we still want return names)
283        func = f.func.signature(keep_return_names=True).with_name(
284            OperatorName(
285                name=BaseOperatorName(
286                    base=f.func.name.name.base,
287                    inplace=False,
288                    dunder_method=f.func.name.name.dunder_method,
289                    # See Note [Overload Ambiguity With Functional Variants]
290                    functional_overload=f.func.kind() == SchemaKind.mutable,
291                ),
292                overload_name=f.func.name.overload_name,
293            )
294        )
295    elif k == SchemaKind.out:
296        # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
297        # but at least today, there is no good reason to actually use them.
298        # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
299        if f.func.kind() == SchemaKind.inplace:
300            func = self_to_out_signature(f.func)
301        elif f.func.kind() == SchemaKind.mutable:
302            func = mutable_to_out_signature(f.func)
303        elif f.func.kind() == SchemaKind.functional:
304            func = functional_to_out_signature(f.func)
305        else:
306            raise AssertionError(
307                "We only bother generating out= functions from either inplace or mutable or functional variants"
308            )
309    else:
310        raise AssertionError(
311            "We currently only generate either functional or out= NativeFunctions"
312        )
313
314    # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to
315    # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and
316    # `randn.generator_with_names_out`.
317    kernel_name = (
318        func.name.unambiguous_name()
319        if func.kind() == SchemaKind.out
320        else cpp.name(func)
321    )
322    if f.func.has_symint():
323        kernel_name += "_symint"
324    backend_metadata = {
325        DispatchKey.CompositeExplicitAutograd: {
326            func.name: BackendMetadata(
327                kernel=kernel_name,
328                structured=False,
329                cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
330            )
331        }
332    }
333    tags = {"generated"} | set(
334        f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"}
335    )
336
337    return (
338        NativeFunction(
339            func=func,
340            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
341            # These generated fn's aren't meant to be user friendly- don't generate methods.
342            variants={Variant.function},
343            structured=False,
344            structured_delegate=None,
345            structured_inherits=None,
346            precomputed=None,
347            autogen=[],
348            ufunc_inner_loop={},
349            manual_kernel_registration=False,
350            manual_cpp_binding=False,
351            python_module=None,
352            category_override=None,
353            device_guard=False,
354            device_check=DeviceCheckType.NoCheck,
355            loc=f.loc,
356            cpp_no_default_args=set(),
357            is_abstract=f.is_abstract,
358            has_composite_implicit_autograd_kernel=False,
359            has_composite_implicit_autograd_nested_tensor_kernel=False,
360            has_composite_explicit_autograd_kernel=True,
361            has_composite_explicit_autograd_non_functional_kernel=False,
362            # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
363            # which NativeFunction objects did not come directly from native_functions.yaml.
364            tags=tags,
365            namespace=f.namespace,
366        ),
367        backend_metadata,
368    )
369
370
371# This function is responsible for adding generated NativeFunctions which don't appear
372# explicitly in the codegen.
373# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
374# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
375# (Maybe we should make a friendly API for this)
376#
377# Note: this function *mutates* its two inputs,
378# adding the new NativeFunctions / BackendMetadata to them
379def add_generated_native_functions(
380    rs: list[NativeFunction],
381    indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
382) -> None:
383    # The main code for generating new NativeFunctions
384    # First we group of NativeFunctions by schema kind,
385    # then we detect which ones are missing and generate them.
386    pre_grouped_native_functions = pre_group_native_functions(rs)
387    for d in pre_grouped_native_functions.values():
388        has_functional = SchemaKind.functional in d
389        has_inplace = SchemaKind.inplace in d
390        has_mutable = SchemaKind.mutable in d
391        has_out = SchemaKind.out in d
392
393        # We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
394        # (1) If an operator has an inplace/out= variant but no functional variant, we can generate
395        #     a simple functional variant that the functionalization pass can consume.
396        # (2) If an operator has an inplace or functional but no out= variant, we generate an out=
397        #     variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
398        #     while maintaining the constraint that the out= variant is "required".
399        if has_mutable or has_inplace or has_out or has_functional:
400            # Don't bother generating functions trio's for native functions that bypass the dispatcher.
401            are_manual = all(f.manual_cpp_binding for f in d.values())
402            # Don't bother generating functional + out= variants for view operators
403            # set_ is technically an inplace_view, but for now it is treated
404            # as a normal inplace op in the codegen
405            has_view_ops = any(
406                f.is_view_op and str(f.func.name.name) != "set_" for f in d.values()
407            )
408            # Don't generate the other variants for CompositeImplicitAutograd operators.
409            # We could probably do this, but the main benefit of generating the function triplets
410            # is for transforms that need them, and transforms don't need to act directly
411            # on CompositeImplicitAutograd operators (since we let them decompose).
412            are_composite_implicit = all(
413                f.has_composite_implicit_autograd_kernel for f in d.values()
414            )
415            if are_manual or has_view_ops or are_composite_implicit:
416                continue
417            if has_out and len(d.values()) == 1:
418                # Note: [Out ops with functional variants that don't get grouped properly]
419                # In theory we could validly have an out= operator in native_functions.yaml
420                # that has no other variants.
421                # But today, all of the operators where that's the case actually do have
422                # functional variants, that we are just unable to pair up properly.
423                # I think banning this all together is probably safer
424                # (you can always add a functional variant yourself if you want to add a new out= operator).
425                #
426                # We should probably fix the existing cases; this check is to prevent us from adding more over time.
427                if (
428                    str(d[SchemaKind.out].func.name)
429                    not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
430                ):
431                    raise AssertionError(
432                        f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
433                    )
434                continue
435
436            # Some inplace ops that have problematic schemas (that we should fix), which prevent us
437            # from generating out= and functional variants
438            if (
439                has_inplace
440                and str(d[SchemaKind.inplace].func.name)
441                in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
442            ):
443                continue
444
445            base_fn = (
446                d[SchemaKind.inplace]
447                if has_inplace
448                else d[SchemaKind.mutable]
449                if has_mutable
450                else d[SchemaKind.out]
451                if has_out
452                else d[SchemaKind.functional]
453            )
454
455            # Note: [Mutable ops that cannot get an out variant]
456            # We can only generate an out= variant if either:
457            # - the original function has tensor-like returns (since we can convert them to out kwargs)
458            # - or it's inplace (since we can convert `self` to an out kwarg)
459            # There are only two functions that don't fit this criteria today though,
460            # and they both look like they should be fixed to be out= variants,
461            # so if feels safer to ban this schema all-together
462            base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any(
463                r.type.is_tensor_like() for r in base_fn.func.returns
464            )
465            # Note: [Loosen the assertion that all functional should have out variant]
466            # By design all functional operators should have our variants. The needs_out check
467            # is loosening this requirement, changing it to only generate out variant if there's
468            # an `autogen` block in the native function, in the long run it should be removed.
469            # FIXME: Remove this after figuring out CI job failures related to min, max, mean
470            needs_out = any("out" in str(op_name) for op_name in base_fn.autogen)
471            gets_out_variant = not has_out and base_fn_valid and needs_out
472            if not has_out and not base_fn_valid:
473                if (
474                    str(base_fn.func.name)
475                    not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
476                    and str(base_fn.func.name)
477                    not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
478                ):
479                    raise AssertionError(
480                        f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}.
481This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If
482out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list."""
483                    )
484
485            # Generate an out= variant
486            if gets_out_variant:
487                fn, metadata = generate_function(base_fn, SchemaKind.out)
488                d[SchemaKind.out] = fn
489                BackendIndex.grow_index(indices, metadata)
490                rs.append(fn)
491
492            # Generate a functional variant, but only do it if the operator got an out= variant
493            # (Functional variants are only useful if we can group up the variants,
494            # which we can only do if they have an out= variant)
495            if not has_functional and (has_out or gets_out_variant):
496                fn, metadata = generate_function(base_fn, SchemaKind.functional)
497                d[SchemaKind.functional] = fn
498                BackendIndex.grow_index(indices, metadata)
499                rs.append(fn)
500
501
502def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
503    assert len(rets) == len(names)
504    if len(rets) == 0:
505        return ""
506    elif len(rets) == 1:
507        return f"return {names[0]};"
508    else:
509        return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
510
511
512# Given a function, and the name of a variable corresponding to the output of that function,
513# gather up all of the individual returns that are not aliased
514def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
515    aliased_rets = func.aliased_return_names()
516    non_aliased_names = []
517    is_out_var_a_tuple = len(func.returns) > 1
518    for i, r in enumerate(aliased_rets):
519        if r is None:
520            non_aliased_names.append(
521                f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
522            )
523    return non_aliased_names
524
525
526# Generates functional kernels in terms of their inplace.mutable counterparts.
527# We only do this for "generated" NativeFunctions
528@with_native_function
529def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
530    # We should only be generating these for code-generated NativeFunctions
531    if "generated" not in g.functional.tags:
532        return None
533    # And we always write the kernel for a generated op in terms of a non-generated op.
534    if g.inplace is not None and "generated" not in g.inplace.tags:
535        target_f = g.inplace
536    elif g.mutable is not None and "generated" not in g.mutable.tags:
537        target_f = g.mutable
538    else:
539        # We should be guaranteed to have a valid inplace/mutable variant to call into.
540        # See Note: [Mutable Ops Not Using Functionalization]
541        raise AssertionError(str(g.functional.func))
542
543    sig = DispatcherSignature(g.functional.func)
544    target_sig = DispatcherSignature(target_f.func)
545
546    context: list[Binding | Expr] = []
547    clone_mutable_inputs = []
548    cloned_return_names = []
549    # We can't just directly pass all of the arguments from the functional op into the mutating op.
550    # We need to check for which inputs to the mutating operator are mutable,
551    # and clone those inputs first.
552    for a_curr, a_tgt in zip(
553        dispatcher.jit_arguments(g.functional.func),
554        dispatcher.jit_arguments(target_f.func),
555    ):
556        if a_tgt.annotation is not None and a_tgt.annotation.is_write:
557            clone_mutable_inputs.append(
558                f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
559            )
560            context.append(
561                Expr(
562                    expr=f"{a_curr.name}_clone",
563                    type=dispatcher.argument_type(a_curr, binds=a_curr.name),
564                )
565            )
566            # Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
567            cloned_return_names.append(f"{a_curr.name}_clone")
568        else:
569            context.append(dispatcher.argument(a_curr))
570    exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
571
572    out_name = "output"
573    maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
574    inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
575    ret_str = return_str(
576        g.functional.func.returns, inner_return_names + cloned_return_names
577    )
578
579    clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
580    return f"""
581{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
582  {clone_mutable_inputs_str}
583  {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
584  {ret_str}
585}}
586"""
587
588
589# Generates out= kernels in terms of their functional counterparts.
590# We only do this for "generated" NativeFunctions
591@with_native_function
592def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
593    # We should only be generating these for code-generated NativeFunctions
594    if "generated" not in g.out.tags:
595        return None
596    # And we always write the kernel for the out= op in terms of the functional.
597    # Note that the functional op might have also been generated, but we don't have to
598    # worry about cycles, because the generated functional kernels are always implemented
599    # in terms of non-generated kernels (see gen_composite_functional_kernel).
600
601    sig = DispatcherSignature(g.out.func)
602    target_sig = DispatcherSignature(g.functional.func)
603
604    exprs = ", ".join(
605        [e.expr for e in translate(sig.arguments(), target_sig.arguments())]
606    )
607
608    copy_outs = []
609    out_name = "tmp_output"
610    for i, out_arg in enumerate(g.out.func.arguments.out):
611        functional_return_name = (
612            out_name
613            if len(g.functional.func.returns) == 1
614            else f"std::get<{i}>({out_name})"
615        )
616        copy_outs.append(
617            f"""\
618  resize_out_helper({out_arg.name}, {functional_return_name});
619  copy_arg({out_arg.name}, {functional_return_name});"""
620        )
621
622    rets = []
623    # For each return arg in the calling (out=) operator,
624    # If it corresponds to an aliased input, return the input.
625    # Otherwise, return the corresponding output from calling the functional operator.
626    for i, ret_name in enumerate(g.out.func.aliased_return_names()):
627        if ret_name is not None:
628            rets.append(ret_name)
629        else:
630            functional_return_name = (
631                out_name
632                if len(g.functional.func.returns) == 1
633                else f"std::get<{i}>({out_name})"
634            )
635            rets.append(functional_return_name)
636
637    copy_outs_str = "\n".join(copy_outs)
638
639    # Kernel name needs to follow the naming convention defined in `generate_function()`
640    return f"""
641{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
642  auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
643  {copy_outs_str}
644  {return_str(g.out.func.returns, rets)}
645}}
646"""
647