xref: /aosp_15_r20/external/pytorch/torchgen/api/autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import re
4from dataclasses import dataclass
5from typing import cast, Sequence
6
7from torchgen import local
8from torchgen.api import cpp
9from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
10from torchgen.model import (
11    BaseTy,
12    BaseType,
13    FunctionSchema,
14    ListType,
15    NativeFunction,
16    NativeFunctionsViewGroup,
17    SchemaKind,
18    Type,
19)
20from torchgen.utils import IDENT_REGEX
21
22
23# Represents a saved attribute involved in backward calculation.
24# Note that it can be a derived property of an input argument, e.g.:
25# we could save `other.scalar_type()` instead of the entire `other` tensor.
26@dataclass(frozen=True)
27class SavedAttribute:
28    # The NamedCType holds the updated name and cpp type of the attribute
29    # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
30    nctype: NamedCType
31
32    # The expression to read the derived property at save time, e.g.:
33    # `other.scalar_type()`.
34    expr: str
35
36
37# Represents a backward formula that calculates derivatives for one
38# or more tensors.
39@dataclass(frozen=True)
40class Derivative:
41    # The formula string (legit C++ expression).
42    # Note that expressions against input arguments have been replaced with the
43    # corresponding saved attributes.
44    # E.g.:
45    #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
46    #         here: `mul_tensor_backward(grad, self, other_scalar_type)`
47    formula: str
48
49    # The formula string before input argument replacement
50    original_formula: str
51
52    # Names of the arguments for which this formula calculates derivatives.
53    var_names: tuple[str, ...]
54
55    # Saved inputs that are referenced by the formula.
56    saved_inputs: tuple[SavedAttribute, ...]
57
58    # Saved outputs that are referenced by the formula.
59    saved_outputs: tuple[SavedAttribute, ...]
60
61    # Gradients that are referenced by name in the formula.
62    named_gradients: set[str]
63
64
65# Represents a forward formula that calculates forward derivatives
66# for one tensor.
67@dataclass(frozen=True)
68class ForwardDerivative:
69    # The formula string (legit C++ expression).
70    # Note that special keywords such as "linear" or "element_wise" have been
71    # replaced by the automatically generated formula.
72    formula: str
73
74    # Name of the output arguments for which this formula calculates forward
75    # derivatives
76    var_names: tuple[str, ...]
77
78    # Type of the output arguments for which this formula calculates forward
79    # derivatives
80    var_types: tuple[Type, ...]
81
82    # Inputs for which the forward derivatives are required for this formula
83    required_inputs_fw_grad: tuple[str, ...] | None
84
85    # Inputs for which the primal is required for this formula
86    required_inputs_primal: tuple[str, ...] | None
87
88    # Flag to specify if this formula requires the original value of self
89    # This is only used by inplace operations
90    required_original_self_value: bool
91
92    # If this formula is specified in derivatives.yaml or if we are re-using the
93    # out of place formula for inplace
94    is_reusing_outplace_formula: bool
95
96
97# Represents differentiability info for a NativeFunction.
98@dataclass(frozen=True)
99class DifferentiabilityInfo:
100    # The base name read from derivatives.yaml.
101    name: str
102
103    # The matching native function.
104    #
105    # There can be multiple NativeFunction having the same base name:
106    #  - different overloads with different types of input arguments;
107    #  - in-place/out/functional variants of the same function;
108    #
109    # We first use the schema string (under the 'name' key) in derivatives.yaml
110    # to find the NativeFunction having the same schema string.
111    # Then we find the in-place/out/functional variants of the matching function.
112    # Among these variants, we choose the one having the same name as the
113    # derivatives.yaml entry. If there is no exact match, then we choose the
114    # in-place variant.
115    # TODO: maybe the logic to search for all variants is no longer necessary?
116    func: NativeFunction
117
118    # The name of the generated autograd function.
119    # It's set only if we will calculate a derivative, i.e.
120    # 'args_with_derivatives' is not empty.
121    op: str | None
122
123    # The derivatives formulae for this function.
124    # Note that the length of this sequence is the number of differentiable inputs
125    derivatives: Sequence[Derivative]
126
127    # The forward derivatives formulae for this function.
128    # Note that the length of this sequence is the number of differentiable outputs
129    forward_derivatives: Sequence[ForwardDerivative]
130
131    # The union of 'saved_inputs' of all 'derivatives'.
132    all_saved_inputs: Sequence[SavedAttribute]
133
134    # The union of 'saved_outputs' of all 'derivatives'.
135    all_saved_outputs: Sequence[SavedAttribute]
136
137    # All named gradients that are available for use, in the same
138    # order as in the grads vector.
139    available_named_gradients: Sequence[str]
140
141    # The named gradients that are used in any of the derivatives.
142    # Invariant: all(name in available_named_gradients for name in used_named_gradients)
143    used_named_gradients: set[str]
144
145    # The function's input arguments for which it calculates derivatives.
146    # It's the union of 'var_names' of all 'derivatives', sorted by the
147    # argument order in the function schema.
148    args_with_derivatives: Sequence[Binding]
149
150    # Names of arguments whose derivative formula is 'non_differentiable'.
151    non_differentiable_arg_names: Sequence[str]
152
153    # Raw data read from derivatives.yaml.
154    output_differentiability: list[bool] | None
155
156    # output_differentiability in derivatives.yaml can be a list of
157    # conditions that express if the output is differentiable. In this case,
158    # the number of conditions must match the number of outputs
159    # (NB: we only support one condition right now).
160    # output_differentiability gets populated with True for each condition,
161    # while output_differentiability_conditions gets populated with the conditions
162    output_differentiability_conditions: list[str] | None
163
164    @property
165    def has_derivatives(self) -> bool:
166        return len(self.args_with_derivatives) > 0
167
168    # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
169    # but with a new operator name.
170    # This is used when generating "copy" variants of view ops,
171    # which are able to use the exact same derivative formula as the original view op
172    # See Note [Codegen'd {view}_copy Operators]
173    def create_view_copy_from_view_derivative(
174        self, g: NativeFunctionsViewGroup
175    ) -> DifferentiabilityInfo | None:
176        if g.view_copy is None:
177            return None
178        f = g.view_copy
179
180        name_split_by_period = self.name.split(".", maxsplit=2)
181        # Append a "_copy" to the base name of the operator (but keep the overload name the same)
182        view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
183            name_split_by_period[1:]
184        )
185        view_copy_op_name = None if self.op is None else f"{self.op}_copy"
186
187        return DifferentiabilityInfo(
188            # Use the "_copy" version of name/func/op
189            name=view_copy_name,
190            func=f,
191            op=view_copy_op_name,
192            # But keep all derivative info the same
193            derivatives=self.derivatives,
194            forward_derivatives=self.forward_derivatives,
195            all_saved_inputs=self.all_saved_inputs,
196            all_saved_outputs=self.all_saved_outputs,
197            available_named_gradients=self.available_named_gradients,
198            used_named_gradients=self.used_named_gradients,
199            args_with_derivatives=self.args_with_derivatives,
200            non_differentiable_arg_names=self.non_differentiable_arg_names,
201            output_differentiability=self.output_differentiability,
202            output_differentiability_conditions=self.output_differentiability_conditions,
203        )
204
205
206def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
207    if info is None:
208        return False
209    for derivative in info.derivatives:
210        formula = derivative.formula
211        if re.search(IDENT_REGEX.format(ident), formula):
212            return True
213    return False
214
215
216def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
217    return uses_ident(info, "retain_variables")
218
219
220def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
221    return uses_ident(info, "grad")
222
223
224# Represents a differentiable `Argument`.
225# How is it different from the `Argument` type?
226# - It's processed Arguments which are differentiable and only used in the
227#   context of the autograd codegen;
228# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
229@dataclass(frozen=True)
230class DifferentiableInput:
231    name: str
232    type: Type
233
234    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
235    cpp_type: str
236
237
238# Represents a differentiable `Return`.
239# How it it different from the `Return` type?
240# - The name in `Return` is optional. Here it is always populated using the same
241#   `cpp.return_names()` method.
242#   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
243# - It's processed Returns which are differentiable, in compliance with the
244#   `output_differentiability` field defined in derivatives.yaml (if specified),
245#   and are only used in the context of the autograd codegen;
246@dataclass(frozen=True)
247class DifferentiableOutput:
248    name: str
249    type: Type
250
251    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
252    cpp_type: str
253
254
255@dataclass(frozen=True)
256class NativeFunctionWithDifferentiabilityInfo:
257    func: NativeFunction
258    info: dict[str, DifferentiabilityInfo] | None
259    fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
260
261
262# TODO: Update comment below since it is out of date.
263def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
264    """How are we going to call the underlying implementation of a
265    declaration?  There are two strategies:
266        - use_derived: we want to call the implementation on CPUDoubleType
267          (or a similar, derived Type instance).  Because these derived
268          instances deal in Tensors, not Variables (it's a completely different
269          object, so it doesn't dispatch back to VariableType), code on
270          this dispatch path needs to wrap/unwrap tensors.  If the
271          derived implementation takes and returns tensors, the
272          implementation is usually differentiable (although we also use
273          the derived dispatch path for non-differentiable functions
274          that we still want to dispatch on the derived Type instance;
275          e.g., size())
276        - use_type: we want to call the implementation on Type, because
277          it is implemented concretely, and the functions it invokes will
278          get dispatched back to VariableType (which will ensure that they
279          are differentiable.)
280    """
281    # fn is derived as long as any of its per-key differentiability infos
282    # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
283    # and ADInplaceOrViewType. We want to generate these functions as long as a
284    # derivative is defined for ANY dispatch key.
285    if fn.func.is_abstract or (
286        fn.info is not None and any(info.has_derivatives for info in fn.info.values())
287    ):
288        # If the function is abstract (not implemented on at::Type), we must
289        # call the implementation on the derived type with unpacked tensors.
290
291        # If the function has a derivative specified and is concrete, we could
292        # call either implementation. We prefer the calling the derived
293        # type's implementation with unpacked tensors because it is more
294        # performant in some cases: any internal calls to other ATen functions
295        # won't have the history tracked.
296
297        # If the function has a type dispatched argument (i.e. is a factory),
298        # we prefer calling the derived type's implementation both because it is
299        # more performant and to ensure factory functions return tensors with _version
300        # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
301        # to understand.
302
303        return "use_derived"
304    else:
305        # If the function is concrete (we don't have to override it) and we
306        # didn't declare it in derivatives.yaml, we'll assume that it is
307        # actually implemented out of differentiable functions. (This
308        # assumption might not hold, but then you'll see gradcheck fail.)
309        return "use_type"
310
311
312def is_foreach_func(f: NativeFunction) -> bool:
313    return f.func.name.name.base.startswith("_foreach_")
314
315
316# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
317# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
318# they would find such one in `functional_info_by_signature`. There however are some exceptions:
319_foreach_with_inplace_ref = {"_foreach_zero_"}
320_foreach_with_tensor_overload = {
321    "_foreach_add.Tensor",
322    "_foreach_mul.Tensor",
323    "_foreach_div.Tensor",
324}
325# The following do not support the alpha kwarg, which the nonforeach versions support.
326_skip_argument_len_check = {
327    "_foreach_add.Scalar",
328    "_foreach_add_.Scalar",
329    "_foreach_add.ScalarList",
330    "_foreach_add_.ScalarList",
331    "_foreach_sub.Scalar",
332    "_foreach_sub_.Scalar",
333    "_foreach_sub.ScalarList",
334    "_foreach_sub_.ScalarList",
335}
336
337
338# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
339# reference to generate derivatives.
340def is_reference_for_foreach(
341    f: NativeFunction,
342    function_schema: FunctionSchema,
343) -> bool:
344    return (
345        f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
346        and (
347            not function_schema.name.name.inplace
348            or str(f.func.name) in _foreach_with_inplace_ref
349        )
350        and (
351            str(f.func.name) in _skip_argument_len_check
352            or len(f.func.arguments.flat_non_out)
353            == len(function_schema.arguments.flat_non_out)
354        )
355        and all(
356            ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
357            for arg, ref_arg in zip(
358                f.func.arguments.flat_non_out,
359                function_schema.arguments.flat_non_out,
360            )
361        )
362    )
363
364
365# TODO(crcrpar): Avoid hard coding "Default" ideally.
366def gen_foreach_derivativeinfo(
367    foreach_function: NativeFunction,
368    functional_info_by_signature: dict[
369        FunctionSchema, dict[str, DifferentiabilityInfo]
370    ],
371    non_functional_info_by_signature: dict[
372        FunctionSchema, dict[str, DifferentiabilityInfo]
373    ],
374    dispatch_key: str = "Default",
375) -> tuple[DifferentiabilityInfo | None, bool]:
376    """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
377
378    The second return value indicates whether the info is generated in this function.
379    """
380    ref_diff_info: DifferentiabilityInfo | None = None
381
382    for function_schema, diff_info in functional_info_by_signature.items():
383        if not is_reference_for_foreach(foreach_function, function_schema):
384            continue
385        ref_diff_info = diff_info[dispatch_key]
386        if ref_diff_info is not None:
387            break
388    # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
389    # while the info of `zero_` is in non_functional_info_by_signature
390    if (
391        ref_diff_info is None
392        and foreach_function.func.kind() == SchemaKind.inplace
393        and str(foreach_function.func.name) in _foreach_with_inplace_ref
394    ):
395        for function_schema, diff_info in non_functional_info_by_signature.items():
396            if not is_reference_for_foreach(foreach_function, function_schema):
397                continue
398            ref_diff_info = diff_info[dispatch_key]
399            if ref_diff_info is not None:
400                break
401    if ref_diff_info is None:
402        return None, False
403
404    # non out-place uses the existing Derivative.
405    if foreach_function.func.kind() == SchemaKind.inplace:
406        return ref_diff_info, False
407
408    map_refarg2foreacharg, map_name2arg = {}, {}
409    for i, (arg, ref_arg) in enumerate(
410        zip(
411            foreach_function.func.arguments.flat_non_out,
412            function_schema.arguments.flat_non_out,
413        )
414    ):
415        map_refarg2foreacharg[ref_arg.name] = arg.name
416        map_name2arg[arg.name] = arg
417
418    all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
419    modified_derivative_formulas = []
420    for i, derivative in enumerate(ref_diff_info.derivatives):
421        modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
422            "result", "result[i]"
423        )
424        saved_inputs, saved_outputs = [], []
425        # note(crcrpar): This context seems necessary to call `cpp.argument_type`
426        with local.parametrize(
427            use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
428            use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
429        ):
430            for ref_input in derivative.saved_inputs:
431                ref_input_jit_name = ref_input.expr.split(".")[0]
432                mapped_name = map_refarg2foreacharg[ref_input_jit_name]
433                if isinstance(map_name2arg[mapped_name].type, ListType):
434                    mapped_expr = mapped_name + "[i]"
435                else:
436                    mapped_expr = mapped_name
437                new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
438                modified_formula = modified_formula.replace(
439                    cast(str, ref_input.nctype.name), new_expr
440                )
441
442                nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
443                canonical_nctype = NamedCType(
444                    nctype.name, nctype.type.remove_const_ref()
445                )
446                saved_inputs.append(
447                    SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
448                )
449            for ref_output in derivative.saved_outputs:
450                if ref_output.nctype.name == "result":
451                    saved_outputs.append(
452                        SavedAttribute(
453                            nctype=NamedCType(
454                                name="result", type=BaseCType(tensorListT)
455                            ),
456                            expr="result",
457                        )
458                    )
459                else:
460                    raise RuntimeError("")
461        var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
462        all_var_names.extend(var_names)
463        all_saved_inputs.extend(saved_inputs)
464        all_saved_outputs.extend(saved_outputs)
465        modified_derivative = Derivative(
466            formula=modified_formula,
467            original_formula=derivative.formula,
468            var_names=tuple(var_names),
469            saved_inputs=tuple(saved_inputs),
470            saved_outputs=tuple(saved_outputs),
471            named_gradients=set(),
472        )
473        modified_derivative_formulas.append(modified_derivative)
474
475    with local.parametrize(
476        use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
477        use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
478    ):
479        args_with_derivatives = [
480            Binding(
481                name=arg.name,
482                nctype=cpp.argument_type(arg, binds=arg.name),
483                argument=arg,
484                default=None,
485            )
486            for arg in foreach_function.func.arguments.flat_non_out
487            if arg.name in all_var_names
488        ]
489
490    forward_derivatives: list[ForwardDerivative] = []
491    fw_derivative: ForwardDerivative
492    for fw_derivative in ref_diff_info.forward_derivatives:
493        var_names: list[str] = list(fw_derivative.var_names)  # type: ignore[no-redef]
494        var_types: list[Type] = list(fw_derivative.var_types)
495        required_inputs_fw_grad: list[str] = []
496        required_inputs_primal: list[str] = []
497        if fw_derivative.required_inputs_fw_grad is not None:
498            required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
499        if fw_derivative.required_inputs_primal:
500            required_inputs_primal = list(fw_derivative.required_inputs_primal)
501        modified_formula = fw_derivative.formula
502
503        # Foreach's result is TensorList
504        if "result" in modified_formula:
505            modified_formula = fw_derivative.formula.replace("result", "result[i]")
506
507        for foreach_arg, ref_arg in zip(
508            foreach_function.func.arguments.flat_non_out,
509            ref_diff_info.func.func.arguments.flat_non_out,
510        ):
511            # Modify reference forward formula
512            if (
513                isinstance(foreach_arg.type, ListType)
514                and not foreach_arg.type.is_tensor_like()
515            ):
516                # Assuming ScalarList
517                modified_formula = modified_formula.replace(
518                    ref_arg.name, foreach_arg.name + "[i]"
519                )
520            elif foreach_arg.type.is_tensor_like():
521                # Assuming TensorList / Tensor
522                # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
523                assert isinstance(foreach_arg.type, ListType) or (
524                    foreach_arg.type == BaseType(BaseTy.Tensor)
525                    and str(foreach_function.func.name) in _foreach_with_tensor_overload
526                ), f"{foreach_function.func.name}, {foreach_arg.type}"
527                for suffix in ("_p", "_t"):
528                    curr_expr = ref_arg.name + suffix
529                    if curr_expr in modified_formula:
530                        new_expr = foreach_arg.name + suffix
531                        modified_formula = modified_formula.replace(curr_expr, new_expr)
532            else:
533                # Assuming Scalar
534                if foreach_arg.name != ref_arg.name:
535                    modified_formula = modified_formula.replace(
536                        ref_arg.name, foreach_arg.name
537                    )
538
539            # note(crcrpar): there should exist a cooler way...
540            for i, name in enumerate(var_names):
541                if name == ref_arg.name:
542                    var_names[i] = foreach_arg.name
543                    var_types[i] = foreach_arg.type
544            for i, name in enumerate(required_inputs_fw_grad):
545                if name == ref_arg.name:
546                    required_inputs_fw_grad[i] = foreach_arg.name
547            for i, name in enumerate(required_inputs_primal):
548                if name == ref_arg.name:
549                    required_inputs_primal[i] = foreach_arg.name
550        forward_derivatives.append(
551            ForwardDerivative(
552                formula=modified_formula,
553                var_names=tuple(var_names),
554                var_types=tuple(var_types),
555                required_inputs_fw_grad=tuple(required_inputs_fw_grad),
556                required_inputs_primal=tuple(required_inputs_primal),
557                required_original_self_value=fw_derivative.required_original_self_value,
558                is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
559            )
560        )
561
562    return (
563        DifferentiabilityInfo(
564            name=foreach_function.func.name.name.base,
565            func=foreach_function,
566            op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
567            derivatives=modified_derivative_formulas,
568            forward_derivatives=forward_derivatives,
569            all_saved_inputs=tuple(set(all_saved_inputs)),
570            all_saved_outputs=tuple(set(all_saved_outputs)),
571            available_named_gradients=(),
572            used_named_gradients=set(),
573            args_with_derivatives=args_with_derivatives,
574            non_differentiable_arg_names=[],
575            output_differentiability=None,
576            output_differentiability_conditions=None,
577        ),
578        True,
579    )
580
581
582def match_differentiability_info(
583    native_functions: list[NativeFunction],
584    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
585) -> list[NativeFunctionWithDifferentiabilityInfo]:
586    """Sets the "derivative" key on declarations to matching autograd function
587    In-place functions will use the out-of-place derivative definition if there
588    is no in-place specific derivative.
589    """
590
591    functional_info_by_signature = {
592        schema.signature(strip_default=True): info_dict
593        for schema, info_dict in differentiability_infos.items()
594        if schema.kind() == SchemaKind.functional
595    }
596    non_functional_info_by_signature = {
597        schema.signature(strip_default=True): info_dict
598        for schema, info_dict in differentiability_infos.items()
599        if schema.kind() != SchemaKind.functional
600    }
601
602    def find_info(
603        f: NativeFunction,
604    ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
605        # Don't bother matching info to generated out= variants
606        if "generated" in f.tags and f.func.kind() == SchemaKind.out:
607            return None, False
608
609        # (1) Check for an exact match
610        if f.func in differentiability_infos:
611            return differentiability_infos[f.func], True
612
613        # (2) If no exact match, check if the out-of-place variant
614        # of this operator has a match.
615        # i.e mul() for mul_() or mul_out()
616        # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
617        # native functions instead of the out-place counterparts.
618        f_sig = f.func.signature(strip_default=True)
619        if f_sig in functional_info_by_signature and not is_foreach_func(f):
620            return functional_info_by_signature[f_sig], False
621
622        # (3) Some operators have a derivative explicitly defined for the mutable
623        # variant, but get a code-generated out-of-place variant which does *not*
624        # come with a derivative formula.
625        # For the generated out-of-place variant, use the mutable variant's formula
626        # if it exists.
627        if "generated" in f.tags and f_sig in non_functional_info_by_signature:
628            info_dict = non_functional_info_by_signature[f_sig]
629            # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
630            assert not any(
631                any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
632                for info in info_dict.values()
633            ), f"""\
634Attempted to convert a derivative formula for a mutable operator
635 to be used by automatically by its functional variant ("{str(f.func)}").
636 this is not currently supported (we'd need to fix up the formula in the codegen)."""
637            return info_dict, False
638
639        # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
640        if is_foreach_func(f):
641            assert f.func not in differentiability_infos
642            diff_info, is_generated = gen_foreach_derivativeinfo(
643                f,
644                functional_info_by_signature,
645                non_functional_info_by_signature,
646            )
647            if diff_info is None:
648                return None, False
649            # TODO(crcrpar): Avoid hard coding "Default" ideally.
650            diff_info_dict = {"Default": diff_info}
651            if is_generated:
652                differentiability_infos[f.func] = diff_info_dict
653                functional_info_by_signature[f.func] = diff_info_dict
654            return diff_info_dict, is_generated
655
656        return None, False
657
658    result: list[NativeFunctionWithDifferentiabilityInfo] = []
659    for f in native_functions:
660        info_dict, is_exact_match = find_info(f)
661
662        # Currently, the '.strides()' to 'strides_or_error' replacement does not support
663        # 'self' derivatives of an inplace function, so we must check for this case.
664        if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
665            for info in info_dict.values():
666                for derivative in info.derivatives:
667                    if "self" in derivative.var_names:
668                        for saved_input in derivative.saved_inputs:
669                            assert "strides_or_error" not in saved_input.expr, (
670                                "Calling '.strides()' in the 'self' derivative formula of an "
671                                f"in-place function is not supported: {f.func}"
672                            )
673
674        if not info_dict:
675            result.append(
676                NativeFunctionWithDifferentiabilityInfo(
677                    func=f, info=None, fw_derivatives=None
678                )
679            )
680            continue
681
682        fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
683        for key, info in info_dict.items():
684            if not info.forward_derivatives:
685                fw_derivative_dict[key] = []
686                continue
687
688            forward_derivatives = info.forward_derivatives
689
690            # For functions that have a single def for out-of-place and inplace (like abs())
691            if f.func.kind() == SchemaKind.inplace:
692                # For inplace functions there is a little bit of work to do:
693                #  1) Validate the formula and make sure the input that is modified in not used:
694                #    - If there is a formula for the inplace variant of the function (is_exact_match == True) then
695                #      we make sure that the original value of the input that is being modified inplace (self_p) is
696                #      not used in the formula. Note that the formula can use "original_self_p" here and that would
697                #      trigger a clone of the original input.
698                #    - If we are re-using the out of place formula (is_exact_match == False) then we replace every
699                #      occurrence of self_p and self_t by original_self_p and original_self_t. These will be
700                #      populated by cloned version of the original input (either the clone done by the backward AD
701                #      logic if self is also used in a backward formula or a special clone that we add).
702                #  2) At this point, there cannot be a self_p in the formula.
703                #  3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
704                #     simply called self (as it is modified inplace).
705                #  4) Update the required primals data in case it used to contain "result" but should now contain
706                #     "self"
707                #  5) If it is not an exact match, the user formula is not modifying the existing forward grad
708                #     inplace as it should. So add some code that makes sure that we do so if the forward grad
709                #     already exists.
710
711                assert (
712                    len(info.forward_derivatives) == 1
713                )  # Only single output inplace should exist
714                fw_info = info.forward_derivatives[0]
715                formula = fw_info.formula
716
717                def replace_self_with_original_self(formula: str, postfix: str) -> str:
718                    def repl(m: re.Match[str]) -> str:
719                        return f"{m.group(1)}original_self{postfix}{m.group(2)}"
720
721                    return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
722
723                if re.search(IDENT_REGEX.format("self_p"), formula):
724                    if is_exact_match:
725                        # For manually defined formulas, don't allow the original value to be used
726                        raise RuntimeError(
727                            f'The formula for "{f.func.name}" is using the original value of self '
728                            "that is being modified inplace. This would lead to wrong forward gradients. "
729                            'Please use "result" in the formula only.'
730                        )
731                    else:
732                        # When the original formula is out of place, we save a clone of the primal
733                        # value to be able to access this value if needed
734                        # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
735                        formula = replace_self_with_original_self(formula, "_p")
736                        formula = replace_self_with_original_self(formula, "_t")
737
738                # replace "result" from the formula by "self_p"
739                def repl(m: re.Match[str]) -> str:
740                    return f"{m.group(1)}self_p{m.group(2)}"
741
742                formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
743
744                required_primals = fw_info.required_inputs_primal
745                if re.search(IDENT_REGEX.format("self_p"), formula):
746                    required_primals = (
747                        required_primals + ("self",) if required_primals else ("self",)
748                    )
749
750                if not is_exact_match:
751                    # NOTE [In-place forward AD formula Optimization]
752                    #
753                    # This optimization transforms the formula to directly do inplace, i.e.
754                    # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
755                    #
756                    # 1) the formula satisfies the pattern: "self_t.op(*args)"
757                    # 2) "op" in (1) needs to be the same as the op the derivative is for
758                    #
759                    # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
760                    # If there is a need, we can relax (2) to allow any op that has an in-place variant
761                    is_single_method_on_self_t = False
762                    directly_do_inplace = False
763                    op_name: str | None = None
764                    between_parens: str | None = None
765                    match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
766                    if match:
767                        op_name, between_parens = match.group(1), match.group(2)
768
769                        # We want to...
770                        #   Match: self_t.op1(other_p.op2(arg))
771                        #   Avoid: self_t.op1(args) + self_t.op2(args)
772                        #   Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
773                        def check_parens_nest_level_gt_zero(s: str) -> bool:
774                            level = 1
775                            for ch in s:
776                                if ch == ")":
777                                    level -= 1
778                                    if level == 0:
779                                        return False
780                                if ch == "(":
781                                    level += 1
782                            return True
783
784                        is_single_method_on_self_t = check_parens_nest_level_gt_zero(
785                            between_parens
786                        )
787                        directly_do_inplace = (
788                            is_single_method_on_self_t and op_name == info.name
789                        )
790
791                    if directly_do_inplace:
792                        assert op_name is not None
793                        assert between_parens is not None
794                        formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
795                    else:
796                        # Make sure that the forward grad is modified inplace when the original formula
797                        # is out of place
798                        formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
799
800                required_original_self_value = bool(
801                    re.search(IDENT_REGEX.format("original_self_p"), formula)
802                ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
803
804                forward_derivatives = [
805                    ForwardDerivative(
806                        formula=formula,
807                        var_names=("self",),
808                        var_types=fw_info.var_types,
809                        required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
810                        required_inputs_primal=required_primals,
811                        required_original_self_value=required_original_self_value,
812                        is_reusing_outplace_formula=not is_exact_match,
813                    ),
814                ]
815
816            fw_derivative_dict[key] = forward_derivatives
817
818        result.append(
819            NativeFunctionWithDifferentiabilityInfo(
820                func=f, info=info_dict, fw_derivatives=fw_derivative_dict
821            )
822        )
823
824    return result
825
826
827def is_differentiable(
828    name: str, type: Type, info: DifferentiabilityInfo | None
829) -> bool:
830    return type.is_tensor_like() and (
831        info is None or name not in info.non_differentiable_arg_names
832    )
833
834
835def gen_differentiable_outputs(
836    fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
837) -> list[DifferentiableOutput]:
838    f = fn.func
839    info = fn.info[key] if fn.info else None
840    outputs: list[DifferentiableOutput] = [
841        DifferentiableOutput(
842            name=name,
843            type=ret.type,
844            cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
845        )
846        for name, ret in zip(cpp.return_names(f), f.func.returns)
847    ]
848    output_differentiability = info.output_differentiability if info else None
849    if output_differentiability is not None:
850        if len(output_differentiability) != len(outputs):
851            raise RuntimeError(
852                f"The length of output_differentiability ({len(output_differentiability)}), "
853                f"does not match the number of outputs ({len(outputs)})."
854            )
855        differentiable_outputs: list[DifferentiableOutput] = []
856        if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
857            raise RuntimeError(
858                "output_differentiability=False for inplace operation (version_counter won't get updated)"
859            )
860        for differentiable, output in zip(output_differentiability, outputs):
861            if differentiable:
862                differentiable_outputs.append(output)
863        return differentiable_outputs
864    candidate_differentiable_outputs = list(
865        filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
866    )
867    if uses_single_grad(info):
868        return candidate_differentiable_outputs[:1]
869    else:
870        return candidate_differentiable_outputs
871