1# Owner(s): ["oncall: distributed"] 2 3import os 4from copy import deepcopy 5 6import torch 7import torch.distributed as dist 8import torch.nn.functional as F 9from torch import nn 10from torch.distributed._composable.fsdp import fully_shard 11from torch.distributed._composable.replicate import replicate 12from torch.distributed._tensor import DTensor 13from torch.testing._internal.common_distributed import ( 14 MultiProcessTestCase, 15 skip_if_lt_x_gpu, 16) 17from torch.testing._internal.common_utils import run_tests 18 19 20class Net(nn.Module): 21 def __init__(self) -> None: 22 super().__init__() 23 self.fc1 = nn.Linear(2, 2) 24 self.fc2 = nn.Linear(2, 2) 25 self.fc3 = nn.Linear(2, 2) 26 27 def forward(self, x): 28 return self.fc3(self.fc2(self.fc1(x))) 29 30 31class ReplicateStateDictTest(MultiProcessTestCase): 32 def setUp(self) -> None: 33 super().setUp() 34 self._spawn_processes() 35 36 def tearDown(self): 37 super().tearDown() 38 try: 39 os.remove(self.file_name) 40 except OSError: 41 pass 42 43 def _init_pg(self): 44 dist.init_process_group( 45 backend="gloo", 46 rank=self.rank, 47 world_size=self.world_size, 48 store=dist.FileStore(self.file_name, self.world_size), 49 ) 50 51 def _check_state_dict_parity(self, sd_1, sd_2): 52 for k1, k2 in zip(sd_1.keys(), sd_2.keys()): 53 self.assertEqual(k1, k2) 54 55 for v1, v2 in zip(sd_1.values(), sd_2.values()): 56 self.assertEqual(v1, v2) 57 58 def test_replicate_single_module_save_load(self): 59 """ 60 Tests that replicate() on a single module state_dict 61 matches local module state_dict. 62 """ 63 self._init_pg() 64 model = Net() 65 replicate_model = replicate(deepcopy(model)) 66 local_sd = model.state_dict() 67 ddp_sd = replicate_model.state_dict() 68 self._check_state_dict_parity(local_sd, ddp_sd) 69 70 def test_replicate_non_root_multiple_save_load(self): 71 """ 72 Tests tha replicate() on multiple submodules matches 73 local module state_dict. 74 """ 75 self._init_pg() 76 model = Net() 77 replicate_model = deepcopy(model) 78 replicate(replicate_model.fc1) 79 replicate(replicate_model.fc2) 80 replicate(replicate_model.fc3) 81 82 local_sd = model.state_dict() 83 ddp_sd = replicate_model.state_dict() 84 self._check_state_dict_parity(local_sd, ddp_sd) 85 86 87class ReplicateTest(MultiProcessTestCase): 88 @property 89 def world_size(self) -> int: 90 return 2 91 92 def setUp(self) -> None: 93 super().setUp() 94 self._spawn_processes() 95 96 def tearDown(self): 97 super().tearDown() 98 try: 99 os.remove(self.file_name) 100 except OSError: 101 pass 102 103 def _init_pg(self): 104 dist.init_process_group( 105 backend="gloo", 106 rank=self.rank, 107 world_size=self.world_size, 108 store=dist.FileStore(self.file_name, self.world_size), 109 ) 110 111 def _compare_module(self, mod, replicate_mod): 112 local_batch_size = 1 113 global_batch_size = self.world_size * local_batch_size 114 input = torch.randn(global_batch_size, 2) 115 target = torch.randn(global_batch_size, 2) 116 117 def step_model(model, input, target): 118 model.train() 119 output = model(input) 120 loss = F.mse_loss(output, target.to(output.device)) 121 loss.backward() 122 for param in model.parameters(): 123 with torch.no_grad(): 124 param -= param.grad 125 param.grad = None 126 127 for iteration in range(2): 128 step_model(mod, input, target) 129 step_model( 130 replicate_mod, 131 input[ 132 self.rank * local_batch_size : (self.rank + 1) * local_batch_size 133 ], 134 target[ 135 self.rank * local_batch_size : (self.rank + 1) * local_batch_size 136 ], 137 ) 138 139 self.assertEqual( 140 len(list(mod.parameters())), 141 len(list(replicate_mod.parameters())), 142 ) 143 for i, j in zip(mod.parameters(), replicate_mod.parameters()): 144 self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5) 145 146 # Shuffle the input so that DDP input is different 147 torch.manual_seed(iteration) 148 input = input[torch.randperm(global_batch_size)] 149 150 def test_replicate_single_module(self): 151 self._init_pg() 152 model = Net() 153 replicate_model = replicate(deepcopy(model)) 154 self._compare_module(model, replicate_model) 155 156 @skip_if_lt_x_gpu(2) 157 def test_replicate_move_args_kwargs_to_device(self): 158 class MyNet(nn.Module): 159 def __init__(self) -> None: 160 super().__init__() 161 self.a = nn.Linear(2, 2) 162 163 def forward(self, inp, *, kwarg=None): 164 if kwarg is not None: 165 inp = inp @ kwarg 166 return self.a(inp) 167 168 self._init_pg() 169 torch.cuda.set_device(self.rank) 170 model = MyNet().cuda() 171 replicate(model, device_id=torch.cuda.current_device()) 172 # CPU input ensures replicate can move arg and kwargs to device. 173 a, b = torch.randn(2, 2), torch.randn(2, 2) 174 model(a, kwarg=b).sum().backward() 175 176 @skip_if_lt_x_gpu(2) 177 def test_replicate_ignore_module(self): 178 self._init_pg() 179 torch.cuda.set_device(self.rank) 180 # Seed ensures diff input and thus different local grads across ranks. 181 torch.manual_seed(self.rank) 182 torch.cuda.manual_seed(self.rank) 183 model = Net().cuda() 184 replicate(model, ignored_modules=[model.fc1]) 185 # CPU input ensures that replicate can move input to GPU as DDP does. 186 inp = torch.randn(5, 2, device="cuda") * (self.rank + 1) 187 out = model(inp) * 10 188 out.sum().backward() 189 # FC1 grads should not be synchronized, FC2 and 3 should be. 190 fc1_grad = model.fc1.weight.grad 191 tensor_list = [torch.zeros_like(fc1_grad) for _ in range(dist.get_world_size())] 192 dist.all_gather(tensor_list, fc1_grad) 193 grad, rest = tensor_list[0], tensor_list[1:] 194 for g in rest: 195 self.assertNotEqual(grad, g) 196 197 for dp_grad in [model.fc2.weight.grad, model.fc3.weight.grad]: 198 tensor_list = [ 199 torch.zeros_like(dp_grad) for _ in range(dist.get_world_size()) 200 ] 201 dist.all_gather(tensor_list, dp_grad) 202 grad, rest = tensor_list[0], tensor_list[1:] 203 for g in rest: 204 self.assertEqual(grad, g) 205 206 def test_replicate_multi_module(self): 207 self._init_pg() 208 model = Net() 209 replicate_model = deepcopy(model) 210 replicate(replicate_model.fc1) 211 replicate(replicate_model.fc2) 212 replicate(replicate_model.fc3) 213 self._compare_module(model, replicate_model) 214 215 def test_replicate_with_kwargs(self): 216 self._init_pg() 217 model = Net() 218 replicate_model = replicate( 219 deepcopy(model), bucket_cap_mb=1, gradient_as_bucket_view=True 220 ) 221 self._compare_module(model, replicate_model) 222 223 @skip_if_lt_x_gpu(2) 224 def test_replicate_device_id(self): 225 self._init_pg() 226 model = Net() 227 model_cuda = deepcopy(model).cuda() 228 model_cuda2 = deepcopy(model_cuda) 229 replicate(model, device_id=torch.device("cpu")) 230 # DDP instance is attached in first pre forward 231 model(torch.randn(2, 2)) 232 replicate_ddp_weakref = replicate.state(model)._ddp_weakref() 233 # Should be None for CPU training 234 self.assertEqual(None, replicate_ddp_weakref.device_ids) 235 236 replicate(model_cuda, device_id=torch.device(torch.cuda.current_device())) 237 # DDP instance is attached in first pre forward 238 model_cuda(torch.randn(2, 2)) 239 replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref() 240 self.assertEqual([0], replicate_ddp_weakref.device_ids) 241 # Pass in int as device_id 242 replicate(model_cuda2, device_id=int(torch.cuda.current_device())) 243 # DDP instance is attached in first pre forward 244 model_cuda2(torch.randn(2, 2)) 245 replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref() 246 self.assertEqual([0], replicate_ddp_weakref.device_ids) 247 248 def test_replicate_wrong_device_id_type(self): 249 self._init_pg() 250 model = Net() 251 with self.assertRaisesRegex( 252 RuntimeError, "Expected device_id to be int or torch.device" 253 ): 254 replicate(model, device_id=[torch.device("cpu")]) 255 256 257class ReplicateFullyShardInit(ReplicateTest): 258 @skip_if_lt_x_gpu(2) 259 def test_replicate_fully_shard_init(self): 260 class ToyModel(nn.Module): 261 def __init__(self, dim: int): 262 super().__init__() 263 self.linears = nn.Sequential( 264 nn.Linear(dim, dim, bias=False), 265 nn.Linear(dim, dim, bias=False), 266 nn.Linear(dim, dim, bias=False), 267 ) 268 self.proj = nn.Linear(dim, dim, bias=False) 269 270 def forward(self, x: torch.Tensor): 271 y = self.linears(x) 272 y = self.proj(y) 273 return y 274 275 self._init_pg() 276 torch.cuda.set_device(self.rank) 277 dim = 3 278 bz = 2 279 model = ToyModel(dim).cuda() 280 for linear in model.linears: 281 fully_shard(linear) 282 fully_shard(model.linears) 283 replicate(model, device_id=torch.cuda.current_device()) 284 for linear in model.linears: 285 self.assertTrue(isinstance(linear.weight, DTensor)) 286 inp = torch.rand(bz, dim) 287 # trigger lazy init 288 model(inp).sum() 289 for linear in model.linears: 290 self.assertTrue(isinstance(linear.weight, DTensor)) 291 292 293if __name__ == "__main__": 294 run_tests() 295