1from __future__ import annotations 2 3import re 4from dataclasses import dataclass 5from typing import cast, Sequence 6 7from torchgen import local 8from torchgen.api import cpp 9from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT 10from torchgen.model import ( 11 BaseTy, 12 BaseType, 13 FunctionSchema, 14 ListType, 15 NativeFunction, 16 NativeFunctionsViewGroup, 17 SchemaKind, 18 Type, 19) 20from torchgen.utils import IDENT_REGEX 21 22 23# Represents a saved attribute involved in backward calculation. 24# Note that it can be a derived property of an input argument, e.g.: 25# we could save `other.scalar_type()` instead of the entire `other` tensor. 26@dataclass(frozen=True) 27class SavedAttribute: 28 # The NamedCType holds the updated name and cpp type of the attribute 29 # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type` 30 nctype: NamedCType 31 32 # The expression to read the derived property at save time, e.g.: 33 # `other.scalar_type()`. 34 expr: str 35 36 37# Represents a backward formula that calculates derivatives for one 38# or more tensors. 39@dataclass(frozen=True) 40class Derivative: 41 # The formula string (legit C++ expression). 42 # Note that expressions against input arguments have been replaced with the 43 # corresponding saved attributes. 44 # E.g.: 45 # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` 46 # here: `mul_tensor_backward(grad, self, other_scalar_type)` 47 formula: str 48 49 # The formula string before input argument replacement 50 original_formula: str 51 52 # Names of the arguments for which this formula calculates derivatives. 53 var_names: tuple[str, ...] 54 55 # Saved inputs that are referenced by the formula. 56 saved_inputs: tuple[SavedAttribute, ...] 57 58 # Saved outputs that are referenced by the formula. 59 saved_outputs: tuple[SavedAttribute, ...] 60 61 # Gradients that are referenced by name in the formula. 62 named_gradients: set[str] 63 64 65# Represents a forward formula that calculates forward derivatives 66# for one tensor. 67@dataclass(frozen=True) 68class ForwardDerivative: 69 # The formula string (legit C++ expression). 70 # Note that special keywords such as "linear" or "element_wise" have been 71 # replaced by the automatically generated formula. 72 formula: str 73 74 # Name of the output arguments for which this formula calculates forward 75 # derivatives 76 var_names: tuple[str, ...] 77 78 # Type of the output arguments for which this formula calculates forward 79 # derivatives 80 var_types: tuple[Type, ...] 81 82 # Inputs for which the forward derivatives are required for this formula 83 required_inputs_fw_grad: tuple[str, ...] | None 84 85 # Inputs for which the primal is required for this formula 86 required_inputs_primal: tuple[str, ...] | None 87 88 # Flag to specify if this formula requires the original value of self 89 # This is only used by inplace operations 90 required_original_self_value: bool 91 92 # If this formula is specified in derivatives.yaml or if we are re-using the 93 # out of place formula for inplace 94 is_reusing_outplace_formula: bool 95 96 97# Represents differentiability info for a NativeFunction. 98@dataclass(frozen=True) 99class DifferentiabilityInfo: 100 # The base name read from derivatives.yaml. 101 name: str 102 103 # The matching native function. 104 # 105 # There can be multiple NativeFunction having the same base name: 106 # - different overloads with different types of input arguments; 107 # - in-place/out/functional variants of the same function; 108 # 109 # We first use the schema string (under the 'name' key) in derivatives.yaml 110 # to find the NativeFunction having the same schema string. 111 # Then we find the in-place/out/functional variants of the matching function. 112 # Among these variants, we choose the one having the same name as the 113 # derivatives.yaml entry. If there is no exact match, then we choose the 114 # in-place variant. 115 # TODO: maybe the logic to search for all variants is no longer necessary? 116 func: NativeFunction 117 118 # The name of the generated autograd function. 119 # It's set only if we will calculate a derivative, i.e. 120 # 'args_with_derivatives' is not empty. 121 op: str | None 122 123 # The derivatives formulae for this function. 124 # Note that the length of this sequence is the number of differentiable inputs 125 derivatives: Sequence[Derivative] 126 127 # The forward derivatives formulae for this function. 128 # Note that the length of this sequence is the number of differentiable outputs 129 forward_derivatives: Sequence[ForwardDerivative] 130 131 # The union of 'saved_inputs' of all 'derivatives'. 132 all_saved_inputs: Sequence[SavedAttribute] 133 134 # The union of 'saved_outputs' of all 'derivatives'. 135 all_saved_outputs: Sequence[SavedAttribute] 136 137 # All named gradients that are available for use, in the same 138 # order as in the grads vector. 139 available_named_gradients: Sequence[str] 140 141 # The named gradients that are used in any of the derivatives. 142 # Invariant: all(name in available_named_gradients for name in used_named_gradients) 143 used_named_gradients: set[str] 144 145 # The function's input arguments for which it calculates derivatives. 146 # It's the union of 'var_names' of all 'derivatives', sorted by the 147 # argument order in the function schema. 148 args_with_derivatives: Sequence[Binding] 149 150 # Names of arguments whose derivative formula is 'non_differentiable'. 151 non_differentiable_arg_names: Sequence[str] 152 153 # Raw data read from derivatives.yaml. 154 output_differentiability: list[bool] | None 155 156 # output_differentiability in derivatives.yaml can be a list of 157 # conditions that express if the output is differentiable. In this case, 158 # the number of conditions must match the number of outputs 159 # (NB: we only support one condition right now). 160 # output_differentiability gets populated with True for each condition, 161 # while output_differentiability_conditions gets populated with the conditions 162 output_differentiability_conditions: list[str] | None 163 164 @property 165 def has_derivatives(self) -> bool: 166 return len(self.args_with_derivatives) > 0 167 168 # Generates a new DifferentiabilityInfo using the exact same set of derivative information, 169 # but with a new operator name. 170 # This is used when generating "copy" variants of view ops, 171 # which are able to use the exact same derivative formula as the original view op 172 # See Note [Codegen'd {view}_copy Operators] 173 def create_view_copy_from_view_derivative( 174 self, g: NativeFunctionsViewGroup 175 ) -> DifferentiabilityInfo | None: 176 if g.view_copy is None: 177 return None 178 f = g.view_copy 179 180 name_split_by_period = self.name.split(".", maxsplit=2) 181 # Append a "_copy" to the base name of the operator (but keep the overload name the same) 182 view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( 183 name_split_by_period[1:] 184 ) 185 view_copy_op_name = None if self.op is None else f"{self.op}_copy" 186 187 return DifferentiabilityInfo( 188 # Use the "_copy" version of name/func/op 189 name=view_copy_name, 190 func=f, 191 op=view_copy_op_name, 192 # But keep all derivative info the same 193 derivatives=self.derivatives, 194 forward_derivatives=self.forward_derivatives, 195 all_saved_inputs=self.all_saved_inputs, 196 all_saved_outputs=self.all_saved_outputs, 197 available_named_gradients=self.available_named_gradients, 198 used_named_gradients=self.used_named_gradients, 199 args_with_derivatives=self.args_with_derivatives, 200 non_differentiable_arg_names=self.non_differentiable_arg_names, 201 output_differentiability=self.output_differentiability, 202 output_differentiability_conditions=self.output_differentiability_conditions, 203 ) 204 205 206def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: 207 if info is None: 208 return False 209 for derivative in info.derivatives: 210 formula = derivative.formula 211 if re.search(IDENT_REGEX.format(ident), formula): 212 return True 213 return False 214 215 216def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: 217 return uses_ident(info, "retain_variables") 218 219 220def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: 221 return uses_ident(info, "grad") 222 223 224# Represents a differentiable `Argument`. 225# How is it different from the `Argument` type? 226# - It's processed Arguments which are differentiable and only used in the 227# context of the autograd codegen; 228# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; 229@dataclass(frozen=True) 230class DifferentiableInput: 231 name: str 232 type: Type 233 234 # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. 235 cpp_type: str 236 237 238# Represents a differentiable `Return`. 239# How it it different from the `Return` type? 240# - The name in `Return` is optional. Here it is always populated using the same 241# `cpp.return_names()` method. 242# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? 243# - It's processed Returns which are differentiable, in compliance with the 244# `output_differentiability` field defined in derivatives.yaml (if specified), 245# and are only used in the context of the autograd codegen; 246@dataclass(frozen=True) 247class DifferentiableOutput: 248 name: str 249 type: Type 250 251 # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. 252 cpp_type: str 253 254 255@dataclass(frozen=True) 256class NativeFunctionWithDifferentiabilityInfo: 257 func: NativeFunction 258 info: dict[str, DifferentiabilityInfo] | None 259 fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None 260 261 262# TODO: Update comment below since it is out of date. 263def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: 264 """How are we going to call the underlying implementation of a 265 declaration? There are two strategies: 266 - use_derived: we want to call the implementation on CPUDoubleType 267 (or a similar, derived Type instance). Because these derived 268 instances deal in Tensors, not Variables (it's a completely different 269 object, so it doesn't dispatch back to VariableType), code on 270 this dispatch path needs to wrap/unwrap tensors. If the 271 derived implementation takes and returns tensors, the 272 implementation is usually differentiable (although we also use 273 the derived dispatch path for non-differentiable functions 274 that we still want to dispatch on the derived Type instance; 275 e.g., size()) 276 - use_type: we want to call the implementation on Type, because 277 it is implemented concretely, and the functions it invokes will 278 get dispatched back to VariableType (which will ensure that they 279 are differentiable.) 280 """ 281 # fn is derived as long as any of its per-key differentiability infos 282 # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType 283 # and ADInplaceOrViewType. We want to generate these functions as long as a 284 # derivative is defined for ANY dispatch key. 285 if fn.func.is_abstract or ( 286 fn.info is not None and any(info.has_derivatives for info in fn.info.values()) 287 ): 288 # If the function is abstract (not implemented on at::Type), we must 289 # call the implementation on the derived type with unpacked tensors. 290 291 # If the function has a derivative specified and is concrete, we could 292 # call either implementation. We prefer the calling the derived 293 # type's implementation with unpacked tensors because it is more 294 # performant in some cases: any internal calls to other ATen functions 295 # won't have the history tracked. 296 297 # If the function has a type dispatched argument (i.e. is a factory), 298 # we prefer calling the derived type's implementation both because it is 299 # more performant and to ensure factory functions return tensors with _version 300 # of 0 (probably not strictly necessary, but nice to have to keeps versions simple 301 # to understand. 302 303 return "use_derived" 304 else: 305 # If the function is concrete (we don't have to override it) and we 306 # didn't declare it in derivatives.yaml, we'll assume that it is 307 # actually implemented out of differentiable functions. (This 308 # assumption might not hold, but then you'll see gradcheck fail.) 309 return "use_type" 310 311 312def is_foreach_func(f: NativeFunction) -> bool: 313 return f.func.name.name.base.startswith("_foreach_") 314 315 316# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind 317# is functional for their backward derivatives (and forward derivatives in the future), i.e., 318# they would find such one in `functional_info_by_signature`. There however are some exceptions: 319_foreach_with_inplace_ref = {"_foreach_zero_"} 320_foreach_with_tensor_overload = { 321 "_foreach_add.Tensor", 322 "_foreach_mul.Tensor", 323 "_foreach_div.Tensor", 324} 325# The following do not support the alpha kwarg, which the nonforeach versions support. 326_skip_argument_len_check = { 327 "_foreach_add.Scalar", 328 "_foreach_add_.Scalar", 329 "_foreach_add.ScalarList", 330 "_foreach_add_.ScalarList", 331 "_foreach_sub.Scalar", 332 "_foreach_sub_.Scalar", 333 "_foreach_sub.ScalarList", 334 "_foreach_sub_.ScalarList", 335} 336 337 338# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function 339# reference to generate derivatives. 340def is_reference_for_foreach( 341 f: NativeFunction, 342 function_schema: FunctionSchema, 343) -> bool: 344 return ( 345 f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base 346 and ( 347 not function_schema.name.name.inplace 348 or str(f.func.name) in _foreach_with_inplace_ref 349 ) 350 and ( 351 str(f.func.name) in _skip_argument_len_check 352 or len(f.func.arguments.flat_non_out) 353 == len(function_schema.arguments.flat_non_out) 354 ) 355 and all( 356 ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) 357 for arg, ref_arg in zip( 358 f.func.arguments.flat_non_out, 359 function_schema.arguments.flat_non_out, 360 ) 361 ) 362 ) 363 364 365# TODO(crcrpar): Avoid hard coding "Default" ideally. 366def gen_foreach_derivativeinfo( 367 foreach_function: NativeFunction, 368 functional_info_by_signature: dict[ 369 FunctionSchema, dict[str, DifferentiabilityInfo] 370 ], 371 non_functional_info_by_signature: dict[ 372 FunctionSchema, dict[str, DifferentiabilityInfo] 373 ], 374 dispatch_key: str = "Default", 375) -> tuple[DifferentiabilityInfo | None, bool]: 376 """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. 377 378 The second return value indicates whether the info is generated in this function. 379 """ 380 ref_diff_info: DifferentiabilityInfo | None = None 381 382 for function_schema, diff_info in functional_info_by_signature.items(): 383 if not is_reference_for_foreach(foreach_function, function_schema): 384 continue 385 ref_diff_info = diff_info[dispatch_key] 386 if ref_diff_info is not None: 387 break 388 # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature 389 # while the info of `zero_` is in non_functional_info_by_signature 390 if ( 391 ref_diff_info is None 392 and foreach_function.func.kind() == SchemaKind.inplace 393 and str(foreach_function.func.name) in _foreach_with_inplace_ref 394 ): 395 for function_schema, diff_info in non_functional_info_by_signature.items(): 396 if not is_reference_for_foreach(foreach_function, function_schema): 397 continue 398 ref_diff_info = diff_info[dispatch_key] 399 if ref_diff_info is not None: 400 break 401 if ref_diff_info is None: 402 return None, False 403 404 # non out-place uses the existing Derivative. 405 if foreach_function.func.kind() == SchemaKind.inplace: 406 return ref_diff_info, False 407 408 map_refarg2foreacharg, map_name2arg = {}, {} 409 for i, (arg, ref_arg) in enumerate( 410 zip( 411 foreach_function.func.arguments.flat_non_out, 412 function_schema.arguments.flat_non_out, 413 ) 414 ): 415 map_refarg2foreacharg[ref_arg.name] = arg.name 416 map_name2arg[arg.name] = arg 417 418 all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] 419 modified_derivative_formulas = [] 420 for i, derivative in enumerate(ref_diff_info.derivatives): 421 modified_formula = derivative.formula.replace("grad", "grads[i]").replace( 422 "result", "result[i]" 423 ) 424 saved_inputs, saved_outputs = [], [] 425 # note(crcrpar): This context seems necessary to call `cpp.argument_type` 426 with local.parametrize( 427 use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, 428 use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, 429 ): 430 for ref_input in derivative.saved_inputs: 431 ref_input_jit_name = ref_input.expr.split(".")[0] 432 mapped_name = map_refarg2foreacharg[ref_input_jit_name] 433 if isinstance(map_name2arg[mapped_name].type, ListType): 434 mapped_expr = mapped_name + "[i]" 435 else: 436 mapped_expr = mapped_name 437 new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) 438 modified_formula = modified_formula.replace( 439 cast(str, ref_input.nctype.name), new_expr 440 ) 441 442 nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) 443 canonical_nctype = NamedCType( 444 nctype.name, nctype.type.remove_const_ref() 445 ) 446 saved_inputs.append( 447 SavedAttribute(nctype=canonical_nctype, expr=mapped_name) 448 ) 449 for ref_output in derivative.saved_outputs: 450 if ref_output.nctype.name == "result": 451 saved_outputs.append( 452 SavedAttribute( 453 nctype=NamedCType( 454 name="result", type=BaseCType(tensorListT) 455 ), 456 expr="result", 457 ) 458 ) 459 else: 460 raise RuntimeError("") 461 var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] 462 all_var_names.extend(var_names) 463 all_saved_inputs.extend(saved_inputs) 464 all_saved_outputs.extend(saved_outputs) 465 modified_derivative = Derivative( 466 formula=modified_formula, 467 original_formula=derivative.formula, 468 var_names=tuple(var_names), 469 saved_inputs=tuple(saved_inputs), 470 saved_outputs=tuple(saved_outputs), 471 named_gradients=set(), 472 ) 473 modified_derivative_formulas.append(modified_derivative) 474 475 with local.parametrize( 476 use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, 477 use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, 478 ): 479 args_with_derivatives = [ 480 Binding( 481 name=arg.name, 482 nctype=cpp.argument_type(arg, binds=arg.name), 483 argument=arg, 484 default=None, 485 ) 486 for arg in foreach_function.func.arguments.flat_non_out 487 if arg.name in all_var_names 488 ] 489 490 forward_derivatives: list[ForwardDerivative] = [] 491 fw_derivative: ForwardDerivative 492 for fw_derivative in ref_diff_info.forward_derivatives: 493 var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef] 494 var_types: list[Type] = list(fw_derivative.var_types) 495 required_inputs_fw_grad: list[str] = [] 496 required_inputs_primal: list[str] = [] 497 if fw_derivative.required_inputs_fw_grad is not None: 498 required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) 499 if fw_derivative.required_inputs_primal: 500 required_inputs_primal = list(fw_derivative.required_inputs_primal) 501 modified_formula = fw_derivative.formula 502 503 # Foreach's result is TensorList 504 if "result" in modified_formula: 505 modified_formula = fw_derivative.formula.replace("result", "result[i]") 506 507 for foreach_arg, ref_arg in zip( 508 foreach_function.func.arguments.flat_non_out, 509 ref_diff_info.func.func.arguments.flat_non_out, 510 ): 511 # Modify reference forward formula 512 if ( 513 isinstance(foreach_arg.type, ListType) 514 and not foreach_arg.type.is_tensor_like() 515 ): 516 # Assuming ScalarList 517 modified_formula = modified_formula.replace( 518 ref_arg.name, foreach_arg.name + "[i]" 519 ) 520 elif foreach_arg.type.is_tensor_like(): 521 # Assuming TensorList / Tensor 522 # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}" 523 assert isinstance(foreach_arg.type, ListType) or ( 524 foreach_arg.type == BaseType(BaseTy.Tensor) 525 and str(foreach_function.func.name) in _foreach_with_tensor_overload 526 ), f"{foreach_function.func.name}, {foreach_arg.type}" 527 for suffix in ("_p", "_t"): 528 curr_expr = ref_arg.name + suffix 529 if curr_expr in modified_formula: 530 new_expr = foreach_arg.name + suffix 531 modified_formula = modified_formula.replace(curr_expr, new_expr) 532 else: 533 # Assuming Scalar 534 if foreach_arg.name != ref_arg.name: 535 modified_formula = modified_formula.replace( 536 ref_arg.name, foreach_arg.name 537 ) 538 539 # note(crcrpar): there should exist a cooler way... 540 for i, name in enumerate(var_names): 541 if name == ref_arg.name: 542 var_names[i] = foreach_arg.name 543 var_types[i] = foreach_arg.type 544 for i, name in enumerate(required_inputs_fw_grad): 545 if name == ref_arg.name: 546 required_inputs_fw_grad[i] = foreach_arg.name 547 for i, name in enumerate(required_inputs_primal): 548 if name == ref_arg.name: 549 required_inputs_primal[i] = foreach_arg.name 550 forward_derivatives.append( 551 ForwardDerivative( 552 formula=modified_formula, 553 var_names=tuple(var_names), 554 var_types=tuple(var_types), 555 required_inputs_fw_grad=tuple(required_inputs_fw_grad), 556 required_inputs_primal=tuple(required_inputs_primal), 557 required_original_self_value=fw_derivative.required_original_self_value, 558 is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, 559 ) 560 ) 561 562 return ( 563 DifferentiabilityInfo( 564 name=foreach_function.func.name.name.base, 565 func=foreach_function, 566 op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", 567 derivatives=modified_derivative_formulas, 568 forward_derivatives=forward_derivatives, 569 all_saved_inputs=tuple(set(all_saved_inputs)), 570 all_saved_outputs=tuple(set(all_saved_outputs)), 571 available_named_gradients=(), 572 used_named_gradients=set(), 573 args_with_derivatives=args_with_derivatives, 574 non_differentiable_arg_names=[], 575 output_differentiability=None, 576 output_differentiability_conditions=None, 577 ), 578 True, 579 ) 580 581 582def match_differentiability_info( 583 native_functions: list[NativeFunction], 584 differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], 585) -> list[NativeFunctionWithDifferentiabilityInfo]: 586 """Sets the "derivative" key on declarations to matching autograd function 587 In-place functions will use the out-of-place derivative definition if there 588 is no in-place specific derivative. 589 """ 590 591 functional_info_by_signature = { 592 schema.signature(strip_default=True): info_dict 593 for schema, info_dict in differentiability_infos.items() 594 if schema.kind() == SchemaKind.functional 595 } 596 non_functional_info_by_signature = { 597 schema.signature(strip_default=True): info_dict 598 for schema, info_dict in differentiability_infos.items() 599 if schema.kind() != SchemaKind.functional 600 } 601 602 def find_info( 603 f: NativeFunction, 604 ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: 605 # Don't bother matching info to generated out= variants 606 if "generated" in f.tags and f.func.kind() == SchemaKind.out: 607 return None, False 608 609 # (1) Check for an exact match 610 if f.func in differentiability_infos: 611 return differentiability_infos[f.func], True 612 613 # (2) If no exact match, check if the out-of-place variant 614 # of this operator has a match. 615 # i.e mul() for mul_() or mul_out() 616 # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing 617 # native functions instead of the out-place counterparts. 618 f_sig = f.func.signature(strip_default=True) 619 if f_sig in functional_info_by_signature and not is_foreach_func(f): 620 return functional_info_by_signature[f_sig], False 621 622 # (3) Some operators have a derivative explicitly defined for the mutable 623 # variant, but get a code-generated out-of-place variant which does *not* 624 # come with a derivative formula. 625 # For the generated out-of-place variant, use the mutable variant's formula 626 # if it exists. 627 if "generated" in f.tags and f_sig in non_functional_info_by_signature: 628 info_dict = non_functional_info_by_signature[f_sig] 629 # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 630 assert not any( 631 any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs) 632 for info in info_dict.values() 633 ), f"""\ 634Attempted to convert a derivative formula for a mutable operator 635 to be used by automatically by its functional variant ("{str(f.func)}"). 636 this is not currently supported (we'd need to fix up the formula in the codegen).""" 637 return info_dict, False 638 639 # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` 640 if is_foreach_func(f): 641 assert f.func not in differentiability_infos 642 diff_info, is_generated = gen_foreach_derivativeinfo( 643 f, 644 functional_info_by_signature, 645 non_functional_info_by_signature, 646 ) 647 if diff_info is None: 648 return None, False 649 # TODO(crcrpar): Avoid hard coding "Default" ideally. 650 diff_info_dict = {"Default": diff_info} 651 if is_generated: 652 differentiability_infos[f.func] = diff_info_dict 653 functional_info_by_signature[f.func] = diff_info_dict 654 return diff_info_dict, is_generated 655 656 return None, False 657 658 result: list[NativeFunctionWithDifferentiabilityInfo] = [] 659 for f in native_functions: 660 info_dict, is_exact_match = find_info(f) 661 662 # Currently, the '.strides()' to 'strides_or_error' replacement does not support 663 # 'self' derivatives of an inplace function, so we must check for this case. 664 if f.func.kind() == SchemaKind.inplace and (info_dict is not None): 665 for info in info_dict.values(): 666 for derivative in info.derivatives: 667 if "self" in derivative.var_names: 668 for saved_input in derivative.saved_inputs: 669 assert "strides_or_error" not in saved_input.expr, ( 670 "Calling '.strides()' in the 'self' derivative formula of an " 671 f"in-place function is not supported: {f.func}" 672 ) 673 674 if not info_dict: 675 result.append( 676 NativeFunctionWithDifferentiabilityInfo( 677 func=f, info=None, fw_derivatives=None 678 ) 679 ) 680 continue 681 682 fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} 683 for key, info in info_dict.items(): 684 if not info.forward_derivatives: 685 fw_derivative_dict[key] = [] 686 continue 687 688 forward_derivatives = info.forward_derivatives 689 690 # For functions that have a single def for out-of-place and inplace (like abs()) 691 if f.func.kind() == SchemaKind.inplace: 692 # For inplace functions there is a little bit of work to do: 693 # 1) Validate the formula and make sure the input that is modified in not used: 694 # - If there is a formula for the inplace variant of the function (is_exact_match == True) then 695 # we make sure that the original value of the input that is being modified inplace (self_p) is 696 # not used in the formula. Note that the formula can use "original_self_p" here and that would 697 # trigger a clone of the original input. 698 # - If we are re-using the out of place formula (is_exact_match == False) then we replace every 699 # occurrence of self_p and self_t by original_self_p and original_self_t. These will be 700 # populated by cloned version of the original input (either the clone done by the backward AD 701 # logic if self is also used in a backward formula or a special clone that we add). 702 # 2) At this point, there cannot be a self_p in the formula. 703 # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is 704 # simply called self (as it is modified inplace). 705 # 4) Update the required primals data in case it used to contain "result" but should now contain 706 # "self" 707 # 5) If it is not an exact match, the user formula is not modifying the existing forward grad 708 # inplace as it should. So add some code that makes sure that we do so if the forward grad 709 # already exists. 710 711 assert ( 712 len(info.forward_derivatives) == 1 713 ) # Only single output inplace should exist 714 fw_info = info.forward_derivatives[0] 715 formula = fw_info.formula 716 717 def replace_self_with_original_self(formula: str, postfix: str) -> str: 718 def repl(m: re.Match[str]) -> str: 719 return f"{m.group(1)}original_self{postfix}{m.group(2)}" 720 721 return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) 722 723 if re.search(IDENT_REGEX.format("self_p"), formula): 724 if is_exact_match: 725 # For manually defined formulas, don't allow the original value to be used 726 raise RuntimeError( 727 f'The formula for "{f.func.name}" is using the original value of self ' 728 "that is being modified inplace. This would lead to wrong forward gradients. " 729 'Please use "result" in the formula only.' 730 ) 731 else: 732 # When the original formula is out of place, we save a clone of the primal 733 # value to be able to access this value if needed 734 # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" 735 formula = replace_self_with_original_self(formula, "_p") 736 formula = replace_self_with_original_self(formula, "_t") 737 738 # replace "result" from the formula by "self_p" 739 def repl(m: re.Match[str]) -> str: 740 return f"{m.group(1)}self_p{m.group(2)}" 741 742 formula = re.sub(IDENT_REGEX.format("result"), repl, formula) 743 744 required_primals = fw_info.required_inputs_primal 745 if re.search(IDENT_REGEX.format("self_p"), formula): 746 required_primals = ( 747 required_primals + ("self",) if required_primals else ("self",) 748 ) 749 750 if not is_exact_match: 751 # NOTE [In-place forward AD formula Optimization] 752 # 753 # This optimization transforms the formula to directly do inplace, i.e. 754 # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: 755 # 756 # 1) the formula satisfies the pattern: "self_t.op(*args)" 757 # 2) "op" in (1) needs to be the same as the op the derivative is for 758 # 759 # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) 760 # If there is a need, we can relax (2) to allow any op that has an in-place variant 761 is_single_method_on_self_t = False 762 directly_do_inplace = False 763 op_name: str | None = None 764 between_parens: str | None = None 765 match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) 766 if match: 767 op_name, between_parens = match.group(1), match.group(2) 768 769 # We want to... 770 # Match: self_t.op1(other_p.op2(arg)) 771 # Avoid: self_t.op1(args) + self_t.op2(args) 772 # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) 773 def check_parens_nest_level_gt_zero(s: str) -> bool: 774 level = 1 775 for ch in s: 776 if ch == ")": 777 level -= 1 778 if level == 0: 779 return False 780 if ch == "(": 781 level += 1 782 return True 783 784 is_single_method_on_self_t = check_parens_nest_level_gt_zero( 785 between_parens 786 ) 787 directly_do_inplace = ( 788 is_single_method_on_self_t and op_name == info.name 789 ) 790 791 if directly_do_inplace: 792 assert op_name is not None 793 assert between_parens is not None 794 formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" 795 else: 796 # Make sure that the forward grad is modified inplace when the original formula 797 # is out of place 798 formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" 799 800 required_original_self_value = bool( 801 re.search(IDENT_REGEX.format("original_self_p"), formula) 802 ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) 803 804 forward_derivatives = [ 805 ForwardDerivative( 806 formula=formula, 807 var_names=("self",), 808 var_types=fw_info.var_types, 809 required_inputs_fw_grad=fw_info.required_inputs_fw_grad, 810 required_inputs_primal=required_primals, 811 required_original_self_value=required_original_self_value, 812 is_reusing_outplace_formula=not is_exact_match, 813 ), 814 ] 815 816 fw_derivative_dict[key] = forward_derivatives 817 818 result.append( 819 NativeFunctionWithDifferentiabilityInfo( 820 func=f, info=info_dict, fw_derivatives=fw_derivative_dict 821 ) 822 ) 823 824 return result 825 826 827def is_differentiable( 828 name: str, type: Type, info: DifferentiabilityInfo | None 829) -> bool: 830 return type.is_tensor_like() and ( 831 info is None or name not in info.non_differentiable_arg_names 832 ) 833 834 835def gen_differentiable_outputs( 836 fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" 837) -> list[DifferentiableOutput]: 838 f = fn.func 839 info = fn.info[key] if fn.info else None 840 outputs: list[DifferentiableOutput] = [ 841 DifferentiableOutput( 842 name=name, 843 type=ret.type, 844 cpp_type=cpp.return_type(ret, symint=True).cpp_type(), 845 ) 846 for name, ret in zip(cpp.return_names(f), f.func.returns) 847 ] 848 output_differentiability = info.output_differentiability if info else None 849 if output_differentiability is not None: 850 if len(output_differentiability) != len(outputs): 851 raise RuntimeError( 852 f"The length of output_differentiability ({len(output_differentiability)}), " 853 f"does not match the number of outputs ({len(outputs)})." 854 ) 855 differentiable_outputs: list[DifferentiableOutput] = [] 856 if False in output_differentiability and f.func.kind() == SchemaKind.inplace: 857 raise RuntimeError( 858 "output_differentiability=False for inplace operation (version_counter won't get updated)" 859 ) 860 for differentiable, output in zip(output_differentiability, outputs): 861 if differentiable: 862 differentiable_outputs.append(output) 863 return differentiable_outputs 864 candidate_differentiable_outputs = list( 865 filter(lambda r: is_differentiable(r.name, r.type, info), outputs) 866 ) 867 if uses_single_grad(info): 868 return candidate_differentiable_outputs[:1] 869 else: 870 return candidate_differentiable_outputs 871