xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch.distributed.rpc as rpc
4import torch.distributed.rpc._testing  # noqa: F401
5from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
6    RpcAgentTestFixture,
7)
8
9# The following message types are currently retried in the RREF protocol and
10# distributed autograd. Thus only these messages should be tested with the
11# Faulty RPC Agent.
12retryable_message_types = ["RREF_FORK_REQUEST",
13                           "RREF_CHILD_ACCEPT",
14                           "RREF_USER_DELETE",
15                           "CLEANUP_AUTOGRAD_CONTEXT_REQ"]
16
17# The following messages incur the corresponding delay in seconds while being
18# processed in FaultyTensorPipeAgent's enqueueSend() function.
19default_messages_to_delay = {
20    "PYTHON_CALL": 1.5,  # Python UDF
21    "SCRIPT_CALL": 1.5,  # Script/Builtin
22}
23
24class FaultyRpcAgentTestFixture(RpcAgentTestFixture):
25    def __init__(self, *args, **kwargs):
26        super().__init__(*args, **kwargs)
27        self.messages_to_fail = retryable_message_types
28        self.messages_to_delay = default_messages_to_delay
29
30    @property
31    def rpc_backend(self):
32        return rpc.backend_registry.BackendType[
33            "FAULTY_TENSORPIPE"
34        ]
35
36    @property
37    def rpc_backend_options(self):
38        return rpc.backend_registry.construct_rpc_backend_options(
39            self.rpc_backend,
40            init_method=self.init_method,
41            num_worker_threads=8,
42            num_fail_sends=3,
43            messages_to_fail=self.messages_to_fail,
44            messages_to_delay=self.messages_to_delay,
45        )
46
47    def setup_fault_injection(self, faulty_messages, messages_to_delay):
48        if faulty_messages is not None:
49            self.messages_to_fail = faulty_messages
50        if messages_to_delay is not None:
51            self.messages_to_delay = messages_to_delay
52
53    def get_shutdown_error_regex(self):
54        error_regexes = [
55            "Exception in thread pool task",
56            "Connection reset by peer",
57            "Connection closed by peer"
58        ]
59        return "|".join([f"({error_str})" for error_str in error_regexes])
60
61    def get_timeout_error_regex(self):
62        return "RPC ran for more than"
63