1# mypy: allow-untyped-defs 2import warnings 3from abc import ABC, abstractmethod 4from enum import auto, Enum 5from functools import partial 6from typing import Any, Callable, Dict, Iterator, Optional, Tuple 7 8import torch 9import torch.nn as nn 10from torch.autograd.graph import save_on_cpu 11from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs 12from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint 13 14 15_CHECKPOINT_WRAPPED_MODULE = "_checkpoint_wrapped_module" 16_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." 17 18 19class CheckpointImpl(Enum): 20 REENTRANT = auto() 21 NO_REENTRANT = auto() 22 23 24class ActivationWrapper(torch.nn.Module, ABC): 25 """ 26 Base class for Activation Checkpoint and Activation Offload. 27 28 Not meant to be instantiated directly. 29 """ 30 31 def __init__(self, mod): 32 super().__init__() 33 self._checkpoint_wrapped_module = mod 34 # state_dict post hook to remove prefix to allow loading into a 35 # non-checkpoint wrapped module. 36 self._register_state_dict_hook(self._post_state_dict_hook) 37 # load_state_dict pre-hook to allow loading back into 38 # checkpoint-wrapped module. 39 self.register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) 40 41 @abstractmethod 42 def forward(self, *args, **kwargs): 43 raise ValueError("Subclasses should implement forward().") 44 45 def __getattr__(self, name: str) -> Any: 46 """Forward missing attributes to wrapped module.""" 47 try: 48 return super().__getattr__(name) # defer to nn.Module's logic 49 except AttributeError: 50 return getattr(self._checkpoint_wrapped_module, name) 51 52 def __getitem__(self, key: int) -> Any: 53 """Forward indexing calls in case the module is a nn.Sequential.""" 54 return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] 55 56 def named_parameters( 57 self, 58 *args, 59 **kwargs, 60 ) -> Iterator[Tuple[str, torch.nn.Parameter]]: 61 """ 62 Override :meth:`named_parameters()` to intercept parameter names. 63 64 remove all occurrences of ``_CHECKPOINT_PREFIX``. 65 """ 66 for param_name, param in super().named_parameters(*args, **kwargs): 67 yield param_name.replace(_CHECKPOINT_PREFIX, ""), param 68 69 @staticmethod 70 def _post_state_dict_hook( 71 module: nn.Module, 72 state_dict: Dict[str, Any], 73 prefix: str, 74 *args: Any, 75 ) -> Dict[str, Any]: 76 """ 77 _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. 78 79 For ``checkpoint_wrapper``, it will strip checkpoint-wrapped module prefix, 80 so that this module can be loaded into non-checkpointed modules. 81 It would still be able to be loaded into checkpoint-wrapped modules as this class, 82 adds the prefix back before loading the state_dict. 83 """ 84 _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix) 85 return state_dict 86 87 @staticmethod 88 def _pre_load_state_dict_hook( 89 module: nn.Module, 90 state_dict: Dict[str, Any], 91 prefix: str, 92 *args: Any, 93 ) -> None: 94 """ 95 ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` is called. 96 97 For ``checkpoint_wrapper``, it will add back the module 98 prefix so that non-checkpointed modules can be loaded into 99 checkpoint_wrapper modules properly. 100 """ 101 _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}") 102 103 104class OffloadWrapper(ActivationWrapper): 105 def __init__(self, mod): 106 super().__init__(mod) 107 108 def forward(self, *args, **kwargs): 109 with save_on_cpu(pin_memory=True): 110 return self._checkpoint_wrapped_module(*args, **kwargs) 111 112 113class CheckpointWrapper(ActivationWrapper): 114 """ 115 An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. 116 117 Note that this module is not meant to be used directly but instead, 118 it is to be used through the ``checkpoint_wrapper`` function. 119 """ 120 121 def __init__( 122 self, 123 mod: torch.nn.Module, 124 checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, 125 checkpoint_fn=None, 126 **checkpoint_fn_kwargs, 127 ): 128 super().__init__(mod) 129 self.checkpoint_impl = checkpoint_impl 130 if checkpoint_fn is None: 131 # use torch.utils.checkpoint 132 self.checkpoint_fn = partial( 133 torch_utils_checkpoint, 134 use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT), 135 **checkpoint_fn_kwargs, 136 ) 137 else: 138 # Construct user-specified checkpoint function. 139 self.checkpoint_fn = partial( 140 checkpoint_fn, 141 **checkpoint_fn_kwargs, 142 ) 143 144 def forward(self, *args, **kwargs): 145 # Support keyword arguments for reentrant checkpoint. Note that this 146 # only works if user has specified self.checkpoint_impl and is not 147 # using their own custom checkpoint_fn. 148 if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: 149 # Pack the args and kwargs 150 flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) 151 152 # Function that only takes (packed) args, but can unpack them 153 # into the original args and kwargs for the checkpointed 154 # function, and runs that function. 155 def my_function(*inputs): 156 # unpack back into args and kwargs 157 unpacked_args, unpacked_kwargs = _unpack_kwargs(inputs, kwarg_keys) 158 # run original module 159 return self._checkpoint_wrapped_module( 160 *unpacked_args, **unpacked_kwargs 161 ) 162 163 # Pass the function that only takes packed args into reentrant 164 # checkpoint API. 165 return self.checkpoint_fn( # type: ignore[misc] 166 my_function, 167 *flat_args, 168 ) 169 else: 170 return self.checkpoint_fn( # type: ignore[misc] 171 self._checkpoint_wrapped_module, *args, **kwargs 172 ) 173 174 175def offload_wrapper(module: torch.nn.Module) -> torch.nn.Module: 176 """ 177 Wrap a module for activation offloading to CPU. 178 179 Offloads intermediate activations to the CPU for modules wrapped with this function. 180 Wrappers with activation offload can be composed with ones that do recomputation-based 181 checkpoint to trade off increased compute versus increased CPU 182 memory usage and additional H2D transfers. 183 184 Usage:: 185 offloaded_module = offload_wrapper(module) 186 outputs = checkpointed_module(inputs) 187 Args: 188 module (nn.Module): 189 The module to be wrapped 190 Returns: 191 (nn.Module): 192 Wrapped module 193 """ 194 return OffloadWrapper(module) 195 196 197def checkpoint_wrapper( 198 module: torch.nn.Module, 199 checkpoint_impl: CheckpointImpl = CheckpointImpl.NO_REENTRANT, 200 checkpoint_fn=None, 201 **checkpoint_fn_kwargs, 202) -> torch.nn.Module: 203 """ 204 Wrap a module for activation checkpointing. 205 206 If the module is wrapped with this function, all subsequent calls to the module will, 207 automatically perform checkpointing without the user having to explicitly call ``checkpoint`` function. 208 209 Usage:: 210 checkpointed_module = checkpoint_wrapper(module) 211 outputs = checkpointed_module(inputs) 212 Args: 213 module (nn.Module): 214 The module to be wrapped 215 checkpoint_impl (Optional[CheckpointImpl]): 216 The checkpointing implementation to use. Note that this will only 217 be passed into the ``torch.utils.checkpoint.checkpoint`` 218 implementation, and is ignored if a custom ``checkpoint_fn`` is 219 specified. Note that for implementations using reentrant checkpoint 220 from ``torch.utils.checkpoint``, keyword arguments will only be 221 supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. 222 checkpoint_fn (Optional[Callable]): 223 Functional checkpoint implementation to use. If this is specified, 224 it will be used over the default ``torch.utils.checkpoint.checkpoint`` 225 implementation and the `checkpoint_impl` argument will be ignored. 226 **checkpoint_fn_kwargs: (Dict[str, Any]): Keyword arguments to pass into `checkpoint_fn`. 227 228 Returns: 229 (nn.Module): 230 Wrapped module 231 """ 232 233 if checkpoint_impl == CheckpointImpl.REENTRANT: 234 warnings.warn( 235 f"Please specify {CheckpointImpl.NO_REENTRANT} as " 236 f"{CheckpointImpl.REENTRANT} will soon be removed as " 237 "the default and eventually deprecated.", 238 FutureWarning, 239 stacklevel=2, 240 ) 241 return CheckpointWrapper( 242 module, 243 checkpoint_impl, 244 checkpoint_fn, 245 **checkpoint_fn_kwargs, 246 ) 247 248 249def apply_activation_checkpointing( 250 model, 251 checkpoint_wrapper_fn=checkpoint_wrapper, 252 check_fn=lambda _: True, 253 auto_wrap_policy: Optional[Callable[[nn.Module, bool, int], bool]] = None, 254): 255 """ 256 Apply :func:`checkpoint_wrapper` to modules within `model` based on a user-defined configuration. 257 258 For each module within `model`, the `check_fn` is used to decide 259 whether `module` should be wrapped with :func:`checkpoint_wrapper` or not. 260 261 Note:: 262 This function modifies `model` in place and replaces appropriate layers with 263 their checkpoint-wrapped modules. 264 Note:: 265 This function will not wrap the overall root module. If this is needed, please directly use 266 :func:`checkpoint_wrapper` or :func:`offload_wrapper`. 267 Usage:: 268 model = nn.Sequential( 269 nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) 270 ) 271 check_fn = lambda l: isinstance(l, nn.Linear) 272 # checkpoint activations 273 apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) 274 # Or offload activations to CPU 275 apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn) 276 Args: 277 model (nn.Module): 278 The model whose submodules should be wrapped with activation checkpointing. 279 checkpoint_wrapper_fn (Optional[Callable[nn.Module]]) 280 A ``Callable`` which will wrap modules 281 check_fn (Optional[Callable[nn.Module, nn.Module]]) 282 A lambda function which will be passed each child submodule of ``model`` and returns 283 ``True`` or ``False`` depending on whether the submodule should be wrapped. 284 auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): A policy to wrap model's 285 submodules with AC. Note that if this is specified, it takes precedence over ``check_fn``. 286 Returns: None (`model` is modified inplace) 287 """ 288 # TODO: Importing inside function to avoid circular import issue between FSDP and 289 # checkpoint_wrapper. This can be resolved once wrap() APIs are decoupled from FSDP code. 290 from torch.distributed.fsdp._wrap_utils import _construct_wrap_fn, _post_order_apply 291 from torch.distributed.fsdp.wrap import ( 292 _Policy, 293 _recursive_wrap, 294 lambda_auto_wrap_policy, 295 ) 296 297 policy = ( 298 auto_wrap_policy 299 if auto_wrap_policy is not None 300 else partial(lambda_auto_wrap_policy, lambda_fn=check_fn) 301 ) 302 if not callable(policy): 303 if not isinstance(policy, _Policy): 304 raise ValueError( 305 f"Expected {policy} to be callable or be a pre-defined wrap policy" 306 ) 307 target_module_to_kwargs = policy._run_policy( 308 model, ignored_modules=set(), root_kwargs={} 309 ) 310 wrap_fn = _construct_wrap_fn( 311 model, target_module_to_kwargs, checkpoint_wrapper_fn 312 ) 313 _post_order_apply(model, wrap_fn) 314 return 315 316 _recursive_wrap( 317 module=model, 318 auto_wrap_policy=policy, # type: ignore[arg-type] 319 wrapper_cls=checkpoint_wrapper_fn, 320 ignored_modules=set(), 321 ignored_params=set(), 322 only_wrap_children=True, 323 ) 324