1# Owner(s): ["oncall: distributed"] 2 3import torch 4from torch.distributed.distributed_c10d import _get_default_group 5from torch.distributed.fsdp._shard_utils import ( 6 _create_chunk_dtensor, 7 _create_chunk_sharded_tensor, 8) 9from torch.testing._internal.common_fsdp import FSDPTest 10from torch.testing._internal.common_utils import run_tests 11from torch.testing._internal.distributed._tensor.common_dtensor import ( 12 DTensorTestBase, 13 skip_if_lt_x_gpu, 14 with_comms, 15) 16 17 18class TestShardUtilsDistributed(FSDPTest): 19 @property 20 def world_size(self): 21 return 2 22 23 def _create_tensor(self, *size): 24 # Keep everything deterministic. 25 torch.manual_seed(0) 26 return torch.rand(*size).cuda() 27 28 @skip_if_lt_x_gpu(2) 29 def test_create_chunk_sharded_tensor(self): 30 for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)): 31 tensor = self._create_tensor(*size) 32 33 sharded_tensor = _create_chunk_sharded_tensor( 34 tensor, 35 self.rank, 36 self.world_size, 37 torch.cuda.device_count(), 38 _get_default_group(), 39 ) 40 output = torch.empty(*size).cuda() if self.rank == 0 else None 41 sharded_tensor.gather(0, output) 42 if self.rank == 0: 43 self.assertEqual(tensor, output) 44 45 46class TestShardUtilsDistributedDTensor(DTensorTestBase): 47 @property 48 def world_size(self): 49 return 2 50 51 def _create_tensor(self, *size): 52 # Keep everything deterministic. 53 torch.manual_seed(0) 54 return torch.rand(*size).cuda() 55 56 @with_comms 57 @skip_if_lt_x_gpu(2) 58 def test_create_chunk_dtensor(self): 59 device_mesh = self.build_device_mesh() 60 61 for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)): 62 tensor = self._create_tensor(*size) 63 tensor_chunks = torch.chunk(tensor, self.world_size, dim=0) 64 65 dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh) 66 local_tensor = dtensor.to_local() 67 68 if local_tensor.numel() != 0: 69 self.assertEqual(local_tensor, tensor_chunks[self.rank]) 70 else: 71 self.assertEqual(self.rank >= len(tensor_chunks), True) 72 73 74if __name__ == "__main__": 75 run_tests() 76