xref: /aosp_15_r20/external/pytorch/test/test_multiprocessing_spawn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: multiprocessing"]
2
3import os
4import pickle
5import random
6import signal
7import sys
8import time
9import unittest
10
11import torch.multiprocessing as mp
12
13from torch.testing._internal.common_utils import (
14    IS_WINDOWS,
15    NO_MULTIPROCESSING_SPAWN,
16    run_tests,
17    TestCase,
18)
19
20def _test_success_func(i):
21    pass
22
23
24def _test_success_single_arg_func(i, arg):
25    if arg:
26        arg.put(i)
27
28
29def _test_exception_single_func(i, arg):
30    if i == arg:
31        raise ValueError("legitimate exception from process %d" % i)
32    time.sleep(1.0)
33
34
35def _test_exception_all_func(i):
36    time.sleep(random.random() / 10)
37    raise ValueError("legitimate exception from process %d" % i)
38
39
40def _test_terminate_signal_func(i):
41    if i == 0:
42        os.kill(os.getpid(), signal.SIGABRT)
43    time.sleep(1.0)
44
45
46def _test_terminate_exit_func(i, arg):
47    if i == 0:
48        sys.exit(arg)
49    time.sleep(1.0)
50
51
52def _test_success_first_then_exception_func(i, arg):
53    if i == 0:
54        return
55    time.sleep(0.1)
56    raise ValueError("legitimate exception")
57
58
59def _test_nested_child_body(i, ready_queue, nested_child_sleep):
60    ready_queue.put(None)
61    time.sleep(nested_child_sleep)
62
63
64def _test_infinite_task(i):
65    while True:
66        time.sleep(1)
67
68
69def _test_process_exit(idx):
70    sys.exit(12)
71
72
73def _test_nested(i, pids_queue, nested_child_sleep, start_method):
74    context = mp.get_context(start_method)
75    nested_child_ready_queue = context.Queue()
76    nprocs = 2
77    mp_context = mp.start_processes(
78        fn=_test_nested_child_body,
79        args=(nested_child_ready_queue, nested_child_sleep),
80        nprocs=nprocs,
81        join=False,
82        daemon=False,
83        start_method=start_method,
84    )
85    pids_queue.put(mp_context.pids())
86
87    # Wait for both children to have started, to ensure that they
88    # have called prctl(2) to register a parent death signal.
89    for _ in range(nprocs):
90        nested_child_ready_queue.get()
91
92    # Kill self. This should take down the child processes as well.
93    os.kill(os.getpid(), signal.SIGTERM)
94
95class _TestMultiProcessing:
96    start_method = None
97
98    def test_success(self):
99        mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method)
100
101    def test_success_non_blocking(self):
102        mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method)
103
104        # After all processes (nproc=2) have joined it must return True
105        mp_context.join(timeout=None)
106        mp_context.join(timeout=None)
107        self.assertTrue(mp_context.join(timeout=None))
108
109    def test_first_argument_index(self):
110        context = mp.get_context(self.start_method)
111        queue = context.SimpleQueue()
112        mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method)
113        self.assertEqual([0, 1], sorted([queue.get(), queue.get()]))
114
115    def test_exception_single(self):
116        nprocs = 2
117        for i in range(nprocs):
118            with self.assertRaisesRegex(
119                Exception,
120                "\nValueError: legitimate exception from process %d$" % i,
121            ):
122                mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method)
123
124    def test_exception_all(self):
125        with self.assertRaisesRegex(
126            Exception,
127            "\nValueError: legitimate exception from process (0|1)$",
128        ):
129            mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method)
130
131    def test_terminate_signal(self):
132        # SIGABRT is aliased with SIGIOT
133        message = "process 0 terminated with signal (SIGABRT|SIGIOT)"
134
135        # Termination through with signal is expressed as a negative exit code
136        # in multiprocessing, so we know it was a signal that caused the exit.
137        # This doesn't appear to exist on Windows, where the exit code is always
138        # positive, and therefore results in a different exception message.
139        # Exit code 22 means "ERROR_BAD_COMMAND".
140        if IS_WINDOWS:
141            message = "process 0 terminated with exit code 22"
142
143        with self.assertRaisesRegex(Exception, message):
144            mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method)
145
146    def test_terminate_exit(self):
147        exitcode = 123
148        with self.assertRaisesRegex(
149            Exception,
150            "process 0 terminated with exit code %d" % exitcode,
151        ):
152            mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
153
154    def test_success_first_then_exception(self):
155        exitcode = 123
156        with self.assertRaisesRegex(
157            Exception,
158            "ValueError: legitimate exception",
159        ):
160            mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method)
161
162    @unittest.skipIf(
163        sys.platform != "linux",
164        "Only runs on Linux; requires prctl(2)",
165    )
166    def _test_nested(self):
167        context = mp.get_context(self.start_method)
168        pids_queue = context.Queue()
169        nested_child_sleep = 20.0
170        mp_context = mp.start_processes(
171            fn=_test_nested,
172            args=(pids_queue, nested_child_sleep, self.start_method),
173            nprocs=1,
174            join=False,
175            daemon=False,
176            start_method=self.start_method,
177        )
178
179        # Wait for nested children to terminate in time
180        pids = pids_queue.get()
181        start = time.time()
182        while len(pids) > 0:
183            for pid in pids:
184                try:
185                    os.kill(pid, 0)
186                except ProcessLookupError:
187                    pids.remove(pid)
188                    break
189
190            # This assert fails if any nested child process is still
191            # alive after (nested_child_sleep / 2) seconds. By
192            # extension, this test times out with an assertion error
193            # after (nested_child_sleep / 2) seconds.
194            self.assertLess(time.time() - start, nested_child_sleep / 2)
195            time.sleep(0.1)
196
197@unittest.skipIf(
198    NO_MULTIPROCESSING_SPAWN,
199    "Disabled for environments that don't support the spawn start method")
200class SpawnTest(TestCase, _TestMultiProcessing):
201    start_method = 'spawn'
202
203    def test_exception_raises(self):
204        with self.assertRaises(mp.ProcessRaisedException):
205            mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1)
206
207    def test_signal_raises(self):
208        context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False)
209        for pid in context.pids():
210            os.kill(pid, signal.SIGTERM)
211        with self.assertRaises(mp.ProcessExitedException):
212            context.join()
213
214    def _test_process_exited(self):
215        with self.assertRaises(mp.ProcessExitedException) as e:
216            mp.spawn(_test_process_exit, args=(), nprocs=1)
217            self.assertEqual(12, e.exit_code)
218
219
220@unittest.skipIf(
221    IS_WINDOWS,
222    "Fork is only available on Unix",
223)
224class ForkTest(TestCase, _TestMultiProcessing):
225    start_method = 'fork'
226
227
228@unittest.skipIf(
229    IS_WINDOWS,
230    "Fork is only available on Unix",
231)
232class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing):
233    orig_paralell_env_val = None
234
235    def setUp(self):
236        super().setUp()
237        self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
238        os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
239
240    def tearDown(self):
241        super().tearDown()
242        if self.orig_paralell_env_val is None:
243            del os.environ[mp.ENV_VAR_PARALLEL_START]
244        else:
245            os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
246
247
248@unittest.skipIf(
249    IS_WINDOWS,
250    "Fork is only available on Unix",
251)
252class ParallelForkServerPerfTest(TestCase):
253
254    def test_forkserver_perf(self):
255
256        start_method = 'forkserver'
257        expensive = Expensive()
258        nprocs = 4
259        orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
260
261        # test the non parallel case
262        os.environ[mp.ENV_VAR_PARALLEL_START] = "0"
263        start = time.perf_counter()
264        mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
265        elapsed = time.perf_counter() - start
266        # the elapsed time should be at least {nprocs}x the sleep time
267        self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs)
268
269        # test the parallel case
270        os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
271        start = time.perf_counter()
272        mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method)
273        elapsed = time.perf_counter() - start
274        # the elapsed time should be less than {nprocs}x the sleep time
275        self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs)
276
277        if orig_paralell_env_val is None:
278            del os.environ[mp.ENV_VAR_PARALLEL_START]
279        else:
280            os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val
281
282
283class Expensive:
284    SLEEP_SECS = 5
285    # Simulate startup overhead such as large imports
286    time.sleep(SLEEP_SECS)
287
288    def __init__(self):
289        self.config: str = "*" * 1000000
290
291    def my_call(self, *args):
292        pass
293
294
295class ErrorTest(TestCase):
296    def test_errors_pickleable(self):
297        for error in (
298            mp.ProcessRaisedException("Oh no!", 1, 1),
299            mp.ProcessExitedException("Oh no!", 1, 1, 1),
300        ):
301            pickle.loads(pickle.dumps(error))
302
303
304if __name__ == '__main__':
305    run_tests()
306