xref: /aosp_15_r20/external/pytorch/test/distributed/test_distributed_spawn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4import sys
5
6import torch
7import torch.distributed as dist
8
9
10torch.backends.cuda.matmul.allow_tf32 = False
11
12if not dist.is_available():
13    print("Distributed not available, skipping tests", file=sys.stderr)
14    sys.exit(0)
15
16from torch.testing._internal.common_utils import (
17    NO_MULTIPROCESSING_SPAWN,
18    run_tests,
19    TEST_WITH_DEV_DBG_ASAN,
20)
21from torch.testing._internal.distributed.distributed_test import (
22    DistributedTest,
23    TestDistBackend,
24)
25
26
27if TEST_WITH_DEV_DBG_ASAN:
28    print(
29        "Skip dev-asan as torch + multiprocessing spawn have known issues",
30        file=sys.stderr,
31    )
32    sys.exit(0)
33
34if NO_MULTIPROCESSING_SPAWN:
35    print("Spawn not available, skipping tests.", file=sys.stderr)
36    sys.exit(0)
37
38_allowed_backends = ("gloo", "nccl", "ucc")
39if (
40    "BACKEND" not in os.environ
41    or "WORLD_SIZE" not in os.environ
42    or "TEMP_DIR" not in os.environ
43):
44    # TODO can we actually have `run_tests.py` emit the complete instructions when it prints a repro command?
45    raise RuntimeError(
46        "Missing expected env vars for `test_distributed_spawn.py`.  Please ensure to specify the following:\n"
47        f"'BACKEND' = one of {_allowed_backends}\n"
48        f"'WORLD_SIZE' = int >= 2\n"
49        "'TEMP_DIR' specifying a directory containing a barrier file named 'barrier'.\n\n"
50        f"e.g.\ntouch /tmp/barrier && TEMP_DIR=/tmp BACKEND='nccl' WORLD_SIZE=2 python {__file__}",
51    )
52
53BACKEND = os.environ["BACKEND"]
54
55if BACKEND in _allowed_backends:
56
57    class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase):
58        def setUp(self):
59            super().setUp()
60            self._spawn_processes()
61            torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
62
63else:
64    print(f"Invalid backend {BACKEND}. Tests will not be run!")
65
66
67if __name__ == "__main__":
68    run_tests()
69