xref: /aosp_15_r20/external/pytorch/torch/distributed/_functional_collectives.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import sys
3import warnings
4from typing import cast, List, Optional, Tuple, TYPE_CHECKING, Union
5
6import torch
7import torch.distributed as dist
8import torch.distributed.distributed_c10d as c10d
9from torch.distributed.device_mesh import DeviceMesh
10from torch.fx.experimental.proxy_tensor import get_proxy_mode
11
12from . import _functional_collectives_impl as fun_col_impl
13
14
15try:
16    from torch.utils._cxx_pytree import tree_map_only
17except ImportError:
18    from torch.utils._pytree import tree_map_only  # type: ignore[no-redef]
19
20
21if torch._running_with_deploy():
22
23    def is_torchdynamo_compiling():
24        """Can't import torchdynamo in torchdeploy builds currently."""
25        return False
26
27else:
28    try:
29        from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
30    except Exception:
31        warnings.warn(
32            "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
33        )
34
35        def is_torchdynamo_compiling():
36            return False
37
38
39"""
40New traceable, functional collectives.
41RFC: https://github.com/pytorch/pytorch/issues/93173
42
43  compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
44  eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
45         automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
46         a downstream op.
47
48Issues:
49* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
50* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
51"""
52
53"""
54Functional collectives are asynchronous only and we perform implicit stream synchronization
55on behalf of the user.
56
57We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
58first usage of the tensor and insert cross stream sync at the right place.
59
60The above are the easy bits, the hard one is how we match the Work object returned by
61c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
62op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
63dispatcher which might call other implementations that are allowed to change the returned
64tensor - even return a tensor with a different shape (see ``torch.vmap``).
65
66This means the caller of our ops receives a Tensor that is not guaranteed to be the same
67allocated by our implementations and that makes pairing The AsyncTensor to the original
68tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
69
70Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
71identity is not stable across dispatch, the op caller would end up with a different Tensor
72instance that would not match any in the dictionary.
73
74With Tensor identity out of the question, we decided use the tensor data pointer, which
75should be stable across all the Tensor changes done during dispatch.
76
77We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
78
79We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
80
81Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
82can clean up stale entries in the dictionary.
83
84To eliminate the possibility of races we have a global version counter that is used by the finalizer.
85
86As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
87
88"""
89
90"""
91Functional collectives can accept any of these types to describe the ranks participating in collectives.
92
93The different types will be desugared to a canonical format
94"""
95RANK_TYPES = Union[
96    List[int],
97    List[List[int]],
98    dist.ProcessGroup,
99    DeviceMesh,
100    Tuple["dist.tensor.DeviceMesh", int],
101    str,
102]
103
104
105"""
106User facing APIs for functional collectives
107-------------------------------------------
108
109These apis are called by user code and expected to work both in eager execution and compilation,
110but there are significant differences to how the two modes are implemented underneath.
111
112Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
113just before the tensor is first used.  Compiled tracing currently relies on the compiler to perform this optimization,
114and cannot yet correctly trace the AsyncTensor wrapper class.  In the future, these paths may be unified
115if sufficient subclass support is added in dynamo.
116
117Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
118
119Here's how it works under torch.compile/dynamo:
120all_reduce(...)
121  |--> _expand_group(...)               - desugars processgroup into canonical/traceable format
122  |--> c10d_functional.all_reduce(...)  - dynamo captures this op call, doesn't trace deeper
123  |--> _maybe_wrap_tensor(...)          - wait_tensor() op is immediately called, no AsyncTensor subclass needed
124
125And under eager execution:
126all_reduce(...)
127  |--> _expand_group(...)               - same as above, but less critical for eager
128  |--> c10d_functional.all_reduce(...)  - dispatches to real kernel OR records op in trace
129  |--> _maybe_wrap_tensor(...)          - AsyncTensor wrapper applied to returned tensor,
130                                          which issues wait_tensor() at the time of first use
131"""
132
133
134def wait_tensor(tensor):
135    """
136    Wait on a tensor returned by the collectives ops.
137
138    Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
139    """
140    return torch.ops._c10d_functional.wait_tensor(tensor)  # type: ignore[attr-defined]
141
142
143def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
144    """
145    Broadcasts the tensor to all processes in the given process group.
146
147    Args:
148        src (int): Source rank
149        group (ProcessGroup or List[int]): The process group to work on.
150        tag (str, optional): A unique identifier for the collective. Default: empty string
151    """
152    group_name = _resolve_group_name(group, tag)
153    tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
154    return _maybe_wrap_tensor(tensor)
155
156
157def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
158    """
159    Reduces the tensor data across all machines in such a way that all get
160    the final result.
161
162    The input tensor is left unmodified.
163
164    Group can be one of:
165        List[int]: ranks participating in the collective.
166        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
167        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
168        DeviceMesh: Do a SPMD collective over all ranks of the mesh
169        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
170
171    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
172    that information and perform collective algebraic optimization. Use other forms of input for that.
173    """
174    group_name = _resolve_group_name(group, tag)
175    tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
176    return _maybe_wrap_tensor(tensor)
177
178
179def all_gather_tensor(
180    self: torch.Tensor,
181    gather_dim: int,
182    group: RANK_TYPES,
183    tag: str = "",
184):
185    """
186    Gather tensor data across from all machines and concatenate over ``gather_dim``.
187
188    Note that it currently only supports gather_dim = 0.
189
190    The input tensor is left unmodified.
191    Group can be one of:
192        List[int]: ranks participating in the collective.
193        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
194        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
195        DeviceMesh: Do a SPMD collective over all ranks of the mesh
196        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
197
198    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
199    that information and perform collective algebraic optimization. Use other forms of input for that.
200    """
201    assert self.is_contiguous()
202    group_name = _resolve_group_name(group, tag)
203    group_size = c10d._get_group_size_by_name(group_name)
204    tensor = torch.ops._c10d_functional.all_gather_into_tensor(
205        self, group_size, group_name
206    )
207    res = _maybe_wrap_tensor(tensor)
208    # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
209    if gather_dim != 0:
210        # torch.cat access the data so we already need to wait here, first do wait
211        # and then chunk + cat avoid us going through ACT dispatching logic again
212        if isinstance(res, AsyncCollectiveTensor):
213            res = res.wait()  # type: ignore[attr-defined]
214        res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
215    return res
216
217
218def all_gather_tensor_autograd(
219    self: torch.Tensor,
220    gather_dim: int,
221    group: RANK_TYPES,
222    tag: str = "",
223):
224    """
225    Gather tensor data across from all machines and concatenate over ``gather_dim``.
226
227    Note that it currently only supports gather_dim = 0.
228
229    This function is the same as all_gather_tensor but will propagate the
230    backwards gradient across workers.
231
232    See all_gather_tensor for more details on usage.
233    """
234    group_name = _resolve_group_name(group, tag)
235    group_size = c10d._get_group_size_by_name(group_name)
236
237    tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor(
238        self, group_size, group_name
239    )
240    res = _FromTorchTensor.apply(tensor)
241    # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
242    if gather_dim != 0:
243        # torch.cat access the data so we already need to wait here, first do wait
244        # and then chunk + cat avoid us going through ACT dispatching logic again
245        if isinstance(res, AsyncCollectiveTensor):
246            res = res.wait()  # type: ignore[attr-defined]
247        res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
248    return res
249
250
251def reduce_scatter_tensor(
252    self: torch.Tensor,
253    reduceOp: str,
254    scatter_dim: int,
255    group: RANK_TYPES,
256    tag: str = "",
257):
258    """
259    Reduces the tensor data across all machines in such a way that all get
260    the final result, then scatter the results to corresponding ranks.
261
262
263    The input tensor is left unmodified.
264    Group can be one of:
265        List[int]: ranks participating in the collective.
266        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
267        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
268        DeviceMesh: Do a SPMD collective over all ranks of the mesh
269        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
270    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
271    that information and perform collective algebraic optimization. Use other forms of input for that.
272    """
273    group_name = _resolve_group_name(group, tag)
274    group_size = c10d._get_group_size_by_name(group_name)
275
276    assert (
277        self.size(scatter_dim) % group_size == 0
278    ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
279    if scatter_dim != 0:
280        tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
281        self = torch.cat(tensor_list)
282
283    tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
284        self,
285        reduceOp.lower(),
286        group_size,
287        group_name,  # type: ignore[possibly-undefined]
288    )
289    res = _maybe_wrap_tensor(tensor)
290    return res
291
292
293def reduce_scatter_tensor_autograd(
294    self: torch.Tensor,
295    reduceOp: str,
296    scatter_dim: int,
297    group: RANK_TYPES,
298    tag: str = "",
299):
300    """
301    Reduces the tensor data across all machines in such a way that all get
302    the final result, then scatter the results to corresponding ranks.
303
304    This function is the same as reduce_scatter_tensor but will propagate the
305    backwards gradient across workers.
306
307    Currently only the "sum" reduceOp is supported.
308
309    See reduce_scatter_tensor for more details on usage.
310    """
311
312    group_name = _resolve_group_name(group, tag)
313    group_size = c10d._get_group_size_by_name(group_name)
314
315    assert (
316        self.size(scatter_dim) % group_size == 0
317    ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
318    if scatter_dim != 0:
319        tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
320        self = torch.cat(tensor_list)
321
322    tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor(
323        self,
324        reduceOp.lower(),
325        group_size,
326        group_name,  # type: ignore[possibly-undefined]
327    )
328    res = _FromTorchTensor.apply(tensor)
329    return res
330
331
332def all_reduce_coalesced(
333    self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
334) -> List[torch.Tensor]:
335    """
336    Reduces a list of tensors across all machines in such a way that all get
337    the final result.
338
339    The all tensors in the input list are left unmodified.
340
341    Group can be one of:
342        List[int]: ranks participating in the collective.
343        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
344        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
345        DeviceMesh: Do a SPMD collective over all ranks of the mesh
346        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
347
348    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
349    that information and perform collective algebraic optimization. Use other forms of input for that.
350    """
351    group_name = _resolve_group_name(group, tag)
352    tensor_list = torch.ops._c10d_functional.all_reduce_coalesced(  # type: ignore[attr-defined]
353        self,
354        reduceOp.lower(),
355        group_name,
356    )
357    return list(map(_maybe_wrap_tensor, tensor_list))
358
359
360def all_gather_into_tensor_coalesced(
361    self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
362) -> List[torch.Tensor]:
363    """
364    Gather a list of tensors across from all machines.
365
366    Note that it currently only supports gather_dim = 0.
367
368    The input tensor is left unmodified.
369    Group can be one of:
370        List[int]: ranks participating in the collective.
371        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
372        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
373        DeviceMesh: Do a SPMD collective over all ranks of the mesh
374        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
375
376    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
377    that information and perform collective algebraic optimization. Use other forms of input for that.
378    """
379    group_name = _resolve_group_name(group, tag)
380    group_size = c10d._get_group_size_by_name(group_name)
381    tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(  # type: ignore[attr-defined]
382        self,
383        group_size,
384        group_name,
385    )
386    return list(map(_maybe_wrap_tensor, tensor_list))
387
388
389def reduce_scatter_tensor_coalesced(
390    inputs: List[torch.Tensor],
391    reduceOp: str,
392    scatter_dim: List[int],
393    group: RANK_TYPES,
394    tag: str = "",
395) -> List[torch.Tensor]:
396    """
397    Reduces a list of tensors across all machines in such a way that all get
398    the final result, then scatter the results to corresponding ranks.
399
400    The input tensors are left unmodified.
401    Group can be one of:
402        List[int]: ranks participating in the collective.
403        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
404        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
405        DeviceMesh: Do a SPMD collective over all ranks of the mesh
406        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
407
408    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
409    that information and perform collective algebraic optimization. Use other forms of input for that.
410    """
411    group_name = _resolve_group_name(group, tag)
412    group_size = c10d._get_group_size_by_name(group_name)
413
414    assert len(scatter_dim) == len(inputs)
415    for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
416        assert (
417            tensor.size(dim) % group_size == 0
418        ), f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
419        if dim != 0:
420            tensor_list = torch.chunk(tensor, group_size, dim=dim)
421            inputs[idx] = torch.cat(tensor_list)
422
423    tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(  # type: ignore[attr-defined]
424        inputs,
425        reduceOp.lower(),
426        group_size,
427        group_name,  # type: ignore[possibly-undefined]
428    )
429
430    return list(map(_maybe_wrap_tensor, tensor_list))
431
432
433# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
434# Today, this maps 1:1 with "aten ops that are views".
435def _is_view_op(tgt):
436    assert isinstance(tgt, torch._ops.OpOverload)
437    schema = tgt._schema
438    if len(schema.arguments) > 0:
439        first_arg = schema.arguments[0]
440        # check if op is a view
441        return first_arg.alias_info is not None and not first_arg.alias_info.is_write
442
443
444def all_to_all_single(
445    self: torch.Tensor,
446    output_split_sizes: Optional[List[int]],
447    input_split_sizes: Optional[List[int]],
448    group: RANK_TYPES,
449    tag: str = "",
450) -> torch.Tensor:
451    """
452    Each process splits input tensor and then scatters the split list
453    to all processes in a group. Then concatenate the received tensors from all
454    the processes in the group and return single output tensor.
455
456    Group can be one of:
457        List[int]: ranks participating in the collective.
458        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
459        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
460        DeviceMesh: Do a SPMD collective over all ranks of the mesh
461        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
462
463    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
464    that information and perform collective algebraic optimization. Use other forms of input for that.
465    """
466    if output_split_sizes is not None:
467        assert all(
468            isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
469        ), output_split_sizes
470    if input_split_sizes is not None:
471        assert all(
472            isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
473        ), input_split_sizes
474    group_name = _resolve_group_name(group, tag)
475    group_size = c10d._get_group_size_by_name(group_name)
476    if output_split_sizes is None or input_split_sizes is None:
477        assert output_split_sizes is None and input_split_sizes is None, (
478            "output_split_sizes and input_split_sizes must either be "
479            "specified together or both set to None"
480        )
481        output_split_sizes = [self.shape[0] // group_size] * group_size
482        input_split_sizes = output_split_sizes
483    tensor = torch.ops._c10d_functional.all_to_all_single(  # type: ignore[attr-defined]
484        self,
485        output_split_sizes,
486        input_split_sizes,
487        group_name,
488    )
489    return _maybe_wrap_tensor(tensor)
490
491
492def all_to_all_single_autograd(
493    self: torch.Tensor,
494    output_split_sizes: Optional[List[int]],
495    input_split_sizes: Optional[List[int]],
496    group: RANK_TYPES,
497    tag: str = "",
498) -> torch.Tensor:
499    """
500    Same as all_to_all_single but supports autograd.
501    """
502    if output_split_sizes is not None:
503        assert all(
504            isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
505        ), output_split_sizes
506    if input_split_sizes is not None:
507        assert all(
508            isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
509        ), input_split_sizes
510
511    group_name = _resolve_group_name(group, tag)
512    group_size = c10d._get_group_size_by_name(group_name)
513    if output_split_sizes is None or input_split_sizes is None:
514        assert output_split_sizes is None and input_split_sizes is None, (
515            "output_split_sizes and input_split_sizes must either be "
516            "specified together or both set to None"
517        )
518        output_split_sizes = [self.shape[0] // group_size] * group_size
519        input_split_sizes = output_split_sizes
520    tensor = torch.ops._c10d_functional_autograd.all_to_all_single(  # type: ignore[attr-defined]
521        self,
522        output_split_sizes,
523        input_split_sizes,
524        group_name,
525    )
526    return _FromTorchTensor.apply(tensor)
527
528
529def permute_tensor(
530    self: torch.Tensor,
531    src_dst: List[int],
532    group: RANK_TYPES,
533    tag: str = "",
534) -> torch.Tensor:
535    """
536    Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
537    be defined such that src_dst[m] == n means m sends to n.
538
539    Group can be one of:
540        List[int]: ranks participating in the collective.
541        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
542        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
543        DeviceMesh: Do a SPMD collective over all ranks of the mesh
544        (DeviceMesh, int): Do a MPMD collective over one
545    """
546    t, rankset, group_size = _expand_group(group, tag)
547    local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
548
549    output_split_sizes = [0] * group_size
550    input_split_sizes = [0] * group_size
551    for src, dst in enumerate(src_dst):
552        if src == dist.get_rank(local_pg):
553            input_split_sizes[dst] = self.numel()
554        if dst == dist.get_rank(local_pg):
555            output_split_sizes[src] = self.numel()
556
557    return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
558
559
560class AsyncCollectiveTensor(torch.Tensor):
561    r"""
562    A Tensor wrapper subclass that is used to trigger a call to wait
563    prior to first use of the underlying tensor.
564    Use it inside functional collective pytorch wrappers like the following:
565    def functional_collective(self, group, tag):
566        tag, rankset, group_size = _expand_group(group, tag)
567        tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
568        return _maybe_wrap_tensor(tensor)
569    """
570    elem: torch.Tensor
571    completed: bool
572
573    __slots__ = ["elem", "completed"]
574
575    @staticmethod
576    def __new__(cls, elem: torch.Tensor):
577        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
578            cls,
579            elem.size(),
580            strides=elem.stride(),
581            storage_offset=elem.storage_offset(),
582            dtype=elem.dtype,
583            layout=elem.layout,
584            device=elem.device,
585            requires_grad=elem.requires_grad,
586        )
587        r.elem = elem
588        r.completed = False
589        return r
590
591    def __tensor_flatten__(self):
592        return ["elem"], None
593
594    def tolist(self):
595        return self.trigger_wait().tolist()
596
597    @staticmethod
598    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
599        assert meta is None
600        elem = inner_tensors["elem"]
601        return AsyncCollectiveTensor(elem)
602
603    def __repr__(self):
604        return f"AsyncCollectiveTensor({self.trigger_wait()})"
605
606    def trigger_wait(self):
607        if not self.completed:
608            out = wait_tensor(self.elem)
609            self.completed = True
610            return out
611        else:
612            return self.elem
613
614    def wait(self) -> torch.Tensor:
615        return wait_tensor(self.elem)
616
617    def _get_acs_underlying_tensor(self):
618        """This method enables  _functional_collectives_impl to test if a tensor is an ACS"""
619        return self.elem
620
621    @classmethod
622    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
623        if func == torch.ops.aten.view.default:
624            # Fast handle aten.view as a lot of view related op goes to aten.view
625            # eventually, this avoids pytree slowdown
626            res = func(args[0].elem, args[1])
627            wrapper_res = AsyncCollectiveTensor(res)
628            return wrapper_res
629
630        is_view_op = _is_view_op(func)
631
632        def unwrap(e: AsyncCollectiveTensor):
633            # wait_tensor is idepotent and will do stream sync only once
634            if not is_view_op:
635                return e.trigger_wait()
636            return e.elem
637
638        def wrap(e: torch.Tensor):
639            # wait_tensor is idepotent and will do stream sync only once
640            assert not isinstance(e, AsyncCollectiveTensor)
641            res = AsyncCollectiveTensor(e)
642            return res
643
644        unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
645        unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
646
647        # we don't wrap the result as it doesn't need to be waited on.
648        out = func(*unwrapped_args, **unwrapped_kwargs)
649
650        # View ops dont require a sync, so we should re-wrap the outputs.
651        if is_view_op:
652            out = tree_map_only(torch.Tensor, wrap, out)
653
654        return out
655
656    def numpy(self):
657        return self.wait().numpy()
658
659
660"""
661Utils and infrastructure for tracing support
662"""
663
664
665def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
666    """
667    _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
668
669    By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
670    torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
671    """
672    # had to define this hack _inside_ expand_group to avoid
673    # graph_break [('torch.* op returned non-Tensor int
674    # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
675    if TYPE_CHECKING:
676
677        def cast_listlistint(x):
678            return cast(List[List[int]], x)
679
680        def cast_listint(x):
681            return cast(List[int], x)
682
683    else:
684        # fake cast op for use at runtime since dynamo doesn't support real cast
685        # also, dynamo didn't like encountering 'typing' objects ()
686        # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
687        def cast_listlistint(x):
688            return x
689
690        def cast_listint(x):
691            return x
692
693    rankset: List[int]
694    if isinstance(group, list):
695        if isinstance(group[0], list):
696            nested_list = cast_listlistint(group)
697            rankset = []
698            group_size = -1
699            for rs in nested_list:
700                rankset.extend(rs)
701                if group_size != -1 and group_size != len(rs):
702                    raise ValueError(
703                        f"group sizes must be identical found {group_size} and {len(rs)}"
704                    )
705                group_size = len(rs)
706        else:
707            rankset = cast_listint(group)
708            group_size = len(rankset)
709    elif isinstance(group, dist.ProcessGroup):
710        rankset = dist.get_process_group_ranks(group)
711        group_size = len(rankset)
712        tag = tag or c10d._get_group_tag(group)
713    elif isinstance(group, DeviceMesh):
714        assert (
715            group.ndim == 1
716        ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
717        # TODO: it should run collective in the whole mesh instead of dim 0
718        tag, rankset, _ = group._dim_group_infos[0]
719        group_size = len(rankset)
720    elif isinstance(group, tuple):
721        if (
722            len(group) == 2
723            and isinstance(group[0], DeviceMesh)
724            and isinstance(group[1], int)
725        ):
726            dmesh = group[0]
727            dim = group[1]
728            tag, rankset, _ = dmesh._dim_group_infos[dim]
729            group_size = len(rankset)
730        else:
731            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
732    else:
733        raise ValueError(
734            "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
735        )
736
737    return (tag, rankset, group_size)
738
739
740def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
741    """
742    Given group in RANK_TYPES, return the group name.
743    """
744    # `tag` will be deprecated. See details in:
745    # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
746    if isinstance(group, dist.ProcessGroup):
747        return group.group_name
748    elif isinstance(group, str):
749        return group
750    elif isinstance(group, DeviceMesh):
751        assert (
752            group.ndim == 1
753        ), "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
754        return group._dim_group_infos[0][2]
755    elif isinstance(group, tuple):
756        if (
757            len(group) == 2
758            and isinstance(group[0], DeviceMesh)
759            and isinstance(group[1], int)
760        ):
761            dmesh = group[0]
762            dim = group[1]
763            return dmesh._dim_group_infos[dim][2]
764        else:
765            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
766    elif isinstance(group, list):
767        if not is_torchdynamo_compiling():
768            warnings.warn(
769                "The combination of ranks + tag as process group "
770                "identifier has been deprecated. Please switch to "
771                "using ProcessGroup, DeviceMesh, or group name instead.",
772                FutureWarning,
773                stacklevel=3,
774            )
775        return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag)
776    else:
777        raise ValueError(f"Unsupported group type: {type(group)}, {group}")
778
779
780class _FromTorchTensor(torch.autograd.Function):
781    """
782    _FromTorchTensor allows autograd to propagate from a normal Tensor to an
783    AsyncCollectiveTensor.
784    """
785
786    @staticmethod
787    def forward(  # type: ignore[override]
788        ctx,  # pyre-ignore[2]: Parameter must be annotated.
789        input: torch.Tensor,
790    ) -> torch.Tensor:
791        return _maybe_wrap_tensor(input)
792
793    @staticmethod
794    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
795        return grad_output
796
797
798def _are_we_tracing() -> bool:
799    if is_torchdynamo_compiling():
800        return True
801    # If functionalization is turned on, we are almost definitely compiling/tracing.
802    # (In particular, AOTAutograd traces a model once with functionalization on
803    #  but proxy tracing turned of, so this is how we detect it).
804    if (
805        torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL)
806        is not None
807    ):
808        return True
809    return get_proxy_mode() is not None
810
811
812def _maybe_wrap_tensor(self) -> torch.Tensor:
813    if _are_we_tracing():
814        return wait_tensor(self)
815    res = AsyncCollectiveTensor(self)
816    return cast(torch.Tensor, res)
817
818
819def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
820    def mk_out_tensor(shard):
821        out_size = list(shard.size())
822        out_size[0] *= group_size
823        out_tensor = shard.new_empty(out_size)
824        return out_tensor
825
826    return [mk_out_tensor(t) for t in self]
827
828
829# We now register meta kernels to deal with tracing
830def _broadcast_meta(self, *args):
831    return torch.empty_like(self)
832
833
834def _all_reduce_meta(self, *args):
835    return torch.empty_like(self)
836
837
838def _wait_tensor_meta(self, *args):
839    return torch.empty_like(self)
840
841
842def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
843    out_size = list(shard.size())
844    out_size[0] *= group_size
845    return shard.new_empty(out_size)
846
847
848def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
849    out_size = list(input.size())
850    out_size[0] //= group_size
851    return input.new_empty(out_size)
852
853
854def _all_reduce_coalesced_meta(self, *args):
855    return [torch.empty_like(t) for t in self]
856
857
858def _all_reduce__meta(inp, *args):
859    return inp
860
861
862def _broadcast__meta(inp, *args):
863    return inp
864
865
866def _all_reduce_coalesced__meta(inputs, *args):
867    return inputs
868
869
870def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
871    def mk_out_tensor(input):
872        out_size = list(input.size())
873        out_size[0] //= group_size
874        out_tensor = input.new_empty(out_size)
875        return out_tensor
876
877    return [mk_out_tensor(t) for t in inputs]
878
879
880# NB: We often say all_to_all has dynamic output size, but this is not
881# technically true: instead, what typically happens is you manually
882# communicate the output_split_sizes ahead of time (which is dynamic),
883# but then you pass those sizes explicitly, and the all to all itself
884# isn't dynamic, it just follows the specified output splits
885def _all_to_all_single_meta(
886    input, output_split_sizes, input_split_sizes, *args, **kwargs
887):
888    if output_split_sizes is None:
889        return input.new_empty(input.size())
890    else:
891        for s in output_split_sizes:
892            torch._check_is_size(s)
893        out_size = list(input.size())
894        out_size[0] = sum(output_split_sizes)
895        return input.new_empty(out_size)
896
897
898def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out):
899    shape = list(input.size())
900    shape[0] *= group_size
901    return input.new_empty(shape)
902
903
904def _all_gather_into_tensor_native_meta(input, group_size, group_name):
905    shape = list(input.size())
906    shape[0] *= group_size
907    return input.new_empty(shape)
908
909
910def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
911    return [
912        _all_gather_into_tensor_native_meta(input, group_size, group_name)
913        for input in inputs
914    ]
915
916
917def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
918    shape = list(inp.size())
919    shape[0] //= group_size
920    return inp.new_empty(shape)
921
922
923def _reduce_scatter_tensor_coalesced_native_meta(
924    inputs, reduce_op, group_size, group_name
925):
926    return [
927        _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
928        for inp in inputs
929    ]
930
931
932if not torch._running_with_deploy():
933    # Library MUST be defined at module scope or it doesn't work
934    # Creating a "DEF" Library always crashes torch::deploy so we create our
935    # Library instances here guarded against running inside it
936    lib_impl = torch.library.Library("_c10d_functional", "IMPL")
937    lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
938    lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
939    lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
940    lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
941    lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
942    lib_impl.impl(
943        "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
944    )
945    lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
946    lib_impl.impl(
947        "all_gather_into_tensor_coalesced",
948        _all_gather_into_tensor_coalesced_native_meta,
949        "Meta",
950    )
951    lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
952    lib_impl.impl(
953        "reduce_scatter_tensor_coalesced",
954        _reduce_scatter_tensor_coalesced_native_meta,
955        "Meta",
956    )
957    lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
958    lib_impl.impl("broadcast", _broadcast_meta, "Meta")
959    lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
960
961    # mark these ops has side effect so that they won't be removed by DCE
962    torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
963    torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
964
965    # Register legacy ops for backward compatibility
966    # TODO(yifu): remove these in functional collective beta release
967    legacy_lib = torch.library.Library("c10d_functional", "DEF")
968    legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
969    ops_defs = [
970        "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
971        "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
972        "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
973        "wait_tensor(Tensor self) -> Tensor",
974        "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
975        "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
976        "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
977        "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
978        "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor",  # noqa: B950
979    ]
980
981    my_module = sys.modules[__name__]
982    for op_def in ops_defs:
983        op_name = op_def[0 : op_def.index("(")]
984        backend_impl = getattr(fun_col_impl, f"_{op_name}")
985        legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
986        legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
987
988else:
989    warnings.warn(
990        "PyTorch Distributed functional collectives do not work with torch::deploy."
991    )
992
993
994"""
995Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
996functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
997
998We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
999the mapping dict below.
1000
1001These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
1002"""
1003
1004
1005def all_gather_tensor_inplace(
1006    output_tensor: torch.Tensor,
1007    input_tensor: torch.Tensor,
1008    group,  # TODO add a type,
1009    async_op: bool = False,
1010    tag: str = "",
1011    gather_dim: int = 0,
1012):
1013    assert (
1014        not async_op
1015    ), "Can't remap async version of inplace op to functional collective"
1016
1017    group = group or dist.group.WORLD
1018    assert group is not None
1019
1020    return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
1021
1022
1023def reduce_scatter_tensor_inplace(
1024    output: torch.Tensor,
1025    input: torch.Tensor,
1026    op: str = "sum",  # TODO type is actually c10d ReduceOp. is this ok?
1027    group=None,  # TODO add a type
1028    async_op: bool = False,
1029    scatter_dim: int = 0,
1030    tag: str = "",
1031):
1032    assert (
1033        not async_op
1034    ), "Can't remap async version of inplace op to functional collective"
1035
1036    group = group or dist.group.WORLD
1037    assert group is not None
1038
1039    return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
1040
1041
1042REDUCE_OP_TO_STR = {
1043    dist.ReduceOp.SUM: "sum",
1044    dist.ReduceOp.AVG: "avg",
1045    dist.ReduceOp.PRODUCT: "product",
1046    dist.ReduceOp.MIN: "min",
1047    dist.ReduceOp.MAX: "max",
1048    dist.ReduceOp.BAND: "band",
1049    dist.ReduceOp.BOR: "bor",
1050    dist.ReduceOp.BXOR: "bxor",
1051}
1052
1053
1054def all_reduce_inplace(
1055    tensor: torch.Tensor,
1056    op: str = "sum",
1057    group=None,
1058    async_op: bool = False,
1059    tag: str = "",
1060):
1061    assert (
1062        not async_op
1063    ), "Can't remap async version of inplace op to functional collective"
1064
1065    group = group or dist.group.WORLD
1066    assert group is not None
1067
1068    return tensor.copy_(all_reduce(tensor, op, group, tag))
1069
1070
1071def all_to_all_inplace(
1072    output: torch.Tensor,
1073    input: torch.Tensor,
1074    output_split_sizes=None,
1075    input_split_sizes=None,
1076    group=None,
1077    async_op=False,
1078    tag: str = "",
1079):
1080    assert (
1081        not async_op
1082    ), "Can't remap async version of inplace op to functional collective"
1083
1084    group = group or dist.group.WORLD
1085    assert group is not None
1086
1087    return output.copy_(
1088        all_to_all_single(
1089            input,
1090            output_split_sizes,
1091            input_split_sizes,
1092            group,
1093            tag,
1094        )
1095    )
1096
1097
1098def all_gather_inplace(
1099    tensor_list: List[torch.Tensor],
1100    tensor: torch.Tensor,
1101    group=None,
1102    async_op=False,
1103    tag: str = "",
1104):
1105    assert (
1106        not async_op
1107    ), "Can't remap async version of inplace op to functional collective"
1108    assert all(
1109        t.size(0) == tensor.size(0) for t in tensor_list
1110    ), "Remapping variable size all_gather is not yet supported"
1111
1112    group = group or dist.group.WORLD
1113    assert group is not None
1114
1115    output = all_gather_tensor(tensor, 0, group, tag)
1116
1117    # Use aten.slice instead of aten.split because the latter causes
1118    # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
1119    output_splits = []
1120    offset = 0
1121    for t in tensor_list:
1122        output_splits.append(output[offset : offset + t.size(0)])
1123        offset += t.size(0)
1124    for dst, src in zip(tensor_list, output_splits):
1125        dst.copy_(src)
1126    return tensor_list
1127
1128
1129from torch.distributed.distributed_c10d import (
1130    _all_gather_base as legacy_all_gather_base,
1131    _reduce_scatter_base as legacy_reduce_scatter_base,
1132    all_gather as legacy_all_gather,
1133    all_gather_into_tensor as legacy_allgather,
1134    all_reduce as legacy_allreduce,
1135    all_to_all_single as legacy_all_to_all_single,
1136    reduce_scatter_tensor as legacy_reducescatter,
1137)
1138
1139
1140# This dict should contain sets of functions that dynamo is allowed to remap.
1141# Functions in this set should accept the same args/kwargs 1:1 as their mapping.
1142traceable_collective_remaps = {
1143    legacy_allgather: all_gather_tensor_inplace,
1144    legacy_reducescatter: reduce_scatter_tensor_inplace,
1145    legacy_allreduce: all_reduce_inplace,
1146    legacy_all_to_all_single: all_to_all_inplace,
1147    legacy_all_gather: all_gather_inplace,
1148    legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
1149    legacy_all_gather_base: all_gather_tensor_inplace,
1150}
1151