xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch
4import time
5import torch.distributed.rpc as rpc
6from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs
7from torch.testing._internal.dist_utils import (
8    dist_init,
9    wait_until_pending_futures_and_users_flushed,
10    wait_until_owners_and_forks_on_rank,
11    worker_name,
12)
13from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
14    RpcAgentTestFixture,
15)
16
17def my_sleep_func(seconds=1):
18    time.sleep(seconds)
19    return torch.mul(torch.tensor(1), torch.tensor(1))
20
21@torch.jit.script
22def my_script_func(tensor):
23    return torch.add(tensor, tensor)
24
25def add_rref_to_value(rref, value):
26    return rref.to_here() + value
27
28class FaultyAgentRpcTest(RpcAgentTestFixture):
29
30    # no faulty_messages defined so this fails all retryable messages - see
31    # faulty_rpc_agent_test_fixture.py for the list of retryable messages.
32    @dist_init(messages_to_delay={})
33    def test_check_failed_messages(self):
34        if self.rank == 0:
35            dst_worker_b = worker_name((self.rank + 1) % self.world_size)
36            dst_worker_c = worker_name((self.rank + 2) % self.world_size)
37
38            # Worker0 sends RPC to Worker1 and creates an RRef there
39            rref = rpc.remote(dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2)))
40            # Worker0 sends an RPC to Worker2 with the RRef as an arg
41            rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2)))
42            # check if the output is as expected
43            self.assertEqual(rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2)))
44        # explicitly delete all User RRefs
45        _delete_all_user_and_unforked_owner_rrefs()
46
47    @dist_init
48    def test_verify_backend_options(self):
49        self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE)
50        self.assertEqual(self.rpc_backend_options.num_worker_threads, 8)
51        self.assertEqual(self.rpc_backend_options.num_fail_sends, 3)
52        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4)
53        self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2)
54        self.assertEqual(self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
55
56    @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"])
57    def test_custom_faulty_messages(self):
58        self.assertEqual(
59            {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"},
60            set(self.rpc_backend_options.messages_to_fail),
61        )
62
63    @dist_init(faulty_messages=[])
64    def test_no_faulty_messages(self):
65        self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0)
66
67    @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5})
68    def test_custom_messages_to_delay(self):
69        self.assertEqual(self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5})
70
71    def _test_remote_message_dropped_pickle(self, dst=None):
72        if self.rank != 0:
73            return
74        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
75        dst_worker = f"worker{dst_rank}"
76        # Since we fail python_remote_call messages synchronously, the future
77        # corresponding to this remote call will be marked with an error when
78        # this function returns.
79        rref = rpc.remote(dst_worker, my_sleep_func, args=(1,))
80        # Call to ensure pending callbacks are run.
81        wait_until_pending_futures_and_users_flushed()
82        # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout.
83        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
84            rref._serialize()
85        # Test that using RRef as arg over RPC (which forks) results in the same
86        # error
87        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
88            rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1))
89
90    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
91    def test_remote_message_dropped_pickle(self):
92        self._test_remote_message_dropped_pickle()
93
94    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
95    def test_remote_message_dropped_pickle_to_self(self):
96        self._test_remote_message_dropped_pickle(self.rank)
97
98
99    def _test_remote_message_dropped_timeout(self, func, args, dst=None):
100        if self.rank != 0:
101            return
102
103        # test the case where rpc.remote() message creation is completely dropped.
104        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
105        dst_worker = f"worker{dst_rank}"
106        # Since we fail python_remote_call messages synchronously, the future
107        # corresponding to this remote call will be marked with an error when
108        # this function returns.
109        rref = rpc.remote(dst_worker, func, args=args)
110        # Call to ensure pending callbacks are run.
111        wait_until_pending_futures_and_users_flushed()
112        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
113            rref.to_here()
114        # Note: during shutdown, logs will indicate "Could not find OwnerRRef..."
115        # on the owning nodes, this is expected because the OwnerRRef was never
116        # successfully created. Therefore, delAllUsers will work as expected.
117
118    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
119    def test_builtin_remote_message_dropped_timeout(self):
120        func = torch.add
121        args = (torch.tensor(1), torch.tensor(1))
122        self._test_remote_message_dropped_timeout(func, args)
123
124    @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"])
125    def test_builtin_remote_message_dropped_timeout_to_self(self):
126        func = torch.add
127        args = (torch.tensor(1), torch.tensor(1))
128        self._test_remote_message_dropped_timeout(func, args, dst=0)
129
130    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
131    def test_udf_remote_message_dropped_timeout(self):
132        func = my_sleep_func
133        args = (2,)
134        self._test_remote_message_dropped_timeout(func, args)
135
136    @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"])
137    def test_udf_remote_message_dropped_timeout_to_self(self):
138        func = my_sleep_func
139        args = (2,)
140        self._test_remote_message_dropped_timeout(func, args, dst=0)
141
142    def _test_remote_message_delay_timeout(self, func, args, dst=None):
143        if self.rank != 0:
144            return
145        # Test the case where remote message is eventually processed on the owner,
146        # but the future on the creator times out before the response comes back.
147        dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size
148        dst_worker = f"worker{dst_rank}"
149        # 10 ms timeout
150        rref = rpc.remote(dst_worker, func, args=args, timeout=0.001)
151        # Future corresponding to the remote creation should time out.
152        expected_error = self.get_timeout_error_regex()
153        with self.assertRaisesRegex(RuntimeError, expected_error):
154            rref._get_future().wait()
155
156        # Call to ensure pending callbacks are run.
157        wait_until_pending_futures_and_users_flushed()
158        # to_here() should now pick up that rpc.remote() creation has failed.
159        with self.assertRaisesRegex(RuntimeError, "RRef creation"):
160            rref.to_here()
161
162        # Test the case where rpc.remote() times out, but to_here() has already
163        # started blocking before.
164        # NOTE: we only test this when not sending to self, as to_here() calls
165        # calls localValue(), which does not send an RPC and thus does not have
166        # a timeout. This can be supported by allowing future.wait() to
167        # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280)
168        if dst_rank != self.rank:
169            slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2)
170
171            with self.assertRaisesRegex(RuntimeError, expected_error):
172                # to_here() should raise timeout error, since it does not know about the
173                # status of rpc.remote().
174                slow_rref.to_here(0.001)
175        # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete
176        # but this can be a noop since it may not exist on the owner yet. Later,
177        # the owner can process the RRef creation and wait for the delete message,
178        # thus leading to a timeout.
179        # Therefore, we wait until we get notification that pending owners have
180        # been confirmed before sending out RRefUserDeletes.
181        if dst_rank != self.rank:
182            wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank)
183
184    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
185    def test_udf_remote_message_delay_timeout(self):
186        func = my_sleep_func
187        args = (2,)
188        self._test_remote_message_delay_timeout(func, args)
189
190    @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2})
191    def test_udf_remote_message_delay_timeout_to_self(self):
192        func = my_sleep_func
193        args = (1,)
194        self._test_remote_message_delay_timeout(func, args, dst=0)
195
196    @dist_init(
197        faulty_messages=[],
198        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
199    )
200    def test_remote_message_builtin_delay_timeout(self):
201        func = torch.add
202        args = (torch.tensor(1), torch.tensor(1))
203        self._test_remote_message_delay_timeout(func, args)
204
205    @dist_init(
206        faulty_messages=[],
207        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
208    )
209    def test_remote_message_builtin_delay_timeout_to_self(self):
210        func = torch.add
211        args = (torch.tensor(1), torch.tensor(1))
212        self._test_remote_message_delay_timeout(func, args, dst=0)
213
214    @dist_init(
215        faulty_messages=[],
216        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
217    )
218    def test_remote_message_script_delay_timeout(self):
219        func = my_script_func
220        args = (torch.tensor(1),)
221        self._test_remote_message_delay_timeout(func, args)
222
223    @dist_init(
224        faulty_messages=[],
225        messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1},
226    )
227    def test_remote_message_script_delay_timeout_to_self(self):
228        func = my_script_func
229        args = (torch.tensor(1),)
230        self._test_remote_message_delay_timeout(func, args, dst=0)
231
232    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1})
233    def test_rref_to_here_timeout(self):
234        if self.rank != 0:
235            return
236
237        dst_rank = (self.rank + 1) % self.world_size
238        dst_worker = f"worker{dst_rank}"
239        rref = rpc.remote(
240            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
241        )
242        expected_error = self.get_timeout_error_regex()
243        with self.assertRaisesRegex(RuntimeError, expected_error):
244            rref.to_here(0.01)
245
246        rref.to_here()
247
248    @dist_init(faulty_messages=[])
249    def test_rpc_builtin_timeout(self):
250        next_rank = (self.rank + 1) % self.world_size
251        dst_worker = worker_name(next_rank)
252        expected_error = self.get_timeout_error_regex()
253        # PYTHON_CALL message types which correspond to Python UDF over RPC
254        # by default get a delay (see faulty_rpc_agent_test_fixture)
255        with self.assertRaisesRegex(RuntimeError, expected_error):
256            rpc.rpc_sync(
257                dst_worker,
258                torch.add,
259                args=(torch.tensor(1), torch.tensor(1)),
260                timeout=1,
261            )
262
263        fut = rpc.rpc_async(
264            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1
265        )
266        with self.assertRaisesRegex(RuntimeError, expected_error):
267            fut.wait()
268
269        # Ensure that the currently set default timeout is large enough such
270        # that RPCs with delays still complete.
271        fut = rpc.rpc_async(
272            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
273        )
274        fut.wait()
275
276        # Ensure timeout if we set a new default and don't override
277        rpc._set_rpc_timeout(0.001)
278        fut = rpc.rpc_async(
279            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1))
280        )
281        with self.assertRaisesRegex(RuntimeError, expected_error):
282            fut.wait()
283
284        # Ensure run to completion if we specify timeout of 0
285        fut = rpc.rpc_async(
286            dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0
287        )
288        fut.wait()
289        # Reset for clean shutdown
290        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
291
292    @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5})
293    def test_rpc_script_timeout(self):
294        next_rank = (self.rank + 1) % self.world_size
295        dst_worker = worker_name(next_rank)
296        expected_error = self.get_timeout_error_regex()
297        with self.assertRaisesRegex(RuntimeError, expected_error):
298            rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
299
300        fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1)
301        with self.assertRaisesRegex(RuntimeError, expected_error):
302            fut.wait()
303
304        # Ensure that the currently set default timeout is large enough such
305        # that RPCs with delays still complete.
306        fut = rpc.rpc_async(
307            dst_worker, my_script_func, args=(torch.tensor(1),)
308        )
309        fut.wait()
310
311        # Ensure timeout if we set a new default and don't override
312        rpc._set_rpc_timeout(0.001)
313        fut = rpc.rpc_async(
314            dst_worker, my_script_func, args=(torch.tensor(1),)
315        )
316        with self.assertRaisesRegex(RuntimeError, expected_error):
317            fut.wait()
318
319        # Ensure run to completion if we specify timeout of 0
320        rpc._set_rpc_timeout(0.001)
321        fut = rpc.rpc_async(
322            dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0
323        )
324        fut.wait()
325        # Reset for clean shutdown
326        rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC)
327