xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/multi_threaded_pg.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import sys
4import threading
5from dataclasses import dataclass
6from typing import Dict, List, Optional, Tuple, Union
7from functools import partial, reduce
8
9import torch
10import torch.distributed as dist
11import weakref
12from torch._C._distributed_c10d import (
13    _create_work_from_future,
14    AllgatherOptions,
15    AllreduceOptions,
16    AllToAllOptions,
17    BarrierOptions,
18    BroadcastOptions,
19    ReduceScatterOptions,
20    ScatterOptions,
21    Store,
22    ReduceOp,
23)
24from torch.distributed.distributed_c10d import _CollOp, _store_based_barrier, P2POp
25from torch.futures import Future
26from torch.utils import _pytree as pytree
27
28"""
29TODO:
30Lots of missing collectives.
31Collectives validation.
32Make timeout robust by making collectives respect the test deadline.
33Make tests robust by making collectives interruptible.
34We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
35
36"""
37
38
39def flatten_list(lst):
40    return pytree.tree_leaves(lst)
41
42
43def ret_work(ret):
44    fut = Future()
45    fut.set_result(ret)
46    return _create_work_from_future(fut)
47
48def binop_reduce(tensors, op):
49    res = op(torch.stack(tensors), dim=0)
50    if isinstance(res, torch.Tensor):
51        return res
52    # min/max return a namedtuple
53    return res.values
54
55def bitwise_reduce(tensors, op):
56    return reduce(op, tensors)
57
58_reduce_ops = {
59    ReduceOp.SUM: partial(binop_reduce, op=torch.sum),
60    ReduceOp.AVG: partial(binop_reduce, op=torch.mean),
61    ReduceOp.PRODUCT: partial(binop_reduce, op=torch.prod),
62    ReduceOp.MIN: partial(binop_reduce, op=torch.min),
63    ReduceOp.MAX: partial(binop_reduce, op=torch.max),
64    ReduceOp.BAND: partial(bitwise_reduce, op=torch.bitwise_and),
65    ReduceOp.BOR: partial(bitwise_reduce, op=torch.bitwise_or),
66    ReduceOp.BXOR: partial(bitwise_reduce, op=torch.bitwise_xor),
67}
68
69class AllToAll:
70    @torch.no_grad()
71    def work(self, data):
72        world_size = len(data)
73        for dest_rank in range(world_size):
74            output_tensor_list, _ = data[dest_rank]
75            for src_rank in range(world_size):
76                _, input_tensor_list = data[src_rank]
77                output_tensor_list[src_rank].copy_(input_tensor_list[dest_rank])
78
79class AllToAllBase:
80    @torch.no_grad()
81    def work(self, data):
82        world_size = len(data)
83        for dest_rank in range(world_size):
84            output_buffer, _, output_split_sizes, _ = data[dest_rank]
85
86            output_indexes = self._size_cumsum(output_buffer.size(0), output_split_sizes, world_size)
87
88            for src_rank in range(world_size):
89                _, input_buffer, _, input_split_sizes = data[src_rank]
90                input_indexes = self._size_cumsum(input_buffer.size(0), input_split_sizes, world_size)
91
92                output_buffer[output_indexes[src_rank]:output_indexes[src_rank + 1]].copy_(
93                    input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]]
94                )
95
96    def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, List[int], None], world_size: int) -> torch.Tensor:
97        if sizes is None or len(sizes) == 0:
98            sizes = torch.full(
99                (world_size,), buf_size // world_size, dtype=torch.int64
100            )
101        if not isinstance(sizes, torch.Tensor):
102            sizes = torch.tensor(sizes, dtype=torch.int64)
103        assert sizes.dtype == torch.int64
104        sizes = torch.cumsum(
105            torch.cat(
106                (
107                    torch.tensor([0], dtype=torch.int64, device=sizes.device), sizes
108                ),
109                dim=0
110            ),
111            dim=0
112        )
113        return sizes
114
115class AllReduce:
116    def __init__(self, op):
117        if op.op not in _reduce_ops:
118            raise NotImplementedError(
119                f"AllReduce op {op.op} not supported on multithreaded pg for now."
120            )
121        self.op = op.op
122
123    @torch.no_grad()
124    def work(self, data):
125        for i in range(len(data[0])):
126            tensors = []
127            # use rank0 as the device for sum
128            rank_0_device = data[0][i].device
129            # collect all data to the list and make them
130            # all on rank 0 device
131            for src_rank in range(0, len(data)):
132                tensors.append(data[src_rank][i].to(rank_0_device))
133
134            # now mimic reduce across all ranks
135            res = _reduce_ops[self.op](tensors)
136
137            # copy all the reduced value to each rank
138            for src_rank in range(len(data)):
139                data[src_rank][i].copy_(res.to(data[src_rank][i].device))
140
141
142class AllGather:
143    @torch.no_grad()
144    def work(self, data):
145        for src_rank in range(len(data)):
146            in_tensor_list = data[src_rank][1]
147            # Can't handle all_gather with multiple tensors
148            assert len(in_tensor_list) == 1
149            src_tensor = in_tensor_list[0]
150
151            for dest in data:
152                dest_tensor = dest[0][0][src_rank]
153                dest_tensor.copy_(src_tensor)
154
155
156class Scatter:
157    def __init__(self, src):
158        self.src = src
159
160    @torch.no_grad()
161    def work(self, data):
162        src_in_tensor_list = data[self.src][1]
163        # Can't handle scatter with multiple input tensor list
164        assert len(src_in_tensor_list) == 1
165        src_in_tensors = src_in_tensor_list[0]
166
167        for rank, each_rank_data in enumerate(data):
168            out_tensor_list = each_rank_data[0]
169            # Can't handle scatter with multiple output tensor
170            assert len(out_tensor_list) == 1
171            dest_tensor = out_tensor_list[0]
172            dest_tensor.copy_(src_in_tensors[rank])
173
174
175class Gather:
176    def __init__(self, dst):
177        self.dst = dst
178
179    @torch.no_grad()
180    def work(self, data):
181        # Can't handle gather with multiple tensor lists
182        assert len(data[self.dst][0]) == 1
183        out_tensor_list = data[self.dst][0][0]
184        for rank, each_rank_data in enumerate(data):
185            src_in_tensor_list = each_rank_data[1]
186            # Can't handle gather with multiple tensor lists
187            assert len(src_in_tensor_list) == 1
188            dest_tensor = out_tensor_list[rank]
189            dest_tensor.copy_(src_in_tensor_list[0])
190
191class ReduceScatter:
192    def __init__(self, op):
193        if op != dist.ReduceOp.SUM and op != dist.ReduceOp.AVG:
194            raise NotImplementedError(f"ReduceScatter does not support {op}")
195        self.op = op
196
197    @torch.no_grad()
198    def work(self, data):
199        start_reduction = [False for _ in range(len(data))]
200        for each_rank_data in data:
201            # Can't handle reduce_scatter with multiple scatter list
202            assert len(each_rank_data[1]) == 1
203            to_scatter = each_rank_data[1][0]
204            for i in range(len(to_scatter)):
205                dest_tensor_on_rank_i = data[i][0]
206                # Can't handle reduce_scatter with multiple output tensor
207                assert len(dest_tensor_on_rank_i) == 1
208                dst_tensor_device = dest_tensor_on_rank_i[0].device
209                if not start_reduction[i]:
210                    dest_tensor_on_rank_i[0].copy_(to_scatter[i].to(dst_tensor_device))
211                    start_reduction[i] = True
212                else:
213                    dest_tensor_on_rank_i[0].add_(to_scatter[i].to(dst_tensor_device))
214        if self.op == dist.ReduceOp.AVG:
215            num_ranks = len(data)
216            for each_rank_data in data:
217                each_rank_data[0][0] /= num_ranks
218
219
220class Broadcast:
221    def __init__(self, src):
222        self.src = src
223
224    @torch.no_grad()
225    def work(self, data):
226        in_tensor_list = flatten_list(data[self.src])
227        for i in range(len(data)):
228            out_tensor_list = flatten_list(data[i])
229            for j in range(len(in_tensor_list)):
230                out_tensor_list[j].copy_(in_tensor_list[j])
231
232
233class Collective:
234    def __init__(self, world_size, collective, pg):
235        self._world_size = world_size
236        self._collective = collective
237
238        self._start_cond = threading.Condition()
239        self._done_cond = threading.Condition()
240
241        self._data = [None] * world_size
242        self._count = 0
243        self._done = False
244
245        self._pg = pg
246
247    def join(self, rank, data):
248        with self._start_cond:
249            self._data[rank] = data
250            self._count += 1
251
252            # notify rank 0
253            if self._count == self._world_size:
254                if rank > 0:
255                    self._start_cond.notify()
256
257            if rank == 0:
258                self._start_cond.wait_for(
259                    lambda: self._count == self._world_size or self._pg._terminate.is_set()
260                )
261                # SystemExit is not a subclass of Exception but BaseException
262                # and can be distinguished from normal exception raised from program errors
263                # so that we can hide it from the exception queue
264                if self._pg._terminate.is_set():
265                    sys.exit("Test termination event occurs.")
266
267        with self._done_cond:
268            # wait for rank 0 to finish
269            if rank > 0:
270                self._done_cond.wait_for(lambda: self._done or self._pg._terminate.is_set())
271                if self._pg._terminate.is_set():
272                    sys.exit("Test termination event occurs.")
273            else:
274                # copy data around
275                self._collective.work(self._data)
276                self._done = True
277                self._done_cond.notify_all()
278        return ret_work(data)
279
280
281class ProcessLocalGroup(dist.ProcessGroup):
282    _coll_lock = threading.Lock()
283    _cur_coll_on_pgs = {}
284
285    _terminate = threading.Event()
286
287    @classmethod
288    def _start_coll(cls, collective, pg):
289        with cls._coll_lock:
290            # pg_name is unique, we use that to record the mapping between pg and collective
291            if pg.pg_name not in cls._cur_coll_on_pgs:
292                cls._cur_coll_on_pgs[pg.pg_name] = Collective(pg.size(), collective, cls)
293            return cls._cur_coll_on_pgs[pg.pg_name]
294
295    @classmethod
296    def _end_coll(cls, collective, pg):
297        # This is racily called by all ranks, so only one will work
298        with cls._coll_lock:
299            if pg.pg_name in cls._cur_coll_on_pgs and cls._cur_coll_on_pgs[pg.pg_name] == collective:
300                cls._cur_coll_on_pgs.pop(pg.pg_name)
301
302    @classmethod
303    def exception_handle(cls, exc):
304        cls._terminate.set()
305        for coll in cls._cur_coll_on_pgs.values():
306            with coll._start_cond:
307                coll._start_cond.notify()
308            with coll._done_cond:
309                coll._done_cond.notify_all()
310
311    @classmethod
312    def reset(cls):
313        with cls._coll_lock:
314            cls._cur_coll_on_pgs = {}
315            cls._terminate.clear()
316
317    def alltoall_base(
318        self,
319        output_buffer: torch.Tensor,
320        input_buffer: torch.Tensor,
321        output_split_sizes: Optional[List[int]],
322        input_split_sizes: Optional[List[int]],
323        opts=AllToAllOptions()
324    ) -> torch.Tensor:
325        coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
326        res = coll.join(self._rank, (output_buffer, input_buffer, output_split_sizes, input_split_sizes))
327        ProcessLocalGroup._end_coll(coll, self)
328        return res
329
330    def alltoall(self, output_tensor_list, input_tensor_list, opts=AllToAllOptions()):
331        coll = ProcessLocalGroup._start_coll(AllToAll(), self)
332        res = coll.join(self._rank, (output_tensor_list, input_tensor_list))
333        ProcessLocalGroup._end_coll(coll, self)
334        return res
335
336    def allreduce(self, tensor_list, opts=AllreduceOptions()):
337        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
338        res = coll.join(self._rank, tensor_list)
339        ProcessLocalGroup._end_coll(coll, self)
340        return res
341
342    def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()):
343        coll = ProcessLocalGroup._start_coll(AllReduce(opts.reduceOp), self)
344        res = coll.join(self._rank, tensor_list)
345        ProcessLocalGroup._end_coll(coll, self)
346        return res
347
348    def barrier(self, opts=BarrierOptions()):
349        return self.allreduce(tensor_list=[torch.ones(1)])
350
351    def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
352        coll = ProcessLocalGroup._start_coll(AllGather(), self)
353        res = coll.join(self._rank, (output_tensors, input_tensor))
354        ProcessLocalGroup._end_coll(coll, self)
355        return res
356
357    def _allgather_base(self, output_tensor, input_tensor, opts=AllgatherOptions()):
358        tensor_list = list(torch.chunk(output_tensor, self._world_size))
359        return self.allgather([tensor_list], [input_tensor], opts)
360
361    def broadcast(self, tensor_list, opts=BroadcastOptions()):
362        coll = ProcessLocalGroup._start_coll(Broadcast(opts.rootRank), self)
363        res = coll.join(self._rank, tensor_list)
364        ProcessLocalGroup._end_coll(coll, self)
365        return res
366
367    def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()):
368        coll = ProcessLocalGroup._start_coll(Scatter(opts.rootRank), self)
369        res = coll.join(self._rank, (output_tensors, input_tensors))
370        ProcessLocalGroup._end_coll(coll, self)
371        return res
372
373    def gather(self, output_tensors, input_tensors, opts=ScatterOptions()):
374        coll = ProcessLocalGroup._start_coll(Gather(opts.rootRank), self)
375        res = coll.join(self._rank, (output_tensors, input_tensors))
376        ProcessLocalGroup._end_coll(coll, self)
377        return res
378
379    def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()):
380        coll = ProcessLocalGroup._start_coll(ReduceScatter(opts.reduceOp), self)
381        res = coll.join(self._rank, (output_tensor, scatter_list))
382        ProcessLocalGroup._end_coll(coll, self)
383        return res
384
385    def _reduce_scatter_base(self, output_tensor, input_tensor, opts=ReduceScatterOptions()):
386        tensor_list = list(torch.chunk(input_tensor, self._world_size))
387        return self.reduce_scatter([output_tensor], [tensor_list], opts)
388
389    def reduce_scatter_tensor_coalesced(self, output_tensors, input_tensors, opts=ReduceScatterOptions()):
390        works = [
391            self._reduce_scatter_base(output_tensor, input_tensor, opts)
392            for output_tensor, input_tensor
393            in zip(output_tensors, input_tensors)
394        ]
395        for work in works[:-1]:
396            work.wait()
397        return works[-1]
398
399    def allgather_into_tensor_coalesced(self, output_tensor_list, input_tensor_list, opts=AllgatherOptions()):
400        res = None
401        for o_t, i_t in zip(output_tensor_list, input_tensor_list):
402            res = self._allgather_base(o_t, i_t)
403        return res
404
405    def __init__(self, rank, world_size):
406        super().__init__(rank, world_size)
407        self._rank = rank
408        self._world_size = world_size
409        world = dist.distributed_c10d._world
410        if isinstance(world, ThreadLocalWorld):
411            world = world._get_world()
412        self._world = weakref.ref(world)
413        self._ctx = torch.autograd.set_multithreading_enabled(False)
414
415    def size(self):
416        return self._world_size
417
418    @property
419    def pg_name(self):
420        """
421        return the global registered name of the current pg in the world
422        """
423        return self._world().pg_names[self]
424
425    @property
426    def group_name(self):
427        return self.pg_name
428
429    def getBackendName(self):
430        return "threaded"
431
432    def __repr__(self):
433        return f"ThreadedPG world_size:{self._world_size} rank:{self._rank}"
434
435
436def _create_threaded_pg(prefix_store, rank, world_size, timeout):
437    pg = ProcessLocalGroup(rank, world_size)
438    # https://github.com/pytorch/pytorch/pull/103033 changed store based barrier to optional
439    # When device mesh involves sub groups while store based barrier is not enabled in c10d,
440    # even though threaded pg actual collectives are assumed to be single threaded,
441    # different threads may be initializing different groups,
442    # leading to race conditions.
443    # For example, if we have a mesh of [[0, 1], [2, 3]], the sub groups
444    # (dim 0 and 1) would be initialized in different threads independently.
445    # In this case we can no longer rely on class or global variables
446    # but have to rely on store based barrier to make sure each group
447    # is ready separately before we can invoke collectives in any of the groups.
448
449    # the prefix store is already per group so we pass an empty name here
450    _store_based_barrier(rank, prefix_store, "", world_size, timeout)
451    return pg
452
453
454dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"])
455
456
457@dataclass
458class WorldData:
459    default_pg: dist.ProcessGroup
460    pg_map: Dict[dist.ProcessGroup, Tuple[str, Optional[Store]]]
461    pg_names: Dict[dist.ProcessGroup, str]
462    pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
463    pg_backend_config: Dict[dist.ProcessGroup, str]
464    group_count: int
465    tags_to_pg: Dict[str, List[dist.ProcessGroup]]
466    pg_to_tag: Dict[dist.ProcessGroup, str]
467    pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
468    pg_default_device: Dict[dist.ProcessGroup, torch.device]
469
470
471class ThreadLocalWorld:
472    _world = threading.local()
473
474    def _get_world(self) -> WorldData:
475        if not hasattr(ThreadLocalWorld._world, "world"):
476            ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {}, {}, {})
477        return ThreadLocalWorld._world.world
478
479    @property
480    def default_pg(self):
481        return self._get_world().default_pg
482
483    @default_pg.setter
484    def default_pg(self, value):
485        self._get_world().default_pg = value
486
487    @property
488    def pg_map(self):
489        return self._get_world().pg_map
490
491    @property
492    def pg_names(self):
493        return self._get_world().pg_names
494
495    @property
496    def pg_group_ranks(self):
497        return self._get_world().pg_group_ranks
498
499    @property
500    def pg_backend_config(self):
501        return self._get_world().pg_backend_config
502
503    @property
504    def group_count(self) -> int:
505        return self._get_world().group_count
506
507    @group_count.setter
508    def group_count(self, value):
509        self._get_world().group_count = value
510
511    @property
512    def tags_to_pg(self):
513        return self._get_world().tags_to_pg
514
515    @property
516    def pg_to_tag(self):
517        return self._get_world().pg_to_tag
518
519    @property
520    def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]:
521        return self._get_world().pg_coalesce_state
522
523    @property
524    def pg_default_device(self) -> Dict[dist.ProcessGroup, torch.device]:
525        return self._get_world().pg_default_device
526
527
528_old_pg_world = None
529_ctx_manager = None
530
531
532def _install_threaded_pg():
533    global _old_pg_world
534    global _ctx_manager
535    _old_pg_world = dist.distributed_c10d._world
536    dist.distributed_c10d._world = ThreadLocalWorld()
537    _ctx_manager = torch.autograd.set_multithreading_enabled(False)
538
539    return dist.distributed_c10d._world
540
541
542def _uninstall_threaded_pg():
543    dist.distributed_c10d._world = _old_pg_world
544