1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import json 10import multiprocessing as mp 11import os 12import shutil 13import signal 14import socket 15import tempfile 16import time 17import unittest 18import uuid 19from dataclasses import dataclass 20from typing import Callable, Dict, List, Optional, Tuple 21from unittest import mock 22from unittest.mock import Mock, patch 23 24import torch 25import torch.distributed as dist 26import torch.distributed.elastic.rendezvous.registry as rdzv_registry 27import torch.distributed.rpc as rpc 28from torch.distributed.elastic.agent.server.api import ( 29 RunResult, 30 WorkerSpec, 31 WorkerState, 32) 33from torch.distributed.elastic.agent.server.local_elastic_agent import ( 34 LocalElasticAgent, 35 TORCHELASTIC_HEALTH_CHECK_PORT, 36 TORCHELASTIC_TIMER_FILE, 37) 38from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, Std 39from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record 40from torch.distributed.elastic.rendezvous import RendezvousParameters 41from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer 42from torch.distributed.rpc.backend_registry import BackendType 43from torch.testing._internal.common_utils import ( 44 skip_but_pass_in_sandcastle_if, 45 TEST_WITH_DEV_DBG_ASAN, 46 TEST_WITH_TSAN, 47) 48 49 50def init_rpc(name, backend): 51 rank = int(os.environ["RANK"]) 52 world_size = int(os.environ["WORLD_SIZE"]) 53 rpc.init_rpc( 54 name=name, 55 backend=backend, 56 rank=rank, 57 world_size=world_size, 58 ) 59 60 61def rpc_master(msg): 62 init_rpc("master", BackendType.TENSORPIPE) 63 ret = rpc.rpc_sync(to="worker", func=_echo, args=(msg,)) 64 rpc.shutdown() 65 return f"{ret} from worker" 66 67 68def rpc_worker(): 69 init_rpc("worker", BackendType.TENSORPIPE) 70 rpc.shutdown() 71 72 73def _happy_function(): 74 return 75 76 77def _sad_function(): 78 raise RuntimeError("sad because i throw") 79 80 81def dummy_compute() -> torch.Tensor: 82 """ 83 returns a predefined size random Tensor 84 """ 85 return torch.rand(100, 100) 86 87 88def dummy_compute_simulate_rank_failure() -> torch.Tensor: 89 """ 90 fails rank 1 once 91 in other cases, returns a predefined size random Tensor 92 """ 93 if os.environ["RANK"] == "1" and os.environ["TORCHELASTIC_RESTART_COUNT"] == "0": 94 os.kill(os.getpid(), 9) 95 return torch.rand(100, 100) 96 97 98def _fatal_signal_function(expected_error_index: int, sig: int): 99 rank = int(os.environ["RANK"]) 100 if rank == expected_error_index: 101 os.kill(os.getpid(), sig) 102 103 104def _check_master_port_addr_override( 105 expected_master_addr: str, expected_master_port: int 106): 107 actual_master_addr = os.environ["MASTER_ADDR"] 108 actual_master_port = int(os.environ["MASTER_PORT"]) 109 if ( 110 expected_master_addr != actual_master_addr 111 and expected_master_port != actual_master_port 112 ): 113 raise RuntimeError( 114 f"Expected addr: {expected_master_addr}:{expected_master_port}, got addr: {actual_master_addr}:{actual_master_port}" 115 ) 116 117 118def _bipolar_function(): 119 rank = int(os.environ["RANK"]) 120 if rank % 2 == 0: 121 _happy_function() 122 else: 123 _sad_function() 124 125 126def _bipolar_sleep_function(sleep_sec): 127 rank = int(os.environ["RANK"]) 128 if rank % 2 == 0: 129 _sleep(sleep_sec) 130 else: 131 _sad_function() 132 133 134def _dist_sum(wait=0): 135 rank = int(os.environ["RANK"]) 136 world_size = int(os.environ["WORLD_SIZE"]) 137 dist.init_process_group(backend="gloo") 138 t = torch.tensor(rank) 139 140 time.sleep(wait) 141 dist.all_reduce(t, op=dist.reduce_op.SUM) 142 143 expected_sum = sum(range(world_size)) 144 actual = t.item() 145 if expected_sum != actual: 146 raise RuntimeError(f"Expected rank sum {expected_sum}, got {actual}") 147 148 149def _sleep(sleep_sec) -> int: 150 time.sleep(sleep_sec) 151 return int(os.environ["RANK"]) 152 153 154@dataclass 155class RankInfo: 156 rank: int 157 role_rank: int 158 group_rank: int 159 role_world_size: int 160 world_size: int 161 162 163def _get_role_info() -> RankInfo: 164 rank = int(os.environ["RANK"]) 165 role_rank = int(os.environ["ROLE_RANK"]) 166 group_rank = int(os.environ["GROUP_RANK"]) 167 role_world_size = int(os.environ["ROLE_WORLD_SIZE"]) 168 world_size = int(os.environ["WORLD_SIZE"]) 169 return RankInfo(rank, role_rank, group_rank, role_world_size, world_size) 170 171 172def _echo(msg): 173 return msg 174 175 176def _check_env_function(): 177 # just check these env vars exist, os.environ[...] will naturally throw 178 # if the variable does not exist 179 env_vars = [ 180 "RANK", 181 "LOCAL_RANK", 182 "ROLE_RANK", 183 "ROLE_NAME", 184 "GROUP_RANK", 185 "LOCAL_WORLD_SIZE", 186 "ROLE_WORLD_SIZE", 187 "WORLD_SIZE", 188 "GROUP_WORLD_SIZE", 189 "MASTER_ADDR", 190 "MASTER_PORT", 191 "TORCHELASTIC_RESTART_COUNT", 192 "TORCHELASTIC_MAX_RESTARTS", 193 "TORCHELASTIC_RUN_ID", 194 "TORCHELASTIC_USE_AGENT_STORE", 195 "TORCH_NCCL_ASYNC_ERROR_HANDLING", 196 ] 197 for var in env_vars: 198 _ = os.environ[var] 199 200 201def _check_env_value(key: str, expected: str): 202 # checks if the env var ``key`` matches ``value`` 203 # this function is intended to be used as the entrypoint to the elastic run 204 if key not in os.environ: 205 raise RuntimeError(f"Environment variable {key} not found in os.environ") 206 else: 207 actual = os.getenv(key) 208 if expected != actual: 209 raise RuntimeError( 210 f"os.environ['{key}']={actual}" 211 f" does not equal the expected value: {expected}" 212 ) 213 214 215def _check_local_watchdog_setup(key: str, should_exist: bool): 216 if should_exist and key not in os.environ: 217 raise RuntimeError(f"Environment variable {key} not found in os.environ") 218 if not should_exist and key in os.environ: 219 raise RuntimeError(f"Environment variable {key} found in os.environ") 220 221 222def acquire_available_port(): 223 """ 224 Uses sockets to acquire an available port from the os for use. 225 226 Note: To reduce the race condition where another process grabs the port 227 after this function returns an available port, we should aim to use 228 the port as quickly as possible. 229 """ 230 addrs = socket.getaddrinfo( 231 host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM 232 ) 233 234 for addr in addrs: 235 family, type, proto, _, _ = addr 236 try: 237 s = socket.socket(family, type, proto) 238 s.bind(("localhost", 0)) 239 s.listen(0) 240 port = s.getsockname()[1] 241 s.close() 242 return port 243 except OSError as e: 244 s.close() 245 print(f"Socket creation attempt failed: {e}") 246 247 raise RuntimeError("Failed to create a socket") 248 249 250@dataclass 251class Conf: 252 """ 253 Holds arguments to launch an agent (e.g. simulates an agent run on a node). 254 255 """ 256 257 entrypoint: Callable 258 local_world_size: int 259 args: Tuple = () 260 role: str = "default" 261 redirects: Std = Std.NONE 262 tee: Std = Std.NONE 263 264 265class LocalElasticAgentTest(unittest.TestCase): 266 @classmethod 267 def setUpClass(cls): 268 # start a standalone, single process etcd server to use for all tests 269 cls._etcd_server = EtcdServer() 270 cls._etcd_server.start() 271 272 @classmethod 273 def tearDownClass(cls): 274 # stop the standalone etcd server 275 cls._etcd_server.stop() 276 277 def setUp(self): 278 self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) 279 self._run_id = str(uuid.uuid4()).split("-")[0] 280 281 def tearDown(self): 282 shutil.rmtree(self._test_dir) 283 284 def log_dir(self) -> str: 285 return tempfile.mkdtemp(prefix="torchelastic_", dir=self._test_dir) 286 287 def get_worker_spec( 288 self, 289 node_config: Conf, 290 min_nodes=1, 291 max_nodes=1, 292 max_restarts=0, 293 monitor_interval=0.01, 294 master_addr_override: Optional[str] = None, 295 master_port_override: Optional[int] = None, 296 is_host=True, 297 ): 298 rdzv_params = RendezvousParameters( 299 backend=self._backend, 300 endpoint=self._endpoint, 301 run_id=self._run_id, 302 min_nodes=min_nodes, 303 max_nodes=max_nodes, 304 is_host=is_host, 305 ) 306 rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) 307 return WorkerSpec( 308 role=node_config.role, 309 local_world_size=node_config.local_world_size, 310 entrypoint=node_config.entrypoint, 311 args=node_config.args, 312 rdzv_handler=rdzv_handler, 313 max_restarts=max_restarts, 314 monitor_interval=monitor_interval, 315 master_addr=master_addr_override, 316 master_port=master_port_override, 317 ) 318 319 def get_agent( 320 self, 321 spec: WorkerSpec, 322 node_config: Conf, 323 start_method: str = "spawn", 324 exit_barrier_timeout=5, 325 log_line_prefix_template: Optional[str] = None, 326 ) -> LocalElasticAgent: 327 return LocalElasticAgent( 328 spec, 329 start_method=start_method, 330 exit_barrier_timeout=exit_barrier_timeout, 331 logs_specs=DefaultLogsSpecs( 332 log_dir=self.log_dir(), 333 redirects=node_config.redirects, 334 tee=node_config.tee, 335 ), 336 log_line_prefix_template=log_line_prefix_template, 337 ) 338 339 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 340 # `torch.distributed.elastic.multiprocessing.errors.record`. 341 @record 342 def run_agent( 343 self, 344 conf: Conf, 345 agent_results: Optional[mp.Queue] = None, # (role, agent_result) 346 min_nodes=1, 347 max_nodes=1, 348 start_method: str = "spawn", 349 max_restarts: int = 0, 350 exit_barrier_timeout=5, 351 master_addr_override: Optional[str] = None, 352 master_port_override: Optional[int] = None, 353 is_host=True, 354 monitor_interval=0.01, 355 log_line_prefix_template: Optional[str] = None, 356 ) -> Optional[RunResult]: 357 """ 358 Runs a single agent. This method can be called either on a separate process 359 or the main test process. When calling this method on a separate process make 360 sure to pass the ``agent_results`` multiprocessing Queue so that the agent's 361 run results can be returned. If ``agent_results`` is omitted, then the 362 run result is returned from the method. 363 """ 364 365 spec = self.get_worker_spec( 366 node_config=conf, 367 min_nodes=min_nodes, 368 max_nodes=max_nodes, 369 max_restarts=max_restarts, 370 master_addr_override=master_addr_override, 371 master_port_override=master_port_override, 372 is_host=is_host, 373 monitor_interval=monitor_interval, 374 ) 375 agent = self.get_agent( 376 spec=spec, 377 node_config=conf, 378 start_method=start_method, 379 exit_barrier_timeout=exit_barrier_timeout, 380 log_line_prefix_template=log_line_prefix_template, 381 ) 382 383 result = agent.run() 384 spec.rdzv_handler.shutdown() 385 386 if agent_results: 387 agent_results.put((conf.role, result)) 388 389 if result.is_failed(): 390 raise ChildFailedError(spec.get_entrypoint_name(), result.failures) 391 else: 392 if not agent_results: 393 return result 394 395 def run_job( 396 self, 397 node_configs: List[Conf], 398 exit_barrier_timeout: int = 5, 399 log_line_prefix_template: Optional[str] = None, 400 ) -> Dict[str, List[RunResult]]: 401 """ 402 Simulates running a distributed job by running multiple agents 403 (one on each process). Agent 0 is run on the main process for 404 test coverage and ease of debugging 405 """ 406 407 nnodes = len(node_configs) 408 409 # each element in this queue holds a tuple (role, RunResult) for each agent 410 agent_results = mp.Queue() 411 412 # run first agent of first config on main process for test coverage + ease of debugging 413 # it is important we loop in reverse order b/c running fn on the main process blocks 414 procs = [] 415 for node_idx in reversed(range(len(node_configs))): 416 conf = node_configs[node_idx] 417 run_agent_args = { 418 "conf": conf, 419 "agent_results": agent_results, 420 "min_nodes": nnodes, 421 "max_nodes": nnodes, 422 "start_method": "spawn", 423 "max_restarts": 0, 424 "exit_barrier_timeout": exit_barrier_timeout, 425 "is_host": node_idx == 0, 426 "log_line_prefix_template": log_line_prefix_template, 427 } 428 p = mp.Process(target=self.run_agent, kwargs=run_agent_args) 429 procs.append(p) 430 p.start() 431 for p in procs: 432 p.join() 433 434 results: Dict[str, List[RunResult]] = {} 435 while not agent_results.empty(): 436 role, run_result = agent_results.get() 437 results.setdefault(role, []).append(run_result) 438 return results 439 440 def run_test_with_backend(self, backend: str, test_to_run: Callable): 441 """ 442 Sets the backend and determines the endpoint before running the 443 given test. 444 445 Note: This method must be invoked to run any test functions that spawn 446 an agent. This is because this function sets the backend and 447 endpoint parameters. 448 """ 449 self._backend = backend 450 451 if self._backend == "etcd-v2" or self._backend == "etcd": 452 self._endpoint = self._etcd_server.get_endpoint() 453 else: 454 # the default is c10d backend 455 self._endpoint = f"localhost:{acquire_available_port()}" 456 457 test_to_run() 458 459 def dummy_compute(self): 460 res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2)) 461 self.assertFalse(res.is_failed()) 462 for return_value in res.return_values.values(): 463 self.assertIsInstance(return_value, torch.Tensor) 464 self.assertEqual((100, 100), return_value.shape) 465 466 @skip_but_pass_in_sandcastle_if( 467 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 468 ) 469 def test_dummy_compute_c10d(self): 470 self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute) 471 472 @skip_but_pass_in_sandcastle_if( 473 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 474 ) 475 def test_dummy_compute_etcd(self): 476 self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute) 477 478 @skip_but_pass_in_sandcastle_if( 479 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 480 ) 481 def test_dummy_compute_etcd_v2(self): 482 self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute) 483 484 def run_happy_function(self): 485 res = self.run_agent(Conf(entrypoint=_happy_function, local_world_size=2)) 486 self.assertFalse(res.is_failed()) 487 self.assertIsNone(res.return_values[0]) 488 self.assertIsNone(res.return_values[1]) 489 490 @skip_but_pass_in_sandcastle_if( 491 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 492 ) 493 def test_run_happy_function_c10d(self): 494 self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function) 495 496 @skip_but_pass_in_sandcastle_if( 497 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 498 ) 499 def test_run_happy_function_etcd(self): 500 self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function) 501 502 @skip_but_pass_in_sandcastle_if( 503 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 504 ) 505 def test_run_happy_function_etcd_v2(self): 506 self.run_test_with_backend( 507 backend="etcd-v2", test_to_run=self.run_happy_function 508 ) 509 510 def check_master_addr_port_override(self): 511 master_addr = "test_host" 512 master_port = 42 513 res = self.run_agent( 514 Conf( 515 entrypoint=_check_master_port_addr_override, 516 args=(master_addr, master_port), 517 local_world_size=1, 518 ), 519 master_addr_override=master_addr, 520 master_port_override=master_port, 521 ) 522 self.assertFalse(res.is_failed()) 523 self.assertIsNone(res.return_values[0]) 524 525 @skip_but_pass_in_sandcastle_if( 526 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 527 ) 528 def test_check_master_addr_port_override_etcd(self): 529 self.run_test_with_backend( 530 backend="etcd", test_to_run=self.check_master_addr_port_override 531 ) 532 533 @skip_but_pass_in_sandcastle_if( 534 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 535 ) 536 def test_check_master_addr_port_override_etcd_v2(self): 537 self.run_test_with_backend( 538 backend="etcd-v2", test_to_run=self.check_master_addr_port_override 539 ) 540 541 def run_check_env_function(self): 542 # just checks that all env vars that we need to set on the user script 543 # is actually set 544 res = self.run_agent(Conf(entrypoint=_check_env_function, local_world_size=1)) 545 self.assertFalse(res.is_failed()) 546 547 def run_check_nccl_async_error_handling_env(self): 548 # make sure TORCH_NCCL_ASYNC_ERROR_HANDLING set in os.environ is honored 549 with patch.dict(os.environ, {"TORCH_NCCL_ASYNC_ERROR_HANDLING": "0"}): 550 res = self.run_agent( 551 Conf( 552 entrypoint=_check_env_value, 553 local_world_size=1, 554 args=("TORCH_NCCL_ASYNC_ERROR_HANDLING", "0"), 555 ) 556 ) 557 self.assertFalse(res.is_failed()) 558 559 def run_check_nccl_async_error_handling_env_default(self): 560 # if not present in env var it should default to 1 561 res = self.run_agent( 562 Conf( 563 entrypoint=_check_env_value, 564 local_world_size=1, 565 args=("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1"), 566 ) 567 ) 568 self.assertFalse(res.is_failed()) 569 570 def run_agent_local_watchdog_setup_enabled(self): 571 # Set the env for watchdog 572 watchdog_env_name = TORCHELASTIC_TIMER_FILE 573 watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) 574 os.environ[watchdog_env_name] = watchdog_file_path 575 # Run the agent 576 node_conf = Conf( 577 entrypoint=_check_local_watchdog_setup, 578 local_world_size=1, 579 args=(TORCHELASTIC_TIMER_FILE, True), 580 ) 581 spec = self.get_worker_spec(node_conf, max_restarts=2) 582 agent = self.get_agent(spec, node_config=node_conf) 583 res = agent.run() 584 self.assertFalse(res.is_failed()) 585 586 def run_agent_local_watchdog_setup_disabled(self): 587 # Do not set the env for watchdog 588 watchdog_env_name = TORCHELASTIC_TIMER_FILE 589 if watchdog_env_name in os.environ: 590 del os.environ[watchdog_env_name] 591 # Run the agent 592 node_conf = Conf( 593 entrypoint=_check_local_watchdog_setup, 594 local_world_size=1, 595 args=(TORCHELASTIC_TIMER_FILE, False), 596 ) 597 spec = self.get_worker_spec(node_conf, max_restarts=2) 598 agent = self.get_agent(spec, node_config=node_conf) 599 res = agent.run() 600 self.assertFalse(res.is_failed()) 601 602 @skip_but_pass_in_sandcastle_if( 603 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 604 ) 605 def test_run_agent_local_watchdog_setup_enabled_etcd(self): 606 self.run_test_with_backend( 607 backend="etcd", test_to_run=self.run_agent_local_watchdog_setup_enabled 608 ) 609 610 @skip_but_pass_in_sandcastle_if( 611 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 612 ) 613 def test_run_agent_local_watchdog_setup_enabled_c10d(self): 614 self.run_test_with_backend( 615 backend="c10d", test_to_run=self.run_agent_local_watchdog_setup_enabled 616 ) 617 618 @skip_but_pass_in_sandcastle_if( 619 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 620 ) 621 def test_run_agent_local_watchdog_setup_disabled_etcd(self): 622 self.run_test_with_backend( 623 backend="etcd", test_to_run=self.run_agent_local_watchdog_setup_disabled 624 ) 625 626 @skip_but_pass_in_sandcastle_if( 627 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 628 ) 629 def test_run_agent_local_watchdog_setup_disabled_c10d(self): 630 self.run_test_with_backend( 631 backend="c10d", test_to_run=self.run_agent_local_watchdog_setup_disabled 632 ) 633 634 def run_agent_healthcheck_setup_enabled(self): 635 # Set the env for healthcheck 636 healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT 637 os.environ[healthcheck_port_env_name] = "12345" 638 # Run the agent 639 node_conf = Conf( 640 entrypoint=_check_local_watchdog_setup, 641 local_world_size=1, 642 args=(TORCHELASTIC_HEALTH_CHECK_PORT, True), 643 ) 644 spec = self.get_worker_spec(node_conf, max_restarts=2) 645 agent = self.get_agent(spec, node_config=node_conf) 646 res = agent.run() 647 self.assertFalse(res.is_failed()) 648 649 def run_agent_healthcheck_setup_disabled(self): 650 # Do not set the env for healthcheck 651 healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT 652 if healthcheck_port_env_name in os.environ: 653 del os.environ[healthcheck_port_env_name] 654 # Run the agent 655 node_conf = Conf( 656 entrypoint=_check_local_watchdog_setup, 657 local_world_size=1, 658 args=(TORCHELASTIC_HEALTH_CHECK_PORT, False), 659 ) 660 spec = self.get_worker_spec(node_conf, max_restarts=2) 661 agent = self.get_agent(spec, node_config=node_conf) 662 res = agent.run() 663 self.assertFalse(res.is_failed()) 664 665 @skip_but_pass_in_sandcastle_if( 666 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 667 ) 668 def test_run_agent_healthcheck_setup_enabled_etcd(self): 669 self.run_test_with_backend( 670 backend="etcd", test_to_run=self.run_agent_healthcheck_setup_enabled 671 ) 672 673 @skip_but_pass_in_sandcastle_if( 674 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 675 ) 676 def test_run_agent_healthcheck_setup_enabled_c10d(self): 677 self.run_test_with_backend( 678 backend="c10d", test_to_run=self.run_agent_healthcheck_setup_enabled 679 ) 680 681 @skip_but_pass_in_sandcastle_if( 682 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 683 ) 684 def test_run_agent_healthcheck_setup_disabled_etcd(self): 685 self.run_test_with_backend( 686 backend="etcd", test_to_run=self.run_agent_healthcheck_setup_disabled 687 ) 688 689 @skip_but_pass_in_sandcastle_if( 690 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 691 ) 692 def test_run_agent_healthcheck_setup_disabled_c10d(self): 693 self.run_test_with_backend( 694 backend="c10d", test_to_run=self.run_agent_healthcheck_setup_disabled 695 ) 696 697 @skip_but_pass_in_sandcastle_if( 698 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 699 ) 700 def test_run_check_env_function_etcd(self): 701 self.run_test_with_backend( 702 backend="etcd", test_to_run=self.run_check_env_function 703 ) 704 705 @skip_but_pass_in_sandcastle_if( 706 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 707 ) 708 def test_run_check_nccl_async_error_handling_env_c10d(self): 709 self.run_test_with_backend( 710 backend="c10d", test_to_run=self.run_check_nccl_async_error_handling_env 711 ) 712 713 @skip_but_pass_in_sandcastle_if( 714 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 715 ) 716 def test_run_check_nccl_async_error_handling_env_default_c10d(self): 717 self.run_test_with_backend( 718 backend="c10d", 719 test_to_run=self.run_check_nccl_async_error_handling_env_default, 720 ) 721 722 def run_function_with_return_value(self): 723 res = self.run_agent(Conf(entrypoint=_echo, args=("foo",), local_world_size=2)) 724 self.assertFalse(res.is_failed()) 725 self.assertEqual("foo", res.return_values[0]) 726 self.assertEqual("foo", res.return_values[1]) 727 728 @skip_but_pass_in_sandcastle_if( 729 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 730 ) 731 def test_run_function_with_return_value_c10d(self): 732 self.run_test_with_backend( 733 backend="c10d", test_to_run=self.run_function_with_return_value 734 ) 735 736 @skip_but_pass_in_sandcastle_if( 737 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 738 ) 739 def test_run_function_with_return_value_etcd(self): 740 self.run_test_with_backend( 741 backend="etcd", test_to_run=self.run_function_with_return_value 742 ) 743 744 @skip_but_pass_in_sandcastle_if( 745 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 746 ) 747 def test_run_function_with_return_value_etcd_v2(self): 748 self.run_test_with_backend( 749 backend="etcd-v2", test_to_run=self.run_function_with_return_value 750 ) 751 752 def simple_dist_sum(self): 753 res = self.run_agent(Conf(entrypoint=_dist_sum, local_world_size=2)) 754 self.assertFalse(res.is_failed()) 755 # _dist_sum internally checks that the sum computed is valid 756 757 @skip_but_pass_in_sandcastle_if( 758 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 759 ) 760 def test_simple_dist_sum_c10d(self): 761 self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum) 762 763 @skip_but_pass_in_sandcastle_if( 764 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 765 ) 766 def test_simple_dist_sum_etcd(self): 767 self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum) 768 769 @skip_but_pass_in_sandcastle_if( 770 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 771 ) 772 def test_simple_dist_sum_etcd_v2(self): 773 self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum) 774 775 def run_distributed_sum_homogeneous( 776 self, log_line_prefix_template: Optional[str] = None 777 ): 778 node_configs = [ 779 Conf(role="sum", entrypoint=_dist_sum, local_world_size=4, tee=Std.ALL), 780 Conf(role="sum", entrypoint=_dist_sum, local_world_size=4, tee=Std.ALL), 781 ] 782 # When the process method is spawn, the coverage collector hangs 783 # due to getting stuck on the _dist_sum in waiting for TCPStore workers 784 # to join the cluster 785 # TODO(aivanou): t83447589 come up with the proper fix 786 res = self.run_job( 787 node_configs, log_line_prefix_template=log_line_prefix_template 788 ) 789 self.assertEqual(2, len(res["sum"])) 790 ranks = set() 791 for run_results in res["sum"]: 792 self.assertFalse(run_results.is_failed()) 793 ranks.update(run_results.return_values.keys()) 794 self.assertSetEqual(set(range(4 + 4)), ranks) 795 796 @unittest.skipIf( 797 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 798 "test incompatible with dev/dbg asan or tsan", 799 ) 800 def test_run_distributed_sum_homogeneous_c10d(self): 801 self.run_test_with_backend( 802 backend="c10d", test_to_run=self.run_distributed_sum_homogeneous 803 ) 804 805 def test_run_with_custom_log_lines(self): 806 log_line_prefix_template = "[${role_name}-${local_rank}:${rank}]:" 807 self.run_test_with_backend( 808 backend="c10d", 809 test_to_run=lambda: self.run_distributed_sum_homogeneous( 810 log_line_prefix_template 811 ), 812 ) 813 814 @unittest.skipIf( 815 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 816 "test incompatible with dev/dbg asan or tsan", 817 ) 818 def test_run_distributed_sum_homogeneous_etcd(self): 819 self.run_test_with_backend( 820 backend="etcd", test_to_run=self.run_distributed_sum_homogeneous 821 ) 822 823 @unittest.skipIf( 824 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 825 "test incompatible with dev/dbg asan or tsan", 826 ) 827 def test_run_distributed_sum_homogeneous_etcd_v2(self): 828 self.run_test_with_backend( 829 backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous 830 ) 831 832 def run_distributed_sum_heterogeneous(self): 833 # sums all ranks on 3 agents; each running 1, 2, 3 workers respectively 834 # sum should be equal to 0 + (1 + 2) + (3 + 4 + 5) = 15 835 # sum asserted inside _dist_sum() 836 node_configs = [ 837 Conf(role="sum", entrypoint=_dist_sum, local_world_size=1), 838 Conf(role="sum", entrypoint=_dist_sum, local_world_size=2), 839 Conf(role="sum", entrypoint=_dist_sum, local_world_size=3), 840 ] 841 # When the process method is spawn, the coverage collector hangs 842 # due to getting stuck on the _dist_sum in waiting for TCPStore workers 843 # to join the cluster 844 # TODO(aivanou): t83447589 come up with the proper fix 845 res = self.run_job(node_configs) 846 self.assertEqual(3, len(res["sum"])) 847 ranks = set() 848 for run_results in res["sum"]: 849 self.assertFalse(run_results.is_failed()) 850 ranks.update(run_results.return_values.keys()) 851 self.assertSetEqual(set(range(1 + 2 + 3)), ranks) 852 853 @skip_but_pass_in_sandcastle_if( 854 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 855 ) 856 def test_run_distributed_sum_heterogeneous_c10d(self): 857 self.run_test_with_backend( 858 backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous 859 ) 860 861 @skip_but_pass_in_sandcastle_if( 862 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 863 ) 864 def test_run_distributed_sum_heterogeneous_etcd(self): 865 self.run_test_with_backend( 866 backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous 867 ) 868 869 @skip_but_pass_in_sandcastle_if( 870 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 871 ) 872 def test_run_distributed_sum_heterogeneous_etcd_v2(self): 873 self.run_test_with_backend( 874 backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous 875 ) 876 877 def run_sad_function(self): 878 """ 879 checks error propagation logic 880 """ 881 replyfile = os.path.join(self._test_dir, "error.json") 882 with mock.patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": replyfile}): 883 with self.assertRaises(ChildFailedError) as cm: 884 self.run_agent(Conf(entrypoint=_sad_function, local_world_size=2)) 885 886 rank, failure = cm.exception.get_first_failure() 887 failure_data = failure.error_file_data["message"] 888 with open(replyfile) as fp: 889 data = json.load(fp)["message"] 890 891 # ran two; both failed; first failure is either rank 0 or 1 892 self.assertTrue(rank in {0, 1}) 893 self.assertTrue(failure.local_rank in {0, 1}) 894 self.assertEqual(1, failure.exitcode) 895 self.assertEqual(data["message"], failure_data["message"]) 896 self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp) 897 898 @skip_but_pass_in_sandcastle_if( 899 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 900 ) 901 def test_run_sad_function_c10d(self): 902 self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function) 903 904 @skip_but_pass_in_sandcastle_if( 905 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 906 ) 907 def test_run_sad_function_etcd(self): 908 self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function) 909 910 @skip_but_pass_in_sandcastle_if( 911 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 912 ) 913 def test_run_sad_function_etcd_v2(self): 914 self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function) 915 916 def run_bipolar_function(self): 917 """ 918 checks agent failure handling logic 919 """ 920 node_conf = Conf(entrypoint=_bipolar_function, local_world_size=4) 921 spec = self.get_worker_spec(node_conf, max_restarts=2) 922 agent = self.get_agent(spec, node_config=node_conf) 923 run_result = agent.run() 924 self.assertTrue(run_result.is_failed()) 925 self.assertEqual(0, agent._remaining_restarts) 926 self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state) 927 self.assertTrue(agent._total_execution_time > 0) 928 929 @skip_but_pass_in_sandcastle_if( 930 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 931 ) 932 def test_run_bipolar_function_c10d(self): 933 self.run_test_with_backend( 934 backend="c10d", test_to_run=self.run_bipolar_function 935 ) 936 937 @skip_but_pass_in_sandcastle_if( 938 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 939 ) 940 def test_run_bipolar_function_etcd(self): 941 self.run_test_with_backend( 942 backend="etcd", test_to_run=self.run_bipolar_function 943 ) 944 945 @skip_but_pass_in_sandcastle_if( 946 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 947 ) 948 def test_run_bipolar_function_etcd_v2(self): 949 self.run_test_with_backend( 950 backend="etcd-v2", test_to_run=self.run_bipolar_function 951 ) 952 953 def correct_rank_assignment_heterogeneous(self): 954 node_configs = [ 955 Conf(role="master", entrypoint=_get_role_info, local_world_size=8), 956 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=1), 957 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=2), 958 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=3), 959 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4), 960 Conf(role="ps", entrypoint=_get_role_info, local_world_size=5), 961 Conf(role="ps", entrypoint=_get_role_info, local_world_size=2), 962 ] 963 results = self.run_job(node_configs) 964 print(f"heterogeneous job result: {results}") 965 self.assertEqual(1, len(results["master"])) 966 self.assertEqual(4, len(results["trainer"])) 967 self.assertEqual(2, len(results["ps"])) 968 self.assert_rank_consistency( 969 results, 970 expected_role_world_sizes={ 971 "master": 8, 972 "trainer": 1 + 2 + 3 + 4, 973 "ps": 5 + 2, 974 }, 975 ) 976 977 @unittest.skipIf( 978 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 979 "test incompatible with dev/dbg asan or tsan", 980 ) 981 def test_correct_rank_assignment_heterogeneous_etcd(self): 982 self.run_test_with_backend( 983 backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous 984 ) 985 986 @unittest.skipIf( 987 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 988 "test incompatible with dev/dbg asan or tsan", 989 ) 990 def test_correct_rank_assignment_heterogeneous_etcd_v2(self): 991 self.run_test_with_backend( 992 backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous 993 ) 994 995 def correct_rank_assignment_homogeneous(self): 996 node_configs = [ 997 Conf(role="master", entrypoint=_get_role_info, local_world_size=1), 998 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4), 999 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4), 1000 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4), 1001 Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4), 1002 Conf(role="ps", entrypoint=_get_role_info, local_world_size=3), 1003 Conf(role="ps", entrypoint=_get_role_info, local_world_size=3), 1004 ] 1005 results = self.run_job(node_configs) 1006 print(f"homogeneous job result: {results}") 1007 self.assertEqual(1, len(results["master"])) 1008 self.assertEqual(4, len(results["trainer"])) 1009 self.assertEqual(2, len(results["ps"])) 1010 self.assert_rank_consistency( 1011 results, 1012 expected_role_world_sizes={"master": 1, "trainer": 4 * 4, "ps": 3 * 2}, 1013 ) 1014 1015 @unittest.skipIf( 1016 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1017 "test incompatible with dev/dbg asan or tsan", 1018 ) 1019 def test_correct_rank_assignment_homogeneous_etcd(self): 1020 self.run_test_with_backend( 1021 backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous 1022 ) 1023 1024 @unittest.skipIf( 1025 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1026 "test incompatible with dev/dbg asan or tsan", 1027 ) 1028 def test_correct_rank_assignment_homogeneous_etcd_v2(self): 1029 self.run_test_with_backend( 1030 backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous 1031 ) 1032 1033 def assert_rank_consistency( 1034 self, 1035 run_results: Dict[str, List[RunResult]], 1036 expected_role_world_sizes: Dict[str, int], 1037 ): 1038 """ 1039 Asserts that ranks are consecutive w.r.t role_rank. If local world sizes are 4: 1040 role_rank_0 -> ranks: 0,1,2,3 1041 role_rank_1 -> ranks: 4,5,6,7 1042 ... etc ... 1043 """ 1044 1045 global_ranks: List[int] = [] 1046 # role -> [role_rank,...] 1047 role_ranks: Dict[str, List[int]] = {} 1048 # group rank -> [(rank, role_rank),...] 1049 grouped_ranks: Dict[int, List[Tuple[int, int]]] = {} 1050 1051 # global world size == sum of all the role world sizes 1052 expected_world_size = sum(expected_role_world_sizes.values()) 1053 for role, results in run_results.items(): 1054 for result in results: 1055 res = result.return_values 1056 for role_info in res.values(): 1057 rank = role_info.rank 1058 role_rank = role_info.role_rank 1059 group_rank = role_info.group_rank 1060 role_world_size = role_info.role_world_size 1061 world_size = role_info.world_size 1062 1063 self.assertEqual(expected_world_size, world_size) 1064 self.assertEqual(expected_role_world_sizes[role], role_world_size) 1065 grouped_ranks.setdefault(group_rank, []).append((rank, role_rank)) 1066 role_ranks.setdefault(role, []).append(role_rank) 1067 global_ranks.append(rank) 1068 1069 global_ranks = sorted(global_ranks) 1070 self.assertEqual(list(range(expected_world_size)), global_ranks) 1071 for role, expected_role_world_size in expected_role_world_sizes.items(): 1072 self.assertEqual( 1073 list(range(expected_role_world_size)), sorted(role_ranks[role]) 1074 ) 1075 # Make sure that each agent assigns consecutive ranks to workers 1076 # The first argument is the global_rank and the second argument 1077 # is role_rank 1078 for ranks_lst in grouped_ranks.values(): 1079 self.assert_ranks_sequential(ranks_lst, 0) 1080 self.assert_ranks_sequential(ranks_lst, 1) 1081 1082 def assert_ranks_sequential(self, ranks_pairs, rank_idx): 1083 ranks = sorted(rank_pair[rank_idx] for rank_pair in ranks_pairs) 1084 start_rank, end_rank = ranks[0], ranks[-1] 1085 self.assertEqual(list(range(start_rank, end_rank + 1)), ranks) 1086 1087 def double_agent_fault_tolerance(self): 1088 """ 1089 start ``nnodes`` agents, kill and restart odd ones, validate fault-tolerance works 1090 """ 1091 nnodes = 2 1092 wait = 2 1093 node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2) 1094 agent_results = mp.Queue() 1095 agent_args = { 1096 "conf": node_conf, 1097 "agent_results": agent_results, 1098 "min_nodes": nnodes, 1099 "max_nodes": nnodes, 1100 "max_restarts": 2, 1101 } 1102 1103 procs = [] 1104 for _ in range(nnodes): 1105 p = mp.Process( 1106 target=self.run_agent, 1107 kwargs=agent_args, 1108 ) 1109 procs.append(p) 1110 p.start() 1111 1112 # restart odd agents 1113 for i in range(nnodes): 1114 if i % 2 != 0: 1115 procs[i].kill() 1116 p = mp.Process( 1117 target=self.run_agent, 1118 kwargs=agent_args, 1119 ) 1120 procs[i] = p 1121 p.start() 1122 1123 for i in range(nnodes): 1124 p = procs[i] 1125 p.join() 1126 self.assertEqual(0, p.exitcode) 1127 1128 @unittest.skipIf( 1129 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1130 "test incompatible with dev/dbg asan or tsan", 1131 ) 1132 def test_double_agent_fault_tolerance_etcd(self): 1133 self.run_test_with_backend( 1134 backend="etcd", test_to_run=self.double_agent_fault_tolerance 1135 ) 1136 1137 @unittest.skipIf( 1138 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1139 "test incompatible with dev/dbg asan or tsan", 1140 ) 1141 def test_double_agent_fault_tolerance_etcd_v2(self): 1142 self.run_test_with_backend( 1143 backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance 1144 ) 1145 1146 def no_exit_barrier_on_failure(self): 1147 """ 1148 start ``nnodes`` agents, kill and restart odd ones, validate fault-tolerance works 1149 """ 1150 nnodes = 2 1151 wait = 20 1152 node_conf = Conf( 1153 entrypoint=_bipolar_sleep_function, args=(wait,), local_world_size=2 1154 ) 1155 agent_results = mp.Queue() 1156 monitor_interval_s = 0.5 1157 agent_args = { 1158 "conf": node_conf, 1159 "agent_results": agent_results, 1160 "min_nodes": nnodes, 1161 "max_nodes": nnodes, 1162 "max_restarts": 0, 1163 "exit_barrier_timeout": 300, 1164 "monitor_interval": monitor_interval_s, 1165 } 1166 1167 procs = [] 1168 for _ in range(nnodes): 1169 p = mp.Process( 1170 target=self.run_agent, 1171 kwargs=agent_args, 1172 ) 1173 procs.append(p) 1174 p.start() 1175 1176 # wait for all processes to finish should not take exit barrier time 1177 exit_interval_between_agents = 0 1178 for i in range(nnodes): 1179 p = procs[i] 1180 p.join() 1181 self.assertNotEqual(0, p.exitcode) 1182 exit_interval_between_agents = ( 1183 time.monotonic() - exit_interval_between_agents 1184 ) 1185 1186 # Validate that the processes finish close to each other. 1187 # Using a slightly higher timeout than 2 * monitor_interval (0.01) to make it less flaky 1188 self.assertGreater( 1189 2 * monitor_interval_s, 1190 exit_interval_between_agents, 1191 "Agents are not cleaned up until 2 * monitor_interval", 1192 ) 1193 1194 @unittest.skipIf( 1195 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1196 "test incompatible with dev/dbg asan or tsan", 1197 ) 1198 def test_no_exit_barrier_on_failure(self): 1199 self.run_test_with_backend( 1200 backend="c10d", test_to_run=self.no_exit_barrier_on_failure 1201 ) 1202 1203 def double_agent_elastic(self): 1204 """ 1205 start ``nnodes`` agents, kill odd ones (do not restart), validate 1206 elasticity (scale-down) works. (scale-up covered in fault_tolerance test) 1207 """ 1208 min_nodes = 1 1209 max_nodes = 2 1210 wait = 2 1211 node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2) 1212 agent_results = mp.Queue() 1213 agent_args = { 1214 "conf": node_conf, 1215 "agent_results": agent_results, 1216 "min_nodes": min_nodes, 1217 "max_nodes": max_nodes, 1218 "max_restarts": 2, 1219 } 1220 1221 procs = [] 1222 for _ in range(max_nodes): 1223 p = mp.Process( 1224 target=self.run_agent, 1225 kwargs=agent_args, 1226 ) 1227 procs.append(p) 1228 p.start() 1229 1230 # kill odd agents 1231 for i in range(max_nodes): 1232 if i % 2 != 0: 1233 procs[i].kill() 1234 1235 for i in range(max_nodes): 1236 p = procs[i] 1237 p.join() 1238 if i % 2 == 0: 1239 self.assertEqual(0, p.exitcode) 1240 else: 1241 self.assertEqual(-signal.SIGKILL, p.exitcode) 1242 1243 @unittest.skipIf( 1244 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1245 "test incompatible with dev/dbg asan or tsan", 1246 ) 1247 def test_double_agent_elastic_c10d(self): 1248 self.run_test_with_backend( 1249 backend="c10d", test_to_run=self.double_agent_elastic 1250 ) 1251 1252 @unittest.skipIf( 1253 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1254 "test incompatible with dev/dbg asan or tsan", 1255 ) 1256 def test_double_agent_elastic_etcd(self): 1257 self.run_test_with_backend( 1258 backend="etcd", test_to_run=self.double_agent_elastic 1259 ) 1260 1261 @unittest.skipIf( 1262 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1263 "test incompatible with dev/dbg asan or tsan", 1264 ) 1265 def test_double_agent_elastic_etcd_v2(self): 1266 self.run_test_with_backend( 1267 backend="etcd-v2", test_to_run=self.double_agent_elastic 1268 ) 1269 1270 def torch_rpc(self): 1271 """ 1272 Simple torch rpc example with torchelastic. 1273 Creates two agents (to simulate two node job), 1274 each agent runs a single worker. worker0 calls an rpc_sync on 1275 worker1. 1276 """ 1277 msg = "hello world" 1278 node_configs = [ 1279 Conf( 1280 role="master", 1281 entrypoint=rpc_master, 1282 args=(msg,), 1283 local_world_size=1, 1284 tee=Std.ALL, 1285 ), 1286 Conf( 1287 role="worker", 1288 entrypoint=rpc_worker, 1289 args=(), 1290 local_world_size=1, 1291 tee=Std.ALL, 1292 ), 1293 ] 1294 1295 results = self.run_job(node_configs) 1296 master_retvals = results["master"][0].return_values 1297 # there is only one master but the global rank is not stable 1298 # so compare the master return value as a collection 1299 self.assertEqual([f"{msg} from worker"], list(master_retvals.values())) 1300 1301 @unittest.skipIf( 1302 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1303 "test incompatible with dev/dbg asan or tsan", 1304 ) 1305 def test_torch_rpc_c10d(self): 1306 self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc) 1307 1308 @unittest.skipIf( 1309 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1310 "test incompatible with dev/dbg asan or tsan", 1311 ) 1312 def test_torch_rpc_etcd(self): 1313 self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc) 1314 1315 @unittest.skipIf( 1316 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1317 "test incompatible with dev/dbg asan or tsan", 1318 ) 1319 def test_torch_rpc_etcd_v2(self): 1320 self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc) 1321 1322 def workers_drift_success(self): 1323 """ 1324 two agents (one worker each) finishes within ``sec`` seconds of each other, 1325 exit barrier timeout set to ``sec * 2 * 2``. 1326 """ 1327 1328 sec = 1 1329 node_configs = [ 1330 Conf(role="zzz", entrypoint=_sleep, args=(0 * sec,), local_world_size=1), 1331 Conf(role="zzz", entrypoint=_sleep, args=(2 * sec,), local_world_size=1), 1332 ] 1333 results = self.run_job(node_configs, exit_barrier_timeout=2 * 2 * sec) 1334 for i in range(2): 1335 run_results = results["zzz"][i] 1336 self.assertFalse(run_results.is_failed()) 1337 for rank, output in run_results.return_values.items(): 1338 # _sleep() returns its own rank 1339 self.assertEqual(rank, output) 1340 1341 @unittest.skipIf( 1342 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1343 "test incompatible with dev/dbg asan or tsan", 1344 ) 1345 def test_workers_drift_success_etcd(self): 1346 self.run_test_with_backend( 1347 backend="etcd", test_to_run=self.workers_drift_success 1348 ) 1349 1350 @unittest.skipIf( 1351 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1352 "test incompatible with dev/dbg asan or tsan", 1353 ) 1354 def test_workers_drift_success_etcd_v2(self): 1355 self.run_test_with_backend( 1356 backend="etcd-v2", test_to_run=self.workers_drift_success 1357 ) 1358 1359 def workers_drift_fail(self): 1360 """ 1361 two agents (one worker each) finishes within ``4 x sec`` seconds of each other, 1362 exit barrier timeout set to 0. Exit barriers should NOT fail the job. 1363 """ 1364 sec = 1 1365 node_configs = [ 1366 Conf(role="zzz", entrypoint=_sleep, args=(0 * sec,), local_world_size=1), 1367 Conf(role="zzz", entrypoint=_sleep, args=(4 * sec,), local_world_size=1), 1368 ] 1369 results = self.run_job(node_configs, exit_barrier_timeout=0) 1370 for i in range(2): 1371 run_results = results["zzz"][i] 1372 self.assertFalse(run_results.is_failed()) 1373 for rank, output in run_results.return_values.items(): 1374 # _sleep() returns its own rank 1375 self.assertEqual(rank, output) 1376 1377 @unittest.skipIf( 1378 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1379 "test incompatible with dev/dbg asan or tsan", 1380 ) 1381 def test_workers_drift_fail_etcd(self): 1382 self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail) 1383 1384 @unittest.skipIf( 1385 TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN, 1386 "test incompatible with dev/dbg asan or tsan", 1387 ) 1388 def test_workers_drift_fail_etcd_v2(self): 1389 self.run_test_with_backend( 1390 backend="etcd-v2", test_to_run=self.workers_drift_fail 1391 ) 1392 1393 @patch("torch.distributed.elastic.utils.store.barrier") 1394 def barrier_failed(self, barrier_mock): 1395 """ 1396 Failure during the barrier should NOT fail the job. 1397 """ 1398 barrier_mock.side_effect = RuntimeError("test error") 1399 res = self.run_agent(Conf(entrypoint=_happy_function, local_world_size=1)) 1400 self.assertFalse(res.is_failed()) 1401 barrier_mock.assert_called_once() 1402 1403 @skip_but_pass_in_sandcastle_if( 1404 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1405 ) 1406 def test_barrier_failed_c10d(self): 1407 self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed) 1408 1409 @skip_but_pass_in_sandcastle_if( 1410 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1411 ) 1412 def test_barrier_failed_etcd(self): 1413 self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed) 1414 1415 @skip_but_pass_in_sandcastle_if( 1416 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1417 ) 1418 def test_barrier_failed_etcd_v2(self): 1419 self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed) 1420 1421 @patch("torch.distributed.elastic.agent.server.local_elastic_agent.start_processes") 1422 def shutdown_called(self, start_processes_mock): 1423 pcontext_mock = Mock() 1424 pcontext_mock.pids.return_value = {0: 0} 1425 start_processes_mock.return_value = pcontext_mock 1426 node_conf = Conf(entrypoint=_happy_function, local_world_size=1) 1427 spec = self.get_worker_spec(node_conf, max_restarts=0) 1428 agent = self.get_agent(spec, node_config=node_conf) 1429 with patch.object(agent, "_monitor_workers") as monitor_mock: 1430 monitor_mock.return_value = RunResult( 1431 state=WorkerState.SUCCEEDED, return_values={0: 0} 1432 ) 1433 agent.run("worker") 1434 pcontext_mock.close.assert_called_once() 1435 1436 @skip_but_pass_in_sandcastle_if( 1437 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1438 ) 1439 def test_shutdown_called_c10d(self): 1440 self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called) 1441 1442 @skip_but_pass_in_sandcastle_if( 1443 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1444 ) 1445 def test_shutdown_called_etcd(self): 1446 self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called) 1447 1448 @skip_but_pass_in_sandcastle_if( 1449 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1450 ) 1451 def test_shutdown_called_etcd_v2(self): 1452 self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called) 1453 1454 def fail_rank_one_once(self): 1455 res = self.run_agent( 1456 Conf(entrypoint=dummy_compute_simulate_rank_failure, local_world_size=2), 1457 max_restarts=3, 1458 ) 1459 self.assertFalse(res.is_failed()) 1460 for return_value in res.return_values.values(): 1461 self.assertIsInstance(return_value, torch.Tensor) 1462 self.assertEqual((100, 100), return_value.shape) 1463 1464 @skip_but_pass_in_sandcastle_if( 1465 TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" 1466 ) 1467 def test_rank_restart_after_failure(self): 1468 self.run_test_with_backend(backend="c10d", test_to_run=self.fail_rank_one_once) 1469