xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/rendezvous/utils_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import socket
10import threading
11import time
12from datetime import timedelta
13from typing import List
14from unittest import TestCase
15from unittest.mock import patch
16
17from torch.distributed.elastic.rendezvous.utils import (
18    _delay,
19    _matches_machine_hostname,
20    _parse_rendezvous_config,
21    _PeriodicTimer,
22    _try_parse_port,
23    parse_rendezvous_endpoint,
24)
25
26
27class UtilsTest(TestCase):
28    def test_parse_rendezvous_config_returns_dict(self) -> None:
29        expected_config = {
30            "a": "dummy1",
31            "b": "dummy2",
32            "c": "dummy3=dummy4",
33            "d": "dummy5/dummy6",
34        }
35
36        config = _parse_rendezvous_config(
37            " b= dummy2  ,c=dummy3=dummy4,  a =dummy1,d=dummy5/dummy6"
38        )
39
40        self.assertEqual(config, expected_config)
41
42    def test_parse_rendezvous_returns_empty_dict_if_str_is_empty(self) -> None:
43        config_strs = ["", "   "]
44
45        for config_str in config_strs:
46            with self.subTest(config_str=config_str):
47                config = _parse_rendezvous_config(config_str)
48
49                self.assertEqual(config, {})
50
51    def test_parse_rendezvous_raises_error_if_str_is_invalid(self) -> None:
52        config_strs = [
53            "a=dummy1,",
54            "a=dummy1,,c=dummy2",
55            "a=dummy1,   ,c=dummy2",
56            "a=dummy1,=  ,c=dummy2",
57            "a=dummy1, = ,c=dummy2",
58            "a=dummy1,  =,c=dummy2",
59            " ,  ",
60        ]
61
62        for config_str in config_strs:
63            with self.subTest(config_str=config_str):
64                with self.assertRaisesRegex(
65                    ValueError,
66                    r"^The rendezvous configuration string must be in format "
67                    r"<key1>=<value1>,...,<keyN>=<valueN>.$",
68                ):
69                    _parse_rendezvous_config(config_str)
70
71    def test_parse_rendezvous_raises_error_if_value_is_empty(self) -> None:
72        config_strs = [
73            "b=dummy1,a,c=dummy2",
74            "b=dummy1,c=dummy2,a",
75            "b=dummy1,a=,c=dummy2",
76            "  a ",
77        ]
78
79        for config_str in config_strs:
80            with self.subTest(config_str=config_str):
81                with self.assertRaisesRegex(
82                    ValueError,
83                    r"^The rendezvous configuration option 'a' must have a value specified.$",
84                ):
85                    _parse_rendezvous_config(config_str)
86
87    def test_try_parse_port_returns_port(self) -> None:
88        port = _try_parse_port("123")
89
90        self.assertEqual(port, 123)
91
92    def test_try_parse_port_returns_none_if_str_is_invalid(self) -> None:
93        port_strs = [
94            "",
95            "   ",
96            "  1",
97            "1  ",
98            " 1 ",
99            "abc",
100        ]
101
102        for port_str in port_strs:
103            with self.subTest(port_str=port_str):
104                port = _try_parse_port(port_str)
105
106                self.assertIsNone(port)
107
108    def test_parse_rendezvous_endpoint_returns_tuple(self) -> None:
109        endpoints = [
110            "dummy.com:0",
111            "dummy.com:123",
112            "dummy.com:65535",
113            "dummy-1.com:0",
114            "dummy-1.com:123",
115            "dummy-1.com:65535",
116            "123.123.123.123:0",
117            "123.123.123.123:123",
118            "123.123.123.123:65535",
119            "[2001:db8::1]:0",
120            "[2001:db8::1]:123",
121            "[2001:db8::1]:65535",
122        ]
123
124        for endpoint in endpoints:
125            with self.subTest(endpoint=endpoint):
126                host, port = parse_rendezvous_endpoint(endpoint, default_port=123)
127
128                expected_host, expected_port = endpoint.rsplit(":", 1)
129
130                if expected_host[0] == "[" and expected_host[-1] == "]":
131                    expected_host = expected_host[1:-1]
132
133                self.assertEqual(host, expected_host)
134                self.assertEqual(port, int(expected_port))
135
136    def test_parse_rendezvous_endpoint_returns_tuple_if_endpoint_has_no_port(
137        self,
138    ) -> None:
139        endpoints = ["dummy.com", "dummy-1.com", "123.123.123.123", "[2001:db8::1]"]
140
141        for endpoint in endpoints:
142            with self.subTest(endpoint=endpoint):
143                host, port = parse_rendezvous_endpoint(endpoint, default_port=123)
144
145                expected_host = endpoint
146
147                if expected_host[0] == "[" and expected_host[-1] == "]":
148                    expected_host = expected_host[1:-1]
149
150                self.assertEqual(host, expected_host)
151                self.assertEqual(port, 123)
152
153    def test_parse_rendezvous_endpoint_returns_tuple_if_endpoint_is_empty(self) -> None:
154        endpoints = ["", "  "]
155
156        for endpoint in endpoints:
157            with self.subTest(endpoint=endpoint):
158                host, port = parse_rendezvous_endpoint("", default_port=123)
159
160                self.assertEqual(host, "localhost")
161                self.assertEqual(port, 123)
162
163    def test_parse_rendezvous_endpoint_raises_error_if_hostname_is_invalid(
164        self,
165    ) -> None:
166        endpoints = ["~", "dummy.com :123", "~:123", ":123"]
167
168        for endpoint in endpoints:
169            with self.subTest(endpoint=endpoint):
170                with self.assertRaisesRegex(
171                    ValueError,
172                    rf"^The hostname of the rendezvous endpoint '{endpoint}' must be a "
173                    r"dot-separated list of labels, an IPv4 address, or an IPv6 address.$",
174                ):
175                    parse_rendezvous_endpoint(endpoint, default_port=123)
176
177    def test_parse_rendezvous_endpoint_raises_error_if_port_is_invalid(self) -> None:
178        endpoints = ["dummy.com:", "dummy.com:abc", "dummy.com:-123", "dummy.com:-"]
179
180        for endpoint in endpoints:
181            with self.subTest(endpoint=endpoint):
182                with self.assertRaisesRegex(
183                    ValueError,
184                    rf"^The port number of the rendezvous endpoint '{endpoint}' must be an integer "
185                    r"between 0 and 65536.$",
186                ):
187                    parse_rendezvous_endpoint(endpoint, default_port=123)
188
189    def test_parse_rendezvous_endpoint_raises_error_if_port_is_too_big(self) -> None:
190        endpoints = ["dummy.com:65536", "dummy.com:70000"]
191
192        for endpoint in endpoints:
193            with self.subTest(endpoint=endpoint):
194                with self.assertRaisesRegex(
195                    ValueError,
196                    rf"^The port number of the rendezvous endpoint '{endpoint}' must be an integer "
197                    r"between 0 and 65536.$",
198                ):
199                    parse_rendezvous_endpoint(endpoint, default_port=123)
200
201    def test_matches_machine_hostname_returns_true_if_hostname_is_loopback(
202        self,
203    ) -> None:
204        hosts = [
205            "localhost",
206            "127.0.0.1",
207            "::1",
208            "0000:0000:0000:0000:0000:0000:0000:0001",
209        ]
210
211        for host in hosts:
212            with self.subTest(host=host):
213                self.assertTrue(_matches_machine_hostname(host))
214
215    def test_matches_machine_hostname_returns_true_if_hostname_is_machine_hostname(
216        self,
217    ) -> None:
218        host = socket.gethostname()
219
220        self.assertTrue(_matches_machine_hostname(host))
221
222    def test_matches_machine_hostname_returns_true_if_hostname_is_machine_fqdn(
223        self,
224    ) -> None:
225        host = socket.getfqdn()
226
227        self.assertTrue(_matches_machine_hostname(host))
228
229    def test_matches_machine_hostname_returns_true_if_hostname_is_machine_address(
230        self,
231    ) -> None:
232        addr_list = socket.getaddrinfo(
233            socket.gethostname(), None, proto=socket.IPPROTO_TCP
234        )
235
236        for addr in (addr_info[4][0] for addr_info in addr_list):
237            with self.subTest(addr=addr):
238                self.assertTrue(_matches_machine_hostname(addr))
239
240    def test_matches_machine_hostname_returns_false_if_hostname_does_not_match(
241        self,
242    ) -> None:
243        hosts = ["dummy", "0.0.0.0", "::2"]
244
245        for host in hosts:
246            with self.subTest(host=host):
247                self.assertFalse(_matches_machine_hostname(host))
248
249    def test_delay_suspends_thread(self) -> None:
250        for seconds in 0.2, (0.2, 0.4):
251            with self.subTest(seconds=seconds):
252                time1 = time.monotonic()
253
254                _delay(seconds)  # type: ignore[arg-type]
255
256                time2 = time.monotonic()
257
258                self.assertGreaterEqual(time2 - time1, 0.2)
259
260    @patch(
261        "socket.getaddrinfo",
262        side_effect=[
263            [(None, None, 0, "a_host", ("1.2.3.4", 0))],
264            [(None, None, 0, "a_different_host", ("1.2.3.4", 0))],
265        ],
266    )
267    def test_matches_machine_hostname_returns_true_if_ip_address_match_between_hosts(
268        self,
269        _0,
270    ) -> None:
271        self.assertTrue(_matches_machine_hostname("a_host"))
272
273    @patch(
274        "socket.getaddrinfo",
275        side_effect=[
276            [(None, None, 0, "a_host", ("1.2.3.4", 0))],
277            [(None, None, 0, "another_host_with_different_ip", ("1.2.3.5", 0))],
278        ],
279    )
280    def test_matches_machine_hostname_returns_false_if_ip_address_not_match_between_hosts(
281        self,
282        _0,
283    ) -> None:
284        self.assertFalse(_matches_machine_hostname("a_host"))
285
286
287class PeriodicTimerTest(TestCase):
288    def test_start_can_be_called_only_once(self) -> None:
289        timer = _PeriodicTimer(timedelta(seconds=1), lambda: None)
290
291        timer.start()
292
293        with self.assertRaisesRegex(RuntimeError, r"^The timer has already started.$"):
294            timer.start()
295
296        timer.cancel()
297
298    def test_cancel_can_be_called_multiple_times(self) -> None:
299        timer = _PeriodicTimer(timedelta(seconds=1), lambda: None)
300
301        timer.start()
302
303        timer.cancel()
304        timer.cancel()
305
306    def test_cancel_stops_background_thread(self) -> None:
307        name = "PeriodicTimer_CancelStopsBackgroundThreadTest"
308
309        timer = _PeriodicTimer(timedelta(seconds=1), lambda: None)
310
311        timer.set_name(name)
312
313        timer.start()
314
315        self.assertTrue(any(t.name == name for t in threading.enumerate()))
316
317        timer.cancel()
318
319        self.assertTrue(all(t.name != name for t in threading.enumerate()))
320
321    def test_delete_stops_background_thread(self) -> None:
322        name = "PeriodicTimer_DeleteStopsBackgroundThreadTest"
323
324        timer = _PeriodicTimer(timedelta(seconds=1), lambda: None)
325
326        timer.set_name(name)
327
328        timer.start()
329
330        self.assertTrue(any(t.name == name for t in threading.enumerate()))
331
332        del timer
333
334        self.assertTrue(all(t.name != name for t in threading.enumerate()))
335
336    def test_set_name_cannot_be_called_after_start(self) -> None:
337        timer = _PeriodicTimer(timedelta(seconds=1), lambda: None)
338
339        timer.start()
340
341        with self.assertRaisesRegex(RuntimeError, r"^The timer has already started.$"):
342            timer.set_name("dummy_name")
343
344        timer.cancel()
345
346    def test_timer_calls_background_thread_at_regular_intervals(self) -> None:
347        timer_begin_time: float
348
349        # Call our function every 200ms.
350        call_interval = 0.2
351
352        # Keep the log of intervals between each consecutive call.
353        actual_call_intervals: List[float] = []
354
355        # Keep the number of times the function was called.
356        call_count = 0
357
358        # In order to prevent a flaky test instead of asserting that the
359        # function was called an exact number of times we use a lower bound
360        # that is guaranteed to be true for a correct implementation.
361        min_required_call_count = 4
362
363        timer_stop_event = threading.Event()
364
365        def log_call(self):
366            nonlocal timer_begin_time, call_count
367
368            actual_call_intervals.append(time.monotonic() - timer_begin_time)
369
370            call_count += 1
371            if call_count == min_required_call_count:
372                timer_stop_event.set()
373
374            timer_begin_time = time.monotonic()
375
376        timer = _PeriodicTimer(timedelta(seconds=call_interval), log_call, self)
377
378        timer_begin_time = time.monotonic()
379
380        timer.start()
381
382        # Although this is theoretically non-deterministic, if our timer, which
383        # has a 200ms call interval, does not get called 4 times in 60 seconds,
384        # there is very likely something else going on.
385        timer_stop_event.wait(60)
386
387        timer.cancel()
388
389        self.longMessage = False
390
391        self.assertGreaterEqual(
392            call_count,
393            min_required_call_count,
394            f"The function has been called {call_count} time(s) but expected to be called at least "
395            f"{min_required_call_count} time(s).",
396        )
397
398        for actual_call_interval in actual_call_intervals:
399            self.assertGreaterEqual(
400                actual_call_interval,
401                call_interval,
402                f"The interval between two function calls was {actual_call_interval} second(s) but "
403                f"expected to be at least {call_interval} second(s).",
404            )
405