1# mypy: allow-untyped-decorators 2import socket 3import uuid 4from contextlib import contextmanager 5from datetime import timedelta 6from functools import partial 7from typing import Any, Callable, Dict, Generator, List, Optional, Tuple 8 9import torch 10import torch.distributed._functional_collectives as funcol 11import torch.distributed.distributed_c10d as c10d 12from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work 13 14 15_group_name_to_store: Dict[str, c10d.Store] = {} 16 17 18def enable_symm_mem_for_group(group_name: str) -> None: 19 """ 20 Enables symmetric memory for a process group. 21 22 Args: 23 group_name (str): the name of the process group. 24 """ 25 if group_name in _group_name_to_store: 26 return 27 28 group = c10d._resolve_process_group(group_name) 29 global_ranks = sorted(c10d._world.pg_group_ranks[group].keys()) 30 # Different subgroups with the same name should use different stores 31 global_ranks_str = "_".join(map(str, global_ranks)) 32 store = c10d.PrefixStore( 33 f"symmetric_memory-{global_ranks_str}", 34 c10d._get_process_group_store(group), 35 ) 36 # Use one store-based broadcast to bootstrap a file store from the process 37 # and simultaneously verify that all ranks are on the same host. 38 hostname = socket.gethostname() 39 if group.rank() == 0: 40 uid = str(uuid.uuid4()) 41 msg = f"{hostname}/{uid}" 42 store.set("init", msg) 43 else: 44 msg = store.get("init").decode("utf-8") 45 tokens = msg.split("/") 46 assert len(tokens) == 2, tokens 47 rank_0_hostname, uid = tokens 48 if hostname != rank_0_hostname: 49 raise RuntimeError( 50 "init_symmetric_memory_for_process_group() failed for " 51 f'group "{group_name}". Rank 0 and rank {group.rank()} ' 52 f"are on different hosts ({rank_0_hostname} and {hostname})" 53 ) 54 store = torch._C._distributed_c10d.FileStore(f"/tmp/{uid}", group.size()) 55 # TODO: check device connectiivity 56 _group_name_to_store[group_name] = store 57 _SymmetricMemory.set_group_info( 58 group_name, 59 group.rank(), 60 group.size(), 61 store, 62 ) 63 64 65_is_test_mode: bool = False 66 67 68@contextmanager 69def _test_mode() -> Generator[None, None, None]: 70 """ 71 Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops 72 defined in the ``symm_mem`` namespace to use fallback implementations. 73 74 The context manager is not thread safe. 75 """ 76 global _is_test_mode 77 prev = _is_test_mode 78 try: 79 _is_test_mode = True 80 yield 81 finally: 82 _is_test_mode = prev 83 84 85def is_symm_mem_enabled_for_group(group_name: str) -> bool: 86 """ 87 Check if symmetric memory is enabled for a process group. 88 89 Args: 90 group_name (str): the name of the process group. 91 """ 92 return _is_test_mode or group_name in _group_name_to_store 93 94 95_group_name_to_workspace_tensor: Dict[str, Optional[torch.Tensor]] = {} 96 97 98def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: 99 """ 100 Get the symmetric memory workspace associated with the process group. If 101 ``min_size`` is greater than the workspace associated with ``group_name``, 102 the workspace will be re-allocated and re-rendezvous'd. 103 104 Args: 105 group_name (str): the name of the process group. 106 min_size (int): the size requirement for the workspace in bytes. 107 108 Returns: 109 _SymmetricMemory: the symmetric memory workspace associated with the 110 group. 111 """ 112 tensor = _group_name_to_workspace_tensor.get(group_name) 113 size = tensor.numel() * tensor.element_size() if tensor is not None else 0 114 if tensor is None or size < min_size: 115 tensor = _SymmetricMemory.empty_strided_p2p( 116 (max(size, min_size),), 117 [1], 118 torch.uint8, 119 torch.device(f"cuda:{torch.cuda.current_device()}"), 120 group_name, 121 ) 122 _group_name_to_workspace_tensor[group_name] = tensor 123 return _SymmetricMemory.rendezvous(tensor) 124 125 126_backend_stream: Optional[torch.cuda.Stream] = None 127 128 129def _get_backend_stream() -> torch.cuda.Stream: 130 global _backend_stream 131 if _backend_stream is None: 132 _backend_stream = torch.cuda.Stream() 133 return _backend_stream 134 135 136def _pipelined_all_gather_and_consume( 137 shard: torch.Tensor, 138 shard_consumer: Callable[[torch.Tensor, int], None], 139 ag_out: torch.Tensor, 140 group_name: str, 141) -> None: 142 """ 143 Perform the following logic with micro-pipelined computation and 144 communication: 145 146 tensor = all_gather_tensor(shard, gather_dim=1, group=group) 147 chunks = tensor.chunk(group.size()) 148 for src_rank, chunk in enumerate(chunks): 149 shard_consumer(chunk, src_rank) 150 151 NOTE: 152 - The shard passed to shard consumer will always be contiguous. 153 """ 154 p2p_workspace_size_req = shard.numel() * shard.element_size() 155 symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) 156 group_size = symm_mem.world_size 157 rank = symm_mem.rank 158 159 backend_stream = _get_backend_stream() 160 backend_stream.wait_stream(torch.cuda.current_stream()) 161 local_p2p_buf = symm_mem.get_buffer(rank, shard.shape, shard.dtype) 162 163 chunks = ag_out.chunk(group_size) 164 165 # While consuming local shard, copy it to the local p2p buffer 166 # in another stream. 167 shard_consumer(shard, rank) 168 chunks[rank].copy_(shard) 169 170 with torch.cuda.stream(backend_stream): 171 local_p2p_buf.copy_(shard) 172 symm_mem.barrier(channel=0) 173 torch.cuda.current_stream().wait_stream(backend_stream) 174 175 # At this point, all ranks have copied their local shard to 176 # their local p2p buffer. Each rank can now copy and consume 177 # remote shards. 178 for step in range(1, group_size): 179 if step % 2 == 0: 180 stream = torch.cuda.current_stream() 181 else: 182 stream = backend_stream 183 remote_rank = (step + rank) % group_size 184 remote_p2p_buf = symm_mem.get_buffer(remote_rank, shard.shape, shard.dtype) 185 with torch.cuda.stream(stream): 186 chunks[remote_rank].copy_(remote_p2p_buf) 187 shard_consumer(chunks[remote_rank], remote_rank) 188 189 with torch.cuda.stream(backend_stream): 190 symm_mem.barrier(channel=group_size % 2) 191 torch.cuda.current_stream().wait_stream(backend_stream) 192 193 194def _pipelined_produce_and_all2all( 195 chunk_producer: Callable[[int, torch.Tensor], None], 196 output: torch.Tensor, 197 group_name: str, 198) -> None: 199 """ 200 Perform the following logic with micro-pipelined computation and 201 communication: 202 203 chunks = [ 204 chunk_producer(dst_rank, chunks[dst_rank]) 205 for dst_rank in range(group_size): 206 ] 207 dist.all_to_all_single(output=output, input=torch.cat(chunks)) 208 """ 209 out_chunks = output.chunk(c10d._get_group_size_by_name(group_name)) 210 p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 211 symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) 212 group_size = symm_mem.world_size 213 rank = symm_mem.rank 214 215 backend_stream = _get_backend_stream() 216 backend_stream.wait_stream(torch.cuda.current_stream()) 217 218 def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: 219 assert idx in (0, 1) 220 offset = 0 if idx == 0 else out_chunks[0].numel() 221 return symm_mem.get_buffer( 222 rank, out_chunks[0].shape, out_chunks[0].dtype, offset 223 ) 224 225 # Prepare two local p2p buffers, so that a remote rank can pull the result 226 # of step [i] in one p2p buffer while the local rank can compute the 227 # result of step [i+1] and write it directly the other p2p buffer. 228 local_p2p_buf_0 = get_p2p_buf(rank, 0) 229 local_p2p_buf_1 = get_p2p_buf(rank, 1) 230 231 for step in range(1, group_size): 232 remote_rank = (rank - step) % group_size 233 if step % 2 == 0: 234 stream = torch.cuda.current_stream() 235 other_stream = backend_stream 236 p2p_buf = local_p2p_buf_1 237 remote_p2p_buf = get_p2p_buf(remote_rank, 1) 238 else: 239 stream = backend_stream 240 other_stream = torch.cuda.current_stream() 241 p2p_buf = local_p2p_buf_0 242 remote_p2p_buf = get_p2p_buf(remote_rank, 0) 243 with torch.cuda.stream(stream): 244 chunk_producer((rank + step) % group_size, p2p_buf) 245 symm_mem.barrier(channel=step % 2) 246 # Make the other stream to wait for the barrier on the current 247 # stream to finish before chunk_producer to avoid the compute 248 # delaying the barrier. 249 other_stream.wait_stream(stream) 250 out_chunks[remote_rank].copy_(remote_p2p_buf) 251 252 chunk_producer(rank, out_chunks[rank]) 253 torch.cuda.current_stream().wait_stream(backend_stream) 254 255 256lib = torch.library.Library("symm_mem", "DEF") # noqa: TOR901 257lib.define( 258 "fused_all_gather_matmul(Tensor A, Tensor[] Bs, int gather_dim, str group_name) -> (Tensor, Tensor[])" 259) 260lib.define( 261 "fused_all_gather_scaled_matmul(" 262 "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, " 263 "int gather_dim, str group_name, " 264 "Tensor?[] biases, " 265 "Tensor?[] result_scales, " 266 "ScalarType?[] out_dtypes, " 267 "bool[] use_fast_accum) -> (Tensor, Tensor[])" 268) 269lib.define( 270 "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor" 271) 272lib.define( 273 "fused_scaled_matmul_reduce_scatter(" 274 "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " 275 "str reduce_op, int scatter_dim, str group_name, " 276 "Tensor? bias = None, " 277 "Tensor? result_scale = None, " 278 "ScalarType? out_dtype = None, " 279 "bool use_fast_accum = False) -> Tensor" 280) 281lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") 282lib.define( 283 "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" 284) 285 286 287def _fused_all_gather_matmul_impl( 288 mm_out_op: torch._ops.OpOverload, 289 A_shard: torch.Tensor, 290 Bs: List[torch.Tensor], 291 kwargs_list: List[Dict[str, Any]], 292 out_dtypes: List[Optional[torch.dtype]], 293 gather_dim: int, 294 group_name: str, 295) -> Tuple[torch.Tensor, List[torch.Tensor]]: 296 if A_shard.dim() < 2: 297 raise ValueError("A_shard must be a matrix") 298 for B in Bs: 299 if B.dim() != 2: 300 raise ValueError("B must be a matrix") 301 if len(out_dtypes) != len(Bs): 302 raise ValueError("len(out_types) must be the same as len(Bs)") 303 if len(kwargs_list) != len(Bs): 304 raise ValueError("len(kwargs_list) must be the same as len(Bs)") 305 if gather_dim < 0 or gather_dim >= A_shard.dim(): 306 raise ValueError("Invalid gather_dim") 307 308 group = c10d._resolve_process_group(group_name) 309 310 # Move the gather_dim to the front and flatten the tensor into a 2D matrix. 311 # The flattened tensor doesn't need to be contiguous (for computation 312 # efficiency), as _pipelined_all_gather_and_consume guarantees that shards 313 # passed to shard_consumer are contiguous. 314 x = A_shard.movedim(gather_dim, 0) 315 leading_dims = [group.size()] + list(x.shape[:-1]) 316 x = x.flatten(0, -2) 317 318 # Helper function for reverting the above transformation 319 def unflatten(t: torch.Tensor) -> torch.Tensor: 320 return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) 321 322 ag_out = x.new_empty( 323 x.shape[0] * group.size(), 324 x.shape[1], 325 ) 326 outputs = [ 327 x.new_empty(x.shape[0] * group.size(), B.shape[1], dtype=out_dtype or B.dtype) 328 for B, out_dtype in zip(Bs, out_dtypes) 329 ] 330 output_shards = [output.chunk(group.size()) for output in outputs] 331 332 # Computing block-wise matmul along the first dim of A 333 def shard_consumer(shard: torch.Tensor, rank: int) -> None: 334 for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): 335 mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) 336 337 _pipelined_all_gather_and_consume( 338 x, 339 shard_consumer, 340 ag_out, 341 group_name, 342 ) 343 return unflatten(ag_out), [unflatten(output) for output in outputs] 344 345 346@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") 347def _fused_all_gather_matmul_fallback( 348 A_shard: torch.Tensor, 349 Bs: List[torch.Tensor], 350 gather_dim: int, 351 group_name: str, 352) -> Tuple[torch.Tensor, List[torch.Tensor]]: 353 group_size = c10d._get_group_size_by_name(group_name) 354 A = torch.ops._c10d_functional.all_gather_into_tensor( 355 A_shard.contiguous(), group_size, group_name 356 ) 357 A = torch.ops._c10d_functional.wait_tensor(A) 358 A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) 359 return A.movedim(0, gather_dim), [ 360 torch.matmul(A, B).movedim(0, gather_dim) for B in Bs 361 ] 362 363 364@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") 365def _fused_all_gather_matmul( 366 A_shard: torch.Tensor, 367 Bs: List[torch.Tensor], 368 gather_dim: int, 369 group_name: str, 370) -> Tuple[torch.Tensor, List[torch.Tensor]]: 371 """ 372 Perform the following logic with micro-pipelined computation and 373 communication: 374 375 all_gather_tensor(A_shard, gather_dim, group_name) @ B 376 377 Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is 378 contiguous, no extra copy is required for input layout transformation. 379 Otherwise A_shard needs to be copied once. 380 """ 381 if _is_test_mode: 382 return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name) 383 384 with torch.profiler.record_function("fused_all_gather_matmul"): 385 return _fused_all_gather_matmul_impl( 386 torch.ops.aten.mm.out, 387 A_shard, 388 Bs, 389 [{} for B in Bs], 390 [B.dtype for B in Bs], 391 gather_dim, 392 group_name, 393 ) 394 395 396@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") 397def _fused_all_gather_scaled_matmul_fallback( 398 A_shard: torch.Tensor, 399 Bs: List[torch.Tensor], 400 A_scale: torch.Tensor, 401 B_scales: List[torch.Tensor], 402 gather_dim: int, 403 group_name: str, 404 biases: List[Optional[torch.Tensor]], 405 result_scales: List[Optional[torch.Tensor]], 406 out_dtypes: List[Optional[torch.dtype]], 407 use_fast_accum: List[bool], 408) -> Tuple[torch.Tensor, List[torch.Tensor]]: 409 out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) 410 411 group_size = c10d._get_group_size_by_name(group_name) 412 A = torch.ops._c10d_functional.all_gather_into_tensor( 413 A_shard.contiguous(), group_size, group_name 414 ) 415 A = torch.ops._c10d_functional.wait_tensor(A) 416 A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) 417 418 def scaled_matmul( 419 A: torch.Tensor, 420 B: torch.Tensor, 421 A_scale: torch.Tensor, 422 B_scale: torch.Tensor, 423 bias: Optional[torch.Tensor], 424 result_scale: Optional[torch.Tensor], 425 out_dtype: Optional[torch.dtype], 426 use_fast_accum: bool, 427 ) -> torch.Tensor: 428 leading_dims = A.shape[:-1] 429 res = torch.ops.aten._scaled_mm( 430 A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype 431 ) 432 return res.unflatten(0, leading_dims) 433 434 return A.movedim(0, gather_dim), [ 435 scaled_matmul( 436 A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum 437 ).movedim(0, gather_dim) 438 for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip( 439 Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum 440 ) 441 ] 442 443 444@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") 445def _fused_all_gather_scaled_matmul( 446 A_shard: torch.Tensor, 447 Bs: List[torch.Tensor], 448 A_scale: torch.Tensor, 449 B_scales: List[torch.Tensor], 450 gather_dim: int, 451 group_name: str, 452 biases: List[Optional[torch.Tensor]], 453 result_scales: List[Optional[torch.Tensor]], 454 out_dtypes: List[Optional[torch.dtype]], 455 use_fast_accum: List[bool], 456) -> Tuple[torch.Tensor, List[torch.Tensor]]: 457 """ 458 Perform the following logic with micro-pipelined computation and 459 communication: 460 461 A = all_gather_tensor(A_shard, gather_dim, group_name) 462 leading_dims = A.shape[:-1] 463 res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) 464 res = res.unflatten(0, leading_dims) 465 466 Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is 467 contiguous, no extra copy is required for input layout transformation. 468 Otherwise A_shard needs to be copied once. 469 """ 470 out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) 471 472 if len(biases) != len(Bs): 473 raise ValueError("len(biases) must be the same as len(Bs)") 474 if len(result_scales) != len(Bs): 475 raise ValueError("len(result_scales) must be the same as len(Bs)") 476 if len(out_dtypes) != len(Bs): 477 raise ValueError("len(out_dtypes) must be the same as len(Bs)") 478 if len(use_fast_accum) != len(Bs): 479 raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)") 480 481 if _is_test_mode: 482 return _fused_all_gather_scaled_matmul_fallback( 483 A_shard, 484 Bs, 485 A_scale, 486 B_scales, 487 gather_dim, 488 group_name, 489 biases, 490 result_scales, 491 out_dtypes, 492 use_fast_accum, 493 ) 494 495 with torch.profiler.record_function("fused_all_gather_scaled_matmul"): 496 return _fused_all_gather_matmul_impl( 497 torch.ops.aten._scaled_mm.out, 498 A_shard, 499 Bs, 500 [ 501 { 502 "scale_a": A_scale, 503 "scale_b": B_scale, 504 "bias": bias, 505 "scale_result": result_scale, 506 "out_dtype": out_dtype, 507 "use_fast_accum": fast_accum, 508 } 509 for B_scale, bias, result_scale, out_dtype, fast_accum in zip( 510 B_scales, biases, result_scales, out_dtypes, use_fast_accum 511 ) 512 ], 513 out_dtypes, 514 gather_dim, 515 group_name, 516 ) 517 518 519def make_contiguous_for_perm( 520 t: torch.Tensor, 521 perm: List[int], 522) -> torch.Tensor: 523 """ 524 Restride `t` such that `t.permute(perm)` is contiguous. 525 """ 526 inv_perm = [0] * len(perm) 527 for i, p in enumerate(perm): 528 inv_perm[p] = i 529 return t.permute(perm).contiguous().permute(inv_perm) 530 531 532def restride_A_shard_for_fused_all_gather_matmul( 533 t: torch.Tensor, 534 gather_dim: int, 535) -> torch.Tensor: 536 """ 537 Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf. 538 See the doc for `fused_all_gather_matmul` for detail. 539 """ 540 perm = list(range(len(t.shape))) 541 perm.insert(0, perm.pop(gather_dim)) 542 return make_contiguous_for_perm(t, perm) 543 544 545def _fused_matmul_reduce_scatter_impl( 546 mm_out_op: torch._ops.OpOverload, 547 A: torch.Tensor, 548 B: torch.Tensor, 549 kwargs: Dict[str, Any], 550 out_dtype: Optional[torch.dtype], 551 reduce_op: str, 552 scatter_dim: int, 553 group_name: str, 554) -> torch.Tensor: 555 if A.dim() < 2: 556 raise ValueError("A_shard must be a matrix") 557 if scatter_dim < 0 or scatter_dim >= A.dim(): 558 raise ValueError("Invalid gather_dim") 559 if B.dim() != 2: 560 raise ValueError("B must be a matrix") 561 if reduce_op == "sum": 562 reduce_fn = partial(torch.sum, dim=0) 563 elif reduce_op == "avg": 564 reduce_fn = partial(torch.mean, dim=0) 565 else: 566 raise ValueError("reduce_op must be sum or avg") 567 568 group = c10d._resolve_process_group(group_name) 569 out_shape = [*A.shape[:-1], B.shape[1]] 570 out_shape[scatter_dim] //= group.size() 571 572 # Move the gather_dim to the front and flatten the tensor into a 2D matrix 573 x = A.movedim(scatter_dim, 0) 574 leading_dims = [group.size()] + list(x.shape[:-1]) 575 leading_dims[1] //= group.size() 576 x = x.flatten(0, -2) 577 shards = x.chunk(group.size()) 578 579 # Computing block-wise matmul along the first dim of A 580 def chunk_producer(rank: int, out: torch.Tensor) -> None: 581 mm_out_op(shards[rank], B, **kwargs, out=out) 582 583 stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) 584 585 _pipelined_produce_and_all2all( 586 chunk_producer, 587 stacked_partials, 588 group_name, 589 ) 590 # Ensures that the transpose and reduction produce contiguous result 591 # in a single reduction kernel. 592 return reduce_fn( 593 stacked_partials.view(*leading_dims, -1) 594 .movedim(1, scatter_dim + 1) 595 .movedim(0, scatter_dim), 596 dim=scatter_dim, 597 ) 598 599 600@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") 601def _fused_matmul_reduce_scatter_fallback( 602 A: torch.Tensor, 603 B: torch.Tensor, 604 reduce_op: str, 605 scatter_dim: int, 606 group_name: str, 607) -> torch.Tensor: 608 res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) 609 res = funcol.wait_tensor(res) 610 return res 611 612 613@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") 614def _fused_matmul_reduce_scatter( 615 A: torch.Tensor, 616 B: torch.Tensor, 617 reduce_op: str, 618 scatter_dim: int, 619 group_name: str, 620) -> torch.Tensor: 621 """ 622 Perform the following logic with micro-pipelined computation and 623 communication: 624 625 reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) 626 627 Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no 628 extra copy is required for input layout transformation. Otherwise A needs 629 to be copied once. 630 """ 631 if _is_test_mode: 632 return _fused_matmul_reduce_scatter_fallback( 633 A, B, reduce_op, scatter_dim, group_name 634 ) 635 636 with torch.profiler.record_function("fused_matmul_reduce_scatter"): 637 return _fused_matmul_reduce_scatter_impl( 638 mm_out_op=torch.ops.aten.mm.out, 639 A=A, 640 B=B, 641 kwargs={}, 642 out_dtype=A.dtype, 643 reduce_op=reduce_op, 644 scatter_dim=scatter_dim, 645 group_name=group_name, 646 ) 647 648 649@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta") 650def _fused_scaled_matmul_reduce_scatter_fallback( 651 A: torch.Tensor, 652 B: torch.Tensor, 653 A_scale: torch.Tensor, 654 B_scale: torch.Tensor, 655 reduce_op: str, 656 scatter_dim: int, 657 group_name: str, 658 bias: Optional[torch.Tensor] = None, 659 result_scale: Optional[torch.Tensor] = None, 660 out_dtype: Optional[torch.dtype] = None, 661 use_fast_accum: bool = False, 662) -> torch.Tensor: 663 C = torch._scaled_mm( 664 A.flatten(0, -2).contiguous(), 665 B, 666 A_scale, 667 B_scale, 668 bias, 669 result_scale, 670 out_dtype, 671 use_fast_accum, 672 ) 673 C = C.view(*A.shape[:-1], B.shape[1]) 674 res = funcol.reduce_scatter_tensor( 675 C, 676 reduce_op, 677 scatter_dim, 678 group_name, 679 ) 680 res = funcol.wait_tensor(res) 681 return res 682 683 684@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA") 685def _fused_scaled_matmul_reduce_scatter( 686 A: torch.Tensor, 687 B: torch.Tensor, 688 A_scale: torch.Tensor, 689 B_scale: torch.Tensor, 690 reduce_op: str, 691 scatter_dim: int, 692 group_name: str, 693 bias: Optional[torch.Tensor] = None, 694 result_scale: Optional[torch.Tensor] = None, 695 out_dtype: Optional[torch.dtype] = None, 696 use_fast_accum: bool = False, 697) -> torch.Tensor: 698 if _is_test_mode: 699 return _fused_scaled_matmul_reduce_scatter_fallback( 700 A, 701 B, 702 A_scale, 703 B_scale, 704 reduce_op, 705 scatter_dim, 706 group_name, 707 bias, 708 result_scale, 709 out_dtype, 710 use_fast_accum, 711 ) 712 with torch.profiler.record_function("fused_matmul_reduce_scatter"): 713 return _fused_matmul_reduce_scatter_impl( 714 mm_out_op=torch.ops.aten._scaled_mm.out, 715 A=A, 716 B=B, 717 kwargs={ 718 "scale_a": A_scale, 719 "scale_b": B_scale, 720 "bias": bias, 721 "scale_result": result_scale, 722 "out_dtype": out_dtype, 723 "use_fast_accum": use_fast_accum, 724 }, 725 out_dtype=out_dtype, 726 reduce_op=reduce_op, 727 scatter_dim=scatter_dim, 728 group_name=group_name, 729 ) 730 731 732def restride_A_for_fused_matmul_reduce_scatter( 733 t: torch.Tensor, 734 gather_dim: int, 735) -> torch.Tensor: 736 """ 737 Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal 738 perf. See the doc for `fused_matmul_reduce_scatter` for detail. 739 """ 740 perm = list(range(len(t.shape))) 741 perm.insert(0, perm.pop(gather_dim)) 742 return make_contiguous_for_perm(t, perm) 743 744 745def _maybe_convert_scalar_types_to_dtypes( 746 scalar_types: List[Any], 747) -> List[Optional[torch.dtype]]: 748 """ 749 When a list of `torch.dtype`s is passed through the dispatcher as 750 `ScalarType[]`, it is converted to a list of scalar type enum values. This 751 function converts it back to a list of `torch.dtype`s. 752 """ 753 # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h 754 _SCALAR_TYPE_TO_DTYPE = { 755 0: torch.uint8, 756 1: torch.int8, 757 2: torch.short, 758 3: torch.int, 759 4: torch.int64, 760 5: torch.half, 761 6: torch.float, 762 7: torch.double, 763 8: torch.complex32, 764 9: torch.complex64, 765 10: torch.complex128, 766 11: torch.bool, 767 12: torch.qint8, 768 13: torch.quint8, 769 14: torch.qint32, 770 15: torch.bfloat16, 771 16: torch.float8_e5m2, 772 17: torch.float8_e4m3fn, 773 18: torch.float8_e5m2fnuz, 774 19: torch.float8_e4m3fnuz, 775 } 776 if any(not isinstance(x, (type(None), int)) for x in scalar_types): 777 return scalar_types 778 779 dtypes: List[Optional[torch.dtype]] = [] 780 for scalar_type in scalar_types: 781 if scalar_type is None: 782 dtypes.append(scalar_type) 783 elif scalar_type not in _SCALAR_TYPE_TO_DTYPE: 784 raise ValueError("Unrecognized scalar type {scalar_type}") 785 else: 786 dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type]) 787 return dtypes 788 789 790class Work(_Work): 791 def __init__(self) -> None: 792 super().__init__() 793 self.event = torch.cuda.Event() 794 self.event.record() 795 796 def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: 797 self.event.wait() 798 return True 799 800 801""" 802NOTE [low-contention collectives] 803When a collective is overlapped with abundant compute, it makes sense to 804prioritize reducing the contention between the collective and the overlapped 805compute, even at the cost of a slightly slower collective. 806 807Common collective implementations (e.g., NCCL without user buffer 808registration) optimize for throughput with no ambient compute. However, such 809implementations may not be optimal when they are overlapped with compute: 810- These implementations typically fuse the entire collective into a single 811kernel and reserve SM resources based on the most demanding portion of the 812collective, even when a large portion of the collective does not require this 813much resource. 814- These implementations often use SM-based P2P copy as opposed to copy 815engine-based P2P copy. Copy engine-based P2P copy may not have a significant 816advantage when there's no ambient compute. However, it may significantly 817improve overall resource utilization in the presence of ambient compute. 818 819When overlapped with intensive compute (e.g., persistent matmul kernels), the 820SM-usage of a collective can lead to inefficient overlapping. 821 822Low-contention collectives achieve their goals with the following strategies: 823- Use copy engine-based copy whenever possible. 824- Break down portions of a collective with different resource requirements 825into multiple kernels. This improves the overlapping efficiency at the cost 826of additional launching overhead. 827""" 828 829 830@torch.library.impl(lib, "_low_contention_all_gather", "Meta") 831def _low_contention_all_gather_meta( 832 tensor: torch.Tensor, 833 group_name: str, 834) -> torch.Tensor: 835 group_size = c10d._get_group_size_by_name(group_name) 836 return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) 837 838 839@torch.library.impl(lib, "_low_contention_all_gather", "CUDA") 840def _low_contention_all_gather( 841 tensor: torch.Tensor, 842 group_name: str, 843) -> torch.Tensor: 844 """ 845 Performs all-gather with symmetric memory in a low-contention fashion. 846 847 When `tensor` is already in symmetric memory: 848 - The collective is carried out without using SMs. 849 - No symmetric memory workspace is required. 850 851 When `tensor` is not in symmetric memory: 852 - An extra SM-based copy is performed to copy the input data into the 853 symmetric memory workspace. 854 - Symmetric memory workspace size requirement: the size of `tensor`. 855 """ 856 symm_mem = _SymmetricMemory.rendezvous(tensor) 857 if symm_mem is not None: 858 input_is_symm_mem = True 859 else: 860 symm_mem = get_symm_mem_workspace( 861 group_name, tensor.numel() * tensor.element_size() 862 ) 863 input_is_symm_mem = False 864 865 rank = symm_mem.rank 866 world_size = symm_mem.world_size 867 868 output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) 869 chunks = output.chunk(world_size) 870 871 _get_backend_stream().wait_stream(torch.cuda.current_stream()) 872 with torch.cuda.stream(_get_backend_stream()): 873 if not input_is_symm_mem: 874 local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) 875 local_buf.copy_(tensor) 876 # pull 877 symm_mem.barrier() 878 for step in range(0, world_size): 879 remote_rank = (rank - step) % world_size 880 src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) 881 chunks[remote_rank].copy_(src_buf) 882 symm_mem.barrier() 883 torch._C._distributed_c10d._register_work(output, Work()) 884 return output 885 886 887@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta") 888def _low_contention_reduce_scatter_meta( 889 tensor: torch.Tensor, 890 reduce_op: str, 891 group_name: str, 892) -> torch.Tensor: 893 group_size = c10d._get_group_size_by_name(group_name) 894 return tensor.unflatten(0, (group_size, -1)).mean(dim=0) 895 896 897def _low_contention_reduce_scatter_with_symm_mem_input( 898 tensor: torch.Tensor, 899 reduce_op: str, 900 symm_mem: _SymmetricMemory, 901) -> torch.Tensor: 902 rank = symm_mem.rank 903 world_size = symm_mem.world_size 904 905 assert tensor.shape[0] % world_size == 0 906 a2a_res = torch.empty_like(tensor) 907 chunks = a2a_res.chunk(world_size) 908 909 _get_backend_stream().wait_stream(torch.cuda.current_stream()) 910 with torch.cuda.stream(_get_backend_stream()): 911 # pull + offline reduction 912 symm_mem.barrier() 913 for step in range(0, world_size): 914 remote_rank = (rank - step) % world_size 915 src_buf = symm_mem.get_buffer( 916 remote_rank, 917 chunks[0].shape, 918 chunks[0].dtype, 919 chunks[0].numel() * rank, 920 ) 921 chunks[remote_rank].copy_(src_buf) 922 symm_mem.barrier() 923 924 ret = a2a_res.unflatten(0, (world_size, -1)) 925 if reduce_op == "sum": 926 ret = ret.sum(dim=0) 927 elif reduce_op == "avg": 928 ret = ret.mean(dim=0) 929 else: 930 raise ValueError(f"reduce_op ({reduce_op}) is not supported") 931 torch._C._distributed_c10d._register_work(ret, Work()) 932 return ret 933 934 935def _low_contention_reduce_scatter_with_workspace( 936 tensor: torch.Tensor, 937 reduce_op: str, 938 workspace: _SymmetricMemory, 939) -> torch.Tensor: 940 rank = workspace.rank 941 world_size = workspace.world_size 942 943 assert tensor.shape[0] % world_size == 0 944 chunks = tensor.chunk(world_size) 945 946 _get_backend_stream().wait_stream(torch.cuda.current_stream()) 947 with torch.cuda.stream(_get_backend_stream()): 948 # push + offline reduction 949 workspace.barrier() 950 for step in range(0, world_size): 951 remote_rank = (rank - step) % world_size 952 dst_buf = workspace.get_buffer( 953 remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank 954 ) 955 dst_buf.copy_(chunks[remote_rank]) 956 workspace.barrier() 957 958 buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) 959 ret = buf.unflatten(0, (world_size, -1)) 960 if reduce_op == "sum": 961 ret = ret.sum(dim=0) 962 elif reduce_op == "avg": 963 ret = ret.mean(dim=0) 964 else: 965 raise ValueError(f"reduce_op ({reduce_op}) is not supported") 966 torch._C._distributed_c10d._register_work(ret, Work()) 967 return ret 968 969 970@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") 971def _low_contention_reduce_scatter( 972 tensor: torch.Tensor, 973 reduce_op: str, 974 group_name: str, 975) -> torch.Tensor: 976 """ 977 Performs reduce-scatter with symmetric memory in a low-contention fashion. 978 979 This implementation performs a P2P-based all-to-all followed by an offline 980 reduction. 981 982 When `tensor` is already in symmetric memory: 983 - Pull-based all-to-all is used. 984 - No symmetric memory workspace is required. 985 986 When `tensor` is not in symmetric memory: 987 - Push-based all-to-all is used. 988 - Symmetric memory workspace size requirement: the size of `tensor`. 989 990 SM-usage: 991 - SM-based copy of the rank's own chunk for the all-to-all. 992 - Reduction on the all-to-all result. 993 994 TODO(yifu): the SM-based copy can be avoided with a list-based reduction 995 kernel. 996 """ 997 symm_mem = _SymmetricMemory.rendezvous(tensor) 998 if symm_mem is not None: 999 return _low_contention_reduce_scatter_with_symm_mem_input( 1000 tensor, reduce_op, symm_mem 1001 ) 1002 else: 1003 workspace = get_symm_mem_workspace( 1004 group_name, tensor.numel() * tensor.element_size() 1005 ) 1006 return _low_contention_reduce_scatter_with_workspace( 1007 tensor, reduce_op, workspace 1008 ) 1009