1# mypy: allow-untyped-defs 2 3import os 4import sys 5import unittest 6from typing import Dict, List, Type 7 8from torch.testing._internal.common_distributed import MultiProcessTestCase 9from torch.testing._internal.common_utils import ( 10 TEST_WITH_DEV_DBG_ASAN, 11 find_free_port, 12 IS_SANDCASTLE, 13) 14from torch.testing._internal.distributed.ddp_under_dist_autograd_test import ( 15 CudaDdpComparisonTest, 16 DdpComparisonTest, 17 DdpUnderDistAutogradTest, 18) 19from torch.testing._internal.distributed.nn.api.remote_module_test import ( 20 CudaRemoteModuleTest, 21 RemoteModuleTest, 22 ThreeWorkersRemoteModuleTest, 23) 24from torch.testing._internal.distributed.rpc.dist_autograd_test import ( 25 DistAutogradTest, 26 CudaDistAutogradTest, 27 FaultyAgentDistAutogradTest, 28 TensorPipeAgentDistAutogradTest, 29 TensorPipeCudaDistAutogradTest 30) 31from torch.testing._internal.distributed.rpc.dist_optimizer_test import ( 32 DistOptimizerTest, 33) 34from torch.testing._internal.distributed.rpc.jit.dist_autograd_test import ( 35 JitDistAutogradTest, 36) 37from torch.testing._internal.distributed.rpc.jit.rpc_test import JitRpcTest 38from torch.testing._internal.distributed.rpc.jit.rpc_test_faulty import ( 39 JitFaultyAgentRpcTest, 40) 41from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 42 RpcAgentTestFixture, 43) 44from torch.testing._internal.distributed.rpc.faulty_agent_rpc_test import ( 45 FaultyAgentRpcTest, 46) 47from torch.testing._internal.distributed.rpc.rpc_test import ( 48 CudaRpcTest, 49 RpcTest, 50 TensorPipeAgentRpcTest, 51 TensorPipeAgentCudaRpcTest, 52) 53from torch.testing._internal.distributed.rpc.examples.parameter_server_test import ParameterServerTest 54from torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test import ( 55 ReinforcementLearningRpcTest, 56) 57 58 59def _check_and_set_tcp_init(): 60 # if we are running with TCP init, set main address and port 61 # before spawning subprocesses, since different processes could find 62 # different ports. 63 use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) 64 if use_tcp_init == "1": 65 os.environ["MASTER_ADDR"] = '127.0.0.1' 66 os.environ["MASTER_PORT"] = str(find_free_port()) 67 68def _check_and_unset_tcp_init(): 69 use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) 70 if use_tcp_init == "1": 71 del os.environ["MASTER_ADDR"] 72 del os.environ["MASTER_PORT"] 73 74# The tests for the RPC module need to cover multiple possible combinations: 75# - different aspects of the API, each one having its own suite of tests; 76# - different agents (ProcessGroup, TensorPipe, ...); 77# To avoid a combinatorial explosion in code size, and to prevent forgetting to 78# add a combination, these are generated automatically by the code in this file. 79# Here, we collect all the test suites that we need to cover. 80# We then have one separate file for each agent, from which 81# we call the generate_tests function of this file, passing to it a fixture for 82# the agent, which then gets mixed-in with each test suite. 83 84@unittest.skipIf( 85 TEST_WITH_DEV_DBG_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues" 86) 87class SpawnHelper(MultiProcessTestCase): 88 def setUp(self): 89 super().setUp() 90 _check_and_set_tcp_init() 91 self._spawn_processes() 92 93 def tearDown(self): 94 _check_and_unset_tcp_init() 95 super().tearDown() 96 97 98# This list contains test suites that are agent-agnostic and that only verify 99# compliance with the generic RPC interface specification. These tests should 100# *not* make use of implementation details of a specific agent (options, 101# attributes, ...). These test suites will be instantiated multiple times, once 102# for each agent (except the faulty agent, which is special). 103GENERIC_TESTS = [ 104 RpcTest, 105 ParameterServerTest, 106 DistAutogradTest, 107 DistOptimizerTest, 108 JitRpcTest, 109 JitDistAutogradTest, 110 RemoteModuleTest, 111 ThreeWorkersRemoteModuleTest, 112 DdpUnderDistAutogradTest, 113 DdpComparisonTest, 114 ReinforcementLearningRpcTest, 115] 116GENERIC_CUDA_TESTS = [ 117 CudaRpcTest, 118 CudaDistAutogradTest, 119 CudaRemoteModuleTest, 120 CudaDdpComparisonTest, 121] 122 123 124# This list contains test suites that will only be run on the TensorPipeAgent. 125# These suites should be standalone, and separate from the ones in the generic 126# list (not subclasses of those!). 127TENSORPIPE_TESTS = [ 128 TensorPipeAgentRpcTest, 129 TensorPipeAgentDistAutogradTest, 130] 131TENSORPIPE_CUDA_TESTS = [ 132 TensorPipeAgentCudaRpcTest, 133 TensorPipeCudaDistAutogradTest, 134] 135 136 137# This list contains test suites that will only be run on the faulty RPC agent. 138# That agent is special as it's only used to perform fault injection in order to 139# verify the error handling behavior. Thus the faulty agent will only run the 140# suites in this list, which were designed to test such behaviors, and not the 141# ones in the generic list. 142FAULTY_AGENT_TESTS = [ 143 FaultyAgentRpcTest, 144 FaultyAgentDistAutogradTest, 145 JitFaultyAgentRpcTest, 146] 147 148 149def generate_tests( 150 prefix: str, 151 mixin: Type[RpcAgentTestFixture], 152 tests: List[Type[RpcAgentTestFixture]], 153 module_name: str, 154) -> Dict[str, Type[RpcAgentTestFixture]]: 155 """Mix in the classes needed to autogenerate the tests based on the params. 156 157 Takes a series of test suites, each written against a "generic" agent (i.e., 158 derived from the abstract RpcAgentTestFixture class), as the `tests` args. 159 Takes a concrete subclass of RpcAgentTestFixture, which specializes it for a 160 certain agent, as the `mixin` arg. Produces all combinations of them. 161 Returns a dictionary of class names to class type 162 objects which can be inserted into the global namespace of the calling 163 module. The name of each test will be a concatenation of the `prefix` arg 164 and the original name of the test suite. 165 The `module_name` should be the name of the calling module so 166 that the classes can be fixed to make it look like they belong to it, which 167 is necessary for pickling to work on them. 168 """ 169 ret: Dict[str, Type[RpcAgentTestFixture]] = {} 170 for test_class in tests: 171 if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN: 172 print( 173 f'Skipping test {test_class} on sandcastle for the following reason: ' 174 'Skip dev-asan as torch + multiprocessing spawn have known issues', file=sys.stderr) 175 continue 176 177 name = f"{prefix}{test_class.__name__}" 178 class_ = type(name, (test_class, mixin, SpawnHelper), {}) 179 class_.__module__ = module_name 180 ret[name] = class_ 181 return ret 182