xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import os
4import sys
5import unittest
6from typing import Dict, List, Type
7
8from torch.testing._internal.common_distributed import MultiProcessTestCase
9from torch.testing._internal.common_utils import (
10    TEST_WITH_DEV_DBG_ASAN,
11    find_free_port,
12    IS_SANDCASTLE,
13)
14from torch.testing._internal.distributed.ddp_under_dist_autograd_test import (
15    CudaDdpComparisonTest,
16    DdpComparisonTest,
17    DdpUnderDistAutogradTest,
18)
19from torch.testing._internal.distributed.nn.api.remote_module_test import (
20    CudaRemoteModuleTest,
21    RemoteModuleTest,
22    ThreeWorkersRemoteModuleTest,
23)
24from torch.testing._internal.distributed.rpc.dist_autograd_test import (
25    DistAutogradTest,
26    CudaDistAutogradTest,
27    FaultyAgentDistAutogradTest,
28    TensorPipeAgentDistAutogradTest,
29    TensorPipeCudaDistAutogradTest
30)
31from torch.testing._internal.distributed.rpc.dist_optimizer_test import (
32    DistOptimizerTest,
33)
34from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import (
35    JitDistAutogradTest,
36)
37from torch.testing._internal.distributed.rpc.jit.rpc_test import JitRpcTest
38from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import (
39    JitFaultyAgentRpcTest,
40)
41from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
42    RpcAgentTestFixture,
43)
44from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import (
45    FaultyAgentRpcTest,
46)
47from torch.testing._internal.distributed.rpc.rpc_test import (
48    CudaRpcTest,
49    RpcTest,
50    TensorPipeAgentRpcTest,
51    TensorPipeAgentCudaRpcTest,
52)
53from torch.testing._internal.distributed.rpc.examples.parameter_server_test import ParameterServerTest
54from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import (
55    ReinforcementLearningRpcTest,
56)
57
58
59def _check_and_set_tcp_init():
60    # if we are running with TCP init, set main address and port
61    # before spawning subprocesses, since different processes could find
62    # different ports.
63    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
64    if use_tcp_init == "1":
65        os.environ["MASTER_ADDR"] = '127.0.0.1'
66        os.environ["MASTER_PORT"] = str(find_free_port())
67
68def _check_and_unset_tcp_init():
69    use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None)
70    if use_tcp_init == "1":
71        del os.environ["MASTER_ADDR"]
72        del os.environ["MASTER_PORT"]
73
74# The tests for the RPC module need to cover multiple possible combinations:
75# - different aspects of the API, each one having its own suite of tests;
76# - different agents (ProcessGroup, TensorPipe, ...);
77# To avoid a combinatorial explosion in code size, and to prevent forgetting to
78# add a combination, these are generated automatically by the code in this file.
79# Here, we collect all the test suites that we need to cover.
80# We then have one separate file for each agent, from which
81# we call the generate_tests function of this file, passing to it a fixture for
82# the agent, which then gets mixed-in with each test suite.
83
84@unittest.skipIf(
85    TEST_WITH_DEV_DBG_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues"
86)
87class SpawnHelper(MultiProcessTestCase):
88    def setUp(self):
89        super().setUp()
90        _check_and_set_tcp_init()
91        self._spawn_processes()
92
93    def tearDown(self):
94        _check_and_unset_tcp_init()
95        super().tearDown()
96
97
98# This list contains test suites that are agent-agnostic and that only verify
99# compliance with the generic RPC interface specification. These tests should
100# *not* make use of implementation details of a specific agent (options,
101# attributes, ...). These test suites will be instantiated multiple times, once
102# for each agent (except the faulty agent, which is special).
103GENERIC_TESTS = [
104    RpcTest,
105    ParameterServerTest,
106    DistAutogradTest,
107    DistOptimizerTest,
108    JitRpcTest,
109    JitDistAutogradTest,
110    RemoteModuleTest,
111    ThreeWorkersRemoteModuleTest,
112    DdpUnderDistAutogradTest,
113    DdpComparisonTest,
114    ReinforcementLearningRpcTest,
115]
116GENERIC_CUDA_TESTS = [
117    CudaRpcTest,
118    CudaDistAutogradTest,
119    CudaRemoteModuleTest,
120    CudaDdpComparisonTest,
121]
122
123
124# This list contains test suites that will only be run on the TensorPipeAgent.
125# These suites should be standalone, and separate from the ones in the generic
126# list (not subclasses of those!).
127TENSORPIPE_TESTS = [
128    TensorPipeAgentRpcTest,
129    TensorPipeAgentDistAutogradTest,
130]
131TENSORPIPE_CUDA_TESTS = [
132    TensorPipeAgentCudaRpcTest,
133    TensorPipeCudaDistAutogradTest,
134]
135
136
137# This list contains test suites that will only be run on the faulty RPC agent.
138# That agent is special as it's only used to perform fault injection in order to
139# verify the error handling behavior. Thus the faulty agent will only run the
140# suites in this list, which were designed to test such behaviors, and not the
141# ones in the generic list.
142FAULTY_AGENT_TESTS = [
143    FaultyAgentRpcTest,
144    FaultyAgentDistAutogradTest,
145    JitFaultyAgentRpcTest,
146]
147
148
149def generate_tests(
150    prefix: str,
151    mixin: Type[RpcAgentTestFixture],
152    tests: List[Type[RpcAgentTestFixture]],
153    module_name: str,
154) -> Dict[str, Type[RpcAgentTestFixture]]:
155    """Mix in the classes needed to autogenerate the tests based on the params.
156
157    Takes a series of test suites, each written against a "generic" agent (i.e.,
158    derived from the abstract RpcAgentTestFixture class), as the `tests` args.
159    Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a
160    certain agent, as the `mixin` arg. Produces all combinations of them.
161    Returns a dictionary of class names to class type
162    objects which can be inserted into the global namespace of the calling
163    module. The name of each test will be a concatenation of the `prefix` arg
164    and the original name of the test suite.
165    The `module_name` should be the name of the calling module so
166    that the classes can be fixed to make it look like they belong to it, which
167    is necessary for pickling to work on them.
168    """
169    ret: Dict[str, Type[RpcAgentTestFixture]] = {}
170    for test_class in tests:
171        if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
172            print(
173                f'Skipping test {test_class} on sandcastle for the following reason: '
174                'Skip dev-asan as torch + multiprocessing spawn have known issues', file=sys.stderr)
175            continue
176
177        name = f"{prefix}{test_class.__name__}"
178        class_ = type(name, (test_class, mixin, SpawnHelper), {})
179        class_.__module__ = module_name
180        ret[name] = class_
181    return ret
182