xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/timer/local_timer_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8import multiprocessing as mp
9import signal
10import time
11import unittest
12import unittest.mock as mock
13
14import torch.distributed.elastic.timer as timer
15from torch.distributed.elastic.timer.api import TimerRequest
16from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue
17from torch.testing._internal.common_utils import (
18    IS_MACOS,
19    IS_WINDOWS,
20    run_tests,
21    TEST_WITH_DEV_DBG_ASAN,
22    TEST_WITH_TSAN,
23    TestCase,
24)
25
26
27# timer is not supported on windows or macos
28if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
29    # func2 should time out
30    def func2(n, mp_queue):
31        if mp_queue is not None:
32            timer.configure(timer.LocalTimerClient(mp_queue))
33        if n > 0:
34            with timer.expires(after=0.1):
35                func2(n - 1, None)
36                time.sleep(0.2)
37
38    class LocalTimerTest(TestCase):
39        def setUp(self):
40            super().setUp()
41            self.ctx = mp.get_context("spawn")
42            self.mp_queue = self.ctx.Queue()
43            self.max_interval = 0.01
44            self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
45            self.server.start()
46
47        def tearDown(self):
48            super().tearDown()
49            self.server.stop()
50
51        def test_exception_propagation(self):
52            with self.assertRaises(Exception, msg="foobar"):
53                with timer.expires(after=1):
54                    raise Exception("foobar")  # noqa: TRY002
55
56        def test_no_client(self):
57            # no timer client configured; exception expected
58            timer.configure(None)
59            with self.assertRaises(RuntimeError):
60                with timer.expires(after=1):
61                    pass
62
63        def test_client_interaction(self):
64            # no timer client configured but one passed in explicitly
65            # no exception expected
66            timer_client = timer.LocalTimerClient(self.mp_queue)
67            timer_client.acquire = mock.MagicMock(wraps=timer_client.acquire)
68            timer_client.release = mock.MagicMock(wraps=timer_client.release)
69            with timer.expires(after=1, scope="test", client=timer_client):
70                pass
71
72            timer_client.acquire.assert_called_once_with("test", mock.ANY)
73            timer_client.release.assert_called_once_with("test")
74
75        def test_happy_path(self):
76            timer.configure(timer.LocalTimerClient(self.mp_queue))
77            with timer.expires(after=0.5):
78                time.sleep(0.1)
79
80        def test_get_timer_recursive(self):
81            """
82            If a function acquires a countdown timer with default scope,
83            then recursive calls to the function should re-acquire the
84            timer rather than creating a new one. That is only the last
85            recursive call's timer will take effect.
86            """
87            self.server.start()
88            timer.configure(timer.LocalTimerClient(self.mp_queue))
89
90            # func should not time out
91            def func(n):
92                if n > 0:
93                    with timer.expires(after=0.1):
94                        func(n - 1)
95                        time.sleep(0.05)
96
97            func(4)
98
99            p = self.ctx.Process(target=func2, args=(2, self.mp_queue))
100            p.start()
101            p.join()
102            self.assertEqual(-signal.SIGKILL, p.exitcode)
103
104        @staticmethod
105        def _run(mp_queue, timeout, duration):
106            client = timer.LocalTimerClient(mp_queue)
107            timer.configure(client)
108
109            with timer.expires(after=timeout):
110                time.sleep(duration)
111
112        @unittest.skipIf(TEST_WITH_TSAN, "test is tsan incompatible")
113        def test_timer(self):
114            timeout = 0.1
115            duration = 1
116            p = mp.Process(target=self._run, args=(self.mp_queue, timeout, duration))
117            p.start()
118            p.join()
119            self.assertEqual(-signal.SIGKILL, p.exitcode)
120
121    def _enqueue_on_interval(mp_queue, n, interval, sem):
122        """
123        enqueues ``n`` timer requests into ``mp_queue`` one element per
124        interval seconds. Releases the given semaphore once before going to work.
125        """
126        sem.release()
127        for i in range(0, n):
128            mp_queue.put(TimerRequest(i, "test_scope", 0))
129            time.sleep(interval)
130
131
132# timer is not supported on windows or macos
133if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
134
135    class MultiprocessingRequestQueueTest(TestCase):
136        def test_get(self):
137            mp_queue = mp.Queue()
138            request_queue = MultiprocessingRequestQueue(mp_queue)
139
140            requests = request_queue.get(1, timeout=0.01)
141            self.assertEqual(0, len(requests))
142
143            request = TimerRequest(1, "test_scope", 0)
144            mp_queue.put(request)
145            requests = request_queue.get(2, timeout=0.01)
146            self.assertEqual(1, len(requests))
147            self.assertIn(request, requests)
148
149        @unittest.skipIf(
150            TEST_WITH_TSAN,
151            "test incompatible with tsan",
152        )
153        def test_get_size(self):
154            """
155            Creates a "producer" process that enqueues ``n`` elements
156            every ``interval`` seconds. Asserts that a ``get(n, timeout=n*interval+delta)``
157            yields all ``n`` elements.
158            """
159            mp_queue = mp.Queue()
160            request_queue = MultiprocessingRequestQueue(mp_queue)
161            n = 10
162            interval = 0.1
163            sem = mp.Semaphore(0)
164
165            p = mp.Process(
166                target=_enqueue_on_interval, args=(mp_queue, n, interval, sem)
167            )
168            p.start()
169
170            sem.acquire()  # blocks until the process has started to run the function
171            timeout = interval * (n + 1)
172            start = time.time()
173            requests = request_queue.get(n, timeout=timeout)
174            self.assertLessEqual(time.time() - start, timeout + interval)
175            self.assertEqual(n, len(requests))
176
177        def test_get_less_than_size(self):
178            """
179            Tests slow producer.
180            Creates a "producer" process that enqueues ``n`` elements
181            every ``interval`` seconds. Asserts that a ``get(n, timeout=(interval * n/2))``
182            yields at most ``n/2`` elements.
183            """
184            mp_queue = mp.Queue()
185            request_queue = MultiprocessingRequestQueue(mp_queue)
186            n = 10
187            interval = 0.1
188            sem = mp.Semaphore(0)
189
190            p = mp.Process(
191                target=_enqueue_on_interval, args=(mp_queue, n, interval, sem)
192            )
193            p.start()
194
195            sem.acquire()  # blocks until the process has started to run the function
196            requests = request_queue.get(n, timeout=(interval * (n / 2)))
197            self.assertLessEqual(n / 2, len(requests))
198
199
200# timer is not supported on windows or macos
201if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN):
202
203    class LocalTimerServerTest(TestCase):
204        def setUp(self):
205            super().setUp()
206            self.mp_queue = mp.Queue()
207            self.max_interval = 0.01
208            self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval)
209
210        def tearDown(self):
211            super().tearDown()
212            self.server.stop()
213
214        def test_watchdog_call_count(self):
215            """
216            checks that the watchdog function ran wait/interval +- 1 times
217            """
218            self.server._run_watchdog = mock.MagicMock(wraps=self.server._run_watchdog)
219
220            wait = 0.1
221
222            self.server.start()
223            time.sleep(wait)
224            self.server.stop()
225            watchdog_call_count = self.server._run_watchdog.call_count
226            self.assertGreaterEqual(
227                watchdog_call_count, int(wait / self.max_interval) - 1
228            )
229            self.assertLessEqual(watchdog_call_count, int(wait / self.max_interval) + 1)
230
231        def test_watchdog_empty_queue(self):
232            """
233            checks that the watchdog can run on an empty queue
234            """
235            self.server._run_watchdog()
236
237        def _expired_timer(self, pid, scope):
238            expired = time.time() - 60
239            return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=expired)
240
241        def _valid_timer(self, pid, scope):
242            valid = time.time() + 60
243            return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=valid)
244
245        def _release_timer(self, pid, scope):
246            return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1)
247
248        @mock.patch("os.kill")
249        def test_expired_timers(self, mock_os_kill):
250            """
251            tests that a single expired timer on a process should terminate
252            the process and clean up all pending timers that was owned by the process
253            """
254            test_pid = -3
255            self.mp_queue.put(self._expired_timer(pid=test_pid, scope="test1"))
256            self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test2"))
257
258            self.server._run_watchdog()
259
260            self.assertEqual(0, len(self.server._timers))
261            mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL)
262
263        @mock.patch("os.kill")
264        def test_acquire_release(self, mock_os_kill):
265            """
266            tests that:
267            1. a timer can be acquired then released (should not terminate process)
268            2. a timer can be vacuously released (e.g. no-op)
269            """
270            test_pid = -3
271            self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test1"))
272            self.mp_queue.put(self._release_timer(pid=test_pid, scope="test1"))
273            self.mp_queue.put(self._release_timer(pid=test_pid, scope="test2"))
274
275            self.server._run_watchdog()
276
277            self.assertEqual(0, len(self.server._timers))
278            mock_os_kill.assert_not_called()
279
280        @mock.patch("os.kill")
281        def test_valid_timers(self, mock_os_kill):
282            """
283            tests that valid timers are processed correctly and the process is left alone
284            """
285            self.mp_queue.put(self._valid_timer(pid=-3, scope="test1"))
286            self.mp_queue.put(self._valid_timer(pid=-3, scope="test2"))
287            self.mp_queue.put(self._valid_timer(pid=-2, scope="test1"))
288            self.mp_queue.put(self._valid_timer(pid=-2, scope="test2"))
289
290            self.server._run_watchdog()
291
292            self.assertEqual(4, len(self.server._timers))
293            self.assertTrue((-3, "test1") in self.server._timers)
294            self.assertTrue((-3, "test2") in self.server._timers)
295            self.assertTrue((-2, "test1") in self.server._timers)
296            self.assertTrue((-2, "test2") in self.server._timers)
297            mock_os_kill.assert_not_called()
298
299
300if __name__ == "__main__":
301    run_tests()
302