1# mypy: allow-untyped-defs
2
3import torch.distributed.rpc as rpc
4from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
5    RpcAgentTestFixture,
6)
7from torch.testing._internal.common_distributed import (
8    tp_transports,
9)
10
11
12class TensorPipeRpcAgentTestFixture(RpcAgentTestFixture):
13    @property
14    def rpc_backend(self):
15        return rpc.backend_registry.BackendType[
16            "TENSORPIPE"
17        ]
18
19    @property
20    def rpc_backend_options(self):
21        return rpc.backend_registry.construct_rpc_backend_options(
22            self.rpc_backend,
23            init_method=self.init_method,
24            _transports=tp_transports()
25        )
26
27    def get_shutdown_error_regex(self):
28        # FIXME Once we consolidate the error messages returned by the
29        # TensorPipe agent put some more specific regex here.
30        error_regexes = [".*"]
31        return "|".join([f"({error_str})" for error_str in error_regexes])
32
33    def get_timeout_error_regex(self):
34        return "RPC ran for more than"
35