1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import io 10import multiprocessing as mp 11import os 12import runpy 13import shutil 14import subprocess 15import sys 16import tempfile 17import uuid 18from contextlib import closing, redirect_stderr, redirect_stdout 19from unittest import mock 20from unittest.mock import MagicMock, Mock, patch 21 22import torch.distributed.run as launch 23from torch.distributed.elastic.agent.server.api import RunResult, WorkerState 24from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs 25from torch.distributed.elastic.multiprocessing.errors import ChildFailedError 26from torch.distributed.elastic.utils import get_socket_with_port 27from torch.distributed.elastic.utils.distributed import get_free_port 28from torch.testing._internal.common_utils import ( 29 run_tests, 30 skip_but_pass_in_sandcastle_if, 31 TEST_WITH_DEV_DBG_ASAN, 32 TestCase, 33) 34 35 36def launch_in_proc(args): 37 launch.main(args) 38 39 40def path(script): 41 return os.path.join(os.path.dirname(__file__), script) 42 43 44def get_child_pids(pid): 45 pgrep = subprocess.Popen(args=f"pgrep -P {pid}", shell=True, stdout=subprocess.PIPE) 46 pgrep.wait() 47 out = pgrep.stdout.read().decode("utf-8").rstrip().split("\n") 48 pids = [] 49 for pid in out: 50 if pid: 51 pids.append(int(pid)) 52 return pids 53 54 55def pid_exists(pid): 56 try: 57 os.kill(pid, 0) 58 return True 59 except OSError: 60 return False 61 62 63class MockException(Exception): 64 pass 65 66 67class ElasticLaunchTest(TestCase): 68 def setUp(self): 69 self.test_dir = tempfile.mkdtemp() 70 71 # remove any lingering environment variables 72 for env in os.environ.keys(): 73 if env.startswith("PET_"): 74 del os.environ[env] 75 76 # set a sentinel env var on the parent proc 77 # this should be present on the child and gets 78 # asserted in ``bin/test_script.py`` 79 os.environ["TEST_SENTINEL_PARENT"] = "FOOBAR" 80 81 def tearDown(self): 82 shutil.rmtree(self.test_dir) 83 84 def test_launch_user_script_python(self): 85 self._test_launch_user_script_python() 86 87 def _test_launch_user_script_python(self): 88 run_id = str(uuid.uuid4().int) 89 nnodes = 1 90 nproc_per_node = 4 91 world_size = nnodes * nproc_per_node 92 args = [ 93 f"--nnodes={nnodes}", 94 f"--nproc-per-node={nproc_per_node}", 95 f"--rdzv-id={run_id}", 96 "--monitor-interval=1", 97 "--start-method=spawn", 98 path("bin/test_script.py"), 99 f"--touch-file-dir={self.test_dir}", 100 ] 101 launch.main(args) 102 103 # make sure all the workers ran 104 # each worker touches a file with its global rank as the name 105 self.assertSetEqual( 106 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 107 ) 108 109 def test_launch_user_script_python_caffe2_bc(self): 110 nnodes = 1 111 nproc_per_node = 4 112 world_size = nnodes * nproc_per_node 113 sock = get_socket_with_port() 114 with closing(sock): 115 master_port = sock.getsockname()[1] 116 args = [ 117 f"--nnodes={nnodes}", 118 f"--nproc-per-node={nproc_per_node}", 119 "--monitor-interval=1", 120 "--start-method=spawn", 121 "--master-addr=localhost", 122 f"--master-port={master_port}", 123 "--node-rank=0", 124 path("bin/test_script.py"), 125 f"--touch-file-dir={self.test_dir}", 126 ] 127 launch.main(args) 128 129 # make sure all the workers ran 130 # each worker touches a file with its global rank as the name 131 self.assertSetEqual( 132 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 133 ) 134 135 @skip_but_pass_in_sandcastle_if( 136 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 137 ) 138 def test_launch_user_script_bash(self): 139 run_id = str(uuid.uuid4().int) 140 nnodes = 1 141 nproc_per_node = 4 142 world_size = nnodes * nproc_per_node 143 args = [ 144 f"--nnodes={nnodes}", 145 f"--nproc-per-node={nproc_per_node}", 146 f"--rdzv-id={run_id}", 147 "--monitor-interval=1", 148 "--start-method=spawn", 149 "--no-python", 150 ] 151 152 script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] 153 154 with self.assertRaises(ValueError): 155 # --no-python cannot be used with --module 156 launch.main(args + ["--module"] + script_args) 157 158 launch.main(args + script_args) 159 160 # make sure all the workers ran 161 # each worker touches a file with its global rank as the name 162 self.assertSetEqual( 163 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 164 ) 165 166 @skip_but_pass_in_sandcastle_if( 167 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 168 ) 169 def test_launch_user_script_default_nproc(self): 170 run_id = str(uuid.uuid4().int) 171 nnodes = 1 172 world_size = 1 173 args = [ 174 f"--nnodes={nnodes}", 175 f"--rdzv-id={run_id}", 176 "--monitor-interval=1", 177 "--start-method=spawn", 178 "--no-python", 179 ] 180 181 script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] 182 183 with self.assertRaises(ValueError): 184 # --no-python cannot be used with --module 185 launch.main(args + ["--module"] + script_args) 186 187 launch.main(args + script_args) 188 189 # make sure all the workers ran 190 # each worker touches a file with its global rank as the name 191 self.assertSetEqual( 192 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 193 ) 194 195 @skip_but_pass_in_sandcastle_if( 196 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 197 ) 198 def test_launch_with_env_vars(self): 199 run_id = str(uuid.uuid4().int) 200 nnodes = 1 201 nproc_per_node = 4 202 world_size = nnodes * nproc_per_node 203 204 os.environ["PET_NNODES"] = str(nnodes) 205 os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node) 206 os.environ["PET_RDZV_ID"] = run_id 207 os.environ["PET_MONITOR_INTERVAL"] = "1" 208 os.environ["PET_START_METHOD"] = "spawn" 209 os.environ["PET_NO_PYTHON"] = "1" 210 211 script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] 212 213 with self.assertRaises(ValueError): 214 # --no-python cannot be used with --module 215 os.environ["PET_MODULE"] = "1" 216 launch.main(script_args) 217 218 os.environ["PET_MODULE"] = "0" 219 launch.main(script_args) 220 221 # make sure all the workers ran 222 # each worker touches a file with its global rank as the name 223 self.assertSetEqual( 224 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 225 ) 226 227 def _test_nproc_launch_configuration(self, nproc_type, expected_number): 228 run_id = str(uuid.uuid4().int) 229 nnodes = 1 230 231 args = [ 232 f"--nnodes={nnodes}", 233 f"--nproc-per-node={nproc_type}", 234 f"--rdzv-id={run_id}", 235 "--monitor-interval=1", 236 "--start-method=spawn", 237 "--no-python", 238 ] 239 240 script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] 241 242 launch.main(args + script_args) 243 244 world_size = nnodes * expected_number 245 # make sure all the workers ran 246 # each worker touches a file with its global rank as the name 247 self.assertSetEqual( 248 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 249 ) 250 251 @skip_but_pass_in_sandcastle_if( 252 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 253 ) 254 @patch("torch.cuda.is_available", return_value=False) 255 def test_nproc_launch_auto_configurations(self, _mock1): 256 self._test_nproc_launch_configuration("auto", os.cpu_count()) 257 258 @skip_but_pass_in_sandcastle_if( 259 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 260 ) 261 def test_nproc_launch_number_configurations(self): 262 self._test_nproc_launch_configuration("4", 4) 263 264 @skip_but_pass_in_sandcastle_if( 265 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 266 ) 267 def test_nproc_launch_unknown_configurations(self): 268 with self.assertRaises(ValueError): 269 self._test_nproc_launch_configuration("unknown", 4) 270 271 @skip_but_pass_in_sandcastle_if( 272 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 273 ) 274 @patch("torch.cuda.is_available", return_value=True) 275 @patch("torch.cuda.device_count", return_value=3) 276 def test_nproc_gpu_launch_configurations(self, _mock1, _mock2): 277 self._test_nproc_launch_configuration("auto", 3) 278 self._test_nproc_launch_configuration("gpu", 3) 279 280 @skip_but_pass_in_sandcastle_if( 281 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 282 ) 283 def test_launch_elastic(self): 284 run_id = str(uuid.uuid4().int) 285 min_nodes = 1 286 max_nodes = 2 287 nproc_per_node = 4 288 # we are only launching 1 node (even though max = 2) 289 world_size = nproc_per_node 290 args = [ 291 f"--nnodes={min_nodes}:{max_nodes}", 292 f"--nproc-per-node={nproc_per_node}", 293 "--rdzv-backend=c10d", 294 f"--rdzv-endpoint=localhost:{get_free_port()}", 295 "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", 296 f"--rdzv-id={run_id}", 297 "--monitor-interval=1", 298 "--start-method=spawn", 299 path("bin/test_script.py"), 300 f"--touch-file-dir={self.test_dir}", 301 ] 302 launch.main(args) 303 304 # make sure all the workers ran 305 # each worker touches a file with its global rank as the name 306 self.assertSetEqual( 307 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 308 ) 309 310 @mock.patch("torch.distributed.elastic.events.record") 311 @skip_but_pass_in_sandcastle_if( 312 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 313 ) 314 def test_launch_elastic_worker_raise_exception(self, record_mock): 315 """ 316 Asserts that when the worker program fails and lancher raieses exception 317 to indicate that worker process failed 318 319 """ 320 run_id = str(uuid.uuid4().int) 321 min_nodes = 1 322 max_nodes = 2 323 nproc_per_node = 4 324 args = [ 325 f"--nnodes={min_nodes}:{max_nodes}", 326 f"--nproc-per-node={nproc_per_node}", 327 "--rdzv-backend=c10d", 328 f"--rdzv-endpoint=localhost:{get_free_port()}", 329 "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", 330 f"--rdzv-id={run_id}", 331 "--monitor-interval=1", 332 "--max-restarts=0", 333 "--start-method=spawn", 334 path("bin/test_script.py"), 335 "--fail", 336 ] 337 with self.assertRaises(ChildFailedError): 338 launch.main(args) 339 340 record_mock.assert_called_once() 341 342 @skip_but_pass_in_sandcastle_if( 343 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 344 ) 345 @mock.patch( 346 "torch.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent.run" 347 ) 348 @mock.patch("torch.distributed.elastic.events.record") 349 def test_launch_elastic_agent_raise_exception(self, record_mock, mock_agent_run): 350 """ 351 Asserts that when the agent raises an exception 352 the launcher re-raises the original exception 353 """ 354 run_id = str(uuid.uuid4().int) 355 min_nodes = 1 356 max_nodes = 2 357 nproc_per_node = 4 358 args = [ 359 f"--nnodes={min_nodes}:{max_nodes}", 360 f"--nproc-per-node={nproc_per_node}", 361 "--rdzv-backend=c10d", 362 f"--rdzv-endpoint=localhost:{get_free_port()}", 363 "--rdzv_conf=timeout=5", 364 f"--rdzv-id={run_id}", 365 "--monitor-interval=1", 366 "--max-restarts=0", 367 "--start-method=spawn", 368 path("bin/test_script.py"), 369 f"--touch-file-dir={self.test_dir}", 370 ] 371 372 mock_agent_run.side_effect = MockException 373 with self.assertRaises(MockException): 374 launch.main(args) 375 record_mock.assert_called_once() 376 377 @skip_but_pass_in_sandcastle_if( 378 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 379 ) 380 def test_launch_standalone(self): 381 nnodes = 1 382 nproc_per_node = 4 383 world_size = nnodes * nproc_per_node 384 args = [ 385 f"--nnodes={nnodes}", 386 f"--nproc-per-node={nproc_per_node}", 387 "--standalone", 388 "--monitor-interval=1", 389 "--start-method=spawn", 390 path("bin/test_script.py"), 391 f"--touch-file-dir={self.test_dir}", 392 ] 393 launch.main(args) 394 395 # make sure all the workers ran 396 # each worker touches a file with its global rank as the name 397 self.assertSetEqual( 398 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 399 ) 400 401 @skip_but_pass_in_sandcastle_if( 402 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 403 ) 404 def test_launch_run_path(self): 405 nnodes = 1 406 nproc_per_node = 4 407 world_size = nnodes * nproc_per_node 408 args = [ 409 "--run-path", 410 f"--nnodes={nnodes}", 411 f"--nproc-per-node={nproc_per_node}", 412 "--monitor-interval=1", 413 "--start-method=spawn", 414 path("bin/test_script.py"), 415 f"--touch-file-dir={self.test_dir}", 416 ] 417 launch.main(args) 418 419 # make sure all the workers ran 420 # each worker touches a file with its global rank as the name 421 self.assertSetEqual( 422 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 423 ) 424 425 @skip_but_pass_in_sandcastle_if( 426 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 427 ) 428 def test_launch_elastic_multiple_agents(self): 429 run_id = str(uuid.uuid4().int) 430 min_nodes = 1 431 max_nodes = 2 432 nproc_per_node = 4 433 nnodes = 2 434 world_size = nnodes * nproc_per_node 435 args = [ 436 f"--nnodes={min_nodes}:{max_nodes}", 437 f"--nproc-per-node={nproc_per_node}", 438 "--rdzv-backend=c10d", 439 f"--rdzv-endpoint=localhost:{get_free_port()}", 440 "--rdzv_conf=timeout=5", 441 f"--rdzv-id={run_id}", 442 "--monitor-interval=1", 443 "--start-method=spawn", 444 path("bin/test_script.py"), 445 f"--touch-file-dir={self.test_dir}", 446 ] 447 procs = [] 448 for _ in range(nnodes - 1): 449 p = mp.Process(target=launch.main, args=[args]) 450 procs.append(p) 451 p.start() 452 launch.main(args) 453 for i in range(nnodes - 1): 454 p = procs[i] 455 p.join() 456 self.assertEqual(0, p.exitcode) 457 458 # make sure all the workers ran 459 # each worker touches a file with its global rank as the name 460 self.assertSetEqual( 461 {str(i) for i in range(world_size)}, set(os.listdir(self.test_dir)) 462 ) 463 464 def test_min_max_nodes_parse(self): 465 min_nodes, max_nodes = launch.parse_min_max_nnodes("1") 466 self.assertEqual(min_nodes, max_nodes) 467 self.assertEqual(1, min_nodes) 468 min_nodes, max_nodes = launch.parse_min_max_nnodes("2:20") 469 self.assertEqual(2, min_nodes) 470 self.assertEqual(20, max_nodes) 471 with self.assertRaises(RuntimeError): 472 launch.parse_min_max_nnodes("2:20:30") 473 474 @patch("torch.distributed.launcher.api.LocalElasticAgent") 475 def test_launch_shutdown(self, agent_mock_cls): 476 nnodes = 1 477 nproc_per_node = 4 478 args = [ 479 f"--nnodes={nnodes}", 480 f"--nproc-per-node={nproc_per_node}", 481 "--monitor-interval=1", 482 "--start-method=spawn", 483 path("bin/test_script.py"), 484 f"--touch-file-dir={self.test_dir}", 485 ] 486 agent_mock = Mock() 487 agent_mock.run.return_value = RunResult(WorkerState.SUCCEEDED) 488 agent_mock_cls.return_value = agent_mock 489 rdzv_handler_mock = Mock() 490 with patch( 491 "torch.distributed.elastic.rendezvous.registry.get_rendezvous_handler" 492 ) as param_mock: 493 param_mock.return_value = rdzv_handler_mock 494 launch.main(args) 495 rdzv_handler_mock.shutdown.assert_called_once() 496 497 @skip_but_pass_in_sandcastle_if( 498 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 499 ) 500 def test_is_torchelastic_launched(self): 501 # launch test script with torchelastic and validate that 502 # torch.distributed.is_torchelastic_launched() returns True 503 504 out_file = f"{os.path.join(self.test_dir, 'out')}" 505 launch.main( 506 [ 507 "--run-path", 508 "--nnodes=1", 509 "--nproc-per-node=1", 510 "--monitor-interval=1", 511 path("bin/test_script_is_torchelastic_launched.py"), 512 f"--out-file={out_file}", 513 ] 514 ) 515 516 with open(out_file) as fp: 517 is_torchelastic_launched = fp.readline() 518 self.assertEqual("True", is_torchelastic_launched) 519 520 @patch("torch.distributed.run.metadata") 521 @skip_but_pass_in_sandcastle_if( 522 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 523 ) 524 def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock): 525 # mock the entrypoint API to avoid version issues. 526 entrypoints = MagicMock() 527 metadata_mock.entry_points.return_value = entrypoints 528 529 group = MagicMock() 530 entrypoints.select.return_value = group 531 532 ep = MagicMock() 533 ep.load.return_value = DefaultLogsSpecs 534 535 group.select.return_value = ep 536 group.__getitem__.return_value = ep 537 538 out_file = f"{os.path.join(self.test_dir, 'out')}" 539 if os.path.exists(out_file): 540 os.remove(out_file) 541 launch.main( 542 [ 543 "--run-path", 544 "--nnodes=1", 545 "--nproc-per-node=1", 546 "--monitor-interval=1", 547 "--logs_specs=default", 548 path("bin/test_script_is_torchelastic_launched.py"), 549 f"--out-file={out_file}", 550 ] 551 ) 552 553 with open(out_file) as fp: 554 is_torchelastic_launched = fp.readline() 555 self.assertEqual("True", is_torchelastic_launched) 556 557 @skip_but_pass_in_sandcastle_if( 558 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 559 ) 560 def test_logs_logs_spec_entrypoint_must_be_defined(self): 561 with self.assertRaises(ValueError): 562 launch.main( 563 [ 564 "--run-path", 565 "--nnodes=1", 566 "--nproc-per-node=1", 567 "--monitor-interval=1", 568 "--logs_specs=DOESNOT_EXIST", 569 path("bin/test_script_is_torchelastic_launched.py"), 570 ] 571 ) 572 573 def test_is_not_torchelastic_launched(self): 574 # launch test script without torchelastic and validate that 575 # torch.distributed.is_torchelastic_launched() returns False 576 577 out_file = f"{os.path.join(self.test_dir, 'out')}" 578 579 # need to run the script with runpy in the same interpreter 580 # as the test because otherwise (depending on the environment) 581 # it will not find torch as a dependency 582 with patch.object( 583 sys, 584 "argv", 585 [ 586 path("bin/test_script_is_torchelastic_launched.py"), 587 f"--out-file={out_file}", 588 ], 589 ): 590 runpy.run_path(sys.argv[0], run_name="__main__") 591 with open(out_file) as fp: 592 is_torchelastic_launched = fp.readline() 593 self.assertEqual("False", is_torchelastic_launched) 594 595 @skip_but_pass_in_sandcastle_if( 596 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 597 ) 598 def test_init_method_tcp_with_torchelastic(self): 599 port = get_free_port() 600 launch.main( 601 [ 602 "--run-path", 603 "--nnodes=1", 604 "--nproc-per-node=4", 605 "--master-addr=localhost", 606 f"--master-port={port}", 607 "--monitor-interval=1", 608 path("bin/test_script_init_method.py"), 609 f"--init-method=tcp://localhost:{port}", 610 ] 611 ) 612 # nothing to validate, just make sure it runs 613 614 @skip_but_pass_in_sandcastle_if( 615 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 616 ) 617 def test_init_method_env_with_torchelastic(self): 618 port = get_free_port() 619 launch.main( 620 [ 621 "--run-path", 622 "--nnodes=1", 623 "--nproc-per-node=4", 624 "--master-addr=localhost", 625 f"--master-port={port}", 626 "--monitor-interval=1", 627 path("bin/test_script_init_method.py"), 628 "--init-method=env://", 629 ] 630 ) 631 # nothing to validate, just make sure it runs 632 633 def test_capture_logs_using_default_logs_specs(self): 634 run_id = str(uuid.uuid4().int) 635 nnodes = 1 636 nproc_per_node = 4 637 args = [ 638 f"--nnodes={nnodes}", 639 f"--nproc-per-node={nproc_per_node}", 640 f"--rdzv-id={run_id}", 641 "--redirect=3", 642 "--tee=3", 643 "--monitor-interval=1", 644 "--start-method=spawn", 645 "--no-python", 646 ] 647 648 script_args = [path("bin/test_script.sh"), f"{self.test_dir}"] 649 650 captured_out = io.StringIO() 651 captured_err = io.StringIO() 652 with redirect_stdout(captured_out), redirect_stderr(captured_err): 653 with patch.dict( 654 os.environ, {"TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE": "[rank${rank}]: "} 655 ): 656 launch.main(args + script_args) 657 658 for i in range(nproc_per_node): 659 self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue()) 660 661 662if __name__ == "__main__": 663 run_tests() 664