xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_freezing_weights.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import sys
5from enum import Enum
6
7import torch
8import torch.nn as nn
9import torch.optim as optim
10from torch import distributed as dist
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.nn.parallel import DistributedDataParallel
13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
15from torch.testing._internal.common_utils import (
16    instantiate_parametrized_tests,
17    parametrize,
18    run_tests,
19    TEST_WITH_DEV_DBG_ASAN,
20)
21
22
23if not dist.is_available():
24    print("Distributed not available, skipping tests", file=sys.stderr)
25    sys.exit(0)
26
27if TEST_WITH_DEV_DBG_ASAN:
28    print(
29        "Skip dev-asan as torch + multiprocessing spawn have known issues",
30        file=sys.stderr,
31    )
32    sys.exit(0)
33
34
35class Model(nn.Module):
36    def __init__(
37        self,
38        with_fsdp,
39        freeze_after_wrap_fsdp,
40        disable_autograd,
41        fsdp_kwargs,
42    ):
43        super().__init__()
44        self.trunk = nn.Sequential(
45            nn.Conv2d(3, 64, kernel_size=3),
46            nn.ReLU(inplace=True),
47            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
48            nn.Flatten(),
49        )
50        self.device = torch.cuda.current_device()
51        self.head = nn.Linear(64, 10)
52        if with_fsdp and freeze_after_wrap_fsdp:
53            self.fsdp_wrap(fsdp_kwargs)
54        self.autograd_ctx = (
55            torch.no_grad if disable_autograd else contextlib.nullcontext
56        )
57
58    def fsdp_wrap(self, fsdp_kwargs):
59        self.trunk = FSDP(self.trunk, **fsdp_kwargs)
60        self.head = FSDP(self.head, **fsdp_kwargs)
61
62    def forward(self, x):
63        with self.autograd_ctx():
64            x = self.trunk(x)
65        return self.head(x)
66
67
68class NestedTrunkModel(nn.Module):
69    def __init__(
70        self,
71        with_fsdp,
72        freeze_after_wrap_fsdp,
73        disable_autograd,
74        fsdp_kwargs,
75    ):
76        super().__init__()
77        self.trunk = nn.Sequential(
78            self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
79            self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
80        )
81        self.head = nn.Sequential(
82            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
83            nn.Flatten(),
84            nn.Linear(64, 10),
85        )
86        if with_fsdp and freeze_after_wrap_fsdp:
87            self.fsdp_wrap(fsdp_kwargs)
88        self.autograd_ctx = (
89            torch.no_grad if disable_autograd else contextlib.nullcontext
90        )
91
92    def fsdp_wrap(self, fsdp_kwargs):
93        for name, child in self.trunk.named_children():
94            wrapped_child = FSDP(child, **fsdp_kwargs)
95            setattr(self.trunk, name, wrapped_child)
96        self.trunk = FSDP(self.trunk, **fsdp_kwargs)
97        self.head = FSDP(self.head, **fsdp_kwargs)
98
99    def forward(self, x):
100        with self.autograd_ctx():
101            x = self.trunk(x)
102        return self.head(x)
103
104    def _create_block(
105        self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp
106    ):
107        block = nn.Sequential(
108            nn.Conv2d(in_channels, out_channels, kernel_size=3),
109            nn.ReLU(inplace=True),
110        )
111        return block
112
113
114class FreezingMethod(str, Enum):
115    GradToNone = "grad_to_none"
116    RequiresGrad = "requires_grad"
117
118
119class TestFreezingWeights(FSDPTest):
120    def _create_model(
121        self,
122        with_fsdp,
123        with_nested_trunk,
124        freeze_after_wrap_fsdp,
125        disable_autograd,
126        fsdp_kwargs,
127    ):
128        if with_nested_trunk:
129            model = NestedTrunkModel(
130                with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs
131            )
132        else:
133            model = Model(
134                with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs
135            )
136        return model
137
138    def _dist_train(
139        self,
140        with_nested_trunk,
141        freezing_method,
142        freeze_after_wrap_fsdp,
143        with_fsdp,
144        disable_autograd,
145        forward_prefetch,
146    ):
147        torch.manual_seed(0)
148        batch = torch.randn(size=(2, 3, 224, 224)).cuda()
149
150        fsdp_kwargs = {
151            "device_id": self.rank,
152            "forward_prefetch": forward_prefetch,
153        }
154
155        ddp_kwargs = {
156            "device_ids": [self.rank],
157            "find_unused_parameters": True if disable_autograd else False,
158        }
159
160        model = self._create_model(
161            with_fsdp,
162            with_nested_trunk,
163            freeze_after_wrap_fsdp,
164            disable_autograd,
165            fsdp_kwargs,
166        )
167        model = model.cuda()
168
169        # freezing the trunk using requires_grad.
170        if freezing_method == FreezingMethod.RequiresGrad:
171            for param in model.trunk.parameters():
172                param.requires_grad = False
173
174        if with_fsdp:
175            if not freeze_after_wrap_fsdp:
176                model.fsdp_wrap(fsdp_kwargs)
177            model = FSDP(model, **fsdp_kwargs)
178        else:
179            model = DistributedDataParallel(model, **ddp_kwargs)
180
181        target = torch.tensor([0, 1], dtype=torch.long).cuda()
182        criterion = nn.CrossEntropyLoss()
183        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
184
185        for iteration in range(3):
186            out = model(batch)
187            fake_loss = criterion(out, target)
188            optimizer.zero_grad()
189            fake_loss.backward()
190            if freezing_method == FreezingMethod.GradToNone:
191                for param in model.module.trunk.parameters():
192                    param.grad = None
193            optimizer.step()
194
195        if with_fsdp:
196            return get_full_params(model)
197
198        return list(model.parameters())
199
200    @skip_if_lt_x_gpu(2)
201    @parametrize("with_nested_trunk", [True, False])
202    @parametrize(
203        "freezing_method", [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]
204    )
205    @parametrize("freeze_after_wrap_fsdp", [True, False])
206    @parametrize("disable_autograd", [True, False])
207    @parametrize("forward_prefetch", [True, False])
208    def test_freezing_weights(
209        self,
210        with_nested_trunk,
211        freezing_method,
212        freeze_after_wrap_fsdp,
213        disable_autograd,
214        forward_prefetch,
215    ):
216        # DDP
217        ddp_state = self._dist_train(
218            with_nested_trunk,
219            freezing_method,
220            freeze_after_wrap_fsdp,
221            with_fsdp=False,
222            disable_autograd=disable_autograd,
223            forward_prefetch=False,  # does not apply to DDP
224        )
225
226        # FSDP
227        fsdp_state = self._dist_train(
228            with_nested_trunk,
229            freezing_method,
230            freeze_after_wrap_fsdp,
231            with_fsdp=True,
232            disable_autograd=disable_autograd,
233            forward_prefetch=forward_prefetch,
234        )
235
236        self.assertEqual(
237            ddp_state,
238            fsdp_state,
239            exact_device=True,
240            msg="FullyShardedDataParallel states didn't match PyTorch DDP states",
241        )
242
243        if freezing_method == FreezingMethod.RequiresGrad:
244            for ddp_param, fsdp_param in zip(ddp_state, fsdp_state):
245                self.assertEqual(ddp_param.requires_grad, fsdp_param.requires_grad)
246
247
248instantiate_parametrized_tests(TestFreezingWeights)
249
250if __name__ == "__main__":
251    run_tests()
252