1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10import datetime 11from multiprocessing.pool import ThreadPool 12from typing import List 13from unittest import mock 14 15import torch.distributed as dist 16import torch.distributed.elastic.utils.store as store_util 17from torch.distributed.elastic.utils.logging import get_logger 18from torch.testing._internal.common_utils import run_tests, TestCase 19 20 21class MockStore: 22 _TEST_TIMEOUT = 1234 23 24 def __init__(self) -> None: 25 self.ops = [] 26 27 def set_timeout(self, timeout: float) -> None: 28 self.ops.append(("set_timeout", timeout)) 29 30 @property 31 def timeout(self) -> datetime.timedelta: 32 self.ops.append(("timeout",)) 33 34 return datetime.timedelta(seconds=self._TEST_TIMEOUT) 35 36 def set(self, key: str, value: str) -> None: 37 self.ops.append(("set", key, value)) 38 39 def get(self, key: str) -> str: 40 self.ops.append(("get", key)) 41 return "value" 42 43 def multi_get(self, keys: List[str]) -> List[str]: 44 self.ops.append(("multi_get", keys)) 45 return ["value"] * len(keys) 46 47 def add(self, key: str, val: int) -> int: 48 self.ops.append(("add", key, val)) 49 return 3 50 51 def wait(self, keys: List[str]) -> None: 52 self.ops.append(("wait", keys)) 53 54 55class StoreUtilTest(TestCase): 56 def test_get_all_rank_0(self): 57 world_size = 3 58 59 store = MockStore() 60 61 store_util.get_all(store, 0, "test/store", world_size) 62 63 self.assertListEqual( 64 store.ops, 65 [ 66 ("multi_get", ["test/store0", "test/store1", "test/store2"]), 67 ("add", "test/store/finished/num_members", 1), 68 ("set", "test/store/finished/last_member", "<val_ignored>"), 69 ("wait", ["test/store/finished/last_member"]), 70 ], 71 ) 72 73 def test_get_all_rank_n(self): 74 store = MockStore() 75 world_size = 3 76 store_util.get_all(store, 1, "test/store", world_size) 77 78 self.assertListEqual( 79 store.ops, 80 [ 81 ("multi_get", ["test/store0", "test/store1", "test/store2"]), 82 ("add", "test/store/finished/num_members", 1), 83 ("set", "test/store/finished/last_member", "<val_ignored>"), 84 ], 85 ) 86 87 def test_synchronize(self): 88 store = MockStore() 89 90 data = b"data0" 91 store_util.synchronize(store, data, 0, 3, key_prefix="test/store") 92 93 self.assertListEqual( 94 store.ops, 95 [ 96 ("timeout",), 97 ("set_timeout", datetime.timedelta(seconds=300)), 98 ("set", "test/store0", data), 99 ("multi_get", ["test/store0", "test/store1", "test/store2"]), 100 ("add", "test/store/finished/num_members", 1), 101 ("set", "test/store/finished/last_member", "<val_ignored>"), 102 ("wait", ["test/store/finished/last_member"]), 103 ("set_timeout", datetime.timedelta(seconds=store._TEST_TIMEOUT)), 104 ], 105 ) 106 107 def test_synchronize_hash_store(self) -> None: 108 N = 4 109 110 store = dist.HashStore() 111 112 def f(i: int): 113 return store_util.synchronize( 114 store, f"data{i}", i, N, key_prefix="test/store" 115 ) 116 117 with ThreadPool(N) as pool: 118 out = pool.map(f, range(N)) 119 120 self.assertListEqual(out, [[f"data{i}".encode() for i in range(N)]] * N) 121 122 def test_barrier(self): 123 store = MockStore() 124 125 store_util.barrier(store, 3, key_prefix="test/store") 126 127 self.assertListEqual( 128 store.ops, 129 [ 130 ("timeout",), 131 ("set_timeout", datetime.timedelta(seconds=300)), 132 ("add", "test/store/num_members", 1), 133 ("set", "test/store/last_member", "<val_ignored>"), 134 ("wait", ["test/store/last_member"]), 135 ("set_timeout", datetime.timedelta(seconds=store._TEST_TIMEOUT)), 136 ], 137 ) 138 139 def test_barrier_timeout_rank_tracing(self): 140 N = 3 141 142 store = dist.HashStore() 143 144 def run_barrier_for_rank(i: int): 145 try: 146 store_util.barrier( 147 store, 148 N, 149 key_prefix="test/store", 150 barrier_timeout=0.1, 151 rank=i, 152 rank_tracing_decoder=lambda x: f"Rank {x} host", 153 trace_timeout=0.01, 154 ) 155 except Exception as e: 156 return str(e) 157 return "" 158 159 with ThreadPool(N - 1) as pool: 160 outputs: List[str] = pool.map(run_barrier_for_rank, range(N - 1)) 161 162 self.assertTrue(any("missing_ranks=[Rank 2 host]" in msg for msg in outputs)) 163 164 self.assertTrue( 165 any( 166 "check rank 0 (Rank 0 host) for missing rank info" in msg 167 for msg in outputs 168 ) 169 ) 170 171 def test_barrier_timeout_operations(self): 172 import torch 173 174 DistStoreError = torch._C._DistStoreError 175 176 N = 3 177 store = MockStore() 178 179 # rank 0 180 with mock.patch.object(store, "wait") as wait_mock: 181 wait_mock.side_effect = [DistStoreError("test"), None, None] 182 183 with self.assertRaises(DistStoreError): 184 store_util.barrier( 185 store, 186 N, 187 key_prefix="test/store", 188 barrier_timeout=1, 189 rank=0, 190 rank_tracing_decoder=lambda x: f"Rank {x} host", 191 trace_timeout=0.1, 192 ) 193 194 self.assertListEqual( 195 store.ops, 196 [ 197 ("timeout",), 198 ("set_timeout", datetime.timedelta(seconds=1)), 199 ("add", "test/store/num_members", 1), 200 ("set", "test/store/last_member", "<val_ignored>"), 201 # wait for last member is mocked 202 ("set", "test/store0/TRACE", "<val_ignored>"), 203 # wait for each rank is mocked 204 ("set", "test/store/TRACING_GATE", "<val_ignored>"), 205 ], 206 ) 207 208 # rank 1 209 with mock.patch.object(store, "wait") as wait_mock: 210 store.ops = [] 211 212 wait_mock.side_effect = [ 213 DistStoreError("test"), 214 None, 215 ] 216 217 with self.assertRaises(DistStoreError): 218 store_util.barrier( 219 store, 220 N, 221 key_prefix="test/store", 222 barrier_timeout=1, 223 rank=1, 224 rank_tracing_decoder=lambda x: f"Rank {x} host", 225 trace_timeout=0.1, 226 ) 227 228 self.assertListEqual( 229 store.ops, 230 [ 231 ("timeout",), 232 ("set_timeout", datetime.timedelta(seconds=1)), 233 ("add", "test/store/num_members", 1), 234 ("set", "test/store/last_member", "<val_ignored>"), 235 ("set", "test/store1/TRACE", "<val_ignored>"), 236 # wait for gate is mocked 237 ], 238 ) 239 240 def test_barrier_hash_store(self) -> None: 241 N = 4 242 243 store = dist.HashStore() 244 245 def f(i: int): 246 store_util.barrier(store, N, key_prefix="test/store") 247 248 with ThreadPool(N) as pool: 249 out = pool.map(f, range(N)) 250 251 self.assertEqual(out, [None] * N) 252 253 254class UtilTest(TestCase): 255 def test_get_logger_different(self): 256 logger1 = get_logger("name1") 257 logger2 = get_logger("name2") 258 self.assertNotEqual(logger1.name, logger2.name) 259 260 def test_get_logger(self): 261 logger1 = get_logger() 262 self.assertEqual(__name__, logger1.name) 263 264 def test_get_logger_none(self): 265 logger1 = get_logger(None) 266 self.assertEqual(__name__, logger1.name) 267 268 def test_get_logger_custom_name(self): 269 logger1 = get_logger("test.module") 270 self.assertEqual("test.module", logger1.name) 271 272 273if __name__ == "__main__": 274 run_tests() 275