xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/utils/util_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import datetime
11from multiprocessing.pool import ThreadPool
12from typing import List
13from unittest import mock
14
15import torch.distributed as dist
16import torch.distributed.elastic.utils.store as store_util
17from torch.distributed.elastic.utils.logging import get_logger
18from torch.testing._internal.common_utils import run_tests, TestCase
19
20
21class MockStore:
22    _TEST_TIMEOUT = 1234
23
24    def __init__(self) -> None:
25        self.ops = []
26
27    def set_timeout(self, timeout: float) -> None:
28        self.ops.append(("set_timeout", timeout))
29
30    @property
31    def timeout(self) -> datetime.timedelta:
32        self.ops.append(("timeout",))
33
34        return datetime.timedelta(seconds=self._TEST_TIMEOUT)
35
36    def set(self, key: str, value: str) -> None:
37        self.ops.append(("set", key, value))
38
39    def get(self, key: str) -> str:
40        self.ops.append(("get", key))
41        return "value"
42
43    def multi_get(self, keys: List[str]) -> List[str]:
44        self.ops.append(("multi_get", keys))
45        return ["value"] * len(keys)
46
47    def add(self, key: str, val: int) -> int:
48        self.ops.append(("add", key, val))
49        return 3
50
51    def wait(self, keys: List[str]) -> None:
52        self.ops.append(("wait", keys))
53
54
55class StoreUtilTest(TestCase):
56    def test_get_all_rank_0(self):
57        world_size = 3
58
59        store = MockStore()
60
61        store_util.get_all(store, 0, "test/store", world_size)
62
63        self.assertListEqual(
64            store.ops,
65            [
66                ("multi_get", ["test/store0", "test/store1", "test/store2"]),
67                ("add", "test/store/finished/num_members", 1),
68                ("set", "test/store/finished/last_member", "<val_ignored>"),
69                ("wait", ["test/store/finished/last_member"]),
70            ],
71        )
72
73    def test_get_all_rank_n(self):
74        store = MockStore()
75        world_size = 3
76        store_util.get_all(store, 1, "test/store", world_size)
77
78        self.assertListEqual(
79            store.ops,
80            [
81                ("multi_get", ["test/store0", "test/store1", "test/store2"]),
82                ("add", "test/store/finished/num_members", 1),
83                ("set", "test/store/finished/last_member", "<val_ignored>"),
84            ],
85        )
86
87    def test_synchronize(self):
88        store = MockStore()
89
90        data = b"data0"
91        store_util.synchronize(store, data, 0, 3, key_prefix="test/store")
92
93        self.assertListEqual(
94            store.ops,
95            [
96                ("timeout",),
97                ("set_timeout", datetime.timedelta(seconds=300)),
98                ("set", "test/store0", data),
99                ("multi_get", ["test/store0", "test/store1", "test/store2"]),
100                ("add", "test/store/finished/num_members", 1),
101                ("set", "test/store/finished/last_member", "<val_ignored>"),
102                ("wait", ["test/store/finished/last_member"]),
103                ("set_timeout", datetime.timedelta(seconds=store._TEST_TIMEOUT)),
104            ],
105        )
106
107    def test_synchronize_hash_store(self) -> None:
108        N = 4
109
110        store = dist.HashStore()
111
112        def f(i: int):
113            return store_util.synchronize(
114                store, f"data{i}", i, N, key_prefix="test/store"
115            )
116
117        with ThreadPool(N) as pool:
118            out = pool.map(f, range(N))
119
120        self.assertListEqual(out, [[f"data{i}".encode() for i in range(N)]] * N)
121
122    def test_barrier(self):
123        store = MockStore()
124
125        store_util.barrier(store, 3, key_prefix="test/store")
126
127        self.assertListEqual(
128            store.ops,
129            [
130                ("timeout",),
131                ("set_timeout", datetime.timedelta(seconds=300)),
132                ("add", "test/store/num_members", 1),
133                ("set", "test/store/last_member", "<val_ignored>"),
134                ("wait", ["test/store/last_member"]),
135                ("set_timeout", datetime.timedelta(seconds=store._TEST_TIMEOUT)),
136            ],
137        )
138
139    def test_barrier_timeout_rank_tracing(self):
140        N = 3
141
142        store = dist.HashStore()
143
144        def run_barrier_for_rank(i: int):
145            try:
146                store_util.barrier(
147                    store,
148                    N,
149                    key_prefix="test/store",
150                    barrier_timeout=0.1,
151                    rank=i,
152                    rank_tracing_decoder=lambda x: f"Rank {x} host",
153                    trace_timeout=0.01,
154                )
155            except Exception as e:
156                return str(e)
157            return ""
158
159        with ThreadPool(N - 1) as pool:
160            outputs: List[str] = pool.map(run_barrier_for_rank, range(N - 1))
161
162        self.assertTrue(any("missing_ranks=[Rank 2 host]" in msg for msg in outputs))
163
164        self.assertTrue(
165            any(
166                "check rank 0 (Rank 0 host) for missing rank info" in msg
167                for msg in outputs
168            )
169        )
170
171    def test_barrier_timeout_operations(self):
172        import torch
173
174        DistStoreError = torch._C._DistStoreError
175
176        N = 3
177        store = MockStore()
178
179        # rank 0
180        with mock.patch.object(store, "wait") as wait_mock:
181            wait_mock.side_effect = [DistStoreError("test"), None, None]
182
183            with self.assertRaises(DistStoreError):
184                store_util.barrier(
185                    store,
186                    N,
187                    key_prefix="test/store",
188                    barrier_timeout=1,
189                    rank=0,
190                    rank_tracing_decoder=lambda x: f"Rank {x} host",
191                    trace_timeout=0.1,
192                )
193
194            self.assertListEqual(
195                store.ops,
196                [
197                    ("timeout",),
198                    ("set_timeout", datetime.timedelta(seconds=1)),
199                    ("add", "test/store/num_members", 1),
200                    ("set", "test/store/last_member", "<val_ignored>"),
201                    # wait for last member is mocked
202                    ("set", "test/store0/TRACE", "<val_ignored>"),
203                    # wait for each rank is mocked
204                    ("set", "test/store/TRACING_GATE", "<val_ignored>"),
205                ],
206            )
207
208        # rank 1
209        with mock.patch.object(store, "wait") as wait_mock:
210            store.ops = []
211
212            wait_mock.side_effect = [
213                DistStoreError("test"),
214                None,
215            ]
216
217            with self.assertRaises(DistStoreError):
218                store_util.barrier(
219                    store,
220                    N,
221                    key_prefix="test/store",
222                    barrier_timeout=1,
223                    rank=1,
224                    rank_tracing_decoder=lambda x: f"Rank {x} host",
225                    trace_timeout=0.1,
226                )
227
228            self.assertListEqual(
229                store.ops,
230                [
231                    ("timeout",),
232                    ("set_timeout", datetime.timedelta(seconds=1)),
233                    ("add", "test/store/num_members", 1),
234                    ("set", "test/store/last_member", "<val_ignored>"),
235                    ("set", "test/store1/TRACE", "<val_ignored>"),
236                    # wait for gate is mocked
237                ],
238            )
239
240    def test_barrier_hash_store(self) -> None:
241        N = 4
242
243        store = dist.HashStore()
244
245        def f(i: int):
246            store_util.barrier(store, N, key_prefix="test/store")
247
248        with ThreadPool(N) as pool:
249            out = pool.map(f, range(N))
250
251        self.assertEqual(out, [None] * N)
252
253
254class UtilTest(TestCase):
255    def test_get_logger_different(self):
256        logger1 = get_logger("name1")
257        logger2 = get_logger("name2")
258        self.assertNotEqual(logger1.name, logger2.name)
259
260    def test_get_logger(self):
261        logger1 = get_logger()
262        self.assertEqual(__name__, logger1.name)
263
264    def test_get_logger_none(self):
265        logger1 = get_logger(None)
266        self.assertEqual(__name__, logger1.name)
267
268    def test_get_logger_custom_name(self):
269        logger1 = get_logger("test.module")
270        self.assertEqual("test.module", logger1.name)
271
272
273if __name__ == "__main__":
274    run_tests()
275