1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import unittest 5from copy import deepcopy 6from functools import partial 7 8import torch 9import torch.nn as nn 10from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 11 apply_activation_checkpointing, 12 checkpoint_wrapper, 13 CheckpointImpl, 14 CheckpointWrapper, 15 offload_wrapper, 16 OffloadWrapper, 17) 18from torch.distributed.fsdp.wrap import ModuleWrapPolicy 19from torch.testing._internal.common_utils import run_tests, TestCase 20from torch.utils.checkpoint import checkpoint 21 22 23_SAVED_PREFIX = "_saved_" 24GRAD_FN_NEXT_FUNCTIONS = "next_functions" 25 26 27class CheckpointWrapperTest(TestCase): 28 def test_load_activation_checkpointed_module(self): 29 lin = nn.Linear(10, 10, bias=False) 30 lin = checkpoint_wrapper( 31 lin, 32 checkpoint_fn=checkpoint, 33 # checkpoint kwargs 34 use_reentrant=True, 35 preserve_rng_state=False, 36 ) 37 state_dict = deepcopy(lin.state_dict()) 38 # Load into non-checkpoint wrapped linear module 39 lin_new = nn.Linear(10, 10, bias=False) 40 lin_new.load_state_dict(state_dict) 41 for p1, p2 in zip(lin.parameters(), lin_new.parameters()): 42 self.assertEqual(p1, p2) 43 self.assertTrue(torch.allclose(p1, p2)) 44 45 # Load non-checkpoint wrapped module into checkpoint wrapped one 46 # Make params different 47 for p in lin_new.parameters(): 48 with torch.no_grad(): 49 p.add_(0.5) 50 51 state_dict = deepcopy(lin_new.state_dict()) 52 # Verify checkpoint wrapped linear can load unwrapped linear 53 lin.load_state_dict(state_dict) 54 for p1, p2 in zip(lin.parameters(), lin_new.parameters()): 55 self.assertEqual(p1, p2) 56 57 def test_checkpoint_wrapper_kwarg_support(self): 58 class MyModel(nn.Module): 59 def __init__(self) -> None: 60 super().__init__() 61 self.lin = nn.Linear(10, 10) 62 63 def forward(self, a, b, c=None, d=None, **kwargs): 64 return (self.lin(a), self.lin(b), self.lin(c), self.lin(d)) 65 66 for wrapper in [ 67 partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), 68 partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT), 69 offload_wrapper, 70 ]: 71 with self.subTest(wrapper=wrapper): 72 model = wrapper(MyModel()) 73 if wrapper == offload_wrapper: 74 self.assertTrue(isinstance(model, OffloadWrapper)) 75 else: 76 self.assertTrue(isinstance(model, CheckpointWrapper)) 77 # Verify kwargs can be passed in 78 inp = torch.ones(4, 10, requires_grad=True) 79 out = model(inp, inp, c=inp, d=inp, e=inp, f=inp) 80 self.assertTrue(isinstance(out, tuple)) 81 self.assertEqual(4, len(out)) 82 # Without kwargs should have equivalent gradient requirements. 83 out_no_kwarg = model(inp, inp, inp, inp) 84 for t1, t2 in zip(out_no_kwarg, out): 85 self.assertEqual(t1, t2) 86 self.assertEqual(t1.requires_grad, t2.requires_grad) 87 88 # Test model that enforces kwarg inputs 89 class ModelEnforceKwarg(nn.Module): 90 def __init__(self) -> None: 91 super().__init__() 92 self.lin = nn.Linear(10, 10) 93 94 def forward(self, *, a=None, b=None): 95 return (self.lin(a), self.lin(b)) 96 97 model = checkpoint_wrapper( 98 ModelEnforceKwarg(), checkpoint_impl=CheckpointImpl.REENTRANT 99 ) 100 101 inp = torch.ones(4, 10, requires_grad=True) 102 out = model(a=inp, b=inp) 103 self.assertEqual(2, len(out)) 104 105 def test_checkpoint_wrapper_args_kwargs(self): 106 """ 107 Tests that checkpoint_wrapper can pass down args / kwargs to configure 108 torch.utils.checkpoint. 109 """ 110 111 count = 0 112 113 @contextlib.contextmanager 114 def ctx_manager(): 115 nonlocal count 116 count += 1 117 yield 118 119 def get_ctx_mgrs(): 120 return (ctx_manager(), ctx_manager()) 121 122 # kwargs test 123 torch_utils_checkpoint = torch.utils.checkpoint.checkpoint 124 m = checkpoint_wrapper( 125 torch.nn.Linear(1, 1), 126 checkpoint_fn=torch_utils_checkpoint, 127 use_reentrant=False, 128 context_fn=get_ctx_mgrs, 129 ) 130 m(torch.randn(2, 1)).sum().backward() 131 self.assertEqual(2, count) 132 133 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") 134 def test_checkpoint_wrapper_parity(self): 135 """ 136 Tests that using checkpoint_wrapper or the functional 137 torch.utils.checkpoint (with the same reentrant config) 138 results in the same maximum memory usage, i.e. they are 139 equivalent memory usage wise. 140 """ 141 142 class Model(nn.Module): 143 def __init__( 144 self, 145 n: int, 146 use_cp: bool, 147 use_wrapper: bool = False, 148 use_reentrant: bool = True, 149 ): 150 super().__init__() 151 self.layers = nn.ModuleList() 152 self.n = n 153 self.use_cp = use_cp 154 self.use_wrapper = use_wrapper 155 self.use_reentrant = use_reentrant 156 wrp = partial( 157 checkpoint_wrapper, 158 checkpoint_impl=CheckpointImpl.REENTRANT 159 if use_reentrant 160 else CheckpointImpl.NO_REENTRANT, 161 ) 162 for i in range(self.n): 163 l = nn.Sequential( 164 nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256) 165 ) 166 use_checkpoint_wrapper = self.use_wrapper 167 if use_checkpoint_wrapper: 168 l = wrp(l) 169 self.layers.append(l) 170 171 def forward(self, x): 172 for i in range(self.n): 173 if self.use_wrapper or not self.use_cp: 174 x = self.layers[i](x) 175 else: 176 x = checkpoint( 177 self.layers[i], x, use_reentrant=self.use_reentrant 178 ) 179 return x 180 181 def test(use_checkpointing, use_wrapper, use_reentrant): 182 a = Model( 183 8, 184 use_checkpointing, 185 use_wrapper=use_wrapper, 186 use_reentrant=use_reentrant, 187 ).cuda() 188 x = torch.randn(10000, 256, requires_grad=True).cuda() 189 torch.cuda.reset_peak_memory_stats() 190 loss = a(x).sum() 191 loss.backward() 192 return torch.cuda.max_memory_allocated() 193 194 functional_no_reentrant = test( 195 use_checkpointing=True, use_wrapper=False, use_reentrant=False 196 ) 197 wrapper_no_reentrant = test( 198 use_checkpointing=False, use_wrapper=True, use_reentrant=False 199 ) 200 self.assertEqual(functional_no_reentrant, wrapper_no_reentrant) 201 202 functional_reentrant = test( 203 use_checkpointing=True, use_wrapper=False, use_reentrant=True 204 ) 205 wrapper_reentrant = test( 206 use_checkpointing=False, use_wrapper=True, use_reentrant=True 207 ) 208 self.assertEqual(functional_reentrant, wrapper_reentrant) 209 210 def test_forward_missing_attributes(self): 211 lin = nn.Linear(1, 1) 212 m = nn.Sequential(lin, lin) 213 wrapped = CheckpointWrapper(m) 214 # Test indexing is forwarded 215 self.assertEqual(wrapped[0], lin) 216 # Test missing attributes are forwarded. 217 m._foo = "bar" 218 self.assertEqual(wrapped._foo, "bar") 219 220 def test_apply_activation_checkpointing(self): 221 """ 222 Ensures that `apply_activation_checkpointing` can be used 223 to swap modules for their checkpoint-wrapped counterparts given 224 a model. 225 """ 226 227 class LinearWithBatchNorm(nn.Module): 228 def __init__(self) -> None: 229 super().__init__() 230 self.lin = nn.Linear(10, 10) 231 self.bn = nn.BatchNorm1d(10) 232 self.nested_linear = nn.Sequential(nn.Linear(10, 10)) 233 234 def forward(self, x): 235 return self.bn(self.nested_linear(self.lin(x))) 236 237 class MyModel(nn.Module): 238 def __init__(self) -> None: 239 super().__init__() 240 self.seq = nn.Sequential( 241 LinearWithBatchNorm(), LinearWithBatchNorm(), LinearWithBatchNorm() 242 ) 243 244 def forward(self, x): 245 return self.seq(x) 246 247 def check_fn(l): 248 return isinstance(l, nn.Linear) 249 250 n_linear = None 251 252 for i, wrapper in enumerate( 253 [ 254 partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), 255 partial( 256 checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT 257 ), 258 offload_wrapper, 259 ] 260 ): 261 model = MyModel() 262 if n_linear is None: 263 n_linear = sum( 264 1 if isinstance(x, nn.Linear) else 0 for x in model.modules() 265 ) 266 267 with self.subTest(wrapper=wrapper): 268 if i != 0: 269 apply_activation_checkpointing( 270 model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn 271 ) 272 else: 273 apply_activation_checkpointing( 274 model, 275 checkpoint_wrapper_fn=wrapper, 276 auto_wrap_policy=ModuleWrapPolicy({nn.Linear}), 277 ) 278 n_linear_wrapped = sum( 279 1 if isinstance(x, nn.Linear) else 0 for x in model.modules() 280 ) 281 n_checkpointed = sum( 282 1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0 283 for x in model.modules() 284 ) 285 self.assertEqual(n_checkpointed, n_linear_wrapped) 286 self.assertEqual(n_linear, n_linear_wrapped) 287 for j in range(3): 288 self.assertTrue( 289 isinstance( 290 model.seq[j].lin, (CheckpointWrapper, OffloadWrapper) 291 ) 292 ) 293 self.assertTrue( 294 isinstance( 295 model.seq[j].nested_linear[0], 296 (CheckpointWrapper, OffloadWrapper), 297 ) 298 ) 299 300 inp = torch.randn(4, 10, requires_grad=True) 301 for i in range(6): 302 # Kwarg input 303 loss = model(x=inp).sum() 304 self.assertTrue(loss.requires_grad) 305 loss.backward() 306 # ensure checkpointed part of model has gradients 307 for j in range(3): 308 weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight 309 bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias 310 weight_nested_lin = ( 311 model.seq[j] 312 .nested_linear[0] 313 ._checkpoint_wrapped_module.weight 314 ) 315 bias_nested_lin = ( 316 model.seq[j] 317 .nested_linear[0] 318 ._checkpoint_wrapped_module.bias 319 ) 320 for param in [ 321 weight_lin, 322 bias_lin, 323 weight_nested_lin, 324 bias_nested_lin, 325 ]: 326 self.assertTrue(param.requires_grad) 327 self.assertFalse(param.grad is None) 328 329 def test_fqn(self): 330 lin = nn.Linear(10, 10, bias=False) 331 lin = checkpoint_wrapper(lin) 332 state_dict = lin.state_dict() 333 for fqn, _ in lin.named_parameters(): 334 self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.") 335 336 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") 337 def test_checkpoint_wrapper_cpu_offload(self): 338 model = nn.Sequential( 339 nn.Linear(10, 10), 340 nn.Linear(10, 10), 341 nn.Linear(10, 10), 342 ).cuda() 343 344 # Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for 345 # testing, otherwise the tensor access during the DFS will cause orig 346 # unpack to run, transferring the tensor back to GPU. 347 def patched_init(saved_tensor_hook_obj, pack_hook, _): 348 saved_tensor_hook_obj.pack_hook = pack_hook 349 350 def testing_cpu_offload_unpack_hook(packed): 351 _, tensor = packed 352 return tensor 353 354 saved_tensor_hook_obj.unpack_hook = testing_cpu_offload_unpack_hook 355 356 orig_init = torch.autograd.graph.saved_tensors_hooks.__init__ 357 torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init 358 359 model = offload_wrapper(model) 360 361 inp = torch.randn(3, 10, device="cuda") 362 loss = model(inp).sum() 363 364 # All autograd saved tensors should be offloaded to CPU. 365 offload_verified = False 366 367 def dfs(grad_fn): 368 for e in dir(grad_fn): 369 if not e.startswith(_SAVED_PREFIX): 370 continue 371 372 saved = getattr(grad_fn, e) 373 if isinstance(saved, torch.Tensor): 374 self.assertEqual(torch.device("cpu"), saved.device) 375 nonlocal offload_verified 376 offload_verified = True 377 378 if hasattr(grad_fn, GRAD_FN_NEXT_FUNCTIONS): 379 for next_grad_fn, _ in grad_fn.next_functions: 380 dfs(next_grad_fn) 381 382 dfs(loss.grad_fn) 383 384 self.assertTrue(offload_verified) 385 386 torch.autograd.graph.saved_tensors_hooks.__init__ = orig_init 387 388 389if __name__ == "__main__": 390 run_tests() 391