1# Owner(s): ["module: multiprocessing"] 2 3import os 4import pickle 5import random 6import signal 7import sys 8import time 9import unittest 10 11import torch.multiprocessing as mp 12 13from torch.testing._internal.common_utils import ( 14 IS_WINDOWS, 15 NO_MULTIPROCESSING_SPAWN, 16 run_tests, 17 TestCase, 18) 19 20def _test_success_func(i): 21 pass 22 23 24def _test_success_single_arg_func(i, arg): 25 if arg: 26 arg.put(i) 27 28 29def _test_exception_single_func(i, arg): 30 if i == arg: 31 raise ValueError("legitimate exception from process %d" % i) 32 time.sleep(1.0) 33 34 35def _test_exception_all_func(i): 36 time.sleep(random.random() / 10) 37 raise ValueError("legitimate exception from process %d" % i) 38 39 40def _test_terminate_signal_func(i): 41 if i == 0: 42 os.kill(os.getpid(), signal.SIGABRT) 43 time.sleep(1.0) 44 45 46def _test_terminate_exit_func(i, arg): 47 if i == 0: 48 sys.exit(arg) 49 time.sleep(1.0) 50 51 52def _test_success_first_then_exception_func(i, arg): 53 if i == 0: 54 return 55 time.sleep(0.1) 56 raise ValueError("legitimate exception") 57 58 59def _test_nested_child_body(i, ready_queue, nested_child_sleep): 60 ready_queue.put(None) 61 time.sleep(nested_child_sleep) 62 63 64def _test_infinite_task(i): 65 while True: 66 time.sleep(1) 67 68 69def _test_process_exit(idx): 70 sys.exit(12) 71 72 73def _test_nested(i, pids_queue, nested_child_sleep, start_method): 74 context = mp.get_context(start_method) 75 nested_child_ready_queue = context.Queue() 76 nprocs = 2 77 mp_context = mp.start_processes( 78 fn=_test_nested_child_body, 79 args=(nested_child_ready_queue, nested_child_sleep), 80 nprocs=nprocs, 81 join=False, 82 daemon=False, 83 start_method=start_method, 84 ) 85 pids_queue.put(mp_context.pids()) 86 87 # Wait for both children to have started, to ensure that they 88 # have called prctl(2) to register a parent death signal. 89 for _ in range(nprocs): 90 nested_child_ready_queue.get() 91 92 # Kill self. This should take down the child processes as well. 93 os.kill(os.getpid(), signal.SIGTERM) 94 95class _TestMultiProcessing: 96 start_method = None 97 98 def test_success(self): 99 mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method) 100 101 def test_success_non_blocking(self): 102 mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method) 103 104 # After all processes (nproc=2) have joined it must return True 105 mp_context.join(timeout=None) 106 mp_context.join(timeout=None) 107 self.assertTrue(mp_context.join(timeout=None)) 108 109 def test_first_argument_index(self): 110 context = mp.get_context(self.start_method) 111 queue = context.SimpleQueue() 112 mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method) 113 self.assertEqual([0, 1], sorted([queue.get(), queue.get()])) 114 115 def test_exception_single(self): 116 nprocs = 2 117 for i in range(nprocs): 118 with self.assertRaisesRegex( 119 Exception, 120 "\nValueError: legitimate exception from process %d$" % i, 121 ): 122 mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method) 123 124 def test_exception_all(self): 125 with self.assertRaisesRegex( 126 Exception, 127 "\nValueError: legitimate exception from process (0|1)$", 128 ): 129 mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method) 130 131 def test_terminate_signal(self): 132 # SIGABRT is aliased with SIGIOT 133 message = "process 0 terminated with signal (SIGABRT|SIGIOT)" 134 135 # Termination through with signal is expressed as a negative exit code 136 # in multiprocessing, so we know it was a signal that caused the exit. 137 # This doesn't appear to exist on Windows, where the exit code is always 138 # positive, and therefore results in a different exception message. 139 # Exit code 22 means "ERROR_BAD_COMMAND". 140 if IS_WINDOWS: 141 message = "process 0 terminated with exit code 22" 142 143 with self.assertRaisesRegex(Exception, message): 144 mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method) 145 146 def test_terminate_exit(self): 147 exitcode = 123 148 with self.assertRaisesRegex( 149 Exception, 150 "process 0 terminated with exit code %d" % exitcode, 151 ): 152 mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method) 153 154 def test_success_first_then_exception(self): 155 exitcode = 123 156 with self.assertRaisesRegex( 157 Exception, 158 "ValueError: legitimate exception", 159 ): 160 mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method) 161 162 @unittest.skipIf( 163 sys.platform != "linux", 164 "Only runs on Linux; requires prctl(2)", 165 ) 166 def _test_nested(self): 167 context = mp.get_context(self.start_method) 168 pids_queue = context.Queue() 169 nested_child_sleep = 20.0 170 mp_context = mp.start_processes( 171 fn=_test_nested, 172 args=(pids_queue, nested_child_sleep, self.start_method), 173 nprocs=1, 174 join=False, 175 daemon=False, 176 start_method=self.start_method, 177 ) 178 179 # Wait for nested children to terminate in time 180 pids = pids_queue.get() 181 start = time.time() 182 while len(pids) > 0: 183 for pid in pids: 184 try: 185 os.kill(pid, 0) 186 except ProcessLookupError: 187 pids.remove(pid) 188 break 189 190 # This assert fails if any nested child process is still 191 # alive after (nested_child_sleep / 2) seconds. By 192 # extension, this test times out with an assertion error 193 # after (nested_child_sleep / 2) seconds. 194 self.assertLess(time.time() - start, nested_child_sleep / 2) 195 time.sleep(0.1) 196 197@unittest.skipIf( 198 NO_MULTIPROCESSING_SPAWN, 199 "Disabled for environments that don't support the spawn start method") 200class SpawnTest(TestCase, _TestMultiProcessing): 201 start_method = 'spawn' 202 203 def test_exception_raises(self): 204 with self.assertRaises(mp.ProcessRaisedException): 205 mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1) 206 207 def test_signal_raises(self): 208 context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False) 209 for pid in context.pids(): 210 os.kill(pid, signal.SIGTERM) 211 with self.assertRaises(mp.ProcessExitedException): 212 context.join() 213 214 def _test_process_exited(self): 215 with self.assertRaises(mp.ProcessExitedException) as e: 216 mp.spawn(_test_process_exit, args=(), nprocs=1) 217 self.assertEqual(12, e.exit_code) 218 219 220@unittest.skipIf( 221 IS_WINDOWS, 222 "Fork is only available on Unix", 223) 224class ForkTest(TestCase, _TestMultiProcessing): 225 start_method = 'fork' 226 227 228@unittest.skipIf( 229 IS_WINDOWS, 230 "Fork is only available on Unix", 231) 232class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing): 233 orig_paralell_env_val = None 234 235 def setUp(self): 236 super().setUp() 237 self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) 238 os.environ[mp.ENV_VAR_PARALLEL_START] = "1" 239 240 def tearDown(self): 241 super().tearDown() 242 if self.orig_paralell_env_val is None: 243 del os.environ[mp.ENV_VAR_PARALLEL_START] 244 else: 245 os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val 246 247 248@unittest.skipIf( 249 IS_WINDOWS, 250 "Fork is only available on Unix", 251) 252class ParallelForkServerPerfTest(TestCase): 253 254 def test_forkserver_perf(self): 255 256 start_method = 'forkserver' 257 expensive = Expensive() 258 nprocs = 4 259 orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) 260 261 # test the non parallel case 262 os.environ[mp.ENV_VAR_PARALLEL_START] = "0" 263 start = time.perf_counter() 264 mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) 265 elapsed = time.perf_counter() - start 266 # the elapsed time should be at least {nprocs}x the sleep time 267 self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs) 268 269 # test the parallel case 270 os.environ[mp.ENV_VAR_PARALLEL_START] = "1" 271 start = time.perf_counter() 272 mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) 273 elapsed = time.perf_counter() - start 274 # the elapsed time should be less than {nprocs}x the sleep time 275 self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs) 276 277 if orig_paralell_env_val is None: 278 del os.environ[mp.ENV_VAR_PARALLEL_START] 279 else: 280 os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val 281 282 283class Expensive: 284 SLEEP_SECS = 5 285 # Simulate startup overhead such as large imports 286 time.sleep(SLEEP_SECS) 287 288 def __init__(self): 289 self.config: str = "*" * 1000000 290 291 def my_call(self, *args): 292 pass 293 294 295class ErrorTest(TestCase): 296 def test_errors_pickleable(self): 297 for error in ( 298 mp.ProcessRaisedException("Oh no!", 1, 1), 299 mp.ProcessExitedException("Oh no!", 1, 1, 1), 300 ): 301 pickle.loads(pickle.dumps(error)) 302 303 304if __name__ == '__main__': 305 run_tests() 306