1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport re 4*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 5*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, Sequence 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom torchgen import local 8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 9*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT 10*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 11*da0073e9SAndroid Build Coastguard Worker BaseTy, 12*da0073e9SAndroid Build Coastguard Worker BaseType, 13*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 14*da0073e9SAndroid Build Coastguard Worker ListType, 15*da0073e9SAndroid Build Coastguard Worker NativeFunction, 16*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup, 17*da0073e9SAndroid Build Coastguard Worker SchemaKind, 18*da0073e9SAndroid Build Coastguard Worker Type, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import IDENT_REGEX 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker# Represents a saved attribute involved in backward calculation. 24*da0073e9SAndroid Build Coastguard Worker# Note that it can be a derived property of an input argument, e.g.: 25*da0073e9SAndroid Build Coastguard Worker# we could save `other.scalar_type()` instead of the entire `other` tensor. 26*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 27*da0073e9SAndroid Build Coastguard Workerclass SavedAttribute: 28*da0073e9SAndroid Build Coastguard Worker # The NamedCType holds the updated name and cpp type of the attribute 29*da0073e9SAndroid Build Coastguard Worker # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type` 30*da0073e9SAndroid Build Coastguard Worker nctype: NamedCType 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker # The expression to read the derived property at save time, e.g.: 33*da0073e9SAndroid Build Coastguard Worker # `other.scalar_type()`. 34*da0073e9SAndroid Build Coastguard Worker expr: str 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker# Represents a backward formula that calculates derivatives for one 38*da0073e9SAndroid Build Coastguard Worker# or more tensors. 39*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 40*da0073e9SAndroid Build Coastguard Workerclass Derivative: 41*da0073e9SAndroid Build Coastguard Worker # The formula string (legit C++ expression). 42*da0073e9SAndroid Build Coastguard Worker # Note that expressions against input arguments have been replaced with the 43*da0073e9SAndroid Build Coastguard Worker # corresponding saved attributes. 44*da0073e9SAndroid Build Coastguard Worker # E.g.: 45*da0073e9SAndroid Build Coastguard Worker # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` 46*da0073e9SAndroid Build Coastguard Worker # here: `mul_tensor_backward(grad, self, other_scalar_type)` 47*da0073e9SAndroid Build Coastguard Worker formula: str 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker # The formula string before input argument replacement 50*da0073e9SAndroid Build Coastguard Worker original_formula: str 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker # Names of the arguments for which this formula calculates derivatives. 53*da0073e9SAndroid Build Coastguard Worker var_names: tuple[str, ...] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker # Saved inputs that are referenced by the formula. 56*da0073e9SAndroid Build Coastguard Worker saved_inputs: tuple[SavedAttribute, ...] 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker # Saved outputs that are referenced by the formula. 59*da0073e9SAndroid Build Coastguard Worker saved_outputs: tuple[SavedAttribute, ...] 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker # Gradients that are referenced by name in the formula. 62*da0073e9SAndroid Build Coastguard Worker named_gradients: set[str] 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker# Represents a forward formula that calculates forward derivatives 66*da0073e9SAndroid Build Coastguard Worker# for one tensor. 67*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 68*da0073e9SAndroid Build Coastguard Workerclass ForwardDerivative: 69*da0073e9SAndroid Build Coastguard Worker # The formula string (legit C++ expression). 70*da0073e9SAndroid Build Coastguard Worker # Note that special keywords such as "linear" or "element_wise" have been 71*da0073e9SAndroid Build Coastguard Worker # replaced by the automatically generated formula. 72*da0073e9SAndroid Build Coastguard Worker formula: str 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker # Name of the output arguments for which this formula calculates forward 75*da0073e9SAndroid Build Coastguard Worker # derivatives 76*da0073e9SAndroid Build Coastguard Worker var_names: tuple[str, ...] 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker # Type of the output arguments for which this formula calculates forward 79*da0073e9SAndroid Build Coastguard Worker # derivatives 80*da0073e9SAndroid Build Coastguard Worker var_types: tuple[Type, ...] 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker # Inputs for which the forward derivatives are required for this formula 83*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad: tuple[str, ...] | None 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker # Inputs for which the primal is required for this formula 86*da0073e9SAndroid Build Coastguard Worker required_inputs_primal: tuple[str, ...] | None 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker # Flag to specify if this formula requires the original value of self 89*da0073e9SAndroid Build Coastguard Worker # This is only used by inplace operations 90*da0073e9SAndroid Build Coastguard Worker required_original_self_value: bool 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker # If this formula is specified in derivatives.yaml or if we are re-using the 93*da0073e9SAndroid Build Coastguard Worker # out of place formula for inplace 94*da0073e9SAndroid Build Coastguard Worker is_reusing_outplace_formula: bool 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker# Represents differentiability info for a NativeFunction. 98*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 99*da0073e9SAndroid Build Coastguard Workerclass DifferentiabilityInfo: 100*da0073e9SAndroid Build Coastguard Worker # The base name read from derivatives.yaml. 101*da0073e9SAndroid Build Coastguard Worker name: str 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker # The matching native function. 104*da0073e9SAndroid Build Coastguard Worker # 105*da0073e9SAndroid Build Coastguard Worker # There can be multiple NativeFunction having the same base name: 106*da0073e9SAndroid Build Coastguard Worker # - different overloads with different types of input arguments; 107*da0073e9SAndroid Build Coastguard Worker # - in-place/out/functional variants of the same function; 108*da0073e9SAndroid Build Coastguard Worker # 109*da0073e9SAndroid Build Coastguard Worker # We first use the schema string (under the 'name' key) in derivatives.yaml 110*da0073e9SAndroid Build Coastguard Worker # to find the NativeFunction having the same schema string. 111*da0073e9SAndroid Build Coastguard Worker # Then we find the in-place/out/functional variants of the matching function. 112*da0073e9SAndroid Build Coastguard Worker # Among these variants, we choose the one having the same name as the 113*da0073e9SAndroid Build Coastguard Worker # derivatives.yaml entry. If there is no exact match, then we choose the 114*da0073e9SAndroid Build Coastguard Worker # in-place variant. 115*da0073e9SAndroid Build Coastguard Worker # TODO: maybe the logic to search for all variants is no longer necessary? 116*da0073e9SAndroid Build Coastguard Worker func: NativeFunction 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker # The name of the generated autograd function. 119*da0073e9SAndroid Build Coastguard Worker # It's set only if we will calculate a derivative, i.e. 120*da0073e9SAndroid Build Coastguard Worker # 'args_with_derivatives' is not empty. 121*da0073e9SAndroid Build Coastguard Worker op: str | None 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker # The derivatives formulae for this function. 124*da0073e9SAndroid Build Coastguard Worker # Note that the length of this sequence is the number of differentiable inputs 125*da0073e9SAndroid Build Coastguard Worker derivatives: Sequence[Derivative] 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker # The forward derivatives formulae for this function. 128*da0073e9SAndroid Build Coastguard Worker # Note that the length of this sequence is the number of differentiable outputs 129*da0073e9SAndroid Build Coastguard Worker forward_derivatives: Sequence[ForwardDerivative] 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker # The union of 'saved_inputs' of all 'derivatives'. 132*da0073e9SAndroid Build Coastguard Worker all_saved_inputs: Sequence[SavedAttribute] 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker # The union of 'saved_outputs' of all 'derivatives'. 135*da0073e9SAndroid Build Coastguard Worker all_saved_outputs: Sequence[SavedAttribute] 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker # All named gradients that are available for use, in the same 138*da0073e9SAndroid Build Coastguard Worker # order as in the grads vector. 139*da0073e9SAndroid Build Coastguard Worker available_named_gradients: Sequence[str] 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker # The named gradients that are used in any of the derivatives. 142*da0073e9SAndroid Build Coastguard Worker # Invariant: all(name in available_named_gradients for name in used_named_gradients) 143*da0073e9SAndroid Build Coastguard Worker used_named_gradients: set[str] 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker # The function's input arguments for which it calculates derivatives. 146*da0073e9SAndroid Build Coastguard Worker # It's the union of 'var_names' of all 'derivatives', sorted by the 147*da0073e9SAndroid Build Coastguard Worker # argument order in the function schema. 148*da0073e9SAndroid Build Coastguard Worker args_with_derivatives: Sequence[Binding] 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker # Names of arguments whose derivative formula is 'non_differentiable'. 151*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names: Sequence[str] 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker # Raw data read from derivatives.yaml. 154*da0073e9SAndroid Build Coastguard Worker output_differentiability: list[bool] | None 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker # output_differentiability in derivatives.yaml can be a list of 157*da0073e9SAndroid Build Coastguard Worker # conditions that express if the output is differentiable. In this case, 158*da0073e9SAndroid Build Coastguard Worker # the number of conditions must match the number of outputs 159*da0073e9SAndroid Build Coastguard Worker # (NB: we only support one condition right now). 160*da0073e9SAndroid Build Coastguard Worker # output_differentiability gets populated with True for each condition, 161*da0073e9SAndroid Build Coastguard Worker # while output_differentiability_conditions gets populated with the conditions 162*da0073e9SAndroid Build Coastguard Worker output_differentiability_conditions: list[str] | None 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker @property 165*da0073e9SAndroid Build Coastguard Worker def has_derivatives(self) -> bool: 166*da0073e9SAndroid Build Coastguard Worker return len(self.args_with_derivatives) > 0 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker # Generates a new DifferentiabilityInfo using the exact same set of derivative information, 169*da0073e9SAndroid Build Coastguard Worker # but with a new operator name. 170*da0073e9SAndroid Build Coastguard Worker # This is used when generating "copy" variants of view ops, 171*da0073e9SAndroid Build Coastguard Worker # which are able to use the exact same derivative formula as the original view op 172*da0073e9SAndroid Build Coastguard Worker # See Note [Codegen'd {view}_copy Operators] 173*da0073e9SAndroid Build Coastguard Worker def create_view_copy_from_view_derivative( 174*da0073e9SAndroid Build Coastguard Worker self, g: NativeFunctionsViewGroup 175*da0073e9SAndroid Build Coastguard Worker ) -> DifferentiabilityInfo | None: 176*da0073e9SAndroid Build Coastguard Worker if g.view_copy is None: 177*da0073e9SAndroid Build Coastguard Worker return None 178*da0073e9SAndroid Build Coastguard Worker f = g.view_copy 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker name_split_by_period = self.name.split(".", maxsplit=2) 181*da0073e9SAndroid Build Coastguard Worker # Append a "_copy" to the base name of the operator (but keep the overload name the same) 182*da0073e9SAndroid Build Coastguard Worker view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( 183*da0073e9SAndroid Build Coastguard Worker name_split_by_period[1:] 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker view_copy_op_name = None if self.op is None else f"{self.op}_copy" 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker return DifferentiabilityInfo( 188*da0073e9SAndroid Build Coastguard Worker # Use the "_copy" version of name/func/op 189*da0073e9SAndroid Build Coastguard Worker name=view_copy_name, 190*da0073e9SAndroid Build Coastguard Worker func=f, 191*da0073e9SAndroid Build Coastguard Worker op=view_copy_op_name, 192*da0073e9SAndroid Build Coastguard Worker # But keep all derivative info the same 193*da0073e9SAndroid Build Coastguard Worker derivatives=self.derivatives, 194*da0073e9SAndroid Build Coastguard Worker forward_derivatives=self.forward_derivatives, 195*da0073e9SAndroid Build Coastguard Worker all_saved_inputs=self.all_saved_inputs, 196*da0073e9SAndroid Build Coastguard Worker all_saved_outputs=self.all_saved_outputs, 197*da0073e9SAndroid Build Coastguard Worker available_named_gradients=self.available_named_gradients, 198*da0073e9SAndroid Build Coastguard Worker used_named_gradients=self.used_named_gradients, 199*da0073e9SAndroid Build Coastguard Worker args_with_derivatives=self.args_with_derivatives, 200*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names=self.non_differentiable_arg_names, 201*da0073e9SAndroid Build Coastguard Worker output_differentiability=self.output_differentiability, 202*da0073e9SAndroid Build Coastguard Worker output_differentiability_conditions=self.output_differentiability_conditions, 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerdef uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: 207*da0073e9SAndroid Build Coastguard Worker if info is None: 208*da0073e9SAndroid Build Coastguard Worker return False 209*da0073e9SAndroid Build Coastguard Worker for derivative in info.derivatives: 210*da0073e9SAndroid Build Coastguard Worker formula = derivative.formula 211*da0073e9SAndroid Build Coastguard Worker if re.search(IDENT_REGEX.format(ident), formula): 212*da0073e9SAndroid Build Coastguard Worker return True 213*da0073e9SAndroid Build Coastguard Worker return False 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Workerdef uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: 217*da0073e9SAndroid Build Coastguard Worker return uses_ident(info, "retain_variables") 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Workerdef uses_single_grad(info: DifferentiabilityInfo | None) -> bool: 221*da0073e9SAndroid Build Coastguard Worker return uses_ident(info, "grad") 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker# Represents a differentiable `Argument`. 225*da0073e9SAndroid Build Coastguard Worker# How is it different from the `Argument` type? 226*da0073e9SAndroid Build Coastguard Worker# - It's processed Arguments which are differentiable and only used in the 227*da0073e9SAndroid Build Coastguard Worker# context of the autograd codegen; 228*da0073e9SAndroid Build Coastguard Worker# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; 229*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 230*da0073e9SAndroid Build Coastguard Workerclass DifferentiableInput: 231*da0073e9SAndroid Build Coastguard Worker name: str 232*da0073e9SAndroid Build Coastguard Worker type: Type 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. 235*da0073e9SAndroid Build Coastguard Worker cpp_type: str 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker# Represents a differentiable `Return`. 239*da0073e9SAndroid Build Coastguard Worker# How it it different from the `Return` type? 240*da0073e9SAndroid Build Coastguard Worker# - The name in `Return` is optional. Here it is always populated using the same 241*da0073e9SAndroid Build Coastguard Worker# `cpp.return_names()` method. 242*da0073e9SAndroid Build Coastguard Worker# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? 243*da0073e9SAndroid Build Coastguard Worker# - It's processed Returns which are differentiable, in compliance with the 244*da0073e9SAndroid Build Coastguard Worker# `output_differentiability` field defined in derivatives.yaml (if specified), 245*da0073e9SAndroid Build Coastguard Worker# and are only used in the context of the autograd codegen; 246*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 247*da0073e9SAndroid Build Coastguard Workerclass DifferentiableOutput: 248*da0073e9SAndroid Build Coastguard Worker name: str 249*da0073e9SAndroid Build Coastguard Worker type: Type 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. 252*da0073e9SAndroid Build Coastguard Worker cpp_type: str 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 256*da0073e9SAndroid Build Coastguard Workerclass NativeFunctionWithDifferentiabilityInfo: 257*da0073e9SAndroid Build Coastguard Worker func: NativeFunction 258*da0073e9SAndroid Build Coastguard Worker info: dict[str, DifferentiabilityInfo] | None 259*da0073e9SAndroid Build Coastguard Worker fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker# TODO: Update comment below since it is out of date. 263*da0073e9SAndroid Build Coastguard Workerdef dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: 264*da0073e9SAndroid Build Coastguard Worker """How are we going to call the underlying implementation of a 265*da0073e9SAndroid Build Coastguard Worker declaration? There are two strategies: 266*da0073e9SAndroid Build Coastguard Worker - use_derived: we want to call the implementation on CPUDoubleType 267*da0073e9SAndroid Build Coastguard Worker (or a similar, derived Type instance). Because these derived 268*da0073e9SAndroid Build Coastguard Worker instances deal in Tensors, not Variables (it's a completely different 269*da0073e9SAndroid Build Coastguard Worker object, so it doesn't dispatch back to VariableType), code on 270*da0073e9SAndroid Build Coastguard Worker this dispatch path needs to wrap/unwrap tensors. If the 271*da0073e9SAndroid Build Coastguard Worker derived implementation takes and returns tensors, the 272*da0073e9SAndroid Build Coastguard Worker implementation is usually differentiable (although we also use 273*da0073e9SAndroid Build Coastguard Worker the derived dispatch path for non-differentiable functions 274*da0073e9SAndroid Build Coastguard Worker that we still want to dispatch on the derived Type instance; 275*da0073e9SAndroid Build Coastguard Worker e.g., size()) 276*da0073e9SAndroid Build Coastguard Worker - use_type: we want to call the implementation on Type, because 277*da0073e9SAndroid Build Coastguard Worker it is implemented concretely, and the functions it invokes will 278*da0073e9SAndroid Build Coastguard Worker get dispatched back to VariableType (which will ensure that they 279*da0073e9SAndroid Build Coastguard Worker are differentiable.) 280*da0073e9SAndroid Build Coastguard Worker """ 281*da0073e9SAndroid Build Coastguard Worker # fn is derived as long as any of its per-key differentiability infos 282*da0073e9SAndroid Build Coastguard Worker # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType 283*da0073e9SAndroid Build Coastguard Worker # and ADInplaceOrViewType. We want to generate these functions as long as a 284*da0073e9SAndroid Build Coastguard Worker # derivative is defined for ANY dispatch key. 285*da0073e9SAndroid Build Coastguard Worker if fn.func.is_abstract or ( 286*da0073e9SAndroid Build Coastguard Worker fn.info is not None and any(info.has_derivatives for info in fn.info.values()) 287*da0073e9SAndroid Build Coastguard Worker ): 288*da0073e9SAndroid Build Coastguard Worker # If the function is abstract (not implemented on at::Type), we must 289*da0073e9SAndroid Build Coastguard Worker # call the implementation on the derived type with unpacked tensors. 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker # If the function has a derivative specified and is concrete, we could 292*da0073e9SAndroid Build Coastguard Worker # call either implementation. We prefer the calling the derived 293*da0073e9SAndroid Build Coastguard Worker # type's implementation with unpacked tensors because it is more 294*da0073e9SAndroid Build Coastguard Worker # performant in some cases: any internal calls to other ATen functions 295*da0073e9SAndroid Build Coastguard Worker # won't have the history tracked. 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker # If the function has a type dispatched argument (i.e. is a factory), 298*da0073e9SAndroid Build Coastguard Worker # we prefer calling the derived type's implementation both because it is 299*da0073e9SAndroid Build Coastguard Worker # more performant and to ensure factory functions return tensors with _version 300*da0073e9SAndroid Build Coastguard Worker # of 0 (probably not strictly necessary, but nice to have to keeps versions simple 301*da0073e9SAndroid Build Coastguard Worker # to understand. 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker return "use_derived" 304*da0073e9SAndroid Build Coastguard Worker else: 305*da0073e9SAndroid Build Coastguard Worker # If the function is concrete (we don't have to override it) and we 306*da0073e9SAndroid Build Coastguard Worker # didn't declare it in derivatives.yaml, we'll assume that it is 307*da0073e9SAndroid Build Coastguard Worker # actually implemented out of differentiable functions. (This 308*da0073e9SAndroid Build Coastguard Worker # assumption might not hold, but then you'll see gradcheck fail.) 309*da0073e9SAndroid Build Coastguard Worker return "use_type" 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Workerdef is_foreach_func(f: NativeFunction) -> bool: 313*da0073e9SAndroid Build Coastguard Worker return f.func.name.name.base.startswith("_foreach_") 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind 317*da0073e9SAndroid Build Coastguard Worker# is functional for their backward derivatives (and forward derivatives in the future), i.e., 318*da0073e9SAndroid Build Coastguard Worker# they would find such one in `functional_info_by_signature`. There however are some exceptions: 319*da0073e9SAndroid Build Coastguard Worker_foreach_with_inplace_ref = {"_foreach_zero_"} 320*da0073e9SAndroid Build Coastguard Worker_foreach_with_tensor_overload = { 321*da0073e9SAndroid Build Coastguard Worker "_foreach_add.Tensor", 322*da0073e9SAndroid Build Coastguard Worker "_foreach_mul.Tensor", 323*da0073e9SAndroid Build Coastguard Worker "_foreach_div.Tensor", 324*da0073e9SAndroid Build Coastguard Worker} 325*da0073e9SAndroid Build Coastguard Worker# The following do not support the alpha kwarg, which the nonforeach versions support. 326*da0073e9SAndroid Build Coastguard Worker_skip_argument_len_check = { 327*da0073e9SAndroid Build Coastguard Worker "_foreach_add.Scalar", 328*da0073e9SAndroid Build Coastguard Worker "_foreach_add_.Scalar", 329*da0073e9SAndroid Build Coastguard Worker "_foreach_add.ScalarList", 330*da0073e9SAndroid Build Coastguard Worker "_foreach_add_.ScalarList", 331*da0073e9SAndroid Build Coastguard Worker "_foreach_sub.Scalar", 332*da0073e9SAndroid Build Coastguard Worker "_foreach_sub_.Scalar", 333*da0073e9SAndroid Build Coastguard Worker "_foreach_sub.ScalarList", 334*da0073e9SAndroid Build Coastguard Worker "_foreach_sub_.ScalarList", 335*da0073e9SAndroid Build Coastguard Worker} 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function 339*da0073e9SAndroid Build Coastguard Worker# reference to generate derivatives. 340*da0073e9SAndroid Build Coastguard Workerdef is_reference_for_foreach( 341*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 342*da0073e9SAndroid Build Coastguard Worker function_schema: FunctionSchema, 343*da0073e9SAndroid Build Coastguard Worker) -> bool: 344*da0073e9SAndroid Build Coastguard Worker return ( 345*da0073e9SAndroid Build Coastguard Worker f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base 346*da0073e9SAndroid Build Coastguard Worker and ( 347*da0073e9SAndroid Build Coastguard Worker not function_schema.name.name.inplace 348*da0073e9SAndroid Build Coastguard Worker or str(f.func.name) in _foreach_with_inplace_ref 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker and ( 351*da0073e9SAndroid Build Coastguard Worker str(f.func.name) in _skip_argument_len_check 352*da0073e9SAndroid Build Coastguard Worker or len(f.func.arguments.flat_non_out) 353*da0073e9SAndroid Build Coastguard Worker == len(function_schema.arguments.flat_non_out) 354*da0073e9SAndroid Build Coastguard Worker ) 355*da0073e9SAndroid Build Coastguard Worker and all( 356*da0073e9SAndroid Build Coastguard Worker ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) 357*da0073e9SAndroid Build Coastguard Worker for arg, ref_arg in zip( 358*da0073e9SAndroid Build Coastguard Worker f.func.arguments.flat_non_out, 359*da0073e9SAndroid Build Coastguard Worker function_schema.arguments.flat_non_out, 360*da0073e9SAndroid Build Coastguard Worker ) 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker ) 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker# TODO(crcrpar): Avoid hard coding "Default" ideally. 366*da0073e9SAndroid Build Coastguard Workerdef gen_foreach_derivativeinfo( 367*da0073e9SAndroid Build Coastguard Worker foreach_function: NativeFunction, 368*da0073e9SAndroid Build Coastguard Worker functional_info_by_signature: dict[ 369*da0073e9SAndroid Build Coastguard Worker FunctionSchema, dict[str, DifferentiabilityInfo] 370*da0073e9SAndroid Build Coastguard Worker ], 371*da0073e9SAndroid Build Coastguard Worker non_functional_info_by_signature: dict[ 372*da0073e9SAndroid Build Coastguard Worker FunctionSchema, dict[str, DifferentiabilityInfo] 373*da0073e9SAndroid Build Coastguard Worker ], 374*da0073e9SAndroid Build Coastguard Worker dispatch_key: str = "Default", 375*da0073e9SAndroid Build Coastguard Worker) -> tuple[DifferentiabilityInfo | None, bool]: 376*da0073e9SAndroid Build Coastguard Worker """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker The second return value indicates whether the info is generated in this function. 379*da0073e9SAndroid Build Coastguard Worker """ 380*da0073e9SAndroid Build Coastguard Worker ref_diff_info: DifferentiabilityInfo | None = None 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker for function_schema, diff_info in functional_info_by_signature.items(): 383*da0073e9SAndroid Build Coastguard Worker if not is_reference_for_foreach(foreach_function, function_schema): 384*da0073e9SAndroid Build Coastguard Worker continue 385*da0073e9SAndroid Build Coastguard Worker ref_diff_info = diff_info[dispatch_key] 386*da0073e9SAndroid Build Coastguard Worker if ref_diff_info is not None: 387*da0073e9SAndroid Build Coastguard Worker break 388*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature 389*da0073e9SAndroid Build Coastguard Worker # while the info of `zero_` is in non_functional_info_by_signature 390*da0073e9SAndroid Build Coastguard Worker if ( 391*da0073e9SAndroid Build Coastguard Worker ref_diff_info is None 392*da0073e9SAndroid Build Coastguard Worker and foreach_function.func.kind() == SchemaKind.inplace 393*da0073e9SAndroid Build Coastguard Worker and str(foreach_function.func.name) in _foreach_with_inplace_ref 394*da0073e9SAndroid Build Coastguard Worker ): 395*da0073e9SAndroid Build Coastguard Worker for function_schema, diff_info in non_functional_info_by_signature.items(): 396*da0073e9SAndroid Build Coastguard Worker if not is_reference_for_foreach(foreach_function, function_schema): 397*da0073e9SAndroid Build Coastguard Worker continue 398*da0073e9SAndroid Build Coastguard Worker ref_diff_info = diff_info[dispatch_key] 399*da0073e9SAndroid Build Coastguard Worker if ref_diff_info is not None: 400*da0073e9SAndroid Build Coastguard Worker break 401*da0073e9SAndroid Build Coastguard Worker if ref_diff_info is None: 402*da0073e9SAndroid Build Coastguard Worker return None, False 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker # non out-place uses the existing Derivative. 405*da0073e9SAndroid Build Coastguard Worker if foreach_function.func.kind() == SchemaKind.inplace: 406*da0073e9SAndroid Build Coastguard Worker return ref_diff_info, False 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker map_refarg2foreacharg, map_name2arg = {}, {} 409*da0073e9SAndroid Build Coastguard Worker for i, (arg, ref_arg) in enumerate( 410*da0073e9SAndroid Build Coastguard Worker zip( 411*da0073e9SAndroid Build Coastguard Worker foreach_function.func.arguments.flat_non_out, 412*da0073e9SAndroid Build Coastguard Worker function_schema.arguments.flat_non_out, 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker ): 415*da0073e9SAndroid Build Coastguard Worker map_refarg2foreacharg[ref_arg.name] = arg.name 416*da0073e9SAndroid Build Coastguard Worker map_name2arg[arg.name] = arg 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] 419*da0073e9SAndroid Build Coastguard Worker modified_derivative_formulas = [] 420*da0073e9SAndroid Build Coastguard Worker for i, derivative in enumerate(ref_diff_info.derivatives): 421*da0073e9SAndroid Build Coastguard Worker modified_formula = derivative.formula.replace("grad", "grads[i]").replace( 422*da0073e9SAndroid Build Coastguard Worker "result", "result[i]" 423*da0073e9SAndroid Build Coastguard Worker ) 424*da0073e9SAndroid Build Coastguard Worker saved_inputs, saved_outputs = [], [] 425*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): This context seems necessary to call `cpp.argument_type` 426*da0073e9SAndroid Build Coastguard Worker with local.parametrize( 427*da0073e9SAndroid Build Coastguard Worker use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, 428*da0073e9SAndroid Build Coastguard Worker use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, 429*da0073e9SAndroid Build Coastguard Worker ): 430*da0073e9SAndroid Build Coastguard Worker for ref_input in derivative.saved_inputs: 431*da0073e9SAndroid Build Coastguard Worker ref_input_jit_name = ref_input.expr.split(".")[0] 432*da0073e9SAndroid Build Coastguard Worker mapped_name = map_refarg2foreacharg[ref_input_jit_name] 433*da0073e9SAndroid Build Coastguard Worker if isinstance(map_name2arg[mapped_name].type, ListType): 434*da0073e9SAndroid Build Coastguard Worker mapped_expr = mapped_name + "[i]" 435*da0073e9SAndroid Build Coastguard Worker else: 436*da0073e9SAndroid Build Coastguard Worker mapped_expr = mapped_name 437*da0073e9SAndroid Build Coastguard Worker new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) 438*da0073e9SAndroid Build Coastguard Worker modified_formula = modified_formula.replace( 439*da0073e9SAndroid Build Coastguard Worker cast(str, ref_input.nctype.name), new_expr 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) 443*da0073e9SAndroid Build Coastguard Worker canonical_nctype = NamedCType( 444*da0073e9SAndroid Build Coastguard Worker nctype.name, nctype.type.remove_const_ref() 445*da0073e9SAndroid Build Coastguard Worker ) 446*da0073e9SAndroid Build Coastguard Worker saved_inputs.append( 447*da0073e9SAndroid Build Coastguard Worker SavedAttribute(nctype=canonical_nctype, expr=mapped_name) 448*da0073e9SAndroid Build Coastguard Worker ) 449*da0073e9SAndroid Build Coastguard Worker for ref_output in derivative.saved_outputs: 450*da0073e9SAndroid Build Coastguard Worker if ref_output.nctype.name == "result": 451*da0073e9SAndroid Build Coastguard Worker saved_outputs.append( 452*da0073e9SAndroid Build Coastguard Worker SavedAttribute( 453*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType( 454*da0073e9SAndroid Build Coastguard Worker name="result", type=BaseCType(tensorListT) 455*da0073e9SAndroid Build Coastguard Worker ), 456*da0073e9SAndroid Build Coastguard Worker expr="result", 457*da0073e9SAndroid Build Coastguard Worker ) 458*da0073e9SAndroid Build Coastguard Worker ) 459*da0073e9SAndroid Build Coastguard Worker else: 460*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("") 461*da0073e9SAndroid Build Coastguard Worker var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] 462*da0073e9SAndroid Build Coastguard Worker all_var_names.extend(var_names) 463*da0073e9SAndroid Build Coastguard Worker all_saved_inputs.extend(saved_inputs) 464*da0073e9SAndroid Build Coastguard Worker all_saved_outputs.extend(saved_outputs) 465*da0073e9SAndroid Build Coastguard Worker modified_derivative = Derivative( 466*da0073e9SAndroid Build Coastguard Worker formula=modified_formula, 467*da0073e9SAndroid Build Coastguard Worker original_formula=derivative.formula, 468*da0073e9SAndroid Build Coastguard Worker var_names=tuple(var_names), 469*da0073e9SAndroid Build Coastguard Worker saved_inputs=tuple(saved_inputs), 470*da0073e9SAndroid Build Coastguard Worker saved_outputs=tuple(saved_outputs), 471*da0073e9SAndroid Build Coastguard Worker named_gradients=set(), 472*da0073e9SAndroid Build Coastguard Worker ) 473*da0073e9SAndroid Build Coastguard Worker modified_derivative_formulas.append(modified_derivative) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker with local.parametrize( 476*da0073e9SAndroid Build Coastguard Worker use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, 477*da0073e9SAndroid Build Coastguard Worker use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, 478*da0073e9SAndroid Build Coastguard Worker ): 479*da0073e9SAndroid Build Coastguard Worker args_with_derivatives = [ 480*da0073e9SAndroid Build Coastguard Worker Binding( 481*da0073e9SAndroid Build Coastguard Worker name=arg.name, 482*da0073e9SAndroid Build Coastguard Worker nctype=cpp.argument_type(arg, binds=arg.name), 483*da0073e9SAndroid Build Coastguard Worker argument=arg, 484*da0073e9SAndroid Build Coastguard Worker default=None, 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker for arg in foreach_function.func.arguments.flat_non_out 487*da0073e9SAndroid Build Coastguard Worker if arg.name in all_var_names 488*da0073e9SAndroid Build Coastguard Worker ] 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker forward_derivatives: list[ForwardDerivative] = [] 491*da0073e9SAndroid Build Coastguard Worker fw_derivative: ForwardDerivative 492*da0073e9SAndroid Build Coastguard Worker for fw_derivative in ref_diff_info.forward_derivatives: 493*da0073e9SAndroid Build Coastguard Worker var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef] 494*da0073e9SAndroid Build Coastguard Worker var_types: list[Type] = list(fw_derivative.var_types) 495*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad: list[str] = [] 496*da0073e9SAndroid Build Coastguard Worker required_inputs_primal: list[str] = [] 497*da0073e9SAndroid Build Coastguard Worker if fw_derivative.required_inputs_fw_grad is not None: 498*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) 499*da0073e9SAndroid Build Coastguard Worker if fw_derivative.required_inputs_primal: 500*da0073e9SAndroid Build Coastguard Worker required_inputs_primal = list(fw_derivative.required_inputs_primal) 501*da0073e9SAndroid Build Coastguard Worker modified_formula = fw_derivative.formula 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker # Foreach's result is TensorList 504*da0073e9SAndroid Build Coastguard Worker if "result" in modified_formula: 505*da0073e9SAndroid Build Coastguard Worker modified_formula = fw_derivative.formula.replace("result", "result[i]") 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker for foreach_arg, ref_arg in zip( 508*da0073e9SAndroid Build Coastguard Worker foreach_function.func.arguments.flat_non_out, 509*da0073e9SAndroid Build Coastguard Worker ref_diff_info.func.func.arguments.flat_non_out, 510*da0073e9SAndroid Build Coastguard Worker ): 511*da0073e9SAndroid Build Coastguard Worker # Modify reference forward formula 512*da0073e9SAndroid Build Coastguard Worker if ( 513*da0073e9SAndroid Build Coastguard Worker isinstance(foreach_arg.type, ListType) 514*da0073e9SAndroid Build Coastguard Worker and not foreach_arg.type.is_tensor_like() 515*da0073e9SAndroid Build Coastguard Worker ): 516*da0073e9SAndroid Build Coastguard Worker # Assuming ScalarList 517*da0073e9SAndroid Build Coastguard Worker modified_formula = modified_formula.replace( 518*da0073e9SAndroid Build Coastguard Worker ref_arg.name, foreach_arg.name + "[i]" 519*da0073e9SAndroid Build Coastguard Worker ) 520*da0073e9SAndroid Build Coastguard Worker elif foreach_arg.type.is_tensor_like(): 521*da0073e9SAndroid Build Coastguard Worker # Assuming TensorList / Tensor 522*da0073e9SAndroid Build Coastguard Worker # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}" 523*da0073e9SAndroid Build Coastguard Worker assert isinstance(foreach_arg.type, ListType) or ( 524*da0073e9SAndroid Build Coastguard Worker foreach_arg.type == BaseType(BaseTy.Tensor) 525*da0073e9SAndroid Build Coastguard Worker and str(foreach_function.func.name) in _foreach_with_tensor_overload 526*da0073e9SAndroid Build Coastguard Worker ), f"{foreach_function.func.name}, {foreach_arg.type}" 527*da0073e9SAndroid Build Coastguard Worker for suffix in ("_p", "_t"): 528*da0073e9SAndroid Build Coastguard Worker curr_expr = ref_arg.name + suffix 529*da0073e9SAndroid Build Coastguard Worker if curr_expr in modified_formula: 530*da0073e9SAndroid Build Coastguard Worker new_expr = foreach_arg.name + suffix 531*da0073e9SAndroid Build Coastguard Worker modified_formula = modified_formula.replace(curr_expr, new_expr) 532*da0073e9SAndroid Build Coastguard Worker else: 533*da0073e9SAndroid Build Coastguard Worker # Assuming Scalar 534*da0073e9SAndroid Build Coastguard Worker if foreach_arg.name != ref_arg.name: 535*da0073e9SAndroid Build Coastguard Worker modified_formula = modified_formula.replace( 536*da0073e9SAndroid Build Coastguard Worker ref_arg.name, foreach_arg.name 537*da0073e9SAndroid Build Coastguard Worker ) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): there should exist a cooler way... 540*da0073e9SAndroid Build Coastguard Worker for i, name in enumerate(var_names): 541*da0073e9SAndroid Build Coastguard Worker if name == ref_arg.name: 542*da0073e9SAndroid Build Coastguard Worker var_names[i] = foreach_arg.name 543*da0073e9SAndroid Build Coastguard Worker var_types[i] = foreach_arg.type 544*da0073e9SAndroid Build Coastguard Worker for i, name in enumerate(required_inputs_fw_grad): 545*da0073e9SAndroid Build Coastguard Worker if name == ref_arg.name: 546*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad[i] = foreach_arg.name 547*da0073e9SAndroid Build Coastguard Worker for i, name in enumerate(required_inputs_primal): 548*da0073e9SAndroid Build Coastguard Worker if name == ref_arg.name: 549*da0073e9SAndroid Build Coastguard Worker required_inputs_primal[i] = foreach_arg.name 550*da0073e9SAndroid Build Coastguard Worker forward_derivatives.append( 551*da0073e9SAndroid Build Coastguard Worker ForwardDerivative( 552*da0073e9SAndroid Build Coastguard Worker formula=modified_formula, 553*da0073e9SAndroid Build Coastguard Worker var_names=tuple(var_names), 554*da0073e9SAndroid Build Coastguard Worker var_types=tuple(var_types), 555*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad=tuple(required_inputs_fw_grad), 556*da0073e9SAndroid Build Coastguard Worker required_inputs_primal=tuple(required_inputs_primal), 557*da0073e9SAndroid Build Coastguard Worker required_original_self_value=fw_derivative.required_original_self_value, 558*da0073e9SAndroid Build Coastguard Worker is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker ) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker return ( 563*da0073e9SAndroid Build Coastguard Worker DifferentiabilityInfo( 564*da0073e9SAndroid Build Coastguard Worker name=foreach_function.func.name.name.base, 565*da0073e9SAndroid Build Coastguard Worker func=foreach_function, 566*da0073e9SAndroid Build Coastguard Worker op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", 567*da0073e9SAndroid Build Coastguard Worker derivatives=modified_derivative_formulas, 568*da0073e9SAndroid Build Coastguard Worker forward_derivatives=forward_derivatives, 569*da0073e9SAndroid Build Coastguard Worker all_saved_inputs=tuple(set(all_saved_inputs)), 570*da0073e9SAndroid Build Coastguard Worker all_saved_outputs=tuple(set(all_saved_outputs)), 571*da0073e9SAndroid Build Coastguard Worker available_named_gradients=(), 572*da0073e9SAndroid Build Coastguard Worker used_named_gradients=set(), 573*da0073e9SAndroid Build Coastguard Worker args_with_derivatives=args_with_derivatives, 574*da0073e9SAndroid Build Coastguard Worker non_differentiable_arg_names=[], 575*da0073e9SAndroid Build Coastguard Worker output_differentiability=None, 576*da0073e9SAndroid Build Coastguard Worker output_differentiability_conditions=None, 577*da0073e9SAndroid Build Coastguard Worker ), 578*da0073e9SAndroid Build Coastguard Worker True, 579*da0073e9SAndroid Build Coastguard Worker ) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Workerdef match_differentiability_info( 583*da0073e9SAndroid Build Coastguard Worker native_functions: list[NativeFunction], 584*da0073e9SAndroid Build Coastguard Worker differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], 585*da0073e9SAndroid Build Coastguard Worker) -> list[NativeFunctionWithDifferentiabilityInfo]: 586*da0073e9SAndroid Build Coastguard Worker """Sets the "derivative" key on declarations to matching autograd function 587*da0073e9SAndroid Build Coastguard Worker In-place functions will use the out-of-place derivative definition if there 588*da0073e9SAndroid Build Coastguard Worker is no in-place specific derivative. 589*da0073e9SAndroid Build Coastguard Worker """ 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker functional_info_by_signature = { 592*da0073e9SAndroid Build Coastguard Worker schema.signature(strip_default=True): info_dict 593*da0073e9SAndroid Build Coastguard Worker for schema, info_dict in differentiability_infos.items() 594*da0073e9SAndroid Build Coastguard Worker if schema.kind() == SchemaKind.functional 595*da0073e9SAndroid Build Coastguard Worker } 596*da0073e9SAndroid Build Coastguard Worker non_functional_info_by_signature = { 597*da0073e9SAndroid Build Coastguard Worker schema.signature(strip_default=True): info_dict 598*da0073e9SAndroid Build Coastguard Worker for schema, info_dict in differentiability_infos.items() 599*da0073e9SAndroid Build Coastguard Worker if schema.kind() != SchemaKind.functional 600*da0073e9SAndroid Build Coastguard Worker } 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def find_info( 603*da0073e9SAndroid Build Coastguard Worker f: NativeFunction, 604*da0073e9SAndroid Build Coastguard Worker ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: 605*da0073e9SAndroid Build Coastguard Worker # Don't bother matching info to generated out= variants 606*da0073e9SAndroid Build Coastguard Worker if "generated" in f.tags and f.func.kind() == SchemaKind.out: 607*da0073e9SAndroid Build Coastguard Worker return None, False 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker # (1) Check for an exact match 610*da0073e9SAndroid Build Coastguard Worker if f.func in differentiability_infos: 611*da0073e9SAndroid Build Coastguard Worker return differentiability_infos[f.func], True 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker # (2) If no exact match, check if the out-of-place variant 614*da0073e9SAndroid Build Coastguard Worker # of this operator has a match. 615*da0073e9SAndroid Build Coastguard Worker # i.e mul() for mul_() or mul_out() 616*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing 617*da0073e9SAndroid Build Coastguard Worker # native functions instead of the out-place counterparts. 618*da0073e9SAndroid Build Coastguard Worker f_sig = f.func.signature(strip_default=True) 619*da0073e9SAndroid Build Coastguard Worker if f_sig in functional_info_by_signature and not is_foreach_func(f): 620*da0073e9SAndroid Build Coastguard Worker return functional_info_by_signature[f_sig], False 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker # (3) Some operators have a derivative explicitly defined for the mutable 623*da0073e9SAndroid Build Coastguard Worker # variant, but get a code-generated out-of-place variant which does *not* 624*da0073e9SAndroid Build Coastguard Worker # come with a derivative formula. 625*da0073e9SAndroid Build Coastguard Worker # For the generated out-of-place variant, use the mutable variant's formula 626*da0073e9SAndroid Build Coastguard Worker # if it exists. 627*da0073e9SAndroid Build Coastguard Worker if "generated" in f.tags and f_sig in non_functional_info_by_signature: 628*da0073e9SAndroid Build Coastguard Worker info_dict = non_functional_info_by_signature[f_sig] 629*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 630*da0073e9SAndroid Build Coastguard Worker assert not any( 631*da0073e9SAndroid Build Coastguard Worker any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) 632*da0073e9SAndroid Build Coastguard Worker for info in info_dict.values() 633*da0073e9SAndroid Build Coastguard Worker ), f"""\ 634*da0073e9SAndroid Build Coastguard WorkerAttempted to convert a derivative formula for a mutable operator 635*da0073e9SAndroid Build Coastguard Worker to be used by automatically by its functional variant ("{str(f.func)}"). 636*da0073e9SAndroid Build Coastguard Worker this is not currently supported (we'd need to fix up the formula in the codegen).""" 637*da0073e9SAndroid Build Coastguard Worker return info_dict, False 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` 640*da0073e9SAndroid Build Coastguard Worker if is_foreach_func(f): 641*da0073e9SAndroid Build Coastguard Worker assert f.func not in differentiability_infos 642*da0073e9SAndroid Build Coastguard Worker diff_info, is_generated = gen_foreach_derivativeinfo( 643*da0073e9SAndroid Build Coastguard Worker f, 644*da0073e9SAndroid Build Coastguard Worker functional_info_by_signature, 645*da0073e9SAndroid Build Coastguard Worker non_functional_info_by_signature, 646*da0073e9SAndroid Build Coastguard Worker ) 647*da0073e9SAndroid Build Coastguard Worker if diff_info is None: 648*da0073e9SAndroid Build Coastguard Worker return None, False 649*da0073e9SAndroid Build Coastguard Worker # TODO(crcrpar): Avoid hard coding "Default" ideally. 650*da0073e9SAndroid Build Coastguard Worker diff_info_dict = {"Default": diff_info} 651*da0073e9SAndroid Build Coastguard Worker if is_generated: 652*da0073e9SAndroid Build Coastguard Worker differentiability_infos[f.func] = diff_info_dict 653*da0073e9SAndroid Build Coastguard Worker functional_info_by_signature[f.func] = diff_info_dict 654*da0073e9SAndroid Build Coastguard Worker return diff_info_dict, is_generated 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker return None, False 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker result: list[NativeFunctionWithDifferentiabilityInfo] = [] 659*da0073e9SAndroid Build Coastguard Worker for f in native_functions: 660*da0073e9SAndroid Build Coastguard Worker info_dict, is_exact_match = find_info(f) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker # Currently, the '.strides()' to 'strides_or_error' replacement does not support 663*da0073e9SAndroid Build Coastguard Worker # 'self' derivatives of an inplace function, so we must check for this case. 664*da0073e9SAndroid Build Coastguard Worker if f.func.kind() == SchemaKind.inplace and (info_dict is not None): 665*da0073e9SAndroid Build Coastguard Worker for info in info_dict.values(): 666*da0073e9SAndroid Build Coastguard Worker for derivative in info.derivatives: 667*da0073e9SAndroid Build Coastguard Worker if "self" in derivative.var_names: 668*da0073e9SAndroid Build Coastguard Worker for saved_input in derivative.saved_inputs: 669*da0073e9SAndroid Build Coastguard Worker assert "strides_or_error" not in saved_input.expr, ( 670*da0073e9SAndroid Build Coastguard Worker "Calling '.strides()' in the 'self' derivative formula of an " 671*da0073e9SAndroid Build Coastguard Worker f"in-place function is not supported: {f.func}" 672*da0073e9SAndroid Build Coastguard Worker ) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker if not info_dict: 675*da0073e9SAndroid Build Coastguard Worker result.append( 676*da0073e9SAndroid Build Coastguard Worker NativeFunctionWithDifferentiabilityInfo( 677*da0073e9SAndroid Build Coastguard Worker func=f, info=None, fw_derivatives=None 678*da0073e9SAndroid Build Coastguard Worker ) 679*da0073e9SAndroid Build Coastguard Worker ) 680*da0073e9SAndroid Build Coastguard Worker continue 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} 683*da0073e9SAndroid Build Coastguard Worker for key, info in info_dict.items(): 684*da0073e9SAndroid Build Coastguard Worker if not info.forward_derivatives: 685*da0073e9SAndroid Build Coastguard Worker fw_derivative_dict[key] = [] 686*da0073e9SAndroid Build Coastguard Worker continue 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker forward_derivatives = info.forward_derivatives 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker # For functions that have a single def for out-of-place and inplace (like abs()) 691*da0073e9SAndroid Build Coastguard Worker if f.func.kind() == SchemaKind.inplace: 692*da0073e9SAndroid Build Coastguard Worker # For inplace functions there is a little bit of work to do: 693*da0073e9SAndroid Build Coastguard Worker # 1) Validate the formula and make sure the input that is modified in not used: 694*da0073e9SAndroid Build Coastguard Worker # - If there is a formula for the inplace variant of the function (is_exact_match == True) then 695*da0073e9SAndroid Build Coastguard Worker # we make sure that the original value of the input that is being modified inplace (self_p) is 696*da0073e9SAndroid Build Coastguard Worker # not used in the formula. Note that the formula can use "original_self_p" here and that would 697*da0073e9SAndroid Build Coastguard Worker # trigger a clone of the original input. 698*da0073e9SAndroid Build Coastguard Worker # - If we are re-using the out of place formula (is_exact_match == False) then we replace every 699*da0073e9SAndroid Build Coastguard Worker # occurrence of self_p and self_t by original_self_p and original_self_t. These will be 700*da0073e9SAndroid Build Coastguard Worker # populated by cloned version of the original input (either the clone done by the backward AD 701*da0073e9SAndroid Build Coastguard Worker # logic if self is also used in a backward formula or a special clone that we add). 702*da0073e9SAndroid Build Coastguard Worker # 2) At this point, there cannot be a self_p in the formula. 703*da0073e9SAndroid Build Coastguard Worker # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is 704*da0073e9SAndroid Build Coastguard Worker # simply called self (as it is modified inplace). 705*da0073e9SAndroid Build Coastguard Worker # 4) Update the required primals data in case it used to contain "result" but should now contain 706*da0073e9SAndroid Build Coastguard Worker # "self" 707*da0073e9SAndroid Build Coastguard Worker # 5) If it is not an exact match, the user formula is not modifying the existing forward grad 708*da0073e9SAndroid Build Coastguard Worker # inplace as it should. So add some code that makes sure that we do so if the forward grad 709*da0073e9SAndroid Build Coastguard Worker # already exists. 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker assert ( 712*da0073e9SAndroid Build Coastguard Worker len(info.forward_derivatives) == 1 713*da0073e9SAndroid Build Coastguard Worker ) # Only single output inplace should exist 714*da0073e9SAndroid Build Coastguard Worker fw_info = info.forward_derivatives[0] 715*da0073e9SAndroid Build Coastguard Worker formula = fw_info.formula 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker def replace_self_with_original_self(formula: str, postfix: str) -> str: 718*da0073e9SAndroid Build Coastguard Worker def repl(m: re.Match[str]) -> str: 719*da0073e9SAndroid Build Coastguard Worker return f"{m.group(1)}original_self{postfix}{m.group(2)}" 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker if re.search(IDENT_REGEX.format("self_p"), formula): 724*da0073e9SAndroid Build Coastguard Worker if is_exact_match: 725*da0073e9SAndroid Build Coastguard Worker # For manually defined formulas, don't allow the original value to be used 726*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 727*da0073e9SAndroid Build Coastguard Worker f'The formula for "{f.func.name}" is using the original value of self ' 728*da0073e9SAndroid Build Coastguard Worker "that is being modified inplace. This would lead to wrong forward gradients. " 729*da0073e9SAndroid Build Coastguard Worker 'Please use "result" in the formula only.' 730*da0073e9SAndroid Build Coastguard Worker ) 731*da0073e9SAndroid Build Coastguard Worker else: 732*da0073e9SAndroid Build Coastguard Worker # When the original formula is out of place, we save a clone of the primal 733*da0073e9SAndroid Build Coastguard Worker # value to be able to access this value if needed 734*da0073e9SAndroid Build Coastguard Worker # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" 735*da0073e9SAndroid Build Coastguard Worker formula = replace_self_with_original_self(formula, "_p") 736*da0073e9SAndroid Build Coastguard Worker formula = replace_self_with_original_self(formula, "_t") 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker # replace "result" from the formula by "self_p" 739*da0073e9SAndroid Build Coastguard Worker def repl(m: re.Match[str]) -> str: 740*da0073e9SAndroid Build Coastguard Worker return f"{m.group(1)}self_p{m.group(2)}" 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker formula = re.sub(IDENT_REGEX.format("result"), repl, formula) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker required_primals = fw_info.required_inputs_primal 745*da0073e9SAndroid Build Coastguard Worker if re.search(IDENT_REGEX.format("self_p"), formula): 746*da0073e9SAndroid Build Coastguard Worker required_primals = ( 747*da0073e9SAndroid Build Coastguard Worker required_primals + ("self",) if required_primals else ("self",) 748*da0073e9SAndroid Build Coastguard Worker ) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker if not is_exact_match: 751*da0073e9SAndroid Build Coastguard Worker # NOTE [In-place forward AD formula Optimization] 752*da0073e9SAndroid Build Coastguard Worker # 753*da0073e9SAndroid Build Coastguard Worker # This optimization transforms the formula to directly do inplace, i.e. 754*da0073e9SAndroid Build Coastguard Worker # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: 755*da0073e9SAndroid Build Coastguard Worker # 756*da0073e9SAndroid Build Coastguard Worker # 1) the formula satisfies the pattern: "self_t.op(*args)" 757*da0073e9SAndroid Build Coastguard Worker # 2) "op" in (1) needs to be the same as the op the derivative is for 758*da0073e9SAndroid Build Coastguard Worker # 759*da0073e9SAndroid Build Coastguard Worker # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) 760*da0073e9SAndroid Build Coastguard Worker # If there is a need, we can relax (2) to allow any op that has an in-place variant 761*da0073e9SAndroid Build Coastguard Worker is_single_method_on_self_t = False 762*da0073e9SAndroid Build Coastguard Worker directly_do_inplace = False 763*da0073e9SAndroid Build Coastguard Worker op_name: str | None = None 764*da0073e9SAndroid Build Coastguard Worker between_parens: str | None = None 765*da0073e9SAndroid Build Coastguard Worker match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) 766*da0073e9SAndroid Build Coastguard Worker if match: 767*da0073e9SAndroid Build Coastguard Worker op_name, between_parens = match.group(1), match.group(2) 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker # We want to... 770*da0073e9SAndroid Build Coastguard Worker # Match: self_t.op1(other_p.op2(arg)) 771*da0073e9SAndroid Build Coastguard Worker # Avoid: self_t.op1(args) + self_t.op2(args) 772*da0073e9SAndroid Build Coastguard Worker # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) 773*da0073e9SAndroid Build Coastguard Worker def check_parens_nest_level_gt_zero(s: str) -> bool: 774*da0073e9SAndroid Build Coastguard Worker level = 1 775*da0073e9SAndroid Build Coastguard Worker for ch in s: 776*da0073e9SAndroid Build Coastguard Worker if ch == ")": 777*da0073e9SAndroid Build Coastguard Worker level -= 1 778*da0073e9SAndroid Build Coastguard Worker if level == 0: 779*da0073e9SAndroid Build Coastguard Worker return False 780*da0073e9SAndroid Build Coastguard Worker if ch == "(": 781*da0073e9SAndroid Build Coastguard Worker level += 1 782*da0073e9SAndroid Build Coastguard Worker return True 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker is_single_method_on_self_t = check_parens_nest_level_gt_zero( 785*da0073e9SAndroid Build Coastguard Worker between_parens 786*da0073e9SAndroid Build Coastguard Worker ) 787*da0073e9SAndroid Build Coastguard Worker directly_do_inplace = ( 788*da0073e9SAndroid Build Coastguard Worker is_single_method_on_self_t and op_name == info.name 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker if directly_do_inplace: 792*da0073e9SAndroid Build Coastguard Worker assert op_name is not None 793*da0073e9SAndroid Build Coastguard Worker assert between_parens is not None 794*da0073e9SAndroid Build Coastguard Worker formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" 795*da0073e9SAndroid Build Coastguard Worker else: 796*da0073e9SAndroid Build Coastguard Worker # Make sure that the forward grad is modified inplace when the original formula 797*da0073e9SAndroid Build Coastguard Worker # is out of place 798*da0073e9SAndroid Build Coastguard Worker formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker required_original_self_value = bool( 801*da0073e9SAndroid Build Coastguard Worker re.search(IDENT_REGEX.format("original_self_p"), formula) 802*da0073e9SAndroid Build Coastguard Worker ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker forward_derivatives = [ 805*da0073e9SAndroid Build Coastguard Worker ForwardDerivative( 806*da0073e9SAndroid Build Coastguard Worker formula=formula, 807*da0073e9SAndroid Build Coastguard Worker var_names=("self",), 808*da0073e9SAndroid Build Coastguard Worker var_types=fw_info.var_types, 809*da0073e9SAndroid Build Coastguard Worker required_inputs_fw_grad=fw_info.required_inputs_fw_grad, 810*da0073e9SAndroid Build Coastguard Worker required_inputs_primal=required_primals, 811*da0073e9SAndroid Build Coastguard Worker required_original_self_value=required_original_self_value, 812*da0073e9SAndroid Build Coastguard Worker is_reusing_outplace_formula=not is_exact_match, 813*da0073e9SAndroid Build Coastguard Worker ), 814*da0073e9SAndroid Build Coastguard Worker ] 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker fw_derivative_dict[key] = forward_derivatives 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker result.append( 819*da0073e9SAndroid Build Coastguard Worker NativeFunctionWithDifferentiabilityInfo( 820*da0073e9SAndroid Build Coastguard Worker func=f, info=info_dict, fw_derivatives=fw_derivative_dict 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker return result 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Workerdef is_differentiable( 828*da0073e9SAndroid Build Coastguard Worker name: str, type: Type, info: DifferentiabilityInfo | None 829*da0073e9SAndroid Build Coastguard Worker) -> bool: 830*da0073e9SAndroid Build Coastguard Worker return type.is_tensor_like() and ( 831*da0073e9SAndroid Build Coastguard Worker info is None or name not in info.non_differentiable_arg_names 832*da0073e9SAndroid Build Coastguard Worker ) 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Workerdef gen_differentiable_outputs( 836*da0073e9SAndroid Build Coastguard Worker fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" 837*da0073e9SAndroid Build Coastguard Worker) -> list[DifferentiableOutput]: 838*da0073e9SAndroid Build Coastguard Worker f = fn.func 839*da0073e9SAndroid Build Coastguard Worker info = fn.info[key] if fn.info else None 840*da0073e9SAndroid Build Coastguard Worker outputs: list[DifferentiableOutput] = [ 841*da0073e9SAndroid Build Coastguard Worker DifferentiableOutput( 842*da0073e9SAndroid Build Coastguard Worker name=name, 843*da0073e9SAndroid Build Coastguard Worker type=ret.type, 844*da0073e9SAndroid Build Coastguard Worker cpp_type=cpp.return_type(ret, symint=True).cpp_type(), 845*da0073e9SAndroid Build Coastguard Worker ) 846*da0073e9SAndroid Build Coastguard Worker for name, ret in zip(cpp.return_names(f), f.func.returns) 847*da0073e9SAndroid Build Coastguard Worker ] 848*da0073e9SAndroid Build Coastguard Worker output_differentiability = info.output_differentiability if info else None 849*da0073e9SAndroid Build Coastguard Worker if output_differentiability is not None: 850*da0073e9SAndroid Build Coastguard Worker if len(output_differentiability) != len(outputs): 851*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 852*da0073e9SAndroid Build Coastguard Worker f"The length of output_differentiability ({len(output_differentiability)}), " 853*da0073e9SAndroid Build Coastguard Worker f"does not match the number of outputs ({len(outputs)})." 854*da0073e9SAndroid Build Coastguard Worker ) 855*da0073e9SAndroid Build Coastguard Worker differentiable_outputs: list[DifferentiableOutput] = [] 856*da0073e9SAndroid Build Coastguard Worker if False in output_differentiability and f.func.kind() == SchemaKind.inplace: 857*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 858*da0073e9SAndroid Build Coastguard Worker "output_differentiability=False for inplace operation (version_counter won't get updated)" 859*da0073e9SAndroid Build Coastguard Worker ) 860*da0073e9SAndroid Build Coastguard Worker for differentiable, output in zip(output_differentiability, outputs): 861*da0073e9SAndroid Build Coastguard Worker if differentiable: 862*da0073e9SAndroid Build Coastguard Worker differentiable_outputs.append(output) 863*da0073e9SAndroid Build Coastguard Worker return differentiable_outputs 864*da0073e9SAndroid Build Coastguard Worker candidate_differentiable_outputs = list( 865*da0073e9SAndroid Build Coastguard Worker filter(lambda r: is_differentiable(r.name, r.type, info), outputs) 866*da0073e9SAndroid Build Coastguard Worker ) 867*da0073e9SAndroid Build Coastguard Worker if uses_single_grad(info): 868*da0073e9SAndroid Build Coastguard Worker return candidate_differentiable_outputs[:1] 869*da0073e9SAndroid Build Coastguard Worker else: 870*da0073e9SAndroid Build Coastguard Worker return candidate_differentiable_outputs 871