1# mypy: allow-untyped-defs 2from typing import Any, NamedTuple, Tuple 3 4import torch 5import torch.utils._pytree as pytree 6from torch._C._functorch import ( 7 _unwrap_for_grad, 8 _wrap_for_grad, 9 current_level, 10 TransformType, 11) 12from torch._functorch.apis import vmap 13from torch._functorch.utils import enable_single_level_autograd_function 14from torch._functorch.vmap import ( 15 _add_batch_dim, 16 _broadcast_to_and_flatten, 17 restore_vmap, 18 unwrap_batched, 19 wrap_batched, 20) 21from torch._ops import HigherOrderOperator 22from torch.autograd.forward_ad import _set_fwd_grad_enabled 23 24 25# autograd.Function technically runs before the regular PyTorch dispatcher. 26# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot) 27# work with it. One day we might decide to change this, but until then, 28# we need to give the illusion that autograd.Function runs before those things. 29# 30# We do this by using creating a custom HigherOrderOperator that only functorch 31# dispatches specially. 32class CustomFunctionHigherOrderOperator(HigherOrderOperator): 33 def __init__(self) -> None: 34 super().__init__("custom_function_call") 35 36 def __call__(self, autograd_function, *args, **kwargs): 37 # When custom_function_call is done dispatching through functorch, 38 # it should just invoke the autograd.Function. This is consistent 39 # with the autograd.Function behavior of being invoked before the 40 # PyTorch dispatcher. 41 # 42 # This will lead us into trouble later down the line, but this is 43 # pre-existing. There is an invariant that a function traced by 44 # make_fx should have the same behavior when provided the same 45 # Tensor. However, make_fx sees autograd.Function as a composite 46 # (because autograd.Function happens before the Python dispatch key) 47 # and only traces the forward pass. 48 if torch._C._are_functorch_transforms_active(): 49 return super().__call__(autograd_function, *args, **kwargs) 50 return autograd_function.apply(*args, **kwargs) 51 52 53# "custom_function_call" 54# This is the mechanism for an autograd.Function that works with functorch transforms. 55# It wraps an autograd.Function; interactions with functorch transforms are defined 56# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch 57# dispatcher. 58custom_function_call = CustomFunctionHigherOrderOperator() 59 60 61# The grad rule for custom_function_call is to construct a new _SingleLevelFunction 62# (autograd.Function that only works with a single layer (level) of functorch) that: 63# - unwraps the inputs 64# - redispatches to custom_function_call 65# - wraps the outputs 66# and whose backward pass calls the original autograd.Function's backward. 67# 68# Why do we need to redispatch to custom_function_call? 69# ----------------------------------------------------- 70# This is consistent with how ATen operators work with functorch's grad transform: 71# they always redispatch to the original operator. 72# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x) 73# 74# grad1 will: 75# - set up the autograd graph 76# - unwrap the inputs 77# - redispatch to at::sin (*) 78# - rewrap the outputs on the return 79# 80# On the redispatch in (*), grad0 will: 81# - set up the autograd graph 82# - unwrap the inputs 83# - redispatch to at::sin 84# - rewrap the outputs on the return 85# 86# To "set up the autograd graph", we generate a _SingleLevelFunction 87# and apply it. 88@custom_function_call.py_impl(TransformType.Grad) 89@custom_function_call.py_impl(TransformType.Jvp) 90def custom_function_call_grad(interpreter, autograd_function, *operands): 91 Generated = generate_single_level_function(interpreter, autograd_function) 92 with enable_single_level_autograd_function(): 93 flat_out = Generated.apply(*operands) 94 return flat_out 95 96 97def generate_single_level_function(interpreter, autograd_function): 98 level = interpreter.level() 99 100 def forward(*operands): 101 unwrapped_operands = pytree.tree_map_only( 102 torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands 103 ) 104 # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter 105 # the transform. _SingleLevelFunction will turn off both fwd and bwd 106 # gradient computation and we need to turn it back on here. 107 with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower(): 108 unwrapped_output = custom_function_call( 109 autograd_function, *unwrapped_operands 110 ) 111 112 # See NOTE [mark_dirty object identity check] 113 def wrap_fn(output): 114 return _wrap_for_grad(output, level) 115 116 return wrap_outputs_maintaining_identity( 117 unwrapped_output, unwrapped_operands, operands, wrap_fn 118 ) 119 120 def setup_context(ctx, inputs, output): 121 return autograd_function.setup_context(ctx, inputs, output) 122 123 # backward is only used if the transform is TransformType.Grad 124 def backward(ctx, *grads): 125 result = autograd_function.backward(ctx, *grads) 126 return result 127 128 # jvp is only used if the transform is TransformType.Jvp 129 def jvp(ctx, *tangents): 130 result = autograd_function.jvp(ctx, *tangents) 131 return result 132 133 # This is the sequence of magic words to dynamically generate a Subclass with 134 # a given name. A Tensor's .grad_fn field has a class name that is the original 135 # autograd.Function's name + Backward, so we do this to generate some 136 # meaningful name. 137 name = f"{autograd_function.__name__}Generated" 138 Generated = type( 139 name, 140 (torch.autograd.function._SingleLevelFunction,), 141 { 142 "forward": staticmethod(forward), 143 "backward": staticmethod(backward), 144 "jvp": staticmethod(jvp), 145 "setup_context": staticmethod(setup_context), 146 }, 147 ) 148 return Generated 149 150 151# wrap_outputs_maintaining_identity handles outputs from the vmap, 152# backward (vjp), and jvp staticmethod. The way it distinguishes 153# between the vmap case and the {backward, jvp} case is if the out_dims 154# are specified or not. 155# 156# NB: we cannot use out_dims=None as the deciding factor. This because 157# out_dims=None can still happen in the vmap staticmethod! What the 158# user is saying in that case is that their output does not have a 159# dimension that is being vmapped over, which is valid. 160NO_OUT_DIMS = "not specified" 161 162 163# NOTE [mark_dirty object identity check] 164# autograd.Function's ctx.mark_dirty expect a returned input 165# to have the same object identity as the input. 166# Mode-only functorch will greatly simplify this logic. 167def wrap_outputs_maintaining_identity( 168 outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS 169): 170 flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) 171 flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) 172 173 unwrapped_input_to_orig_input = { 174 id(unwrapped): orig 175 for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs) 176 } 177 178 flat_outputs, spec = pytree.tree_flatten(outputs) 179 result = [] 180 181 out_dims_specified = out_dims != NO_OUT_DIMS 182 183 if out_dims_specified: 184 flat_out_dims = _broadcast_to_and_flatten(out_dims, spec) 185 # _broadcast_to_and_flatten returns None if it is unable to broadcast. 186 # TODO: update following link from master to stable once that's out 187 if flat_out_dims is None: 188 raise RuntimeError( 189 f"The autograd.Function's vmap staticmethod returned an " 190 f"incompatible (output, out_dims) tuple. " 191 f"Expected out_dims={out_dims} " 192 f"to be compatible with the structure of `output`. " 193 f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " 194 f"but output has structure {spec}. " 195 f"For more details, please see " 196 f"https://pytorch.org/docs/main/notes/extending.func.html" 197 ) 198 199 for i, output in enumerate(flat_outputs): 200 if not isinstance(output, torch.Tensor): 201 result.append(output) 202 continue 203 if id(output) in unwrapped_input_to_orig_input: 204 result.append(unwrapped_input_to_orig_input[id(output)]) 205 continue 206 if out_dims_specified: 207 result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index] 208 else: 209 result.append(wrap_fn(output)) 210 211 return pytree.tree_unflatten(result, spec) 212 213 214# NOTE: [functorch vjp and autograd interaction] 215# There's an edge case with the functorch vjp and autograd interaction 216# that will eventually be fixed by mode-only functorch. 217# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper, 218# so we (the framework) need to do it manually. Regular PyTorch operators 219# automatically do so this is consistent. 220# 221# class MyExp(torch.autograd.Function): 222# @staticmethod 223# def forward(x): 224# return x.exp() 225# 226# @staticmethod 227# def setup_context(ctx, inputs, output): 228# y = output 229# ctx.save_for_backward(y) 230# 231# @staticmethod 232# def backward(gy): 233# y, = ctx.saved_tensors() 234# return MyMul.apply(gy, y) 235# 236# x = torch.randn([], requires_grad=True) 237# gy = torch.randn([], requires_grad=True) 238# _, vjp_fn = vjp(MySin.apply, x) 239# result = vjp_fn(gy) 240# 241# MyMul is an autograd.Function that is not shown here. 242# It saves a `y` for backward (since gy requires grad). 243# 244# in vjp_fn(gy), we get: 245# > MyMul.apply(gy, GradTensorWrapper(y, level=dead)) 246# Because the y that is saved for backward by MyExp is a GradTensorWrapper 247# but is now dead since we are outside the vjp context. 248# 249# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper, 250# will automatically unwrap the GradTensorWrapper when applied. 251# But since autograd.Function technically sits above the regular PyTorch 252# dispatcher, it doesn't get this treatment. So we manually do 253# the unwrapping to be consistent with regular PyTorch dispatcher operations. 254 255 256class VmapInfo(NamedTuple): 257 batch_size: int 258 randomness: str 259 260 261def has_overriden_vmap_rule(autograd_function): 262 return autograd_function.vmap is not torch.autograd.Function.vmap 263 264 265def validate_vmap_returns_tuple_of_two_elements(result): 266 base_error_msg = ( 267 "Expected the vmap staticmethod to have two returns, an output " 268 "and out_dims with pytree structure compatible with the output. " 269 ) 270 if not isinstance(result, tuple): 271 raise RuntimeError(base_error_msg + f"Got a {type(result)} instead") 272 if not len(result) == 2: 273 raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead") 274 275 276@custom_function_call.py_impl(TransformType.Vmap) 277def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs): 278 if any( 279 isinstance(val, torch.Tensor) 280 for val in torch.utils._pytree.tree_flatten(kwargs)[0] 281 ): 282 raise NotImplementedError( 283 f"Run vmap on autograd.Function with kwarg-only Tensor args. " 284 f"Please do not pass kwarg-only Tensors to autograd.Function. " 285 f"Got: {kwargs}" 286 ) 287 288 if autograd_function.generate_vmap_rule: 289 if has_overriden_vmap_rule(autograd_function): 290 # TODO: Update link to stable once that's out 291 # https://github.com/pytorch/pytorch/issues/92029 292 raise RuntimeError( 293 f"You tried to vmap over {autograd_function.__name__}, but " 294 f"it has both generate_vmap_rule=True and an overriden vmap " 295 f"staticmethod. Please set generate_vmap_rule=False or delete " 296 f"the overriden vmap staticmethod to avoid ambiguity. " 297 f"For more details, please see " 298 f"https://pytorch.org/docs/main/notes/extending.func.html" 299 ) 300 return custom_function_call_vmap_generate_rule( 301 interpreter, autograd_function, *operands 302 ) 303 304 if not has_overriden_vmap_rule(autograd_function): 305 # TODO: Update link to stable once that's out 306 # https://github.com/pytorch/pytorch/issues/92029 307 raise RuntimeError( 308 f"You tried to vmap over {autograd_function.__name__}, but " 309 f"it does not have vmap support. Please override and implement the " 310 f"vmap staticmethod or set generate_vmap_rule=True. " 311 f"For more details, please see " 312 f"https://pytorch.org/docs/main/notes/extending.func.html" 313 ) 314 315 return custom_function_call_vmap_helper( 316 interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs 317 ) 318 319 320def custom_function_call_vmap_helper( 321 interpreter, vmap_function, op, *operands, **kwargs 322): 323 current_level = interpreter.level() 324 info = VmapInfo( 325 batch_size=interpreter.batch_size(), 326 randomness=interpreter.randomness(), 327 ) 328 unwrapped_operands, in_dims = unwrap_batched(operands, current_level) 329 # If none of the tensors are batched at the current level, then we skip the 330 # current level. This saves the user from needing to handle this case in 331 # their vmap staticmethod (and is consistent with our C++ batching rule API) 332 if pytree.tree_all(lambda dim: dim is None, in_dims): 333 with interpreter.lower(): 334 if isinstance(op, torch.autograd.function.FunctionMeta): 335 return custom_function_call(op, *operands) 336 else: 337 return op(*operands, **kwargs) 338 339 with interpreter.lower(): 340 result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs) 341 validate_vmap_returns_tuple_of_two_elements(result) 342 unwrapped_output, out_dims = result 343 344 # See NOTE [mark_dirty object identity check] 345 def wrap_fn(output, out_dim): 346 return ( 347 output 348 if out_dim is None 349 else _add_batch_dim(output, out_dim, current_level) 350 ) 351 352 return wrap_outputs_maintaining_identity( 353 unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims 354 ) 355 356 357def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands): 358 unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level()) 359 vmapped_function, get_out_dims = vmapify_autograd_function( 360 autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness() 361 ) 362 363 with interpreter.lower(): 364 output = custom_function_call(vmapped_function, *unwrapped_operands) 365 366 out_dims = get_out_dims() 367 return wrap_batched(output, out_dims, interpreter.level()) 368 369 370@custom_function_call.py_impl(TransformType.Functionalize) 371def custom_function_call_functionalize( 372 interpreter, autograd_function, generate_vmap_rule, *operands 373): 374 raise RuntimeError("NYI: Functionalize rule for custom_function_call") 375 376 377def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness): 378 # The following values are saved from the forward() and setup_context() 379 # and used in backward(). 380 # Why do we save the values out here instead of on the ctx object? 381 # - out_dims: There's no way to retrieve this from forward() 382 # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting 383 # vmap(vmap( but not completely sure if it is a problem. If we 384 # assigned those fields to the ctx object, the worry is that they 385 # get overwritten. 386 init_val = "not populated" 387 out_dims = init_val 388 input_shapes: Any = init_val 389 saved_tensors_bdims: Any = init_val 390 391 def forward(*operands): 392 nonlocal out_dims 393 outputs, out_dims = restore_vmap( 394 autograd_function.forward, in_dims, batch_size, randomness 395 )(*operands) 396 return outputs 397 398 def setup_context(ctx, inputs, outputs): 399 input_shapes_ = None 400 saved_tensors_bdims_ = None 401 402 def inner(inputs, outputs): 403 # wrapped_ctx.save_for_backward will: 404 # - unwrap batchedtensors into (tensor, bdim) 405 # - save_for_backward(*unwrapped_tensors) 406 # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims 407 wrapped_ctx = CtxCustomSave(ctx, current_level()) 408 autograd_function.setup_context(wrapped_ctx, inputs, outputs) 409 410 # input_shapes are used for reductify later to reduce expanded gradients 411 # to the correct shape. 412 # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?] 413 # for more details 414 nonlocal input_shapes_ 415 input_shapes_ = tuple( 416 inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs 417 ) 418 nonlocal saved_tensors_bdims_ 419 saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims 420 421 # See NOTE: [Why do we need to run setup_context under a vmap?] 422 restore_vmap( 423 inner, 424 (in_dims, out_dims), 425 batch_size, 426 randomness, 427 )(inputs, outputs) 428 429 nonlocal input_shapes 430 input_shapes = input_shapes_ 431 nonlocal saved_tensors_bdims 432 saved_tensors_bdims = saved_tensors_bdims_ 433 434 def jvp(ctx, *tangents): 435 assert out_dims != init_val 436 assert saved_tensors_bdims != init_val 437 438 def jvp_no_context(saved_tensors, tangents): 439 wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) 440 return autograd_function.jvp(wrapped_ctx, *tangents) 441 442 tangent_in_dims = get_tangents_in_dims(in_dims, tangents) 443 out_tangents, out_tangents_dims = restore_vmap( 444 jvp_no_context, 445 (saved_tensors_bdims, tangent_in_dims), 446 batch_size, 447 randomness, 448 )(ctx.saved_tensors, tangents) 449 450 result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size) 451 return result 452 453 def backward(ctx, *grad_outputs): 454 assert out_dims != init_val 455 assert input_shapes != init_val 456 assert saved_tensors_bdims != init_val 457 458 def backward_no_context(inputs): 459 saved_tensors, grad_outputs = inputs 460 wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors) 461 return autograd_function.backward(wrapped_ctx, *grad_outputs) 462 463 grad_ins, grad_ins_dims = restore_vmap( 464 backward_no_context, 465 ((saved_tensors_bdims, out_dims),), 466 batch_size, 467 randomness, 468 )((ctx.saved_tensors, grad_outputs)) 469 result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes) 470 return result 471 472 name = f"Vmapped{autograd_function.__name__}" 473 Generated = type( 474 name, 475 (torch.autograd.Function,), 476 { 477 "forward": staticmethod(forward), 478 "backward": staticmethod(backward), 479 "jvp": staticmethod(jvp), 480 "setup_context": staticmethod(setup_context), 481 "generate_vmap_rule": True, 482 }, 483 ) 484 485 def get_out_dims(): 486 assert out_dims != init_val 487 return out_dims 488 489 return Generated, get_out_dims 490 491 492# tangents might be None, so we need to replace 493# the corresponding in_dims with None. 494def get_tangents_in_dims(input_dims, tangents): 495 flat_in_dims, spec = pytree.tree_flatten(input_dims) 496 flat_tangents = pytree.arg_tree_leaves(*tangents) 497 result = [ 498 None if tangent is None else in_dim 499 for in_dim, tangent in zip(flat_in_dims, flat_tangents) 500 ] 501 return pytree.tree_unflatten(result, spec) 502 503 504# NOTE: [Why do we need to run setup_context under a vmap?] 505# Consider the following autograd.Function 506# 507# class Sum(torch.autograd.Function): 508# @staticmethod 509# def forward(x): 510# return x.sum() 511# @staticmethod 512# def setup_context(ctx, inputs, outputs): 513# ctx.x_shape = inputs[0] 514# @staticmethod 515# def backward(ctx, gy): 516# return gy.expand(ctx.x_shape) 517# 518# x = torch.randn(B, 4) 519# in_dims = 0 520# vmap(Sum.apply, in_dims)(x) 521# 522# Let's assume for a moment that we didn't vmap setup_context in VmappedSum: 523# 524# class VmappedSum(torch.autograd.Function): 525# @staticmethod 526# def forward(x): 527# return vmap(Sum.forward, in_dims)(x) 528# 529# @staticmethod 530# def setup_context(ctx, inputs, outputs): 531# Sum.setup_context(ctx, inputs, outputs) 532# 533# @staticmethod 534# def backward(ctx, gy): 535# def backward_no_context(gy): 536# return gy.expand(ctx.x_shape) 537# 538# dims = (0,) 539# gx = vmap(backward_no_context, dims)(gy) 540# return gx 541# 542# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B], 543# and we're doing: 544# 545# def backward_no_context(gy): 546# return gy.expand([B, 4]) 547# 548# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]") 549# 550# This gives us the wrong result (gx has shape [B, B, 4], but it should 551# have shape [4]). Performing vmap over setup_context means the shape 552# saved has shape [4] and leads to a correct result shape for gx. 553 554 555# Wraps a ctx object. Forwards all attr accesses to the underlying object 556# except for the attrs in _pt_attrs 557class WrappedCtx: 558 _pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx") 559 560 def __init__(self, ctx): 561 if not isinstance(ctx, WrappedCtx): 562 reserved_attrs = type(self)._pt_reserved_attrs 563 for name in reserved_attrs: 564 if not hasattr(ctx, name): 565 continue 566 raise RuntimeError( 567 f"PyTorch reserves the {reserved_attrs} field on ctx. " 568 "Please name your fields on ctx something else to avoid name " 569 "collision." 570 ) 571 self._pt_inner_ctx = ctx 572 573 def __getattr__(self, name): 574 return getattr(self._pt_inner_ctx, name) 575 576 def __setattr__(self, name, value): 577 if name in type(self)._pt_reserved_attrs: 578 self.__dict__[name] = value 579 return 580 return setattr(self._pt_inner_ctx, name, value) 581 582 583# Wraps ctx to create a new ctx object that overrides saved_tensors. 584class CtxWithSavedTensors(WrappedCtx): 585 _pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs) 586 587 def __init__(self, ctx, new_saved_tensors): 588 super().__init__(ctx) 589 self._pt_new_saved_tensors = new_saved_tensors 590 591 @property 592 def saved_tensors(self): 593 return self._pt_new_saved_tensors 594 595 596class CtxCustomSave(WrappedCtx): 597 _pt_reserved_attrs = ( 598 "_pt_saved_tensors_bdims", 599 "_pt_current_level", 600 *WrappedCtx._pt_reserved_attrs, 601 ) 602 603 def __init__(self, ctx, current_level): 604 super().__init__(ctx) 605 self._pt_saved_tensors_bdims = () 606 self._pt_current_level = current_level 607 608 def save_for_backward(self, *tensors): 609 unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) 610 self._pt_inner_ctx.save_for_backward(*unwrapped_tensors) 611 self._pt_saved_tensors_bdims = bdims 612 613 def save_for_forward(self, *tensors): 614 unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level) 615 self._pt_inner_ctx.save_for_forward(*unwrapped_tensors) 616 self._pt_saved_tensors_bdims = bdims 617 618 619def reductify( 620 grad_input, 621 grad_input_bdim, 622 input_bdim, 623 batch_size, 624 target_shape_without_bdim_to_reduce_to=None, 625): 626 if not isinstance(grad_input, tuple): 627 grad_input = (grad_input,) 628 if not isinstance(grad_input_bdim, tuple): 629 grad_input_bdim = (grad_input_bdim,) 630 if not isinstance(input_bdim, tuple): 631 input_bdim = (input_bdim,) 632 633 if target_shape_without_bdim_to_reduce_to is None: 634 target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,) 635 result = tuple( 636 reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape) 637 for gi, gi_bdim, i_bdim, maybe_ishape in zip( 638 grad_input, 639 grad_input_bdim, 640 input_bdim, 641 target_shape_without_bdim_to_reduce_to, 642 ) 643 ) 644 return result 645 646 647def reductify_leaf( 648 grad_input, 649 grad_input_bdim, 650 input_bdim, 651 batch_size, 652 target_shape_without_bdim_to_reduce_to=None, 653): 654 if grad_input is None: 655 return None 656 657 if grad_input_bdim is None and input_bdim is None: 658 return grad_input 659 660 if grad_input_bdim is not None and input_bdim is None: 661 return grad_input.sum(grad_input_bdim) 662 663 # NOTE: [Why can't we rely on autograd to reduce expanded gradients?] 664 # For reverse-mode AD, 665 # given a grad_input and input, it is valid for the user to return a 666 # grad_input that has a broadcasted shape when compared to the input. 667 # In this situation, autograd automatically reduces the grad_input to 668 # the shape of the input. 669 # 670 # However, when input_bdim is not None, we have problems. 671 # 672 # [example 1] 673 # grad_input: Tensor[3, 4], input: Tensor[B, 4] 674 # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable 675 # from [B, 4]. 676 # 677 # [example 2] 678 # grad_input: Tensor[3, B, 4], input: Tensor[B, 4] 679 # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable 680 # from [B, 4]. 681 # 682 # This means that we need to also reduce the grad_input to the shape of the 683 # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag; 684 # if not-None then we do the reducing manually, otherwise, we do not do a reduction. 685 assert input_bdim is not None 686 687 if grad_input_bdim is None: 688 grad_input = grad_input.unsqueeze(input_bdim) 689 new_shape = list(grad_input.shape) 690 new_shape[input_bdim] = batch_size 691 grad_input = grad_input.expand(new_shape) 692 grad_input_bdim = input_bdim 693 694 if target_shape_without_bdim_to_reduce_to is not None: 695 return vmap( 696 torch.Tensor.sum_to_size, 697 in_dims=(grad_input_bdim, None), 698 out_dims=input_bdim, 699 )(grad_input, target_shape_without_bdim_to_reduce_to) 700 701 if input_bdim != grad_input_bdim: 702 grad_input = grad_input.movedim(grad_input_bdim, input_bdim) 703 return grad_input 704 705 706def autograd_function_forward_rewritten(original_forward, original_setup_context): 707 def new_forward(ctx, *args, **kwargs): 708 output = original_forward(*args, **kwargs) 709 original_setup_context(ctx, args, output) 710 return output 711 712 return new_forward 713 714 715class AutogradFunctionApply(HigherOrderOperator): 716 def __init__(self) -> None: 717 super().__init__("autograd_function_apply") 718 719 def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs): 720 saved_values = None 721 args_tensor_mask = fwd_kwargs["args_tensor_mask"] 722 non_differentiable_idx = fwd_kwargs["non_differentiable_idx"] 723 length_of_tensor_args = sum(args_tensor_mask) 724 # Filter out the original tensor args from fwd_args, 725 # lifted freevars should not be args of ApplyTemplate.apply 726 # since we don't need to calculate the gradients of them. 727 new_fwd_args = fwd_args[:length_of_tensor_args] 728 729 class ApplyTemplate(torch.autograd.Function): 730 @staticmethod 731 def forward(ctx, *args): 732 nonlocal saved_values 733 output, saved_values = fwd(None, *fwd_args) 734 735 # If users call ctx.mark_non_differentiable() in the original fwd function. 736 if len(non_differentiable_idx) > 0: 737 non_differentiable_output = [] 738 for i, x in enumerate(output): 739 if i in non_differentiable_idx: 740 non_differentiable_output.append(x) 741 ctx.mark_non_differentiable(*non_differentiable_output) 742 743 return output 744 745 @staticmethod 746 def backward(ctx, *grad): 747 return bwd(None, *grad, *saved_values) 748 749 return ApplyTemplate.apply(*new_fwd_args) 750 751 752autograd_function_apply = AutogradFunctionApply() 753