1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: distributed"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom contextlib import closing 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist 8*da0073e9SAndroid Build Coastguard Workerimport torch.distributed.launch as launch 9*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.utils import get_socket_with_port 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerif not dist.is_available(): 13*da0073e9SAndroid Build Coastguard Worker print("Distributed not available, skipping tests", file=sys.stderr) 14*da0073e9SAndroid Build Coastguard Worker sys.exit(0) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 17*da0073e9SAndroid Build Coastguard Worker run_tests, 18*da0073e9SAndroid Build Coastguard Worker TEST_WITH_DEV_DBG_ASAN, 19*da0073e9SAndroid Build Coastguard Worker TestCase, 20*da0073e9SAndroid Build Coastguard Worker) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerdef path(script): 24*da0073e9SAndroid Build Coastguard Worker return os.path.join(os.path.dirname(__file__), script) 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerif TEST_WITH_DEV_DBG_ASAN: 28*da0073e9SAndroid Build Coastguard Worker print( 29*da0073e9SAndroid Build Coastguard Worker "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr 30*da0073e9SAndroid Build Coastguard Worker ) 31*da0073e9SAndroid Build Coastguard Worker sys.exit(0) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerclass TestDistributedLaunch(TestCase): 35*da0073e9SAndroid Build Coastguard Worker def test_launch_user_script(self): 36*da0073e9SAndroid Build Coastguard Worker nnodes = 1 37*da0073e9SAndroid Build Coastguard Worker nproc_per_node = 4 38*da0073e9SAndroid Build Coastguard Worker world_size = nnodes * nproc_per_node 39*da0073e9SAndroid Build Coastguard Worker sock = get_socket_with_port() 40*da0073e9SAndroid Build Coastguard Worker with closing(sock): 41*da0073e9SAndroid Build Coastguard Worker master_port = sock.getsockname()[1] 42*da0073e9SAndroid Build Coastguard Worker args = [ 43*da0073e9SAndroid Build Coastguard Worker f"--nnodes={nnodes}", 44*da0073e9SAndroid Build Coastguard Worker f"--nproc-per-node={nproc_per_node}", 45*da0073e9SAndroid Build Coastguard Worker "--monitor-interval=1", 46*da0073e9SAndroid Build Coastguard Worker "--start-method=spawn", 47*da0073e9SAndroid Build Coastguard Worker "--master-addr=localhost", 48*da0073e9SAndroid Build Coastguard Worker f"--master-port={master_port}", 49*da0073e9SAndroid Build Coastguard Worker "--node-rank=0", 50*da0073e9SAndroid Build Coastguard Worker "--use-env", 51*da0073e9SAndroid Build Coastguard Worker path("bin/test_script.py"), 52*da0073e9SAndroid Build Coastguard Worker ] 53*da0073e9SAndroid Build Coastguard Worker launch.main(args) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 57*da0073e9SAndroid Build Coastguard Worker run_tests() 58