1# Owner(s): ["module: c10d"] 2 3import torch 4import torch.distributed as dist 5from torch._C._autograd import DeviceType 6from torch._C._distributed_c10d import _SymmetricMemory 7from torch.distributed._symmetric_memory import ( 8 _fused_all_gather_matmul_fallback, 9 _fused_all_gather_scaled_matmul_fallback, 10 _fused_matmul_reduce_scatter_fallback, 11 _fused_scaled_matmul_reduce_scatter_fallback, 12 enable_symm_mem_for_group, 13 restride_A_for_fused_matmul_reduce_scatter, 14 restride_A_shard_for_fused_all_gather_matmul, 15) 16from torch.testing._internal.common_distributed import ( 17 MultiProcessTestCase, 18 skip_if_lt_x_gpu, 19) 20from torch.testing._internal.common_utils import ( 21 instantiate_parametrized_tests, 22 parametrize, 23 run_tests, 24 skip_but_pass_in_sandcastle_if, 25 skipIfRocm, 26) 27 28 29def requires_cuda_p2p_access(): 30 cuda_p2p_access_available = ( 31 torch.cuda.is_available() 32 and torch.cuda.get_device_capability() >= (8, 0) 33 and torch.cuda.device_count() >= 2 34 ) 35 num_devices = torch.cuda.device_count() 36 for i in range(num_devices - 1): 37 for j in range(i + 1, num_devices): 38 if not torch.cuda.can_device_access_peer(i, j): 39 cuda_p2p_access_available = False 40 break 41 if not cuda_p2p_access_available: 42 break 43 44 return skip_but_pass_in_sandcastle_if( 45 not cuda_p2p_access_available, 46 "cuda p2p access is not available", 47 ) 48 49 50def requires_multicast_support(): 51 has_multicast_support = ( 52 torch.cuda.is_available() 53 and _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0) 54 ) 55 return skip_but_pass_in_sandcastle_if( 56 not has_multicast_support, 57 "multicast support is not available", 58 ) 59 60 61@instantiate_parametrized_tests 62@requires_cuda_p2p_access() 63class SymmetricMemoryTest(MultiProcessTestCase): 64 def setUp(self) -> None: 65 super().setUp() 66 self._spawn_processes() 67 68 @property 69 def world_size(self) -> int: 70 return 2 71 72 @property 73 def device(self) -> torch.device: 74 return torch.device(f"cuda:{self.rank}") 75 76 def _init_process(self): 77 torch.cuda.set_device(self.device) 78 store = dist.FileStore(self.file_name, self.world_size) 79 dist.init_process_group( 80 backend="nccl", 81 world_size=self.world_size, 82 rank=self.rank, 83 store=store, 84 ) 85 enable_symm_mem_for_group(dist.group.WORLD.group_name) 86 87 def _verify_symmetric_memory(self, symm_mem): 88 self.assertEqual(symm_mem.world_size, 2) 89 90 buf = symm_mem.get_buffer(0, (64, 64), torch.float32) 91 if symm_mem.rank == 0: 92 symm_mem.wait_signal(src_rank=1) 93 self.assertTrue(buf.eq(42).all()) 94 else: 95 buf.fill_(42) 96 symm_mem.put_signal(dst_rank=0) 97 98 symm_mem.barrier() 99 100 if symm_mem.rank == 0: 101 symm_mem.barrier() 102 self.assertTrue(buf.eq(43).all()) 103 else: 104 buf.fill_(43) 105 symm_mem.barrier() 106 107 symm_mem.barrier() 108 109 @skipIfRocm 110 @skip_if_lt_x_gpu(2) 111 def test_cuda_nvlink_connectivity_detection(self) -> None: 112 from torch._C._distributed_c10d import _detect_dma_connectivity 113 114 connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink") 115 self.assertEqual(connectivity.device_type, DeviceType.CUDA) 116 self.assertEqual(connectivity.connection_type, "nvlink") 117 self.assertEqual(len(connectivity.matrix), torch.cuda.device_count()) 118 for row in connectivity.matrix: 119 self.assertEqual(len(row), torch.cuda.device_count()) 120 121 @skipIfRocm 122 @skip_if_lt_x_gpu(2) 123 def test_empty_strided_p2p(self) -> None: 124 self._init_process() 125 126 shape = (64, 64) 127 stride = (64, 1) 128 dtype = torch.float32 129 device = self.device 130 group_name = "0" 131 alloc_args = (shape, stride, dtype, device, group_name) 132 133 t = torch.empty(shape, dtype=dtype, device=device) 134 self.assertIsNone(_SymmetricMemory.rendezvous(t)) 135 136 t = _SymmetricMemory.empty_strided_p2p(*alloc_args) 137 symm_mem = _SymmetricMemory.rendezvous(t) 138 139 del t 140 self._verify_symmetric_memory(symm_mem) 141 dist.destroy_process_group() 142 143 @skipIfRocm 144 @skip_if_lt_x_gpu(2) 145 def test_empty_strided_p2p_persistent(self) -> None: 146 self._init_process() 147 148 shape = (64, 64) 149 stride = (64, 1) 150 dtype = torch.float32 151 device = self.device 152 alloc_id = 42 # Persistent allocation 153 group_name = "0" 154 alloc_args = (shape, stride, dtype, device, group_name, alloc_id) 155 156 t = _SymmetricMemory.empty_strided_p2p(*alloc_args) 157 data_ptr = t.data_ptr() 158 159 # Verify that persistent allocation would fail if there's an active 160 # allocation with the same alloc_id. 161 with self.assertRaises(RuntimeError): 162 _SymmetricMemory.empty_strided_p2p(*alloc_args) 163 164 # Verify that persistent allocation would succeed in lieu of activate 165 # allocations with the same alloc_id, and the returned tensor would 166 # have the same data pointer. 167 del t 168 t = _SymmetricMemory.empty_strided_p2p(*alloc_args) 169 self.assertEqual(t.data_ptr(), data_ptr) 170 171 # Verify that get_symmetric_memory would fail if called before 172 # rendezvous. 173 with self.assertRaises(RuntimeError): 174 _SymmetricMemory.get_symmetric_memory(t) 175 176 symm_mem_0 = _SymmetricMemory.rendezvous(t) 177 symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) 178 self.assertEqual(id(symm_mem_0), id(symm_mem_1)) 179 180 self._verify_symmetric_memory(symm_mem_0) 181 dist.destroy_process_group() 182 183 @skipIfRocm 184 @skip_if_lt_x_gpu(2) 185 @parametrize("gather_dim", [0, 1]) 186 def test_fused_all_gather_matmul(self, gather_dim: int) -> None: 187 self._init_process() 188 189 BATCH = 8 190 M = 64 191 N = 16 192 K = 32 193 group = dist.group.WORLD 194 rank = self.rank 195 world_size = self.world_size 196 197 torch.manual_seed(42 + rank) 198 A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda") 199 Bs = [torch.rand(K, N, device="cuda") for _ in range(3)] 200 201 ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( 202 A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name 203 ) 204 ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul( 205 A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name 206 ) 207 208 assert torch.allclose(ag_output_0, ag_output_1) 209 assert ag_output_0.stride() == ag_output_1.stride() 210 for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): 211 assert torch.allclose(mm_output_0, mm_output_1) 212 assert mm_output_0.stride(), mm_output_1.stride() 213 214 dist.destroy_process_group() 215 216 @skipIfRocm 217 @skip_if_lt_x_gpu(2) 218 @parametrize("gather_dim", [0, 1]) 219 def test_fused_all_gather_scaled_matmul(self, gather_dim: int) -> None: 220 self._init_process() 221 222 BATCH = 8 223 M = 64 224 N = 16 225 K = 32 226 group = dist.group.WORLD 227 rank = self.rank 228 world_size = self.world_size 229 230 torch.manual_seed(42 + rank) 231 A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda").to( 232 torch.float8_e4m3fn 233 ) 234 A_scale = torch.tensor(0.1, device="cuda") 235 Bs = [ 236 torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3) 237 ] 238 B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)] 239 out_dtypes = [None, torch.bfloat16, torch.float32] 240 241 ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback( 242 A_shard, 243 Bs, 244 A_scale, 245 B_scales, 246 gather_dim=gather_dim, 247 group_name=group.group_name, 248 biases=[None] * len(Bs), 249 result_scales=[None] * len(Bs), 250 out_dtypes=out_dtypes, 251 use_fast_accum=[None] * len(Bs), 252 ) 253 ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul( 254 A_shard, 255 Bs, 256 A_scale, 257 B_scales, 258 gather_dim=gather_dim, 259 group_name=group.group_name, 260 biases=[None] * len(Bs), 261 result_scales=[None] * len(Bs), 262 out_dtypes=out_dtypes, 263 use_fast_accum=[None] * len(Bs), 264 ) 265 266 self.assertTrue( 267 torch.allclose( 268 ag_output_0.to(torch.float32), 269 ag_output_1.to(torch.float32), 270 ) 271 ) 272 self.assertEqual(ag_output_0.stride(), ag_output_1.stride()) 273 for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): 274 self.assertTrue( 275 torch.allclose( 276 mm_output_0.to(torch.float32), mm_output_1.to(torch.float32) 277 ) 278 ) 279 self.assertEqual(mm_output_0.stride(), mm_output_1.stride()) 280 self.assertEqual(mm_output_0.dtype, mm_output_1.dtype) 281 282 dist.destroy_process_group() 283 284 @skipIfRocm 285 @skip_if_lt_x_gpu(2) 286 @parametrize("scatter_dim", [0, 1]) 287 def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: 288 self._init_process() 289 290 BATCH = 8 291 M = 64 292 N = 16 293 K = 32 294 group = dist.group.WORLD 295 rank = self.rank 296 world_size = self.world_size 297 298 torch.manual_seed(42 + rank) 299 A = torch.rand(BATCH, M, K, device="cuda") 300 B = torch.rand(K, N, device="cuda") 301 302 output_0 = _fused_matmul_reduce_scatter_fallback( 303 A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name 304 ) 305 output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter( 306 A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name 307 ) 308 309 assert torch.allclose(output_0, output_1) 310 assert output_0.stride() == output_1.stride() 311 312 dist.destroy_process_group() 313 314 @skipIfRocm 315 @skip_if_lt_x_gpu(2) 316 @parametrize("scatter_dim", [0, 1]) 317 def test_fused_scaled_matmul_reduce_scatter(self, scatter_dim: int) -> None: 318 self._init_process() 319 320 BATCH = 8 321 M = 64 322 N = 16 323 K = 32 324 group = dist.group.WORLD 325 rank = self.rank 326 world_size = self.world_size 327 328 torch.manual_seed(42 + rank) 329 A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn) 330 A_scale = torch.tensor(0.1, device="cuda") 331 B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T 332 B_scale = torch.tensor(0.1, device="cuda") 333 334 output_0 = _fused_scaled_matmul_reduce_scatter_fallback( 335 A, 336 B, 337 A_scale, 338 B_scale, 339 "avg", 340 scatter_dim, 341 group.group_name, 342 out_dtype=torch.bfloat16, 343 ) 344 output_1 = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( 345 A, 346 B, 347 A_scale, 348 B_scale, 349 "avg", 350 scatter_dim, 351 group.group_name, 352 out_dtype=torch.bfloat16, 353 ) 354 355 assert torch.allclose(output_0, output_1) 356 assert output_0.stride() == output_1.stride() 357 358 dist.destroy_process_group() 359 360 @skipIfRocm 361 @parametrize("dim", [0, 1, 2]) 362 def test_optimal_layout(self, dim: int) -> None: 363 t = torch.rand(8, 64, 32, 16) 364 365 x = restride_A_shard_for_fused_all_gather_matmul(t, dim) 366 self.assertTrue(x.movedim(dim, 0).is_contiguous()) 367 self.assertTrue(torch.allclose(x, t)) 368 369 x = restride_A_for_fused_matmul_reduce_scatter(t, dim) 370 self.assertTrue(x.movedim(dim, 0).is_contiguous()) 371 self.assertTrue(torch.allclose(x, t)) 372 373 @skipIfRocm 374 @skip_if_lt_x_gpu(2) 375 @parametrize("symm_mem_input", [True, False]) 376 def test_low_contention_all_gather(self, symm_mem_input: bool) -> None: 377 self._init_process() 378 379 if symm_mem_input: 380 t = _SymmetricMemory.empty_strided_p2p( 381 size=(64, 64), 382 stride=(64, 1), 383 dtype=torch.float32, 384 device=self.device, 385 group_name="0", 386 ).fill_(self.rank) 387 else: 388 t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device) 389 390 res = torch.ops.symm_mem._low_contention_all_gather(t, "0") 391 res = torch.ops._c10d_functional.wait_tensor(res) 392 self.assertEqual(res.shape, (64 * self.world_size, 64)) 393 394 chunks = res.chunk(self.world_size) 395 for r in range(self.world_size): 396 self.assertTrue(chunks[r].eq(r).all()) 397 398 dist.destroy_process_group() 399 400 @skipIfRocm 401 @skip_if_lt_x_gpu(2) 402 @parametrize("reduce_op", ["sum", "avg"]) 403 @parametrize("symm_mem_input", [True, False]) 404 def test_low_contention_reduce_scatter( 405 self, reduce_op: str, symm_mem_input: bool 406 ) -> None: 407 self._init_process() 408 409 if symm_mem_input: 410 t = _SymmetricMemory.empty_strided_p2p( 411 size=(64, 64), 412 stride=(64, 1), 413 dtype=torch.float32, 414 device=self.device, 415 group_name="0", 416 ) 417 else: 418 t = torch.empty((64, 64), dtype=torch.float32, device=self.device) 419 420 chunks = t.chunk(self.world_size) 421 for r in range(self.world_size): 422 chunks[r].fill_(r) 423 424 res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0") 425 res = torch.ops._c10d_functional.wait_tensor(res) 426 self.assertEqual(res.shape, (64 // self.world_size, 64)) 427 428 if reduce_op == "sum": 429 expect = self.rank * self.world_size 430 elif reduce_op == "avg": 431 expect = self.rank 432 else: 433 raise AssertionError(f"Unexpected reduce_op: {reduce_op}") 434 self.assertTrue(res.eq(expect).all()) 435 436 dist.destroy_process_group() 437 438 @skip_if_lt_x_gpu(2) 439 @requires_multicast_support() 440 @parametrize("dtype", [torch.float, torch.bfloat16]) 441 @parametrize("align_bytes", [4, 8, 16]) 442 @parametrize("size_bytes", [4, 8192, 8196]) 443 def test_multimem_all_reduce( 444 self, dtype: torch.dtype, size_bytes: int, align_bytes: int 445 ) -> None: 446 self._init_process() 447 group_name = dist.group.WORLD.group_name 448 449 t = _SymmetricMemory.empty_strided_p2p( 450 size=(16384,), 451 stride=(1,), 452 dtype=dtype, 453 device=self.device, 454 group_name=group_name, 455 ).fill_(1) 456 457 self.assertTrue(t.data_ptr() % 16 == 0) 458 self.assertTrue(align_bytes % t.element_size() == 0) 459 self.assertTrue(size_bytes % t.element_size() == 0) 460 461 shift = align_bytes // t.element_size() 462 numel = size_bytes // t.element_size() 463 x = t[shift : shift + numel] 464 465 torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name) 466 self.assertTrue(x.eq(self.world_size).all().item()) 467 468 # Head and tail should not be written 469 self.assertTrue(t[:shift].eq(1).all().item()) 470 self.assertTrue(t[shift + numel :].eq(1).all().item()) 471 dist.destroy_process_group() 472 473 @skip_if_lt_x_gpu(2) 474 @requires_multicast_support() 475 @parametrize("dtype", [torch.float, torch.bfloat16]) 476 @parametrize("align_bytes", [4, 8, 16]) 477 @parametrize("size_bytes", [4, 8192, 8196]) 478 def test_multimem_one_shot_all_reduce( 479 self, dtype: torch.dtype, size_bytes: int, align_bytes: int 480 ) -> None: 481 self._init_process() 482 group_name = dist.group.WORLD.group_name 483 484 t = _SymmetricMemory.empty_strided_p2p( 485 size=(16384,), 486 stride=(1,), 487 dtype=dtype, 488 device=self.device, 489 group_name=group_name, 490 ).fill_(0) 491 492 self.assertTrue(t.data_ptr() % 16 == 0) 493 self.assertTrue(align_bytes % t.element_size() == 0) 494 self.assertTrue(size_bytes % t.element_size() == 0) 495 496 shift = align_bytes // t.element_size() 497 numel = size_bytes // t.element_size() 498 x = t[shift : shift + numel] 499 x.fill_(1) 500 501 res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name) 502 self.assertTrue(res.eq(self.world_size).all().item()) 503 dist.destroy_process_group() 504 505 506if __name__ == "__main__": 507 run_tests() 508