1# mypy: allow-untyped-defs 2 3import torch.distributed.rpc as rpc 4import torch.distributed.rpc._testing # noqa: F401 5from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 6 RpcAgentTestFixture, 7) 8 9# The following message types are currently retried in the RREF protocol and 10# distributed autograd. Thus only these messages should be tested with the 11# Faulty RPC Agent. 12retryable_message_types = ["RREF_FORK_REQUEST", 13 "RREF_CHILD_ACCEPT", 14 "RREF_USER_DELETE", 15 "CLEANUP_AUTOGRAD_CONTEXT_REQ"] 16 17# The following messages incur the corresponding delay in seconds while being 18# processed in FaultyTensorPipeAgent's enqueueSend() function. 19default_messages_to_delay = { 20 "PYTHON_CALL": 1.5, # Python UDF 21 "SCRIPT_CALL": 1.5, # Script/Builtin 22} 23 24class FaultyRpcAgentTestFixture(RpcAgentTestFixture): 25 def __init__(self, *args, **kwargs): 26 super().__init__(*args, **kwargs) 27 self.messages_to_fail = retryable_message_types 28 self.messages_to_delay = default_messages_to_delay 29 30 @property 31 def rpc_backend(self): 32 return rpc.backend_registry.BackendType[ 33 "FAULTY_TENSORPIPE" 34 ] 35 36 @property 37 def rpc_backend_options(self): 38 return rpc.backend_registry.construct_rpc_backend_options( 39 self.rpc_backend, 40 init_method=self.init_method, 41 num_worker_threads=8, 42 num_fail_sends=3, 43 messages_to_fail=self.messages_to_fail, 44 messages_to_delay=self.messages_to_delay, 45 ) 46 47 def setup_fault_injection(self, faulty_messages, messages_to_delay): 48 if faulty_messages is not None: 49 self.messages_to_fail = faulty_messages 50 if messages_to_delay is not None: 51 self.messages_to_delay = messages_to_delay 52 53 def get_shutdown_error_regex(self): 54 error_regexes = [ 55 "Exception in thread pool task", 56 "Connection reset by peer", 57 "Connection closed by peer" 58 ] 59 return "|".join([f"({error_str})" for error_str in error_regexes]) 60 61 def get_timeout_error_regex(self): 62 return "RPC ran for more than" 63