xref: /aosp_15_r20/external/pytorch/test/distributed/test_symmetric_memory.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: c10d"]
2
3import torch
4import torch.distributed as dist
5from torch._C._autograd import DeviceType
6from torch._C._distributed_c10d import _SymmetricMemory
7from torch.distributed._symmetric_memory import (
8    _fused_all_gather_matmul_fallback,
9    _fused_all_gather_scaled_matmul_fallback,
10    _fused_matmul_reduce_scatter_fallback,
11    _fused_scaled_matmul_reduce_scatter_fallback,
12    enable_symm_mem_for_group,
13    restride_A_for_fused_matmul_reduce_scatter,
14    restride_A_shard_for_fused_all_gather_matmul,
15)
16from torch.testing._internal.common_distributed import (
17    MultiProcessTestCase,
18    skip_if_lt_x_gpu,
19)
20from torch.testing._internal.common_utils import (
21    instantiate_parametrized_tests,
22    parametrize,
23    run_tests,
24    skip_but_pass_in_sandcastle_if,
25    skipIfRocm,
26)
27
28
29def requires_cuda_p2p_access():
30    cuda_p2p_access_available = (
31        torch.cuda.is_available()
32        and torch.cuda.get_device_capability() >= (8, 0)
33        and torch.cuda.device_count() >= 2
34    )
35    num_devices = torch.cuda.device_count()
36    for i in range(num_devices - 1):
37        for j in range(i + 1, num_devices):
38            if not torch.cuda.can_device_access_peer(i, j):
39                cuda_p2p_access_available = False
40                break
41        if not cuda_p2p_access_available:
42            break
43
44    return skip_but_pass_in_sandcastle_if(
45        not cuda_p2p_access_available,
46        "cuda p2p access is not available",
47    )
48
49
50def requires_multicast_support():
51    has_multicast_support = (
52        torch.cuda.is_available()
53        and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
54    )
55    return skip_but_pass_in_sandcastle_if(
56        not has_multicast_support,
57        "multicast support is not available",
58    )
59
60
61@instantiate_parametrized_tests
62@requires_cuda_p2p_access()
63class SymmetricMemoryTest(MultiProcessTestCase):
64    def setUp(self) -> None:
65        super().setUp()
66        self._spawn_processes()
67
68    @property
69    def world_size(self) -> int:
70        return 2
71
72    @property
73    def device(self) -> torch.device:
74        return torch.device(f"cuda:{self.rank}")
75
76    def _init_process(self):
77        torch.cuda.set_device(self.device)
78        store = dist.FileStore(self.file_name, self.world_size)
79        dist.init_process_group(
80            backend="nccl",
81            world_size=self.world_size,
82            rank=self.rank,
83            store=store,
84        )
85        enable_symm_mem_for_group(dist.group.WORLD.group_name)
86
87    def _verify_symmetric_memory(self, symm_mem):
88        self.assertEqual(symm_mem.world_size, 2)
89
90        buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
91        if symm_mem.rank == 0:
92            symm_mem.wait_signal(src_rank=1)
93            self.assertTrue(buf.eq(42).all())
94        else:
95            buf.fill_(42)
96            symm_mem.put_signal(dst_rank=0)
97
98        symm_mem.barrier()
99
100        if symm_mem.rank == 0:
101            symm_mem.barrier()
102            self.assertTrue(buf.eq(43).all())
103        else:
104            buf.fill_(43)
105            symm_mem.barrier()
106
107        symm_mem.barrier()
108
109    @skipIfRocm
110    @skip_if_lt_x_gpu(2)
111    def test_cuda_nvlink_connectivity_detection(self) -> None:
112        from torch._C._distributed_c10d import _detect_dma_connectivity
113
114        connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
115        self.assertEqual(connectivity.device_type, DeviceType.CUDA)
116        self.assertEqual(connectivity.connection_type, "nvlink")
117        self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
118        for row in connectivity.matrix:
119            self.assertEqual(len(row), torch.cuda.device_count())
120
121    @skipIfRocm
122    @skip_if_lt_x_gpu(2)
123    def test_empty_strided_p2p(self) -> None:
124        self._init_process()
125
126        shape = (64, 64)
127        stride = (64, 1)
128        dtype = torch.float32
129        device = self.device
130        group_name = "0"
131        alloc_args = (shape, stride, dtype, device, group_name)
132
133        t = torch.empty(shape, dtype=dtype, device=device)
134        self.assertIsNone(_SymmetricMemory.rendezvous(t))
135
136        t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
137        symm_mem = _SymmetricMemory.rendezvous(t)
138
139        del t
140        self._verify_symmetric_memory(symm_mem)
141        dist.destroy_process_group()
142
143    @skipIfRocm
144    @skip_if_lt_x_gpu(2)
145    def test_empty_strided_p2p_persistent(self) -> None:
146        self._init_process()
147
148        shape = (64, 64)
149        stride = (64, 1)
150        dtype = torch.float32
151        device = self.device
152        alloc_id = 42  # Persistent allocation
153        group_name = "0"
154        alloc_args = (shape, stride, dtype, device, group_name, alloc_id)
155
156        t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
157        data_ptr = t.data_ptr()
158
159        # Verify that persistent allocation would fail if there's an active
160        # allocation with the same alloc_id.
161        with self.assertRaises(RuntimeError):
162            _SymmetricMemory.empty_strided_p2p(*alloc_args)
163
164        # Verify that persistent allocation would succeed in lieu of activate
165        # allocations with the same alloc_id, and the returned tensor would
166        # have the same data pointer.
167        del t
168        t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
169        self.assertEqual(t.data_ptr(), data_ptr)
170
171        # Verify that get_symmetric_memory would fail if called before
172        # rendezvous.
173        with self.assertRaises(RuntimeError):
174            _SymmetricMemory.get_symmetric_memory(t)
175
176        symm_mem_0 = _SymmetricMemory.rendezvous(t)
177        symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
178        self.assertEqual(id(symm_mem_0), id(symm_mem_1))
179
180        self._verify_symmetric_memory(symm_mem_0)
181        dist.destroy_process_group()
182
183    @skipIfRocm
184    @skip_if_lt_x_gpu(2)
185    @parametrize("gather_dim", [0, 1])
186    def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
187        self._init_process()
188
189        BATCH = 8
190        M = 64
191        N = 16
192        K = 32
193        group = dist.group.WORLD
194        rank = self.rank
195        world_size = self.world_size
196
197        torch.manual_seed(42 + rank)
198        A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
199        Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
200
201        ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
202            A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
203        )
204        ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
205            A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
206        )
207
208        assert torch.allclose(ag_output_0, ag_output_1)
209        assert ag_output_0.stride() == ag_output_1.stride()
210        for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
211            assert torch.allclose(mm_output_0, mm_output_1)
212            assert mm_output_0.stride(), mm_output_1.stride()
213
214        dist.destroy_process_group()
215
216    @skipIfRocm
217    @skip_if_lt_x_gpu(2)
218    @parametrize("gather_dim", [0, 1])
219    def test_fused_all_gather_scaled_matmul(self, gather_dim: int) -> None:
220        self._init_process()
221
222        BATCH = 8
223        M = 64
224        N = 16
225        K = 32
226        group = dist.group.WORLD
227        rank = self.rank
228        world_size = self.world_size
229
230        torch.manual_seed(42 + rank)
231        A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda").to(
232            torch.float8_e4m3fn
233        )
234        A_scale = torch.tensor(0.1, device="cuda")
235        Bs = [
236            torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)
237        ]
238        B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)]
239        out_dtypes = [None, torch.bfloat16, torch.float32]
240
241        ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
242            A_shard,
243            Bs,
244            A_scale,
245            B_scales,
246            gather_dim=gather_dim,
247            group_name=group.group_name,
248            biases=[None] * len(Bs),
249            result_scales=[None] * len(Bs),
250            out_dtypes=out_dtypes,
251            use_fast_accum=[None] * len(Bs),
252        )
253        ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
254            A_shard,
255            Bs,
256            A_scale,
257            B_scales,
258            gather_dim=gather_dim,
259            group_name=group.group_name,
260            biases=[None] * len(Bs),
261            result_scales=[None] * len(Bs),
262            out_dtypes=out_dtypes,
263            use_fast_accum=[None] * len(Bs),
264        )
265
266        self.assertTrue(
267            torch.allclose(
268                ag_output_0.to(torch.float32),
269                ag_output_1.to(torch.float32),
270            )
271        )
272        self.assertEqual(ag_output_0.stride(), ag_output_1.stride())
273        for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
274            self.assertTrue(
275                torch.allclose(
276                    mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
277                )
278            )
279            self.assertEqual(mm_output_0.stride(), mm_output_1.stride())
280            self.assertEqual(mm_output_0.dtype, mm_output_1.dtype)
281
282        dist.destroy_process_group()
283
284    @skipIfRocm
285    @skip_if_lt_x_gpu(2)
286    @parametrize("scatter_dim", [0, 1])
287    def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
288        self._init_process()
289
290        BATCH = 8
291        M = 64
292        N = 16
293        K = 32
294        group = dist.group.WORLD
295        rank = self.rank
296        world_size = self.world_size
297
298        torch.manual_seed(42 + rank)
299        A = torch.rand(BATCH, M, K, device="cuda")
300        B = torch.rand(K, N, device="cuda")
301
302        output_0 = _fused_matmul_reduce_scatter_fallback(
303            A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
304        )
305        output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(
306            A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
307        )
308
309        assert torch.allclose(output_0, output_1)
310        assert output_0.stride() == output_1.stride()
311
312        dist.destroy_process_group()
313
314    @skipIfRocm
315    @skip_if_lt_x_gpu(2)
316    @parametrize("scatter_dim", [0, 1])
317    def test_fused_scaled_matmul_reduce_scatter(self, scatter_dim: int) -> None:
318        self._init_process()
319
320        BATCH = 8
321        M = 64
322        N = 16
323        K = 32
324        group = dist.group.WORLD
325        rank = self.rank
326        world_size = self.world_size
327
328        torch.manual_seed(42 + rank)
329        A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)
330        A_scale = torch.tensor(0.1, device="cuda")
331        B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T
332        B_scale = torch.tensor(0.1, device="cuda")
333
334        output_0 = _fused_scaled_matmul_reduce_scatter_fallback(
335            A,
336            B,
337            A_scale,
338            B_scale,
339            "avg",
340            scatter_dim,
341            group.group_name,
342            out_dtype=torch.bfloat16,
343        )
344        output_1 = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
345            A,
346            B,
347            A_scale,
348            B_scale,
349            "avg",
350            scatter_dim,
351            group.group_name,
352            out_dtype=torch.bfloat16,
353        )
354
355        assert torch.allclose(output_0, output_1)
356        assert output_0.stride() == output_1.stride()
357
358        dist.destroy_process_group()
359
360    @skipIfRocm
361    @parametrize("dim", [0, 1, 2])
362    def test_optimal_layout(self, dim: int) -> None:
363        t = torch.rand(8, 64, 32, 16)
364
365        x = restride_A_shard_for_fused_all_gather_matmul(t, dim)
366        self.assertTrue(x.movedim(dim, 0).is_contiguous())
367        self.assertTrue(torch.allclose(x, t))
368
369        x = restride_A_for_fused_matmul_reduce_scatter(t, dim)
370        self.assertTrue(x.movedim(dim, 0).is_contiguous())
371        self.assertTrue(torch.allclose(x, t))
372
373    @skipIfRocm
374    @skip_if_lt_x_gpu(2)
375    @parametrize("symm_mem_input", [True, False])
376    def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
377        self._init_process()
378
379        if symm_mem_input:
380            t = _SymmetricMemory.empty_strided_p2p(
381                size=(64, 64),
382                stride=(64, 1),
383                dtype=torch.float32,
384                device=self.device,
385                group_name="0",
386            ).fill_(self.rank)
387        else:
388            t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
389
390        res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
391        res = torch.ops._c10d_functional.wait_tensor(res)
392        self.assertEqual(res.shape, (64 * self.world_size, 64))
393
394        chunks = res.chunk(self.world_size)
395        for r in range(self.world_size):
396            self.assertTrue(chunks[r].eq(r).all())
397
398        dist.destroy_process_group()
399
400    @skipIfRocm
401    @skip_if_lt_x_gpu(2)
402    @parametrize("reduce_op", ["sum", "avg"])
403    @parametrize("symm_mem_input", [True, False])
404    def test_low_contention_reduce_scatter(
405        self, reduce_op: str, symm_mem_input: bool
406    ) -> None:
407        self._init_process()
408
409        if symm_mem_input:
410            t = _SymmetricMemory.empty_strided_p2p(
411                size=(64, 64),
412                stride=(64, 1),
413                dtype=torch.float32,
414                device=self.device,
415                group_name="0",
416            )
417        else:
418            t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
419
420        chunks = t.chunk(self.world_size)
421        for r in range(self.world_size):
422            chunks[r].fill_(r)
423
424        res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
425        res = torch.ops._c10d_functional.wait_tensor(res)
426        self.assertEqual(res.shape, (64 // self.world_size, 64))
427
428        if reduce_op == "sum":
429            expect = self.rank * self.world_size
430        elif reduce_op == "avg":
431            expect = self.rank
432        else:
433            raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
434        self.assertTrue(res.eq(expect).all())
435
436        dist.destroy_process_group()
437
438    @skip_if_lt_x_gpu(2)
439    @requires_multicast_support()
440    @parametrize("dtype", [torch.float, torch.bfloat16])
441    @parametrize("align_bytes", [4, 8, 16])
442    @parametrize("size_bytes", [4, 8192, 8196])
443    def test_multimem_all_reduce(
444        self, dtype: torch.dtype, size_bytes: int, align_bytes: int
445    ) -> None:
446        self._init_process()
447        group_name = dist.group.WORLD.group_name
448
449        t = _SymmetricMemory.empty_strided_p2p(
450            size=(16384,),
451            stride=(1,),
452            dtype=dtype,
453            device=self.device,
454            group_name=group_name,
455        ).fill_(1)
456
457        self.assertTrue(t.data_ptr() % 16 == 0)
458        self.assertTrue(align_bytes % t.element_size() == 0)
459        self.assertTrue(size_bytes % t.element_size() == 0)
460
461        shift = align_bytes // t.element_size()
462        numel = size_bytes // t.element_size()
463        x = t[shift : shift + numel]
464
465        torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
466        self.assertTrue(x.eq(self.world_size).all().item())
467
468        # Head and tail should not be written
469        self.assertTrue(t[:shift].eq(1).all().item())
470        self.assertTrue(t[shift + numel :].eq(1).all().item())
471        dist.destroy_process_group()
472
473    @skip_if_lt_x_gpu(2)
474    @requires_multicast_support()
475    @parametrize("dtype", [torch.float, torch.bfloat16])
476    @parametrize("align_bytes", [4, 8, 16])
477    @parametrize("size_bytes", [4, 8192, 8196])
478    def test_multimem_one_shot_all_reduce(
479        self, dtype: torch.dtype, size_bytes: int, align_bytes: int
480    ) -> None:
481        self._init_process()
482        group_name = dist.group.WORLD.group_name
483
484        t = _SymmetricMemory.empty_strided_p2p(
485            size=(16384,),
486            stride=(1,),
487            dtype=dtype,
488            device=self.device,
489            group_name=group_name,
490        ).fill_(0)
491
492        self.assertTrue(t.data_ptr() % 16 == 0)
493        self.assertTrue(align_bytes % t.element_size() == 0)
494        self.assertTrue(size_bytes % t.element_size() == 0)
495
496        shift = align_bytes // t.element_size()
497        numel = size_bytes // t.element_size()
498        x = t[shift : shift + numel]
499        x.fill_(1)
500
501        res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
502        self.assertTrue(res.eq(self.world_size).all().item())
503        dist.destroy_process_group()
504
505
506if __name__ == "__main__":
507    run_tests()
508