xref: /aosp_15_r20/external/pytorch/torchgen/gen_functionalization_type.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Callable, TYPE_CHECKING
5
6from torchgen.api import cpp, dispatcher
7from torchgen.api.translate import translate
8from torchgen.api.types import (
9    BaseCType,
10    Binding,
11    CType,
12    DispatcherSignature,
13    FunctionalizationLambda,
14    iTensorListRefT,
15    NativeSignature,
16    OptionalCType,
17    optionalSymIntArrayRefT,
18    symIntArrayRefT,
19    SymIntT,
20    tensorListT,
21    tensorT,
22    VectorCType,
23    ViewInverseSignature,
24)
25from torchgen.context import (
26    method_with_native_function,
27    native_function_manager,
28    with_native_function,
29    with_native_function_and,
30)
31from torchgen.model import (
32    Argument,
33    BackendIndex,
34    BaseTy,
35    BaseType,
36    FunctionSchema,
37    ListType,
38    NativeFunction,
39    NativeFunctionsGroup,
40    NativeFunctionsViewGroup,
41    Return,
42    SchemaKind,
43    SelfArgument,
44    TensorOptionsArguments,
45)
46from torchgen.native_function_generation import (
47    INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
48    MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
49    OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
50)
51from torchgen.utils import dataclass_repr
52
53
54if TYPE_CHECKING:
55    from torchgen.selective_build.selector import SelectiveBuilder
56
57
58# Note: [Mutable Ops Not Using Functionalization]
59# Ops in this list currently do not work with functionalization and should be fixed.
60MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
61    OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
62    + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
63    + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
64    + [
65        # It will be BC-breaking, but we should fix their schemas.
66        # should be inplace?
67        "record_stream",
68        # See Note [resize_ in Functionalization]
69        "resize_",
70        "resize_as_",
71        # This function is used as for testing purposes only.
72        "_fill_mem_eff_dropout_mask_",
73    ]
74)
75
76# This file contains codegen that relates to the functionalization pass.
77# It includes:
78# - gen_functionalization_definition
79#     Generates dispatcher kernel definitions for the functionalization pass.
80# - gen_functionalization_registration
81#     Generates dispatcher kernel registrations for the functionalization pass.
82# - gen_functionalization_view_inverse_declaration
83#     Generates a declaration for an "inverse view", for every view op
84#     that is needed in functionalization. We manually implement their definitions.
85# - gen_composite_view_copy_kernel
86#     Generates view_copy() composite kernels for all view_copy operators.
87
88
89# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction
90# See Note [view_copy NativeFunctions]
91@dataclass(frozen=True)
92class GenCompositeViewCopyKernel:
93    backend_index: BackendIndex
94
95    @method_with_native_function
96    def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
97        if g.view_copy is None:
98            return None
99        elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
100            # If the view_copy doesn't match the standard naming scheme of <op>_copy,
101            # assume it already exists and doesn't need to be generated.
102            # Example: slice_inverse() with the copy variant named slice_scatter()
103            # instead of slice_inverse_copy()
104            return None
105
106        metadata = self.backend_index.get_kernel(g.view_copy)
107        assert metadata is not None
108
109        # We can make view_copy work in more cases by using reshape()
110        # when a normal view call would ordinarily fail.
111        # This also makes LTC more efficient, because they don't need to include
112        # clone() calls in their graph (which is normally needed by reshape).
113        if str(g.view_copy.func.name) == "view_copy":
114            assert metadata.kernel == "view_copy_symint"
115            return """\
116at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
117  c10::SymDimVector shape = infer_size_dv(size, self.sym_numel());
118  if (!at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape).has_value()) {
119    return self.reshape_symint(size);
120  } else {
121    auto output = at::_ops::view::call(self, size);
122    return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
123  }
124}
125"""
126        # view_copy is a native signature, since we're generating an at::native:: kernel
127        # Functionalization always operates on symints though
128        view_copy_sig = NativeSignature(
129            g.view_copy.func, symint=metadata.supports_symint()
130        )
131
132        # view is a dispatcher signature, since we're calling into the at::_ops API
133        view_sig = DispatcherSignature(g.view.func)
134
135        view_api_name = g.view.func.name.unambiguous_name()
136        exprs = ", ".join(
137            [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())]
138        )
139
140        # view ops today always return either a Tensor or a list of Tensors
141        assert len(g.view.func.returns) == 1
142        assert g.view.func.returns[0].type == BaseType(
143            BaseTy.Tensor
144        ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)
145
146        if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
147            return_cloned_output = """\
148  return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);"""
149        else:
150            # If the return type is a list, we need to clone each tensor in the list.
151            return_cloned_output = f"""\
152  {view_copy_sig.returns_type().cpp_type()} out_clone;
153  for (const auto i : c10::irange(output.size())) {{
154    out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous));
155  }}
156  return out_clone;"""
157
158        # The default generated composite kernel for {view}_copy() operators just clones
159        # the input tensor, and runs the underlying view on the clone.
160        return f"""
161{view_copy_sig.defn(name=metadata.kernel)} {{
162  auto output = at::_ops::{view_api_name}::call({exprs});
163  {return_cloned_output}
164}}
165"""
166
167
168def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
169    assert len(rets) == len(names)
170    if len(rets) == 0:
171        return ""
172    elif len(rets) == 1:
173        return f"return {names[0]};"
174    else:
175        return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
176
177
178def modifies_arguments(f: NativeFunction) -> bool:
179    return any(
180        a.annotation is not None and a.annotation.is_write
181        for a in f.func.arguments.flat_all
182    )
183
184
185def wrapper_name(func: FunctionSchema) -> str:
186    if func.name.overload_name:
187        return f"{cpp.name(func)}_{func.name.overload_name}"
188    else:
189        return cpp.name(func)
190
191
192def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
193    return isinstance(a, SelfArgument) or (
194        isinstance(a, Argument) and a.type.is_tensor_like()
195    )
196
197
198# We need to wrap / unwrap various arguments from the op in the functionalization kernels.
199# Some op schemas include non-owning types though (like TensorList),
200# and when we unwrap them we expect to get out an owning type!.
201# We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
202def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
203    if t == BaseCType(tensorListT):
204        return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
205    if t == BaseCType(iTensorListRefT):
206        return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}"
207    # There are technically other non-owning types out there (like IntArrayRef),
208    # but functionalization only actually cares about the ones involving tensors.
209    return t, lambda x: x
210
211
212# unwraps all tensor-like arguments, returning:
213# (1) a string containing all of the logic that does the unwrapping
214# (2) a context, to be used by translate(), with all of the relevant bindings.
215def unwrap_tensor_args(
216    sig: DispatcherSignature, *, is_view_op: bool
217) -> tuple[str, list[Binding]]:
218    context: list[Binding] = []
219    unwrapped_tensor_args: list[str] = []
220    for arg in sig.arguments():
221        if is_tensor_like(arg.argument):
222            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
223            unwrapped_name = f"{arg.name}_"
224            # For most ops, the functionalization needs to sync any pending updates on the input tensors
225            # before calling the operator, since otherwise the operator will act on stale data.
226            # For view ops though, we can continue to defer syncing until the tensor is used by
227            # a non-view operator.
228            maybe_sync_input = (
229                "" if is_view_op else f"at::functionalization::impl::sync({arg.name});"
230            )
231            unwrapped_type, conversion_fn = get_owning_type(
232                arg.nctype.remove_const_ref().type
233            )
234            unwrapped_tensor_args.append(
235                f"""
236      {unwrapped_type.cpp_type()} {unwrapped_name};
237      if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{
238        {maybe_sync_input}
239        {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});
240      }} else {{
241        {unwrapped_name} = {conversion_fn(arg.name)};
242      }}"""
243            )
244            context.append(arg.with_name(unwrapped_name))
245        else:
246            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
247            context.append(arg)
248    unwrap_tensor_args_str = "\n      ".join(unwrapped_tensor_args)
249    return unwrap_tensor_args_str, context
250
251
252# converts  all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
253# (1) a string containing all of the logic that does the conversions.
254# (2) a context, to be used by translate(), with all of the relevant bindings.
255def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
256    context: list[Binding] = []
257    unwrapped_tensor_args: list[str] = []
258    for arg in sig.arguments():
259        if is_tensor_like(arg.argument):
260            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
261            a_ = arg.name
262            unwrapped_name = f"{arg.name}_meta"
263            unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});")
264            context.append(arg.with_name(unwrapped_name))
265        else:
266            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
267            context.append(arg)
268    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
269    return unwrap_tensor_args_str, context
270
271
272# The functionalization codegen currently expects view op schemas to have this form:
273# foo(Tensor(a), ...) -> Tensor(a) (e.g. transpose)
274# foo(Tensor(a!), ...) -> Tensor(a!) (e.g. transpose_)
275def assert_view_op_properties(func: FunctionSchema) -> None:
276    def is_alias(a: Argument) -> bool:
277        return a.annotation is not None
278
279    args = func.arguments.flat_non_out
280    # The first argument is a tensor with an alias semantics (annotations)
281    assert len(args) > 0 and args[0].type == BaseType(
282        BaseTy.Tensor
283    ), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
284but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
285    # No other arguments have aliasing semantics
286    assert is_alias(args[0]) and not any(
287        is_alias(a) for a in args[1:]
288    ), """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
289View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint"""
290
291
292# One-liner expression for checking if an expression expr of type type has any
293# symbolic values.
294def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
295    if type == BaseCType(SymIntT):
296        return f"{expr}.is_symbolic()"
297
298    if isinstance(type, OptionalCType):
299        innerexpr = f"(*{expr})"
300        return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false"
301
302    if type == BaseCType(optionalSymIntArrayRefT):
303        return emit_expr_has_symbolic_values(
304            expr, OptionalCType(BaseCType(symIntArrayRefT))
305        )
306
307    if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))):
308        argname = "arg"
309        lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT))
310        return (
311            "std::any_of("
312            f"{expr}.begin(), {expr}.end(), "
313            f"[=](auto& {argname}) {{ return {lambda_check}; }})"
314        )
315
316    raise ValueError(
317        "unsupported type for has_symbolic_values check. "
318        "It should be a SymInt or a collection of those. "
319        f"Got: {type.cpp_type()}"
320    )
321
322
323# Detects whether any of the SymInt arguments are, in fact, symbolic values.
324# This is used in the constructor of ViewMeta.
325def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
326    name = "has_symbolic_inputs"
327    statements = [
328        f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
329        for binding in sig.arguments()
330        if (
331            isinstance(binding.argument, Argument)
332            and binding.argument.type.is_symint_like()
333        )
334    ]
335    body = "\n      ".join(statements)
336    return (
337        name,
338        f"""
339      bool {name} = false;
340      {body}""",
341    )
342
343
344# Generates the Functionalization kernel for:
345# - ops that create aliases (e.g. transpose())
346# - ops that are views AND mutations (e.g. transpose_())
347def emit_view_functionalization_body(
348    g: NativeFunctionsViewGroup, *, view_inplace: bool
349) -> str:
350    if view_inplace:
351        # This op is both an inplace op AND a view op.
352        # See Note [Functionalization Pass - Inplace View Ops] for details.
353        # I currently have the view meta call into the out-of-place variant of the view, to avoid
354        # having to define an extra ~20 inplace {view}_inverse_ functions.
355        # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
356        # I'm assuming that every inplace-view op has a corresponding out-of-place view op,
357        # with the same name but the trailing underscore removed.
358        # This is currently asserted at parse time in gen.py (see error_check_native_functions).
359        assert g.view_inplace is not None
360        f = g.view_inplace
361    else:
362        f = g.view
363
364    assert g.view_copy is not None
365    with native_function_manager(f):
366        call_sig = DispatcherSignature.from_schema(g.view_copy.func)
367
368        # the "view_copy" op name that the functionalization kernels need to call
369        api_name = g.view_copy.func.name.unambiguous_name()
370        # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
371        # "no-op"ing in this context is just redispatching to the original op.
372        noop_api_name = f.func.name.unambiguous_name()
373
374        dispatcher_sig = DispatcherSignature.from_schema(f.func)
375        assert_view_op_properties(f.func)
376        view_tensor_name = dispatcher_sig.arguments()[0].name
377
378        return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()
379
380        unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
381            dispatcher_sig, is_view_op=True
382        )
383        view_redispatch_args = [
384            e.expr
385            for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
386        ]
387
388        forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
389        reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
390
391        # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
392        meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
393        meta_call_args = [
394            e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
395        ]
396
397        (
398            symbolic_inputs_varname,
399            symbolic_inputs_check,
400        ) = emit_has_symbolic_inputs(call_sig)
401
402        if "inplace_view" in f.tags:
403            # See Note [Functionalization Pass - Inplace View Ops] for more details
404            return f"""
405    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
406      if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
407        // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
408        {unwrap_tensor_args_str}
409        at::AutoDispatchSkipFunctionalize guard;
410        return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
411      }}
412      auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
413      auto inverse_return_mode = (
414          reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
415            : at::functionalization::InverseReturnMode::NeverView
416      );
417      {symbolic_inputs_check}
418      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
419        {forward_lambda.decl()} {{
420          if (reapply_views) {{
421            return {forward_lambda.inner_call(reapply_views=True)}
422          }} else {{
423            return {forward_lambda.inner_call(reapply_views=False)}
424          }}
425        }},
426        {reverse_lambda.decl()} {{
427          return {reverse_lambda.inner_call()}
428        }},
429        /*has_symbolic_inputs=*/{symbolic_inputs_varname}
430      );
431      auto compute_reference_meta =
432        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
433        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
434      {return_type} reference_tensor_output;
435      if (compute_reference_meta) {{
436        {meta_conversion_str}
437        at::AutoDispatchSkipFunctionalize func_guard;
438        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
439        reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
440      }}
441      // This function adds the above view meta to the current tensor and replays them off the base,
442      // mutating the size/stride info of the current FunctionalTensorWrapper.
443      // Because of this, we need to make sure to run the reference shape function above,
444      // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides)
445      at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
446      // See  Note [Propagating strides in the functionalization pass]
447      // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
448      // on a reference implementation here (instead of relying on the output from the forward lambda
449      // having the correct stride info)
450      if (compute_reference_meta) {{
451        at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
452      }}
453      return {view_tensor_name};
454    }}
455"""
456
457        else:
458            is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
459            return f"""
460    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
461      {unwrap_tensor_args_str}
462      if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
463        // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
464        at::AutoDispatchSkipFunctionalize guard;
465        return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
466      }}
467      auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
468      auto inverse_return_mode = (
469          reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
470            : at::functionalization::InverseReturnMode::NeverView
471      );
472      auto compute_reference_meta =
473        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
474        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
475      {return_type} reference_tensor_output;
476      if (compute_reference_meta) {{
477        {meta_conversion_str}
478        at::AutoDispatchSkipFunctionalize func_guard;
479        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
480        reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)});
481      }}
482      {return_type} tmp_output;
483      {{
484        at::AutoDispatchSkipFunctionalize guard;
485        if (reapply_views) {{
486          tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)});
487        }} else {{
488          tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)});
489        }}
490      }}
491      {symbolic_inputs_check}
492      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
493        {forward_lambda.decl()} {{
494          if (reapply_views) {{
495            return {forward_lambda.inner_call(reapply_views=True)}
496          }} else {{
497            return {forward_lambda.inner_call(reapply_views=False)}
498          }}
499        }},
500        {reverse_lambda.decl()} {{
501          return {reverse_lambda.inner_call()}
502        }},
503        /*has_symbolic_inputs=*/{symbolic_inputs_varname},
504        /*is_multi_output=*/{str(is_multi_output_view).lower()},
505        /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()}
506      );
507      auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
508      // See  Note [Propagating strides in the functionalization pass]
509      if (compute_reference_meta) {{
510        at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
511      }}
512      return out;
513    }}
514"""
515
516
517def maybe_create_output(f: NativeFunction, var_name: str) -> str:
518    if len(f.func.returns) == 0:
519        return ""
520    return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
521    return f"{return_type} {var_name} = "
522
523
524# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function,
525# this returns two lists of names, consisting of:
526# - the names of returns corresponding to the original (mutable) inputs of the outer function
527# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
528def get_mutable_redispatch_return_names(
529    f: NativeFunction, inner_return_var: str
530) -> tuple[list[str], list[str]]:
531    aliased_returns = []
532    non_aliased_returns = []
533    for i, name in enumerate(f.func.aliased_return_names()):
534        if name is not None:
535            aliased_returns.append(name)
536        else:
537            non_aliased_returns.append(
538                inner_return_var
539                if len(f.func.returns) == 1
540                else f"std::get<{i}>({inner_return_var})"
541            )
542    return aliased_returns, non_aliased_returns
543
544
545# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that:
546#  - For fresh outputs, we return the result of the redispatch (without wrapping outputs)
547#  - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped)
548def return_from_mutable_noop_redispatch(
549    f: NativeFunction, inner_return_var: str
550) -> str:
551    aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var)
552    # Just get all of the return names, and immediately return them
553    return return_str(f.func.returns, aliased + non_aliased)
554
555
556def wrap_propagate_mutations_and_return(
557    f: NativeFunction, functional_op: NativeFunction, inner_return_var: str
558) -> str:
559    mutable_arg_names = f.func.arguments.mutable_arg_names()
560    (
561        aliased_outer_rets,
562        non_aliased_outer_rets,
563    ) = get_mutable_redispatch_return_names(f, inner_return_var)
564    _, non_aliased_inner_rets = get_mutable_redispatch_return_names(
565        functional_op, inner_return_var
566    )
567    # The outer function may have a mix of aliased and non-aliased outputs,
568    # But the inner functional op that we're transforming to should only have non-aliased outputs
569    assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len(
570        non_aliased_inner_rets
571    )
572
573    # First, take all of the newly created outputs from the inner call and wrap them into functional tensors
574    updates = []
575    non_aliased_wrapped_ret_names = []
576    for i, inner_ret in enumerate(
577        non_aliased_inner_rets[: len(non_aliased_outer_rets)]
578    ):
579        ret_name = f"output_{i}"
580        updates.append(
581            f"""\
582  auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});"""
583        )
584        non_aliased_wrapped_ret_names.append(ret_name)
585
586    # Next, take all of the mutated outputs from the inner call corresponding to mutated inputs,
587    # and propagate the mutations
588    for outer_arg, inner_ret in zip(
589        mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :]
590    ):
591        updates.append(
592            f"""\
593  auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
594  at::functionalization::impl::replace_({outer_arg}, {inner_ret});
595  at::functionalization::impl::commit_update({outer_arg});
596  at::functionalization::impl::sync({outer_arg});
597  auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
598  at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
599        )
600
601    # Finally, we return:
602    # - Any mutable arguments that also returns
603    # - Any immutable returns that were created wrapping the output from the inner call
604    returns_str = return_str(
605        f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names
606    )
607    updates_str = "\n".join(updates)
608    return f"""\
609{updates_str}
610    {returns_str}"""
611
612
613# Generates the Functionalization kernel for:
614# - mutation ops (inplace and out= ops)
615@with_native_function_and
616def emit_inplace_functionalization_body(
617    f: NativeFunction, g: NativeFunctionsGroup
618) -> str:
619    # mutation case
620    assert modifies_arguments(f)
621
622    dispatcher_sig = DispatcherSignature.from_schema(f.func)
623
624    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
625        dispatcher_sig, is_view_op=False
626    )
627
628    mutated_names = [
629        a.name
630        for a in f.func.arguments.flat_all
631        if a.type.is_tensor_like() and a.annotation is not None
632    ]
633    non_mutated_names = [
634        a.name
635        for a in f.func.arguments.flat_all
636        if a.type.is_tensor_like() and a.annotation is None
637    ]
638    non_mutated_tensor_names = [
639        a.name
640        for a in f.func.arguments.flat_all
641        if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
642    ]
643    # all mutable inputs must be functional tensors in order to participate in functionalization
644    check_all_mutated_args_are_functional = " && ".join(
645        ["true"]
646        + [
647            f"at::functionalization::impl::isFunctionalTensor({a})"
648            for a in mutated_names
649        ]
650    )
651    check_any_non_mutated_args_are_functional = " || ".join(
652        ["false"]
653        + [
654            f"at::functionalization::impl::isFunctionalTensor({a})"
655            for a in non_mutated_names
656        ]
657    )
658
659    check_any_non_mutated_tensors_are_xla = " || ".join(
660        ["false"]
661        + [
662            f"{a}.device().type() == c10::DeviceType::XLA"
663            for a in non_mutated_tensor_names
664        ]
665    )
666    # These are used in the cases where we don't functionalize and redispatch to the inplace op
667    # case 1: we hit an inplace op that doesn't have an out-of-place equivalent
668    # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
669    inplace_exprs = [
670        e.expr
671        for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
672    ]
673
674    # call the out-of-place variant of the op
675    return_type = (
676        dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
677    )
678    functional_sig = DispatcherSignature.from_schema(g.functional.func)
679    functional_exprs = [
680        e.expr
681        for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
682    ]
683
684    if f.func.is_out_fn():
685        mutable_input_post_processing = "\n".join(
686            [
687                f"""
688      at::functionalization::impl::replace_(
689        {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'});
690      at::functionalization::impl::commit_update({a.name});"""
691                for (i, a) in enumerate(f.func.arguments.out)
692                if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
693            ]
694        )
695    else:
696        mutable_input_post_processing = "\n".join(
697            [
698                f"""
699      at::functionalization::impl::replace_({a.name}, tmp_output);
700      at::functionalization::impl::commit_update({a.name});"""
701                for a in f.func.arguments.flat_all
702                if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
703            ]
704        )
705
706    meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
707    # We don't want to run the inplace meta func for ops like .set_(), because:
708    # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(),
709    #     where broadcasting will work for the out-of-place case but should fail on the inplace call
710    # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument
711    #     into a meta storage
712    any_storage_args = any(
713        a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all
714    )
715
716    return f"""
717    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
718      if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
719        // Before converting the mutable op to its functional variant, run meta tensors through the original op.
720        // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
721        // (We can only do this for inplace ops today though, because they technically all support meta tensors).
722        {meta_conversion_str}
723        at::AutoDispatchSkipFunctionalize func_guard;
724        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
725        at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)});
726      }}
727      {unwrap_tensor_args_str}
728      if (!({check_all_mutated_args_are_functional})) {{
729        // We want to disable this check if there are any XLA tensors.
730        // cpu_tensor.copy_(xla_tensor) is valid code.
731        if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
732         // case 1: trying to mutate a non functional tensor with a functional tensor is an error
733         TORCH_INTERNAL_ASSERT(false,
734           "mutating a non-functional tensor with a functional tensor is not allowed.",
735           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
736        }} else {{
737         // case 2: arguments are not functional tensors, so we no-op and redispatch.
738         at::AutoDispatchSkipFunctionalize guard;
739         {maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)});
740         {return_from_mutable_noop_redispatch(f, 'tmp_output')}
741        }}
742      }} else {{
743        {return_type} tmp_output;
744        {{
745          at::AutoDispatchSkipFunctionalize guard;
746          tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)});
747        }}
748        {wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')}
749      }}
750    }}"""
751
752
753# The below functions generate RegisterFunctionalization.cpp
754# These files provide the kernels that run the functionalization pass, which can be opted into
755# per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch).
756
757
758# See Note [Functionalization Pass: View Inverses].
759def gen_functionalization_view_inverse_declaration(
760    selector: SelectiveBuilder, g: NativeFunctionsViewGroup
761) -> str | None:
762    # For every (non-composite) view op, we need a corresponding "inverse view" function.
763    # This generates the declarations so we get a good compiler error when someone adds a new view.
764    @with_native_function
765    def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
766        if g.view.has_composite_implicit_autograd_kernel:
767            return None
768        view_inverse_sig = ViewInverseSignature(g)
769        return view_inverse_sig.decl()
770
771    return emit_decl_helper(g)
772
773
774def gen_functionalization_registration(
775    selector: SelectiveBuilder,
776    g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
777    composite_implicit_autograd_index: BackendIndex,
778) -> list[str]:
779    @with_native_function
780    def emit_registration_helper(f: NativeFunction) -> str:
781        assert not f.has_composite_implicit_autograd_kernel
782        registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
783        return f'm.impl("{f.func.name}", {registration_str});'
784
785    # Don't generate kernels in mobile build
786    if not selector.include_all_operators:
787        return []
788
789    if isinstance(g, NativeFunctionsViewGroup):
790        # functionalization needs to register kernels for view + view_inplace ops
791        # See Note [Functionalization <> torch.Tensor constructor]
792        if str(g.view.func.name) == "lift_fresh":
793            return []
794        view_str = []
795        if not g.view.has_composite_implicit_autograd_kernel:
796            view_str.append(emit_registration_helper(g.view))
797        if (
798            g.view_inplace is not None
799            and not g.view_inplace.has_composite_implicit_autograd_kernel
800        ):
801            assert g.view_inplace.is_view_op
802            view_str.append(emit_registration_helper(g.view_inplace))
803        return view_str
804
805    elif isinstance(g, NativeFunctionsGroup):
806        # Gets a hand-written functionalization kernel
807        if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor":
808            fns = []
809        else:
810            fns = list(g.functions())
811    else:
812        if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
813            return []
814        fns = [g]
815
816    registrations = []
817    for f in fns:
818        if f.has_composite_implicit_autograd_kernel:
819            continue
820        if str(f.func.name) == "lift":
821            # See Note [Functionalization <> torch.Tensor constructor]
822            return []
823        if str(f.func.name) == "resize_":
824            # See Note [resize_ in Functionalization]
825            return []
826        if str(f.func.name.name) != "set_":
827            assert not f.is_view_op
828        # functionalization needs to generate and register kernels for inplace ops.
829        # We *also* need to directly register CompositeImplicitAUtograd kernels
830        # so that they decompose properly before functioanlization.
831        if modifies_arguments(f):
832            registrations.append(emit_registration_helper(f))
833    return registrations
834
835
836def gen_functionalization_definition(
837    selector: SelectiveBuilder,
838    # Note: Ideally this code should never have to look at NativeFunction
839    # (and instead only need to operate on grouped NativeFunctions).
840    # The only reason currently is because we need to emit direct dispatch registrations
841    # For CompositeImplicitAutograd operators, which are potentially ungrouped.
842    g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
843) -> list[str]:
844    # Don't generate kernels in mobile build
845    if not selector.include_all_operators:
846        return []
847
848    if isinstance(g, NativeFunctionsViewGroup):
849        # Case 1: emit view -> view_copy kernels for the functionalization pass
850        view_defs = []
851        if not g.composite:
852            # invariant: NativeFunctionsViewGroup's always have a view_copy operator
853            # if the view is not composite (implicit autograd)
854            assert g.view_copy is not None, dataclass_repr(g, indent=1)
855            view_defs.append(emit_view_functionalization_body(g, view_inplace=False))
856            if g.view_inplace is not None:
857                view_defs.append(emit_view_functionalization_body(g, view_inplace=True))
858        return view_defs
859    elif isinstance(g, NativeFunction):
860        # Invariant: all mutable operators that we need to handle in functionalization
861        # should have been properly grouped up.
862        # TODO: The below ops all have "problematic" schemas that prevent them from
863        # getting functionalized. Instead of bending over backwards to get things to work,
864        # I think we should either:
865        # (1) fix their schemas (BC-breaking)
866        # (2) hand-write their functionalization kernels
867        if (
868            str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
869            and str(g.func.name.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
870        ):
871            assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g)
872        return []
873    else:
874        # Case 2: emit inplace -> out-of-place kernels for the functionalization pass
875        mutation_defs = []
876        mutation_defs.append(emit_inplace_functionalization_body(g.out, g))
877        if g.inplace is not None:
878            mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g))
879        if g.mutable is not None:
880            mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g))
881        return mutation_defs
882    return []
883