xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/timer/api_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8import unittest
9import unittest.mock as mock
10
11from torch.distributed.elastic.timer import TimerServer
12from torch.distributed.elastic.timer.api import RequestQueue, TimerRequest
13
14
15class MockRequestQueue(RequestQueue):
16    def size(self):
17        return 2
18
19    def get(self, size, timeout):
20        return [TimerRequest(1, "test_1", 0), TimerRequest(2, "test_2", 0)]
21
22
23class MockTimerServer(TimerServer):
24    """
25     Mock implementation of TimerServer for testing purposes.
26     This mock has the following behavior:
27
28     1. reaping worker 1 throws
29     2. reaping worker 2 succeeds
30     3. reaping worker 3 fails (caught exception)
31
32    For each workers 1 - 3 returns 2 expired timers
33    """
34
35    def __init__(self, request_queue, max_interval):
36        super().__init__(request_queue, max_interval)
37
38    def register_timers(self, timer_requests):
39        pass
40
41    def clear_timers(self, worker_ids):
42        pass
43
44    def get_expired_timers(self, deadline):
45        return {
46            i: [TimerRequest(i, f"test_{i}_0", 0), TimerRequest(i, f"test_{i}_1", 0)]
47            for i in range(1, 4)
48        }
49
50    def _reap_worker(self, worker_id):
51        if worker_id == 1:
52            raise RuntimeError("test error")
53        elif worker_id == 2:
54            return True
55        elif worker_id == 3:
56            return False
57
58
59class TimerApiTest(unittest.TestCase):
60    @mock.patch.object(MockTimerServer, "register_timers")
61    @mock.patch.object(MockTimerServer, "clear_timers")
62    def test_run_watchdog(self, mock_clear_timers, mock_register_timers):
63        """
64        tests that when a ``_reap_worker()`` method throws an exception
65        for a particular worker_id, the timers for successfully reaped workers
66        are cleared properly
67        """
68        max_interval = 1
69        request_queue = mock.Mock(wraps=MockRequestQueue())
70        timer_server = MockTimerServer(request_queue, max_interval)
71        timer_server._run_watchdog()
72
73        request_queue.size.assert_called_once()
74        request_queue.get.assert_called_with(request_queue.size(), max_interval)
75        mock_register_timers.assert_called_with(request_queue.get(2, 1))
76        mock_clear_timers.assert_called_with({1, 2})
77