xref: /aosp_15_r20/external/pytorch/test/distributed/test_nccl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import re
4import sys
5
6import torch
7import torch.cuda
8import torch.cuda.nccl as nccl
9import torch.distributed as c10d
10from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
11from torch.testing._internal.common_device_type import (
12    dtypes,
13    instantiate_device_type_tests,
14)
15from torch.testing._internal.common_utils import (
16    IS_WINDOWS,
17    load_tests,
18    NoTest,
19    run_tests,
20    skip_but_pass_in_sandcastle_if,
21    TEST_WITH_ROCM,
22    TestCase,
23)
24
25
26HIP_VERSION = (
27    0.0
28    if torch.version.hip is None
29    else float(re.search(r"^\d+\.\d+", torch.version.hip)[0])
30)
31
32# load_tests from common_utils is used to automatically filter tests for
33# sharding on sandcastle. This line silences flake warnings
34load_tests = load_tests
35
36nGPUs = torch.cuda.device_count()
37if not TEST_CUDA:
38    print("CUDA not available, skipping tests", file=sys.stderr)
39    TestCase = NoTest  # noqa: F811
40
41
42datatypes = [torch.float]
43if (
44    TEST_CUDA and c10d.is_nccl_available() and nccl.version() >= (2, 10)
45) or TEST_WITH_ROCM:
46    datatypes.append(torch.bfloat16)
47
48
49class TestNCCL(TestCase):
50    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
51    def test_unique_id(self, device):
52        uid = nccl.unique_id()
53        self.assertIsInstance(uid, bytes)
54        self.assertGreater(len(uid), 1)
55
56    @skip_but_pass_in_sandcastle_if(
57        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
58    )
59    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
60    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
61    @dtypes(*datatypes)
62    def test_broadcast(self, device, dtype):
63        expected = torch.zeros(128).uniform_().to(dtype=dtype)
64        tensors = [expected.cuda()]
65        for device in range(1, torch.cuda.device_count()):
66            tensors.append(torch.zeros(128, dtype=dtype, device=device))
67
68        nccl.broadcast(tensors)
69        for i in range(torch.cuda.device_count()):
70            self.assertEqual(tensors[i], expected)
71
72        # Test with tuple
73        tensors = [expected.cuda()]
74        for device in range(1, torch.cuda.device_count()):
75            tensors.append(torch.zeros(128, dtype=dtype, device=device))
76
77        nccl.broadcast(tuple(tensors))
78        for i in range(torch.cuda.device_count()):
79            self.assertEqual(tensors[i], expected)
80
81    @skip_but_pass_in_sandcastle_if(
82        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
83    )
84    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
85    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
86    @dtypes(*datatypes)
87    def test_reduce(self, device, dtype):
88        cpu_tensors = [
89            torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)
90        ]
91        expected = torch.zeros(128, dtype=dtype)
92        for t in cpu_tensors:
93            expected.add_(t)
94
95        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
96        nccl.reduce(tensors)
97
98        self.assertEqual(tensors[0], expected)
99
100        # Test with tuple
101        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
102        nccl.reduce(tuple(tensors))
103
104        self.assertEqual(tensors[0], expected)
105
106    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
107    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
108    @skip_but_pass_in_sandcastle_if(
109        TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16,  # noqa: F821
110        "Skip bfloat16 test for ROCm < 3.5",
111    )
112    @dtypes(*datatypes)
113    def test_all_reduce(self, device, dtype):
114        cpu_tensors = [
115            torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)
116        ]
117        expected = torch.zeros(128, dtype=dtype)
118        for t in cpu_tensors:
119            expected.add_(t)
120
121        tensors = [cpu_tensors[i].cuda(i) for i in range(nGPUs)]
122        nccl.all_reduce(tensors)
123
124        for tensor in tensors:
125            self.assertEqual(tensor, expected)
126
127        # Test with tuple.
128        tensors = tuple(cpu_tensors[i].cuda(i) for i in range(nGPUs))
129        nccl.all_reduce(tensors)
130
131        for tensor in tensors:
132            self.assertEqual(tensor, expected)
133
134        # Test with set.
135        tensors = {cpu_tensors[i].cuda(i) for i in range(nGPUs)}
136        nccl.all_reduce(tensors)
137
138        for tensor in tensors:
139            self.assertEqual(tensor, expected)
140
141    @skip_but_pass_in_sandcastle_if(
142        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
143    )
144    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
145    def test_collective_errors(self, device):
146        t = torch.rand(10).cuda(0)
147        with self.assertRaisesRegex(
148            TypeError, "Inputs should be a collection of tensors"
149        ):
150            nccl.all_reduce(t)
151
152        with self.assertRaisesRegex(
153            TypeError, "Inputs should be a collection of tensors"
154        ):
155            nccl.reduce(t)
156
157        with self.assertRaisesRegex(
158            TypeError, "Inputs should be a collection of tensors"
159        ):
160            nccl.broadcast(t)
161
162        with self.assertRaisesRegex(
163            TypeError, "Inputs should be a collection of tensors"
164        ):
165            nccl.all_gather(t, t)
166
167        with self.assertRaisesRegex(
168            TypeError, "Inputs should be a collection of tensors"
169        ):
170            nccl.reduce_scatter(t, t)
171
172    @skip_but_pass_in_sandcastle_if(
173        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
174    )
175    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
176    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
177    @dtypes(*datatypes)
178    def test_all_gather(self, device, dtype):
179        cpu_inputs = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
180        expected = torch.cat(cpu_inputs, 0)
181
182        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
183        outputs = [
184            torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)
185        ]
186        nccl.all_gather(inputs, outputs)
187
188        for tensor in outputs:
189            self.assertEqual(tensor, expected)
190
191        # Test with tuple.
192        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
193        outputs = [
194            torch.zeros(128 * nGPUs, device=i, dtype=dtype) for i in range(nGPUs)
195        ]
196        nccl.all_gather(tuple(inputs), tuple(outputs))
197
198        for tensor in outputs:
199            self.assertEqual(tensor, expected)
200
201    @skip_but_pass_in_sandcastle_if(
202        TEST_WITH_ROCM and HIP_VERSION < 3.5, "Skip NCCL tests for ROCm"
203    )
204    @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
205    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected")
206    @dtypes(*datatypes)
207    def test_reduce_scatter(self, device, dtype):
208        in_size = 32 * nGPUs
209        out_size = 32
210
211        cpu_inputs = [
212            torch.zeros(in_size).uniform_().to(dtype=dtype) for i in range(nGPUs)
213        ]
214        expected = torch.zeros(in_size, dtype=dtype)
215        for t in cpu_inputs:
216            expected.add_(t)
217        expected = expected.view(nGPUs, 32)
218
219        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
220        outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]
221        nccl.reduce_scatter(inputs, outputs)
222
223        for i in range(nGPUs):
224            self.assertEqual(outputs[i], expected[i])
225
226        # Test with tuple
227        inputs = [cpu_inputs[i].cuda(i) for i in range(nGPUs)]
228        outputs = [torch.zeros(out_size, device=i, dtype=dtype) for i in range(nGPUs)]
229        nccl.reduce_scatter(tuple(inputs), tuple(outputs))
230
231        for i in range(nGPUs):
232            self.assertEqual(outputs[i], expected[i])
233
234
235instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
236
237if __name__ == "__main__":
238    run_tests()
239