1import os 2from contextlib import contextmanager, ExitStack 3from typing import Generator 4 5from torch.distributed.elastic.multiprocessing.errors import record 6 7 8__all__ = [ 9 "worker_main", 10] 11 12TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" 13 14 15@contextmanager 16def _worker_server(socket_path: str) -> Generator[None, None, None]: 17 from torch._C._distributed_c10d import _WorkerServer 18 19 server = _WorkerServer(socket_path) 20 try: 21 yield 22 finally: 23 server.shutdown() 24 25 26@contextmanager 27@record 28def worker_main() -> Generator[None, None, None]: 29 """ 30 This is a context manager that wraps your main entry function. This combines 31 the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that 32 exposes handlers via a unix socket specified by 33 ``Torch_WORKER_SERVER_SOCKET``. 34 35 Example 36 37 :: 38 39 @worker_main() 40 def main(): 41 pass 42 43 if __name__=="__main__": 44 main() 45 46 """ 47 with ExitStack() as stack: 48 socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) 49 if socket_path is not None: 50 stack.enter_context(_worker_server(socket_path)) 51 52 yield 53