xref: /aosp_15_r20/external/pytorch/test/distributed/test_launcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: distributed"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom contextlib import closing
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch.distributed as dist
8*da0073e9SAndroid Build Coastguard Workerimport torch.distributed.launch as launch
9*da0073e9SAndroid Build Coastguard Workerfrom torch.distributed.elastic.utils import get_socket_with_port
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerif not dist.is_available():
13*da0073e9SAndroid Build Coastguard Worker    print("Distributed not available, skipping tests", file=sys.stderr)
14*da0073e9SAndroid Build Coastguard Worker    sys.exit(0)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
17*da0073e9SAndroid Build Coastguard Worker    run_tests,
18*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_DEV_DBG_ASAN,
19*da0073e9SAndroid Build Coastguard Worker    TestCase,
20*da0073e9SAndroid Build Coastguard Worker)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerdef path(script):
24*da0073e9SAndroid Build Coastguard Worker    return os.path.join(os.path.dirname(__file__), script)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerif TEST_WITH_DEV_DBG_ASAN:
28*da0073e9SAndroid Build Coastguard Worker    print(
29*da0073e9SAndroid Build Coastguard Worker        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
30*da0073e9SAndroid Build Coastguard Worker    )
31*da0073e9SAndroid Build Coastguard Worker    sys.exit(0)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerclass TestDistributedLaunch(TestCase):
35*da0073e9SAndroid Build Coastguard Worker    def test_launch_user_script(self):
36*da0073e9SAndroid Build Coastguard Worker        nnodes = 1
37*da0073e9SAndroid Build Coastguard Worker        nproc_per_node = 4
38*da0073e9SAndroid Build Coastguard Worker        world_size = nnodes * nproc_per_node
39*da0073e9SAndroid Build Coastguard Worker        sock = get_socket_with_port()
40*da0073e9SAndroid Build Coastguard Worker        with closing(sock):
41*da0073e9SAndroid Build Coastguard Worker            master_port = sock.getsockname()[1]
42*da0073e9SAndroid Build Coastguard Worker        args = [
43*da0073e9SAndroid Build Coastguard Worker            f"--nnodes={nnodes}",
44*da0073e9SAndroid Build Coastguard Worker            f"--nproc-per-node={nproc_per_node}",
45*da0073e9SAndroid Build Coastguard Worker            "--monitor-interval=1",
46*da0073e9SAndroid Build Coastguard Worker            "--start-method=spawn",
47*da0073e9SAndroid Build Coastguard Worker            "--master-addr=localhost",
48*da0073e9SAndroid Build Coastguard Worker            f"--master-port={master_port}",
49*da0073e9SAndroid Build Coastguard Worker            "--node-rank=0",
50*da0073e9SAndroid Build Coastguard Worker            "--use-env",
51*da0073e9SAndroid Build Coastguard Worker            path("bin/test_script.py"),
52*da0073e9SAndroid Build Coastguard Worker        ]
53*da0073e9SAndroid Build Coastguard Worker        launch.main(args)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
57*da0073e9SAndroid Build Coastguard Worker    run_tests()
58