1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import sys 5from enum import Enum 6 7import torch 8import torch.nn as nn 9import torch.optim as optim 10from torch import distributed as dist 11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 12from torch.nn.parallel import DistributedDataParallel 13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 14from torch.testing._internal.common_fsdp import FSDPTest, get_full_params 15from torch.testing._internal.common_utils import ( 16 instantiate_parametrized_tests, 17 parametrize, 18 run_tests, 19 TEST_WITH_DEV_DBG_ASAN, 20) 21 22 23if not dist.is_available(): 24 print("Distributed not available, skipping tests", file=sys.stderr) 25 sys.exit(0) 26 27if TEST_WITH_DEV_DBG_ASAN: 28 print( 29 "Skip dev-asan as torch + multiprocessing spawn have known issues", 30 file=sys.stderr, 31 ) 32 sys.exit(0) 33 34 35class Model(nn.Module): 36 def __init__( 37 self, 38 with_fsdp, 39 freeze_after_wrap_fsdp, 40 disable_autograd, 41 fsdp_kwargs, 42 ): 43 super().__init__() 44 self.trunk = nn.Sequential( 45 nn.Conv2d(3, 64, kernel_size=3), 46 nn.ReLU(inplace=True), 47 nn.AdaptiveAvgPool2d(output_size=(1, 1)), 48 nn.Flatten(), 49 ) 50 self.device = torch.cuda.current_device() 51 self.head = nn.Linear(64, 10) 52 if with_fsdp and freeze_after_wrap_fsdp: 53 self.fsdp_wrap(fsdp_kwargs) 54 self.autograd_ctx = ( 55 torch.no_grad if disable_autograd else contextlib.nullcontext 56 ) 57 58 def fsdp_wrap(self, fsdp_kwargs): 59 self.trunk = FSDP(self.trunk, **fsdp_kwargs) 60 self.head = FSDP(self.head, **fsdp_kwargs) 61 62 def forward(self, x): 63 with self.autograd_ctx(): 64 x = self.trunk(x) 65 return self.head(x) 66 67 68class NestedTrunkModel(nn.Module): 69 def __init__( 70 self, 71 with_fsdp, 72 freeze_after_wrap_fsdp, 73 disable_autograd, 74 fsdp_kwargs, 75 ): 76 super().__init__() 77 self.trunk = nn.Sequential( 78 self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp), 79 self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp), 80 ) 81 self.head = nn.Sequential( 82 nn.AdaptiveAvgPool2d(output_size=(1, 1)), 83 nn.Flatten(), 84 nn.Linear(64, 10), 85 ) 86 if with_fsdp and freeze_after_wrap_fsdp: 87 self.fsdp_wrap(fsdp_kwargs) 88 self.autograd_ctx = ( 89 torch.no_grad if disable_autograd else contextlib.nullcontext 90 ) 91 92 def fsdp_wrap(self, fsdp_kwargs): 93 for name, child in self.trunk.named_children(): 94 wrapped_child = FSDP(child, **fsdp_kwargs) 95 setattr(self.trunk, name, wrapped_child) 96 self.trunk = FSDP(self.trunk, **fsdp_kwargs) 97 self.head = FSDP(self.head, **fsdp_kwargs) 98 99 def forward(self, x): 100 with self.autograd_ctx(): 101 x = self.trunk(x) 102 return self.head(x) 103 104 def _create_block( 105 self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp 106 ): 107 block = nn.Sequential( 108 nn.Conv2d(in_channels, out_channels, kernel_size=3), 109 nn.ReLU(inplace=True), 110 ) 111 return block 112 113 114class FreezingMethod(str, Enum): 115 GradToNone = "grad_to_none" 116 RequiresGrad = "requires_grad" 117 118 119class TestFreezingWeights(FSDPTest): 120 def _create_model( 121 self, 122 with_fsdp, 123 with_nested_trunk, 124 freeze_after_wrap_fsdp, 125 disable_autograd, 126 fsdp_kwargs, 127 ): 128 if with_nested_trunk: 129 model = NestedTrunkModel( 130 with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs 131 ) 132 else: 133 model = Model( 134 with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs 135 ) 136 return model 137 138 def _dist_train( 139 self, 140 with_nested_trunk, 141 freezing_method, 142 freeze_after_wrap_fsdp, 143 with_fsdp, 144 disable_autograd, 145 forward_prefetch, 146 ): 147 torch.manual_seed(0) 148 batch = torch.randn(size=(2, 3, 224, 224)).cuda() 149 150 fsdp_kwargs = { 151 "device_id": self.rank, 152 "forward_prefetch": forward_prefetch, 153 } 154 155 ddp_kwargs = { 156 "device_ids": [self.rank], 157 "find_unused_parameters": True if disable_autograd else False, 158 } 159 160 model = self._create_model( 161 with_fsdp, 162 with_nested_trunk, 163 freeze_after_wrap_fsdp, 164 disable_autograd, 165 fsdp_kwargs, 166 ) 167 model = model.cuda() 168 169 # freezing the trunk using requires_grad. 170 if freezing_method == FreezingMethod.RequiresGrad: 171 for param in model.trunk.parameters(): 172 param.requires_grad = False 173 174 if with_fsdp: 175 if not freeze_after_wrap_fsdp: 176 model.fsdp_wrap(fsdp_kwargs) 177 model = FSDP(model, **fsdp_kwargs) 178 else: 179 model = DistributedDataParallel(model, **ddp_kwargs) 180 181 target = torch.tensor([0, 1], dtype=torch.long).cuda() 182 criterion = nn.CrossEntropyLoss() 183 optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 184 185 for iteration in range(3): 186 out = model(batch) 187 fake_loss = criterion(out, target) 188 optimizer.zero_grad() 189 fake_loss.backward() 190 if freezing_method == FreezingMethod.GradToNone: 191 for param in model.module.trunk.parameters(): 192 param.grad = None 193 optimizer.step() 194 195 if with_fsdp: 196 return get_full_params(model) 197 198 return list(model.parameters()) 199 200 @skip_if_lt_x_gpu(2) 201 @parametrize("with_nested_trunk", [True, False]) 202 @parametrize( 203 "freezing_method", [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone] 204 ) 205 @parametrize("freeze_after_wrap_fsdp", [True, False]) 206 @parametrize("disable_autograd", [True, False]) 207 @parametrize("forward_prefetch", [True, False]) 208 def test_freezing_weights( 209 self, 210 with_nested_trunk, 211 freezing_method, 212 freeze_after_wrap_fsdp, 213 disable_autograd, 214 forward_prefetch, 215 ): 216 # DDP 217 ddp_state = self._dist_train( 218 with_nested_trunk, 219 freezing_method, 220 freeze_after_wrap_fsdp, 221 with_fsdp=False, 222 disable_autograd=disable_autograd, 223 forward_prefetch=False, # does not apply to DDP 224 ) 225 226 # FSDP 227 fsdp_state = self._dist_train( 228 with_nested_trunk, 229 freezing_method, 230 freeze_after_wrap_fsdp, 231 with_fsdp=True, 232 disable_autograd=disable_autograd, 233 forward_prefetch=forward_prefetch, 234 ) 235 236 self.assertEqual( 237 ddp_state, 238 fsdp_state, 239 exact_device=True, 240 msg="FullyShardedDataParallel states didn't match PyTorch DDP states", 241 ) 242 243 if freezing_method == FreezingMethod.RequiresGrad: 244 for ddp_param, fsdp_param in zip(ddp_state, fsdp_state): 245 self.assertEqual(ddp_param.requires_grad, fsdp_param.requires_grad) 246 247 248instantiate_parametrized_tests(TestFreezingWeights) 249 250if __name__ == "__main__": 251 run_tests() 252