1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import socket 10import threading 11import time 12from datetime import timedelta 13from typing import List 14from unittest import TestCase 15from unittest.mock import patch 16 17from torch.distributed.elastic.rendezvous.utils import ( 18 _delay, 19 _matches_machine_hostname, 20 _parse_rendezvous_config, 21 _PeriodicTimer, 22 _try_parse_port, 23 parse_rendezvous_endpoint, 24) 25 26 27class UtilsTest(TestCase): 28 def test_parse_rendezvous_config_returns_dict(self) -> None: 29 expected_config = { 30 "a": "dummy1", 31 "b": "dummy2", 32 "c": "dummy3=dummy4", 33 "d": "dummy5/dummy6", 34 } 35 36 config = _parse_rendezvous_config( 37 " b= dummy2 ,c=dummy3=dummy4, a =dummy1,d=dummy5/dummy6" 38 ) 39 40 self.assertEqual(config, expected_config) 41 42 def test_parse_rendezvous_returns_empty_dict_if_str_is_empty(self) -> None: 43 config_strs = ["", " "] 44 45 for config_str in config_strs: 46 with self.subTest(config_str=config_str): 47 config = _parse_rendezvous_config(config_str) 48 49 self.assertEqual(config, {}) 50 51 def test_parse_rendezvous_raises_error_if_str_is_invalid(self) -> None: 52 config_strs = [ 53 "a=dummy1,", 54 "a=dummy1,,c=dummy2", 55 "a=dummy1, ,c=dummy2", 56 "a=dummy1,= ,c=dummy2", 57 "a=dummy1, = ,c=dummy2", 58 "a=dummy1, =,c=dummy2", 59 " , ", 60 ] 61 62 for config_str in config_strs: 63 with self.subTest(config_str=config_str): 64 with self.assertRaisesRegex( 65 ValueError, 66 r"^The rendezvous configuration string must be in format " 67 r"<key1>=<value1>,...,<keyN>=<valueN>.$", 68 ): 69 _parse_rendezvous_config(config_str) 70 71 def test_parse_rendezvous_raises_error_if_value_is_empty(self) -> None: 72 config_strs = [ 73 "b=dummy1,a,c=dummy2", 74 "b=dummy1,c=dummy2,a", 75 "b=dummy1,a=,c=dummy2", 76 " a ", 77 ] 78 79 for config_str in config_strs: 80 with self.subTest(config_str=config_str): 81 with self.assertRaisesRegex( 82 ValueError, 83 r"^The rendezvous configuration option 'a' must have a value specified.$", 84 ): 85 _parse_rendezvous_config(config_str) 86 87 def test_try_parse_port_returns_port(self) -> None: 88 port = _try_parse_port("123") 89 90 self.assertEqual(port, 123) 91 92 def test_try_parse_port_returns_none_if_str_is_invalid(self) -> None: 93 port_strs = [ 94 "", 95 " ", 96 " 1", 97 "1 ", 98 " 1 ", 99 "abc", 100 ] 101 102 for port_str in port_strs: 103 with self.subTest(port_str=port_str): 104 port = _try_parse_port(port_str) 105 106 self.assertIsNone(port) 107 108 def test_parse_rendezvous_endpoint_returns_tuple(self) -> None: 109 endpoints = [ 110 "dummy.com:0", 111 "dummy.com:123", 112 "dummy.com:65535", 113 "dummy-1.com:0", 114 "dummy-1.com:123", 115 "dummy-1.com:65535", 116 "123.123.123.123:0", 117 "123.123.123.123:123", 118 "123.123.123.123:65535", 119 "[2001:db8::1]:0", 120 "[2001:db8::1]:123", 121 "[2001:db8::1]:65535", 122 ] 123 124 for endpoint in endpoints: 125 with self.subTest(endpoint=endpoint): 126 host, port = parse_rendezvous_endpoint(endpoint, default_port=123) 127 128 expected_host, expected_port = endpoint.rsplit(":", 1) 129 130 if expected_host[0] == "[" and expected_host[-1] == "]": 131 expected_host = expected_host[1:-1] 132 133 self.assertEqual(host, expected_host) 134 self.assertEqual(port, int(expected_port)) 135 136 def test_parse_rendezvous_endpoint_returns_tuple_if_endpoint_has_no_port( 137 self, 138 ) -> None: 139 endpoints = ["dummy.com", "dummy-1.com", "123.123.123.123", "[2001:db8::1]"] 140 141 for endpoint in endpoints: 142 with self.subTest(endpoint=endpoint): 143 host, port = parse_rendezvous_endpoint(endpoint, default_port=123) 144 145 expected_host = endpoint 146 147 if expected_host[0] == "[" and expected_host[-1] == "]": 148 expected_host = expected_host[1:-1] 149 150 self.assertEqual(host, expected_host) 151 self.assertEqual(port, 123) 152 153 def test_parse_rendezvous_endpoint_returns_tuple_if_endpoint_is_empty(self) -> None: 154 endpoints = ["", " "] 155 156 for endpoint in endpoints: 157 with self.subTest(endpoint=endpoint): 158 host, port = parse_rendezvous_endpoint("", default_port=123) 159 160 self.assertEqual(host, "localhost") 161 self.assertEqual(port, 123) 162 163 def test_parse_rendezvous_endpoint_raises_error_if_hostname_is_invalid( 164 self, 165 ) -> None: 166 endpoints = ["~", "dummy.com :123", "~:123", ":123"] 167 168 for endpoint in endpoints: 169 with self.subTest(endpoint=endpoint): 170 with self.assertRaisesRegex( 171 ValueError, 172 rf"^The hostname of the rendezvous endpoint '{endpoint}' must be a " 173 r"dot-separated list of labels, an IPv4 address, or an IPv6 address.$", 174 ): 175 parse_rendezvous_endpoint(endpoint, default_port=123) 176 177 def test_parse_rendezvous_endpoint_raises_error_if_port_is_invalid(self) -> None: 178 endpoints = ["dummy.com:", "dummy.com:abc", "dummy.com:-123", "dummy.com:-"] 179 180 for endpoint in endpoints: 181 with self.subTest(endpoint=endpoint): 182 with self.assertRaisesRegex( 183 ValueError, 184 rf"^The port number of the rendezvous endpoint '{endpoint}' must be an integer " 185 r"between 0 and 65536.$", 186 ): 187 parse_rendezvous_endpoint(endpoint, default_port=123) 188 189 def test_parse_rendezvous_endpoint_raises_error_if_port_is_too_big(self) -> None: 190 endpoints = ["dummy.com:65536", "dummy.com:70000"] 191 192 for endpoint in endpoints: 193 with self.subTest(endpoint=endpoint): 194 with self.assertRaisesRegex( 195 ValueError, 196 rf"^The port number of the rendezvous endpoint '{endpoint}' must be an integer " 197 r"between 0 and 65536.$", 198 ): 199 parse_rendezvous_endpoint(endpoint, default_port=123) 200 201 def test_matches_machine_hostname_returns_true_if_hostname_is_loopback( 202 self, 203 ) -> None: 204 hosts = [ 205 "localhost", 206 "127.0.0.1", 207 "::1", 208 "0000:0000:0000:0000:0000:0000:0000:0001", 209 ] 210 211 for host in hosts: 212 with self.subTest(host=host): 213 self.assertTrue(_matches_machine_hostname(host)) 214 215 def test_matches_machine_hostname_returns_true_if_hostname_is_machine_hostname( 216 self, 217 ) -> None: 218 host = socket.gethostname() 219 220 self.assertTrue(_matches_machine_hostname(host)) 221 222 def test_matches_machine_hostname_returns_true_if_hostname_is_machine_fqdn( 223 self, 224 ) -> None: 225 host = socket.getfqdn() 226 227 self.assertTrue(_matches_machine_hostname(host)) 228 229 def test_matches_machine_hostname_returns_true_if_hostname_is_machine_address( 230 self, 231 ) -> None: 232 addr_list = socket.getaddrinfo( 233 socket.gethostname(), None, proto=socket.IPPROTO_TCP 234 ) 235 236 for addr in (addr_info[4][0] for addr_info in addr_list): 237 with self.subTest(addr=addr): 238 self.assertTrue(_matches_machine_hostname(addr)) 239 240 def test_matches_machine_hostname_returns_false_if_hostname_does_not_match( 241 self, 242 ) -> None: 243 hosts = ["dummy", "0.0.0.0", "::2"] 244 245 for host in hosts: 246 with self.subTest(host=host): 247 self.assertFalse(_matches_machine_hostname(host)) 248 249 def test_delay_suspends_thread(self) -> None: 250 for seconds in 0.2, (0.2, 0.4): 251 with self.subTest(seconds=seconds): 252 time1 = time.monotonic() 253 254 _delay(seconds) # type: ignore[arg-type] 255 256 time2 = time.monotonic() 257 258 self.assertGreaterEqual(time2 - time1, 0.2) 259 260 @patch( 261 "socket.getaddrinfo", 262 side_effect=[ 263 [(None, None, 0, "a_host", ("1.2.3.4", 0))], 264 [(None, None, 0, "a_different_host", ("1.2.3.4", 0))], 265 ], 266 ) 267 def test_matches_machine_hostname_returns_true_if_ip_address_match_between_hosts( 268 self, 269 _0, 270 ) -> None: 271 self.assertTrue(_matches_machine_hostname("a_host")) 272 273 @patch( 274 "socket.getaddrinfo", 275 side_effect=[ 276 [(None, None, 0, "a_host", ("1.2.3.4", 0))], 277 [(None, None, 0, "another_host_with_different_ip", ("1.2.3.5", 0))], 278 ], 279 ) 280 def test_matches_machine_hostname_returns_false_if_ip_address_not_match_between_hosts( 281 self, 282 _0, 283 ) -> None: 284 self.assertFalse(_matches_machine_hostname("a_host")) 285 286 287class PeriodicTimerTest(TestCase): 288 def test_start_can_be_called_only_once(self) -> None: 289 timer = _PeriodicTimer(timedelta(seconds=1), lambda: None) 290 291 timer.start() 292 293 with self.assertRaisesRegex(RuntimeError, r"^The timer has already started.$"): 294 timer.start() 295 296 timer.cancel() 297 298 def test_cancel_can_be_called_multiple_times(self) -> None: 299 timer = _PeriodicTimer(timedelta(seconds=1), lambda: None) 300 301 timer.start() 302 303 timer.cancel() 304 timer.cancel() 305 306 def test_cancel_stops_background_thread(self) -> None: 307 name = "PeriodicTimer_CancelStopsBackgroundThreadTest" 308 309 timer = _PeriodicTimer(timedelta(seconds=1), lambda: None) 310 311 timer.set_name(name) 312 313 timer.start() 314 315 self.assertTrue(any(t.name == name for t in threading.enumerate())) 316 317 timer.cancel() 318 319 self.assertTrue(all(t.name != name for t in threading.enumerate())) 320 321 def test_delete_stops_background_thread(self) -> None: 322 name = "PeriodicTimer_DeleteStopsBackgroundThreadTest" 323 324 timer = _PeriodicTimer(timedelta(seconds=1), lambda: None) 325 326 timer.set_name(name) 327 328 timer.start() 329 330 self.assertTrue(any(t.name == name for t in threading.enumerate())) 331 332 del timer 333 334 self.assertTrue(all(t.name != name for t in threading.enumerate())) 335 336 def test_set_name_cannot_be_called_after_start(self) -> None: 337 timer = _PeriodicTimer(timedelta(seconds=1), lambda: None) 338 339 timer.start() 340 341 with self.assertRaisesRegex(RuntimeError, r"^The timer has already started.$"): 342 timer.set_name("dummy_name") 343 344 timer.cancel() 345 346 def test_timer_calls_background_thread_at_regular_intervals(self) -> None: 347 timer_begin_time: float 348 349 # Call our function every 200ms. 350 call_interval = 0.2 351 352 # Keep the log of intervals between each consecutive call. 353 actual_call_intervals: List[float] = [] 354 355 # Keep the number of times the function was called. 356 call_count = 0 357 358 # In order to prevent a flaky test instead of asserting that the 359 # function was called an exact number of times we use a lower bound 360 # that is guaranteed to be true for a correct implementation. 361 min_required_call_count = 4 362 363 timer_stop_event = threading.Event() 364 365 def log_call(self): 366 nonlocal timer_begin_time, call_count 367 368 actual_call_intervals.append(time.monotonic() - timer_begin_time) 369 370 call_count += 1 371 if call_count == min_required_call_count: 372 timer_stop_event.set() 373 374 timer_begin_time = time.monotonic() 375 376 timer = _PeriodicTimer(timedelta(seconds=call_interval), log_call, self) 377 378 timer_begin_time = time.monotonic() 379 380 timer.start() 381 382 # Although this is theoretically non-deterministic, if our timer, which 383 # has a 200ms call interval, does not get called 4 times in 60 seconds, 384 # there is very likely something else going on. 385 timer_stop_event.wait(60) 386 387 timer.cancel() 388 389 self.longMessage = False 390 391 self.assertGreaterEqual( 392 call_count, 393 min_required_call_count, 394 f"The function has been called {call_count} time(s) but expected to be called at least " 395 f"{min_required_call_count} time(s).", 396 ) 397 398 for actual_call_interval in actual_call_intervals: 399 self.assertGreaterEqual( 400 actual_call_interval, 401 call_interval, 402 f"The interval between two function calls was {actual_call_interval} second(s) but " 403 f"expected to be at least {call_interval} second(s).", 404 ) 405