xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/distributed_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3from contextlib import contextmanager
4from datetime import timedelta
5from functools import (
6    partial,
7    wraps,
8)
9
10import torch.distributed as dist
11import torch.distributed.distributed_c10d as c10d
12
13class MockProcessGroup(dist.ProcessGroup):
14
15    def __init__(self, rank, world):
16        super().__init__(rank, world)
17
18    def getBackendName(self):
19        return "mock_process_group"
20
21def create_mock_pg(prefix_store, rank, world_size, timeout):
22    return MockProcessGroup(rank, world_size)
23
24dist.Backend.register_backend('mock_process_group', create_mock_pg)
25
26def mock_init_dist(rank, world_size):
27    # !!! WARNING !!!
28    # Kids don't try this at home, this is a cute pile of hacks that
29    # depends on a small mountain of c10d internals
30    assert not dist.is_initialized()
31    store = dist.HashStore()
32    # Trick _store_based_barrier into believing everyone else already checked-in
33    # Zero is the group index
34    store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1)
35    dist.init_process_group(
36        backend="mock_process_group",
37        rank=rank,
38        world_size=world_size,
39        store=store,
40        group_name="fake",
41        timeout=timedelta(seconds=1))
42
43@contextmanager
44def with_dist(rank=0, world_size=2):
45    """
46    Context manager that initializer c10d with a fake process group.
47    """
48    mock_init_dist(rank=rank, world_size=world_size)
49    try:
50        yield
51    finally:
52        dist.destroy_process_group()
53
54def with_fake_comms(func=None, rank=0, world_size=2):
55    """
56    Function wrapper that inits a fake process group designed for testing.
57    Right now only querying for world size is available
58    """
59    if func is None:
60        return partial(with_fake_comms, rank=rank, world_size=world_size)
61
62    @wraps(func)
63    def wrapper(self, *args, **kwargs):
64        with with_dist(rank, world_size):
65            func(self, *args, **kwargs)
66    return wrapper
67