xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/control_plane.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2from contextlib import contextmanager, ExitStack
3from typing import Generator
4
5from torch.distributed.elastic.multiprocessing.errors import record
6
7
8__all__ = [
9    "worker_main",
10]
11
12TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
13
14
15@contextmanager
16def _worker_server(socket_path: str) -> Generator[None, None, None]:
17    from torch._C._distributed_c10d import _WorkerServer
18
19    server = _WorkerServer(socket_path)
20    try:
21        yield
22    finally:
23        server.shutdown()
24
25
26@contextmanager
27@record
28def worker_main() -> Generator[None, None, None]:
29    """
30    This is a context manager that wraps your main entry function. This combines
31    the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
32    exposes handlers via a unix socket specified by
33    ``Torch_WORKER_SERVER_SOCKET``.
34
35    Example
36
37    ::
38
39     @worker_main()
40     def main():
41         pass
42
43     if __name__=="__main__":
44        main()
45
46    """
47    with ExitStack() as stack:
48        socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
49        if socket_path is not None:
50            stack.enter_context(_worker_server(socket_path))
51
52        yield
53