1# mypy: allow-untyped-defs 2 3import torch.distributed.rpc as rpc 4from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 5 RpcAgentTestFixture, 6) 7from torch.testing._internal.common_distributed import ( 8 tp_transports, 9) 10 11 12class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture): 13 @property 14 def rpc_backend(self): 15 return rpc.backend_registry.BackendType[ 16 "TENSORPIPE" 17 ] 18 19 @property 20 def rpc_backend_options(self): 21 return rpc.backend_registry.construct_rpc_backend_options( 22 self.rpc_backend, 23 init_method=self.init_method, 24 _transports=tp_transports() 25 ) 26 27 def get_shutdown_error_regex(self): 28 # FIXME Once we consolidate the error messages returned by the 29 # TensorPipe agent put some more specific regex here. 30 error_regexes = [".*"] 31 return "|".join([f"({error_str})" for error_str in error_regexes]) 32 33 def get_timeout_error_regex(self): 34 return "RPC ran for more than" 35