xref: /aosp_15_r20/external/pytorch/torch/distributed/_symmetric_memory/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2import socket
3import uuid
4from contextlib import contextmanager
5from datetime import timedelta
6from functools import partial
7from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
8
9import torch
10import torch.distributed._functional_collectives as funcol
11import torch.distributed.distributed_c10d as c10d
12from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work
13
14
15_group_name_to_store: Dict[str, c10d.Store] = {}
16
17
18def enable_symm_mem_for_group(group_name: str) -> None:
19    """
20    Enables symmetric memory for a process group.
21
22    Args:
23        group_name (str): the name of the process group.
24    """
25    if group_name in _group_name_to_store:
26        return
27
28    group = c10d._resolve_process_group(group_name)
29    global_ranks = sorted(c10d._world.pg_group_ranks[group].keys())
30    # Different subgroups with the same name should use different stores
31    global_ranks_str = "_".join(map(str, global_ranks))
32    store = c10d.PrefixStore(
33        f"symmetric_memory-{global_ranks_str}",
34        c10d._get_process_group_store(group),
35    )
36    # Use one store-based broadcast to bootstrap a file store from the process
37    # and simultaneously verify that all ranks are on the same host.
38    hostname = socket.gethostname()
39    if group.rank() == 0:
40        uid = str(uuid.uuid4())
41        msg = f"{hostname}/{uid}"
42        store.set("init", msg)
43    else:
44        msg = store.get("init").decode("utf-8")
45        tokens = msg.split("/")
46        assert len(tokens) == 2, tokens
47        rank_0_hostname, uid = tokens
48        if hostname != rank_0_hostname:
49            raise RuntimeError(
50                "init_symmetric_memory_for_process_group() failed for "
51                f'group "{group_name}". Rank 0 and rank {group.rank()} '
52                f"are on different hosts ({rank_0_hostname} and {hostname})"
53            )
54    store = torch._C._distributed_c10d.FileStore(f"/tmp/{uid}", group.size())
55    # TODO: check device connectiivity
56    _group_name_to_store[group_name] = store
57    _SymmetricMemory.set_group_info(
58        group_name,
59        group.rank(),
60        group.size(),
61        store,
62    )
63
64
65_is_test_mode: bool = False
66
67
68@contextmanager
69def _test_mode() -> Generator[None, None, None]:
70    """
71    Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops
72    defined in the ``symm_mem`` namespace to use fallback implementations.
73
74    The context manager is not thread safe.
75    """
76    global _is_test_mode
77    prev = _is_test_mode
78    try:
79        _is_test_mode = True
80        yield
81    finally:
82        _is_test_mode = prev
83
84
85def is_symm_mem_enabled_for_group(group_name: str) -> bool:
86    """
87    Check if symmetric memory is enabled for a process group.
88
89    Args:
90        group_name (str): the name of the process group.
91    """
92    return _is_test_mode or group_name in _group_name_to_store
93
94
95_group_name_to_workspace_tensor: Dict[str, Optional[torch.Tensor]] = {}
96
97
98def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
99    """
100    Get the symmetric memory workspace associated with the process group. If
101    ``min_size`` is greater than the workspace associated with ``group_name``,
102    the workspace will be re-allocated and re-rendezvous'd.
103
104    Args:
105        group_name (str): the name of the process group.
106        min_size (int): the size requirement for the workspace in bytes.
107
108    Returns:
109        _SymmetricMemory: the symmetric memory workspace associated with the
110        group.
111    """
112    tensor = _group_name_to_workspace_tensor.get(group_name)
113    size = tensor.numel() * tensor.element_size() if tensor is not None else 0
114    if tensor is None or size < min_size:
115        tensor = _SymmetricMemory.empty_strided_p2p(
116            (max(size, min_size),),
117            [1],
118            torch.uint8,
119            torch.device(f"cuda:{torch.cuda.current_device()}"),
120            group_name,
121        )
122        _group_name_to_workspace_tensor[group_name] = tensor
123    return _SymmetricMemory.rendezvous(tensor)
124
125
126_backend_stream: Optional[torch.cuda.Stream] = None
127
128
129def _get_backend_stream() -> torch.cuda.Stream:
130    global _backend_stream
131    if _backend_stream is None:
132        _backend_stream = torch.cuda.Stream()
133    return _backend_stream
134
135
136def _pipelined_all_gather_and_consume(
137    shard: torch.Tensor,
138    shard_consumer: Callable[[torch.Tensor, int], None],
139    ag_out: torch.Tensor,
140    group_name: str,
141) -> None:
142    """
143    Perform the following logic with micro-pipelined computation and
144    communication:
145
146        tensor = all_gather_tensor(shard, gather_dim=1, group=group)
147        chunks = tensor.chunk(group.size())
148        for src_rank, chunk in enumerate(chunks):
149            shard_consumer(chunk, src_rank)
150
151    NOTE:
152    - The shard passed to shard consumer will always be contiguous.
153    """
154    p2p_workspace_size_req = shard.numel() * shard.element_size()
155    symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
156    group_size = symm_mem.world_size
157    rank = symm_mem.rank
158
159    backend_stream = _get_backend_stream()
160    backend_stream.wait_stream(torch.cuda.current_stream())
161    local_p2p_buf = symm_mem.get_buffer(rank, shard.shape, shard.dtype)
162
163    chunks = ag_out.chunk(group_size)
164
165    # While consuming local shard, copy it to the local p2p buffer
166    # in another stream.
167    shard_consumer(shard, rank)
168    chunks[rank].copy_(shard)
169
170    with torch.cuda.stream(backend_stream):
171        local_p2p_buf.copy_(shard)
172        symm_mem.barrier(channel=0)
173    torch.cuda.current_stream().wait_stream(backend_stream)
174
175    # At this point, all ranks have copied their local shard to
176    # their local p2p buffer. Each rank can now copy and consume
177    # remote shards.
178    for step in range(1, group_size):
179        if step % 2 == 0:
180            stream = torch.cuda.current_stream()
181        else:
182            stream = backend_stream
183        remote_rank = (step + rank) % group_size
184        remote_p2p_buf = symm_mem.get_buffer(remote_rank, shard.shape, shard.dtype)
185        with torch.cuda.stream(stream):
186            chunks[remote_rank].copy_(remote_p2p_buf)
187            shard_consumer(chunks[remote_rank], remote_rank)
188
189    with torch.cuda.stream(backend_stream):
190        symm_mem.barrier(channel=group_size % 2)
191    torch.cuda.current_stream().wait_stream(backend_stream)
192
193
194def _pipelined_produce_and_all2all(
195    chunk_producer: Callable[[int, torch.Tensor], None],
196    output: torch.Tensor,
197    group_name: str,
198) -> None:
199    """
200    Perform the following logic with micro-pipelined computation and
201    communication:
202
203        chunks = [
204            chunk_producer(dst_rank, chunks[dst_rank])
205            for dst_rank in range(group_size):
206        ]
207        dist.all_to_all_single(output=output, input=torch.cat(chunks))
208    """
209    out_chunks = output.chunk(c10d._get_group_size_by_name(group_name))
210    p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2
211    symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req)
212    group_size = symm_mem.world_size
213    rank = symm_mem.rank
214
215    backend_stream = _get_backend_stream()
216    backend_stream.wait_stream(torch.cuda.current_stream())
217
218    def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
219        assert idx in (0, 1)
220        offset = 0 if idx == 0 else out_chunks[0].numel()
221        return symm_mem.get_buffer(
222            rank, out_chunks[0].shape, out_chunks[0].dtype, offset
223        )
224
225    # Prepare two local p2p buffers, so that a remote rank can pull the result
226    # of step [i] in one p2p buffer while the local rank can compute the
227    # result of step [i+1] and write it directly the other p2p buffer.
228    local_p2p_buf_0 = get_p2p_buf(rank, 0)
229    local_p2p_buf_1 = get_p2p_buf(rank, 1)
230
231    for step in range(1, group_size):
232        remote_rank = (rank - step) % group_size
233        if step % 2 == 0:
234            stream = torch.cuda.current_stream()
235            other_stream = backend_stream
236            p2p_buf = local_p2p_buf_1
237            remote_p2p_buf = get_p2p_buf(remote_rank, 1)
238        else:
239            stream = backend_stream
240            other_stream = torch.cuda.current_stream()
241            p2p_buf = local_p2p_buf_0
242            remote_p2p_buf = get_p2p_buf(remote_rank, 0)
243        with torch.cuda.stream(stream):
244            chunk_producer((rank + step) % group_size, p2p_buf)
245            symm_mem.barrier(channel=step % 2)
246            # Make the other stream to wait for the barrier on the current
247            # stream to finish before chunk_producer to avoid the compute
248            # delaying the barrier.
249            other_stream.wait_stream(stream)
250            out_chunks[remote_rank].copy_(remote_p2p_buf)
251
252    chunk_producer(rank, out_chunks[rank])
253    torch.cuda.current_stream().wait_stream(backend_stream)
254
255
256lib = torch.library.Library("symm_mem", "DEF")  # noqa: TOR901
257lib.define(
258    "fused_all_gather_matmul(Tensor A, Tensor[] Bs, int gather_dim, str group_name) -> (Tensor, Tensor[])"
259)
260lib.define(
261    "fused_all_gather_scaled_matmul("
262    "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, "
263    "int gather_dim, str group_name, "
264    "Tensor?[] biases, "
265    "Tensor?[] result_scales, "
266    "ScalarType?[] out_dtypes, "
267    "bool[] use_fast_accum) -> (Tensor, Tensor[])"
268)
269lib.define(
270    "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor"
271)
272lib.define(
273    "fused_scaled_matmul_reduce_scatter("
274    "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, "
275    "str reduce_op, int scatter_dim, str group_name, "
276    "Tensor? bias = None, "
277    "Tensor? result_scale = None, "
278    "ScalarType? out_dtype = None, "
279    "bool use_fast_accum = False) -> Tensor"
280)
281lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor")
282lib.define(
283    "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
284)
285
286
287def _fused_all_gather_matmul_impl(
288    mm_out_op: torch._ops.OpOverload,
289    A_shard: torch.Tensor,
290    Bs: List[torch.Tensor],
291    kwargs_list: List[Dict[str, Any]],
292    out_dtypes: List[Optional[torch.dtype]],
293    gather_dim: int,
294    group_name: str,
295) -> Tuple[torch.Tensor, List[torch.Tensor]]:
296    if A_shard.dim() < 2:
297        raise ValueError("A_shard must be a matrix")
298    for B in Bs:
299        if B.dim() != 2:
300            raise ValueError("B must be a matrix")
301    if len(out_dtypes) != len(Bs):
302        raise ValueError("len(out_types) must be the same as len(Bs)")
303    if len(kwargs_list) != len(Bs):
304        raise ValueError("len(kwargs_list) must be the same as len(Bs)")
305    if gather_dim < 0 or gather_dim >= A_shard.dim():
306        raise ValueError("Invalid gather_dim")
307
308    group = c10d._resolve_process_group(group_name)
309
310    # Move the gather_dim to the front and flatten the tensor into a 2D matrix.
311    # The flattened tensor doesn't need to be contiguous (for computation
312    # efficiency), as _pipelined_all_gather_and_consume guarantees that shards
313    # passed to shard_consumer are contiguous.
314    x = A_shard.movedim(gather_dim, 0)
315    leading_dims = [group.size()] + list(x.shape[:-1])
316    x = x.flatten(0, -2)
317
318    # Helper function for reverting the above transformation
319    def unflatten(t: torch.Tensor) -> torch.Tensor:
320        return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim)
321
322    ag_out = x.new_empty(
323        x.shape[0] * group.size(),
324        x.shape[1],
325    )
326    outputs = [
327        x.new_empty(x.shape[0] * group.size(), B.shape[1], dtype=out_dtype or B.dtype)
328        for B, out_dtype in zip(Bs, out_dtypes)
329    ]
330    output_shards = [output.chunk(group.size()) for output in outputs]
331
332    # Computing block-wise matmul along the first dim of A
333    def shard_consumer(shard: torch.Tensor, rank: int) -> None:
334        for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
335            mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank])
336
337    _pipelined_all_gather_and_consume(
338        x,
339        shard_consumer,
340        ag_out,
341        group_name,
342    )
343    return unflatten(ag_out), [unflatten(output) for output in outputs]
344
345
346@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
347def _fused_all_gather_matmul_fallback(
348    A_shard: torch.Tensor,
349    Bs: List[torch.Tensor],
350    gather_dim: int,
351    group_name: str,
352) -> Tuple[torch.Tensor, List[torch.Tensor]]:
353    group_size = c10d._get_group_size_by_name(group_name)
354    A = torch.ops._c10d_functional.all_gather_into_tensor(
355        A_shard.contiguous(), group_size, group_name
356    )
357    A = torch.ops._c10d_functional.wait_tensor(A)
358    A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
359    return A.movedim(0, gather_dim), [
360        torch.matmul(A, B).movedim(0, gather_dim) for B in Bs
361    ]
362
363
364@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA")
365def _fused_all_gather_matmul(
366    A_shard: torch.Tensor,
367    Bs: List[torch.Tensor],
368    gather_dim: int,
369    group_name: str,
370) -> Tuple[torch.Tensor, List[torch.Tensor]]:
371    """
372    Perform the following logic with micro-pipelined computation and
373    communication:
374
375        all_gather_tensor(A_shard, gather_dim, group_name) @ B
376
377    Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is
378    contiguous, no extra copy is required for input layout transformation.
379    Otherwise A_shard needs to be copied once.
380    """
381    if _is_test_mode:
382        return _fused_all_gather_matmul_fallback(A_shard, Bs, gather_dim, group_name)
383
384    with torch.profiler.record_function("fused_all_gather_matmul"):
385        return _fused_all_gather_matmul_impl(
386            torch.ops.aten.mm.out,
387            A_shard,
388            Bs,
389            [{} for B in Bs],
390            [B.dtype for B in Bs],
391            gather_dim,
392            group_name,
393        )
394
395
396@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
397def _fused_all_gather_scaled_matmul_fallback(
398    A_shard: torch.Tensor,
399    Bs: List[torch.Tensor],
400    A_scale: torch.Tensor,
401    B_scales: List[torch.Tensor],
402    gather_dim: int,
403    group_name: str,
404    biases: List[Optional[torch.Tensor]],
405    result_scales: List[Optional[torch.Tensor]],
406    out_dtypes: List[Optional[torch.dtype]],
407    use_fast_accum: List[bool],
408) -> Tuple[torch.Tensor, List[torch.Tensor]]:
409    out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
410
411    group_size = c10d._get_group_size_by_name(group_name)
412    A = torch.ops._c10d_functional.all_gather_into_tensor(
413        A_shard.contiguous(), group_size, group_name
414    )
415    A = torch.ops._c10d_functional.wait_tensor(A)
416    A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1)
417
418    def scaled_matmul(
419        A: torch.Tensor,
420        B: torch.Tensor,
421        A_scale: torch.Tensor,
422        B_scale: torch.Tensor,
423        bias: Optional[torch.Tensor],
424        result_scale: Optional[torch.Tensor],
425        out_dtype: Optional[torch.dtype],
426        use_fast_accum: bool,
427    ) -> torch.Tensor:
428        leading_dims = A.shape[:-1]
429        res = torch.ops.aten._scaled_mm(
430            A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype
431        )
432        return res.unflatten(0, leading_dims)
433
434    return A.movedim(0, gather_dim), [
435        scaled_matmul(
436            A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum
437        ).movedim(0, gather_dim)
438        for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip(
439            Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum
440        )
441    ]
442
443
444@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
445def _fused_all_gather_scaled_matmul(
446    A_shard: torch.Tensor,
447    Bs: List[torch.Tensor],
448    A_scale: torch.Tensor,
449    B_scales: List[torch.Tensor],
450    gather_dim: int,
451    group_name: str,
452    biases: List[Optional[torch.Tensor]],
453    result_scales: List[Optional[torch.Tensor]],
454    out_dtypes: List[Optional[torch.dtype]],
455    use_fast_accum: List[bool],
456) -> Tuple[torch.Tensor, List[torch.Tensor]]:
457    """
458    Perform the following logic with micro-pipelined computation and
459    communication:
460
461        A = all_gather_tensor(A_shard, gather_dim, group_name)
462        leading_dims = A.shape[:-1]
463        res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale)
464        res = res.unflatten(0, leading_dims)
465
466    Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is
467    contiguous, no extra copy is required for input layout transformation.
468    Otherwise A_shard needs to be copied once.
469    """
470    out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
471
472    if len(biases) != len(Bs):
473        raise ValueError("len(biases) must be the same as len(Bs)")
474    if len(result_scales) != len(Bs):
475        raise ValueError("len(result_scales) must be the same as len(Bs)")
476    if len(out_dtypes) != len(Bs):
477        raise ValueError("len(out_dtypes) must be the same as len(Bs)")
478    if len(use_fast_accum) != len(Bs):
479        raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)")
480
481    if _is_test_mode:
482        return _fused_all_gather_scaled_matmul_fallback(
483            A_shard,
484            Bs,
485            A_scale,
486            B_scales,
487            gather_dim,
488            group_name,
489            biases,
490            result_scales,
491            out_dtypes,
492            use_fast_accum,
493        )
494
495    with torch.profiler.record_function("fused_all_gather_scaled_matmul"):
496        return _fused_all_gather_matmul_impl(
497            torch.ops.aten._scaled_mm.out,
498            A_shard,
499            Bs,
500            [
501                {
502                    "scale_a": A_scale,
503                    "scale_b": B_scale,
504                    "bias": bias,
505                    "scale_result": result_scale,
506                    "out_dtype": out_dtype,
507                    "use_fast_accum": fast_accum,
508                }
509                for B_scale, bias, result_scale, out_dtype, fast_accum in zip(
510                    B_scales, biases, result_scales, out_dtypes, use_fast_accum
511                )
512            ],
513            out_dtypes,
514            gather_dim,
515            group_name,
516        )
517
518
519def make_contiguous_for_perm(
520    t: torch.Tensor,
521    perm: List[int],
522) -> torch.Tensor:
523    """
524    Restride `t` such that `t.permute(perm)` is contiguous.
525    """
526    inv_perm = [0] * len(perm)
527    for i, p in enumerate(perm):
528        inv_perm[p] = i
529    return t.permute(perm).contiguous().permute(inv_perm)
530
531
532def restride_A_shard_for_fused_all_gather_matmul(
533    t: torch.Tensor,
534    gather_dim: int,
535) -> torch.Tensor:
536    """
537    Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf.
538    See the doc for `fused_all_gather_matmul` for detail.
539    """
540    perm = list(range(len(t.shape)))
541    perm.insert(0, perm.pop(gather_dim))
542    return make_contiguous_for_perm(t, perm)
543
544
545def _fused_matmul_reduce_scatter_impl(
546    mm_out_op: torch._ops.OpOverload,
547    A: torch.Tensor,
548    B: torch.Tensor,
549    kwargs: Dict[str, Any],
550    out_dtype: Optional[torch.dtype],
551    reduce_op: str,
552    scatter_dim: int,
553    group_name: str,
554) -> torch.Tensor:
555    if A.dim() < 2:
556        raise ValueError("A_shard must be a matrix")
557    if scatter_dim < 0 or scatter_dim >= A.dim():
558        raise ValueError("Invalid gather_dim")
559    if B.dim() != 2:
560        raise ValueError("B must be a matrix")
561    if reduce_op == "sum":
562        reduce_fn = partial(torch.sum, dim=0)
563    elif reduce_op == "avg":
564        reduce_fn = partial(torch.mean, dim=0)
565    else:
566        raise ValueError("reduce_op must be sum or avg")
567
568    group = c10d._resolve_process_group(group_name)
569    out_shape = [*A.shape[:-1], B.shape[1]]
570    out_shape[scatter_dim] //= group.size()
571
572    # Move the gather_dim to the front and flatten the tensor into a 2D matrix
573    x = A.movedim(scatter_dim, 0)
574    leading_dims = [group.size()] + list(x.shape[:-1])
575    leading_dims[1] //= group.size()
576    x = x.flatten(0, -2)
577    shards = x.chunk(group.size())
578
579    # Computing block-wise matmul along the first dim of A
580    def chunk_producer(rank: int, out: torch.Tensor) -> None:
581        mm_out_op(shards[rank], B, **kwargs, out=out)
582
583    stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype)
584
585    _pipelined_produce_and_all2all(
586        chunk_producer,
587        stacked_partials,
588        group_name,
589    )
590    # Ensures that the transpose and reduction produce contiguous result
591    # in a single reduction kernel.
592    return reduce_fn(
593        stacked_partials.view(*leading_dims, -1)
594        .movedim(1, scatter_dim + 1)
595        .movedim(0, scatter_dim),
596        dim=scatter_dim,
597    )
598
599
600@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta")
601def _fused_matmul_reduce_scatter_fallback(
602    A: torch.Tensor,
603    B: torch.Tensor,
604    reduce_op: str,
605    scatter_dim: int,
606    group_name: str,
607) -> torch.Tensor:
608    res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
609    res = funcol.wait_tensor(res)
610    return res
611
612
613@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA")
614def _fused_matmul_reduce_scatter(
615    A: torch.Tensor,
616    B: torch.Tensor,
617    reduce_op: str,
618    scatter_dim: int,
619    group_name: str,
620) -> torch.Tensor:
621    """
622    Perform the following logic with micro-pipelined computation and
623    communication:
624
625        reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name)
626
627    Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no
628    extra copy is required for input layout transformation. Otherwise A needs
629    to be copied once.
630    """
631    if _is_test_mode:
632        return _fused_matmul_reduce_scatter_fallback(
633            A, B, reduce_op, scatter_dim, group_name
634        )
635
636    with torch.profiler.record_function("fused_matmul_reduce_scatter"):
637        return _fused_matmul_reduce_scatter_impl(
638            mm_out_op=torch.ops.aten.mm.out,
639            A=A,
640            B=B,
641            kwargs={},
642            out_dtype=A.dtype,
643            reduce_op=reduce_op,
644            scatter_dim=scatter_dim,
645            group_name=group_name,
646        )
647
648
649@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta")
650def _fused_scaled_matmul_reduce_scatter_fallback(
651    A: torch.Tensor,
652    B: torch.Tensor,
653    A_scale: torch.Tensor,
654    B_scale: torch.Tensor,
655    reduce_op: str,
656    scatter_dim: int,
657    group_name: str,
658    bias: Optional[torch.Tensor] = None,
659    result_scale: Optional[torch.Tensor] = None,
660    out_dtype: Optional[torch.dtype] = None,
661    use_fast_accum: bool = False,
662) -> torch.Tensor:
663    C = torch._scaled_mm(
664        A.flatten(0, -2).contiguous(),
665        B,
666        A_scale,
667        B_scale,
668        bias,
669        result_scale,
670        out_dtype,
671        use_fast_accum,
672    )
673    C = C.view(*A.shape[:-1], B.shape[1])
674    res = funcol.reduce_scatter_tensor(
675        C,
676        reduce_op,
677        scatter_dim,
678        group_name,
679    )
680    res = funcol.wait_tensor(res)
681    return res
682
683
684@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA")
685def _fused_scaled_matmul_reduce_scatter(
686    A: torch.Tensor,
687    B: torch.Tensor,
688    A_scale: torch.Tensor,
689    B_scale: torch.Tensor,
690    reduce_op: str,
691    scatter_dim: int,
692    group_name: str,
693    bias: Optional[torch.Tensor] = None,
694    result_scale: Optional[torch.Tensor] = None,
695    out_dtype: Optional[torch.dtype] = None,
696    use_fast_accum: bool = False,
697) -> torch.Tensor:
698    if _is_test_mode:
699        return _fused_scaled_matmul_reduce_scatter_fallback(
700            A,
701            B,
702            A_scale,
703            B_scale,
704            reduce_op,
705            scatter_dim,
706            group_name,
707            bias,
708            result_scale,
709            out_dtype,
710            use_fast_accum,
711        )
712    with torch.profiler.record_function("fused_matmul_reduce_scatter"):
713        return _fused_matmul_reduce_scatter_impl(
714            mm_out_op=torch.ops.aten._scaled_mm.out,
715            A=A,
716            B=B,
717            kwargs={
718                "scale_a": A_scale,
719                "scale_b": B_scale,
720                "bias": bias,
721                "scale_result": result_scale,
722                "out_dtype": out_dtype,
723                "use_fast_accum": use_fast_accum,
724            },
725            out_dtype=out_dtype,
726            reduce_op=reduce_op,
727            scatter_dim=scatter_dim,
728            group_name=group_name,
729        )
730
731
732def restride_A_for_fused_matmul_reduce_scatter(
733    t: torch.Tensor,
734    gather_dim: int,
735) -> torch.Tensor:
736    """
737    Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal
738    perf. See the doc for `fused_matmul_reduce_scatter` for detail.
739    """
740    perm = list(range(len(t.shape)))
741    perm.insert(0, perm.pop(gather_dim))
742    return make_contiguous_for_perm(t, perm)
743
744
745def _maybe_convert_scalar_types_to_dtypes(
746    scalar_types: List[Any],
747) -> List[Optional[torch.dtype]]:
748    """
749    When a list of `torch.dtype`s is passed through the dispatcher as
750    `ScalarType[]`, it is converted to a list of scalar type enum values. This
751    function converts it back to a list of `torch.dtype`s.
752    """
753    # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
754    _SCALAR_TYPE_TO_DTYPE = {
755        0: torch.uint8,
756        1: torch.int8,
757        2: torch.short,
758        3: torch.int,
759        4: torch.int64,
760        5: torch.half,
761        6: torch.float,
762        7: torch.double,
763        8: torch.complex32,
764        9: torch.complex64,
765        10: torch.complex128,
766        11: torch.bool,
767        12: torch.qint8,
768        13: torch.quint8,
769        14: torch.qint32,
770        15: torch.bfloat16,
771        16: torch.float8_e5m2,
772        17: torch.float8_e4m3fn,
773        18: torch.float8_e5m2fnuz,
774        19: torch.float8_e4m3fnuz,
775    }
776    if any(not isinstance(x, (type(None), int)) for x in scalar_types):
777        return scalar_types
778
779    dtypes: List[Optional[torch.dtype]] = []
780    for scalar_type in scalar_types:
781        if scalar_type is None:
782            dtypes.append(scalar_type)
783        elif scalar_type not in _SCALAR_TYPE_TO_DTYPE:
784            raise ValueError("Unrecognized scalar type {scalar_type}")
785        else:
786            dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type])
787    return dtypes
788
789
790class Work(_Work):
791    def __init__(self) -> None:
792        super().__init__()
793        self.event = torch.cuda.Event()
794        self.event.record()
795
796    def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
797        self.event.wait()
798        return True
799
800
801"""
802NOTE [low-contention collectives]
803When a collective is overlapped with abundant compute, it makes sense to
804prioritize reducing the contention between the collective and the overlapped
805compute, even at the cost of a slightly slower collective.
806
807Common collective implementations (e.g., NCCL without user buffer
808registration) optimize for throughput with no ambient compute. However, such
809implementations may not be optimal when they are overlapped with compute:
810- These implementations typically fuse the entire collective into a single
811kernel and reserve SM resources based on the most demanding portion of the
812collective, even when a large portion of the collective does not require this
813much resource.
814- These implementations often use SM-based P2P copy as opposed to copy
815engine-based P2P copy. Copy engine-based P2P copy may not have a significant
816advantage when there's no ambient compute. However, it may significantly
817improve overall resource utilization in the presence of ambient compute.
818
819When overlapped with intensive compute (e.g., persistent matmul kernels), the
820SM-usage of a collective can lead to inefficient overlapping.
821
822Low-contention collectives achieve their goals with the following strategies:
823- Use copy engine-based copy whenever possible.
824- Break down portions of a collective with different resource requirements
825into multiple kernels. This improves the overlapping efficiency at the cost
826of additional launching overhead.
827"""
828
829
830@torch.library.impl(lib, "_low_contention_all_gather", "Meta")
831def _low_contention_all_gather_meta(
832    tensor: torch.Tensor,
833    group_name: str,
834) -> torch.Tensor:
835    group_size = c10d._get_group_size_by_name(group_name)
836    return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:])
837
838
839@torch.library.impl(lib, "_low_contention_all_gather", "CUDA")
840def _low_contention_all_gather(
841    tensor: torch.Tensor,
842    group_name: str,
843) -> torch.Tensor:
844    """
845    Performs all-gather with symmetric memory in a low-contention fashion.
846
847    When `tensor` is already in symmetric memory:
848        - The collective is carried out without using SMs.
849        - No symmetric memory workspace is required.
850
851    When `tensor` is not in symmetric memory:
852        - An extra SM-based copy is performed to copy the input data into the
853          symmetric memory workspace.
854        - Symmetric memory workspace size requirement: the size of `tensor`.
855    """
856    symm_mem = _SymmetricMemory.rendezvous(tensor)
857    if symm_mem is not None:
858        input_is_symm_mem = True
859    else:
860        symm_mem = get_symm_mem_workspace(
861            group_name, tensor.numel() * tensor.element_size()
862        )
863        input_is_symm_mem = False
864
865    rank = symm_mem.rank
866    world_size = symm_mem.world_size
867
868    output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:])
869    chunks = output.chunk(world_size)
870
871    _get_backend_stream().wait_stream(torch.cuda.current_stream())
872    with torch.cuda.stream(_get_backend_stream()):
873        if not input_is_symm_mem:
874            local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype)
875            local_buf.copy_(tensor)
876        # pull
877        symm_mem.barrier()
878        for step in range(0, world_size):
879            remote_rank = (rank - step) % world_size
880            src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype)
881            chunks[remote_rank].copy_(src_buf)
882        symm_mem.barrier()
883        torch._C._distributed_c10d._register_work(output, Work())
884        return output
885
886
887@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta")
888def _low_contention_reduce_scatter_meta(
889    tensor: torch.Tensor,
890    reduce_op: str,
891    group_name: str,
892) -> torch.Tensor:
893    group_size = c10d._get_group_size_by_name(group_name)
894    return tensor.unflatten(0, (group_size, -1)).mean(dim=0)
895
896
897def _low_contention_reduce_scatter_with_symm_mem_input(
898    tensor: torch.Tensor,
899    reduce_op: str,
900    symm_mem: _SymmetricMemory,
901) -> torch.Tensor:
902    rank = symm_mem.rank
903    world_size = symm_mem.world_size
904
905    assert tensor.shape[0] % world_size == 0
906    a2a_res = torch.empty_like(tensor)
907    chunks = a2a_res.chunk(world_size)
908
909    _get_backend_stream().wait_stream(torch.cuda.current_stream())
910    with torch.cuda.stream(_get_backend_stream()):
911        # pull + offline reduction
912        symm_mem.barrier()
913        for step in range(0, world_size):
914            remote_rank = (rank - step) % world_size
915            src_buf = symm_mem.get_buffer(
916                remote_rank,
917                chunks[0].shape,
918                chunks[0].dtype,
919                chunks[0].numel() * rank,
920            )
921            chunks[remote_rank].copy_(src_buf)
922        symm_mem.barrier()
923
924        ret = a2a_res.unflatten(0, (world_size, -1))
925        if reduce_op == "sum":
926            ret = ret.sum(dim=0)
927        elif reduce_op == "avg":
928            ret = ret.mean(dim=0)
929        else:
930            raise ValueError(f"reduce_op ({reduce_op}) is not supported")
931        torch._C._distributed_c10d._register_work(ret, Work())
932        return ret
933
934
935def _low_contention_reduce_scatter_with_workspace(
936    tensor: torch.Tensor,
937    reduce_op: str,
938    workspace: _SymmetricMemory,
939) -> torch.Tensor:
940    rank = workspace.rank
941    world_size = workspace.world_size
942
943    assert tensor.shape[0] % world_size == 0
944    chunks = tensor.chunk(world_size)
945
946    _get_backend_stream().wait_stream(torch.cuda.current_stream())
947    with torch.cuda.stream(_get_backend_stream()):
948        # push + offline reduction
949        workspace.barrier()
950        for step in range(0, world_size):
951            remote_rank = (rank - step) % world_size
952            dst_buf = workspace.get_buffer(
953                remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank
954            )
955            dst_buf.copy_(chunks[remote_rank])
956        workspace.barrier()
957
958        buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype)
959        ret = buf.unflatten(0, (world_size, -1))
960        if reduce_op == "sum":
961            ret = ret.sum(dim=0)
962        elif reduce_op == "avg":
963            ret = ret.mean(dim=0)
964        else:
965            raise ValueError(f"reduce_op ({reduce_op}) is not supported")
966        torch._C._distributed_c10d._register_work(ret, Work())
967        return ret
968
969
970@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA")
971def _low_contention_reduce_scatter(
972    tensor: torch.Tensor,
973    reduce_op: str,
974    group_name: str,
975) -> torch.Tensor:
976    """
977    Performs reduce-scatter with symmetric memory in a low-contention fashion.
978
979    This implementation performs a P2P-based all-to-all followed by an offline
980    reduction.
981
982    When `tensor` is already in symmetric memory:
983        - Pull-based all-to-all is used.
984        - No symmetric memory workspace is required.
985
986    When `tensor` is not in symmetric memory:
987        - Push-based all-to-all is used.
988        - Symmetric memory workspace size requirement: the size of `tensor`.
989
990    SM-usage:
991        - SM-based copy of the rank's own chunk for the all-to-all.
992        - Reduction on the all-to-all result.
993
994    TODO(yifu): the SM-based copy can be avoided with a list-based reduction
995    kernel.
996    """
997    symm_mem = _SymmetricMemory.rendezvous(tensor)
998    if symm_mem is not None:
999        return _low_contention_reduce_scatter_with_symm_mem_input(
1000            tensor, reduce_op, symm_mem
1001        )
1002    else:
1003        workspace = get_symm_mem_workspace(
1004            group_name, tensor.numel() * tensor.element_size()
1005        )
1006        return _low_contention_reduce_scatter_with_workspace(
1007            tensor, reduce_op, workspace
1008        )
1009