1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import contextlib 4import platform 5import uuid 6import warnings 7import weakref 8from collections import defaultdict 9from typing import * # noqa: F403 10import enum 11from weakref import ReferenceType 12 13import torch 14import torch.fx.traceback as fx_traceback 15from torch._functorch._aot_autograd.functional_utils import is_fun 16from torch.utils._pytree import tree_map 17from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode 18from torch.utils._python_dispatch import TorchDispatchMode 19 20__all__ = [ 21 "checkpoint", 22 "checkpoint_sequential", 23 "CheckpointError", 24 "CheckpointFunction", 25 "check_backward_validity", 26 "detach_variable", 27 "get_device_states", 28 "set_device_states", 29 "noop_context_fn", 30 "set_checkpoint_early_stop", 31 "DefaultDeviceType", 32 "set_checkpoint_debug_enabled", 33 "CheckpointPolicy", 34 "SelectiveCheckpointContext", 35 "create_selective_checkpoint_contexts", 36 "SAC_IGNORED_OPS", 37] 38 39_DEFAULT_DETERMINISM_MODE = "default" 40 41_checkpoint_debug_enabled: Optional[bool] = None 42 43 44@contextlib.contextmanager 45def set_checkpoint_debug_enabled(enabled: Optional[bool]): 46 """ 47 Context manager that sets whether checkpoint should print additional debug 48 information when running. See the ``debug`` flag for 49 :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that 50 when set, this context manager overrides the value of ``debug`` passed to 51 checkpoint. To defer to the local setting, pass ``None`` to this context. 52 53 Args: 54 enabled (bool): Whether checkpoint should print debug information. 55 Default is 'None'. 56 """ 57 global _checkpoint_debug_enabled 58 try: 59 prev = _checkpoint_debug_enabled 60 _checkpoint_debug_enabled = enabled 61 yield 62 finally: 63 _checkpoint_debug_enabled = prev 64 65 66def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: 67 if isinstance(inputs, tuple): 68 out = [] 69 for inp in inputs: 70 if not isinstance(inp, torch.Tensor): 71 out.append(inp) 72 continue 73 74 x = inp.detach() 75 x.requires_grad = inp.requires_grad 76 out.append(x) 77 return tuple(out) 78 else: 79 raise RuntimeError( 80 "Only tuple of tensors is supported. Got Unsupported input type: ", 81 type(inputs).__name__, 82 ) 83 84 85def check_backward_validity(inputs: Iterable[Any]) -> None: 86 if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): 87 warnings.warn( 88 "None of the inputs have requires_grad=True. Gradients will be None" 89 ) 90 91 92def _get_device_module(device="cuda"): 93 if device == "meta": 94 return torch.device("meta") 95 device_module = getattr(torch, device) 96 return device_module 97 98 99class DefaultDeviceType: 100 r""" 101 A class that manages the default device type for checkpointing. 102 103 If no non-CPU tensors are present, the default device type will 104 be used. The default value is 'cuda'. The device type is used in 105 the checkpointing process when determining which device states 106 to save and restore for recomputation. 107 """ 108 109 _default_device_type = "cuda" 110 111 @staticmethod 112 def set_device_type(device: str = "cuda"): 113 """ 114 Set the default device type for checkpointing. 115 116 Args: 117 device (str): The device type to be set as default. Default is 'cuda'. 118 """ 119 DefaultDeviceType._default_device_type = device 120 121 @staticmethod 122 def get_device_type() -> str: 123 """ 124 Get the current default device type for checkpointing. 125 126 Returns: 127 str: The current default device type. 128 """ 129 return DefaultDeviceType._default_device_type 130 131 132def _infer_device_type(*args): 133 device_types = [] 134 135 def add_device_types(arg): 136 nonlocal device_types 137 if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": 138 device_types.append(arg.device.type) 139 tree_map(add_device_types, args) 140 141 device_types_set = set(device_types) 142 if len(device_types_set) > 1: 143 warnings.warn( 144 "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " 145 "Device state will only be saved for devices of a single device type, and the remaining " 146 "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " 147 "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " 148 "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" 149 f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}" 150 ) 151 if len(device_types) == 0: 152 return DefaultDeviceType.get_device_type() 153 elif "cuda" in device_types_set: 154 return "cuda" 155 else: 156 return device_types[0] 157 158 159# We can't know if the run_fn will internally move some args to different devices, 160# which would require logic to preserve rng states for those devices as well. 161# We could paranoically stash and restore ALL the rng states for all visible devices, 162# but that seems very wasteful for most cases. Compromise: Stash the RNG state for 163# the device of all Tensor args. 164# 165# To consider: maybe get_device_states and set_device_states should reside in torch/random.py? 166def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: 167 # This will not error out if "arg" is a CPU tensor or a non-tensor type because 168 # the conditionals short-circuit. 169 fwd_device_ids = [] 170 171 def add_device_ids(arg): 172 nonlocal fwd_device_ids 173 if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: 174 fwd_device_ids.append(arg.get_device()) 175 tree_map(add_device_ids, args) 176 177 fwd_device_states = [] 178 device_module = _get_device_module(_infer_device_type(*args)) 179 for device_id in fwd_device_ids: 180 with device_module.device(device_id): 181 fwd_device_states.append(device_module.get_rng_state()) 182 183 return fwd_device_ids, fwd_device_states 184 185 186def set_device_states(devices, states, *, device_type=None) -> None: 187 """Sets random number generator states for the specified devices. 188 189 Args: 190 devices: Device ids to set states for. 191 states: States to set. 192 device_type: ``device_type`` of the devices to set states for. Default 193 is the device returned by a call to ``DefaultDeviceType.get_device_type()``, 194 which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``. 195 """ 196 if device_type is None: 197 device_type = DefaultDeviceType.get_device_type() 198 if device_type == "meta": 199 return 200 device_module = _get_device_module(device_type) 201 for device, state in zip(devices, states): 202 with device_module.device(device): 203 device_module.set_rng_state(state) 204 205 206def _get_autocast_kwargs(device_type="cuda"): 207 if torch.amp.is_autocast_available(device_type): 208 device_autocast_kwargs = { 209 "enabled": torch.is_autocast_enabled(device_type), 210 "dtype": torch.get_autocast_dtype(device_type), 211 "cache_enabled": torch.is_autocast_cache_enabled(), 212 } 213 else: 214 device_autocast_kwargs = None 215 216 cpu_autocast_kwargs = { 217 "enabled": torch.is_autocast_enabled('cpu'), 218 "dtype": torch.get_autocast_dtype('cpu'), 219 "cache_enabled": torch.is_autocast_cache_enabled(), 220 } 221 222 return device_autocast_kwargs, cpu_autocast_kwargs 223 224 225class CheckpointFunction(torch.autograd.Function): 226 @staticmethod 227 def forward(ctx, run_function, preserve_rng_state, *args): 228 check_backward_validity(args) 229 ctx.run_function = run_function 230 ctx.preserve_rng_state = preserve_rng_state 231 # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 232 ctx.device_type = _infer_device_type(*args) 233 ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( 234 ctx.device_type 235 ) 236 if preserve_rng_state: 237 ctx.fwd_cpu_state = torch.get_rng_state() 238 # Don't eagerly initialize the cuda context by accident. 239 # (If the user intends that the context is initialized later, within their 240 # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 241 # we have no way to anticipate this will happen before we run the function.) 242 ctx.had_device_in_fwd = False 243 device_module = _get_device_module(ctx.device_type) 244 if getattr(device_module, "_initialized", False): 245 ctx.had_device_in_fwd = True 246 ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) 247 248 # Save non-tensor inputs in ctx, keep a placeholder None for tensors 249 # to be filled out during the backward. 250 ctx.inputs = [] 251 ctx.tensor_indices = [] 252 tensor_inputs = [] 253 for i, arg in enumerate(args): 254 if torch.is_tensor(arg): 255 tensor_inputs.append(arg) 256 ctx.tensor_indices.append(i) 257 ctx.inputs.append(None) 258 else: 259 ctx.inputs.append(arg) 260 261 ctx.save_for_backward(*tensor_inputs) 262 263 with torch.no_grad(): 264 outputs = run_function(*args) 265 return outputs 266 267 @staticmethod 268 def backward(ctx, *args): 269 if not torch.autograd._is_checkpoint_valid(): 270 raise RuntimeError( 271 "When use_reentrant=True, torch.utils.checkpoint is incompatible" 272 " with .grad() or passing an `inputs` parameter to .backward()." 273 " To resolve this error, you can either set use_reentrant=False," 274 " or call .backward() without passing the `inputs` argument." 275 ) 276 # Copy the list to avoid modifying original list. 277 inputs = list(ctx.inputs) 278 tensor_indices = ctx.tensor_indices 279 tensors = ctx.saved_tensors 280 281 # Fill in inputs with appropriate saved tensors. 282 for i, idx in enumerate(tensor_indices): 283 inputs[idx] = tensors[i] 284 285 # Stash the surrounding rng state, and mimic the state that was 286 # present at this time during forward. Restore the surrounding state 287 # when we're done. 288 rng_devices = [] 289 if ctx.preserve_rng_state and ctx.had_device_in_fwd: 290 rng_devices = ctx.fwd_devices 291 with torch.random.fork_rng( 292 devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type 293 ): 294 if ctx.preserve_rng_state: 295 torch.set_rng_state(ctx.fwd_cpu_state) 296 if ctx.had_device_in_fwd: 297 set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) 298 detached_inputs = detach_variable(tuple(inputs)) 299 300 device_autocast_ctx = torch.amp.autocast( 301 device_type=ctx.device_type, **ctx.device_autocast_kwargs 302 ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() 303 with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] 304 outputs = ctx.run_function(*detached_inputs) 305 306 if isinstance(outputs, torch.Tensor): 307 outputs = (outputs,) 308 309 # run backward() with only tensor that requires grad 310 outputs_with_grad = [] 311 args_with_grad = [] 312 for i in range(len(outputs)): 313 if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: 314 outputs_with_grad.append(outputs[i]) 315 args_with_grad.append(args[i]) 316 if len(outputs_with_grad) == 0: 317 raise RuntimeError( 318 "none of output has requires_grad=True," 319 " this checkpoint() is not necessary" 320 ) 321 torch.autograd.backward(outputs_with_grad, args_with_grad) 322 grads = tuple( 323 inp.grad if isinstance(inp, torch.Tensor) else None 324 for inp in detached_inputs 325 ) 326 327 return (None, None) + grads 328 329 330def noop_context_fn(): 331 return contextlib.nullcontext(), contextlib.nullcontext() 332 333# TorchDynamo does not step inside utils.checkpoint function. The flow 334# looks likes this 335# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by 336# speculatively checking if the forward function is safe to trace. 337# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher 338# order op. As a result, TorchDynamo does not look inside utils.checkpoint. 339# 3) If not, then TorchDynamo falls back to eager by performing a graph 340# break. And here, the following disable wrapper ensures that 341# TorchDynamo does not trigger again on the frames created by 342# utils.checkpoint innards. 343@torch._disable_dynamo 344def checkpoint( 345 function, 346 *args, 347 use_reentrant: Optional[bool] = None, 348 context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, 349 determinism_check: str = _DEFAULT_DETERMINISM_MODE, 350 debug: bool = False, 351 **kwargs 352): 353 r"""Checkpoint a model or part of the model. 354 355 Activation checkpointing is a technique that trades compute for memory. 356 Instead of keeping tensors needed for backward alive until they are used in 357 gradient computation during backward, forward computation in checkpointed 358 regions omits saving tensors for backward and recomputes them during the 359 backward pass. Activation checkpointing can be applied to any part of a 360 model. 361 362 There are currently two checkpointing implementations available, determined 363 by the :attr:`use_reentrant` parameter. It is recommended that you use 364 ``use_reentrant=False``. Please refer the note below for a discussion of 365 their differences. 366 367 .. warning:: 368 369 If the :attr:`function` invocation during the backward pass differs 370 from the forward pass, e.g., due to a global variable, the checkpointed 371 version may not be equivalent, potentially causing an 372 error being raised or leading to silently incorrect gradients. 373 374 .. warning:: 375 376 The ``use_reentrant`` parameter should be passed explicitly. In version 377 2.4 we will raise an exception if ``use_reentrant`` is not passed. 378 If you are using the ``use_reentrant=True`` variant, please refer to the 379 note below for important considerations and potential limitations. 380 381 .. note:: 382 383 The reentrant variant of checkpoint (``use_reentrant=True``) and 384 the non-reentrant variant of checkpoint (``use_reentrant=False``) 385 differ in the following ways: 386 387 * Non-reentrant checkpoint stops recomputation as soon as all needed 388 intermediate activations have been recomputed. This feature is enabled 389 by default, but can be disabled with :func:`set_checkpoint_early_stop`. 390 Reentrant checkpoint always recomputes :attr:`function` in its 391 entirety during the backward pass. 392 393 * The reentrant variant does not record the autograd graph during the 394 forward pass, as it runs with the forward pass under 395 :func:`torch.no_grad`. The non-reentrant version does record the 396 autograd graph, allowing one to perform backward on the graph within 397 checkpointed regions. 398 399 * The reentrant checkpoint only supports the 400 :func:`torch.autograd.backward` API for the backward pass without its 401 `inputs` argument, while the non-reentrant version supports all ways 402 of performing the backward pass. 403 404 * At least one input and output must have ``requires_grad=True`` for the 405 reentrant variant. If this condition is unmet, the checkpointed part 406 of the model will not have gradients. The non-reentrant version does 407 not have this requirement. 408 409 * The reentrant version does not consider tensors in nested structures 410 (e.g., custom objects, lists, dicts, etc) as participating in 411 autograd, while the non-reentrant version does. 412 413 * The reentrant checkpoint does not support checkpointed regions with 414 detached tensors from the computational graph, whereas the 415 non-reentrant version does. For the reentrant variant, if the 416 checkpointed segment contains tensors detached using ``detach()`` or 417 with :func:`torch.no_grad`, the backward pass will raise an error. 418 This is because ``checkpoint`` makes all the outputs require gradients 419 and this causes issues when a tensor is defined to have no gradient in 420 the model. To avoid this, detach the tensors outside of the 421 ``checkpoint`` function. 422 423 Args: 424 function: describes what to run in the forward pass of the model or 425 part of the model. It should also know how to handle the inputs 426 passed as the tuple. For example, in LSTM, if user passes 427 ``(activation, hidden)``, :attr:`function` should correctly use the 428 first input as ``activation`` and the second input as ``hidden`` 429 preserve_rng_state(bool, optional): Omit stashing and restoring 430 the RNG state during each checkpoint. Note that under torch.compile, 431 this flag doesn't take effect and we always preserve RNG state. 432 Default: ``True`` 433 use_reentrant(bool): 434 specify whether to use the activation checkpoint variant that 435 requires reentrant autograd. This parameter should be passed 436 explicitly. In version 2.5 we will raise an exception if 437 ``use_reentrant`` is not passed. If ``use_reentrant=False``, 438 ``checkpoint`` will use an implementation that does not require 439 reentrant autograd. This allows ``checkpoint`` to support additional 440 functionality, such as working as expected with 441 ``torch.autograd.grad`` and support for keyword arguments input into 442 the checkpointed function. 443 context_fn(Callable, optional): A callable returning a tuple of two 444 context managers. The function and its recomputation will be run 445 under the first and second context managers respectively. 446 This argument is only supported if ``use_reentrant=False``. 447 determinism_check(str, optional): A string specifying the determinism 448 check to perform. By default it is set to ``"default"`` which 449 compares the shapes, dtypes, and devices of the recomputed tensors 450 against those the saved tensors. To turn off this check, specify 451 ``"none"``. Currently these are the only two supported values. 452 Please open an issue if you would like to see more determinism 453 checks. This argument is only supported if ``use_reentrant=False``, 454 if ``use_reentrant=True``, the determinism check is always disabled. 455 debug(bool, optional): If ``True``, error messages will also include 456 a trace of the operators ran during the original forward computation 457 as well as the recomputation. This argument is only supported if 458 ``use_reentrant=False``. 459 args: tuple containing inputs to the :attr:`function` 460 461 Returns: 462 Output of running :attr:`function` on :attr:`*args` 463 """ 464 if use_reentrant is None: 465 warnings.warn( 466 "torch.utils.checkpoint: the use_reentrant parameter should be " 467 "passed explicitly. In version 2.5 we will raise an exception " 468 "if use_reentrant is not passed. use_reentrant=False is " 469 "recommended, but if you need to preserve the current default " 470 "behavior, you can pass use_reentrant=True. Refer to docs for more " 471 "details on the differences between the two variants.", 472 stacklevel=2 473 ) 474 use_reentrant = True 475 476 # Hack to mix *args with **kwargs in a python 2.7-compliant way 477 preserve = kwargs.pop("preserve_rng_state", True) 478 if kwargs and use_reentrant: 479 raise ValueError( 480 "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) 481 ) 482 483 if use_reentrant: 484 if context_fn is not noop_context_fn or debug is not False: 485 raise ValueError( 486 "Passing `context_fn` or `debug` is only supported when " 487 "use_reentrant=False." 488 ) 489 return CheckpointFunction.apply(function, preserve, *args) 490 else: 491 gen = _checkpoint_without_reentrant_generator( 492 function, preserve, context_fn, determinism_check, debug, *args, **kwargs 493 ) 494 # Runs pre-forward logic 495 next(gen) 496 ret = function(*args, **kwargs) 497 # Runs post-forward logic 498 try: 499 next(gen) 500 except StopIteration: 501 return ret 502 503 504def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): 505 r"""Checkpoint a sequential model to save memory. 506 507 Sequential models execute a list of modules/functions in order 508 (sequentially). Therefore, we can divide such a model in various segments 509 and checkpoint each segment. All segments except the last will not store 510 the intermediate activations. The inputs of each checkpointed segment will 511 be saved for re-running the segment in the backward pass. 512 513 .. warning:: 514 The ``use_reentrant`` parameter should be passed explicitly. In version 515 2.4 we will raise an exception if ``use_reentrant`` is not passed. 516 If you are using the ``use_reentrant=True` variant, please see 517 :func:`~torch.utils.checkpoint.checkpoint` for 518 the important considerations and limitations of this variant. It is 519 recommended that you use ``use_reentrant=False``. 520 521 .. warning: 522 Since PyTorch 1.4, it allows only one Tensor as the input and 523 intermediate outputs, just like :class:`torch.nn.Sequential`. 524 525 Args: 526 functions: A :class:`torch.nn.Sequential` or the list of modules or 527 functions (comprising the model) to run sequentially. 528 segments: Number of chunks to create in the model 529 input: A Tensor that is input to :attr:`functions` 530 preserve_rng_state(bool, optional): Omit stashing and restoring 531 the RNG state during each checkpoint. 532 Default: ``True`` 533 use_reentrant(bool): 534 specify whether to use the activation checkpoint variant that 535 requires reentrant autograd. This parameter should be passed 536 explicitly. In version 2.5 we will raise an exception if 537 ``use_reentrant`` is not passed. If ``use_reentrant=False``, 538 ``checkpoint`` will use an implementation that does not require 539 reentrant autograd. This allows ``checkpoint`` to support additional 540 functionality, such as working as expected with 541 ``torch.autograd.grad`` and support for keyword arguments input into 542 the checkpointed function. 543 544 Returns: 545 Output of running :attr:`functions` sequentially on :attr:`*inputs` 546 547 Example: 548 >>> # xdoctest: +SKIP("stub") 549 >>> model = nn.Sequential(...) 550 >>> input_var = checkpoint_sequential(model, chunks, input_var) 551 """ 552 if use_reentrant is None: 553 warnings.warn( 554 "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " 555 "parameter should be passed explicitly. " 556 "In version 2.5 we will raise an exception if use_reentrant " 557 "is not passed. use_reentrant=False is " 558 "recommended, but if you need to preserve the current default " 559 "behavior, you can pass use_reentrant=True. Refer to docs for more " 560 "details on the differences between the two variants." 561 ) 562 use_reentrant = True 563 564 # Hack for keyword-only parameter in a python 2.7-compliant way 565 preserve = kwargs.pop("preserve_rng_state", True) 566 if kwargs: 567 raise ValueError( 568 "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) 569 ) 570 571 def run_function(start, end, functions): 572 def forward(input): 573 for j in range(start, end + 1): 574 input = functions[j](input) 575 return input 576 577 return forward 578 579 if isinstance(functions, torch.nn.Sequential): 580 functions = list(functions.children()) 581 582 segment_size = len(functions) // segments 583 # the last chunk has to be non-volatile 584 end = -1 585 for start in range(0, segment_size * (segments - 1), segment_size): 586 end = start + segment_size - 1 587 input = checkpoint( 588 run_function(start, end, functions), 589 input, 590 use_reentrant=use_reentrant, 591 preserve_rng_state=preserve, 592 ) 593 return run_function(end + 1, len(functions) - 1, functions)(input) 594 595 596def _internal_assert(cond): 597 if not cond: 598 raise AssertionError( 599 "Something went unexpectedly wrong in activation checkpoint. " 600 "Please report this bug by filing an issue to PyTorch." 601 ) 602 603 604# NOTE [ Nestable Checkpoint ] 605# 606# The semantics of nested checkpoint can be defined by two basic rules. 607# Following the two rules leads to an important implication that is central 608# to motivating the design. 609# 610# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden 611# from any outer layers of checkpoint. 612# 613# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its 614# parent checkpoint. 615# 616# Implication: To recompute any given saved tensor, we need to recompute all of 617# the checkpoints wrapping it. 618# 619# Why is this implied? To unpack a saved tensor X during backward we need to 620# recompute the inner-most checkpoint (#1), and in order to recompute that 621# checkpoint I need to have its inputs, which are managed by that checkpoint's 622# parent (#2), which thus also needs to be recomputed first. Continue this line 623# of reasoning and we realize that in order to unpack X, all checkpoints that 624# were active at the time X was saved need to be recomputed. (unless we have 625# already done so in that backward for some other saved tensor). 626# 627# In practice, we use a noop autograd Function to save inputs as saved tensors. 628# During unpack calling ctx.saved_tensor triggers the parent checkpoint to 629# recompute. 630# 631# Rule 3. We should start recomputation as if there are no checkpoints currently 632# active. Checkpoints encountered during recomputation are still 633# respected. 634# 635# When we start recomputation, we push the saved variable hook meant for 636# recomputation on the stack. See examples in Rule 6 for more context. 637# 638# * * * * 639# 640# Beyond the basic semantics specific to nested checkpoint, we impose several 641# more constraints that may apply to checkpointing in general. 642# 643# Rule 4. Lifetime of recomputed tensors 644# 645# Recomputed tensors are considered specific to particular invocations 646# of backward and are always cleared immediately as they are unpacked 647# Particularly, we require this to happen even if retain_graph=True. 648# 649# [ Implementation details of Rule 4 ] 650# 651# If we were okay with recomputed tensors staying alive after backward is run 652# with retain_graph=True, we would store recomputed variables as the values of a 653# WeakKeyDictionary and pack strong references to the keys, so that as we 654# backward, those packed keys would be cleared as long as retain_graph=False. 655# Clearing the packed key clears the corresponding entry in the WKD. 656# 657# If we wish recomputed variables to be immediately cleared as we unpack them in 658# the retain_graph=True case, we cannot rely on the packed keys to be cleared by 659# backward automatically. Instead of packing the strong reference to the key 660# directly, we pack a container object, which we manually clear as we unpack. 661# 662# An important detail is that if a second backward happens, the second 663# recomputation needs to reset the container with a newly created key. 664# 665# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we 666# know we need. 667# 668# [ Implementation details of Rule 5 ] 669# 670# During recomputation, raise an exception if the number of recomputed tensors 671# matches the number of tensors that we expected to recompute. We wrap the 672# recomputation call with a try-catch to catch this specific exception. See 673# Rule #6 below for some examples. 674# 675# Rule 6. We support doing backward inside checkpoint context 676# 677# [ retain_graph is True] 678# 679# def fn(x): 680# y = x.sin() 681# z = y.cos() 682# gx, = torch.autograd.grad(z, x, retains_grad=True) 683# return gx, z 684# 685# out = checkpoint(fn)(inp) 686# out.backward() 687# 688# Because z is saved by cos while checkpoint is enabled, it would not be 689# actually saved, and so the .grad() call inside must trigger a recomputation. 690# 691# During recomputation the "inner pack hook" has two responsibilities: 692# 693# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors 694# 2) Pack the actual tensor (detached) so that one may perform backward on the 695# recomputed graph. The tensors saved to this graph will live until the end 696# of recomputation, or die earlier if someone performs backward with 697# retain_graph=False. 698# 699# More generally performing backward on the recomputed graph occurs in the 700# following cases: 701# - If backward is performed inside forward, 702# - During the original forward IF early-stop is disabled 703# - During the original backward 704# - If there are multiple .grad()/.backward() calls, we would perform backward 705# on the recomputed graph even if early-stop is enabled (see the example below) 706# 707# [ retain_graph is False ] 708# 709# The example below shows what happens if during recomputation we find that some 710# of the tensors we are trying to recompute have already been cleared. 711# 712# Spoiler: we don't do anything special, we just skip over them! 713# 714# def fn(x): 715# y = x.sin() # (1) 716# z = y.cos() # (2) 717# gx, = torch.autograd.grad(z, x) # (3) 718# return x.cos() * gx # (4) 719# 720# out = checkpoint(fn)(inp) 721# out.backward() # (5) 722# 723# 1, 2. Don't save x and y since we are inside a checkpoint. 724# 3. Trigger a recompute of fn since x and y weren't saved. 725# And depending on whether early stop is enabled, either stop at (2) or 726# continue running the function. 727# Because we are running backward with retain_graph=False, we clear x and y's 728# holders. 729# 4. Don't save x since we are inside a checkpoint. 730# 5. Calling backward triggers another recompute of fn. During recompute, we see 731# that x and y have already been cleared in the original graph as indicated 732# by holder=None. We skip over them. We still save x at (4) (since its holder 733# is still alive.) 734 735_enable_checkpoint_early_stop = True 736 737 738@contextlib.contextmanager 739def set_checkpoint_early_stop(enable: bool): 740 """Context manager that sets whether checkpoint should stop recomputation early. 741 742 By default, non-reentrant checkpoint stops recomputation as soon as it 743 has computed all needed Tensors. This context manager can be used to disable 744 that feature if it is problematic for your specific application. 745 746 This context manager only needs to be active when forward is run. It does 747 not need to be active during backward. 748 749 Example:: 750 751 >>> # xdoctest: +SKIP(failing) 752 >>> message = "saved tensors default hooks are disabled" 753 >>> with set_checkpoint_early_stop(False): 754 ... # Any checkpoint under this context manager will respect this 755 ... # context manager, even if its backward is performed outside. 756 ... out = checkpoint(fn, inputs) 757 ... 758 >>> out.backward() 759 """ 760 global _enable_checkpoint_early_stop 761 try: 762 prev = _enable_checkpoint_early_stop 763 _enable_checkpoint_early_stop = enable 764 yield 765 finally: 766 _enable_checkpoint_early_stop = prev 767 768 769class _Handle: 770 pass 771 772 773class _Holder: 774 def __init__(self): 775 self.handles: Dict[int, Optional[_Handle]] = {} 776 777 778class _NoopSaveInputs(torch.autograd.Function): 779 @staticmethod 780 def forward(*args): 781 return torch.empty((0,)) 782 783 @staticmethod 784 def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: 785 # Only tensors can be saved with ctx.save_for_backward, everything else 786 # is captured by get_args, which is saved directly on ctx 787 tensor_indices, tensors = zip( 788 *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] 789 ) 790 idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} 791 # args but with tensors replaced with None as placeholders 792 args = [None if isinstance(o, torch.Tensor) else o for o in inputs] 793 794 def get_args(saved_tensors): 795 # restore the placeholders with the original tensors grabbed from 796 # ctx.saved_tensors (which may be saved on a parent checkpoint if 797 # this checkpoint is nested, and that would trigger a recursive 798 # unpack!) 799 ret = [ 800 saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o 801 for i, o in enumerate(args) 802 ] 803 # grab the tail since we also saved the dummy to avoid having to explicitly 804 # handle the case where there are no tensor inputs 805 return ret[1:] 806 807 ctx.get_args = get_args 808 ctx.save_for_backward(*tensors) 809 810 @staticmethod 811 def backward(ctx, *grad_outputs): 812 raise AssertionError("Did not expect to backward on this graph") 813 814 815class _CheckpointFrame: 816 def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): 817 self.recompute_fn = recompute_fn 818 self.input_saver = None 819 self.weak_holders: List[ReferenceType] = [] 820 # We store this as a weakkeydictionary so that in the case of a partial 821 # backward, the entries in the dict are cleared alongside the Holder 822 # which will be removed when the SavedVariable is cleared. 823 self.recomputed: DefaultDict[ 824 int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] 825 ] = defaultdict(weakref.WeakKeyDictionary) 826 # We need both recomp_counter and recomputed since they can diverge 827 # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 828 self.recomp_counter: DefaultDict[int, int] = defaultdict(int) 829 self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) 830 831 # See Rule 5 832 self.early_stop = early_stop 833 834 # Debugging 835 self.metadata_fn = metadata_fn 836 self.unpack_error_cb = unpack_error_cb 837 self.x_metadatas = [] 838 self.forward_completed = False 839 self.ignore_saved_mismatch = False 840 841 def check_recomputed_tensors_match(self, gid): 842 if self.ignore_saved_mismatch: 843 # TODO: we can probably make this check stricter by checking that 844 # the metadata of the first tensors still match. 845 return 846 # NOTE [ Error handling for checkpoint ] 847 # 848 # At a high level, we need to check that the tensors saved 849 # during original forward matches tensors saved during recompute 850 # This means handling 3 cases: 851 # 852 # 1. During recompute, more tensors were saved. 853 # 854 # Usually this is hidden due to the StopRecomputationError 855 # but if early stop is not enabled, or we would have errored 856 # anyway because there aren't enough weak_holders. But we 857 # do want to have a nice error. See the _recomputation_hook 858 # for details. 859 if not len(self.weak_holders) == self.recomp_counter[gid]: 860 # 2. During recompute, fewer tensors were saved 861 # 862 # We know that everytime we save something do original forward 863 # we append to weak_holder, and every time we save a tensor 864 # during recompute we increment recompute_counter. 865 raise CheckpointError( 866 "torch.utils.checkpoint: A different number of tensors was saved " 867 "during the original forward and recomputation.\n" 868 f"Number of tensors saved during forward: {len(self.weak_holders)}\n" 869 f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}" 870 ) 871 872 # 3. During recompute, the same tensors were saved, but they 873 # have different metadata 874 nb_meta_different = [] 875 for idx, weak_holder in enumerate(self.weak_holders): 876 holder = weak_holder() 877 if holder is None: 878 continue 879 # We've seen all holders since we iterate over them in order 880 # For every holder that is still alive now, it must've been 881 # alive when we saw it during recompute, therefore, the 882 # gid must be set. 883 _internal_assert(gid in holder.handles) 884 # We know this is the first unpack, so it couldn't have been set 885 # to None yet. 886 _internal_assert(holder.handles[gid] is not None) 887 # We always set these together in the recomputation hook 888 _internal_assert(holder.handles[gid] in self.recomputed[gid]) 889 # see pack hook, x_metadata is 1:1 with weak_holders. 890 x_meta = self.x_metadatas[idx] 891 recomputed_x = self.recomputed[gid][holder.handles[gid]] 892 if x_meta != self.metadata_fn(recomputed_x): 893 nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) 894 895 if len(nb_meta_different) > 0: 896 mismatched_tensors = "" 897 for idx, x_meta, recomputed_meta in nb_meta_different: 898 mismatched_tensors += ( 899 f"tensor at position {idx}:\n" 900 f"saved metadata: {x_meta}\n" 901 f"recomputed metadata: {recomputed_meta}\n" 902 ) 903 raise CheckpointError( 904 "torch.utils.checkpoint: Recomputed values for the following tensors " 905 "have different metadata than during the forward pass.\n" 906 f"{mismatched_tensors}" 907 ) 908 909 910_checkpoint_error_template = """ \ 911An error happened while unpacking tensors; dumping logs of latest computation 912because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. 913Scroll all the way down for guidance on how to navigate these logs. 914 915+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ 916| 1. Stack traces of the operators that ran in the original forward | 917+------------------------------------------------------------------------------+ 918 919{forward_traces} 920+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ 921| 2. Stack traces of the operators that ran during recomputation | 922+------------------------------------------------------------------------------+ 923 924{recompute_traces} 925+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ 926| 3. Log of operators in the original forward and recomputation | 927+------------------------------------------------------------------------------+ 928(Scroll up to correlate stack traces with each operation listed below. This 929 helps identify their source in the code.) 930 931IMPORTANT: Differences in "detach" calls between the original forward and the 932 recomputation are expected. They are introduced by the checkpointing 933 mechanism and can be ignored. 934 935Operations executed during the original forward: 936 937{forward_ops} 938 939Operations executed during recomputation: 940 941{recompute_ops} 942 943+------------------------------------------------------------------------------+ 944 ERROR: Detected non-determinism while running activation checkpointing 945 946 You are seeing this error because you passed `debug=True` to checkpoint and 947 tensors to be saved during the original forward and differ between those saved 948 during recomputation. This can happen if different operators were ran in the 949 original forward and in the recomputation. 950 951 To identify where the mismatch may be coming from, you can do the following: 952 953 1) Compare the operators ran during original forward and recomputation to 954 see where they differ. These operators are printed above in the order they 955 were executed. 956 957 2) Review the stack trace for each operator to locate its invocation source. 958 Each operator's stack trace is printed in their execution order. 959 960 Note that the logs can be quite long. Here's how they are structured: 961 (Tip: you can Ctrl-f for these headers) 962 963 1. Stack traces of the operators that ran in the original forward 964 2. Stack traces of the operators that ran during recomputation 965 3. Log of operators in the original forward and recomputation 966 4. Error message <--- You are here 967-------------------------------------------------------------------------------- 968""" 969 970class CheckpointError(RuntimeError): 971 pass 972 973 974def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: 975 # This function returns the context_fn and error_cb to be used by the 976 # checkpointing mechanism. error_cb is invoked when an error is detected 977 # during unpack. 978 979 # record_context_cpp is not support on non-linux non-x86_64 platforms 980 cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' 981 982 class CaptureLogs: 983 def __init__(self): 984 self.logs = None 985 self.tbs = None 986 987 def get_context_manager(self): 988 @contextlib.contextmanager 989 def logging_mode(): 990 with LoggingTensorMode(), \ 991 capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: 992 self.logs, self.tbs = logs_and_tb 993 yield logs_and_tb 994 return logging_mode() 995 996 capture_logs_fwd = CaptureLogs() 997 capture_logs_recompute = CaptureLogs() 998 999 def unpack_error_cb(e: CheckpointError): 1000 def get_str_tb(label, capture_logs): 1001 out = "" 1002 total_len = len(capture_logs.logs) 1003 for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): 1004 out += f"{log} ({i + 1} of {total_len} in {label})\n\n" 1005 found_torch_dispatch = False 1006 for line in tb: 1007 # Start printing stack trace only after __torch_dispatch__ is found 1008 is_torch_dispatch = line['name'] == '__torch_dispatch__' 1009 if not found_torch_dispatch and not is_torch_dispatch: 1010 continue 1011 elif is_torch_dispatch: 1012 found_torch_dispatch = True 1013 continue 1014 out += f"{line['filename']}:{line['line']}:{line['name']}\n" 1015 out += "\n\n" 1016 return out 1017 assert capture_logs_fwd.logs is not None 1018 assert capture_logs_recompute.logs is not None 1019 raise CheckpointError( 1020 _checkpoint_error_template.format( 1021 forward_traces=get_str_tb("original", capture_logs_fwd), 1022 recompute_traces=get_str_tb("recompute", capture_logs_recompute), 1023 forward_ops="\n".join(capture_logs_fwd.logs), 1024 recompute_ops="\n".join(capture_logs_recompute.logs) 1025 ) 1026 ) from e 1027 1028 def context_fn(): 1029 return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() 1030 1031 return context_fn, unpack_error_cb 1032 1033def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: 1034 # These properties are fast to check, easy to understand 1035 return { 1036 "shape": x.shape, 1037 "dtype": x.dtype, 1038 "device": x.device 1039 } 1040 1041_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { 1042 _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, 1043 "none": lambda _: None, 1044} 1045 1046# See Rule 5 1047class _StopRecomputationError(Exception): 1048 pass 1049 1050 1051class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): 1052 def __init__(self, target_frame_ref: ReferenceType, gid: int): 1053 def pack_hook(x): 1054 x = x.detach() if x.requires_grad else x 1055 target_frame = target_frame_ref() 1056 assert target_frame is not None # appease mypy 1057 recomp_idx = target_frame.recomp_counter[gid] 1058 target_frame.recomp_counter[gid] += 1 1059 1060 if recomp_idx >= len(target_frame.weak_holders): 1061 assert not target_frame.early_stop 1062 if not target_frame.forward_completed: 1063 # We run into this case when early stop is not enabled and do 1064 # grad within checkpoint. 1065 # We need to set this flag, so we don't error out later when 1066 # we check if the number of tensors saved during forward and 1067 # recomputation match. 1068 target_frame.ignore_saved_mismatch = True 1069 return x 1070 raise CheckpointError( 1071 "torch.utils.checkpoint: trying to save more tensors during " 1072 "recomputation than during the original forward pass." 1073 ) 1074 1075 holder = target_frame.weak_holders[recomp_idx]() 1076 1077 # This holder may have been cleared because someone may have called 1078 # backward within forward. If so, we don't need to save. 1079 if holder is not None: 1080 _internal_assert(holder.handles.get(gid, None) is None) 1081 holder.handles[gid] = _Handle() 1082 target_frame.recomputed[gid][holder.handles[gid]] = x 1083 1084 if target_frame.early_stop and target_frame.recomp_counter[gid] == len( 1085 target_frame.weak_holders 1086 ): 1087 raise _StopRecomputationError 1088 # See Rule 6: [ retain_graph is True ] above 1089 return x 1090 1091 def unpack_hook(x): 1092 # See Rule 6: [ retain_graph is True ] above for an example of when 1093 # the graph created during recomputation could be backwarded. 1094 return x 1095 1096 super().__init__(pack_hook, unpack_hook) 1097 1098 1099class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): 1100 def __init__(self, frame): 1101 def pack_hook(x): 1102 # See Rule 4 above 1103 holder = _Holder() 1104 frame.weak_holders.append(weakref.ref(holder)) 1105 # Save metadata to detect non-determinism 1106 if frame.metadata_fn is not None: 1107 with torch.no_grad(): 1108 frame.x_metadatas.append(frame.metadata_fn(x)) 1109 return holder 1110 1111 def unpack_hook(holder): 1112 gid = torch._C._current_graph_task_id() 1113 if gid == -1: 1114 # generate a temporary id if we trigger unpack outside of a backward call 1115 gid = int(uuid.uuid4()) 1116 1117 if not frame.is_recomputed[gid]: 1118 ctx = frame.input_saver.grad_fn 1119 args = ctx.get_args(ctx.saved_tensors) 1120 1121 try: 1122 with _recomputation_hook( 1123 weakref.ref(frame), gid 1124 ), torch.autograd.enable_grad(): 1125 frame.recompute_fn(*args) 1126 except _StopRecomputationError: 1127 pass 1128 frame.is_recomputed[gid] = True 1129 frame.check_recomputed_tensors_match(gid) 1130 1131 _internal_assert(gid in holder.handles) 1132 1133 if holder.handles[gid] is None: 1134 raise CheckpointError( 1135 "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " 1136 "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " 1137 "so only once. Otherwise please open an issue with details on your use case." 1138 ) 1139 _internal_assert(holder.handles[gid] in frame.recomputed[gid]) 1140 ret = frame.recomputed[gid][holder.handles[gid]] 1141 holder.handles[gid] = None 1142 return ret 1143 1144 if frame.unpack_error_cb is not None: 1145 def unpack_hook_with_error_cb(holder): 1146 try: 1147 return unpack_hook(holder) 1148 except CheckpointError as e: 1149 frame.unpack_error_cb(e) 1150 super().__init__(pack_hook, unpack_hook_with_error_cb) 1151 else: 1152 super().__init__(pack_hook, unpack_hook) 1153 1154 1155def _is_compiling(func, args, kwargs): 1156 # Check if we are under AOTAutograd tracing 1157 # There should probably be a better way to do this... 1158 # TODO: unify _is_compiling across all compile stacks 1159 for arg in args: 1160 if isinstance(arg, torch.Tensor) and is_fun(arg): 1161 return True 1162 return False 1163 1164 1165class _VersionWrapper: 1166 # Check that cached tensors are not mutated. 1167 def __init__(self, val): 1168 self.val: Union[torch.Tensor, Any] = val 1169 self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None 1170 1171 def get_val(self, allow_cache_entry_mutation): 1172 if self.version is not None and not allow_cache_entry_mutation: 1173 if self.val._version != self.version: 1174 # Can we give user a stack trace of where the mutation happened? 1175 raise RuntimeError( 1176 "Tensor cached during selective activation checkpoint has been mutated" 1177 ) 1178 return self.val 1179 1180 1181def _maybe_detach(x, any_ret_has_alias_info): 1182 # We detach for two separate reasons: 1183 # - For view ops, we need to ensure that when the tensor is returned from 1184 # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr 1185 # - Avoid reference cycles 1186 # For case 1, it is not enough to check whether x has differentiable dtype 1187 # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. 1188 # when the tensor is a view. 1189 if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): 1190 with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): 1191 # Ensure that view performed beneath autograd properly propagates 1192 # version counter. TODO: Use reentrant_dispatch instead of 1193 # manually manipulating dispatch keys. Using reentrant_dispatch 1194 # would respect inference_mode, though that is not relevant for 1195 # this case. 1196 x = x.detach() 1197 return x 1198 1199 1200class SelectiveCheckpointContext: 1201 """ 1202 Context passed to policy function during selective checkpointing. 1203 1204 This class is used to pass relevant metadata to the policy function during 1205 selective checkpointing. The metadata includes whether the current invocation 1206 of the policy function is during recomputation or not. 1207 1208 Example: 1209 >>> # xdoctest: +SKIP(stub) 1210 >>> 1211 >>> def policy_fn(ctx, op, *args, **kwargs): 1212 >>> print(ctx.is_recompute) 1213 >>> 1214 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 1215 >>> 1216 >>> out = torch.utils.checkpoint.checkpoint( 1217 >>> fn, x, y, 1218 >>> use_reentrant=False, 1219 >>> context_fn=context_fn, 1220 >>> ) 1221 """ 1222 def __init__(self, *, is_recompute): 1223 self.is_recompute = is_recompute 1224 1225 1226class CheckpointPolicy(enum.Enum): 1227 """ 1228 Enum for specifying the policy for checkpointing during backpropagation. 1229 1230 The following policies are supported: 1231 1232 - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward 1233 pass and will not be recomputed during the backward pass 1234 - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the 1235 forward pass and will be recomputed during the backward pass 1236 1237 Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden 1238 by other subsystems like `torch.compile`. 1239 1240 .. note:: 1241 A policy function that always returns ``PREFER_RECOMPUTE`` is 1242 equivalent to vanilla checkpointing. 1243 1244 A policy function that returns ``PREFER_SAVE`` every op is 1245 NOT equivalent to not using checkpointing. Using such a policy would 1246 save additional tensors not limited to ones that are actually needed for 1247 gradient computation. 1248 """ 1249 MUST_SAVE = 0 1250 PREFER_SAVE = 1 1251 MUST_RECOMPUTE = 2 1252 PREFER_RECOMPUTE = 3 1253 1254 1255def _policy_from_bool(b): 1256 # For backward compatability 1257 return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE 1258 1259 1260SAC_IGNORED_OPS = { 1261 # AC inserts different number of detach during forward and recompute. 1262 torch.ops.aten.detach.default, 1263 # AC's determinism check invokes additional metadata ops during forward. 1264 # With subclasses involved, these metadata ops become dispatchable, this 1265 # can result in incorrectness if these ops are selected cached. 1266 torch.ops.prim.device.default, 1267} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) 1268 1269 1270class _CachingTorchDispatchMode(TorchDispatchMode): 1271 # Used together with _CachedTorchDispatchMode to implement SAC. 1272 def __init__(self, policy_fn, storage): 1273 self.policy_fn = policy_fn 1274 self.storage = storage 1275 1276 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1277 if func in SAC_IGNORED_OPS: 1278 return func(*args, **kwargs) 1279 1280 kwargs = {} if kwargs is None else kwargs 1281 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), 1282 func, *args, **kwargs) 1283 if isinstance(policy, bool): 1284 policy = _policy_from_bool(policy) 1285 1286 is_compiling = _is_compiling(func, args, kwargs) 1287 1288 if is_compiling: 1289 # Overwrite each node's "recompute" tag to add in the user annotation. 1290 fx_traceback.current_meta["recompute"] = policy 1291 1292 out = func(*args, **kwargs) 1293 1294 any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) 1295 1296 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: 1297 self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) 1298 return out 1299 1300class _CachedTorchDispatchMode(TorchDispatchMode): 1301 # Used together with _CachedTorchDispatchMode to implement SAC. 1302 def __init__(self, policy_fn, storage, allow_cache_entry_mutation): 1303 self.policy_fn = policy_fn 1304 self.storage = storage 1305 self.allow_cache_entry_mutation = allow_cache_entry_mutation 1306 1307 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1308 if func in SAC_IGNORED_OPS: 1309 return func(*args, **kwargs) 1310 1311 kwargs = {} if kwargs is None else kwargs 1312 policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), 1313 func, *args, **kwargs) 1314 if isinstance(policy, bool): 1315 policy = _policy_from_bool(policy) 1316 1317 is_compiling = _is_compiling(func, args, kwargs) 1318 1319 if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: 1320 storage = self.storage.get(func) 1321 if storage is None: 1322 raise RuntimeError(f"{func} encountered during backward, but not found in storage") 1323 if len(storage) == 0: 1324 raise RuntimeError( 1325 "Trying to backward an extra time. You are only allowed to backward once " 1326 "on any region computed under selective activation checkpoint." 1327 ) 1328 out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) 1329 else: 1330 out = func(*args, **kwargs) 1331 return out 1332 1333 1334def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 1335 """ 1336 Helper to avoid recomputing certain ops during activation checkpointing. 1337 1338 Use this with `torch.utils.checkpoint.checkpoint` to control which 1339 operations are recomputed during the backward pass. 1340 1341 Args: 1342 policy_fn_or_list (Callable or List): 1343 - If a policy function is provided, it should accept a 1344 :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and 1345 kwargs to the op, and return a :class:`CheckpointPolicy` enum value 1346 indicating whether the execution of the op should be recomputed or not. 1347 - If a list of operations is provided, it is equivalent to a policy 1348 returning `CheckpointPolicy.MUST_SAVE` for the specified 1349 operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other 1350 operations. 1351 allow_cache_entry_mutation (bool, optional): By default, an error is 1352 raised if any tensors cached by selective activation checkpoint are 1353 mutated in order to ensure correctness. If set to `True`, this check 1354 is disabled. 1355 Returns: 1356 A tuple of two context managers. 1357 1358 Example: 1359 >>> # xdoctest: +REQUIRES(LINUX) 1360 >>> import functools 1361 >>> 1362 >>> x = torch.rand(10, 10, requires_grad=True) 1363 >>> y = torch.rand(10, 10, requires_grad=True) 1364 >>> 1365 >>> ops_to_save = [ 1366 >>> torch.ops.aten.mm.default, 1367 >>> ] 1368 >>> 1369 >>> def policy_fn(ctx, op, *args, **kwargs): 1370 >>> if op in ops_to_save: 1371 >>> return CheckpointPolicy.MUST_SAVE 1372 >>> else: 1373 >>> return CheckpointPolicy.PREFER_RECOMPUTE 1374 >>> 1375 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 1376 >>> 1377 >>> # or equivalently 1378 >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) 1379 >>> 1380 >>> def fn(x, y): 1381 >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 1382 >>> 1383 >>> out = torch.utils.checkpoint.checkpoint( 1384 >>> fn, x, y, 1385 >>> use_reentrant=False, 1386 >>> context_fn=context_fn, 1387 >>> ) 1388 """ 1389 # NB: If grad_mode is disabled, checkpoint would not run forward under 1390 # context_fn anyway, so proceed as usual. 1391 if isinstance(policy_fn_or_list, list): 1392 for op in policy_fn_or_list: 1393 if not isinstance(op, torch._ops.OpOverload): 1394 _extra_msg = ( 1395 "Please update the OpOverloadPacket to a specific OpOverload." 1396 "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." 1397 ) if isinstance(op, torch._ops.OpOverloadPacket) else "" 1398 raise ValueError( 1399 f"Expected op in `op_list` to be an OpOverload but got: {op} " 1400 f"of type {type(op)}. {_extra_msg}" 1401 ) 1402 1403 def policy_fn(ctx, op, *args, **kwargs): 1404 if op in policy_fn_or_list: 1405 return CheckpointPolicy.MUST_SAVE 1406 else: 1407 return CheckpointPolicy.PREFER_RECOMPUTE 1408 elif callable(policy_fn_or_list): 1409 policy_fn = policy_fn_or_list 1410 else: 1411 raise TypeError("policy_fn_or_list must be either a function or a list of ops.") 1412 1413 storage: Dict[Any, List[Any]] = defaultdict(list) 1414 return ( 1415 _CachingTorchDispatchMode(policy_fn, storage), 1416 _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), 1417 ) 1418 1419# NB: this helper wraps fn before calling checkpoint_impl. kwargs and 1420# saving/restoring of global state is handled here. 1421 1422def _checkpoint_without_reentrant_generator( 1423 fn, 1424 preserve_rng_state=True, 1425 context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, 1426 determinism_check: str = _DEFAULT_DETERMINISM_MODE, 1427 debug: bool = False, 1428 *args, 1429 **kwargs 1430): 1431 """Checkpointing without reentrant autograd. 1432 1433 Args: 1434 function: describes what to run in the forward pass of the model or 1435 part of the model. It should also know how to handle the inputs 1436 passed as the tuple. For example, in LSTM, if user passes 1437 ``(activation, hidden)``, :attr:`function` should correctly use the 1438 first input as ``activation`` and the second input as ``hidden`` 1439 preserve_rng_state(bool, optional): Omit stashing and restoring 1440 the RNG state during each checkpoint. 1441 Default: ``True`` 1442 context_fn(Callable, optional): A callable returning a tuple of two 1443 context managers. The function and its recomputation will be run 1444 under the first and second context managers respectively. 1445 determinism_check(str, optional): A string specifying the determinism 1446 check to perform. By default it is set to ``"default"`` which 1447 compares the shapes, dtypes, and devices of the recomputed tensors 1448 against those the saved tensors. To turn off this check, specify 1449 ``"none"``. Currently these are the only two supported values. 1450 Please open an issue if you would like to see more determinism 1451 checks. 1452 debug(bool, optional): If ``True``, error messages will also include 1453 a trace of the operators ran during the original forward computation 1454 as well as the recomputation. 1455 *args: Arguments to pass in to the given ``function``. 1456 **kwargs: Keyword arguments to pass into the given ``function``. 1457 """ 1458 unpack_error_cb = None 1459 1460 if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: 1461 if context_fn != noop_context_fn: 1462 raise ValueError( 1463 "debug=True is incompatible with non-default context_fn" 1464 ) 1465 context_fn, unpack_error_cb = _get_debug_context_and_cb() 1466 1467 if determinism_check in _allowed_determinism_checks_to_fns: 1468 metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] 1469 else: 1470 raise ValueError( 1471 f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " 1472 f"but got {determinism_check}" 1473 ) 1474 1475 device_type = _infer_device_type(*args) 1476 device_module = _get_device_module(device_type) 1477 forward_context, recompute_context = context_fn() 1478 if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: 1479 assert ( 1480 isinstance(forward_context, TorchDispatchMode) and 1481 isinstance(recompute_context, TorchDispatchMode) 1482 ), \ 1483 "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \ 1484 "must generate a tuple of two `TorchDispatchMode`s." 1485 # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 1486 device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type) 1487 1488 if preserve_rng_state: 1489 fwd_cpu_state = torch.get_rng_state() 1490 # Don't eagerly initialize the cuda context by accident. 1491 # (If the user intends that the context is initialized later, within their 1492 # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 1493 # we have no way to anticipate this will happen before we run the function. 1494 # If they do so, we raise an error.) 1495 had_device_in_fwd = False 1496 if getattr(device_module, "_initialized", False): 1497 had_device_in_fwd = True 1498 fwd_devices, fwd_device_states = get_device_states(*args) 1499 1500 def recompute_fn(*inputs): 1501 kwargs, *args = inputs 1502 # This will be called later during recomputation. This wrapping enables 1503 # the necessary global state to be captured. 1504 rng_devices = [] 1505 if preserve_rng_state and had_device_in_fwd: 1506 rng_devices = fwd_devices 1507 with torch.random.fork_rng( 1508 devices=rng_devices, enabled=preserve_rng_state, device_type=device_type 1509 ): 1510 if preserve_rng_state: 1511 torch.set_rng_state(fwd_cpu_state) 1512 if had_device_in_fwd: 1513 set_device_states(fwd_devices, fwd_device_states, device_type=device_type) 1514 1515 device_autocast_ctx = torch.amp.autocast( 1516 device_type=device_type, **device_autocast_kwargs 1517 ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() 1518 with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] 1519 fn(*args, **kwargs) 1520 1521 new_frame = _CheckpointFrame( 1522 recompute_fn, 1523 _enable_checkpoint_early_stop, 1524 unpack_error_cb, 1525 metadata_fn 1526 ) 1527 dummy = torch.empty((0,), requires_grad=True) 1528 new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) 1529 1530 # When ambient grad_mode is False 1531 if new_frame.input_saver.grad_fn is None: 1532 yield 1533 return 1534 1535 with _checkpoint_hook(new_frame), forward_context: 1536 yield 1537 new_frame.forward_completed = True 1538 1539 if getattr(device_module, "_initialized", False) and \ 1540 preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] 1541 # Device was not initialized before running the forward, so we didn't 1542 # stash the device state. 1543 raise RuntimeError( 1544 "PyTorch's device state was initialized in the forward pass " 1545 "of a Checkpoint, which is not allowed. Please open an issue " 1546 "if you need this feature." 1547 ) 1548 1549 return 1550