xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import os
4from abc import ABC, abstractmethod
5
6import torch.testing._internal.dist_utils
7
8
9class RpcAgentTestFixture(ABC):
10    @property
11    def world_size(self) -> int:
12        return 4
13
14    @property
15    def init_method(self):
16        use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
17        if use_tcp_init == "1":
18            master_addr = os.environ["MASTER_ADDR"]
19            master_port = os.environ["MASTER_PORT"]
20            return f"tcp://{master_addr}:{master_port}"
21        else:
22            return self.file_init_method
23
24    @property
25    def file_init_method(self):
26        return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
27            file_name=self.file_name
28        )
29
30    @property
31    @abstractmethod
32    def rpc_backend(self):
33        pass
34
35    @property
36    @abstractmethod
37    def rpc_backend_options(self):
38        pass
39
40    def setup_fault_injection(self, faulty_messages, messages_to_delay):  # noqa: B027
41        """Method used by dist_init to prepare the faulty agent.
42
43        Does nothing for other agents.
44        """
45
46    # Shutdown sequence is not well defined, so we may see any of the following
47    # errors when running tests that simulate errors via a shutdown on the
48    # remote end.
49    @abstractmethod
50    def get_shutdown_error_regex(self):
51        """
52        Return various error message we may see from RPC agents while running
53        tests that check for failures. This function is used to match against
54        possible errors to ensure failures were raised properly.
55        """
56
57    @abstractmethod
58    def get_timeout_error_regex(self):
59        """
60        Returns a partial string indicating the error we should receive when an
61        RPC has timed out. Useful for use with assertRaisesRegex() to ensure we
62        have the right errors during timeout.
63        """
64