xref: /aosp_15_r20/external/pytorch/tools/autograd/load_derivatives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Parses derivatives.yaml into autograd functions
2*da0073e9SAndroid Build Coastguard Worker#
3*da0073e9SAndroid Build Coastguard Worker# Each autograd function is represented by `DifferentiabilityInfo` containing
4*da0073e9SAndroid Build Coastguard Worker# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport re
9*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
10*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Counter, Dict, Sequence, Set, Tuple
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport yaml
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
15*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.autograd import (
16*da0073e9SAndroid Build Coastguard Worker    Derivative,
17*da0073e9SAndroid Build Coastguard Worker    DifferentiabilityInfo,
18*da0073e9SAndroid Build Coastguard Worker    ForwardDerivative,
19*da0073e9SAndroid Build Coastguard Worker    SavedAttribute,
20*da0073e9SAndroid Build Coastguard Worker)
21*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
22*da0073e9SAndroid Build Coastguard Worker    BaseCType,
23*da0073e9SAndroid Build Coastguard Worker    Binding,
24*da0073e9SAndroid Build Coastguard Worker    boolT,
25*da0073e9SAndroid Build Coastguard Worker    CppSignatureGroup,
26*da0073e9SAndroid Build Coastguard Worker    layoutT,
27*da0073e9SAndroid Build Coastguard Worker    longT,
28*da0073e9SAndroid Build Coastguard Worker    NamedCType,
29*da0073e9SAndroid Build Coastguard Worker    OptionalCType,
30*da0073e9SAndroid Build Coastguard Worker    scalarTypeT,
31*da0073e9SAndroid Build Coastguard Worker    SpecialArgName,
32*da0073e9SAndroid Build Coastguard Worker    stringT,
33*da0073e9SAndroid Build Coastguard Worker    symIntArrayRefT,
34*da0073e9SAndroid Build Coastguard Worker    SymIntT,
35*da0073e9SAndroid Build Coastguard Worker    tensorGeometryT,
36*da0073e9SAndroid Build Coastguard Worker    tensorOptionsT,
37*da0073e9SAndroid Build Coastguard Worker    typeAndSizeT,
38*da0073e9SAndroid Build Coastguard Worker    VectorCType,
39*da0073e9SAndroid Build Coastguard Worker)
40*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import with_native_function
41*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
42*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
43*da0073e9SAndroid Build Coastguard Worker    AUTOGRAD_KEYS,
44*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
45*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
46*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsViewGroup,
47*da0073e9SAndroid Build Coastguard Worker    OperatorName,
48*da0073e9SAndroid Build Coastguard Worker    SchemaKind,
49*da0073e9SAndroid Build Coastguard Worker    Type,
50*da0073e9SAndroid Build Coastguard Worker    Variant,
51*da0073e9SAndroid Build Coastguard Worker)
52*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import concatMap, IDENT_REGEX, split_name_params
53*da0073e9SAndroid Build Coastguard Workerfrom torchgen.yaml_utils import YamlLoader
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard WorkerDerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
64*da0073e9SAndroid Build Coastguard Worker# Since every {view} and {view}_copy op shares the same derivative formula,
65*da0073e9SAndroid Build Coastguard Worker# we generate them here instead of duplicating them in the yaml.
66*da0073e9SAndroid Build Coastguard Worker# See Note [Codegen'd {view}_copy Operators]
67*da0073e9SAndroid Build Coastguard Workerdef add_view_copy_derivatives(
68*da0073e9SAndroid Build Coastguard Worker    infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
69*da0073e9SAndroid Build Coastguard Worker    view_groups: list[NativeFunctionsViewGroup],
70*da0073e9SAndroid Build Coastguard Worker) -> None:
71*da0073e9SAndroid Build Coastguard Worker    # Get the map from each view op's name to its corresponding view group
72*da0073e9SAndroid Build Coastguard Worker    view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
73*da0073e9SAndroid Build Coastguard Worker        g.view.func.name: g for g in view_groups
74*da0073e9SAndroid Build Coastguard Worker    }
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    view_infos = {}
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    for info_dispatch_dict in infos.values():
79*da0073e9SAndroid Build Coastguard Worker        # maybe_view_group only needs to be calculated once per info_dispatch_dict
80*da0073e9SAndroid Build Coastguard Worker        maybe_view_group = None
81*da0073e9SAndroid Build Coastguard Worker        view_copy_differentiability_infos = {}
82*da0073e9SAndroid Build Coastguard Worker        for dispatch_key, info in info_dispatch_dict.items():
83*da0073e9SAndroid Build Coastguard Worker            maybe_view_group = view_name_to_group.get(info.func.func.name, None)
84*da0073e9SAndroid Build Coastguard Worker            if maybe_view_group is not None and maybe_view_group.view_copy is not None:
85*da0073e9SAndroid Build Coastguard Worker                view_copy_info = info.create_view_copy_from_view_derivative(
86*da0073e9SAndroid Build Coastguard Worker                    maybe_view_group
87*da0073e9SAndroid Build Coastguard Worker                )
88*da0073e9SAndroid Build Coastguard Worker                if view_copy_info is not None:
89*da0073e9SAndroid Build Coastguard Worker                    fn_schema = view_copy_info.func.func
90*da0073e9SAndroid Build Coastguard Worker                    view_copy_differentiability_infos[dispatch_key] = view_copy_info
91*da0073e9SAndroid Build Coastguard Worker            else:
92*da0073e9SAndroid Build Coastguard Worker                break
93*da0073e9SAndroid Build Coastguard Worker        # prefer manually-defined derivatives if any
94*da0073e9SAndroid Build Coastguard Worker        if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
95*da0073e9SAndroid Build Coastguard Worker            assert fn_schema is not None
96*da0073e9SAndroid Build Coastguard Worker            view_infos[fn_schema] = view_copy_differentiability_infos
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    infos.update(view_infos)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerdef load_derivatives(
102*da0073e9SAndroid Build Coastguard Worker    derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
103*da0073e9SAndroid Build Coastguard Worker) -> DerivativeRet:
104*da0073e9SAndroid Build Coastguard Worker    # Do some caching as this is a deterministic function
105*da0073e9SAndroid Build Coastguard Worker    global _GLOBAL_LOAD_DERIVATIVE_CACHE
106*da0073e9SAndroid Build Coastguard Worker    key = (derivatives_yaml_path, native_yaml_path)
107*da0073e9SAndroid Build Coastguard Worker    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
108*da0073e9SAndroid Build Coastguard Worker        with open(derivatives_yaml_path) as f:
109*da0073e9SAndroid Build Coastguard Worker            definitions = yaml.load(f, Loader=YamlLoader)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
112*da0073e9SAndroid Build Coastguard Worker        # From the parsed native functions, separate out the (generated) view_copy functions,
113*da0073e9SAndroid Build Coastguard Worker        # so we can generate derivatives for them separately.
114*da0073e9SAndroid Build Coastguard Worker        native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
115*da0073e9SAndroid Build Coastguard Worker        native_functions = concatMap(
116*da0073e9SAndroid Build Coastguard Worker            lambda g: [g]
117*da0073e9SAndroid Build Coastguard Worker            if isinstance(g, NativeFunction)
118*da0073e9SAndroid Build Coastguard Worker            else list(g.functions(include_copy=True)),
119*da0073e9SAndroid Build Coastguard Worker            native_functions_with_view_groups,
120*da0073e9SAndroid Build Coastguard Worker        )
121*da0073e9SAndroid Build Coastguard Worker        view_groups = [
122*da0073e9SAndroid Build Coastguard Worker            g
123*da0073e9SAndroid Build Coastguard Worker            for g in native_functions_with_view_groups
124*da0073e9SAndroid Build Coastguard Worker            if isinstance(g, NativeFunctionsViewGroup)
125*da0073e9SAndroid Build Coastguard Worker        ]
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        # What's the difference between function schema v.s. signature?
128*da0073e9SAndroid Build Coastguard Worker        # function schema is the complete declaration including mutability annotation / default value and etc.
129*da0073e9SAndroid Build Coastguard Worker        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
130*da0073e9SAndroid Build Coastguard Worker        # that are semantically related.
131*da0073e9SAndroid Build Coastguard Worker        functions_by_signature: dict[
132*da0073e9SAndroid Build Coastguard Worker            FunctionSchema, list[NativeFunction]
133*da0073e9SAndroid Build Coastguard Worker        ] = defaultdict(list)
134*da0073e9SAndroid Build Coastguard Worker        functions_by_schema: dict[str, NativeFunction] = {}
135*da0073e9SAndroid Build Coastguard Worker        for function in native_functions:
136*da0073e9SAndroid Build Coastguard Worker            functions_by_signature[function.func.signature()].append(function)
137*da0073e9SAndroid Build Coastguard Worker            assert str(function.func) not in functions_by_schema
138*da0073e9SAndroid Build Coastguard Worker            functions_by_schema[str(function.func)] = function
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        # Keep track of how many of which ops we've seen so we can
141*da0073e9SAndroid Build Coastguard Worker        # disambiguate them with a numeric suffix.
142*da0073e9SAndroid Build Coastguard Worker        op_counter = Counter[str]()
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
145*da0073e9SAndroid Build Coastguard Worker        # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
146*da0073e9SAndroid Build Coastguard Worker        # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
147*da0073e9SAndroid Build Coastguard Worker        infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
148*da0073e9SAndroid Build Coastguard Worker        used_dispatch_keys: set[str] = set()
149*da0073e9SAndroid Build Coastguard Worker        for defn_dict in definitions:
150*da0073e9SAndroid Build Coastguard Worker            # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
151*da0073e9SAndroid Build Coastguard Worker            if "dispatch" not in defn_dict:
152*da0073e9SAndroid Build Coastguard Worker                specification = defn_dict.pop("name")
153*da0073e9SAndroid Build Coastguard Worker                output_differentiability = defn_dict.pop(
154*da0073e9SAndroid Build Coastguard Worker                    "output_differentiability", None
155*da0073e9SAndroid Build Coastguard Worker                )
156*da0073e9SAndroid Build Coastguard Worker                defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
157*da0073e9SAndroid Build Coastguard Worker                if output_differentiability:
158*da0073e9SAndroid Build Coastguard Worker                    defn_dict["output_differentiability"] = output_differentiability
159*da0073e9SAndroid Build Coastguard Worker            name, per_dispatch_diffinfos = create_differentiability_info(
160*da0073e9SAndroid Build Coastguard Worker                defn_dict,
161*da0073e9SAndroid Build Coastguard Worker                functions_by_signature,
162*da0073e9SAndroid Build Coastguard Worker                functions_by_schema,
163*da0073e9SAndroid Build Coastguard Worker                op_counter,
164*da0073e9SAndroid Build Coastguard Worker                used_dispatch_keys,
165*da0073e9SAndroid Build Coastguard Worker            )
166*da0073e9SAndroid Build Coastguard Worker            infos[name] = per_dispatch_diffinfos
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        add_view_copy_derivatives(infos, view_groups)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        # cache both loaded infos as well a a set of all the dispatch_keys/aliases
171*da0073e9SAndroid Build Coastguard Worker        # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
172*da0073e9SAndroid Build Coastguard Worker        # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
173*da0073e9SAndroid Build Coastguard Worker        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker# TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
179*da0073e9SAndroid Build Coastguard Worker@with_native_function
180*da0073e9SAndroid Build Coastguard Workerdef cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
181*da0073e9SAndroid Build Coastguard Worker    sigs = CppSignatureGroup.from_native_function(f, method=False)
182*da0073e9SAndroid Build Coastguard Worker    if sigs.symint_signature is not None:
183*da0073e9SAndroid Build Coastguard Worker        return sigs.symint_signature.arguments()
184*da0073e9SAndroid Build Coastguard Worker    else:
185*da0073e9SAndroid Build Coastguard Worker        return sigs.signature.arguments()
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Workerdef create_derivative(
189*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
190*da0073e9SAndroid Build Coastguard Worker    formula: str,
191*da0073e9SAndroid Build Coastguard Worker    var_names: tuple[str, ...],
192*da0073e9SAndroid Build Coastguard Worker    available_named_gradients: Sequence[str],
193*da0073e9SAndroid Build Coastguard Worker) -> Derivative:
194*da0073e9SAndroid Build Coastguard Worker    original_formula = formula
195*da0073e9SAndroid Build Coastguard Worker    arguments: list[NamedCType] = [
196*da0073e9SAndroid Build Coastguard Worker        a.nctype.remove_const_ref() for a in cpp_arguments(f)
197*da0073e9SAndroid Build Coastguard Worker    ]
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
200*da0073e9SAndroid Build Coastguard Worker    return_types = tuple(
201*da0073e9SAndroid Build Coastguard Worker        cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
202*da0073e9SAndroid Build Coastguard Worker    )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    named_returns = [
205*da0073e9SAndroid Build Coastguard Worker        NamedCType(name, type) for name, type in zip(return_names, return_types)
206*da0073e9SAndroid Build Coastguard Worker    ]
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    formula, saved_inputs = saved_variables(formula, arguments, var_names)
209*da0073e9SAndroid Build Coastguard Worker    formula, saved_outputs = saved_variables(formula, named_returns, var_names)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    used_named_gradients = {
212*da0073e9SAndroid Build Coastguard Worker        name
213*da0073e9SAndroid Build Coastguard Worker        for name in available_named_gradients
214*da0073e9SAndroid Build Coastguard Worker        if re.search(IDENT_REGEX.format(name), formula)
215*da0073e9SAndroid Build Coastguard Worker    }
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    # Check that the referenced derivatives in the formula are in bounds
218*da0073e9SAndroid Build Coastguard Worker    for i in used_gradient_indices(formula):
219*da0073e9SAndroid Build Coastguard Worker        if i >= len(f.func.returns):
220*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
221*da0073e9SAndroid Build Coastguard Worker                f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
222*da0073e9SAndroid Build Coastguard Worker                f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
223*da0073e9SAndroid Build Coastguard Worker            )
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    return Derivative(
226*da0073e9SAndroid Build Coastguard Worker        formula=formula,
227*da0073e9SAndroid Build Coastguard Worker        original_formula=original_formula,
228*da0073e9SAndroid Build Coastguard Worker        var_names=var_names,
229*da0073e9SAndroid Build Coastguard Worker        saved_inputs=saved_inputs,
230*da0073e9SAndroid Build Coastguard Worker        saved_outputs=saved_outputs,
231*da0073e9SAndroid Build Coastguard Worker        named_gradients=used_named_gradients,
232*da0073e9SAndroid Build Coastguard Worker    )
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Workerdef create_forward_derivative(
236*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction, formula: str, names: tuple[str, ...]
237*da0073e9SAndroid Build Coastguard Worker) -> ForwardDerivative:
238*da0073e9SAndroid Build Coastguard Worker    var_names = names
239*da0073e9SAndroid Build Coastguard Worker    var_types: tuple[Type, ...] | None = None
240*da0073e9SAndroid Build Coastguard Worker    for r in f.func.returns:
241*da0073e9SAndroid Build Coastguard Worker        if r.name in var_names:
242*da0073e9SAndroid Build Coastguard Worker            if var_types is None:
243*da0073e9SAndroid Build Coastguard Worker                var_types = ()
244*da0073e9SAndroid Build Coastguard Worker            var_types = var_types + (r.type,)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    # Handle default return names
247*da0073e9SAndroid Build Coastguard Worker    if var_types is None:
248*da0073e9SAndroid Build Coastguard Worker        if var_names == ("result",):
249*da0073e9SAndroid Build Coastguard Worker            assert len(f.func.returns) == 1
250*da0073e9SAndroid Build Coastguard Worker            var_types = (f.func.returns[0].type,)
251*da0073e9SAndroid Build Coastguard Worker        else:
252*da0073e9SAndroid Build Coastguard Worker            for var_name in var_names:
253*da0073e9SAndroid Build Coastguard Worker                res = re.findall(r"^result(\d+)$", var_name)
254*da0073e9SAndroid Build Coastguard Worker                if len(res) == 1:
255*da0073e9SAndroid Build Coastguard Worker                    if var_types is None:
256*da0073e9SAndroid Build Coastguard Worker                        var_types = ()
257*da0073e9SAndroid Build Coastguard Worker                    arg_idx = int(res[0])
258*da0073e9SAndroid Build Coastguard Worker                    var_types = var_types + (f.func.returns[arg_idx].type,)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    assert var_types is not None, "No matching output for forward derivative definition"
261*da0073e9SAndroid Build Coastguard Worker    return ForwardDerivative(
262*da0073e9SAndroid Build Coastguard Worker        formula=formula,
263*da0073e9SAndroid Build Coastguard Worker        var_names=var_names,
264*da0073e9SAndroid Build Coastguard Worker        var_types=var_types,
265*da0073e9SAndroid Build Coastguard Worker        required_inputs_fw_grad=None,
266*da0073e9SAndroid Build Coastguard Worker        required_inputs_primal=None,
267*da0073e9SAndroid Build Coastguard Worker        required_original_self_value=False,
268*da0073e9SAndroid Build Coastguard Worker        is_reusing_outplace_formula=False,
269*da0073e9SAndroid Build Coastguard Worker    )
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Workerdef postprocess_forward_derivatives(
273*da0073e9SAndroid Build Coastguard Worker    f: NativeFunction,
274*da0073e9SAndroid Build Coastguard Worker    defn_name: str,
275*da0073e9SAndroid Build Coastguard Worker    all_arg_names: list[str],
276*da0073e9SAndroid Build Coastguard Worker    derivatives: list[Derivative],
277*da0073e9SAndroid Build Coastguard Worker    forward_derivatives: list[ForwardDerivative],
278*da0073e9SAndroid Build Coastguard Worker    args_with_derivatives: Sequence[Binding],
279*da0073e9SAndroid Build Coastguard Worker) -> list[ForwardDerivative]:
280*da0073e9SAndroid Build Coastguard Worker    def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
281*da0073e9SAndroid Build Coastguard Worker        is_foreach = f.func.name.name.base.startswith("_foreach_")
282*da0073e9SAndroid Build Coastguard Worker        required_inputs = set()
283*da0073e9SAndroid Build Coastguard Worker        for arg in args_with_derivatives:
284*da0073e9SAndroid Build Coastguard Worker            if (
285*da0073e9SAndroid Build Coastguard Worker                arg.type in ("at::TensorList", "const at::ITensorListRef &")
286*da0073e9SAndroid Build Coastguard Worker                and not is_foreach
287*da0073e9SAndroid Build Coastguard Worker            ):
288*da0073e9SAndroid Build Coastguard Worker                # The functions taking TensorList handle everything internally
289*da0073e9SAndroid Build Coastguard Worker                continue
290*da0073e9SAndroid Build Coastguard Worker            arg_name = arg.name
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker            found = re.search(IDENT_REGEX.format(arg_name), formula)
293*da0073e9SAndroid Build Coastguard Worker            if found:
294*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
295*da0073e9SAndroid Build Coastguard Worker                    f"The forward formula for {defn_name} is using the base name of the {arg_name} "
296*da0073e9SAndroid Build Coastguard Worker                    f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
297*da0073e9SAndroid Build Coastguard Worker                    f"value and {arg_name}_t to access the tangent."
298*da0073e9SAndroid Build Coastguard Worker                )
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker            found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
301*da0073e9SAndroid Build Coastguard Worker            if found:
302*da0073e9SAndroid Build Coastguard Worker                required_inputs.add(arg_name)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        return tuple(required_inputs)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker    updated_derivatives: list[ForwardDerivative] = []
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    for defn in forward_derivatives:
309*da0073e9SAndroid Build Coastguard Worker        formula = defn.formula
310*da0073e9SAndroid Build Coastguard Worker        required_inputs_tangent = find_required_inputs(formula, "_t")
311*da0073e9SAndroid Build Coastguard Worker        if formula == "auto_element_wise":
312*da0073e9SAndroid Build Coastguard Worker            assert (
313*da0073e9SAndroid Build Coastguard Worker                f.func.kind() != SchemaKind.inplace
314*da0073e9SAndroid Build Coastguard Worker            ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
315*da0073e9SAndroid Build Coastguard Worker            if (
316*da0073e9SAndroid Build Coastguard Worker                (not len(args_with_derivatives) == 1)
317*da0073e9SAndroid Build Coastguard Worker                or len(forward_derivatives) > 1
318*da0073e9SAndroid Build Coastguard Worker                or len(forward_derivatives[0].var_names) > 1
319*da0073e9SAndroid Build Coastguard Worker            ):
320*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
321*da0073e9SAndroid Build Coastguard Worker                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
322*da0073e9SAndroid Build Coastguard Worker                    "forward definition of gradient as element_wise but this only "
323*da0073e9SAndroid Build Coastguard Worker                    "works for functions with a single differentiable input and a "
324*da0073e9SAndroid Build Coastguard Worker                    "single differentiable output."
325*da0073e9SAndroid Build Coastguard Worker                )
326*da0073e9SAndroid Build Coastguard Worker            if not len(derivatives) == 1:
327*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
328*da0073e9SAndroid Build Coastguard Worker                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
329*da0073e9SAndroid Build Coastguard Worker                    "forward definition of gradient as element_wise but it does not "
330*da0073e9SAndroid Build Coastguard Worker                    "defines the gradient formula for its argument which is required."
331*da0073e9SAndroid Build Coastguard Worker                )
332*da0073e9SAndroid Build Coastguard Worker            # This transformation is based on the observation that for element-wise functions, the Jacobian
333*da0073e9SAndroid Build Coastguard Worker            # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
334*da0073e9SAndroid Build Coastguard Worker            # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
335*da0073e9SAndroid Build Coastguard Worker            # So here we are going to re-use the backward formula and replace two things:
336*da0073e9SAndroid Build Coastguard Worker            # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
337*da0073e9SAndroid Build Coastguard Worker            # 2) all usage of an original input "foo" with its primal value "foo_p".
338*da0073e9SAndroid Build Coastguard Worker            # 3) conjugate the final result
339*da0073e9SAndroid Build Coastguard Worker            # For example, for abs, the backward formula is:
340*da0073e9SAndroid Build Coastguard Worker            #   grad * self.sgn()
341*da0073e9SAndroid Build Coastguard Worker            # And this function generates a forward formula that is:
342*da0073e9SAndroid Build Coastguard Worker            #   (self_t.conj() * self_p.sgn()).conj()
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker            backward_formula = derivatives[0].original_formula
345*da0073e9SAndroid Build Coastguard Worker            input_name = args_with_derivatives[0].name
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker            # Do replacement 1) of the grad
348*da0073e9SAndroid Build Coastguard Worker            def repl(m: Any) -> str:
349*da0073e9SAndroid Build Coastguard Worker                return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker            fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker            # Do replacement 2) of the input variables
354*da0073e9SAndroid Build Coastguard Worker            for arg in args_with_derivatives:
355*da0073e9SAndroid Build Coastguard Worker                arg_name = arg.name
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker                def repl(m: Any) -> str:
358*da0073e9SAndroid Build Coastguard Worker                    return f"{m.group(1)}{arg_name}_p{m.group(2)}"
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker                fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker            # Do the final conjugate 3)
363*da0073e9SAndroid Build Coastguard Worker            fw_formula = f"({fw_formula}).conj()"
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker            # Since there is a single differentiable inputs and we necessarily need its tangent we can
366*da0073e9SAndroid Build Coastguard Worker            # simply require all differentiable input's tangent.
367*da0073e9SAndroid Build Coastguard Worker            required_inputs_tangent = tuple(all_arg_names)
368*da0073e9SAndroid Build Coastguard Worker            formula = fw_formula
369*da0073e9SAndroid Build Coastguard Worker        elif formula == "auto_linear":
370*da0073e9SAndroid Build Coastguard Worker            if (
371*da0073e9SAndroid Build Coastguard Worker                len(forward_derivatives) > 1
372*da0073e9SAndroid Build Coastguard Worker                or len(forward_derivatives[0].var_names) > 1
373*da0073e9SAndroid Build Coastguard Worker            ):
374*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
375*da0073e9SAndroid Build Coastguard Worker                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
376*da0073e9SAndroid Build Coastguard Worker                    "forward definition of gradient as linear but this only works "
377*da0073e9SAndroid Build Coastguard Worker                    "for functions with a single differentiable output."
378*da0073e9SAndroid Build Coastguard Worker                )
379*da0073e9SAndroid Build Coastguard Worker            # This transformation is based on the observation that linear functions can be written as:
380*da0073e9SAndroid Build Coastguard Worker            #   y = f(x) = A * x
381*da0073e9SAndroid Build Coastguard Worker            # For some matrix A and the Jacobian of the function f is also A.
382*da0073e9SAndroid Build Coastguard Worker            # So doing J * v = A * v = f(v).
383*da0073e9SAndroid Build Coastguard Worker            # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
384*da0073e9SAndroid Build Coastguard Worker            # We do this by calling the forward again by replacing any occurrence of the differentiable
385*da0073e9SAndroid Build Coastguard Worker            # input "foo" by it's tangent "foo_t".
386*da0073e9SAndroid Build Coastguard Worker            # Note that multiple inputs are not a problem as long as the function is truly linear wrt to
387*da0073e9SAndroid Build Coastguard Worker            # the vector where all the differentiable inputs are stacked.
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker            diff_arg_names = [arg.name for arg in args_with_derivatives]
390*da0073e9SAndroid Build Coastguard Worker            assert len(diff_arg_names) > 0
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker            # Do replacement of input variables
393*da0073e9SAndroid Build Coastguard Worker            new_args = []
394*da0073e9SAndroid Build Coastguard Worker            for arg_name in all_arg_names:
395*da0073e9SAndroid Build Coastguard Worker                if arg_name in diff_arg_names:
396*da0073e9SAndroid Build Coastguard Worker                    arg_name = arg_name + "_t"
397*da0073e9SAndroid Build Coastguard Worker                new_args.append(arg_name)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker            # TODO we are trolling
400*da0073e9SAndroid Build Coastguard Worker            if f.func.has_symint():
401*da0073e9SAndroid Build Coastguard Worker                defn_name += "_symint"
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker            # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
404*da0073e9SAndroid Build Coastguard Worker            if Variant.function in f.variants:
405*da0073e9SAndroid Build Coastguard Worker                fw_formula = f"at::{defn_name}({', '.join(new_args)})"
406*da0073e9SAndroid Build Coastguard Worker            else:
407*da0073e9SAndroid Build Coastguard Worker                assert Variant.method in f.variants
408*da0073e9SAndroid Build Coastguard Worker                fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker            # All of the input tangents are always used so all of them are required here.
411*da0073e9SAndroid Build Coastguard Worker            required_inputs_tangent = tuple(diff_arg_names)
412*da0073e9SAndroid Build Coastguard Worker            formula = fw_formula
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker        # At this point, the formula is final and is not modified anymore.
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        # During forward formula, we use the primal instead of the input Tensors.
417*da0073e9SAndroid Build Coastguard Worker        # This call inspects the formula to find for which input's primal are used.
418*da0073e9SAndroid Build Coastguard Worker        required_inputs_primal = find_required_inputs(formula, "_p")
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        updated_derivatives.append(
421*da0073e9SAndroid Build Coastguard Worker            ForwardDerivative(
422*da0073e9SAndroid Build Coastguard Worker                formula=formula,
423*da0073e9SAndroid Build Coastguard Worker                var_names=defn.var_names,
424*da0073e9SAndroid Build Coastguard Worker                var_types=defn.var_types,
425*da0073e9SAndroid Build Coastguard Worker                required_inputs_fw_grad=required_inputs_tangent,
426*da0073e9SAndroid Build Coastguard Worker                required_inputs_primal=required_inputs_primal,
427*da0073e9SAndroid Build Coastguard Worker                required_original_self_value=False,
428*da0073e9SAndroid Build Coastguard Worker                is_reusing_outplace_formula=False,
429*da0073e9SAndroid Build Coastguard Worker            )
430*da0073e9SAndroid Build Coastguard Worker        )
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    return updated_derivatives
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Workerdef is_forward_derivative_definition(
436*da0073e9SAndroid Build Coastguard Worker    all_arg_names: list[str], names: tuple[str, ...]
437*da0073e9SAndroid Build Coastguard Worker) -> bool:
438*da0073e9SAndroid Build Coastguard Worker    for name in names:
439*da0073e9SAndroid Build Coastguard Worker        return name not in all_arg_names
440*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("Expected `names` to be non-empty")
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Workerdef create_differentiability_info(
444*da0073e9SAndroid Build Coastguard Worker    defn_dict: dict[Any, Any],
445*da0073e9SAndroid Build Coastguard Worker    functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
446*da0073e9SAndroid Build Coastguard Worker    functions_by_schema: dict[str, NativeFunction],
447*da0073e9SAndroid Build Coastguard Worker    op_counter: Counter[str],
448*da0073e9SAndroid Build Coastguard Worker    used_dispatch_keys: set[str],
449*da0073e9SAndroid Build Coastguard Worker) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
450*da0073e9SAndroid Build Coastguard Worker    """Processes a single entry `defn` in derivatives.yaml"""
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    def canonical_function(
453*da0073e9SAndroid Build Coastguard Worker        functions: Sequence[NativeFunction], name: str
454*da0073e9SAndroid Build Coastguard Worker    ) -> NativeFunction:
455*da0073e9SAndroid Build Coastguard Worker        for f in functions:
456*da0073e9SAndroid Build Coastguard Worker            if (
457*da0073e9SAndroid Build Coastguard Worker                not f.func.is_functional_fn()
458*da0073e9SAndroid Build Coastguard Worker                and not f.func.is_out_fn()
459*da0073e9SAndroid Build Coastguard Worker                and name == str(f.func.name.name)
460*da0073e9SAndroid Build Coastguard Worker            ):
461*da0073e9SAndroid Build Coastguard Worker                return f
462*da0073e9SAndroid Build Coastguard Worker        # some functions only have in-place variants
463*da0073e9SAndroid Build Coastguard Worker        assert name + "_" == cpp.name(functions[0].func)
464*da0073e9SAndroid Build Coastguard Worker        return functions[0]
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    def split_names(raw_names: str) -> tuple[str, ...]:
467*da0073e9SAndroid Build Coastguard Worker        """Given "foo, bar", return ["foo", "bar"]."""
468*da0073e9SAndroid Build Coastguard Worker        return tuple(x.strip() for x in raw_names.split(","))
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
471*da0073e9SAndroid Build Coastguard Worker        """
472*da0073e9SAndroid Build Coastguard Worker        Check for some subtle mistakes one might make when writing derivatives.
473*da0073e9SAndroid Build Coastguard Worker        These mistakes will compile, but will be latent until a function is
474*da0073e9SAndroid Build Coastguard Worker        used with double backwards.
475*da0073e9SAndroid Build Coastguard Worker        """
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker        uses_grad = False  # true if any derivative uses "grad"
478*da0073e9SAndroid Build Coastguard Worker        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
479*da0073e9SAndroid Build Coastguard Worker        uses_named_grads = False  # true if any derivative uses "grad_{name}"
480*da0073e9SAndroid Build Coastguard Worker        used_grads_indices: list[int] = []  # which indices of grads are used
481*da0073e9SAndroid Build Coastguard Worker        for d in derivatives:
482*da0073e9SAndroid Build Coastguard Worker            formula = d.formula
483*da0073e9SAndroid Build Coastguard Worker            uses_grad = uses_grad or bool(
484*da0073e9SAndroid Build Coastguard Worker                re.findall(IDENT_REGEX.format("grad"), formula)
485*da0073e9SAndroid Build Coastguard Worker            )
486*da0073e9SAndroid Build Coastguard Worker            num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
487*da0073e9SAndroid Build Coastguard Worker            uses_named_grads = uses_named_grads or bool(d.named_gradients)
488*da0073e9SAndroid Build Coastguard Worker            used_grads_indices.extend(used_gradient_indices(formula))
489*da0073e9SAndroid Build Coastguard Worker        # This is a basic sanity check: the number of places we see
490*da0073e9SAndroid Build Coastguard Worker        # "grads" should be no fewer than the number of indices we see
491*da0073e9SAndroid Build Coastguard Worker        # inside "grads". They may not be equal because we may use
492*da0073e9SAndroid Build Coastguard Worker        # "grads" without an index.
493*da0073e9SAndroid Build Coastguard Worker        assert num_grads_uses >= len(used_grads_indices)
494*da0073e9SAndroid Build Coastguard Worker        # Thus if the number is equal, every use of grads is also
495*da0073e9SAndroid Build Coastguard Worker        # indexed.
496*da0073e9SAndroid Build Coastguard Worker        only_used_grads_indices = num_grads_uses == len(used_grads_indices)
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        if uses_grad and num_grads_uses > 0:
499*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
500*da0073e9SAndroid Build Coastguard Worker                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
501*da0073e9SAndroid Build Coastguard Worker                "mixes use of 'grad' and 'grads'. Consider replacing "
502*da0073e9SAndroid Build Coastguard Worker                "occurrences of 'grad' with 'grads[0]'"
503*da0073e9SAndroid Build Coastguard Worker            )
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        if only_used_grads_indices and set(used_grads_indices) == {0}:
506*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
507*da0073e9SAndroid Build Coastguard Worker                f"Derivative definition of {defn_name} in derivatives.yaml solely "
508*da0073e9SAndroid Build Coastguard Worker                "refers to 'grads[0]'.  If the first output is indeed the "
509*da0073e9SAndroid Build Coastguard Worker                "only differentiable output, replace 'grads[0]' with 'grad'; "
510*da0073e9SAndroid Build Coastguard Worker                "otherwise, there is a likely error in your derivatives "
511*da0073e9SAndroid Build Coastguard Worker                "declaration."
512*da0073e9SAndroid Build Coastguard Worker            )
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker        if uses_named_grads and (uses_grad or num_grads_uses > 0):
515*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
516*da0073e9SAndroid Build Coastguard Worker                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
517*da0073e9SAndroid Build Coastguard Worker                'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
518*da0073e9SAndroid Build Coastguard Worker                "only one method for identifying gradients."
519*da0073e9SAndroid Build Coastguard Worker            )
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    @with_native_function
522*da0073e9SAndroid Build Coastguard Worker    def set_up_derivatives(
523*da0073e9SAndroid Build Coastguard Worker        f: NativeFunction,
524*da0073e9SAndroid Build Coastguard Worker    ) -> tuple[
525*da0073e9SAndroid Build Coastguard Worker        Sequence[Derivative],
526*da0073e9SAndroid Build Coastguard Worker        Sequence[ForwardDerivative],
527*da0073e9SAndroid Build Coastguard Worker        Sequence[Binding],
528*da0073e9SAndroid Build Coastguard Worker        Sequence[str],
529*da0073e9SAndroid Build Coastguard Worker        Sequence[str],
530*da0073e9SAndroid Build Coastguard Worker    ]:
531*da0073e9SAndroid Build Coastguard Worker        # Set up the derivative information
532*da0073e9SAndroid Build Coastguard Worker        derivatives: list[Derivative] = []
533*da0073e9SAndroid Build Coastguard Worker        forward_derivatives: list[ForwardDerivative] = []
534*da0073e9SAndroid Build Coastguard Worker        non_differentiable_arg_names: list[str] = []
535*da0073e9SAndroid Build Coastguard Worker        args_with_derivatives_set: set[str] = set()
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker        all_arg_names = [a.name for a in cpp_arguments(f)]
538*da0073e9SAndroid Build Coastguard Worker        all_ret_names = [
539*da0073e9SAndroid Build Coastguard Worker            r.name for r in f.func.returns
540*da0073e9SAndroid Build Coastguard Worker        ]  # only used for the assert below
541*da0073e9SAndroid Build Coastguard Worker        # output_differentiability is captured from the enclosed
542*da0073e9SAndroid Build Coastguard Worker        # scope. Don't modify it.
543*da0073e9SAndroid Build Coastguard Worker        #
544*da0073e9SAndroid Build Coastguard Worker        # If it is not present, then no output is explicitly
545*da0073e9SAndroid Build Coastguard Worker        # undifferentiable.
546*da0073e9SAndroid Build Coastguard Worker        #
547*da0073e9SAndroid Build Coastguard Worker        # It may be present and shorter than the length of return
548*da0073e9SAndroid Build Coastguard Worker        # values. If that's the case, any return value that does not
549*da0073e9SAndroid Build Coastguard Worker        # have a corresponding entry is considered not differentiable.
550*da0073e9SAndroid Build Coastguard Worker        differentiability = output_differentiability or [True] * len(f.func.returns)
551*da0073e9SAndroid Build Coastguard Worker        # A return is available as a named gradient ...
552*da0073e9SAndroid Build Coastguard Worker        available_named_gradients = [
553*da0073e9SAndroid Build Coastguard Worker            f"grad_{ret.name}"
554*da0073e9SAndroid Build Coastguard Worker            for ret, differentiable in zip(f.func.returns, differentiability)
555*da0073e9SAndroid Build Coastguard Worker            # if it has not been explicitly made undifferentiable
556*da0073e9SAndroid Build Coastguard Worker            if differentiable
557*da0073e9SAndroid Build Coastguard Worker            # and if it has a name
558*da0073e9SAndroid Build Coastguard Worker            and ret.name is not None
559*da0073e9SAndroid Build Coastguard Worker            # and if its type is differentiable
560*da0073e9SAndroid Build Coastguard Worker            and ret.type.is_tensor_like()
561*da0073e9SAndroid Build Coastguard Worker        ]
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker        for raw_names in sorted(defn.keys()):
564*da0073e9SAndroid Build Coastguard Worker            formula = defn[raw_names]
565*da0073e9SAndroid Build Coastguard Worker            names = split_names(raw_names)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker            for name in names:
568*da0073e9SAndroid Build Coastguard Worker                assert not (name in all_arg_names and name in all_ret_names), (
569*da0073e9SAndroid Build Coastguard Worker                    f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
570*da0073e9SAndroid Build Coastguard Worker                    f"expected '{name}' to not be both an input arg and named return. "
571*da0073e9SAndroid Build Coastguard Worker                )
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker            if is_forward_derivative_definition(all_arg_names, names):
574*da0073e9SAndroid Build Coastguard Worker                forward_derivatives.append(create_forward_derivative(f, formula, names))
575*da0073e9SAndroid Build Coastguard Worker            else:
576*da0073e9SAndroid Build Coastguard Worker                if formula.lower().strip() == "non_differentiable":
577*da0073e9SAndroid Build Coastguard Worker                    non_differentiable_arg_names += names
578*da0073e9SAndroid Build Coastguard Worker                else:
579*da0073e9SAndroid Build Coastguard Worker                    derivative = create_derivative(
580*da0073e9SAndroid Build Coastguard Worker                        f, formula, names, available_named_gradients
581*da0073e9SAndroid Build Coastguard Worker                    )
582*da0073e9SAndroid Build Coastguard Worker                    derivatives.append(derivative)
583*da0073e9SAndroid Build Coastguard Worker                    args_with_derivatives_set |= set(names)
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
586*da0073e9SAndroid Build Coastguard Worker        if overlap:
587*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
588*da0073e9SAndroid Build Coastguard Worker                f"derivatives definition for {defn} have overlapped non_differentiable "
589*da0073e9SAndroid Build Coastguard Worker                f"and differentiable variables: {overlap}"
590*da0073e9SAndroid Build Coastguard Worker            )
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        # Next, let us determine the list of inputs in order.
593*da0073e9SAndroid Build Coastguard Worker        # TODO: do we need eagerly calculate and save it here? Can it be derived
594*da0073e9SAndroid Build Coastguard Worker        # from NativeFunction and `derivatives` on callsites instead?
595*da0073e9SAndroid Build Coastguard Worker        args_with_derivatives = [
596*da0073e9SAndroid Build Coastguard Worker            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
597*da0073e9SAndroid Build Coastguard Worker        ]
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker        # Postprocess forward derivatives definitions now that we know the differentiable arguments
600*da0073e9SAndroid Build Coastguard Worker        forward_derivatives = postprocess_forward_derivatives(
601*da0073e9SAndroid Build Coastguard Worker            f,
602*da0073e9SAndroid Build Coastguard Worker            defn_name,
603*da0073e9SAndroid Build Coastguard Worker            all_arg_names,
604*da0073e9SAndroid Build Coastguard Worker            derivatives,
605*da0073e9SAndroid Build Coastguard Worker            forward_derivatives,
606*da0073e9SAndroid Build Coastguard Worker            args_with_derivatives,
607*da0073e9SAndroid Build Coastguard Worker        )
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        # Test to see if the use of 'grads' makes sense.
610*da0073e9SAndroid Build Coastguard Worker        check_grad_usage(defn_name, derivatives)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker        return (
613*da0073e9SAndroid Build Coastguard Worker            derivatives,
614*da0073e9SAndroid Build Coastguard Worker            forward_derivatives,
615*da0073e9SAndroid Build Coastguard Worker            args_with_derivatives,
616*da0073e9SAndroid Build Coastguard Worker            non_differentiable_arg_names,
617*da0073e9SAndroid Build Coastguard Worker            available_named_gradients,
618*da0073e9SAndroid Build Coastguard Worker        )
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    # NB: Removes 'name' from defn dictionary
621*da0073e9SAndroid Build Coastguard Worker    specification = defn_dict.pop("name")
622*da0073e9SAndroid Build Coastguard Worker    defn_name, _ = split_name_params(specification)
623*da0073e9SAndroid Build Coastguard Worker    # NB: Removes 'output_differentiability' from defn dictionary
624*da0073e9SAndroid Build Coastguard Worker    #     `None` means all differentiable.
625*da0073e9SAndroid Build Coastguard Worker    output_differentiability = defn_dict.pop("output_differentiability", None)
626*da0073e9SAndroid Build Coastguard Worker    output_differentiability_conditions = None
627*da0073e9SAndroid Build Coastguard Worker    if output_differentiability and any(
628*da0073e9SAndroid Build Coastguard Worker        isinstance(diff, str) for diff in output_differentiability
629*da0073e9SAndroid Build Coastguard Worker    ):
630*da0073e9SAndroid Build Coastguard Worker        if len(output_differentiability) != 1:
631*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
632*da0073e9SAndroid Build Coastguard Worker                f"Not supported: for {specification},"
633*da0073e9SAndroid Build Coastguard Worker                f"output_differentiability must either be "
634*da0073e9SAndroid Build Coastguard Worker                f"List[bool] or a List[str] where each str is a "
635*da0073e9SAndroid Build Coastguard Worker                f"condition. In the case where it is a condition, "
636*da0073e9SAndroid Build Coastguard Worker                f"we only support single-output functions. "
637*da0073e9SAndroid Build Coastguard Worker                f"Please file us an issue. "
638*da0073e9SAndroid Build Coastguard Worker            )
639*da0073e9SAndroid Build Coastguard Worker        output_differentiability_conditions = output_differentiability
640*da0073e9SAndroid Build Coastguard Worker        output_differentiability = [True]
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker    schema_function = functions_by_schema.get(specification)
643*da0073e9SAndroid Build Coastguard Worker    if not schema_function:
644*da0073e9SAndroid Build Coastguard Worker        avail = "\n".join(
645*da0073e9SAndroid Build Coastguard Worker            k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
646*da0073e9SAndroid Build Coastguard Worker        )
647*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
648*da0073e9SAndroid Build Coastguard Worker            f"could not find ATen function for schema: {specification} "
649*da0073e9SAndroid Build Coastguard Worker            f".  Available signatures:\n{avail}"
650*da0073e9SAndroid Build Coastguard Worker        )
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
653*da0073e9SAndroid Build Coastguard Worker    # to map in-place schemas to the out-of-place variants.
654*da0073e9SAndroid Build Coastguard Worker    # TODO: maybe the logic to handle the legacy schema is no longer necessary?
655*da0073e9SAndroid Build Coastguard Worker    signature = schema_function.func.signature()
656*da0073e9SAndroid Build Coastguard Worker    functions = functions_by_signature[signature]
657*da0073e9SAndroid Build Coastguard Worker    if len(functions) == 0:
658*da0073e9SAndroid Build Coastguard Worker        avail = "\n".join(
659*da0073e9SAndroid Build Coastguard Worker            str(k)
660*da0073e9SAndroid Build Coastguard Worker            for k, v in functions_by_signature.items()
661*da0073e9SAndroid Build Coastguard Worker            if cpp.name(k) == defn_name
662*da0073e9SAndroid Build Coastguard Worker        )
663*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
664*da0073e9SAndroid Build Coastguard Worker            f"could not find ATen function for legacy signature: {signature} "
665*da0073e9SAndroid Build Coastguard Worker            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
666*da0073e9SAndroid Build Coastguard Worker            f"Available signatures:\n{avail}"
667*da0073e9SAndroid Build Coastguard Worker        )
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    canonical = canonical_function(functions, defn_name)
670*da0073e9SAndroid Build Coastguard Worker    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
671*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
672*da0073e9SAndroid Build Coastguard Worker            f"Schema for {defn_name} has an argument named grad_input_mask, "
673*da0073e9SAndroid Build Coastguard Worker            "but this name would be shadowed by our codegen. "
674*da0073e9SAndroid Build Coastguard Worker            "Please use a different name in native_functions.yaml."
675*da0073e9SAndroid Build Coastguard Worker        )
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker    if "result" in (a.name for a in cpp_arguments(canonical)):
678*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
679*da0073e9SAndroid Build Coastguard Worker            f"Schema for {defn_name} has an argument named result, "
680*da0073e9SAndroid Build Coastguard Worker            "but this is only allowed for outputs."
681*da0073e9SAndroid Build Coastguard Worker            "Please use a different name in native_functions.yaml."
682*da0073e9SAndroid Build Coastguard Worker        )
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker    diffinfo_dict = {}
685*da0073e9SAndroid Build Coastguard Worker    for key, defn in defn_dict["dispatch"].items():
686*da0073e9SAndroid Build Coastguard Worker        if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
687*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
688*da0073e9SAndroid Build Coastguard Worker                f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
689*da0073e9SAndroid Build Coastguard Worker                f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
690*da0073e9SAndroid Build Coastguard Worker            )
691*da0073e9SAndroid Build Coastguard Worker        if key not in used_dispatch_keys:
692*da0073e9SAndroid Build Coastguard Worker            used_dispatch_keys.add(key)
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        (
695*da0073e9SAndroid Build Coastguard Worker            derivatives,
696*da0073e9SAndroid Build Coastguard Worker            forward_derivatives,
697*da0073e9SAndroid Build Coastguard Worker            args_with_derivatives,
698*da0073e9SAndroid Build Coastguard Worker            non_differentiable_arg_names,
699*da0073e9SAndroid Build Coastguard Worker            available_named_gradients,
700*da0073e9SAndroid Build Coastguard Worker        ) = set_up_derivatives(canonical)
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker        used_named_gradients: set[str] = set()
703*da0073e9SAndroid Build Coastguard Worker        for d in derivatives:
704*da0073e9SAndroid Build Coastguard Worker            used_named_gradients |= d.named_gradients
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker        # only assign an op name if we are actually going to calculate a derivative
707*da0073e9SAndroid Build Coastguard Worker        op = None
708*da0073e9SAndroid Build Coastguard Worker        if args_with_derivatives:
709*da0073e9SAndroid Build Coastguard Worker            op_prefix = _create_op_prefix(defn_name)
710*da0073e9SAndroid Build Coastguard Worker            if key != "Default":
711*da0073e9SAndroid Build Coastguard Worker                op_prefix = op_prefix + key
712*da0073e9SAndroid Build Coastguard Worker            op = f"{op_prefix}{op_counter[op_prefix]}"
713*da0073e9SAndroid Build Coastguard Worker            op_counter[op_prefix] += 1
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker        diffinfo_dict[key] = DifferentiabilityInfo(
716*da0073e9SAndroid Build Coastguard Worker            name=defn_name,
717*da0073e9SAndroid Build Coastguard Worker            func=canonical,
718*da0073e9SAndroid Build Coastguard Worker            op=op,
719*da0073e9SAndroid Build Coastguard Worker            derivatives=derivatives,
720*da0073e9SAndroid Build Coastguard Worker            forward_derivatives=forward_derivatives,
721*da0073e9SAndroid Build Coastguard Worker            all_saved_inputs=dedup_vars(
722*da0073e9SAndroid Build Coastguard Worker                [v for d in derivatives for v in d.saved_inputs]
723*da0073e9SAndroid Build Coastguard Worker            ),
724*da0073e9SAndroid Build Coastguard Worker            all_saved_outputs=dedup_vars(
725*da0073e9SAndroid Build Coastguard Worker                [v for d in derivatives for v in d.saved_outputs]
726*da0073e9SAndroid Build Coastguard Worker            ),
727*da0073e9SAndroid Build Coastguard Worker            available_named_gradients=available_named_gradients,
728*da0073e9SAndroid Build Coastguard Worker            used_named_gradients=used_named_gradients,
729*da0073e9SAndroid Build Coastguard Worker            args_with_derivatives=args_with_derivatives,
730*da0073e9SAndroid Build Coastguard Worker            non_differentiable_arg_names=non_differentiable_arg_names,
731*da0073e9SAndroid Build Coastguard Worker            output_differentiability=output_differentiability,
732*da0073e9SAndroid Build Coastguard Worker            output_differentiability_conditions=output_differentiability_conditions,
733*da0073e9SAndroid Build Coastguard Worker        )
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    return canonical.func, diffinfo_dict
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard WorkerGRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Workerdef used_gradient_indices(formula: str) -> list[int]:
742*da0073e9SAndroid Build Coastguard Worker    """Determine a list of gradient indices (the i in grads[i]) that
743*da0073e9SAndroid Build Coastguard Worker    are used by the formula.
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    >>> used_gradient_indices("foo(grads[0], grads[1])")
746*da0073e9SAndroid Build Coastguard Worker    [0, 1]
747*da0073e9SAndroid Build Coastguard Worker    """
748*da0073e9SAndroid Build Coastguard Worker    return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Workerdef saved_variables(
752*da0073e9SAndroid Build Coastguard Worker    formula: str,
753*da0073e9SAndroid Build Coastguard Worker    nctypes: list[NamedCType],
754*da0073e9SAndroid Build Coastguard Worker    var_names: tuple[str, ...],
755*da0073e9SAndroid Build Coastguard Worker) -> tuple[str, tuple[SavedAttribute, ...]]:
756*da0073e9SAndroid Build Coastguard Worker    def stride_expr(name: str) -> str:
757*da0073e9SAndroid Build Coastguard Worker        assert var_names == (name,), (
758*da0073e9SAndroid Build Coastguard Worker            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
759*da0073e9SAndroid Build Coastguard Worker            'that ".strides()" is being called on.'
760*da0073e9SAndroid Build Coastguard Worker        )
761*da0073e9SAndroid Build Coastguard Worker        return f'strides_or_error({name}, "{name}")'
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker    REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
764*da0073e9SAndroid Build Coastguard Worker        # replace self.sym_sizes() with self_sym_sizes
765*da0073e9SAndroid Build Coastguard Worker        (
766*da0073e9SAndroid Build Coastguard Worker            r"{}.sym_sizes\(\)",
767*da0073e9SAndroid Build Coastguard Worker            {
768*da0073e9SAndroid Build Coastguard Worker                "suffix": "_sym_sizes",
769*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
770*da0073e9SAndroid Build Coastguard Worker            },
771*da0073e9SAndroid Build Coastguard Worker        ),
772*da0073e9SAndroid Build Coastguard Worker        # replace self->sym_sizes() with self_sym_sizes_opt
773*da0073e9SAndroid Build Coastguard Worker        (
774*da0073e9SAndroid Build Coastguard Worker            r"{}->sym_sizes\(\)",
775*da0073e9SAndroid Build Coastguard Worker            {
776*da0073e9SAndroid Build Coastguard Worker                "suffix": "_sym_sizes_opt",
777*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(
778*da0073e9SAndroid Build Coastguard Worker                    name, OptionalCType(BaseCType(symIntArrayRefT))
779*da0073e9SAndroid Build Coastguard Worker                ),
780*da0073e9SAndroid Build Coastguard Worker                "expr": lambda name: f"{name}.has_value() ? std::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : std::nullopt",
781*da0073e9SAndroid Build Coastguard Worker            },
782*da0073e9SAndroid Build Coastguard Worker        ),
783*da0073e9SAndroid Build Coastguard Worker        # replace self.sym_blocksize() with self_sym_blocksize_opt
784*da0073e9SAndroid Build Coastguard Worker        (
785*da0073e9SAndroid Build Coastguard Worker            r"{}.sym_blocksize\(\)",
786*da0073e9SAndroid Build Coastguard Worker            {
787*da0073e9SAndroid Build Coastguard Worker                "suffix": "_self_sym_blocksize_opt",
788*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(
789*da0073e9SAndroid Build Coastguard Worker                    name, OptionalCType(BaseCType(symIntArrayRefT))
790*da0073e9SAndroid Build Coastguard Worker                ),
791*da0073e9SAndroid Build Coastguard Worker                "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
792*da0073e9SAndroid Build Coastguard Worker            },
793*da0073e9SAndroid Build Coastguard Worker        ),
794*da0073e9SAndroid Build Coastguard Worker        # replace self.options() with self_options
795*da0073e9SAndroid Build Coastguard Worker        (
796*da0073e9SAndroid Build Coastguard Worker            r"{}.options\(\)",
797*da0073e9SAndroid Build Coastguard Worker            {
798*da0073e9SAndroid Build Coastguard Worker                "suffix": "_options",
799*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
800*da0073e9SAndroid Build Coastguard Worker            },
801*da0073e9SAndroid Build Coastguard Worker        ),
802*da0073e9SAndroid Build Coastguard Worker        # replace zeros_like(self) with self_info
803*da0073e9SAndroid Build Coastguard Worker        (
804*da0073e9SAndroid Build Coastguard Worker            r"zeros_like\({}\)",
805*da0073e9SAndroid Build Coastguard Worker            {
806*da0073e9SAndroid Build Coastguard Worker                "suffix": "_info",
807*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
808*da0073e9SAndroid Build Coastguard Worker                "expr": lambda name: name,  # at save-time
809*da0073e9SAndroid Build Coastguard Worker                "res": lambda name: name + "_info.zeros()",  # at eval-time
810*da0073e9SAndroid Build Coastguard Worker            },
811*da0073e9SAndroid Build Coastguard Worker        ),
812*da0073e9SAndroid Build Coastguard Worker        # replace self.sym_size(2) with self_sym_size_2
813*da0073e9SAndroid Build Coastguard Worker        (
814*da0073e9SAndroid Build Coastguard Worker            r"{}.sym_size\((-?\w+)\)",
815*da0073e9SAndroid Build Coastguard Worker            {
816*da0073e9SAndroid Build Coastguard Worker                "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
817*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
818*da0073e9SAndroid Build Coastguard Worker            },
819*da0073e9SAndroid Build Coastguard Worker        ),
820*da0073e9SAndroid Build Coastguard Worker        # replace self.numel() with self_numel
821*da0073e9SAndroid Build Coastguard Worker        (
822*da0073e9SAndroid Build Coastguard Worker            r"{}.numel\(\)",
823*da0073e9SAndroid Build Coastguard Worker            {
824*da0073e9SAndroid Build Coastguard Worker                "suffix": "_numel",
825*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
826*da0073e9SAndroid Build Coastguard Worker            },
827*da0073e9SAndroid Build Coastguard Worker        ),
828*da0073e9SAndroid Build Coastguard Worker        # replace self.sym_numel() with self_sym_numel
829*da0073e9SAndroid Build Coastguard Worker        (
830*da0073e9SAndroid Build Coastguard Worker            r"{}.sym_numel\(\)",
831*da0073e9SAndroid Build Coastguard Worker            {
832*da0073e9SAndroid Build Coastguard Worker                "suffix": "_sym_numel",
833*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
834*da0073e9SAndroid Build Coastguard Worker            },
835*da0073e9SAndroid Build Coastguard Worker        ),
836*da0073e9SAndroid Build Coastguard Worker        # replace to_args_sizes(self) with self_args_sizes
837*da0073e9SAndroid Build Coastguard Worker        (
838*da0073e9SAndroid Build Coastguard Worker            r"to_args_sizes\({}\)",
839*da0073e9SAndroid Build Coastguard Worker            {
840*da0073e9SAndroid Build Coastguard Worker                "suffix": "_args_sizes",
841*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(
842*da0073e9SAndroid Build Coastguard Worker                    name, VectorCType(VectorCType(BaseCType(longT)))
843*da0073e9SAndroid Build Coastguard Worker                ),
844*da0073e9SAndroid Build Coastguard Worker            },
845*da0073e9SAndroid Build Coastguard Worker        ),
846*da0073e9SAndroid Build Coastguard Worker        # replace to_args_sizes_symint(self) with self_args_sizes
847*da0073e9SAndroid Build Coastguard Worker        (
848*da0073e9SAndroid Build Coastguard Worker            r"to_args_sizes_symint\({}\)",
849*da0073e9SAndroid Build Coastguard Worker            {
850*da0073e9SAndroid Build Coastguard Worker                "suffix": "_args_sizes_symint",
851*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(
852*da0073e9SAndroid Build Coastguard Worker                    name, VectorCType(VectorCType(BaseCType(SymIntT)))
853*da0073e9SAndroid Build Coastguard Worker                ),
854*da0073e9SAndroid Build Coastguard Worker            },
855*da0073e9SAndroid Build Coastguard Worker        ),
856*da0073e9SAndroid Build Coastguard Worker        # replace to_args_scalartypes(self) with self_args_scalartypes
857*da0073e9SAndroid Build Coastguard Worker        (
858*da0073e9SAndroid Build Coastguard Worker            r"to_args_scalartypes\({}\)",
859*da0073e9SAndroid Build Coastguard Worker            {
860*da0073e9SAndroid Build Coastguard Worker                "suffix": "_args_scalartypes",
861*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(
862*da0073e9SAndroid Build Coastguard Worker                    name, VectorCType(BaseCType(scalarTypeT))
863*da0073e9SAndroid Build Coastguard Worker                ),
864*da0073e9SAndroid Build Coastguard Worker            },
865*da0073e9SAndroid Build Coastguard Worker        ),
866*da0073e9SAndroid Build Coastguard Worker        # replace TensorGeometry(self) with self_geometry
867*da0073e9SAndroid Build Coastguard Worker        (
868*da0073e9SAndroid Build Coastguard Worker            r"TensorGeometry\({}\)",
869*da0073e9SAndroid Build Coastguard Worker            {
870*da0073e9SAndroid Build Coastguard Worker                "suffix": "_geometry",
871*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
872*da0073e9SAndroid Build Coastguard Worker            },
873*da0073e9SAndroid Build Coastguard Worker        ),
874*da0073e9SAndroid Build Coastguard Worker        (
875*da0073e9SAndroid Build Coastguard Worker            r"{}.scalar_type\(\)",
876*da0073e9SAndroid Build Coastguard Worker            {
877*da0073e9SAndroid Build Coastguard Worker                "suffix": "_scalar_type",
878*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
879*da0073e9SAndroid Build Coastguard Worker            },
880*da0073e9SAndroid Build Coastguard Worker        ),
881*da0073e9SAndroid Build Coastguard Worker        # replace self.dim() with self_dim
882*da0073e9SAndroid Build Coastguard Worker        (
883*da0073e9SAndroid Build Coastguard Worker            r"{}.dim\(\)",
884*da0073e9SAndroid Build Coastguard Worker            {
885*da0073e9SAndroid Build Coastguard Worker                "suffix": "_dim",
886*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
887*da0073e9SAndroid Build Coastguard Worker            },
888*da0073e9SAndroid Build Coastguard Worker        ),
889*da0073e9SAndroid Build Coastguard Worker        # replace self.sym_strides() with self_sym_strides
890*da0073e9SAndroid Build Coastguard Worker        (
891*da0073e9SAndroid Build Coastguard Worker            r"{}.sym_strides\(\)",
892*da0073e9SAndroid Build Coastguard Worker            {
893*da0073e9SAndroid Build Coastguard Worker                "suffix": "_sym_strides",
894*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
895*da0073e9SAndroid Build Coastguard Worker                "expr": stride_expr,
896*da0073e9SAndroid Build Coastguard Worker            },
897*da0073e9SAndroid Build Coastguard Worker        ),
898*da0073e9SAndroid Build Coastguard Worker        # replace self.layout() with self_layout
899*da0073e9SAndroid Build Coastguard Worker        (
900*da0073e9SAndroid Build Coastguard Worker            r"{}.layout\(\)",
901*da0073e9SAndroid Build Coastguard Worker            {
902*da0073e9SAndroid Build Coastguard Worker                "suffix": "_layout",
903*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
904*da0073e9SAndroid Build Coastguard Worker            },
905*da0073e9SAndroid Build Coastguard Worker        ),
906*da0073e9SAndroid Build Coastguard Worker        # replace self.is_conj() with self_conjugate
907*da0073e9SAndroid Build Coastguard Worker        (
908*da0073e9SAndroid Build Coastguard Worker            r"{}.is_conj\(\)",
909*da0073e9SAndroid Build Coastguard Worker            {
910*da0073e9SAndroid Build Coastguard Worker                "suffix": "_conjugate",
911*da0073e9SAndroid Build Coastguard Worker                "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
912*da0073e9SAndroid Build Coastguard Worker            },
913*da0073e9SAndroid Build Coastguard Worker        ),
914*da0073e9SAndroid Build Coastguard Worker    ]
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    # find which arguments need to be saved
917*da0073e9SAndroid Build Coastguard Worker    saved: list[SavedAttribute] = []
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker    if ".sizes()" in formula or "->sizes()" in formula:
920*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
921*da0073e9SAndroid Build Coastguard Worker            ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
922*da0073e9SAndroid Build Coastguard Worker            + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
923*da0073e9SAndroid Build Coastguard Worker        )
924*da0073e9SAndroid Build Coastguard Worker    if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
925*da0073e9SAndroid Build Coastguard Worker        r"->size\([-]?\d+\)", formula
926*da0073e9SAndroid Build Coastguard Worker    ):
927*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
928*da0073e9SAndroid Build Coastguard Worker            ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
929*da0073e9SAndroid Build Coastguard Worker            + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
930*da0073e9SAndroid Build Coastguard Worker        )
931*da0073e9SAndroid Build Coastguard Worker    if ".strides()" in formula or "->strides()" in formula:
932*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
933*da0073e9SAndroid Build Coastguard Worker            ".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
934*da0073e9SAndroid Build Coastguard Worker            + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
935*da0073e9SAndroid Build Coastguard Worker        )
936*da0073e9SAndroid Build Coastguard Worker    for nctype in nctypes:
937*da0073e9SAndroid Build Coastguard Worker        name = (
938*da0073e9SAndroid Build Coastguard Worker            nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
939*da0073e9SAndroid Build Coastguard Worker        )
940*da0073e9SAndroid Build Coastguard Worker        # First search the formula for expressions which can be evaluated
941*da0073e9SAndroid Build Coastguard Worker        # when the autograd Function is created to avoid saving variables
942*da0073e9SAndroid Build Coastguard Worker        for regex, info in REPLACEMENTS:
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker            def repl(m: re.Match[str]) -> str:
945*da0073e9SAndroid Build Coastguard Worker                suffix: str = (
946*da0073e9SAndroid Build Coastguard Worker                    info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
947*da0073e9SAndroid Build Coastguard Worker                )
948*da0073e9SAndroid Build Coastguard Worker                expr: str = info["expr"](name) if "expr" in info else m.group(0)
949*da0073e9SAndroid Build Coastguard Worker                saved.append(
950*da0073e9SAndroid Build Coastguard Worker                    SavedAttribute(
951*da0073e9SAndroid Build Coastguard Worker                        nctype=info["nctype"](name + suffix),
952*da0073e9SAndroid Build Coastguard Worker                        expr=expr,
953*da0073e9SAndroid Build Coastguard Worker                    )
954*da0073e9SAndroid Build Coastguard Worker                )
955*da0073e9SAndroid Build Coastguard Worker                if "res" in info:
956*da0073e9SAndroid Build Coastguard Worker                    replacement: str = info["res"](name)
957*da0073e9SAndroid Build Coastguard Worker                    return replacement
958*da0073e9SAndroid Build Coastguard Worker                return name + suffix
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker            formula = re.sub(regex.format(name), repl, formula)
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker        # std::optional<std::string> types stored in Backward nodes must be
963*da0073e9SAndroid Build Coastguard Worker        # converted to std::optional<std::string_view> before being passed into
964*da0073e9SAndroid Build Coastguard Worker        # the backward function
965*da0073e9SAndroid Build Coastguard Worker        if nctype.type == OptionalCType(BaseCType(stringT)):
966*da0073e9SAndroid Build Coastguard Worker            formula = re.sub(
967*da0073e9SAndroid Build Coastguard Worker                rf"\b{name}\b",
968*da0073e9SAndroid Build Coastguard Worker                f"{name}.has_value() ? std::optional<c10::string_view>({name}.value()) : std::nullopt",
969*da0073e9SAndroid Build Coastguard Worker                formula,
970*da0073e9SAndroid Build Coastguard Worker            )
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker        # Find any variables which remain in the formula and save them
973*da0073e9SAndroid Build Coastguard Worker        if re.search(IDENT_REGEX.format(name), formula):
974*da0073e9SAndroid Build Coastguard Worker            saved.append(
975*da0073e9SAndroid Build Coastguard Worker                SavedAttribute(
976*da0073e9SAndroid Build Coastguard Worker                    nctype=nctype,
977*da0073e9SAndroid Build Coastguard Worker                    expr=name,
978*da0073e9SAndroid Build Coastguard Worker                )
979*da0073e9SAndroid Build Coastguard Worker            )
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    return formula, tuple(saved)
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker
984*da0073e9SAndroid Build Coastguard Workerdef _create_op_prefix(name: str) -> str:
985*da0073e9SAndroid Build Coastguard Worker    """Takes a native function name converts to a op prefix name.
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker    Note that the "name" parameter must be the native function name
988*da0073e9SAndroid Build Coastguard Worker    without the optional variant suffix, so "add" instead of
989*da0073e9SAndroid Build Coastguard Worker    "add.out".
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker    OP names correspond to classes, hence the change to title case.
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker    Example::
994*da0073e9SAndroid Build Coastguard Worker    >>> _create_op_prefix('add')
995*da0073e9SAndroid Build Coastguard Worker    'AddBackward'
996*da0073e9SAndroid Build Coastguard Worker    """
997*da0073e9SAndroid Build Coastguard Worker    camel_case = "".join([p.title() for p in name.split("_")])
998*da0073e9SAndroid Build Coastguard Worker    return (camel_case + "Backward").replace("ForwardBackward", "Backward")
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Workerdef dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
1002*da0073e9SAndroid Build Coastguard Worker    seen: set[str] = set()
1003*da0073e9SAndroid Build Coastguard Worker    saved: list[SavedAttribute] = []
1004*da0073e9SAndroid Build Coastguard Worker    for var in vars:
1005*da0073e9SAndroid Build Coastguard Worker        name = (
1006*da0073e9SAndroid Build Coastguard Worker            var.nctype.name.name
1007*da0073e9SAndroid Build Coastguard Worker            if isinstance(var.nctype.name, SpecialArgName)
1008*da0073e9SAndroid Build Coastguard Worker            else var.nctype.name
1009*da0073e9SAndroid Build Coastguard Worker        )
1010*da0073e9SAndroid Build Coastguard Worker        if name in seen:
1011*da0073e9SAndroid Build Coastguard Worker            continue
1012*da0073e9SAndroid Build Coastguard Worker        seen.add(name)
1013*da0073e9SAndroid Build Coastguard Worker        saved.append(var)
1014*da0073e9SAndroid Build Coastguard Worker    return saved
1015