1# Owner(s): ["oncall: distributed"] 2 3import contextlib 4import copy 5import functools 6import threading 7import unittest 8from typing import Any, List, Optional, Tuple, Union 9 10import torch 11import torch.distributed as dist 12import torch.nn as nn 13from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy 14from torch.distributed.device_mesh import DeviceMesh 15from torch.testing._internal.common_cuda import TEST_CUDA 16from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 17from torch.testing._internal.common_fsdp import ( 18 check_sharded_parity, 19 FSDPTest, 20 FSDPTestMultiThread, 21 MLP, 22) 23from torch.testing._internal.common_utils import run_tests 24from torch.testing._internal.two_tensor import TwoTensor 25 26 27def two_tensor_fsdp_pre_all_gather( 28 self, mesh: DeviceMesh 29) -> Tuple[Tuple[torch.Tensor, ...], Any]: 30 all_gather_inputs = (self.a, self.b) 31 metadata = None 32 return all_gather_inputs, metadata 33 34 35def two_tensor_fsdp_post_all_gather( 36 self, 37 all_gather_outputs: Tuple[torch.Tensor, ...], 38 metadata: Any, 39 param_dtype: torch.dtype, 40 *, 41 out: Optional[torch.Tensor] = None, 42) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: 43 assert metadata is None, f"{metadata}" 44 a, b = all_gather_outputs 45 if out is not None: 46 assert isinstance(out, TwoTensor), f"{type(out)}" 47 if a.dtype == param_dtype: 48 assert a.untyped_storage().data_ptr() == out.a.untyped_storage().data_ptr() 49 assert b.untyped_storage().data_ptr() == out.b.untyped_storage().data_ptr() 50 else: 51 assert out.a.dtype == param_dtype, f"{out.a.dtype} {param_dtype}" 52 assert out.b.dtype == param_dtype, f"{out.b.dtype} {param_dtype}" 53 out.a.copy_(a) 54 out.b.copy_(b) 55 return 56 tensors_to_free = (a, b) 57 # If the cast is real, then the all-gather outputs will not alias the 58 # returned `TwoTensor`'s `a` and `b` 59 two_tensor = TwoTensor(a, b).to(param_dtype) 60 return two_tensor, tensors_to_free 61 62 63class TestFullyShardAllGatherExtensionsCommon: 64 @property 65 def world_size(self) -> int: 66 return 2 67 68 @contextlib.contextmanager 69 def _patch_two_tensor_fsdp_all_gather(self): 70 lock = threading.Lock() 71 TwoTensor.fsdp_pre_all_gather = two_tensor_fsdp_pre_all_gather 72 TwoTensor.fsdp_post_all_gather = two_tensor_fsdp_post_all_gather 73 dist.barrier() 74 try: 75 yield 76 finally: 77 dist.barrier() 78 with lock: # only one thread needs to delete 79 if hasattr(TwoTensor, "fsdp_pre_all_gather"): 80 delattr(TwoTensor, "fsdp_pre_all_gather") 81 if hasattr(TwoTensor, "fsdp_post_all_gather"): 82 delattr(TwoTensor, "fsdp_post_all_gather") 83 84 def _init_two_tensor_mlp(self) -> nn.Module: 85 # Disable bias because the reference model will end up with a bias 86 # gradient that is a `TwoTensor`, whereas the FSDP model does not 87 model = nn.Sequential(*[MLP(8, bias=False) for _ in range(3)]) 88 for mlp in model: 89 mlp.in_proj.weight = nn.Parameter( 90 TwoTensor(mlp.in_proj.weight, mlp.in_proj.weight.clone()) 91 ) 92 mlp.out_proj.weight = nn.Parameter( 93 TwoTensor(mlp.out_proj.weight, mlp.out_proj.weight.clone()) 94 ) 95 return model 96 97 98class TestFullyShardAllGatherExtensionsMultiProcess( 99 TestFullyShardAllGatherExtensionsCommon, FSDPTest 100): 101 @skip_if_lt_x_gpu(2) 102 def test_all_gather_extensions_train_parity(self): 103 with self._patch_two_tensor_fsdp_all_gather(): 104 self.run_subtests( 105 {"reshard_after_forward": [True, False]}, 106 self._test_all_gather_extensions_train_parity, 107 ) 108 109 def _test_all_gather_extensions_train_parity(self, reshard_after_forward: bool): 110 torch.manual_seed(42) 111 model = self._init_two_tensor_mlp() 112 ref_model = copy.deepcopy(model).cuda() 113 ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=True) 114 fully_shard_fn = functools.partial( 115 fully_shard, reshard_after_forward=reshard_after_forward 116 ) 117 for mlp in model: 118 fully_shard_fn(mlp) 119 fully_shard_fn(model) 120 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 121 check_sharded_parity(self, ref_model, model) 122 123 torch.manual_seed(42 + self.rank + 1) 124 inp = torch.randn((2, 8), device="cuda") 125 for iter_idx in range(10): 126 losses: List[torch.Tensor] = [] 127 for _model in (ref_model, model): 128 losses.append(_model(inp).sum()) 129 losses[-1].backward() 130 if _model is ref_model: 131 for param_name, param in _model.named_parameters(): 132 dist.all_reduce(param.grad) 133 param.grad.detach().div_(self.world_size) 134 self.assertEqual(losses[0], losses[1]) 135 check_sharded_parity(self, ref_model, model) 136 for _optim in (ref_optim, optim): 137 _optim.step() 138 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) 139 check_sharded_parity(self, ref_model, model) 140 141 142class TestFullyShardAllGatherExtensionsMultiThread( 143 TestFullyShardAllGatherExtensionsCommon, FSDPTestMultiThread 144): 145 @property 146 def device(self) -> torch.device: 147 return torch.device("cuda:0") 148 149 @unittest.skipIf(not TEST_CUDA, "no cuda") 150 def test_all_gather_extensions_end_to_end(self): 151 with self._patch_two_tensor_fsdp_all_gather(): 152 self.run_subtests( 153 {"reshard_after_forward": [True, False]}, 154 self._test_all_gather_extensions_end_to_end, 155 ) 156 157 def _test_all_gather_extensions_end_to_end(self, reshard_after_forward: bool): 158 # Check that we can run the meta-device initialization flow 159 with torch.device("meta"): 160 model = self._init_two_tensor_mlp() 161 for param in model.parameters(): 162 self.assertEqual(param.device, torch.device("meta")) 163 fully_shard_fn = functools.partial( 164 fully_shard, 165 reshard_after_forward=reshard_after_forward, 166 mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16), 167 ) 168 for mlp in model: 169 fully_shard_fn(mlp) 170 fully_shard_fn(model) 171 model.to_empty(device=self.device) 172 for param in model.parameters(): 173 nn.init.trunc_normal_(param) 174 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 175 176 # Run a few iterations to check for errors 177 torch.manual_seed(42 + self.rank + 1) 178 inp = torch.randn((2, 8), device="cuda") 179 for _ in range(3): 180 model(inp).sum().backward() 181 optim.step() 182 optim.zero_grad() 183 184 @unittest.skipIf(not TEST_CUDA, "no cuda") 185 def test_all_gather_extensions_monkey_patch(self): 186 # Define a pre/post-all-gather pair that quantizes to bf16 for the 187 # all-gather and de-quantizes back to the parameter dtype 188 def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]: 189 return (self.to(torch.bfloat16),), None 190 191 def fsdp_post_all_gather( 192 self, 193 all_gather_outputs: Tuple[torch.Tensor, ...], 194 metadata: Any, 195 param_dtype: torch.dtype, 196 *, 197 out: Optional[torch.Tensor] = None, 198 ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: 199 (tensor,) = all_gather_outputs 200 assert metadata is None, f"{metadata}" 201 assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}" 202 if out is not None: 203 out.copy_(tensor) 204 return 205 return tensor.to(param_dtype), (tensor,) 206 207 with torch.device("meta"): 208 model = self._init_two_tensor_mlp() 209 for mlp in model: 210 fully_shard(mlp) 211 fully_shard(model) 212 model.to_empty(device=self.device) 213 for param in model.parameters(): 214 nn.init.trunc_normal_(param) 215 # Monkey patch the pre/post-all-gather functions *after* `to_empty()` 216 # since the local tensor objects change from materialization 217 self.assertGreater(sum("weight" in n for n, _ in model.named_parameters()), 0) 218 for param_name, param in model.named_parameters(): 219 if "weight" in param_name: 220 local_param = param.to_local() 221 # Monkey patch on the `torch.Tensor` to show that the extension 222 # can work even without a subclass 223 local_param.fsdp_pre_all_gather = fsdp_pre_all_gather 224 local_param.fsdp_post_all_gather = fsdp_post_all_gather 225 optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) 226 227 # Run a few iterations to check for errors 228 torch.manual_seed(42 + self.rank + 1) 229 inp = torch.randn((2, 8), device="cuda") 230 for _ in range(3): 231 model(inp).sum().backward() 232 optim.step() 233 optim.zero_grad() 234 235 236if __name__ == "__main__": 237 run_tests() 238