xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_checkpoint_wrapper.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import unittest
5from copy import deepcopy
6from functools import partial
7
8import torch
9import torch.nn as nn
10from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
11    apply_activation_checkpointing,
12    checkpoint_wrapper,
13    CheckpointImpl,
14    CheckpointWrapper,
15    offload_wrapper,
16    OffloadWrapper,
17)
18from torch.distributed.fsdp.wrap import ModuleWrapPolicy
19from torch.testing._internal.common_utils import run_tests, TestCase
20from torch.utils.checkpoint import checkpoint
21
22
23_SAVED_PREFIX = "_saved_"
24GRAD_FN_NEXT_FUNCTIONS = "next_functions"
25
26
27class CheckpointWrapperTest(TestCase):
28    def test_load_activation_checkpointed_module(self):
29        lin = nn.Linear(10, 10, bias=False)
30        lin = checkpoint_wrapper(
31            lin,
32            checkpoint_fn=checkpoint,
33            # checkpoint kwargs
34            use_reentrant=True,
35            preserve_rng_state=False,
36        )
37        state_dict = deepcopy(lin.state_dict())
38        # Load into non-checkpoint wrapped linear module
39        lin_new = nn.Linear(10, 10, bias=False)
40        lin_new.load_state_dict(state_dict)
41        for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
42            self.assertEqual(p1, p2)
43            self.assertTrue(torch.allclose(p1, p2))
44
45        # Load non-checkpoint wrapped module into checkpoint wrapped one
46        # Make params different
47        for p in lin_new.parameters():
48            with torch.no_grad():
49                p.add_(0.5)
50
51        state_dict = deepcopy(lin_new.state_dict())
52        # Verify checkpoint wrapped linear can load unwrapped linear
53        lin.load_state_dict(state_dict)
54        for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
55            self.assertEqual(p1, p2)
56
57    def test_checkpoint_wrapper_kwarg_support(self):
58        class MyModel(nn.Module):
59            def __init__(self) -> None:
60                super().__init__()
61                self.lin = nn.Linear(10, 10)
62
63            def forward(self, a, b, c=None, d=None, **kwargs):
64                return (self.lin(a), self.lin(b), self.lin(c), self.lin(d))
65
66        for wrapper in [
67            partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
68            partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
69            offload_wrapper,
70        ]:
71            with self.subTest(wrapper=wrapper):
72                model = wrapper(MyModel())
73                if wrapper == offload_wrapper:
74                    self.assertTrue(isinstance(model, OffloadWrapper))
75                else:
76                    self.assertTrue(isinstance(model, CheckpointWrapper))
77                # Verify kwargs can be passed in
78                inp = torch.ones(4, 10, requires_grad=True)
79                out = model(inp, inp, c=inp, d=inp, e=inp, f=inp)
80                self.assertTrue(isinstance(out, tuple))
81                self.assertEqual(4, len(out))
82                # Without kwargs should have equivalent gradient requirements.
83                out_no_kwarg = model(inp, inp, inp, inp)
84                for t1, t2 in zip(out_no_kwarg, out):
85                    self.assertEqual(t1, t2)
86                    self.assertEqual(t1.requires_grad, t2.requires_grad)
87
88        # Test model that enforces kwarg inputs
89        class ModelEnforceKwarg(nn.Module):
90            def __init__(self) -> None:
91                super().__init__()
92                self.lin = nn.Linear(10, 10)
93
94            def forward(self, *, a=None, b=None):
95                return (self.lin(a), self.lin(b))
96
97        model = checkpoint_wrapper(
98            ModelEnforceKwarg(), checkpoint_impl=CheckpointImpl.REENTRANT
99        )
100
101        inp = torch.ones(4, 10, requires_grad=True)
102        out = model(a=inp, b=inp)
103        self.assertEqual(2, len(out))
104
105    def test_checkpoint_wrapper_args_kwargs(self):
106        """
107        Tests that checkpoint_wrapper can pass down args / kwargs to configure
108        torch.utils.checkpoint.
109        """
110
111        count = 0
112
113        @contextlib.contextmanager
114        def ctx_manager():
115            nonlocal count
116            count += 1
117            yield
118
119        def get_ctx_mgrs():
120            return (ctx_manager(), ctx_manager())
121
122        # kwargs test
123        torch_utils_checkpoint = torch.utils.checkpoint.checkpoint
124        m = checkpoint_wrapper(
125            torch.nn.Linear(1, 1),
126            checkpoint_fn=torch_utils_checkpoint,
127            use_reentrant=False,
128            context_fn=get_ctx_mgrs,
129        )
130        m(torch.randn(2, 1)).sum().backward()
131        self.assertEqual(2, count)
132
133    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
134    def test_checkpoint_wrapper_parity(self):
135        """
136        Tests that using checkpoint_wrapper or the functional
137        torch.utils.checkpoint (with the same reentrant config)
138        results in the same maximum memory usage, i.e. they are
139        equivalent memory usage wise.
140        """
141
142        class Model(nn.Module):
143            def __init__(
144                self,
145                n: int,
146                use_cp: bool,
147                use_wrapper: bool = False,
148                use_reentrant: bool = True,
149            ):
150                super().__init__()
151                self.layers = nn.ModuleList()
152                self.n = n
153                self.use_cp = use_cp
154                self.use_wrapper = use_wrapper
155                self.use_reentrant = use_reentrant
156                wrp = partial(
157                    checkpoint_wrapper,
158                    checkpoint_impl=CheckpointImpl.REENTRANT
159                    if use_reentrant
160                    else CheckpointImpl.NO_REENTRANT,
161                )
162                for i in range(self.n):
163                    l = nn.Sequential(
164                        nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
165                    )
166                    use_checkpoint_wrapper = self.use_wrapper
167                    if use_checkpoint_wrapper:
168                        l = wrp(l)
169                    self.layers.append(l)
170
171            def forward(self, x):
172                for i in range(self.n):
173                    if self.use_wrapper or not self.use_cp:
174                        x = self.layers[i](x)
175                    else:
176                        x = checkpoint(
177                            self.layers[i], x, use_reentrant=self.use_reentrant
178                        )
179                return x
180
181        def test(use_checkpointing, use_wrapper, use_reentrant):
182            a = Model(
183                8,
184                use_checkpointing,
185                use_wrapper=use_wrapper,
186                use_reentrant=use_reentrant,
187            ).cuda()
188            x = torch.randn(10000, 256, requires_grad=True).cuda()
189            torch.cuda.reset_peak_memory_stats()
190            loss = a(x).sum()
191            loss.backward()
192            return torch.cuda.max_memory_allocated()
193
194        functional_no_reentrant = test(
195            use_checkpointing=True, use_wrapper=False, use_reentrant=False
196        )
197        wrapper_no_reentrant = test(
198            use_checkpointing=False, use_wrapper=True, use_reentrant=False
199        )
200        self.assertEqual(functional_no_reentrant, wrapper_no_reentrant)
201
202        functional_reentrant = test(
203            use_checkpointing=True, use_wrapper=False, use_reentrant=True
204        )
205        wrapper_reentrant = test(
206            use_checkpointing=False, use_wrapper=True, use_reentrant=True
207        )
208        self.assertEqual(functional_reentrant, wrapper_reentrant)
209
210    def test_forward_missing_attributes(self):
211        lin = nn.Linear(1, 1)
212        m = nn.Sequential(lin, lin)
213        wrapped = CheckpointWrapper(m)
214        # Test indexing is forwarded
215        self.assertEqual(wrapped[0], lin)
216        # Test missing attributes are forwarded.
217        m._foo = "bar"
218        self.assertEqual(wrapped._foo, "bar")
219
220    def test_apply_activation_checkpointing(self):
221        """
222        Ensures that `apply_activation_checkpointing` can be used
223        to swap modules for their checkpoint-wrapped counterparts given
224        a model.
225        """
226
227        class LinearWithBatchNorm(nn.Module):
228            def __init__(self) -> None:
229                super().__init__()
230                self.lin = nn.Linear(10, 10)
231                self.bn = nn.BatchNorm1d(10)
232                self.nested_linear = nn.Sequential(nn.Linear(10, 10))
233
234            def forward(self, x):
235                return self.bn(self.nested_linear(self.lin(x)))
236
237        class MyModel(nn.Module):
238            def __init__(self) -> None:
239                super().__init__()
240                self.seq = nn.Sequential(
241                    LinearWithBatchNorm(), LinearWithBatchNorm(), LinearWithBatchNorm()
242                )
243
244            def forward(self, x):
245                return self.seq(x)
246
247        def check_fn(l):
248            return isinstance(l, nn.Linear)
249
250        n_linear = None
251
252        for i, wrapper in enumerate(
253            [
254                partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
255                partial(
256                    checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
257                ),
258                offload_wrapper,
259            ]
260        ):
261            model = MyModel()
262            if n_linear is None:
263                n_linear = sum(
264                    1 if isinstance(x, nn.Linear) else 0 for x in model.modules()
265                )
266
267            with self.subTest(wrapper=wrapper):
268                if i != 0:
269                    apply_activation_checkpointing(
270                        model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn
271                    )
272                else:
273                    apply_activation_checkpointing(
274                        model,
275                        checkpoint_wrapper_fn=wrapper,
276                        auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
277                    )
278                n_linear_wrapped = sum(
279                    1 if isinstance(x, nn.Linear) else 0 for x in model.modules()
280                )
281                n_checkpointed = sum(
282                    1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0
283                    for x in model.modules()
284                )
285                self.assertEqual(n_checkpointed, n_linear_wrapped)
286                self.assertEqual(n_linear, n_linear_wrapped)
287                for j in range(3):
288                    self.assertTrue(
289                        isinstance(
290                            model.seq[j].lin, (CheckpointWrapper, OffloadWrapper)
291                        )
292                    )
293                    self.assertTrue(
294                        isinstance(
295                            model.seq[j].nested_linear[0],
296                            (CheckpointWrapper, OffloadWrapper),
297                        )
298                    )
299
300                inp = torch.randn(4, 10, requires_grad=True)
301                for i in range(6):
302                    # Kwarg input
303                    loss = model(x=inp).sum()
304                    self.assertTrue(loss.requires_grad)
305                    loss.backward()
306                    # ensure checkpointed part of model has gradients
307                    for j in range(3):
308                        weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight
309                        bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias
310                        weight_nested_lin = (
311                            model.seq[j]
312                            .nested_linear[0]
313                            ._checkpoint_wrapped_module.weight
314                        )
315                        bias_nested_lin = (
316                            model.seq[j]
317                            .nested_linear[0]
318                            ._checkpoint_wrapped_module.bias
319                        )
320                        for param in [
321                            weight_lin,
322                            bias_lin,
323                            weight_nested_lin,
324                            bias_nested_lin,
325                        ]:
326                            self.assertTrue(param.requires_grad)
327                            self.assertFalse(param.grad is None)
328
329    def test_fqn(self):
330        lin = nn.Linear(10, 10, bias=False)
331        lin = checkpoint_wrapper(lin)
332        state_dict = lin.state_dict()
333        for fqn, _ in lin.named_parameters():
334            self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
335
336    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
337    def test_checkpoint_wrapper_cpu_offload(self):
338        model = nn.Sequential(
339            nn.Linear(10, 10),
340            nn.Linear(10, 10),
341            nn.Linear(10, 10),
342        ).cuda()
343
344        # Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for
345        # testing, otherwise the tensor access during the DFS will cause orig
346        # unpack to run, transferring the tensor back to GPU.
347        def patched_init(saved_tensor_hook_obj, pack_hook, _):
348            saved_tensor_hook_obj.pack_hook = pack_hook
349
350            def testing_cpu_offload_unpack_hook(packed):
351                _, tensor = packed
352                return tensor
353
354            saved_tensor_hook_obj.unpack_hook = testing_cpu_offload_unpack_hook
355
356        orig_init = torch.autograd.graph.saved_tensors_hooks.__init__
357        torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init
358
359        model = offload_wrapper(model)
360
361        inp = torch.randn(3, 10, device="cuda")
362        loss = model(inp).sum()
363
364        # All autograd saved tensors should be offloaded to CPU.
365        offload_verified = False
366
367        def dfs(grad_fn):
368            for e in dir(grad_fn):
369                if not e.startswith(_SAVED_PREFIX):
370                    continue
371
372                saved = getattr(grad_fn, e)
373                if isinstance(saved, torch.Tensor):
374                    self.assertEqual(torch.device("cpu"), saved.device)
375                    nonlocal offload_verified
376                    offload_verified = True
377
378            if hasattr(grad_fn, GRAD_FN_NEXT_FUNCTIONS):
379                for next_grad_fn, _ in grad_fn.next_functions:
380                    dfs(next_grad_fn)
381
382        dfs(loss.grad_fn)
383
384        self.assertTrue(offload_verified)
385
386        torch.autograd.graph.saved_tensors_hooks.__init__ = orig_init
387
388
389if __name__ == "__main__":
390    run_tests()
391