# Owner(s): ["oncall: distributed"] import os import sys from contextlib import closing import torch.distributed as dist import torch.distributed.launch as launch from torch.distributed.elastic.utils import get_socket_with_port if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) from torch.testing._internal.common_utils import ( run_tests, TEST_WITH_DEV_DBG_ASAN, TestCase, ) def path(script): return os.path.join(os.path.dirname(__file__), script) if TEST_WITH_DEV_DBG_ASAN: print( "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr ) sys.exit(0) class TestDistributedLaunch(TestCase): def test_launch_user_script(self): nnodes = 1 nproc_per_node = 4 world_size = nnodes * nproc_per_node sock = get_socket_with_port() with closing(sock): master_port = sock.getsockname()[1] args = [ f"--nnodes={nnodes}", f"--nproc-per-node={nproc_per_node}", "--monitor-interval=1", "--start-method=spawn", "--master-addr=localhost", f"--master-port={master_port}", "--node-rank=0", "--use-env", path("bin/test_script.py"), ] launch.main(args) if __name__ == "__main__": run_tests()