xref: /aosp_15_r20/external/pytorch/torch/utils/checkpoint.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import contextlib
4import platform
5import uuid
6import warnings
7import weakref
8from collections import defaultdict
9from typing import *  # noqa: F403
10import enum
11from weakref import ReferenceType
12
13import torch
14import torch.fx.traceback as fx_traceback
15from torch._functorch._aot_autograd.functional_utils import is_fun
16from torch.utils._pytree import tree_map
17from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
18from torch.utils._python_dispatch import TorchDispatchMode
19
20__all__ = [
21    "checkpoint",
22    "checkpoint_sequential",
23    "CheckpointError",
24    "CheckpointFunction",
25    "check_backward_validity",
26    "detach_variable",
27    "get_device_states",
28    "set_device_states",
29    "noop_context_fn",
30    "set_checkpoint_early_stop",
31    "DefaultDeviceType",
32    "set_checkpoint_debug_enabled",
33    "CheckpointPolicy",
34    "SelectiveCheckpointContext",
35    "create_selective_checkpoint_contexts",
36    "SAC_IGNORED_OPS",
37]
38
39_DEFAULT_DETERMINISM_MODE = "default"
40
41_checkpoint_debug_enabled: Optional[bool] = None
42
43
44@contextlib.contextmanager
45def set_checkpoint_debug_enabled(enabled: Optional[bool]):
46    """
47    Context manager that sets whether checkpoint should print additional debug
48    information when running. See the ``debug`` flag for
49    :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
50    when set, this context manager overrides the value of ``debug`` passed to
51    checkpoint. To defer to the local setting, pass ``None`` to this context.
52
53    Args:
54        enabled (bool): Whether checkpoint should print debug information.
55            Default is 'None'.
56    """
57    global _checkpoint_debug_enabled
58    try:
59        prev = _checkpoint_debug_enabled
60        _checkpoint_debug_enabled = enabled
61        yield
62    finally:
63        _checkpoint_debug_enabled = prev
64
65
66def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
67    if isinstance(inputs, tuple):
68        out = []
69        for inp in inputs:
70            if not isinstance(inp, torch.Tensor):
71                out.append(inp)
72                continue
73
74            x = inp.detach()
75            x.requires_grad = inp.requires_grad
76            out.append(x)
77        return tuple(out)
78    else:
79        raise RuntimeError(
80            "Only tuple of tensors is supported. Got Unsupported input type: ",
81            type(inputs).__name__,
82        )
83
84
85def check_backward_validity(inputs: Iterable[Any]) -> None:
86    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
87        warnings.warn(
88            "None of the inputs have requires_grad=True. Gradients will be None"
89        )
90
91
92def _get_device_module(device="cuda"):
93    if device == "meta":
94        return torch.device("meta")
95    device_module = getattr(torch, device)
96    return device_module
97
98
99class DefaultDeviceType:
100    r"""
101    A class that manages the default device type for checkpointing.
102
103    If no non-CPU tensors are present, the default device type will
104    be used. The default value is 'cuda'. The device type is used in
105    the checkpointing process when determining which device states
106    to save and restore for recomputation.
107    """
108
109    _default_device_type = "cuda"
110
111    @staticmethod
112    def set_device_type(device: str = "cuda"):
113        """
114        Set the default device type for checkpointing.
115
116        Args:
117            device (str): The device type to be set as default. Default is 'cuda'.
118        """
119        DefaultDeviceType._default_device_type = device
120
121    @staticmethod
122    def get_device_type() -> str:
123        """
124        Get the current default device type for checkpointing.
125
126        Returns:
127            str: The current default device type.
128        """
129        return DefaultDeviceType._default_device_type
130
131
132def _infer_device_type(*args):
133    device_types = []
134
135    def add_device_types(arg):
136        nonlocal device_types
137        if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu":
138            device_types.append(arg.device.type)
139    tree_map(add_device_types, args)
140
141    device_types_set = set(device_types)
142    if len(device_types_set) > 1:
143        warnings.warn(
144            "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
145            "Device state will only be saved for devices of a single device type, and the remaining "
146            "devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
147            "this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
148            "detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
149            f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}"
150        )
151    if len(device_types) == 0:
152        return DefaultDeviceType.get_device_type()
153    elif "cuda" in device_types_set:
154        return "cuda"
155    else:
156        return device_types[0]
157
158
159# We can't know if the run_fn will internally move some args to different devices,
160# which would require logic to preserve rng states for those devices as well.
161# We could paranoically stash and restore ALL the rng states for all visible devices,
162# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
163# the device of all Tensor args.
164#
165# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
166def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
167    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
168    # the conditionals short-circuit.
169    fwd_device_ids = []
170
171    def add_device_ids(arg):
172        nonlocal fwd_device_ids
173        if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}:
174            fwd_device_ids.append(arg.get_device())
175    tree_map(add_device_ids, args)
176
177    fwd_device_states = []
178    device_module = _get_device_module(_infer_device_type(*args))
179    for device_id in fwd_device_ids:
180        with device_module.device(device_id):
181            fwd_device_states.append(device_module.get_rng_state())
182
183    return fwd_device_ids, fwd_device_states
184
185
186def set_device_states(devices, states, *, device_type=None) -> None:
187    """Sets random number generator states for the specified devices.
188
189    Args:
190        devices: Device ids to set states for.
191        states: States to set.
192        device_type: ``device_type`` of the devices to set states for. Default
193            is the device returned by a call to ``DefaultDeviceType.get_device_type()``,
194            which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``.
195    """
196    if device_type is None:
197        device_type = DefaultDeviceType.get_device_type()
198    if device_type == "meta":
199        return
200    device_module = _get_device_module(device_type)
201    for device, state in zip(devices, states):
202        with device_module.device(device):
203            device_module.set_rng_state(state)
204
205
206def _get_autocast_kwargs(device_type="cuda"):
207    if torch.amp.is_autocast_available(device_type):
208        device_autocast_kwargs = {
209            "enabled": torch.is_autocast_enabled(device_type),
210            "dtype": torch.get_autocast_dtype(device_type),
211            "cache_enabled": torch.is_autocast_cache_enabled(),
212        }
213    else:
214        device_autocast_kwargs = None
215
216    cpu_autocast_kwargs = {
217        "enabled": torch.is_autocast_enabled('cpu'),
218        "dtype": torch.get_autocast_dtype('cpu'),
219        "cache_enabled": torch.is_autocast_cache_enabled(),
220    }
221
222    return device_autocast_kwargs, cpu_autocast_kwargs
223
224
225class CheckpointFunction(torch.autograd.Function):
226    @staticmethod
227    def forward(ctx, run_function, preserve_rng_state, *args):
228        check_backward_validity(args)
229        ctx.run_function = run_function
230        ctx.preserve_rng_state = preserve_rng_state
231        # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
232        ctx.device_type = _infer_device_type(*args)
233        ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
234            ctx.device_type
235        )
236        if preserve_rng_state:
237            ctx.fwd_cpu_state = torch.get_rng_state()
238            # Don't eagerly initialize the cuda context by accident.
239            # (If the user intends that the context is initialized later, within their
240            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
241            # we have no way to anticipate this will happen before we run the function.)
242            ctx.had_device_in_fwd = False
243            device_module = _get_device_module(ctx.device_type)
244            if getattr(device_module, "_initialized", False):
245                ctx.had_device_in_fwd = True
246                ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
247
248        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
249        # to be filled out during the backward.
250        ctx.inputs = []
251        ctx.tensor_indices = []
252        tensor_inputs = []
253        for i, arg in enumerate(args):
254            if torch.is_tensor(arg):
255                tensor_inputs.append(arg)
256                ctx.tensor_indices.append(i)
257                ctx.inputs.append(None)
258            else:
259                ctx.inputs.append(arg)
260
261        ctx.save_for_backward(*tensor_inputs)
262
263        with torch.no_grad():
264            outputs = run_function(*args)
265        return outputs
266
267    @staticmethod
268    def backward(ctx, *args):
269        if not torch.autograd._is_checkpoint_valid():
270            raise RuntimeError(
271                "When use_reentrant=True, torch.utils.checkpoint is incompatible"
272                " with .grad() or passing an `inputs` parameter to .backward()."
273                " To resolve this error, you can either set use_reentrant=False,"
274                " or call .backward() without passing the `inputs` argument."
275            )
276        # Copy the list to avoid modifying original list.
277        inputs = list(ctx.inputs)
278        tensor_indices = ctx.tensor_indices
279        tensors = ctx.saved_tensors
280
281        # Fill in inputs with appropriate saved tensors.
282        for i, idx in enumerate(tensor_indices):
283            inputs[idx] = tensors[i]
284
285        # Stash the surrounding rng state, and mimic the state that was
286        # present at this time during forward.  Restore the surrounding state
287        # when we're done.
288        rng_devices = []
289        if ctx.preserve_rng_state and ctx.had_device_in_fwd:
290            rng_devices = ctx.fwd_devices
291        with torch.random.fork_rng(
292            devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type
293        ):
294            if ctx.preserve_rng_state:
295                torch.set_rng_state(ctx.fwd_cpu_state)
296                if ctx.had_device_in_fwd:
297                    set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type)
298            detached_inputs = detach_variable(tuple(inputs))
299
300            device_autocast_ctx = torch.amp.autocast(
301                device_type=ctx.device_type, **ctx.device_autocast_kwargs
302            ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext()
303            with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
304                outputs = ctx.run_function(*detached_inputs)
305
306        if isinstance(outputs, torch.Tensor):
307            outputs = (outputs,)
308
309        # run backward() with only tensor that requires grad
310        outputs_with_grad = []
311        args_with_grad = []
312        for i in range(len(outputs)):
313            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
314                outputs_with_grad.append(outputs[i])
315                args_with_grad.append(args[i])
316        if len(outputs_with_grad) == 0:
317            raise RuntimeError(
318                "none of output has requires_grad=True,"
319                " this checkpoint() is not necessary"
320            )
321        torch.autograd.backward(outputs_with_grad, args_with_grad)
322        grads = tuple(
323            inp.grad if isinstance(inp, torch.Tensor) else None
324            for inp in detached_inputs
325        )
326
327        return (None, None) + grads
328
329
330def noop_context_fn():
331    return contextlib.nullcontext(), contextlib.nullcontext()
332
333# TorchDynamo does not step inside utils.checkpoint function.  The flow
334# looks likes this
335#  1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
336#     speculatively checking if the forward function is safe to trace.
337#  2) If yes, then Dynamo-generated Fx graph has the wrapped higher
338#     order op. As a result, TorchDynamo does not look inside utils.checkpoint.
339#  3) If not, then TorchDynamo falls back to eager by performing a graph
340#     break. And here, the following disable wrapper ensures that
341#     TorchDynamo does not trigger again on the frames created by
342#     utils.checkpoint innards.
343@torch._disable_dynamo
344def checkpoint(
345    function,
346    *args,
347    use_reentrant: Optional[bool] = None,
348    context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
349    determinism_check: str = _DEFAULT_DETERMINISM_MODE,
350    debug: bool = False,
351    **kwargs
352):
353    r"""Checkpoint a model or part of the model.
354
355    Activation checkpointing is a technique that trades compute for memory.
356    Instead of keeping tensors needed for backward alive until they are used in
357    gradient computation during backward, forward computation in checkpointed
358    regions omits saving tensors for backward and recomputes them during the
359    backward pass. Activation checkpointing can be applied to any part of a
360    model.
361
362    There are currently two checkpointing implementations available, determined
363    by the :attr:`use_reentrant` parameter. It is recommended that you use
364    ``use_reentrant=False``. Please refer the note below for a discussion of
365    their differences.
366
367    .. warning::
368
369        If the :attr:`function` invocation during the backward pass differs
370        from the forward pass, e.g., due to a global variable, the checkpointed
371        version may not be equivalent, potentially causing an
372        error being raised or leading to silently incorrect gradients.
373
374    .. warning::
375
376        The ``use_reentrant`` parameter should be passed explicitly. In version
377        2.4 we will raise an exception if ``use_reentrant`` is not passed.
378        If you are using the ``use_reentrant=True`` variant, please refer to the
379        note below for important considerations and potential limitations.
380
381    .. note::
382
383        The reentrant variant of checkpoint (``use_reentrant=True``) and
384        the non-reentrant variant of checkpoint (``use_reentrant=False``)
385        differ in the following ways:
386
387        * Non-reentrant checkpoint stops recomputation as soon as all needed
388          intermediate activations have been recomputed. This feature is enabled
389          by default, but can be disabled with :func:`set_checkpoint_early_stop`.
390          Reentrant checkpoint always recomputes :attr:`function` in its
391          entirety during the backward pass.
392
393        * The reentrant variant does not record the autograd graph during the
394          forward pass, as it runs with the forward pass under
395          :func:`torch.no_grad`. The non-reentrant version does record the
396          autograd graph, allowing one to perform backward on the graph within
397          checkpointed regions.
398
399        * The reentrant checkpoint only supports the
400          :func:`torch.autograd.backward` API for the backward pass without its
401          `inputs` argument, while the non-reentrant version supports all ways
402          of performing the backward pass.
403
404        * At least one input and output must have ``requires_grad=True`` for the
405          reentrant variant. If this condition is unmet, the checkpointed part
406          of the model will not have gradients. The non-reentrant version does
407          not have this requirement.
408
409        * The reentrant version does not consider tensors in nested structures
410          (e.g., custom objects, lists, dicts, etc) as participating in
411          autograd, while the non-reentrant version does.
412
413        * The reentrant checkpoint does not support checkpointed regions with
414          detached tensors from the computational graph, whereas the
415          non-reentrant version does. For the reentrant variant, if the
416          checkpointed segment contains tensors detached using ``detach()`` or
417          with :func:`torch.no_grad`, the backward pass will raise an error.
418          This is because ``checkpoint`` makes all the outputs require gradients
419          and this causes issues when a tensor is defined to have no gradient in
420          the model. To avoid this, detach the tensors outside of the
421          ``checkpoint`` function.
422
423    Args:
424        function: describes what to run in the forward pass of the model or
425            part of the model. It should also know how to handle the inputs
426            passed as the tuple. For example, in LSTM, if user passes
427            ``(activation, hidden)``, :attr:`function` should correctly use the
428            first input as ``activation`` and the second input as ``hidden``
429        preserve_rng_state(bool, optional):  Omit stashing and restoring
430            the RNG state during each checkpoint. Note that under torch.compile,
431            this flag doesn't take effect and we always preserve RNG state.
432            Default: ``True``
433        use_reentrant(bool):
434            specify whether to use the activation checkpoint variant that
435            requires reentrant autograd. This parameter should be passed
436            explicitly. In version 2.5 we will raise an exception if
437            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
438            ``checkpoint`` will use an implementation that does not require
439            reentrant autograd. This allows ``checkpoint`` to support additional
440            functionality, such as working as expected with
441            ``torch.autograd.grad`` and support for keyword arguments input into
442            the checkpointed function.
443        context_fn(Callable, optional): A callable returning a tuple of two
444            context managers. The function and its recomputation will be run
445            under the first and second context managers respectively.
446            This argument is only supported if ``use_reentrant=False``.
447        determinism_check(str, optional): A string specifying the determinism
448            check to perform. By default it is set to ``"default"`` which
449            compares the shapes, dtypes, and devices of the recomputed tensors
450            against those the saved tensors. To turn off this check, specify
451            ``"none"``. Currently these are the only two supported values.
452            Please open an issue if you would like to see more determinism
453            checks. This argument is only supported if ``use_reentrant=False``,
454            if ``use_reentrant=True``, the determinism check is always disabled.
455        debug(bool, optional): If ``True``, error messages will also include
456            a trace of the operators ran during the original forward computation
457            as well as the recomputation. This argument is only supported if
458            ``use_reentrant=False``.
459        args: tuple containing inputs to the :attr:`function`
460
461    Returns:
462        Output of running :attr:`function` on :attr:`*args`
463    """
464    if use_reentrant is None:
465        warnings.warn(
466            "torch.utils.checkpoint: the use_reentrant parameter should be "
467            "passed explicitly. In version 2.5 we will raise an exception "
468            "if use_reentrant is not passed. use_reentrant=False is "
469            "recommended, but if you need to preserve the current default "
470            "behavior, you can pass use_reentrant=True. Refer to docs for more "
471            "details on the differences between the two variants.",
472            stacklevel=2
473        )
474        use_reentrant = True
475
476    # Hack to mix *args with **kwargs in a python 2.7-compliant way
477    preserve = kwargs.pop("preserve_rng_state", True)
478    if kwargs and use_reentrant:
479        raise ValueError(
480            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
481        )
482
483    if use_reentrant:
484        if context_fn is not noop_context_fn or debug is not False:
485            raise ValueError(
486                "Passing `context_fn` or `debug` is only supported when "
487                "use_reentrant=False."
488            )
489        return CheckpointFunction.apply(function, preserve, *args)
490    else:
491        gen = _checkpoint_without_reentrant_generator(
492            function, preserve, context_fn, determinism_check, debug, *args, **kwargs
493        )
494        # Runs pre-forward logic
495        next(gen)
496        ret = function(*args, **kwargs)
497        # Runs post-forward logic
498        try:
499            next(gen)
500        except StopIteration:
501            return ret
502
503
504def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
505    r"""Checkpoint a sequential model to save memory.
506
507    Sequential models execute a list of modules/functions in order
508    (sequentially). Therefore, we can divide such a model in various segments
509    and checkpoint each segment. All segments except the last will not store
510    the intermediate activations. The inputs of each checkpointed segment will
511    be saved for re-running the segment in the backward pass.
512
513    .. warning::
514        The ``use_reentrant`` parameter should be passed explicitly. In version
515        2.4 we will raise an exception if ``use_reentrant`` is not passed.
516        If you are using the ``use_reentrant=True` variant, please see
517        :func:`~torch.utils.checkpoint.checkpoint` for
518        the important considerations and limitations of this variant. It is
519        recommended that you use ``use_reentrant=False``.
520
521    .. warning:
522        Since PyTorch 1.4, it allows only one Tensor as the input and
523        intermediate outputs, just like :class:`torch.nn.Sequential`.
524
525    Args:
526        functions: A :class:`torch.nn.Sequential` or the list of modules or
527            functions (comprising the model) to run sequentially.
528        segments: Number of chunks to create in the model
529        input: A Tensor that is input to :attr:`functions`
530        preserve_rng_state(bool, optional):  Omit stashing and restoring
531            the RNG state during each checkpoint.
532            Default: ``True``
533        use_reentrant(bool):
534            specify whether to use the activation checkpoint variant that
535            requires reentrant autograd. This parameter should be passed
536            explicitly. In version 2.5 we will raise an exception if
537            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
538            ``checkpoint`` will use an implementation that does not require
539            reentrant autograd. This allows ``checkpoint`` to support additional
540            functionality, such as working as expected with
541            ``torch.autograd.grad`` and support for keyword arguments input into
542            the checkpointed function.
543
544    Returns:
545        Output of running :attr:`functions` sequentially on :attr:`*inputs`
546
547    Example:
548        >>> # xdoctest: +SKIP("stub")
549        >>> model = nn.Sequential(...)
550        >>> input_var = checkpoint_sequential(model, chunks, input_var)
551    """
552    if use_reentrant is None:
553        warnings.warn(
554            "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
555            "parameter should be passed explicitly. "
556            "In version 2.5 we will raise an exception if use_reentrant "
557            "is not passed. use_reentrant=False is "
558            "recommended, but if you need to preserve the current default "
559            "behavior, you can pass use_reentrant=True. Refer to docs for more "
560            "details on the differences between the two variants."
561        )
562        use_reentrant = True
563
564    # Hack for keyword-only parameter in a python 2.7-compliant way
565    preserve = kwargs.pop("preserve_rng_state", True)
566    if kwargs:
567        raise ValueError(
568            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
569        )
570
571    def run_function(start, end, functions):
572        def forward(input):
573            for j in range(start, end + 1):
574                input = functions[j](input)
575            return input
576
577        return forward
578
579    if isinstance(functions, torch.nn.Sequential):
580        functions = list(functions.children())
581
582    segment_size = len(functions) // segments
583    # the last chunk has to be non-volatile
584    end = -1
585    for start in range(0, segment_size * (segments - 1), segment_size):
586        end = start + segment_size - 1
587        input = checkpoint(
588            run_function(start, end, functions),
589            input,
590            use_reentrant=use_reentrant,
591            preserve_rng_state=preserve,
592        )
593    return run_function(end + 1, len(functions) - 1, functions)(input)
594
595
596def _internal_assert(cond):
597    if not cond:
598        raise AssertionError(
599            "Something went unexpectedly wrong in activation checkpoint. "
600            "Please report this bug by filing an issue to PyTorch."
601        )
602
603
604# NOTE [ Nestable Checkpoint ]
605#
606# The semantics of nested checkpoint can be defined by two basic rules.
607# Following the two rules leads to an important implication that is central
608# to motivating the design.
609#
610# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
611#         from any outer layers of checkpoint.
612#
613# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
614#         parent checkpoint.
615#
616# Implication: To recompute any given saved tensor, we need to recompute all of
617#              the checkpoints wrapping it.
618#
619# Why is this implied? To unpack a saved tensor X during backward we need to
620# recompute the inner-most checkpoint (#1), and in order to recompute that
621# checkpoint I need to have its inputs, which are managed by that checkpoint's
622# parent (#2), which thus also needs to be recomputed first. Continue this line
623# of reasoning and we realize that in order to unpack X, all checkpoints that
624# were active at the time X was saved need to be recomputed. (unless we have
625# already done so in that backward for some other saved tensor).
626#
627# In practice, we use a noop autograd Function to save inputs as saved tensors.
628# During unpack calling ctx.saved_tensor triggers the parent checkpoint to
629# recompute.
630#
631# Rule 3. We should start recomputation as if there are no checkpoints currently
632#         active. Checkpoints encountered during recomputation are still
633#         respected.
634#
635# When we start recomputation, we push the saved variable hook meant for
636# recomputation on the stack. See examples in Rule 6 for more context.
637#
638#                                  * * * *
639#
640# Beyond the basic semantics specific to nested checkpoint, we impose several
641# more constraints that may apply to checkpointing in general.
642#
643# Rule 4. Lifetime of recomputed tensors
644#
645#         Recomputed tensors are considered specific to particular invocations
646#         of backward and are always cleared immediately as they are unpacked
647#         Particularly, we require this to happen even if retain_graph=True.
648#
649# [ Implementation details of Rule 4 ]
650#
651# If we were okay with recomputed tensors staying alive after backward is run
652# with retain_graph=True, we would store recomputed variables as the values of a
653# WeakKeyDictionary and pack strong references to the keys, so that as we
654# backward, those packed keys would be cleared as long as retain_graph=False.
655# Clearing the packed key clears the corresponding entry in the WKD.
656#
657# If we wish recomputed variables to be immediately cleared as we unpack them in
658# the retain_graph=True case, we cannot rely on the packed keys to be cleared by
659# backward automatically. Instead of packing the strong reference to the key
660# directly, we pack a container object, which we manually clear as we unpack.
661#
662# An important detail is that if a second backward happens, the second
663# recomputation needs to reset the container with a newly created key.
664#
665# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
666#         know we need.
667#
668# [ Implementation details of Rule 5 ]
669#
670# During recomputation, raise an exception if the number of recomputed tensors
671# matches the number of tensors that we expected to recompute. We wrap the
672# recomputation call with a try-catch to catch this specific exception. See
673# Rule #6 below for some examples.
674#
675# Rule 6. We support doing backward inside checkpoint context
676#
677# [ retain_graph is True]
678#
679# def fn(x):
680#   y = x.sin()
681#   z = y.cos()
682#   gx, = torch.autograd.grad(z, x, retains_grad=True)
683#   return gx, z
684#
685# out = checkpoint(fn)(inp)
686# out.backward()
687#
688# Because z is saved by cos while checkpoint is enabled, it would not be
689# actually saved, and so the .grad() call inside must trigger a recomputation.
690#
691# During recomputation the "inner pack hook" has two responsibilities:
692#
693# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
694# 2) Pack the actual tensor (detached) so that one may perform backward on the
695#    recomputed graph. The tensors saved to this graph will live until the end
696#    of recomputation, or die earlier if someone performs backward with
697#    retain_graph=False.
698#
699# More generally performing backward on the recomputed graph occurs in the
700# following cases:
701# - If backward is performed inside forward,
702#   - During the original forward IF early-stop is disabled
703#   - During the original backward
704# - If there are multiple .grad()/.backward() calls, we would perform backward
705#   on the recomputed graph even if early-stop is enabled (see the example below)
706#
707# [ retain_graph is False ]
708#
709# The example below shows what happens if during recomputation we find that some
710# of the tensors we are trying to recompute have already been cleared.
711#
712# Spoiler: we don't do anything special, we just skip over them!
713#
714# def fn(x):
715#   y = x.sin()                           # (1)
716#   z = y.cos()                           # (2)
717#   gx, = torch.autograd.grad(z, x)       # (3)
718#   return x.cos() * gx                   # (4)
719#
720# out = checkpoint(fn)(inp)
721# out.backward()                          # (5)
722#
723# 1, 2. Don't save x and y since we are inside a checkpoint.
724# 3. Trigger a recompute of fn since x and y weren't saved.
725#    And depending on whether early stop is enabled, either stop at (2) or
726#    continue running the function.
727#    Because we are running backward with retain_graph=False, we clear x and y's
728#    holders.
729# 4. Don't save x since we are inside a checkpoint.
730# 5. Calling backward triggers another recompute of fn. During recompute, we see
731#    that x and y have already been cleared in the original graph as indicated
732#    by holder=None. We skip over them. We still save x at (4) (since its holder
733#    is still alive.)
734
735_enable_checkpoint_early_stop = True
736
737
738@contextlib.contextmanager
739def set_checkpoint_early_stop(enable: bool):
740    """Context manager that sets whether checkpoint should stop recomputation early.
741
742    By default, non-reentrant checkpoint stops recomputation as soon as it
743    has computed all needed Tensors. This context manager can be used to disable
744    that feature if it is problematic for your specific application.
745
746    This context manager only needs to be active when forward is run. It does
747    not need to be active during backward.
748
749    Example::
750
751    >>> # xdoctest: +SKIP(failing)
752    >>> message = "saved tensors default hooks are disabled"
753    >>> with set_checkpoint_early_stop(False):
754    ...     # Any checkpoint under this context manager will respect this
755    ...     # context manager, even if its backward is performed outside.
756    ...     out = checkpoint(fn, inputs)
757    ...
758    >>> out.backward()
759    """
760    global _enable_checkpoint_early_stop
761    try:
762        prev = _enable_checkpoint_early_stop
763        _enable_checkpoint_early_stop = enable
764        yield
765    finally:
766        _enable_checkpoint_early_stop = prev
767
768
769class _Handle:
770    pass
771
772
773class _Holder:
774    def __init__(self):
775        self.handles: Dict[int, Optional[_Handle]] = {}
776
777
778class _NoopSaveInputs(torch.autograd.Function):
779    @staticmethod
780    def forward(*args):
781        return torch.empty((0,))
782
783    @staticmethod
784    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
785        # Only tensors can be saved with ctx.save_for_backward, everything else
786        # is captured by get_args, which is saved directly on ctx
787        tensor_indices, tensors = zip(
788            *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
789        )
790        idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
791        # args but with tensors replaced with None as placeholders
792        args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
793
794        def get_args(saved_tensors):
795            # restore the placeholders with the original tensors grabbed from
796            # ctx.saved_tensors (which may be saved on a parent checkpoint if
797            # this checkpoint is nested, and that would trigger a recursive
798            # unpack!)
799            ret = [
800                saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
801                for i, o in enumerate(args)
802            ]
803            # grab the tail since we also saved the dummy to avoid having to explicitly
804            # handle the case where there are no tensor inputs
805            return ret[1:]
806
807        ctx.get_args = get_args
808        ctx.save_for_backward(*tensors)
809
810    @staticmethod
811    def backward(ctx, *grad_outputs):
812        raise AssertionError("Did not expect to backward on this graph")
813
814
815class _CheckpointFrame:
816    def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
817        self.recompute_fn = recompute_fn
818        self.input_saver = None
819        self.weak_holders: List[ReferenceType] = []
820        # We store this as a weakkeydictionary so that in the case of a partial
821        # backward, the entries in the dict are cleared alongside the Holder
822        # which will be removed when the SavedVariable is cleared.
823        self.recomputed: DefaultDict[
824            int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
825        ] = defaultdict(weakref.WeakKeyDictionary)
826        # We need both recomp_counter and recomputed since they can diverge
827        # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
828        self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
829        self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
830
831        # See Rule 5
832        self.early_stop = early_stop
833
834        # Debugging
835        self.metadata_fn = metadata_fn
836        self.unpack_error_cb = unpack_error_cb
837        self.x_metadatas = []
838        self.forward_completed = False
839        self.ignore_saved_mismatch = False
840
841    def check_recomputed_tensors_match(self, gid):
842        if self.ignore_saved_mismatch:
843            # TODO: we can probably make this check stricter by checking that
844            #       the metadata of the first tensors still match.
845            return
846        # NOTE [ Error handling for checkpoint ]
847        #
848        # At a high level, we need to check that the tensors saved
849        # during original forward matches tensors saved during recompute
850        # This means handling 3 cases:
851        #
852        # 1. During recompute, more tensors were saved.
853        #
854        #    Usually this is hidden due to the StopRecomputationError
855        #    but if early stop is not enabled, or we would have errored
856        #    anyway because there aren't enough weak_holders. But we
857        #    do want to have a nice error. See the _recomputation_hook
858        #    for details.
859        if not len(self.weak_holders) == self.recomp_counter[gid]:
860            # 2. During recompute, fewer tensors were saved
861            #
862            # We know that everytime we save something do original forward
863            # we append to weak_holder, and every time we save a tensor
864            # during recompute we increment recompute_counter.
865            raise CheckpointError(
866                "torch.utils.checkpoint: A different number of tensors was saved "
867                "during the original forward and recomputation.\n"
868                f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
869                f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
870            )
871
872        # 3. During recompute, the same tensors were saved, but they
873        #    have different metadata
874        nb_meta_different = []
875        for idx, weak_holder in enumerate(self.weak_holders):
876            holder = weak_holder()
877            if holder is None:
878                continue
879            # We've seen all holders since we iterate over them in order
880            # For every holder that is still alive now, it must've been
881            # alive when we saw it during recompute, therefore, the
882            # gid must be set.
883            _internal_assert(gid in holder.handles)
884            # We know this is the first unpack, so it couldn't have been set
885            # to None yet.
886            _internal_assert(holder.handles[gid] is not None)
887            # We always set these together in the recomputation hook
888            _internal_assert(holder.handles[gid] in self.recomputed[gid])
889            # see pack hook, x_metadata is 1:1 with weak_holders.
890            x_meta = self.x_metadatas[idx]
891            recomputed_x = self.recomputed[gid][holder.handles[gid]]
892            if x_meta != self.metadata_fn(recomputed_x):
893                nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
894
895        if len(nb_meta_different) > 0:
896            mismatched_tensors = ""
897            for idx, x_meta, recomputed_meta in nb_meta_different:
898                mismatched_tensors += (
899                    f"tensor at position {idx}:\n"
900                    f"saved metadata: {x_meta}\n"
901                    f"recomputed metadata: {recomputed_meta}\n"
902                )
903            raise CheckpointError(
904                "torch.utils.checkpoint: Recomputed values for the following tensors "
905                "have different metadata than during the forward pass.\n"
906                f"{mismatched_tensors}"
907            )
908
909
910_checkpoint_error_template = """ \
911An error happened while unpacking tensors; dumping logs of latest computation
912because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
913Scroll all the way down for guidance on how to navigate these logs.
914
915+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
916|        1. Stack traces of the operators that ran in the original forward     |
917+------------------------------------------------------------------------------+
918
919{forward_traces}
920+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
921|        2. Stack traces of the operators that ran during recomputation        |
922+------------------------------------------------------------------------------+
923
924{recompute_traces}
925+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
926|       3. Log of operators in the original forward and recomputation          |
927+------------------------------------------------------------------------------+
928(Scroll up to correlate stack traces with each operation listed below. This
929 helps identify their source in the code.)
930
931IMPORTANT: Differences in "detach" calls between the original forward and the
932           recomputation are expected. They are introduced by the checkpointing
933           mechanism and can be ignored.
934
935Operations executed during the original forward:
936
937{forward_ops}
938
939Operations executed during recomputation:
940
941{recompute_ops}
942
943+------------------------------------------------------------------------------+
944 ERROR: Detected non-determinism while running activation checkpointing
945
946 You are seeing this error because you passed `debug=True` to checkpoint and
947 tensors to be saved during the original forward and differ between those saved
948 during recomputation. This can happen if different operators were ran in the
949 original forward and in the recomputation.
950
951 To identify where the mismatch may be coming from, you can do the following:
952
953 1) Compare the operators ran during original forward and recomputation to
954    see where they differ. These operators are printed above in the order they
955    were executed.
956
957 2) Review the stack trace for each operator to locate its invocation source.
958    Each operator's stack trace is printed in their execution order.
959
960 Note that the logs can be quite long. Here's how they are structured:
961 (Tip: you can Ctrl-f for these headers)
962
963 1. Stack traces of the operators that ran in the original forward
964 2. Stack traces of the operators that ran during recomputation
965 3. Log of operators in the original forward and recomputation
966 4. Error message                                             <--- You are here
967--------------------------------------------------------------------------------
968"""
969
970class CheckpointError(RuntimeError):
971    pass
972
973
974def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
975    # This function returns the context_fn and error_cb to be used by the
976    # checkpointing mechanism. error_cb is invoked when an error is detected
977    # during unpack.
978
979    # record_context_cpp is not support on non-linux non-x86_64 platforms
980    cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
981
982    class CaptureLogs:
983        def __init__(self):
984            self.logs = None
985            self.tbs = None
986
987        def get_context_manager(self):
988            @contextlib.contextmanager
989            def logging_mode():
990                with LoggingTensorMode(), \
991                     capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
992                    self.logs, self.tbs = logs_and_tb
993                    yield logs_and_tb
994            return logging_mode()
995
996    capture_logs_fwd = CaptureLogs()
997    capture_logs_recompute = CaptureLogs()
998
999    def unpack_error_cb(e: CheckpointError):
1000        def get_str_tb(label, capture_logs):
1001            out = ""
1002            total_len = len(capture_logs.logs)
1003            for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
1004                out += f"{log}   ({i + 1} of {total_len} in {label})\n\n"
1005                found_torch_dispatch = False
1006                for line in tb:
1007                    # Start printing stack trace only after __torch_dispatch__ is found
1008                    is_torch_dispatch = line['name'] == '__torch_dispatch__'
1009                    if not found_torch_dispatch and not is_torch_dispatch:
1010                        continue
1011                    elif is_torch_dispatch:
1012                        found_torch_dispatch = True
1013                        continue
1014                    out += f"{line['filename']}:{line['line']}:{line['name']}\n"
1015                out += "\n\n"
1016            return out
1017        assert capture_logs_fwd.logs is not None
1018        assert capture_logs_recompute.logs is not None
1019        raise CheckpointError(
1020            _checkpoint_error_template.format(
1021                forward_traces=get_str_tb("original", capture_logs_fwd),
1022                recompute_traces=get_str_tb("recompute", capture_logs_recompute),
1023                forward_ops="\n".join(capture_logs_fwd.logs),
1024                recompute_ops="\n".join(capture_logs_recompute.logs)
1025            )
1026        ) from e
1027
1028    def context_fn():
1029        return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
1030
1031    return context_fn, unpack_error_cb
1032
1033def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
1034    # These properties are fast to check, easy to understand
1035    return {
1036        "shape": x.shape,
1037        "dtype": x.dtype,
1038        "device": x.device
1039    }
1040
1041_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
1042    _DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
1043    "none": lambda _: None,
1044}
1045
1046# See Rule 5
1047class _StopRecomputationError(Exception):
1048    pass
1049
1050
1051class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
1052    def __init__(self, target_frame_ref: ReferenceType, gid: int):
1053        def pack_hook(x):
1054            x = x.detach() if x.requires_grad else x
1055            target_frame = target_frame_ref()
1056            assert target_frame is not None  # appease mypy
1057            recomp_idx = target_frame.recomp_counter[gid]
1058            target_frame.recomp_counter[gid] += 1
1059
1060            if recomp_idx >= len(target_frame.weak_holders):
1061                assert not target_frame.early_stop
1062                if not target_frame.forward_completed:
1063                    # We run into this case when early stop is not enabled and do
1064                    # grad within checkpoint.
1065                    # We need to set this flag, so we don't error out later when
1066                    # we check if the number of tensors saved during forward and
1067                    # recomputation match.
1068                    target_frame.ignore_saved_mismatch = True
1069                    return x
1070                raise CheckpointError(
1071                    "torch.utils.checkpoint: trying to save more tensors during "
1072                    "recomputation than during the original forward pass."
1073                )
1074
1075            holder = target_frame.weak_holders[recomp_idx]()
1076
1077            # This holder may have been cleared because someone may have called
1078            # backward within forward. If so, we don't need to save.
1079            if holder is not None:
1080                _internal_assert(holder.handles.get(gid, None) is None)
1081                holder.handles[gid] = _Handle()
1082                target_frame.recomputed[gid][holder.handles[gid]] = x
1083
1084            if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
1085                target_frame.weak_holders
1086            ):
1087                raise _StopRecomputationError
1088            # See Rule 6: [ retain_graph is True ] above
1089            return x
1090
1091        def unpack_hook(x):
1092            # See Rule 6: [ retain_graph is True ] above for an example of when
1093            # the graph created during recomputation could be backwarded.
1094            return x
1095
1096        super().__init__(pack_hook, unpack_hook)
1097
1098
1099class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
1100    def __init__(self, frame):
1101        def pack_hook(x):
1102            # See Rule 4 above
1103            holder = _Holder()
1104            frame.weak_holders.append(weakref.ref(holder))
1105            # Save metadata to detect non-determinism
1106            if frame.metadata_fn is not None:
1107                with torch.no_grad():
1108                    frame.x_metadatas.append(frame.metadata_fn(x))
1109            return holder
1110
1111        def unpack_hook(holder):
1112            gid = torch._C._current_graph_task_id()
1113            if gid == -1:
1114                # generate a temporary id if we trigger unpack outside of a backward call
1115                gid = int(uuid.uuid4())
1116
1117            if not frame.is_recomputed[gid]:
1118                ctx = frame.input_saver.grad_fn
1119                args = ctx.get_args(ctx.saved_tensors)
1120
1121                try:
1122                    with _recomputation_hook(
1123                        weakref.ref(frame), gid
1124                    ), torch.autograd.enable_grad():
1125                        frame.recompute_fn(*args)
1126                except _StopRecomputationError:
1127                    pass
1128                frame.is_recomputed[gid] = True
1129                frame.check_recomputed_tensors_match(gid)
1130
1131            _internal_assert(gid in holder.handles)
1132
1133            if holder.handles[gid] is None:
1134                raise CheckpointError(
1135                    "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
1136                    "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
1137                    "so only once. Otherwise please open an issue with details on your use case."
1138                )
1139            _internal_assert(holder.handles[gid] in frame.recomputed[gid])
1140            ret = frame.recomputed[gid][holder.handles[gid]]
1141            holder.handles[gid] = None
1142            return ret
1143
1144        if frame.unpack_error_cb is not None:
1145            def unpack_hook_with_error_cb(holder):
1146                try:
1147                    return unpack_hook(holder)
1148                except CheckpointError as e:
1149                    frame.unpack_error_cb(e)
1150            super().__init__(pack_hook, unpack_hook_with_error_cb)
1151        else:
1152            super().__init__(pack_hook, unpack_hook)
1153
1154
1155def _is_compiling(func, args, kwargs):
1156    # Check if we are under AOTAutograd tracing
1157    # There should probably be a better way to do this...
1158    # TODO: unify _is_compiling across all compile stacks
1159    for arg in args:
1160        if isinstance(arg, torch.Tensor) and is_fun(arg):
1161            return True
1162    return False
1163
1164
1165class _VersionWrapper:
1166    # Check that cached tensors are not mutated.
1167    def __init__(self, val):
1168        self.val: Union[torch.Tensor, Any] = val
1169        self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
1170
1171    def get_val(self, allow_cache_entry_mutation):
1172        if self.version is not None and not allow_cache_entry_mutation:
1173            if self.val._version != self.version:
1174                # Can we give user a stack trace of where the mutation happened?
1175                raise RuntimeError(
1176                    "Tensor cached during selective activation checkpoint has been mutated"
1177                )
1178        return self.val
1179
1180
1181def _maybe_detach(x, any_ret_has_alias_info):
1182    # We detach for two separate reasons:
1183    # - For view ops, we need to ensure that when the tensor is returned from
1184    #   CachedDispatchMode, as_view sees that the AutogradMeta is nullptr
1185    # - Avoid reference cycles
1186    # For case 1, it is not enough to check whether x has differentiable dtype
1187    # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g.
1188    # when the tensor is a view.
1189    if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info):
1190        with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False):
1191            # Ensure that view performed beneath autograd properly propagates
1192            # version counter. TODO: Use reentrant_dispatch instead of
1193            # manually manipulating dispatch keys. Using reentrant_dispatch
1194            # would respect inference_mode, though that is not relevant for
1195            # this case.
1196            x = x.detach()
1197    return x
1198
1199
1200class SelectiveCheckpointContext:
1201    """
1202    Context passed to policy function during selective checkpointing.
1203
1204    This class is used to pass relevant metadata to the policy function during
1205    selective checkpointing. The metadata includes whether the current invocation
1206    of the policy function is during recomputation or not.
1207
1208    Example:
1209        >>> # xdoctest: +SKIP(stub)
1210        >>>
1211        >>> def policy_fn(ctx, op, *args, **kwargs):
1212        >>>    print(ctx.is_recompute)
1213        >>>
1214        >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
1215        >>>
1216        >>> out = torch.utils.checkpoint.checkpoint(
1217        >>>     fn, x, y,
1218        >>>     use_reentrant=False,
1219        >>>     context_fn=context_fn,
1220        >>> )
1221    """
1222    def __init__(self, *, is_recompute):
1223        self.is_recompute = is_recompute
1224
1225
1226class CheckpointPolicy(enum.Enum):
1227    """
1228    Enum for specifying the policy for checkpointing during backpropagation.
1229
1230    The following policies are supported:
1231
1232    - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward
1233      pass and will not be recomputed during the backward pass
1234    - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the
1235      forward pass and will be recomputed during the backward pass
1236
1237    Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden
1238    by other subsystems like `torch.compile`.
1239
1240    .. note::
1241        A policy function that always returns ``PREFER_RECOMPUTE`` is
1242        equivalent to vanilla checkpointing.
1243
1244        A policy function that returns ``PREFER_SAVE`` every op is
1245        NOT equivalent to not using checkpointing. Using such a policy would
1246        save additional tensors not limited to ones that are actually needed for
1247        gradient computation.
1248    """
1249    MUST_SAVE = 0
1250    PREFER_SAVE = 1
1251    MUST_RECOMPUTE = 2
1252    PREFER_RECOMPUTE = 3
1253
1254
1255def _policy_from_bool(b):
1256    # For backward compatability
1257    return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE
1258
1259
1260SAC_IGNORED_OPS = {
1261    # AC inserts different number of detach during forward and recompute.
1262    torch.ops.aten.detach.default,
1263    # AC's determinism check invokes additional metadata ops during forward.
1264    # With subclasses involved, these metadata ops become dispatchable, this
1265    # can result in incorrectness if these ops are selected cached.
1266    torch.ops.prim.device.default,
1267} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
1268
1269
1270class _CachingTorchDispatchMode(TorchDispatchMode):
1271    # Used together with _CachedTorchDispatchMode to implement SAC.
1272    def __init__(self, policy_fn, storage):
1273        self.policy_fn = policy_fn
1274        self.storage = storage
1275
1276    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1277        if func in SAC_IGNORED_OPS:
1278            return func(*args, **kwargs)
1279
1280        kwargs = {} if kwargs is None else kwargs
1281        policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
1282                                func, *args, **kwargs)
1283        if isinstance(policy, bool):
1284            policy = _policy_from_bool(policy)
1285
1286        is_compiling = _is_compiling(func, args, kwargs)
1287
1288        if is_compiling:
1289            # Overwrite each node's "recompute" tag to add in the user annotation.
1290            fx_traceback.current_meta["recompute"] = policy
1291
1292        out = func(*args, **kwargs)
1293
1294        any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
1295
1296        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
1297            self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
1298        return out
1299
1300class _CachedTorchDispatchMode(TorchDispatchMode):
1301    # Used together with _CachedTorchDispatchMode to implement SAC.
1302    def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
1303        self.policy_fn = policy_fn
1304        self.storage = storage
1305        self.allow_cache_entry_mutation = allow_cache_entry_mutation
1306
1307    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1308        if func in SAC_IGNORED_OPS:
1309            return func(*args, **kwargs)
1310
1311        kwargs = {} if kwargs is None else kwargs
1312        policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
1313                                func, *args, **kwargs)
1314        if isinstance(policy, bool):
1315            policy = _policy_from_bool(policy)
1316
1317        is_compiling = _is_compiling(func, args, kwargs)
1318
1319        if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
1320            storage = self.storage.get(func)
1321            if storage is None:
1322                raise RuntimeError(f"{func} encountered during backward, but not found in storage")
1323            if len(storage) == 0:
1324                raise RuntimeError(
1325                    "Trying to backward an extra time. You are only allowed to backward once "
1326                    "on any region computed under selective activation checkpoint."
1327                )
1328            out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
1329        else:
1330            out = func(*args, **kwargs)
1331        return out
1332
1333
1334def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
1335    """
1336    Helper to avoid recomputing certain ops during activation checkpointing.
1337
1338    Use this with `torch.utils.checkpoint.checkpoint` to control which
1339    operations are recomputed during the backward pass.
1340
1341    Args:
1342        policy_fn_or_list (Callable or List):
1343          - If a policy function is provided, it should accept a
1344            :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and
1345            kwargs to the op, and return a :class:`CheckpointPolicy` enum value
1346            indicating whether the execution of the op should be recomputed or not.
1347          - If a list of operations is provided, it is equivalent to a policy
1348            returning `CheckpointPolicy.MUST_SAVE` for the specified
1349            operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other
1350            operations.
1351        allow_cache_entry_mutation (bool, optional): By default, an error is
1352            raised if any tensors cached by selective activation checkpoint are
1353            mutated in order to ensure correctness. If set to `True`, this check
1354            is disabled.
1355    Returns:
1356        A tuple of two context managers.
1357
1358    Example:
1359        >>> # xdoctest: +REQUIRES(LINUX)
1360        >>> import functools
1361        >>>
1362        >>> x = torch.rand(10, 10, requires_grad=True)
1363        >>> y = torch.rand(10, 10, requires_grad=True)
1364        >>>
1365        >>> ops_to_save = [
1366        >>>    torch.ops.aten.mm.default,
1367        >>> ]
1368        >>>
1369        >>> def policy_fn(ctx, op, *args, **kwargs):
1370        >>>    if op in ops_to_save:
1371        >>>        return CheckpointPolicy.MUST_SAVE
1372        >>>    else:
1373        >>>        return CheckpointPolicy.PREFER_RECOMPUTE
1374        >>>
1375        >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
1376        >>>
1377        >>> # or equivalently
1378        >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
1379        >>>
1380        >>> def fn(x, y):
1381        >>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
1382        >>>
1383        >>> out = torch.utils.checkpoint.checkpoint(
1384        >>>     fn, x, y,
1385        >>>     use_reentrant=False,
1386        >>>     context_fn=context_fn,
1387        >>> )
1388    """
1389    # NB: If grad_mode is disabled, checkpoint would not run forward under
1390    #     context_fn anyway, so proceed as usual.
1391    if isinstance(policy_fn_or_list, list):
1392        for op in policy_fn_or_list:
1393            if not isinstance(op, torch._ops.OpOverload):
1394                _extra_msg = (
1395                    "Please update the OpOverloadPacket to a specific OpOverload."
1396                    "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."
1397                ) if isinstance(op, torch._ops.OpOverloadPacket) else ""
1398                raise ValueError(
1399                    f"Expected op in `op_list` to be an OpOverload but got: {op} "
1400                    f"of type {type(op)}. {_extra_msg}"
1401                )
1402
1403        def policy_fn(ctx, op, *args, **kwargs):
1404            if op in policy_fn_or_list:
1405                return CheckpointPolicy.MUST_SAVE
1406            else:
1407                return CheckpointPolicy.PREFER_RECOMPUTE
1408    elif callable(policy_fn_or_list):
1409        policy_fn = policy_fn_or_list
1410    else:
1411        raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
1412
1413    storage: Dict[Any, List[Any]] = defaultdict(list)
1414    return (
1415        _CachingTorchDispatchMode(policy_fn, storage),
1416        _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation),
1417    )
1418
1419# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
1420#     saving/restoring of global state is handled here.
1421
1422def _checkpoint_without_reentrant_generator(
1423    fn,
1424    preserve_rng_state=True,
1425    context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
1426    determinism_check: str = _DEFAULT_DETERMINISM_MODE,
1427    debug: bool = False,
1428    *args,
1429    **kwargs
1430):
1431    """Checkpointing without reentrant autograd.
1432
1433    Args:
1434        function: describes what to run in the forward pass of the model or
1435            part of the model. It should also know how to handle the inputs
1436            passed as the tuple. For example, in LSTM, if user passes
1437            ``(activation, hidden)``, :attr:`function` should correctly use the
1438            first input as ``activation`` and the second input as ``hidden``
1439        preserve_rng_state(bool, optional):  Omit stashing and restoring
1440            the RNG state during each checkpoint.
1441            Default: ``True``
1442        context_fn(Callable, optional): A callable returning a tuple of two
1443            context managers. The function and its recomputation will be run
1444            under the first and second context managers respectively.
1445        determinism_check(str, optional): A string specifying the determinism
1446            check to perform. By default it is set to ``"default"`` which
1447            compares the shapes, dtypes, and devices of the recomputed tensors
1448            against those the saved tensors. To turn off this check, specify
1449            ``"none"``. Currently these are the only two supported values.
1450            Please open an issue if you would like to see more determinism
1451            checks.
1452        debug(bool, optional): If ``True``, error messages will also include
1453            a trace of the operators ran during the original forward computation
1454            as well as the recomputation.
1455        *args: Arguments to pass in to the given ``function``.
1456        **kwargs: Keyword arguments to pass into the given ``function``.
1457    """
1458    unpack_error_cb = None
1459
1460    if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
1461        if context_fn != noop_context_fn:
1462            raise ValueError(
1463                "debug=True is incompatible with non-default context_fn"
1464            )
1465        context_fn, unpack_error_cb = _get_debug_context_and_cb()
1466
1467    if determinism_check in _allowed_determinism_checks_to_fns:
1468        metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
1469    else:
1470        raise ValueError(
1471            f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
1472            f"but got {determinism_check}"
1473        )
1474
1475    device_type = _infer_device_type(*args)
1476    device_module = _get_device_module(device_type)
1477    forward_context, recompute_context = context_fn()
1478    if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
1479        assert (
1480            isinstance(forward_context, TorchDispatchMode) and
1481            isinstance(recompute_context, TorchDispatchMode)
1482        ), \
1483            "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
1484            "must generate a tuple of two `TorchDispatchMode`s."
1485    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
1486    device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)
1487
1488    if preserve_rng_state:
1489        fwd_cpu_state = torch.get_rng_state()
1490        # Don't eagerly initialize the cuda context by accident.
1491        # (If the user intends that the context is initialized later, within their
1492        # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
1493        # we have no way to anticipate this will happen before we run the function.
1494        # If they do so, we raise an error.)
1495        had_device_in_fwd = False
1496        if getattr(device_module, "_initialized", False):
1497            had_device_in_fwd = True
1498            fwd_devices, fwd_device_states = get_device_states(*args)
1499
1500    def recompute_fn(*inputs):
1501        kwargs, *args = inputs
1502        # This will be called later during recomputation. This wrapping enables
1503        # the necessary global state to be captured.
1504        rng_devices = []
1505        if preserve_rng_state and had_device_in_fwd:
1506            rng_devices = fwd_devices
1507        with torch.random.fork_rng(
1508            devices=rng_devices, enabled=preserve_rng_state, device_type=device_type
1509        ):
1510            if preserve_rng_state:
1511                torch.set_rng_state(fwd_cpu_state)
1512                if had_device_in_fwd:
1513                    set_device_states(fwd_devices, fwd_device_states, device_type=device_type)
1514
1515            device_autocast_ctx = torch.amp.autocast(
1516                device_type=device_type, **device_autocast_kwargs
1517            ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext()
1518            with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
1519                fn(*args, **kwargs)
1520
1521    new_frame = _CheckpointFrame(
1522        recompute_fn,
1523        _enable_checkpoint_early_stop,
1524        unpack_error_cb,
1525        metadata_fn
1526    )
1527    dummy = torch.empty((0,), requires_grad=True)
1528    new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
1529
1530    # When ambient grad_mode is False
1531    if new_frame.input_saver.grad_fn is None:
1532        yield
1533        return
1534
1535    with _checkpoint_hook(new_frame), forward_context:
1536        yield
1537    new_frame.forward_completed = True
1538
1539    if getattr(device_module, "_initialized", False) and \
1540       preserve_rng_state and not had_device_in_fwd:  # type: ignore[possibly-undefined]
1541        # Device was not initialized before running the forward, so we didn't
1542        # stash the device state.
1543        raise RuntimeError(
1544            "PyTorch's device state was initialized in the forward pass "
1545            "of a Checkpoint, which is not allowed. Please open an issue "
1546            "if you need this feature."
1547        )
1548
1549    return
1550