1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8import multiprocessing as mp 9import signal 10import time 11import unittest 12import unittest.mock as mock 13 14import torch.distributed.elastic.timer as timer 15from torch.distributed.elastic.timer.api import TimerRequest 16from torch.distributed.elastic.timer.local_timer import MultiprocessingRequestQueue 17from torch.testing._internal.common_utils import ( 18 IS_MACOS, 19 IS_WINDOWS, 20 run_tests, 21 TEST_WITH_DEV_DBG_ASAN, 22 TEST_WITH_TSAN, 23 TestCase, 24) 25 26 27# timer is not supported on windows or macos 28if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN): 29 # func2 should time out 30 def func2(n, mp_queue): 31 if mp_queue is not None: 32 timer.configure(timer.LocalTimerClient(mp_queue)) 33 if n > 0: 34 with timer.expires(after=0.1): 35 func2(n - 1, None) 36 time.sleep(0.2) 37 38 class LocalTimerTest(TestCase): 39 def setUp(self): 40 super().setUp() 41 self.ctx = mp.get_context("spawn") 42 self.mp_queue = self.ctx.Queue() 43 self.max_interval = 0.01 44 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) 45 self.server.start() 46 47 def tearDown(self): 48 super().tearDown() 49 self.server.stop() 50 51 def test_exception_propagation(self): 52 with self.assertRaises(Exception, msg="foobar"): 53 with timer.expires(after=1): 54 raise Exception("foobar") # noqa: TRY002 55 56 def test_no_client(self): 57 # no timer client configured; exception expected 58 timer.configure(None) 59 with self.assertRaises(RuntimeError): 60 with timer.expires(after=1): 61 pass 62 63 def test_client_interaction(self): 64 # no timer client configured but one passed in explicitly 65 # no exception expected 66 timer_client = timer.LocalTimerClient(self.mp_queue) 67 timer_client.acquire = mock.MagicMock(wraps=timer_client.acquire) 68 timer_client.release = mock.MagicMock(wraps=timer_client.release) 69 with timer.expires(after=1, scope="test", client=timer_client): 70 pass 71 72 timer_client.acquire.assert_called_once_with("test", mock.ANY) 73 timer_client.release.assert_called_once_with("test") 74 75 def test_happy_path(self): 76 timer.configure(timer.LocalTimerClient(self.mp_queue)) 77 with timer.expires(after=0.5): 78 time.sleep(0.1) 79 80 def test_get_timer_recursive(self): 81 """ 82 If a function acquires a countdown timer with default scope, 83 then recursive calls to the function should re-acquire the 84 timer rather than creating a new one. That is only the last 85 recursive call's timer will take effect. 86 """ 87 self.server.start() 88 timer.configure(timer.LocalTimerClient(self.mp_queue)) 89 90 # func should not time out 91 def func(n): 92 if n > 0: 93 with timer.expires(after=0.1): 94 func(n - 1) 95 time.sleep(0.05) 96 97 func(4) 98 99 p = self.ctx.Process(target=func2, args=(2, self.mp_queue)) 100 p.start() 101 p.join() 102 self.assertEqual(-signal.SIGKILL, p.exitcode) 103 104 @staticmethod 105 def _run(mp_queue, timeout, duration): 106 client = timer.LocalTimerClient(mp_queue) 107 timer.configure(client) 108 109 with timer.expires(after=timeout): 110 time.sleep(duration) 111 112 @unittest.skipIf(TEST_WITH_TSAN, "test is tsan incompatible") 113 def test_timer(self): 114 timeout = 0.1 115 duration = 1 116 p = mp.Process(target=self._run, args=(self.mp_queue, timeout, duration)) 117 p.start() 118 p.join() 119 self.assertEqual(-signal.SIGKILL, p.exitcode) 120 121 def _enqueue_on_interval(mp_queue, n, interval, sem): 122 """ 123 enqueues ``n`` timer requests into ``mp_queue`` one element per 124 interval seconds. Releases the given semaphore once before going to work. 125 """ 126 sem.release() 127 for i in range(0, n): 128 mp_queue.put(TimerRequest(i, "test_scope", 0)) 129 time.sleep(interval) 130 131 132# timer is not supported on windows or macos 133if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN): 134 135 class MultiprocessingRequestQueueTest(TestCase): 136 def test_get(self): 137 mp_queue = mp.Queue() 138 request_queue = MultiprocessingRequestQueue(mp_queue) 139 140 requests = request_queue.get(1, timeout=0.01) 141 self.assertEqual(0, len(requests)) 142 143 request = TimerRequest(1, "test_scope", 0) 144 mp_queue.put(request) 145 requests = request_queue.get(2, timeout=0.01) 146 self.assertEqual(1, len(requests)) 147 self.assertIn(request, requests) 148 149 @unittest.skipIf( 150 TEST_WITH_TSAN, 151 "test incompatible with tsan", 152 ) 153 def test_get_size(self): 154 """ 155 Creates a "producer" process that enqueues ``n`` elements 156 every ``interval`` seconds. Asserts that a ``get(n, timeout=n*interval+delta)`` 157 yields all ``n`` elements. 158 """ 159 mp_queue = mp.Queue() 160 request_queue = MultiprocessingRequestQueue(mp_queue) 161 n = 10 162 interval = 0.1 163 sem = mp.Semaphore(0) 164 165 p = mp.Process( 166 target=_enqueue_on_interval, args=(mp_queue, n, interval, sem) 167 ) 168 p.start() 169 170 sem.acquire() # blocks until the process has started to run the function 171 timeout = interval * (n + 1) 172 start = time.time() 173 requests = request_queue.get(n, timeout=timeout) 174 self.assertLessEqual(time.time() - start, timeout + interval) 175 self.assertEqual(n, len(requests)) 176 177 def test_get_less_than_size(self): 178 """ 179 Tests slow producer. 180 Creates a "producer" process that enqueues ``n`` elements 181 every ``interval`` seconds. Asserts that a ``get(n, timeout=(interval * n/2))`` 182 yields at most ``n/2`` elements. 183 """ 184 mp_queue = mp.Queue() 185 request_queue = MultiprocessingRequestQueue(mp_queue) 186 n = 10 187 interval = 0.1 188 sem = mp.Semaphore(0) 189 190 p = mp.Process( 191 target=_enqueue_on_interval, args=(mp_queue, n, interval, sem) 192 ) 193 p.start() 194 195 sem.acquire() # blocks until the process has started to run the function 196 requests = request_queue.get(n, timeout=(interval * (n / 2))) 197 self.assertLessEqual(n / 2, len(requests)) 198 199 200# timer is not supported on windows or macos 201if not (IS_WINDOWS or IS_MACOS or TEST_WITH_DEV_DBG_ASAN): 202 203 class LocalTimerServerTest(TestCase): 204 def setUp(self): 205 super().setUp() 206 self.mp_queue = mp.Queue() 207 self.max_interval = 0.01 208 self.server = timer.LocalTimerServer(self.mp_queue, self.max_interval) 209 210 def tearDown(self): 211 super().tearDown() 212 self.server.stop() 213 214 def test_watchdog_call_count(self): 215 """ 216 checks that the watchdog function ran wait/interval +- 1 times 217 """ 218 self.server._run_watchdog = mock.MagicMock(wraps=self.server._run_watchdog) 219 220 wait = 0.1 221 222 self.server.start() 223 time.sleep(wait) 224 self.server.stop() 225 watchdog_call_count = self.server._run_watchdog.call_count 226 self.assertGreaterEqual( 227 watchdog_call_count, int(wait / self.max_interval) - 1 228 ) 229 self.assertLessEqual(watchdog_call_count, int(wait / self.max_interval) + 1) 230 231 def test_watchdog_empty_queue(self): 232 """ 233 checks that the watchdog can run on an empty queue 234 """ 235 self.server._run_watchdog() 236 237 def _expired_timer(self, pid, scope): 238 expired = time.time() - 60 239 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=expired) 240 241 def _valid_timer(self, pid, scope): 242 valid = time.time() + 60 243 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=valid) 244 245 def _release_timer(self, pid, scope): 246 return TimerRequest(worker_id=pid, scope_id=scope, expiration_time=-1) 247 248 @mock.patch("os.kill") 249 def test_expired_timers(self, mock_os_kill): 250 """ 251 tests that a single expired timer on a process should terminate 252 the process and clean up all pending timers that was owned by the process 253 """ 254 test_pid = -3 255 self.mp_queue.put(self._expired_timer(pid=test_pid, scope="test1")) 256 self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test2")) 257 258 self.server._run_watchdog() 259 260 self.assertEqual(0, len(self.server._timers)) 261 mock_os_kill.assert_called_once_with(test_pid, signal.SIGKILL) 262 263 @mock.patch("os.kill") 264 def test_acquire_release(self, mock_os_kill): 265 """ 266 tests that: 267 1. a timer can be acquired then released (should not terminate process) 268 2. a timer can be vacuously released (e.g. no-op) 269 """ 270 test_pid = -3 271 self.mp_queue.put(self._valid_timer(pid=test_pid, scope="test1")) 272 self.mp_queue.put(self._release_timer(pid=test_pid, scope="test1")) 273 self.mp_queue.put(self._release_timer(pid=test_pid, scope="test2")) 274 275 self.server._run_watchdog() 276 277 self.assertEqual(0, len(self.server._timers)) 278 mock_os_kill.assert_not_called() 279 280 @mock.patch("os.kill") 281 def test_valid_timers(self, mock_os_kill): 282 """ 283 tests that valid timers are processed correctly and the process is left alone 284 """ 285 self.mp_queue.put(self._valid_timer(pid=-3, scope="test1")) 286 self.mp_queue.put(self._valid_timer(pid=-3, scope="test2")) 287 self.mp_queue.put(self._valid_timer(pid=-2, scope="test1")) 288 self.mp_queue.put(self._valid_timer(pid=-2, scope="test2")) 289 290 self.server._run_watchdog() 291 292 self.assertEqual(4, len(self.server._timers)) 293 self.assertTrue((-3, "test1") in self.server._timers) 294 self.assertTrue((-3, "test2") in self.server._timers) 295 self.assertTrue((-2, "test1") in self.server._timers) 296 self.assertTrue((-2, "test2") in self.server._timers) 297 mock_os_kill.assert_not_called() 298 299 300if __name__ == "__main__": 301 run_tests() 302