xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from abc import ABC, abstractmethod
4from enum import auto, Enum
5from functools import partial
6from typing import Any, Callable, Dict, Iterator, Optional, Tuple
7
8import torch
9import torch.nn as nn
10from torch.autograd.graph import save_on_cpu
11from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs
12from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint
13
14
15_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module"
16_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "."
17
18
19class CheckpointImpl(Enum):
20    REENTRANT = auto()
21    NO_REENTRANT = auto()
22
23
24class ActivationWrapper(torch.nn.Module, ABC):
25    """
26    Base class for Activation Checkpoint and Activation Offload.
27
28    Not meant to be instantiated directly.
29    """
30
31    def __init__(self, mod):
32        super().__init__()
33        self._checkpoint_wrapped_module = mod
34        # state_dict post hook to remove prefix to allow loading into a
35        # non-checkpoint wrapped module.
36        self._register_state_dict_hook(self._post_state_dict_hook)
37        # load_state_dict pre-hook to allow loading back into
38        # checkpoint-wrapped module.
39        self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
40
41    @abstractmethod
42    def forward(self, *args, **kwargs):
43        raise ValueError("Subclasses should implement forward().")
44
45    def __getattr__(self, name: str) -> Any:
46        """Forward missing attributes to wrapped module."""
47        try:
48            return super().__getattr__(name)  # defer to nn.Module's logic
49        except AttributeError:
50            return getattr(self._checkpoint_wrapped_module, name)
51
52    def __getitem__(self, key: int) -> Any:
53        """Forward indexing calls in case the module is a nn.Sequential."""
54        return self._checkpoint_wrapped_module.__getitem__(key)  # type: ignore[operator]
55
56    def named_parameters(
57        self,
58        *args,
59        **kwargs,
60    ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
61        """
62        Override :meth:`named_parameters()` to intercept parameter names.
63
64        remove all occurrences of ``_CHECKPOINT_PREFIX``.
65        """
66        for param_name, param in super().named_parameters(*args, **kwargs):
67            yield param_name.replace(_CHECKPOINT_PREFIX, ""), param
68
69    @staticmethod
70    def _post_state_dict_hook(
71        module: nn.Module,
72        state_dict: Dict[str, Any],
73        prefix: str,
74        *args: Any,
75    ) -> Dict[str, Any]:
76        """
77        _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
78
79        For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix,
80        so that this module can be loaded into non-checkpointed modules.
81        It would still be able to be loaded into checkpoint-wrapped modules as this class,
82        adds the prefix back before loading the state_dict.
83        """
84        _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix)
85        return state_dict
86
87    @staticmethod
88    def _pre_load_state_dict_hook(
89        module: nn.Module,
90        state_dict: Dict[str, Any],
91        prefix: str,
92        *args: Any,
93    ) -> None:
94        """
95        ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called.
96
97        For ``checkpoint_wrapper``, it will add back the module
98        prefix so that non-checkpointed modules can be loaded into
99        checkpoint_wrapper modules properly.
100        """
101        _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}")
102
103
104class OffloadWrapper(ActivationWrapper):
105    def __init__(self, mod):
106        super().__init__(mod)
107
108    def forward(self, *args, **kwargs):
109        with save_on_cpu(pin_memory=True):
110            return self._checkpoint_wrapped_module(*args, **kwargs)
111
112
113class CheckpointWrapper(ActivationWrapper):
114    """
115    An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing.
116
117    Note that this module is not meant to be used directly but instead,
118    it is to be used through the ``checkpoint_wrapper`` function.
119    """
120
121    def __init__(
122        self,
123        mod: torch.nn.Module,
124        checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
125        checkpoint_fn=None,
126        **checkpoint_fn_kwargs,
127    ):
128        super().__init__(mod)
129        self.checkpoint_impl = checkpoint_impl
130        if checkpoint_fn is None:
131            # use torch.utils.checkpoint
132            self.checkpoint_fn = partial(
133                torch_utils_checkpoint,
134                use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
135                **checkpoint_fn_kwargs,
136            )
137        else:
138            # Construct user-specified checkpoint function.
139            self.checkpoint_fn = partial(
140                checkpoint_fn,
141                **checkpoint_fn_kwargs,
142            )
143
144    def forward(self, *args, **kwargs):
145        # Support keyword arguments for reentrant checkpoint. Note that this
146        # only works if user has specified self.checkpoint_impl and is not
147        # using their own custom checkpoint_fn.
148        if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}:
149            # Pack the args and kwargs
150            flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs)
151
152            # Function that only takes (packed) args, but can unpack them
153            # into the original args and kwargs for the checkpointed
154            # function, and runs that function.
155            def my_function(*inputs):
156                # unpack back into args and kwargs
157                unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys)
158                # run original module
159                return self._checkpoint_wrapped_module(
160                    *unpacked_args, **unpacked_kwargs
161                )
162
163            # Pass the function that only takes packed args into reentrant
164            # checkpoint API.
165            return self.checkpoint_fn(  # type: ignore[misc]
166                my_function,
167                *flat_args,
168            )
169        else:
170            return self.checkpoint_fn(  # type: ignore[misc]
171                self._checkpoint_wrapped_module, *args, **kwargs
172            )
173
174
175def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module:
176    """
177    Wrap a module for activation offloading to CPU.
178
179    Offloads intermediate activations to the CPU for modules wrapped with this function.
180    Wrappers with activation offload can be composed with ones that do recomputation-based
181    checkpoint to trade off increased compute versus increased CPU
182    memory usage and additional H2D transfers.
183
184    Usage::
185        offloaded_module = offload_wrapper(module)
186        outputs = checkpointed_module(inputs)
187    Args:
188        module (nn.Module):
189            The module to be wrapped
190    Returns:
191        (nn.Module):
192            Wrapped module
193    """
194    return OffloadWrapper(module)
195
196
197def checkpoint_wrapper(
198    module: torch.nn.Module,
199    checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT,
200    checkpoint_fn=None,
201    **checkpoint_fn_kwargs,
202) -> torch.nn.Module:
203    """
204    Wrap a module for activation checkpointing.
205
206    If the module is wrapped with this function, all subsequent calls to the module will,
207    automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function.
208
209    Usage::
210        checkpointed_module = checkpoint_wrapper(module)
211        outputs = checkpointed_module(inputs)
212    Args:
213        module (nn.Module):
214            The module to be wrapped
215        checkpoint_impl (Optional[CheckpointImpl]):
216            The checkpointing implementation to use. Note that this will only
217            be passed into the ``torch.utils.checkpoint.checkpoint``
218            implementation, and is ignored if a custom ``checkpoint_fn`` is
219            specified. Note that for implementations using reentrant checkpoint
220            from ``torch.utils.checkpoint``, keyword arguments will only be
221            supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`.
222        checkpoint_fn (Optional[Callable]):
223            Functional checkpoint implementation to use. If this is specified,
224            it will be used over the default ``torch.utils.checkpoint.checkpoint``
225            implementation and the `checkpoint_impl` argument will be ignored.
226        **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`.
227
228    Returns:
229        (nn.Module):
230            Wrapped module
231    """
232
233    if checkpoint_impl == CheckpointImpl.REENTRANT:
234        warnings.warn(
235            f"Please specify {CheckpointImpl.NO_REENTRANT} as "
236            f"{CheckpointImpl.REENTRANT} will soon be removed as "
237            "the default and eventually deprecated.",
238            FutureWarning,
239            stacklevel=2,
240        )
241    return CheckpointWrapper(
242        module,
243        checkpoint_impl,
244        checkpoint_fn,
245        **checkpoint_fn_kwargs,
246    )
247
248
249def apply_activation_checkpointing(
250    model,
251    checkpoint_wrapper_fn=checkpoint_wrapper,
252    check_fn=lambda _: True,
253    auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None,
254):
255    """
256    Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration.
257
258    For each module within `model`, the `check_fn` is used to decide
259    whether `module` should be wrapped with :func:`checkpoint_wrapper` or not.
260
261    Note::
262        This function modifies `model` in place and replaces appropriate layers with
263        their checkpoint-wrapped modules.
264    Note::
265        This function will not wrap the overall root module. If this is needed, please directly use
266        :func:`checkpoint_wrapper` or :func:`offload_wrapper`.
267    Usage::
268        model = nn.Sequential(
269            nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)
270        )
271        check_fn = lambda l: isinstance(l, nn.Linear)
272        # checkpoint activations
273        apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn)
274        # Or offload activations to CPU
275        apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn)
276    Args:
277        model (nn.Module):
278            The model whose submodules should be wrapped with activation checkpointing.
279        checkpoint_wrapper_fn (Optional[Callable[nn.Module]])
280            A ``Callable`` which will wrap modules
281        check_fn (Optional[Callable[nn.Module, nn.Module]])
282            A lambda function which will be passed each child submodule of ``model`` and returns
283            ``True`` or ``False`` depending on whether the submodule should be wrapped.
284        auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's
285            submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``.
286    Returns: None (`model` is modified inplace)
287    """
288    # TODO: Importing inside function to avoid circular import issue between FSDP and
289    # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code.
290    from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply
291    from torch.distributed.fsdp.wrap import (
292        _Policy,
293        _recursive_wrap,
294        lambda_auto_wrap_policy,
295    )
296
297    policy = (
298        auto_wrap_policy
299        if auto_wrap_policy is not None
300        else partial(lambda_auto_wrap_policy, lambda_fn=check_fn)
301    )
302    if not callable(policy):
303        if not isinstance(policy, _Policy):
304            raise ValueError(
305                f"Expected {policy} to be callable or be a pre-defined wrap policy"
306            )
307        target_module_to_kwargs = policy._run_policy(
308            model, ignored_modules=set(), root_kwargs={}
309        )
310        wrap_fn = _construct_wrap_fn(
311            model, target_module_to_kwargs, checkpoint_wrapper_fn
312        )
313        _post_order_apply(model, wrap_fn)
314        return
315
316    _recursive_wrap(
317        module=model,
318        auto_wrap_policy=policy,  # type: ignore[arg-type]
319        wrapper_cls=checkpoint_wrapper_fn,
320        ignored_modules=set(),
321        ignored_params=set(),
322        only_wrap_children=True,
323    )
324