1# mypy: allow-untyped-defs 2 3from contextlib import contextmanager 4from datetime import timedelta 5from functools import ( 6 partial, 7 wraps, 8) 9 10import torch.distributed as dist 11import torch.distributed.distributed_c10d as c10d 12 13class MockProcessGroup(dist.ProcessGroup): 14 15 def __init__(self, rank, world): 16 super().__init__(rank, world) 17 18 def getBackendName(self): 19 return "mock_process_group" 20 21def create_mock_pg(prefix_store, rank, world_size, timeout): 22 return MockProcessGroup(rank, world_size) 23 24dist.Backend.register_backend('mock_process_group', create_mock_pg) 25 26def mock_init_dist(rank, world_size): 27 # !!! WARNING !!! 28 # Kids don't try this at home, this is a cute pile of hacks that 29 # depends on a small mountain of c10d internals 30 assert not dist.is_initialized() 31 store = dist.HashStore() 32 # Trick _store_based_barrier into believing everyone else already checked-in 33 # Zero is the group index 34 store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1) 35 dist.init_process_group( 36 backend="mock_process_group", 37 rank=rank, 38 world_size=world_size, 39 store=store, 40 group_name="fake", 41 timeout=timedelta(seconds=1)) 42 43@contextmanager 44def with_dist(rank=0, world_size=2): 45 """ 46 Context manager that initializer c10d with a fake process group. 47 """ 48 mock_init_dist(rank=rank, world_size=world_size) 49 try: 50 yield 51 finally: 52 dist.destroy_process_group() 53 54def with_fake_comms(func=None, rank=0, world_size=2): 55 """ 56 Function wrapper that inits a fake process group designed for testing. 57 Right now only querying for world size is available 58 """ 59 if func is None: 60 return partial(with_fake_comms, rank=rank, world_size=world_size) 61 62 @wraps(func) 63 def wrapper(self, *args, **kwargs): 64 with with_dist(rank, world_size): 65 func(self, *args, **kwargs) 66 return wrapper 67