xref: /aosp_15_r20/external/pytorch/torchgen/api/functionalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from torchgen.api import dispatcher
4from torchgen.api.types import (
5    BaseCppType,
6    BaseCType,
7    Binding,
8    boolT,
9    ConstRefCType,
10    CType,
11    longT,
12    NamedCType,
13    tensorT,
14)
15from torchgen.model import (
16    Argument,
17    BaseTy,
18    BaseType,
19    FunctionSchema,
20    NativeFunction,
21    NativeFunctionsViewGroup,
22)
23
24
25# This file describes the translation of JIT schema to API's used
26# when creating view lambdas that are used by the functionalization pass.
27# There are two types of lambdas: forward lambdas and reverse lambdas.
28# These API's mostly follow the dispatcher API, with a few quirks:
29# - The lambda capture has to convert reference types to value types
30# - While the forward lambda just directly calls into the at::_ops API
31#   (following the dispatcher convention), the logic here for the reverse lambda
32#   is responsible for generating both the call-site, and the declarations
33#   (which are implemented manually in the at::functionalization::impl namespace).
34
35# The lambdas generated for each view op in the functionalization pass are of the form
36# [capture_arguments](outer_arguments) -> returns_type {
37#     return name(inner_arguments);
38# }
39
40# Define some specific lambda input arguments.
41base_binding = Binding(
42    name="base",
43    nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
44    argument=Argument(
45        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
46    ),
47    default=None,
48)
49mutated_view_binding = Binding(
50    name="mutated_view",
51    nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
52    argument=Argument(
53        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
54    ),
55    default=None,
56)
57mutated_view_idx_binding = Binding(
58    name="mutated_view_idx",
59    nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
60    argument=Argument(
61        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
62    ),
63    default=None,
64)
65reapply_views_binding = Binding(
66    name="reapply_views",
67    nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
68    argument=Argument(
69        name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
70    ),
71    default=None,
72)
73
74InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
75inverse_return_mode_binding = Binding(
76    name="inverse_return_mode",
77    nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
78    argument=Argument(
79        name="inverse_return_mode",
80        # NB: not actually a bool but it doesn't matter because this isn't used
81        type=BaseType(BaseTy.bool),
82        default=None,
83        annotation=None,
84    ),
85    default=None,
86)
87
88
89# The lambda capture itself doesn't have a name.
90# The name returned here corresponds to the name of the inner function called by the lambda.
91def name(
92    g: NativeFunctionsViewGroup,
93    *,
94    is_reverse: bool,
95    include_namespace: bool,
96    reapply_views: bool | None = None,
97) -> str:
98    if reapply_views is None:
99        # reapply_views is only important for the fwd lambda,
100        # since we always plumb the runtime "reapply_views" argument into the reverse function.
101        assert is_reverse
102    if is_reverse:
103        return reverse_name(g.view, include_namespace)
104    # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
105    assert include_namespace
106    assert g.view_copy is not None
107    api_name = (
108        g.view.func.name.unambiguous_name()
109        if reapply_views
110        else g.view_copy.func.name.unambiguous_name()
111    )
112    return f"at::_ops::{api_name}::call"
113
114
115def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
116    # for the reverse: we plumb the "reapply_views" flag into that function and support
117    # both copy and non-copy variants. (We could avoid doing that, but that would require
118    # writing out twice as many view inverse functions).
119    api_name = f.func.name.unambiguous_name()
120    # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
121    if include_namespace:
122        return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
123    else:
124        return f"{api_name}_inverse"
125
126
127def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
128    # capture arguments include all arguments except `self`.
129    # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
130    # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
131    args = func.arguments.flat_all
132    assert args[0].type == BaseType(BaseTy.Tensor)
133    non_self_args = args[1:]
134    non_self_value_bindings = [
135        dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
136    ]
137
138    all_bindings = [
139        inverse_return_mode_binding if is_reverse else reapply_views_binding
140    ]
141    all_bindings.extend(non_self_value_bindings)
142    return all_bindings
143
144
145def returns_type(func: FunctionSchema) -> CType:
146    # Assertion: all view ops return tensor-like outputs
147    assert len(func.returns) >= 1
148    for ret in func.returns:
149        assert ret.type.is_tensor_like()
150    # However, the return type of the lambda is always an individual tensor.
151    # For multi-tensor outputs, each tensor needs to be tracked individually.
152    return BaseCType(tensorT)
153
154
155def outer_arguments(*, is_reverse: bool) -> list[Binding]:
156    if is_reverse:
157        return [base_binding, mutated_view_binding, mutated_view_idx_binding]
158    else:
159        return [base_binding, mutated_view_idx_binding]
160
161
162def inner_call_index(func: FunctionSchema) -> Binding | None:
163    # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
164    # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
165    if len(func.returns) > 1 or (
166        len(func.returns) == 1 and func.returns[0].type.is_list_like()
167    ):
168        return mutated_view_idx_binding
169    return None
170
171
172def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
173    args = func.arguments.flat_all
174    assert args[0].type == BaseType(BaseTy.Tensor)
175    non_self_args = args[1:]
176    # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
177    # Both of these follow the dispatcher API.
178    non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
179    if not is_reverse:
180        # the forward lambda swaps out the original tensor argument with the lambd arg "base"
181        return [base_binding] + non_self_bindings
182    else:
183        # the reverse lambda does the same, but with an additional "mutated_view" arg
184        # additionally, we have a calling convention: for view ops that return multiple tensor outputs
185        # their corresponding view_inverse function takes in an additional index argument.
186        index_binding = inner_call_index(func)
187        if index_binding is not None:
188            return [
189                base_binding,
190                mutated_view_binding,
191                inverse_return_mode_binding,
192                index_binding,
193            ] + non_self_bindings
194        else:
195            return [
196                base_binding,
197                mutated_view_binding,
198                inverse_return_mode_binding,
199            ] + non_self_bindings
200