1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8import unittest 9import unittest.mock as mock 10 11from torch.distributed.elastic.timer import TimerServer 12from torch.distributed.elastic.timer.api import RequestQueue, TimerRequest 13 14 15class MockRequestQueue(RequestQueue): 16 def size(self): 17 return 2 18 19 def get(self, size, timeout): 20 return [TimerRequest(1, "test_1", 0), TimerRequest(2, "test_2", 0)] 21 22 23class MockTimerServer(TimerServer): 24 """ 25 Mock implementation of TimerServer for testing purposes. 26 This mock has the following behavior: 27 28 1. reaping worker 1 throws 29 2. reaping worker 2 succeeds 30 3. reaping worker 3 fails (caught exception) 31 32 For each workers 1 - 3 returns 2 expired timers 33 """ 34 35 def __init__(self, request_queue, max_interval): 36 super().__init__(request_queue, max_interval) 37 38 def register_timers(self, timer_requests): 39 pass 40 41 def clear_timers(self, worker_ids): 42 pass 43 44 def get_expired_timers(self, deadline): 45 return { 46 i: [TimerRequest(i, f"test_{i}_0", 0), TimerRequest(i, f"test_{i}_1", 0)] 47 for i in range(1, 4) 48 } 49 50 def _reap_worker(self, worker_id): 51 if worker_id == 1: 52 raise RuntimeError("test error") 53 elif worker_id == 2: 54 return True 55 elif worker_id == 3: 56 return False 57 58 59class TimerApiTest(unittest.TestCase): 60 @mock.patch.object(MockTimerServer, "register_timers") 61 @mock.patch.object(MockTimerServer, "clear_timers") 62 def test_run_watchdog(self, mock_clear_timers, mock_register_timers): 63 """ 64 tests that when a ``_reap_worker()`` method throws an exception 65 for a particular worker_id, the timers for successfully reaped workers 66 are cleared properly 67 """ 68 max_interval = 1 69 request_queue = mock.Mock(wraps=MockRequestQueue()) 70 timer_server = MockTimerServer(request_queue, max_interval) 71 timer_server._run_watchdog() 72 73 request_queue.size.assert_called_once() 74 request_queue.get.assert_called_with(request_queue.size(), max_interval) 75 mock_register_timers.assert_called_with(request_queue.get(2, 1)) 76 mock_clear_timers.assert_called_with({1, 2}) 77