1# mypy: allow-untyped-defs 2 3import torch 4import time 5import torch.distributed.rpc as rpc 6from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs 7from torch.testing._internal.dist_utils import ( 8 dist_init, 9 wait_until_pending_futures_and_users_flushed, 10 wait_until_owners_and_forks_on_rank, 11 worker_name, 12) 13from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 14 RpcAgentTestFixture, 15) 16 17def my_sleep_func(seconds=1): 18 time.sleep(seconds) 19 return torch.mul(torch.tensor(1), torch.tensor(1)) 20 21@torch.jit.script 22def my_script_func(tensor): 23 return torch.add(tensor, tensor) 24 25def add_rref_to_value(rref, value): 26 return rref.to_here() + value 27 28class FaultyAgentRpcTest(RpcAgentTestFixture): 29 30 # no faulty_messages defined so this fails all retryable messages - see 31 # faulty_rpc_agent_test_fixture.py for the list of retryable messages. 32 @dist_init(messages_to_delay={}) 33 def test_check_failed_messages(self): 34 if self.rank == 0: 35 dst_worker_b = worker_name((self.rank + 1) % self.world_size) 36 dst_worker_c = worker_name((self.rank + 2) % self.world_size) 37 38 # Worker0 sends RPC to Worker1 and creates an RRef there 39 rref = rpc.remote(dst_worker_b, torch.add, args=(torch.ones(2, 2), torch.ones(2, 2))) 40 # Worker0 sends an RPC to Worker2 with the RRef as an arg 41 rpc.remote(dst_worker_c, add_rref_to_value, args=(rref, torch.ones(2, 2))) 42 # check if the output is as expected 43 self.assertEqual(rref.to_here(), torch.add(torch.ones(2, 2), torch.ones(2, 2))) 44 # explicitly delete all User RRefs 45 _delete_all_user_and_unforked_owner_rrefs() 46 47 @dist_init 48 def test_verify_backend_options(self): 49 self.assertEqual(self.rpc_backend, rpc.backend_registry.BackendType.FAULTY_TENSORPIPE) 50 self.assertEqual(self.rpc_backend_options.num_worker_threads, 8) 51 self.assertEqual(self.rpc_backend_options.num_fail_sends, 3) 52 self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 4) 53 self.assertEqual(len(self.rpc_backend_options.messages_to_delay), 2) 54 self.assertEqual(self.rpc_backend_options.rpc_timeout, rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 55 56 @dist_init(faulty_messages=["RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"]) 57 def test_custom_faulty_messages(self): 58 self.assertEqual( 59 {"RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT"}, 60 set(self.rpc_backend_options.messages_to_fail), 61 ) 62 63 @dist_init(faulty_messages=[]) 64 def test_no_faulty_messages(self): 65 self.assertEqual(len(self.rpc_backend_options.messages_to_fail), 0) 66 67 @dist_init(messages_to_delay={"SCRIPT_CALL": 1.5}) 68 def test_custom_messages_to_delay(self): 69 self.assertEqual(self.rpc_backend_options.messages_to_delay, {"SCRIPT_CALL": 1.5}) 70 71 def _test_remote_message_dropped_pickle(self, dst=None): 72 if self.rank != 0: 73 return 74 dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size 75 dst_worker = f"worker{dst_rank}" 76 # Since we fail python_remote_call messages synchronously, the future 77 # corresponding to this remote call will be marked with an error when 78 # this function returns. 79 rref = rpc.remote(dst_worker, my_sleep_func, args=(1,)) 80 # Call to ensure pending callbacks are run. 81 wait_until_pending_futures_and_users_flushed() 82 # Attempt to fork the RRef should raise an error indicating the rpc.remote timeout. 83 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 84 rref._serialize() 85 # Test that using RRef as arg over RPC (which forks) results in the same 86 # error 87 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 88 rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 1)) 89 90 @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) 91 def test_remote_message_dropped_pickle(self): 92 self._test_remote_message_dropped_pickle() 93 94 @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) 95 def test_remote_message_dropped_pickle_to_self(self): 96 self._test_remote_message_dropped_pickle(self.rank) 97 98 99 def _test_remote_message_dropped_timeout(self, func, args, dst=None): 100 if self.rank != 0: 101 return 102 103 # test the case where rpc.remote() message creation is completely dropped. 104 dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size 105 dst_worker = f"worker{dst_rank}" 106 # Since we fail python_remote_call messages synchronously, the future 107 # corresponding to this remote call will be marked with an error when 108 # this function returns. 109 rref = rpc.remote(dst_worker, func, args=args) 110 # Call to ensure pending callbacks are run. 111 wait_until_pending_futures_and_users_flushed() 112 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 113 rref.to_here() 114 # Note: during shutdown, logs will indicate "Could not find OwnerRRef..." 115 # on the owning nodes, this is expected because the OwnerRRef was never 116 # successfully created. Therefore, delAllUsers will work as expected. 117 118 @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) 119 def test_builtin_remote_message_dropped_timeout(self): 120 func = torch.add 121 args = (torch.tensor(1), torch.tensor(1)) 122 self._test_remote_message_dropped_timeout(func, args) 123 124 @dist_init(faulty_messages=["SCRIPT_REMOTE_CALL"]) 125 def test_builtin_remote_message_dropped_timeout_to_self(self): 126 func = torch.add 127 args = (torch.tensor(1), torch.tensor(1)) 128 self._test_remote_message_dropped_timeout(func, args, dst=0) 129 130 @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) 131 def test_udf_remote_message_dropped_timeout(self): 132 func = my_sleep_func 133 args = (2,) 134 self._test_remote_message_dropped_timeout(func, args) 135 136 @dist_init(faulty_messages=["PYTHON_REMOTE_CALL"]) 137 def test_udf_remote_message_dropped_timeout_to_self(self): 138 func = my_sleep_func 139 args = (2,) 140 self._test_remote_message_dropped_timeout(func, args, dst=0) 141 142 def _test_remote_message_delay_timeout(self, func, args, dst=None): 143 if self.rank != 0: 144 return 145 # Test the case where remote message is eventually processed on the owner, 146 # but the future on the creator times out before the response comes back. 147 dst_rank = dst if dst is not None else (self.rank + 1) % self.world_size 148 dst_worker = f"worker{dst_rank}" 149 # 10 ms timeout 150 rref = rpc.remote(dst_worker, func, args=args, timeout=0.001) 151 # Future corresponding to the remote creation should time out. 152 expected_error = self.get_timeout_error_regex() 153 with self.assertRaisesRegex(RuntimeError, expected_error): 154 rref._get_future().wait() 155 156 # Call to ensure pending callbacks are run. 157 wait_until_pending_futures_and_users_flushed() 158 # to_here() should now pick up that rpc.remote() creation has failed. 159 with self.assertRaisesRegex(RuntimeError, "RRef creation"): 160 rref.to_here() 161 162 # Test the case where rpc.remote() times out, but to_here() has already 163 # started blocking before. 164 # NOTE: we only test this when not sending to self, as to_here() calls 165 # calls localValue(), which does not send an RPC and thus does not have 166 # a timeout. This can be supported by allowing future.wait() to 167 # take in an optional timeout (https://github.com/pytorch/pytorch/issues/39280) 168 if dst_rank != self.rank: 169 slow_rref = rpc.remote(dst_worker, func, args=args, timeout=2) 170 171 with self.assertRaisesRegex(RuntimeError, expected_error): 172 # to_here() should raise timeout error, since it does not know about the 173 # status of rpc.remote(). 174 slow_rref.to_here(0.001) 175 # Note: If we proceed with shutdown, UserRRef will send out a RRefUserDelete 176 # but this can be a noop since it may not exist on the owner yet. Later, 177 # the owner can process the RRef creation and wait for the delete message, 178 # thus leading to a timeout. 179 # Therefore, we wait until we get notification that pending owners have 180 # been confirmed before sending out RRefUserDeletes. 181 if dst_rank != self.rank: 182 wait_until_owners_and_forks_on_rank(2, 2, rank=dst_rank) 183 184 @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2}) 185 def test_udf_remote_message_delay_timeout(self): 186 func = my_sleep_func 187 args = (2,) 188 self._test_remote_message_delay_timeout(func, args) 189 190 @dist_init(faulty_messages=[], messages_to_delay={"PYTHON_REMOTE_CALL": 2}) 191 def test_udf_remote_message_delay_timeout_to_self(self): 192 func = my_sleep_func 193 args = (1,) 194 self._test_remote_message_delay_timeout(func, args, dst=0) 195 196 @dist_init( 197 faulty_messages=[], 198 messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, 199 ) 200 def test_remote_message_builtin_delay_timeout(self): 201 func = torch.add 202 args = (torch.tensor(1), torch.tensor(1)) 203 self._test_remote_message_delay_timeout(func, args) 204 205 @dist_init( 206 faulty_messages=[], 207 messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, 208 ) 209 def test_remote_message_builtin_delay_timeout_to_self(self): 210 func = torch.add 211 args = (torch.tensor(1), torch.tensor(1)) 212 self._test_remote_message_delay_timeout(func, args, dst=0) 213 214 @dist_init( 215 faulty_messages=[], 216 messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, 217 ) 218 def test_remote_message_script_delay_timeout(self): 219 func = my_script_func 220 args = (torch.tensor(1),) 221 self._test_remote_message_delay_timeout(func, args) 222 223 @dist_init( 224 faulty_messages=[], 225 messages_to_delay={"SCRIPT_REMOTE_CALL": 2, "SCRIPT_RREF_FETCH_CALL": 1}, 226 ) 227 def test_remote_message_script_delay_timeout_to_self(self): 228 func = my_script_func 229 args = (torch.tensor(1),) 230 self._test_remote_message_delay_timeout(func, args, dst=0) 231 232 @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_RREF_FETCH_CALL": 1}) 233 def test_rref_to_here_timeout(self): 234 if self.rank != 0: 235 return 236 237 dst_rank = (self.rank + 1) % self.world_size 238 dst_worker = f"worker{dst_rank}" 239 rref = rpc.remote( 240 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 241 ) 242 expected_error = self.get_timeout_error_regex() 243 with self.assertRaisesRegex(RuntimeError, expected_error): 244 rref.to_here(0.01) 245 246 rref.to_here() 247 248 @dist_init(faulty_messages=[]) 249 def test_rpc_builtin_timeout(self): 250 next_rank = (self.rank + 1) % self.world_size 251 dst_worker = worker_name(next_rank) 252 expected_error = self.get_timeout_error_regex() 253 # PYTHON_CALL message types which correspond to Python UDF over RPC 254 # by default get a delay (see faulty_rpc_agent_test_fixture) 255 with self.assertRaisesRegex(RuntimeError, expected_error): 256 rpc.rpc_sync( 257 dst_worker, 258 torch.add, 259 args=(torch.tensor(1), torch.tensor(1)), 260 timeout=1, 261 ) 262 263 fut = rpc.rpc_async( 264 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=1 265 ) 266 with self.assertRaisesRegex(RuntimeError, expected_error): 267 fut.wait() 268 269 # Ensure that the currently set default timeout is large enough such 270 # that RPCs with delays still complete. 271 fut = rpc.rpc_async( 272 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 273 ) 274 fut.wait() 275 276 # Ensure timeout if we set a new default and don't override 277 rpc._set_rpc_timeout(0.001) 278 fut = rpc.rpc_async( 279 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)) 280 ) 281 with self.assertRaisesRegex(RuntimeError, expected_error): 282 fut.wait() 283 284 # Ensure run to completion if we specify timeout of 0 285 fut = rpc.rpc_async( 286 dst_worker, torch.add, args=(torch.tensor(1), torch.tensor(1)), timeout=0 287 ) 288 fut.wait() 289 # Reset for clean shutdown 290 rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 291 292 @dist_init(faulty_messages=[], messages_to_delay={"SCRIPT_CALL": 1.5}) 293 def test_rpc_script_timeout(self): 294 next_rank = (self.rank + 1) % self.world_size 295 dst_worker = worker_name(next_rank) 296 expected_error = self.get_timeout_error_regex() 297 with self.assertRaisesRegex(RuntimeError, expected_error): 298 rpc.rpc_sync(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) 299 300 fut = rpc.rpc_async(dst_worker, my_script_func, args=(torch.tensor(1),), timeout=1) 301 with self.assertRaisesRegex(RuntimeError, expected_error): 302 fut.wait() 303 304 # Ensure that the currently set default timeout is large enough such 305 # that RPCs with delays still complete. 306 fut = rpc.rpc_async( 307 dst_worker, my_script_func, args=(torch.tensor(1),) 308 ) 309 fut.wait() 310 311 # Ensure timeout if we set a new default and don't override 312 rpc._set_rpc_timeout(0.001) 313 fut = rpc.rpc_async( 314 dst_worker, my_script_func, args=(torch.tensor(1),) 315 ) 316 with self.assertRaisesRegex(RuntimeError, expected_error): 317 fut.wait() 318 319 # Ensure run to completion if we specify timeout of 0 320 rpc._set_rpc_timeout(0.001) 321 fut = rpc.rpc_async( 322 dst_worker, my_script_func, args=(torch.tensor(1),), timeout=0 323 ) 324 fut.wait() 325 # Reset for clean shutdown 326 rpc._set_rpc_timeout(rpc.constants.DEFAULT_RPC_TIMEOUT_SEC) 327