xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_shard_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import torch
4from torch.distributed.distributed_c10d import _get_default_group
5from torch.distributed.fsdp._shard_utils import (
6    _create_chunk_dtensor,
7    _create_chunk_sharded_tensor,
8)
9from torch.testing._internal.common_fsdp import FSDPTest
10from torch.testing._internal.common_utils import run_tests
11from torch.testing._internal.distributed._tensor.common_dtensor import (
12    DTensorTestBase,
13    skip_if_lt_x_gpu,
14    with_comms,
15)
16
17
18class TestShardUtilsDistributed(FSDPTest):
19    @property
20    def world_size(self):
21        return 2
22
23    def _create_tensor(self, *size):
24        # Keep everything deterministic.
25        torch.manual_seed(0)
26        return torch.rand(*size).cuda()
27
28    @skip_if_lt_x_gpu(2)
29    def test_create_chunk_sharded_tensor(self):
30        for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
31            tensor = self._create_tensor(*size)
32
33            sharded_tensor = _create_chunk_sharded_tensor(
34                tensor,
35                self.rank,
36                self.world_size,
37                torch.cuda.device_count(),
38                _get_default_group(),
39            )
40            output = torch.empty(*size).cuda() if self.rank == 0 else None
41            sharded_tensor.gather(0, output)
42            if self.rank == 0:
43                self.assertEqual(tensor, output)
44
45
46class TestShardUtilsDistributedDTensor(DTensorTestBase):
47    @property
48    def world_size(self):
49        return 2
50
51    def _create_tensor(self, *size):
52        # Keep everything deterministic.
53        torch.manual_seed(0)
54        return torch.rand(*size).cuda()
55
56    @with_comms
57    @skip_if_lt_x_gpu(2)
58    def test_create_chunk_dtensor(self):
59        device_mesh = self.build_device_mesh()
60
61        for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
62            tensor = self._create_tensor(*size)
63            tensor_chunks = torch.chunk(tensor, self.world_size, dim=0)
64
65            dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh)
66            local_tensor = dtensor.to_local()
67
68            if local_tensor.numel() != 0:
69                self.assertEqual(local_tensor, tensor_chunks[self.rank])
70            else:
71                self.assertEqual(self.rank >= len(tensor_chunks), True)
72
73
74if __name__ == "__main__":
75    run_tests()
76