xref: /aosp_15_r20/external/pytorch/test/distributed/launcher/test_run.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import io
10import multiprocessing as mp
11import os
12import runpy
13import shutil
14import subprocess
15import sys
16import tempfile
17import uuid
18from contextlib import closing, redirect_stderr, redirect_stdout
19from unittest import mock
20from unittest.mock import MagicMock, Mock, patch
21
22import torch.distributed.run as launch
23from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
24from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
25from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
26from torch.distributed.elastic.utils import get_socket_with_port
27from torch.distributed.elastic.utils.distributed import get_free_port
28from torch.testing._internal.common_utils import (
29    run_tests,
30    skip_but_pass_in_sandcastle_if,
31    TEST_WITH_DEV_DBG_ASAN,
32    TestCase,
33)
34
35
36def launch_in_proc(args):
37    launch.main(args)
38
39
40def path(script):
41    return os.path.join(os.path.dirname(__file__), script)
42
43
44def get_child_pids(pid):
45    pgrep = subprocess.Popen(args=f"pgrep -P {pid}", shell=True, stdout=subprocess.PIPE)
46    pgrep.wait()
47    out = pgrep.stdout.read().decode("utf-8").rstrip().split("\n")
48    pids = []
49    for pid in out:
50        if pid:
51            pids.append(int(pid))
52    return pids
53
54
55def pid_exists(pid):
56    try:
57        os.kill(pid, 0)
58        return True
59    except OSError:
60        return False
61
62
63class MockException(Exception):
64    pass
65
66
67class ElasticLaunchTest(TestCase):
68    def setUp(self):
69        self.test_dir = tempfile.mkdtemp()
70
71        # remove any lingering environment variables
72        for env in os.environ.keys():
73            if env.startswith("PET_"):
74                del os.environ[env]
75
76        # set a sentinel env var on the parent proc
77        # this should be present on the child and gets
78        # asserted in ``bin/test_script.py``
79        os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR"
80
81    def tearDown(self):
82        shutil.rmtree(self.test_dir)
83
84    def test_launch_user_script_python(self):
85        self._test_launch_user_script_python()
86
87    def _test_launch_user_script_python(self):
88        run_id = str(uuid.uuid4().int)
89        nnodes = 1
90        nproc_per_node = 4
91        world_size = nnodes * nproc_per_node
92        args = [
93            f"--nnodes={nnodes}",
94            f"--nproc-per-node={nproc_per_node}",
95            f"--rdzv-id={run_id}",
96            "--monitor-interval=1",
97            "--start-method=spawn",
98            path("bin/test_script.py"),
99            f"--touch-file-dir={self.test_dir}",
100        ]
101        launch.main(args)
102
103        # make sure all the workers ran
104        # each worker touches a file with its global rank as the name
105        self.assertSetEqual(
106            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
107        )
108
109    def test_launch_user_script_python_caffe2_bc(self):
110        nnodes = 1
111        nproc_per_node = 4
112        world_size = nnodes * nproc_per_node
113        sock = get_socket_with_port()
114        with closing(sock):
115            master_port = sock.getsockname()[1]
116        args = [
117            f"--nnodes={nnodes}",
118            f"--nproc-per-node={nproc_per_node}",
119            "--monitor-interval=1",
120            "--start-method=spawn",
121            "--master-addr=localhost",
122            f"--master-port={master_port}",
123            "--node-rank=0",
124            path("bin/test_script.py"),
125            f"--touch-file-dir={self.test_dir}",
126        ]
127        launch.main(args)
128
129        # make sure all the workers ran
130        # each worker touches a file with its global rank as the name
131        self.assertSetEqual(
132            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
133        )
134
135    @skip_but_pass_in_sandcastle_if(
136        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
137    )
138    def test_launch_user_script_bash(self):
139        run_id = str(uuid.uuid4().int)
140        nnodes = 1
141        nproc_per_node = 4
142        world_size = nnodes * nproc_per_node
143        args = [
144            f"--nnodes={nnodes}",
145            f"--nproc-per-node={nproc_per_node}",
146            f"--rdzv-id={run_id}",
147            "--monitor-interval=1",
148            "--start-method=spawn",
149            "--no-python",
150        ]
151
152        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
153
154        with self.assertRaises(ValueError):
155            # --no-python cannot be used with --module
156            launch.main(args + ["--module"] + script_args)
157
158        launch.main(args + script_args)
159
160        # make sure all the workers ran
161        # each worker touches a file with its global rank as the name
162        self.assertSetEqual(
163            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
164        )
165
166    @skip_but_pass_in_sandcastle_if(
167        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
168    )
169    def test_launch_user_script_default_nproc(self):
170        run_id = str(uuid.uuid4().int)
171        nnodes = 1
172        world_size = 1
173        args = [
174            f"--nnodes={nnodes}",
175            f"--rdzv-id={run_id}",
176            "--monitor-interval=1",
177            "--start-method=spawn",
178            "--no-python",
179        ]
180
181        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
182
183        with self.assertRaises(ValueError):
184            # --no-python cannot be used with --module
185            launch.main(args + ["--module"] + script_args)
186
187        launch.main(args + script_args)
188
189        # make sure all the workers ran
190        # each worker touches a file with its global rank as the name
191        self.assertSetEqual(
192            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
193        )
194
195    @skip_but_pass_in_sandcastle_if(
196        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
197    )
198    def test_launch_with_env_vars(self):
199        run_id = str(uuid.uuid4().int)
200        nnodes = 1
201        nproc_per_node = 4
202        world_size = nnodes * nproc_per_node
203
204        os.environ["PET_NNODES"] = str(nnodes)
205        os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node)
206        os.environ["PET_RDZV_ID"] = run_id
207        os.environ["PET_MONITOR_INTERVAL"] = "1"
208        os.environ["PET_START_METHOD"] = "spawn"
209        os.environ["PET_NO_PYTHON"] = "1"
210
211        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
212
213        with self.assertRaises(ValueError):
214            # --no-python cannot be used with --module
215            os.environ["PET_MODULE"] = "1"
216            launch.main(script_args)
217
218        os.environ["PET_MODULE"] = "0"
219        launch.main(script_args)
220
221        # make sure all the workers ran
222        # each worker touches a file with its global rank as the name
223        self.assertSetEqual(
224            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
225        )
226
227    def _test_nproc_launch_configuration(self, nproc_type, expected_number):
228        run_id = str(uuid.uuid4().int)
229        nnodes = 1
230
231        args = [
232            f"--nnodes={nnodes}",
233            f"--nproc-per-node={nproc_type}",
234            f"--rdzv-id={run_id}",
235            "--monitor-interval=1",
236            "--start-method=spawn",
237            "--no-python",
238        ]
239
240        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
241
242        launch.main(args + script_args)
243
244        world_size = nnodes * expected_number
245        # make sure all the workers ran
246        # each worker touches a file with its global rank as the name
247        self.assertSetEqual(
248            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
249        )
250
251    @skip_but_pass_in_sandcastle_if(
252        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
253    )
254    @patch("torch.cuda.is_available", return_value=False)
255    def test_nproc_launch_auto_configurations(self, _mock1):
256        self._test_nproc_launch_configuration("auto", os.cpu_count())
257
258    @skip_but_pass_in_sandcastle_if(
259        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
260    )
261    def test_nproc_launch_number_configurations(self):
262        self._test_nproc_launch_configuration("4", 4)
263
264    @skip_but_pass_in_sandcastle_if(
265        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
266    )
267    def test_nproc_launch_unknown_configurations(self):
268        with self.assertRaises(ValueError):
269            self._test_nproc_launch_configuration("unknown", 4)
270
271    @skip_but_pass_in_sandcastle_if(
272        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
273    )
274    @patch("torch.cuda.is_available", return_value=True)
275    @patch("torch.cuda.device_count", return_value=3)
276    def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
277        self._test_nproc_launch_configuration("auto", 3)
278        self._test_nproc_launch_configuration("gpu", 3)
279
280    @skip_but_pass_in_sandcastle_if(
281        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
282    )
283    def test_launch_elastic(self):
284        run_id = str(uuid.uuid4().int)
285        min_nodes = 1
286        max_nodes = 2
287        nproc_per_node = 4
288        # we are only launching 1 node (even though max = 2)
289        world_size = nproc_per_node
290        args = [
291            f"--nnodes={min_nodes}:{max_nodes}",
292            f"--nproc-per-node={nproc_per_node}",
293            "--rdzv-backend=c10d",
294            f"--rdzv-endpoint=localhost:{get_free_port()}",
295            "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
296            f"--rdzv-id={run_id}",
297            "--monitor-interval=1",
298            "--start-method=spawn",
299            path("bin/test_script.py"),
300            f"--touch-file-dir={self.test_dir}",
301        ]
302        launch.main(args)
303
304        # make sure all the workers ran
305        # each worker touches a file with its global rank as the name
306        self.assertSetEqual(
307            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
308        )
309
310    @mock.patch("torch.distributed.elastic.events.record")
311    @skip_but_pass_in_sandcastle_if(
312        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
313    )
314    def test_launch_elastic_worker_raise_exception(self, record_mock):
315        """
316        Asserts that when the worker program fails and lancher raieses exception
317        to indicate that worker process failed
318
319        """
320        run_id = str(uuid.uuid4().int)
321        min_nodes = 1
322        max_nodes = 2
323        nproc_per_node = 4
324        args = [
325            f"--nnodes={min_nodes}:{max_nodes}",
326            f"--nproc-per-node={nproc_per_node}",
327            "--rdzv-backend=c10d",
328            f"--rdzv-endpoint=localhost:{get_free_port()}",
329            "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
330            f"--rdzv-id={run_id}",
331            "--monitor-interval=1",
332            "--max-restarts=0",
333            "--start-method=spawn",
334            path("bin/test_script.py"),
335            "--fail",
336        ]
337        with self.assertRaises(ChildFailedError):
338            launch.main(args)
339
340        record_mock.assert_called_once()
341
342    @skip_but_pass_in_sandcastle_if(
343        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
344    )
345    @mock.patch(
346        "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run"
347    )
348    @mock.patch("torch.distributed.elastic.events.record")
349    def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run):
350        """
351        Asserts that when the agent raises an exception
352        the launcher re-raises the original exception
353        """
354        run_id = str(uuid.uuid4().int)
355        min_nodes = 1
356        max_nodes = 2
357        nproc_per_node = 4
358        args = [
359            f"--nnodes={min_nodes}:{max_nodes}",
360            f"--nproc-per-node={nproc_per_node}",
361            "--rdzv-backend=c10d",
362            f"--rdzv-endpoint=localhost:{get_free_port()}",
363            "--rdzv_conf=timeout=5",
364            f"--rdzv-id={run_id}",
365            "--monitor-interval=1",
366            "--max-restarts=0",
367            "--start-method=spawn",
368            path("bin/test_script.py"),
369            f"--touch-file-dir={self.test_dir}",
370        ]
371
372        mock_agent_run.side_effect = MockException
373        with self.assertRaises(MockException):
374            launch.main(args)
375        record_mock.assert_called_once()
376
377    @skip_but_pass_in_sandcastle_if(
378        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
379    )
380    def test_launch_standalone(self):
381        nnodes = 1
382        nproc_per_node = 4
383        world_size = nnodes * nproc_per_node
384        args = [
385            f"--nnodes={nnodes}",
386            f"--nproc-per-node={nproc_per_node}",
387            "--standalone",
388            "--monitor-interval=1",
389            "--start-method=spawn",
390            path("bin/test_script.py"),
391            f"--touch-file-dir={self.test_dir}",
392        ]
393        launch.main(args)
394
395        # make sure all the workers ran
396        # each worker touches a file with its global rank as the name
397        self.assertSetEqual(
398            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
399        )
400
401    @skip_but_pass_in_sandcastle_if(
402        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
403    )
404    def test_launch_run_path(self):
405        nnodes = 1
406        nproc_per_node = 4
407        world_size = nnodes * nproc_per_node
408        args = [
409            "--run-path",
410            f"--nnodes={nnodes}",
411            f"--nproc-per-node={nproc_per_node}",
412            "--monitor-interval=1",
413            "--start-method=spawn",
414            path("bin/test_script.py"),
415            f"--touch-file-dir={self.test_dir}",
416        ]
417        launch.main(args)
418
419        # make sure all the workers ran
420        # each worker touches a file with its global rank as the name
421        self.assertSetEqual(
422            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
423        )
424
425    @skip_but_pass_in_sandcastle_if(
426        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
427    )
428    def test_launch_elastic_multiple_agents(self):
429        run_id = str(uuid.uuid4().int)
430        min_nodes = 1
431        max_nodes = 2
432        nproc_per_node = 4
433        nnodes = 2
434        world_size = nnodes * nproc_per_node
435        args = [
436            f"--nnodes={min_nodes}:{max_nodes}",
437            f"--nproc-per-node={nproc_per_node}",
438            "--rdzv-backend=c10d",
439            f"--rdzv-endpoint=localhost:{get_free_port()}",
440            "--rdzv_conf=timeout=5",
441            f"--rdzv-id={run_id}",
442            "--monitor-interval=1",
443            "--start-method=spawn",
444            path("bin/test_script.py"),
445            f"--touch-file-dir={self.test_dir}",
446        ]
447        procs = []
448        for _ in range(nnodes - 1):
449            p = mp.Process(target=launch.main, args=[args])
450            procs.append(p)
451            p.start()
452        launch.main(args)
453        for i in range(nnodes - 1):
454            p = procs[i]
455            p.join()
456            self.assertEqual(0, p.exitcode)
457
458        # make sure all the workers ran
459        # each worker touches a file with its global rank as the name
460        self.assertSetEqual(
461            {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir))
462        )
463
464    def test_min_max_nodes_parse(self):
465        min_nodes, max_nodes = launch.parse_min_max_nnodes("1")
466        self.assertEqual(min_nodes, max_nodes)
467        self.assertEqual(1, min_nodes)
468        min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20")
469        self.assertEqual(2, min_nodes)
470        self.assertEqual(20, max_nodes)
471        with self.assertRaises(RuntimeError):
472            launch.parse_min_max_nnodes("2:20:30")
473
474    @patch("torch.distributed.launcher.api.LocalElasticAgent")
475    def test_launch_shutdown(self, agent_mock_cls):
476        nnodes = 1
477        nproc_per_node = 4
478        args = [
479            f"--nnodes={nnodes}",
480            f"--nproc-per-node={nproc_per_node}",
481            "--monitor-interval=1",
482            "--start-method=spawn",
483            path("bin/test_script.py"),
484            f"--touch-file-dir={self.test_dir}",
485        ]
486        agent_mock = Mock()
487        agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED)
488        agent_mock_cls.return_value = agent_mock
489        rdzv_handler_mock = Mock()
490        with patch(
491            "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler"
492        ) as param_mock:
493            param_mock.return_value = rdzv_handler_mock
494            launch.main(args)
495            rdzv_handler_mock.shutdown.assert_called_once()
496
497    @skip_but_pass_in_sandcastle_if(
498        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
499    )
500    def test_is_torchelastic_launched(self):
501        # launch test script with torchelastic and validate that
502        # torch.distributed.is_torchelastic_launched() returns True
503
504        out_file = f"{os.path.join(self.test_dir, 'out')}"
505        launch.main(
506            [
507                "--run-path",
508                "--nnodes=1",
509                "--nproc-per-node=1",
510                "--monitor-interval=1",
511                path("bin/test_script_is_torchelastic_launched.py"),
512                f"--out-file={out_file}",
513            ]
514        )
515
516        with open(out_file) as fp:
517            is_torchelastic_launched = fp.readline()
518            self.assertEqual("True", is_torchelastic_launched)
519
520    @patch("torch.distributed.run.metadata")
521    @skip_but_pass_in_sandcastle_if(
522        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
523    )
524    def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock):
525        # mock the entrypoint API to avoid version issues.
526        entrypoints = MagicMock()
527        metadata_mock.entry_points.return_value = entrypoints
528
529        group = MagicMock()
530        entrypoints.select.return_value = group
531
532        ep = MagicMock()
533        ep.load.return_value = DefaultLogsSpecs
534
535        group.select.return_value = ep
536        group.__getitem__.return_value = ep
537
538        out_file = f"{os.path.join(self.test_dir, 'out')}"
539        if os.path.exists(out_file):
540            os.remove(out_file)
541        launch.main(
542            [
543                "--run-path",
544                "--nnodes=1",
545                "--nproc-per-node=1",
546                "--monitor-interval=1",
547                "--logs_specs=default",
548                path("bin/test_script_is_torchelastic_launched.py"),
549                f"--out-file={out_file}",
550            ]
551        )
552
553        with open(out_file) as fp:
554            is_torchelastic_launched = fp.readline()
555            self.assertEqual("True", is_torchelastic_launched)
556
557    @skip_but_pass_in_sandcastle_if(
558        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
559    )
560    def test_logs_logs_spec_entrypoint_must_be_defined(self):
561        with self.assertRaises(ValueError):
562            launch.main(
563                [
564                    "--run-path",
565                    "--nnodes=1",
566                    "--nproc-per-node=1",
567                    "--monitor-interval=1",
568                    "--logs_specs=DOESNOT_EXIST",
569                    path("bin/test_script_is_torchelastic_launched.py"),
570                ]
571            )
572
573    def test_is_not_torchelastic_launched(self):
574        # launch test script without torchelastic and validate that
575        # torch.distributed.is_torchelastic_launched() returns False
576
577        out_file = f"{os.path.join(self.test_dir, 'out')}"
578
579        # need to run the script with runpy in the same interpreter
580        # as the test because otherwise (depending on the environment)
581        # it will not find torch as a dependency
582        with patch.object(
583            sys,
584            "argv",
585            [
586                path("bin/test_script_is_torchelastic_launched.py"),
587                f"--out-file={out_file}",
588            ],
589        ):
590            runpy.run_path(sys.argv[0], run_name="__main__")
591            with open(out_file) as fp:
592                is_torchelastic_launched = fp.readline()
593                self.assertEqual("False", is_torchelastic_launched)
594
595    @skip_but_pass_in_sandcastle_if(
596        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
597    )
598    def test_init_method_tcp_with_torchelastic(self):
599        port = get_free_port()
600        launch.main(
601            [
602                "--run-path",
603                "--nnodes=1",
604                "--nproc-per-node=4",
605                "--master-addr=localhost",
606                f"--master-port={port}",
607                "--monitor-interval=1",
608                path("bin/test_script_init_method.py"),
609                f"--init-method=tcp://localhost:{port}",
610            ]
611        )
612        # nothing to validate, just make sure it runs
613
614    @skip_but_pass_in_sandcastle_if(
615        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
616    )
617    def test_init_method_env_with_torchelastic(self):
618        port = get_free_port()
619        launch.main(
620            [
621                "--run-path",
622                "--nnodes=1",
623                "--nproc-per-node=4",
624                "--master-addr=localhost",
625                f"--master-port={port}",
626                "--monitor-interval=1",
627                path("bin/test_script_init_method.py"),
628                "--init-method=env://",
629            ]
630        )
631        # nothing to validate, just make sure it runs
632
633    def test_capture_logs_using_default_logs_specs(self):
634        run_id = str(uuid.uuid4().int)
635        nnodes = 1
636        nproc_per_node = 4
637        args = [
638            f"--nnodes={nnodes}",
639            f"--nproc-per-node={nproc_per_node}",
640            f"--rdzv-id={run_id}",
641            "--redirect=3",
642            "--tee=3",
643            "--monitor-interval=1",
644            "--start-method=spawn",
645            "--no-python",
646        ]
647
648        script_args = [path("bin/test_script.sh"), f"{self.test_dir}"]
649
650        captured_out = io.StringIO()
651        captured_err = io.StringIO()
652        with redirect_stdout(captured_out), redirect_stderr(captured_err):
653            with patch.dict(
654                os.environ, {"TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE": "[rank${rank}]: "}
655            ):
656                launch.main(args + script_args)
657
658        for i in range(nproc_per_node):
659            self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue())
660
661
662if __name__ == "__main__":
663    run_tests()
664