xref: /aosp_15_r20/external/pytorch/torchgen/api/autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport re
4*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
5*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, Sequence
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom torchgen import local
8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
9*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
10*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
11*da0073e9SAndroid Build Coastguard Worker    BaseTy,
12*da0073e9SAndroid Build Coastguard Worker    BaseType,
13*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
14*da0073e9SAndroid Build Coastguard Worker    ListType,
15*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
16*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsViewGroup,
17*da0073e9SAndroid Build Coastguard Worker    SchemaKind,
18*da0073e9SAndroid Build Coastguard Worker    Type,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import IDENT_REGEX
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker# Represents a saved attribute involved in backward calculation.
24*da0073e9SAndroid Build Coastguard Worker# Note that it can be a derived property of an input argument, e.g.:
25*da0073e9SAndroid Build Coastguard Worker# we could save `other.scalar_type()` instead of the entire `other` tensor.
26*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
27*da0073e9SAndroid Build Coastguard Workerclass SavedAttribute:
28*da0073e9SAndroid Build Coastguard Worker    # The NamedCType holds the updated name and cpp type of the attribute
29*da0073e9SAndroid Build Coastguard Worker    # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
30*da0073e9SAndroid Build Coastguard Worker    nctype: NamedCType
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    # The expression to read the derived property at save time, e.g.:
33*da0073e9SAndroid Build Coastguard Worker    # `other.scalar_type()`.
34*da0073e9SAndroid Build Coastguard Worker    expr: str
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker# Represents a backward formula that calculates derivatives for one
38*da0073e9SAndroid Build Coastguard Worker# or more tensors.
39*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
40*da0073e9SAndroid Build Coastguard Workerclass Derivative:
41*da0073e9SAndroid Build Coastguard Worker    # The formula string (legit C++ expression).
42*da0073e9SAndroid Build Coastguard Worker    # Note that expressions against input arguments have been replaced with the
43*da0073e9SAndroid Build Coastguard Worker    # corresponding saved attributes.
44*da0073e9SAndroid Build Coastguard Worker    # E.g.:
45*da0073e9SAndroid Build Coastguard Worker    #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
46*da0073e9SAndroid Build Coastguard Worker    #         here: `mul_tensor_backward(grad, self, other_scalar_type)`
47*da0073e9SAndroid Build Coastguard Worker    formula: str
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    # The formula string before input argument replacement
50*da0073e9SAndroid Build Coastguard Worker    original_formula: str
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    # Names of the arguments for which this formula calculates derivatives.
53*da0073e9SAndroid Build Coastguard Worker    var_names: tuple[str, ...]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    # Saved inputs that are referenced by the formula.
56*da0073e9SAndroid Build Coastguard Worker    saved_inputs: tuple[SavedAttribute, ...]
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    # Saved outputs that are referenced by the formula.
59*da0073e9SAndroid Build Coastguard Worker    saved_outputs: tuple[SavedAttribute, ...]
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    # Gradients that are referenced by name in the formula.
62*da0073e9SAndroid Build Coastguard Worker    named_gradients: set[str]
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker# Represents a forward formula that calculates forward derivatives
66*da0073e9SAndroid Build Coastguard Worker# for one tensor.
67*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
68*da0073e9SAndroid Build Coastguard Workerclass ForwardDerivative:
69*da0073e9SAndroid Build Coastguard Worker    # The formula string (legit C++ expression).
70*da0073e9SAndroid Build Coastguard Worker    # Note that special keywords such as "linear" or "element_wise" have been
71*da0073e9SAndroid Build Coastguard Worker    # replaced by the automatically generated formula.
72*da0073e9SAndroid Build Coastguard Worker    formula: str
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    # Name of the output arguments for which this formula calculates forward
75*da0073e9SAndroid Build Coastguard Worker    # derivatives
76*da0073e9SAndroid Build Coastguard Worker    var_names: tuple[str, ...]
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    # Type of the output arguments for which this formula calculates forward
79*da0073e9SAndroid Build Coastguard Worker    # derivatives
80*da0073e9SAndroid Build Coastguard Worker    var_types: tuple[Type, ...]
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    # Inputs for which the forward derivatives are required for this formula
83*da0073e9SAndroid Build Coastguard Worker    required_inputs_fw_grad: tuple[str, ...] | None
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    # Inputs for which the primal is required for this formula
86*da0073e9SAndroid Build Coastguard Worker    required_inputs_primal: tuple[str, ...] | None
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    # Flag to specify if this formula requires the original value of self
89*da0073e9SAndroid Build Coastguard Worker    # This is only used by inplace operations
90*da0073e9SAndroid Build Coastguard Worker    required_original_self_value: bool
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    # If this formula is specified in derivatives.yaml or if we are re-using the
93*da0073e9SAndroid Build Coastguard Worker    # out of place formula for inplace
94*da0073e9SAndroid Build Coastguard Worker    is_reusing_outplace_formula: bool
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker# Represents differentiability info for a NativeFunction.
98*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
99*da0073e9SAndroid Build Coastguard Workerclass DifferentiabilityInfo:
100*da0073e9SAndroid Build Coastguard Worker    # The base name read from derivatives.yaml.
101*da0073e9SAndroid Build Coastguard Worker    name: str
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    # The matching native function.
104*da0073e9SAndroid Build Coastguard Worker    #
105*da0073e9SAndroid Build Coastguard Worker    # There can be multiple NativeFunction having the same base name:
106*da0073e9SAndroid Build Coastguard Worker    #  - different overloads with different types of input arguments;
107*da0073e9SAndroid Build Coastguard Worker    #  - in-place/out/functional variants of the same function;
108*da0073e9SAndroid Build Coastguard Worker    #
109*da0073e9SAndroid Build Coastguard Worker    # We first use the schema string (under the 'name' key) in derivatives.yaml
110*da0073e9SAndroid Build Coastguard Worker    # to find the NativeFunction having the same schema string.
111*da0073e9SAndroid Build Coastguard Worker    # Then we find the in-place/out/functional variants of the matching function.
112*da0073e9SAndroid Build Coastguard Worker    # Among these variants, we choose the one having the same name as the
113*da0073e9SAndroid Build Coastguard Worker    # derivatives.yaml entry. If there is no exact match, then we choose the
114*da0073e9SAndroid Build Coastguard Worker    # in-place variant.
115*da0073e9SAndroid Build Coastguard Worker    # TODO: maybe the logic to search for all variants is no longer necessary?
116*da0073e9SAndroid Build Coastguard Worker    func: NativeFunction
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    # The name of the generated autograd function.
119*da0073e9SAndroid Build Coastguard Worker    # It's set only if we will calculate a derivative, i.e.
120*da0073e9SAndroid Build Coastguard Worker    # 'args_with_derivatives' is not empty.
121*da0073e9SAndroid Build Coastguard Worker    op: str | None
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    # The derivatives formulae for this function.
124*da0073e9SAndroid Build Coastguard Worker    # Note that the length of this sequence is the number of differentiable inputs
125*da0073e9SAndroid Build Coastguard Worker    derivatives: Sequence[Derivative]
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    # The forward derivatives formulae for this function.
128*da0073e9SAndroid Build Coastguard Worker    # Note that the length of this sequence is the number of differentiable outputs
129*da0073e9SAndroid Build Coastguard Worker    forward_derivatives: Sequence[ForwardDerivative]
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    # The union of 'saved_inputs' of all 'derivatives'.
132*da0073e9SAndroid Build Coastguard Worker    all_saved_inputs: Sequence[SavedAttribute]
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    # The union of 'saved_outputs' of all 'derivatives'.
135*da0073e9SAndroid Build Coastguard Worker    all_saved_outputs: Sequence[SavedAttribute]
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    # All named gradients that are available for use, in the same
138*da0073e9SAndroid Build Coastguard Worker    # order as in the grads vector.
139*da0073e9SAndroid Build Coastguard Worker    available_named_gradients: Sequence[str]
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    # The named gradients that are used in any of the derivatives.
142*da0073e9SAndroid Build Coastguard Worker    # Invariant: all(name in available_named_gradients for name in used_named_gradients)
143*da0073e9SAndroid Build Coastguard Worker    used_named_gradients: set[str]
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    # The function's input arguments for which it calculates derivatives.
146*da0073e9SAndroid Build Coastguard Worker    # It's the union of 'var_names' of all 'derivatives', sorted by the
147*da0073e9SAndroid Build Coastguard Worker    # argument order in the function schema.
148*da0073e9SAndroid Build Coastguard Worker    args_with_derivatives: Sequence[Binding]
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    # Names of arguments whose derivative formula is 'non_differentiable'.
151*da0073e9SAndroid Build Coastguard Worker    non_differentiable_arg_names: Sequence[str]
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    # Raw data read from derivatives.yaml.
154*da0073e9SAndroid Build Coastguard Worker    output_differentiability: list[bool] | None
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    # output_differentiability in derivatives.yaml can be a list of
157*da0073e9SAndroid Build Coastguard Worker    # conditions that express if the output is differentiable. In this case,
158*da0073e9SAndroid Build Coastguard Worker    # the number of conditions must match the number of outputs
159*da0073e9SAndroid Build Coastguard Worker    # (NB: we only support one condition right now).
160*da0073e9SAndroid Build Coastguard Worker    # output_differentiability gets populated with True for each condition,
161*da0073e9SAndroid Build Coastguard Worker    # while output_differentiability_conditions gets populated with the conditions
162*da0073e9SAndroid Build Coastguard Worker    output_differentiability_conditions: list[str] | None
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    @property
165*da0073e9SAndroid Build Coastguard Worker    def has_derivatives(self) -> bool:
166*da0073e9SAndroid Build Coastguard Worker        return len(self.args_with_derivatives) > 0
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
169*da0073e9SAndroid Build Coastguard Worker    # but with a new operator name.
170*da0073e9SAndroid Build Coastguard Worker    # This is used when generating "copy" variants of view ops,
171*da0073e9SAndroid Build Coastguard Worker    # which are able to use the exact same derivative formula as the original view op
172*da0073e9SAndroid Build Coastguard Worker    # See Note [Codegen'd {view}_copy Operators]
173*da0073e9SAndroid Build Coastguard Worker    def create_view_copy_from_view_derivative(
174*da0073e9SAndroid Build Coastguard Worker        self, g: NativeFunctionsViewGroup
175*da0073e9SAndroid Build Coastguard Worker    ) -> DifferentiabilityInfo | None:
176*da0073e9SAndroid Build Coastguard Worker        if g.view_copy is None:
177*da0073e9SAndroid Build Coastguard Worker            return None
178*da0073e9SAndroid Build Coastguard Worker        f = g.view_copy
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        name_split_by_period = self.name.split(".", maxsplit=2)
181*da0073e9SAndroid Build Coastguard Worker        # Append a "_copy" to the base name of the operator (but keep the overload name the same)
182*da0073e9SAndroid Build Coastguard Worker        view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
183*da0073e9SAndroid Build Coastguard Worker            name_split_by_period[1:]
184*da0073e9SAndroid Build Coastguard Worker        )
185*da0073e9SAndroid Build Coastguard Worker        view_copy_op_name = None if self.op is None else f"{self.op}_copy"
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        return DifferentiabilityInfo(
188*da0073e9SAndroid Build Coastguard Worker            # Use the "_copy" version of name/func/op
189*da0073e9SAndroid Build Coastguard Worker            name=view_copy_name,
190*da0073e9SAndroid Build Coastguard Worker            func=f,
191*da0073e9SAndroid Build Coastguard Worker            op=view_copy_op_name,
192*da0073e9SAndroid Build Coastguard Worker            # But keep all derivative info the same
193*da0073e9SAndroid Build Coastguard Worker            derivatives=self.derivatives,
194*da0073e9SAndroid Build Coastguard Worker            forward_derivatives=self.forward_derivatives,
195*da0073e9SAndroid Build Coastguard Worker            all_saved_inputs=self.all_saved_inputs,
196*da0073e9SAndroid Build Coastguard Worker            all_saved_outputs=self.all_saved_outputs,
197*da0073e9SAndroid Build Coastguard Worker            available_named_gradients=self.available_named_gradients,
198*da0073e9SAndroid Build Coastguard Worker            used_named_gradients=self.used_named_gradients,
199*da0073e9SAndroid Build Coastguard Worker            args_with_derivatives=self.args_with_derivatives,
200*da0073e9SAndroid Build Coastguard Worker            non_differentiable_arg_names=self.non_differentiable_arg_names,
201*da0073e9SAndroid Build Coastguard Worker            output_differentiability=self.output_differentiability,
202*da0073e9SAndroid Build Coastguard Worker            output_differentiability_conditions=self.output_differentiability_conditions,
203*da0073e9SAndroid Build Coastguard Worker        )
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerdef uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
207*da0073e9SAndroid Build Coastguard Worker    if info is None:
208*da0073e9SAndroid Build Coastguard Worker        return False
209*da0073e9SAndroid Build Coastguard Worker    for derivative in info.derivatives:
210*da0073e9SAndroid Build Coastguard Worker        formula = derivative.formula
211*da0073e9SAndroid Build Coastguard Worker        if re.search(IDENT_REGEX.format(ident), formula):
212*da0073e9SAndroid Build Coastguard Worker            return True
213*da0073e9SAndroid Build Coastguard Worker    return False
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Workerdef uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
217*da0073e9SAndroid Build Coastguard Worker    return uses_ident(info, "retain_variables")
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Workerdef uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
221*da0073e9SAndroid Build Coastguard Worker    return uses_ident(info, "grad")
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker# Represents a differentiable `Argument`.
225*da0073e9SAndroid Build Coastguard Worker# How is it different from the `Argument` type?
226*da0073e9SAndroid Build Coastguard Worker# - It's processed Arguments which are differentiable and only used in the
227*da0073e9SAndroid Build Coastguard Worker#   context of the autograd codegen;
228*da0073e9SAndroid Build Coastguard Worker# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
229*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
230*da0073e9SAndroid Build Coastguard Workerclass DifferentiableInput:
231*da0073e9SAndroid Build Coastguard Worker    name: str
232*da0073e9SAndroid Build Coastguard Worker    type: Type
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
235*da0073e9SAndroid Build Coastguard Worker    cpp_type: str
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker# Represents a differentiable `Return`.
239*da0073e9SAndroid Build Coastguard Worker# How it it different from the `Return` type?
240*da0073e9SAndroid Build Coastguard Worker# - The name in `Return` is optional. Here it is always populated using the same
241*da0073e9SAndroid Build Coastguard Worker#   `cpp.return_names()` method.
242*da0073e9SAndroid Build Coastguard Worker#   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
243*da0073e9SAndroid Build Coastguard Worker# - It's processed Returns which are differentiable, in compliance with the
244*da0073e9SAndroid Build Coastguard Worker#   `output_differentiability` field defined in derivatives.yaml (if specified),
245*da0073e9SAndroid Build Coastguard Worker#   and are only used in the context of the autograd codegen;
246*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
247*da0073e9SAndroid Build Coastguard Workerclass DifferentiableOutput:
248*da0073e9SAndroid Build Coastguard Worker    name: str
249*da0073e9SAndroid Build Coastguard Worker    type: Type
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
252*da0073e9SAndroid Build Coastguard Worker    cpp_type: str
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
256*da0073e9SAndroid Build Coastguard Workerclass NativeFunctionWithDifferentiabilityInfo:
257*da0073e9SAndroid Build Coastguard Worker    func: NativeFunction
258*da0073e9SAndroid Build Coastguard Worker    info: dict[str, DifferentiabilityInfo] | None
259*da0073e9SAndroid Build Coastguard Worker    fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker# TODO: Update comment below since it is out of date.
263*da0073e9SAndroid Build Coastguard Workerdef dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
264*da0073e9SAndroid Build Coastguard Worker    """How are we going to call the underlying implementation of a
265*da0073e9SAndroid Build Coastguard Worker    declaration?  There are two strategies:
266*da0073e9SAndroid Build Coastguard Worker        - use_derived: we want to call the implementation on CPUDoubleType
267*da0073e9SAndroid Build Coastguard Worker          (or a similar, derived Type instance).  Because these derived
268*da0073e9SAndroid Build Coastguard Worker          instances deal in Tensors, not Variables (it's a completely different
269*da0073e9SAndroid Build Coastguard Worker          object, so it doesn't dispatch back to VariableType), code on
270*da0073e9SAndroid Build Coastguard Worker          this dispatch path needs to wrap/unwrap tensors.  If the
271*da0073e9SAndroid Build Coastguard Worker          derived implementation takes and returns tensors, the
272*da0073e9SAndroid Build Coastguard Worker          implementation is usually differentiable (although we also use
273*da0073e9SAndroid Build Coastguard Worker          the derived dispatch path for non-differentiable functions
274*da0073e9SAndroid Build Coastguard Worker          that we still want to dispatch on the derived Type instance;
275*da0073e9SAndroid Build Coastguard Worker          e.g., size())
276*da0073e9SAndroid Build Coastguard Worker        - use_type: we want to call the implementation on Type, because
277*da0073e9SAndroid Build Coastguard Worker          it is implemented concretely, and the functions it invokes will
278*da0073e9SAndroid Build Coastguard Worker          get dispatched back to VariableType (which will ensure that they
279*da0073e9SAndroid Build Coastguard Worker          are differentiable.)
280*da0073e9SAndroid Build Coastguard Worker    """
281*da0073e9SAndroid Build Coastguard Worker    # fn is derived as long as any of its per-key differentiability infos
282*da0073e9SAndroid Build Coastguard Worker    # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
283*da0073e9SAndroid Build Coastguard Worker    # and ADInplaceOrViewType. We want to generate these functions as long as a
284*da0073e9SAndroid Build Coastguard Worker    # derivative is defined for ANY dispatch key.
285*da0073e9SAndroid Build Coastguard Worker    if fn.func.is_abstract or (
286*da0073e9SAndroid Build Coastguard Worker        fn.info is not None and any(info.has_derivatives for info in fn.info.values())
287*da0073e9SAndroid Build Coastguard Worker    ):
288*da0073e9SAndroid Build Coastguard Worker        # If the function is abstract (not implemented on at::Type), we must
289*da0073e9SAndroid Build Coastguard Worker        # call the implementation on the derived type with unpacked tensors.
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        # If the function has a derivative specified and is concrete, we could
292*da0073e9SAndroid Build Coastguard Worker        # call either implementation. We prefer the calling the derived
293*da0073e9SAndroid Build Coastguard Worker        # type's implementation with unpacked tensors because it is more
294*da0073e9SAndroid Build Coastguard Worker        # performant in some cases: any internal calls to other ATen functions
295*da0073e9SAndroid Build Coastguard Worker        # won't have the history tracked.
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        # If the function has a type dispatched argument (i.e. is a factory),
298*da0073e9SAndroid Build Coastguard Worker        # we prefer calling the derived type's implementation both because it is
299*da0073e9SAndroid Build Coastguard Worker        # more performant and to ensure factory functions return tensors with _version
300*da0073e9SAndroid Build Coastguard Worker        # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
301*da0073e9SAndroid Build Coastguard Worker        # to understand.
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        return "use_derived"
304*da0073e9SAndroid Build Coastguard Worker    else:
305*da0073e9SAndroid Build Coastguard Worker        # If the function is concrete (we don't have to override it) and we
306*da0073e9SAndroid Build Coastguard Worker        # didn't declare it in derivatives.yaml, we'll assume that it is
307*da0073e9SAndroid Build Coastguard Worker        # actually implemented out of differentiable functions. (This
308*da0073e9SAndroid Build Coastguard Worker        # assumption might not hold, but then you'll see gradcheck fail.)
309*da0073e9SAndroid Build Coastguard Worker        return "use_type"
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Workerdef is_foreach_func(f: NativeFunction) -> bool:
313*da0073e9SAndroid Build Coastguard Worker    return f.func.name.name.base.startswith("_foreach_")
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
317*da0073e9SAndroid Build Coastguard Worker# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
318*da0073e9SAndroid Build Coastguard Worker# they would find such one in `functional_info_by_signature`. There however are some exceptions:
319*da0073e9SAndroid Build Coastguard Worker_foreach_with_inplace_ref = {"_foreach_zero_"}
320*da0073e9SAndroid Build Coastguard Worker_foreach_with_tensor_overload = {
321*da0073e9SAndroid Build Coastguard Worker    "_foreach_add.Tensor",
322*da0073e9SAndroid Build Coastguard Worker    "_foreach_mul.Tensor",
323*da0073e9SAndroid Build Coastguard Worker    "_foreach_div.Tensor",
324*da0073e9SAndroid Build Coastguard Worker}
325*da0073e9SAndroid Build Coastguard Worker# The following do not support the alpha kwarg, which the nonforeach versions support.
326*da0073e9SAndroid Build Coastguard Worker_skip_argument_len_check = {
327*da0073e9SAndroid Build Coastguard Worker    "_foreach_add.Scalar",
328*da0073e9SAndroid Build Coastguard Worker    "_foreach_add_.Scalar",
329*da0073e9SAndroid Build Coastguard Worker    "_foreach_add.ScalarList",
330*da0073e9SAndroid Build Coastguard Worker    "_foreach_add_.ScalarList",
331*da0073e9SAndroid Build Coastguard Worker    "_foreach_sub.Scalar",
332*da0073e9SAndroid Build Coastguard Worker    "_foreach_sub_.Scalar",
333*da0073e9SAndroid Build Coastguard Worker    "_foreach_sub.ScalarList",
334*da0073e9SAndroid Build Coastguard Worker    "_foreach_sub_.ScalarList",
335*da0073e9SAndroid Build Coastguard Worker}
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
339*da0073e9SAndroid Build Coastguard Worker# reference to generate derivatives.
340*da0073e9SAndroid Build Coastguard Workerdef is_reference_for_foreach(
341*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
342*da0073e9SAndroid Build Coastguard Worker    function_schema: FunctionSchema,
343*da0073e9SAndroid Build Coastguard Worker) -> bool:
344*da0073e9SAndroid Build Coastguard Worker    return (
345*da0073e9SAndroid Build Coastguard Worker        f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
346*da0073e9SAndroid Build Coastguard Worker        and (
347*da0073e9SAndroid Build Coastguard Worker            not function_schema.name.name.inplace
348*da0073e9SAndroid Build Coastguard Worker            or str(f.func.name) in _foreach_with_inplace_ref
349*da0073e9SAndroid Build Coastguard Worker        )
350*da0073e9SAndroid Build Coastguard Worker        and (
351*da0073e9SAndroid Build Coastguard Worker            str(f.func.name) in _skip_argument_len_check
352*da0073e9SAndroid Build Coastguard Worker            or len(f.func.arguments.flat_non_out)
353*da0073e9SAndroid Build Coastguard Worker            == len(function_schema.arguments.flat_non_out)
354*da0073e9SAndroid Build Coastguard Worker        )
355*da0073e9SAndroid Build Coastguard Worker        and all(
356*da0073e9SAndroid Build Coastguard Worker            ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
357*da0073e9SAndroid Build Coastguard Worker            for arg, ref_arg in zip(
358*da0073e9SAndroid Build Coastguard Worker                f.func.arguments.flat_non_out,
359*da0073e9SAndroid Build Coastguard Worker                function_schema.arguments.flat_non_out,
360*da0073e9SAndroid Build Coastguard Worker            )
361*da0073e9SAndroid Build Coastguard Worker        )
362*da0073e9SAndroid Build Coastguard Worker    )
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker# TODO(crcrpar): Avoid hard coding "Default" ideally.
366*da0073e9SAndroid Build Coastguard Workerdef gen_foreach_derivativeinfo(
367*da0073e9SAndroid Build Coastguard Worker    foreach_function: NativeFunction,
368*da0073e9SAndroid Build Coastguard Worker    functional_info_by_signature: dict[
369*da0073e9SAndroid Build Coastguard Worker        FunctionSchema, dict[str, DifferentiabilityInfo]
370*da0073e9SAndroid Build Coastguard Worker    ],
371*da0073e9SAndroid Build Coastguard Worker    non_functional_info_by_signature: dict[
372*da0073e9SAndroid Build Coastguard Worker        FunctionSchema, dict[str, DifferentiabilityInfo]
373*da0073e9SAndroid Build Coastguard Worker    ],
374*da0073e9SAndroid Build Coastguard Worker    dispatch_key: str = "Default",
375*da0073e9SAndroid Build Coastguard Worker) -> tuple[DifferentiabilityInfo | None, bool]:
376*da0073e9SAndroid Build Coastguard Worker    """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    The second return value indicates whether the info is generated in this function.
379*da0073e9SAndroid Build Coastguard Worker    """
380*da0073e9SAndroid Build Coastguard Worker    ref_diff_info: DifferentiabilityInfo | None = None
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    for function_schema, diff_info in functional_info_by_signature.items():
383*da0073e9SAndroid Build Coastguard Worker        if not is_reference_for_foreach(foreach_function, function_schema):
384*da0073e9SAndroid Build Coastguard Worker            continue
385*da0073e9SAndroid Build Coastguard Worker        ref_diff_info = diff_info[dispatch_key]
386*da0073e9SAndroid Build Coastguard Worker        if ref_diff_info is not None:
387*da0073e9SAndroid Build Coastguard Worker            break
388*da0073e9SAndroid Build Coastguard Worker    # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
389*da0073e9SAndroid Build Coastguard Worker    # while the info of `zero_` is in non_functional_info_by_signature
390*da0073e9SAndroid Build Coastguard Worker    if (
391*da0073e9SAndroid Build Coastguard Worker        ref_diff_info is None
392*da0073e9SAndroid Build Coastguard Worker        and foreach_function.func.kind() == SchemaKind.inplace
393*da0073e9SAndroid Build Coastguard Worker        and str(foreach_function.func.name) in _foreach_with_inplace_ref
394*da0073e9SAndroid Build Coastguard Worker    ):
395*da0073e9SAndroid Build Coastguard Worker        for function_schema, diff_info in non_functional_info_by_signature.items():
396*da0073e9SAndroid Build Coastguard Worker            if not is_reference_for_foreach(foreach_function, function_schema):
397*da0073e9SAndroid Build Coastguard Worker                continue
398*da0073e9SAndroid Build Coastguard Worker            ref_diff_info = diff_info[dispatch_key]
399*da0073e9SAndroid Build Coastguard Worker            if ref_diff_info is not None:
400*da0073e9SAndroid Build Coastguard Worker                break
401*da0073e9SAndroid Build Coastguard Worker    if ref_diff_info is None:
402*da0073e9SAndroid Build Coastguard Worker        return None, False
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    # non out-place uses the existing Derivative.
405*da0073e9SAndroid Build Coastguard Worker    if foreach_function.func.kind() == SchemaKind.inplace:
406*da0073e9SAndroid Build Coastguard Worker        return ref_diff_info, False
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    map_refarg2foreacharg, map_name2arg = {}, {}
409*da0073e9SAndroid Build Coastguard Worker    for i, (arg, ref_arg) in enumerate(
410*da0073e9SAndroid Build Coastguard Worker        zip(
411*da0073e9SAndroid Build Coastguard Worker            foreach_function.func.arguments.flat_non_out,
412*da0073e9SAndroid Build Coastguard Worker            function_schema.arguments.flat_non_out,
413*da0073e9SAndroid Build Coastguard Worker        )
414*da0073e9SAndroid Build Coastguard Worker    ):
415*da0073e9SAndroid Build Coastguard Worker        map_refarg2foreacharg[ref_arg.name] = arg.name
416*da0073e9SAndroid Build Coastguard Worker        map_name2arg[arg.name] = arg
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
419*da0073e9SAndroid Build Coastguard Worker    modified_derivative_formulas = []
420*da0073e9SAndroid Build Coastguard Worker    for i, derivative in enumerate(ref_diff_info.derivatives):
421*da0073e9SAndroid Build Coastguard Worker        modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
422*da0073e9SAndroid Build Coastguard Worker            "result", "result[i]"
423*da0073e9SAndroid Build Coastguard Worker        )
424*da0073e9SAndroid Build Coastguard Worker        saved_inputs, saved_outputs = [], []
425*da0073e9SAndroid Build Coastguard Worker        # note(crcrpar): This context seems necessary to call `cpp.argument_type`
426*da0073e9SAndroid Build Coastguard Worker        with local.parametrize(
427*da0073e9SAndroid Build Coastguard Worker            use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
428*da0073e9SAndroid Build Coastguard Worker            use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
429*da0073e9SAndroid Build Coastguard Worker        ):
430*da0073e9SAndroid Build Coastguard Worker            for ref_input in derivative.saved_inputs:
431*da0073e9SAndroid Build Coastguard Worker                ref_input_jit_name = ref_input.expr.split(".")[0]
432*da0073e9SAndroid Build Coastguard Worker                mapped_name = map_refarg2foreacharg[ref_input_jit_name]
433*da0073e9SAndroid Build Coastguard Worker                if isinstance(map_name2arg[mapped_name].type, ListType):
434*da0073e9SAndroid Build Coastguard Worker                    mapped_expr = mapped_name + "[i]"
435*da0073e9SAndroid Build Coastguard Worker                else:
436*da0073e9SAndroid Build Coastguard Worker                    mapped_expr = mapped_name
437*da0073e9SAndroid Build Coastguard Worker                new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
438*da0073e9SAndroid Build Coastguard Worker                modified_formula = modified_formula.replace(
439*da0073e9SAndroid Build Coastguard Worker                    cast(str, ref_input.nctype.name), new_expr
440*da0073e9SAndroid Build Coastguard Worker                )
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker                nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
443*da0073e9SAndroid Build Coastguard Worker                canonical_nctype = NamedCType(
444*da0073e9SAndroid Build Coastguard Worker                    nctype.name, nctype.type.remove_const_ref()
445*da0073e9SAndroid Build Coastguard Worker                )
446*da0073e9SAndroid Build Coastguard Worker                saved_inputs.append(
447*da0073e9SAndroid Build Coastguard Worker                    SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
448*da0073e9SAndroid Build Coastguard Worker                )
449*da0073e9SAndroid Build Coastguard Worker            for ref_output in derivative.saved_outputs:
450*da0073e9SAndroid Build Coastguard Worker                if ref_output.nctype.name == "result":
451*da0073e9SAndroid Build Coastguard Worker                    saved_outputs.append(
452*da0073e9SAndroid Build Coastguard Worker                        SavedAttribute(
453*da0073e9SAndroid Build Coastguard Worker                            nctype=NamedCType(
454*da0073e9SAndroid Build Coastguard Worker                                name="result", type=BaseCType(tensorListT)
455*da0073e9SAndroid Build Coastguard Worker                            ),
456*da0073e9SAndroid Build Coastguard Worker                            expr="result",
457*da0073e9SAndroid Build Coastguard Worker                        )
458*da0073e9SAndroid Build Coastguard Worker                    )
459*da0073e9SAndroid Build Coastguard Worker                else:
460*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("")
461*da0073e9SAndroid Build Coastguard Worker        var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
462*da0073e9SAndroid Build Coastguard Worker        all_var_names.extend(var_names)
463*da0073e9SAndroid Build Coastguard Worker        all_saved_inputs.extend(saved_inputs)
464*da0073e9SAndroid Build Coastguard Worker        all_saved_outputs.extend(saved_outputs)
465*da0073e9SAndroid Build Coastguard Worker        modified_derivative = Derivative(
466*da0073e9SAndroid Build Coastguard Worker            formula=modified_formula,
467*da0073e9SAndroid Build Coastguard Worker            original_formula=derivative.formula,
468*da0073e9SAndroid Build Coastguard Worker            var_names=tuple(var_names),
469*da0073e9SAndroid Build Coastguard Worker            saved_inputs=tuple(saved_inputs),
470*da0073e9SAndroid Build Coastguard Worker            saved_outputs=tuple(saved_outputs),
471*da0073e9SAndroid Build Coastguard Worker            named_gradients=set(),
472*da0073e9SAndroid Build Coastguard Worker        )
473*da0073e9SAndroid Build Coastguard Worker        modified_derivative_formulas.append(modified_derivative)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    with local.parametrize(
476*da0073e9SAndroid Build Coastguard Worker        use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
477*da0073e9SAndroid Build Coastguard Worker        use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
478*da0073e9SAndroid Build Coastguard Worker    ):
479*da0073e9SAndroid Build Coastguard Worker        args_with_derivatives = [
480*da0073e9SAndroid Build Coastguard Worker            Binding(
481*da0073e9SAndroid Build Coastguard Worker                name=arg.name,
482*da0073e9SAndroid Build Coastguard Worker                nctype=cpp.argument_type(arg, binds=arg.name),
483*da0073e9SAndroid Build Coastguard Worker                argument=arg,
484*da0073e9SAndroid Build Coastguard Worker                default=None,
485*da0073e9SAndroid Build Coastguard Worker            )
486*da0073e9SAndroid Build Coastguard Worker            for arg in foreach_function.func.arguments.flat_non_out
487*da0073e9SAndroid Build Coastguard Worker            if arg.name in all_var_names
488*da0073e9SAndroid Build Coastguard Worker        ]
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    forward_derivatives: list[ForwardDerivative] = []
491*da0073e9SAndroid Build Coastguard Worker    fw_derivative: ForwardDerivative
492*da0073e9SAndroid Build Coastguard Worker    for fw_derivative in ref_diff_info.forward_derivatives:
493*da0073e9SAndroid Build Coastguard Worker        var_names: list[str] = list(fw_derivative.var_names)  # type: ignore[no-redef]
494*da0073e9SAndroid Build Coastguard Worker        var_types: list[Type] = list(fw_derivative.var_types)
495*da0073e9SAndroid Build Coastguard Worker        required_inputs_fw_grad: list[str] = []
496*da0073e9SAndroid Build Coastguard Worker        required_inputs_primal: list[str] = []
497*da0073e9SAndroid Build Coastguard Worker        if fw_derivative.required_inputs_fw_grad is not None:
498*da0073e9SAndroid Build Coastguard Worker            required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
499*da0073e9SAndroid Build Coastguard Worker        if fw_derivative.required_inputs_primal:
500*da0073e9SAndroid Build Coastguard Worker            required_inputs_primal = list(fw_derivative.required_inputs_primal)
501*da0073e9SAndroid Build Coastguard Worker        modified_formula = fw_derivative.formula
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        # Foreach's result is TensorList
504*da0073e9SAndroid Build Coastguard Worker        if "result" in modified_formula:
505*da0073e9SAndroid Build Coastguard Worker            modified_formula = fw_derivative.formula.replace("result", "result[i]")
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        for foreach_arg, ref_arg in zip(
508*da0073e9SAndroid Build Coastguard Worker            foreach_function.func.arguments.flat_non_out,
509*da0073e9SAndroid Build Coastguard Worker            ref_diff_info.func.func.arguments.flat_non_out,
510*da0073e9SAndroid Build Coastguard Worker        ):
511*da0073e9SAndroid Build Coastguard Worker            # Modify reference forward formula
512*da0073e9SAndroid Build Coastguard Worker            if (
513*da0073e9SAndroid Build Coastguard Worker                isinstance(foreach_arg.type, ListType)
514*da0073e9SAndroid Build Coastguard Worker                and not foreach_arg.type.is_tensor_like()
515*da0073e9SAndroid Build Coastguard Worker            ):
516*da0073e9SAndroid Build Coastguard Worker                # Assuming ScalarList
517*da0073e9SAndroid Build Coastguard Worker                modified_formula = modified_formula.replace(
518*da0073e9SAndroid Build Coastguard Worker                    ref_arg.name, foreach_arg.name + "[i]"
519*da0073e9SAndroid Build Coastguard Worker                )
520*da0073e9SAndroid Build Coastguard Worker            elif foreach_arg.type.is_tensor_like():
521*da0073e9SAndroid Build Coastguard Worker                # Assuming TensorList / Tensor
522*da0073e9SAndroid Build Coastguard Worker                # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
523*da0073e9SAndroid Build Coastguard Worker                assert isinstance(foreach_arg.type, ListType) or (
524*da0073e9SAndroid Build Coastguard Worker                    foreach_arg.type == BaseType(BaseTy.Tensor)
525*da0073e9SAndroid Build Coastguard Worker                    and str(foreach_function.func.name) in _foreach_with_tensor_overload
526*da0073e9SAndroid Build Coastguard Worker                ), f"{foreach_function.func.name}, {foreach_arg.type}"
527*da0073e9SAndroid Build Coastguard Worker                for suffix in ("_p", "_t"):
528*da0073e9SAndroid Build Coastguard Worker                    curr_expr = ref_arg.name + suffix
529*da0073e9SAndroid Build Coastguard Worker                    if curr_expr in modified_formula:
530*da0073e9SAndroid Build Coastguard Worker                        new_expr = foreach_arg.name + suffix
531*da0073e9SAndroid Build Coastguard Worker                        modified_formula = modified_formula.replace(curr_expr, new_expr)
532*da0073e9SAndroid Build Coastguard Worker            else:
533*da0073e9SAndroid Build Coastguard Worker                # Assuming Scalar
534*da0073e9SAndroid Build Coastguard Worker                if foreach_arg.name != ref_arg.name:
535*da0073e9SAndroid Build Coastguard Worker                    modified_formula = modified_formula.replace(
536*da0073e9SAndroid Build Coastguard Worker                        ref_arg.name, foreach_arg.name
537*da0073e9SAndroid Build Coastguard Worker                    )
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker            # note(crcrpar): there should exist a cooler way...
540*da0073e9SAndroid Build Coastguard Worker            for i, name in enumerate(var_names):
541*da0073e9SAndroid Build Coastguard Worker                if name == ref_arg.name:
542*da0073e9SAndroid Build Coastguard Worker                    var_names[i] = foreach_arg.name
543*da0073e9SAndroid Build Coastguard Worker                    var_types[i] = foreach_arg.type
544*da0073e9SAndroid Build Coastguard Worker            for i, name in enumerate(required_inputs_fw_grad):
545*da0073e9SAndroid Build Coastguard Worker                if name == ref_arg.name:
546*da0073e9SAndroid Build Coastguard Worker                    required_inputs_fw_grad[i] = foreach_arg.name
547*da0073e9SAndroid Build Coastguard Worker            for i, name in enumerate(required_inputs_primal):
548*da0073e9SAndroid Build Coastguard Worker                if name == ref_arg.name:
549*da0073e9SAndroid Build Coastguard Worker                    required_inputs_primal[i] = foreach_arg.name
550*da0073e9SAndroid Build Coastguard Worker        forward_derivatives.append(
551*da0073e9SAndroid Build Coastguard Worker            ForwardDerivative(
552*da0073e9SAndroid Build Coastguard Worker                formula=modified_formula,
553*da0073e9SAndroid Build Coastguard Worker                var_names=tuple(var_names),
554*da0073e9SAndroid Build Coastguard Worker                var_types=tuple(var_types),
555*da0073e9SAndroid Build Coastguard Worker                required_inputs_fw_grad=tuple(required_inputs_fw_grad),
556*da0073e9SAndroid Build Coastguard Worker                required_inputs_primal=tuple(required_inputs_primal),
557*da0073e9SAndroid Build Coastguard Worker                required_original_self_value=fw_derivative.required_original_self_value,
558*da0073e9SAndroid Build Coastguard Worker                is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
559*da0073e9SAndroid Build Coastguard Worker            )
560*da0073e9SAndroid Build Coastguard Worker        )
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    return (
563*da0073e9SAndroid Build Coastguard Worker        DifferentiabilityInfo(
564*da0073e9SAndroid Build Coastguard Worker            name=foreach_function.func.name.name.base,
565*da0073e9SAndroid Build Coastguard Worker            func=foreach_function,
566*da0073e9SAndroid Build Coastguard Worker            op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
567*da0073e9SAndroid Build Coastguard Worker            derivatives=modified_derivative_formulas,
568*da0073e9SAndroid Build Coastguard Worker            forward_derivatives=forward_derivatives,
569*da0073e9SAndroid Build Coastguard Worker            all_saved_inputs=tuple(set(all_saved_inputs)),
570*da0073e9SAndroid Build Coastguard Worker            all_saved_outputs=tuple(set(all_saved_outputs)),
571*da0073e9SAndroid Build Coastguard Worker            available_named_gradients=(),
572*da0073e9SAndroid Build Coastguard Worker            used_named_gradients=set(),
573*da0073e9SAndroid Build Coastguard Worker            args_with_derivatives=args_with_derivatives,
574*da0073e9SAndroid Build Coastguard Worker            non_differentiable_arg_names=[],
575*da0073e9SAndroid Build Coastguard Worker            output_differentiability=None,
576*da0073e9SAndroid Build Coastguard Worker            output_differentiability_conditions=None,
577*da0073e9SAndroid Build Coastguard Worker        ),
578*da0073e9SAndroid Build Coastguard Worker        True,
579*da0073e9SAndroid Build Coastguard Worker    )
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Workerdef match_differentiability_info(
583*da0073e9SAndroid Build Coastguard Worker    native_functions: list[NativeFunction],
584*da0073e9SAndroid Build Coastguard Worker    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
585*da0073e9SAndroid Build Coastguard Worker) -> list[NativeFunctionWithDifferentiabilityInfo]:
586*da0073e9SAndroid Build Coastguard Worker    """Sets the "derivative" key on declarations to matching autograd function
587*da0073e9SAndroid Build Coastguard Worker    In-place functions will use the out-of-place derivative definition if there
588*da0073e9SAndroid Build Coastguard Worker    is no in-place specific derivative.
589*da0073e9SAndroid Build Coastguard Worker    """
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    functional_info_by_signature = {
592*da0073e9SAndroid Build Coastguard Worker        schema.signature(strip_default=True): info_dict
593*da0073e9SAndroid Build Coastguard Worker        for schema, info_dict in differentiability_infos.items()
594*da0073e9SAndroid Build Coastguard Worker        if schema.kind() == SchemaKind.functional
595*da0073e9SAndroid Build Coastguard Worker    }
596*da0073e9SAndroid Build Coastguard Worker    non_functional_info_by_signature = {
597*da0073e9SAndroid Build Coastguard Worker        schema.signature(strip_default=True): info_dict
598*da0073e9SAndroid Build Coastguard Worker        for schema, info_dict in differentiability_infos.items()
599*da0073e9SAndroid Build Coastguard Worker        if schema.kind() != SchemaKind.functional
600*da0073e9SAndroid Build Coastguard Worker    }
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    def find_info(
603*da0073e9SAndroid Build Coastguard Worker        f: NativeFunction,
604*da0073e9SAndroid Build Coastguard Worker    ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
605*da0073e9SAndroid Build Coastguard Worker        # Don't bother matching info to generated out= variants
606*da0073e9SAndroid Build Coastguard Worker        if "generated" in f.tags and f.func.kind() == SchemaKind.out:
607*da0073e9SAndroid Build Coastguard Worker            return None, False
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        # (1) Check for an exact match
610*da0073e9SAndroid Build Coastguard Worker        if f.func in differentiability_infos:
611*da0073e9SAndroid Build Coastguard Worker            return differentiability_infos[f.func], True
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        # (2) If no exact match, check if the out-of-place variant
614*da0073e9SAndroid Build Coastguard Worker        # of this operator has a match.
615*da0073e9SAndroid Build Coastguard Worker        # i.e mul() for mul_() or mul_out()
616*da0073e9SAndroid Build Coastguard Worker        # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
617*da0073e9SAndroid Build Coastguard Worker        # native functions instead of the out-place counterparts.
618*da0073e9SAndroid Build Coastguard Worker        f_sig = f.func.signature(strip_default=True)
619*da0073e9SAndroid Build Coastguard Worker        if f_sig in functional_info_by_signature and not is_foreach_func(f):
620*da0073e9SAndroid Build Coastguard Worker            return functional_info_by_signature[f_sig], False
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        # (3) Some operators have a derivative explicitly defined for the mutable
623*da0073e9SAndroid Build Coastguard Worker        # variant, but get a code-generated out-of-place variant which does *not*
624*da0073e9SAndroid Build Coastguard Worker        # come with a derivative formula.
625*da0073e9SAndroid Build Coastguard Worker        # For the generated out-of-place variant, use the mutable variant's formula
626*da0073e9SAndroid Build Coastguard Worker        # if it exists.
627*da0073e9SAndroid Build Coastguard Worker        if "generated" in f.tags and f_sig in non_functional_info_by_signature:
628*da0073e9SAndroid Build Coastguard Worker            info_dict = non_functional_info_by_signature[f_sig]
629*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
630*da0073e9SAndroid Build Coastguard Worker            assert not any(
631*da0073e9SAndroid Build Coastguard Worker                any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
632*da0073e9SAndroid Build Coastguard Worker                for info in info_dict.values()
633*da0073e9SAndroid Build Coastguard Worker            ), f"""\
634*da0073e9SAndroid Build Coastguard WorkerAttempted to convert a derivative formula for a mutable operator
635*da0073e9SAndroid Build Coastguard Worker to be used by automatically by its functional variant ("{str(f.func)}").
636*da0073e9SAndroid Build Coastguard Worker this is not currently supported (we'd need to fix up the formula in the codegen)."""
637*da0073e9SAndroid Build Coastguard Worker            return info_dict, False
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker        # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
640*da0073e9SAndroid Build Coastguard Worker        if is_foreach_func(f):
641*da0073e9SAndroid Build Coastguard Worker            assert f.func not in differentiability_infos
642*da0073e9SAndroid Build Coastguard Worker            diff_info, is_generated = gen_foreach_derivativeinfo(
643*da0073e9SAndroid Build Coastguard Worker                f,
644*da0073e9SAndroid Build Coastguard Worker                functional_info_by_signature,
645*da0073e9SAndroid Build Coastguard Worker                non_functional_info_by_signature,
646*da0073e9SAndroid Build Coastguard Worker            )
647*da0073e9SAndroid Build Coastguard Worker            if diff_info is None:
648*da0073e9SAndroid Build Coastguard Worker                return None, False
649*da0073e9SAndroid Build Coastguard Worker            # TODO(crcrpar): Avoid hard coding "Default" ideally.
650*da0073e9SAndroid Build Coastguard Worker            diff_info_dict = {"Default": diff_info}
651*da0073e9SAndroid Build Coastguard Worker            if is_generated:
652*da0073e9SAndroid Build Coastguard Worker                differentiability_infos[f.func] = diff_info_dict
653*da0073e9SAndroid Build Coastguard Worker                functional_info_by_signature[f.func] = diff_info_dict
654*da0073e9SAndroid Build Coastguard Worker            return diff_info_dict, is_generated
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        return None, False
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker    result: list[NativeFunctionWithDifferentiabilityInfo] = []
659*da0073e9SAndroid Build Coastguard Worker    for f in native_functions:
660*da0073e9SAndroid Build Coastguard Worker        info_dict, is_exact_match = find_info(f)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker        # Currently, the '.strides()' to 'strides_or_error' replacement does not support
663*da0073e9SAndroid Build Coastguard Worker        # 'self' derivatives of an inplace function, so we must check for this case.
664*da0073e9SAndroid Build Coastguard Worker        if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
665*da0073e9SAndroid Build Coastguard Worker            for info in info_dict.values():
666*da0073e9SAndroid Build Coastguard Worker                for derivative in info.derivatives:
667*da0073e9SAndroid Build Coastguard Worker                    if "self" in derivative.var_names:
668*da0073e9SAndroid Build Coastguard Worker                        for saved_input in derivative.saved_inputs:
669*da0073e9SAndroid Build Coastguard Worker                            assert "strides_or_error" not in saved_input.expr, (
670*da0073e9SAndroid Build Coastguard Worker                                "Calling '.strides()' in the 'self' derivative formula of an "
671*da0073e9SAndroid Build Coastguard Worker                                f"in-place function is not supported: {f.func}"
672*da0073e9SAndroid Build Coastguard Worker                            )
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        if not info_dict:
675*da0073e9SAndroid Build Coastguard Worker            result.append(
676*da0073e9SAndroid Build Coastguard Worker                NativeFunctionWithDifferentiabilityInfo(
677*da0073e9SAndroid Build Coastguard Worker                    func=f, info=None, fw_derivatives=None
678*da0073e9SAndroid Build Coastguard Worker                )
679*da0073e9SAndroid Build Coastguard Worker            )
680*da0073e9SAndroid Build Coastguard Worker            continue
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
683*da0073e9SAndroid Build Coastguard Worker        for key, info in info_dict.items():
684*da0073e9SAndroid Build Coastguard Worker            if not info.forward_derivatives:
685*da0073e9SAndroid Build Coastguard Worker                fw_derivative_dict[key] = []
686*da0073e9SAndroid Build Coastguard Worker                continue
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker            forward_derivatives = info.forward_derivatives
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker            # For functions that have a single def for out-of-place and inplace (like abs())
691*da0073e9SAndroid Build Coastguard Worker            if f.func.kind() == SchemaKind.inplace:
692*da0073e9SAndroid Build Coastguard Worker                # For inplace functions there is a little bit of work to do:
693*da0073e9SAndroid Build Coastguard Worker                #  1) Validate the formula and make sure the input that is modified in not used:
694*da0073e9SAndroid Build Coastguard Worker                #    - If there is a formula for the inplace variant of the function (is_exact_match == True) then
695*da0073e9SAndroid Build Coastguard Worker                #      we make sure that the original value of the input that is being modified inplace (self_p) is
696*da0073e9SAndroid Build Coastguard Worker                #      not used in the formula. Note that the formula can use "original_self_p" here and that would
697*da0073e9SAndroid Build Coastguard Worker                #      trigger a clone of the original input.
698*da0073e9SAndroid Build Coastguard Worker                #    - If we are re-using the out of place formula (is_exact_match == False) then we replace every
699*da0073e9SAndroid Build Coastguard Worker                #      occurrence of self_p and self_t by original_self_p and original_self_t. These will be
700*da0073e9SAndroid Build Coastguard Worker                #      populated by cloned version of the original input (either the clone done by the backward AD
701*da0073e9SAndroid Build Coastguard Worker                #      logic if self is also used in a backward formula or a special clone that we add).
702*da0073e9SAndroid Build Coastguard Worker                #  2) At this point, there cannot be a self_p in the formula.
703*da0073e9SAndroid Build Coastguard Worker                #  3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
704*da0073e9SAndroid Build Coastguard Worker                #     simply called self (as it is modified inplace).
705*da0073e9SAndroid Build Coastguard Worker                #  4) Update the required primals data in case it used to contain "result" but should now contain
706*da0073e9SAndroid Build Coastguard Worker                #     "self"
707*da0073e9SAndroid Build Coastguard Worker                #  5) If it is not an exact match, the user formula is not modifying the existing forward grad
708*da0073e9SAndroid Build Coastguard Worker                #     inplace as it should. So add some code that makes sure that we do so if the forward grad
709*da0073e9SAndroid Build Coastguard Worker                #     already exists.
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker                assert (
712*da0073e9SAndroid Build Coastguard Worker                    len(info.forward_derivatives) == 1
713*da0073e9SAndroid Build Coastguard Worker                )  # Only single output inplace should exist
714*da0073e9SAndroid Build Coastguard Worker                fw_info = info.forward_derivatives[0]
715*da0073e9SAndroid Build Coastguard Worker                formula = fw_info.formula
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker                def replace_self_with_original_self(formula: str, postfix: str) -> str:
718*da0073e9SAndroid Build Coastguard Worker                    def repl(m: re.Match[str]) -> str:
719*da0073e9SAndroid Build Coastguard Worker                        return f"{m.group(1)}original_self{postfix}{m.group(2)}"
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker                    return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker                if re.search(IDENT_REGEX.format("self_p"), formula):
724*da0073e9SAndroid Build Coastguard Worker                    if is_exact_match:
725*da0073e9SAndroid Build Coastguard Worker                        # For manually defined formulas, don't allow the original value to be used
726*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError(
727*da0073e9SAndroid Build Coastguard Worker                            f'The formula for "{f.func.name}" is using the original value of self '
728*da0073e9SAndroid Build Coastguard Worker                            "that is being modified inplace. This would lead to wrong forward gradients. "
729*da0073e9SAndroid Build Coastguard Worker                            'Please use "result" in the formula only.'
730*da0073e9SAndroid Build Coastguard Worker                        )
731*da0073e9SAndroid Build Coastguard Worker                    else:
732*da0073e9SAndroid Build Coastguard Worker                        # When the original formula is out of place, we save a clone of the primal
733*da0073e9SAndroid Build Coastguard Worker                        # value to be able to access this value if needed
734*da0073e9SAndroid Build Coastguard Worker                        # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
735*da0073e9SAndroid Build Coastguard Worker                        formula = replace_self_with_original_self(formula, "_p")
736*da0073e9SAndroid Build Coastguard Worker                        formula = replace_self_with_original_self(formula, "_t")
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker                # replace "result" from the formula by "self_p"
739*da0073e9SAndroid Build Coastguard Worker                def repl(m: re.Match[str]) -> str:
740*da0073e9SAndroid Build Coastguard Worker                    return f"{m.group(1)}self_p{m.group(2)}"
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker                formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker                required_primals = fw_info.required_inputs_primal
745*da0073e9SAndroid Build Coastguard Worker                if re.search(IDENT_REGEX.format("self_p"), formula):
746*da0073e9SAndroid Build Coastguard Worker                    required_primals = (
747*da0073e9SAndroid Build Coastguard Worker                        required_primals + ("self",) if required_primals else ("self",)
748*da0073e9SAndroid Build Coastguard Worker                    )
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker                if not is_exact_match:
751*da0073e9SAndroid Build Coastguard Worker                    # NOTE [In-place forward AD formula Optimization]
752*da0073e9SAndroid Build Coastguard Worker                    #
753*da0073e9SAndroid Build Coastguard Worker                    # This optimization transforms the formula to directly do inplace, i.e.
754*da0073e9SAndroid Build Coastguard Worker                    # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
755*da0073e9SAndroid Build Coastguard Worker                    #
756*da0073e9SAndroid Build Coastguard Worker                    # 1) the formula satisfies the pattern: "self_t.op(*args)"
757*da0073e9SAndroid Build Coastguard Worker                    # 2) "op" in (1) needs to be the same as the op the derivative is for
758*da0073e9SAndroid Build Coastguard Worker                    #
759*da0073e9SAndroid Build Coastguard Worker                    # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
760*da0073e9SAndroid Build Coastguard Worker                    # If there is a need, we can relax (2) to allow any op that has an in-place variant
761*da0073e9SAndroid Build Coastguard Worker                    is_single_method_on_self_t = False
762*da0073e9SAndroid Build Coastguard Worker                    directly_do_inplace = False
763*da0073e9SAndroid Build Coastguard Worker                    op_name: str | None = None
764*da0073e9SAndroid Build Coastguard Worker                    between_parens: str | None = None
765*da0073e9SAndroid Build Coastguard Worker                    match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
766*da0073e9SAndroid Build Coastguard Worker                    if match:
767*da0073e9SAndroid Build Coastguard Worker                        op_name, between_parens = match.group(1), match.group(2)
768*da0073e9SAndroid Build Coastguard Worker
769*da0073e9SAndroid Build Coastguard Worker                        # We want to...
770*da0073e9SAndroid Build Coastguard Worker                        #   Match: self_t.op1(other_p.op2(arg))
771*da0073e9SAndroid Build Coastguard Worker                        #   Avoid: self_t.op1(args) + self_t.op2(args)
772*da0073e9SAndroid Build Coastguard Worker                        #   Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
773*da0073e9SAndroid Build Coastguard Worker                        def check_parens_nest_level_gt_zero(s: str) -> bool:
774*da0073e9SAndroid Build Coastguard Worker                            level = 1
775*da0073e9SAndroid Build Coastguard Worker                            for ch in s:
776*da0073e9SAndroid Build Coastguard Worker                                if ch == ")":
777*da0073e9SAndroid Build Coastguard Worker                                    level -= 1
778*da0073e9SAndroid Build Coastguard Worker                                    if level == 0:
779*da0073e9SAndroid Build Coastguard Worker                                        return False
780*da0073e9SAndroid Build Coastguard Worker                                if ch == "(":
781*da0073e9SAndroid Build Coastguard Worker                                    level += 1
782*da0073e9SAndroid Build Coastguard Worker                            return True
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker                        is_single_method_on_self_t = check_parens_nest_level_gt_zero(
785*da0073e9SAndroid Build Coastguard Worker                            between_parens
786*da0073e9SAndroid Build Coastguard Worker                        )
787*da0073e9SAndroid Build Coastguard Worker                        directly_do_inplace = (
788*da0073e9SAndroid Build Coastguard Worker                            is_single_method_on_self_t and op_name == info.name
789*da0073e9SAndroid Build Coastguard Worker                        )
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker                    if directly_do_inplace:
792*da0073e9SAndroid Build Coastguard Worker                        assert op_name is not None
793*da0073e9SAndroid Build Coastguard Worker                        assert between_parens is not None
794*da0073e9SAndroid Build Coastguard Worker                        formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
795*da0073e9SAndroid Build Coastguard Worker                    else:
796*da0073e9SAndroid Build Coastguard Worker                        # Make sure that the forward grad is modified inplace when the original formula
797*da0073e9SAndroid Build Coastguard Worker                        # is out of place
798*da0073e9SAndroid Build Coastguard Worker                        formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker                required_original_self_value = bool(
801*da0073e9SAndroid Build Coastguard Worker                    re.search(IDENT_REGEX.format("original_self_p"), formula)
802*da0073e9SAndroid Build Coastguard Worker                ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker                forward_derivatives = [
805*da0073e9SAndroid Build Coastguard Worker                    ForwardDerivative(
806*da0073e9SAndroid Build Coastguard Worker                        formula=formula,
807*da0073e9SAndroid Build Coastguard Worker                        var_names=("self",),
808*da0073e9SAndroid Build Coastguard Worker                        var_types=fw_info.var_types,
809*da0073e9SAndroid Build Coastguard Worker                        required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
810*da0073e9SAndroid Build Coastguard Worker                        required_inputs_primal=required_primals,
811*da0073e9SAndroid Build Coastguard Worker                        required_original_self_value=required_original_self_value,
812*da0073e9SAndroid Build Coastguard Worker                        is_reusing_outplace_formula=not is_exact_match,
813*da0073e9SAndroid Build Coastguard Worker                    ),
814*da0073e9SAndroid Build Coastguard Worker                ]
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker            fw_derivative_dict[key] = forward_derivatives
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker        result.append(
819*da0073e9SAndroid Build Coastguard Worker            NativeFunctionWithDifferentiabilityInfo(
820*da0073e9SAndroid Build Coastguard Worker                func=f, info=info_dict, fw_derivatives=fw_derivative_dict
821*da0073e9SAndroid Build Coastguard Worker            )
822*da0073e9SAndroid Build Coastguard Worker        )
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker    return result
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Workerdef is_differentiable(
828*da0073e9SAndroid Build Coastguard Worker    name: str, type: Type, info: DifferentiabilityInfo | None
829*da0073e9SAndroid Build Coastguard Worker) -> bool:
830*da0073e9SAndroid Build Coastguard Worker    return type.is_tensor_like() and (
831*da0073e9SAndroid Build Coastguard Worker        info is None or name not in info.non_differentiable_arg_names
832*da0073e9SAndroid Build Coastguard Worker    )
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Workerdef gen_differentiable_outputs(
836*da0073e9SAndroid Build Coastguard Worker    fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
837*da0073e9SAndroid Build Coastguard Worker) -> list[DifferentiableOutput]:
838*da0073e9SAndroid Build Coastguard Worker    f = fn.func
839*da0073e9SAndroid Build Coastguard Worker    info = fn.info[key] if fn.info else None
840*da0073e9SAndroid Build Coastguard Worker    outputs: list[DifferentiableOutput] = [
841*da0073e9SAndroid Build Coastguard Worker        DifferentiableOutput(
842*da0073e9SAndroid Build Coastguard Worker            name=name,
843*da0073e9SAndroid Build Coastguard Worker            type=ret.type,
844*da0073e9SAndroid Build Coastguard Worker            cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
845*da0073e9SAndroid Build Coastguard Worker        )
846*da0073e9SAndroid Build Coastguard Worker        for name, ret in zip(cpp.return_names(f), f.func.returns)
847*da0073e9SAndroid Build Coastguard Worker    ]
848*da0073e9SAndroid Build Coastguard Worker    output_differentiability = info.output_differentiability if info else None
849*da0073e9SAndroid Build Coastguard Worker    if output_differentiability is not None:
850*da0073e9SAndroid Build Coastguard Worker        if len(output_differentiability) != len(outputs):
851*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
852*da0073e9SAndroid Build Coastguard Worker                f"The length of output_differentiability ({len(output_differentiability)}), "
853*da0073e9SAndroid Build Coastguard Worker                f"does not match the number of outputs ({len(outputs)})."
854*da0073e9SAndroid Build Coastguard Worker            )
855*da0073e9SAndroid Build Coastguard Worker        differentiable_outputs: list[DifferentiableOutput] = []
856*da0073e9SAndroid Build Coastguard Worker        if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
857*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
858*da0073e9SAndroid Build Coastguard Worker                "output_differentiability=False for inplace operation (version_counter won't get updated)"
859*da0073e9SAndroid Build Coastguard Worker            )
860*da0073e9SAndroid Build Coastguard Worker        for differentiable, output in zip(output_differentiability, outputs):
861*da0073e9SAndroid Build Coastguard Worker            if differentiable:
862*da0073e9SAndroid Build Coastguard Worker                differentiable_outputs.append(output)
863*da0073e9SAndroid Build Coastguard Worker        return differentiable_outputs
864*da0073e9SAndroid Build Coastguard Worker    candidate_differentiable_outputs = list(
865*da0073e9SAndroid Build Coastguard Worker        filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
866*da0073e9SAndroid Build Coastguard Worker    )
867*da0073e9SAndroid Build Coastguard Worker    if uses_single_grad(info):
868*da0073e9SAndroid Build Coastguard Worker        return candidate_differentiable_outputs[:1]
869*da0073e9SAndroid Build Coastguard Worker    else:
870*da0073e9SAndroid Build Coastguard Worker        return candidate_differentiable_outputs
871