xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/agent/server/test/local_elastic_agent_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import json
10import multiprocessing as mp
11import os
12import shutil
13import signal
14import socket
15import tempfile
16import time
17import unittest
18import uuid
19from dataclasses import dataclass
20from typing import Callable, Dict, List, Optional, Tuple
21from unittest import mock
22from unittest.mock import Mock, patch
23
24import torch
25import torch.distributed as dist
26import torch.distributed.elastic.rendezvous.registry as rdzv_registry
27import torch.distributed.rpc as rpc
28from torch.distributed.elastic.agent.server.api import (
29    RunResult,
30    WorkerSpec,
31    WorkerState,
32)
33from torch.distributed.elastic.agent.server.local_elastic_agent import (
34    LocalElasticAgent,
35    TORCHELASTIC_HEALTH_CHECK_PORT,
36    TORCHELASTIC_TIMER_FILE,
37)
38from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, Std
39from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
40from torch.distributed.elastic.rendezvous import RendezvousParameters
41from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
42from torch.distributed.rpc.backend_registry import BackendType
43from torch.testing._internal.common_utils import (
44    skip_but_pass_in_sandcastle_if,
45    TEST_WITH_DEV_DBG_ASAN,
46    TEST_WITH_TSAN,
47)
48
49
50def init_rpc(name, backend):
51    rank = int(os.environ["RANK"])
52    world_size = int(os.environ["WORLD_SIZE"])
53    rpc.init_rpc(
54        name=name,
55        backend=backend,
56        rank=rank,
57        world_size=world_size,
58    )
59
60
61def rpc_master(msg):
62    init_rpc("master", BackendType.TENSORPIPE)
63    ret = rpc.rpc_sync(to="worker", func=_echo, args=(msg,))
64    rpc.shutdown()
65    return f"{ret} from worker"
66
67
68def rpc_worker():
69    init_rpc("worker", BackendType.TENSORPIPE)
70    rpc.shutdown()
71
72
73def _happy_function():
74    return
75
76
77def _sad_function():
78    raise RuntimeError("sad because i throw")
79
80
81def dummy_compute() -> torch.Tensor:
82    """
83    returns a predefined size random Tensor
84    """
85    return torch.rand(100, 100)
86
87
88def dummy_compute_simulate_rank_failure() -> torch.Tensor:
89    """
90    fails rank 1 once
91    in other cases, returns a predefined size random Tensor
92    """
93    if os.environ["RANK"] == "1" and os.environ["TORCHELASTIC_RESTART_COUNT"] == "0":
94        os.kill(os.getpid(), 9)
95    return torch.rand(100, 100)
96
97
98def _fatal_signal_function(expected_error_index: int, sig: int):
99    rank = int(os.environ["RANK"])
100    if rank == expected_error_index:
101        os.kill(os.getpid(), sig)
102
103
104def _check_master_port_addr_override(
105    expected_master_addr: str, expected_master_port: int
106):
107    actual_master_addr = os.environ["MASTER_ADDR"]
108    actual_master_port = int(os.environ["MASTER_PORT"])
109    if (
110        expected_master_addr != actual_master_addr
111        and expected_master_port != actual_master_port
112    ):
113        raise RuntimeError(
114            f"Expected addr: {expected_master_addr}:{expected_master_port}, got addr: {actual_master_addr}:{actual_master_port}"
115        )
116
117
118def _bipolar_function():
119    rank = int(os.environ["RANK"])
120    if rank % 2 == 0:
121        _happy_function()
122    else:
123        _sad_function()
124
125
126def _bipolar_sleep_function(sleep_sec):
127    rank = int(os.environ["RANK"])
128    if rank % 2 == 0:
129        _sleep(sleep_sec)
130    else:
131        _sad_function()
132
133
134def _dist_sum(wait=0):
135    rank = int(os.environ["RANK"])
136    world_size = int(os.environ["WORLD_SIZE"])
137    dist.init_process_group(backend="gloo")
138    t = torch.tensor(rank)
139
140    time.sleep(wait)
141    dist.all_reduce(t, op=dist.reduce_op.SUM)
142
143    expected_sum = sum(range(world_size))
144    actual = t.item()
145    if expected_sum != actual:
146        raise RuntimeError(f"Expected rank sum {expected_sum}, got {actual}")
147
148
149def _sleep(sleep_sec) -> int:
150    time.sleep(sleep_sec)
151    return int(os.environ["RANK"])
152
153
154@dataclass
155class RankInfo:
156    rank: int
157    role_rank: int
158    group_rank: int
159    role_world_size: int
160    world_size: int
161
162
163def _get_role_info() -> RankInfo:
164    rank = int(os.environ["RANK"])
165    role_rank = int(os.environ["ROLE_RANK"])
166    group_rank = int(os.environ["GROUP_RANK"])
167    role_world_size = int(os.environ["ROLE_WORLD_SIZE"])
168    world_size = int(os.environ["WORLD_SIZE"])
169    return RankInfo(rank, role_rank, group_rank, role_world_size, world_size)
170
171
172def _echo(msg):
173    return msg
174
175
176def _check_env_function():
177    # just check these env vars exist, os.environ[...] will naturally throw
178    # if the variable does not exist
179    env_vars = [
180        "RANK",
181        "LOCAL_RANK",
182        "ROLE_RANK",
183        "ROLE_NAME",
184        "GROUP_RANK",
185        "LOCAL_WORLD_SIZE",
186        "ROLE_WORLD_SIZE",
187        "WORLD_SIZE",
188        "GROUP_WORLD_SIZE",
189        "MASTER_ADDR",
190        "MASTER_PORT",
191        "TORCHELASTIC_RESTART_COUNT",
192        "TORCHELASTIC_MAX_RESTARTS",
193        "TORCHELASTIC_RUN_ID",
194        "TORCHELASTIC_USE_AGENT_STORE",
195        "TORCH_NCCL_ASYNC_ERROR_HANDLING",
196    ]
197    for var in env_vars:
198        _ = os.environ[var]
199
200
201def _check_env_value(key: str, expected: str):
202    # checks if the env var ``key`` matches ``value``
203    # this function is intended to be used as the entrypoint to the elastic run
204    if key not in os.environ:
205        raise RuntimeError(f"Environment variable {key} not found in os.environ")
206    else:
207        actual = os.getenv(key)
208        if expected != actual:
209            raise RuntimeError(
210                f"os.environ['{key}']={actual}"
211                f" does not equal the expected value: {expected}"
212            )
213
214
215def _check_local_watchdog_setup(key: str, should_exist: bool):
216    if should_exist and key not in os.environ:
217        raise RuntimeError(f"Environment variable {key} not found in os.environ")
218    if not should_exist and key in os.environ:
219        raise RuntimeError(f"Environment variable {key} found in os.environ")
220
221
222def acquire_available_port():
223    """
224    Uses sockets to acquire an available port from the os for use.
225
226    Note: To reduce the race condition where another process grabs the port
227          after this function returns an available port, we should aim to use
228          the port as quickly as possible.
229    """
230    addrs = socket.getaddrinfo(
231        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
232    )
233
234    for addr in addrs:
235        family, type, proto, _, _ = addr
236        try:
237            s = socket.socket(family, type, proto)
238            s.bind(("localhost", 0))
239            s.listen(0)
240            port = s.getsockname()[1]
241            s.close()
242            return port
243        except OSError as e:
244            s.close()
245            print(f"Socket creation attempt failed: {e}")
246
247    raise RuntimeError("Failed to create a socket")
248
249
250@dataclass
251class Conf:
252    """
253    Holds arguments to launch an agent (e.g. simulates an agent run on a node).
254
255    """
256
257    entrypoint: Callable
258    local_world_size: int
259    args: Tuple = ()
260    role: str = "default"
261    redirects: Std = Std.NONE
262    tee: Std = Std.NONE
263
264
265class LocalElasticAgentTest(unittest.TestCase):
266    @classmethod
267    def setUpClass(cls):
268        # start a standalone, single process etcd server to use for all tests
269        cls._etcd_server = EtcdServer()
270        cls._etcd_server.start()
271
272    @classmethod
273    def tearDownClass(cls):
274        # stop the standalone etcd server
275        cls._etcd_server.stop()
276
277    def setUp(self):
278        self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
279        self._run_id = str(uuid.uuid4()).split("-")[0]
280
281    def tearDown(self):
282        shutil.rmtree(self._test_dir)
283
284    def log_dir(self) -> str:
285        return tempfile.mkdtemp(prefix="torchelastic_", dir=self._test_dir)
286
287    def get_worker_spec(
288        self,
289        node_config: Conf,
290        min_nodes=1,
291        max_nodes=1,
292        max_restarts=0,
293        monitor_interval=0.01,
294        master_addr_override: Optional[str] = None,
295        master_port_override: Optional[int] = None,
296        is_host=True,
297    ):
298        rdzv_params = RendezvousParameters(
299            backend=self._backend,
300            endpoint=self._endpoint,
301            run_id=self._run_id,
302            min_nodes=min_nodes,
303            max_nodes=max_nodes,
304            is_host=is_host,
305        )
306        rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
307        return WorkerSpec(
308            role=node_config.role,
309            local_world_size=node_config.local_world_size,
310            entrypoint=node_config.entrypoint,
311            args=node_config.args,
312            rdzv_handler=rdzv_handler,
313            max_restarts=max_restarts,
314            monitor_interval=monitor_interval,
315            master_addr=master_addr_override,
316            master_port=master_port_override,
317        )
318
319    def get_agent(
320        self,
321        spec: WorkerSpec,
322        node_config: Conf,
323        start_method: str = "spawn",
324        exit_barrier_timeout=5,
325        log_line_prefix_template: Optional[str] = None,
326    ) -> LocalElasticAgent:
327        return LocalElasticAgent(
328            spec,
329            start_method=start_method,
330            exit_barrier_timeout=exit_barrier_timeout,
331            logs_specs=DefaultLogsSpecs(
332                log_dir=self.log_dir(),
333                redirects=node_config.redirects,
334                tee=node_config.tee,
335            ),
336            log_line_prefix_template=log_line_prefix_template,
337        )
338
339    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
340    #  `torch.distributed.elastic.multiprocessing.errors.record`.
341    @record
342    def run_agent(
343        self,
344        conf: Conf,
345        agent_results: Optional[mp.Queue] = None,  # (role, agent_result)
346        min_nodes=1,
347        max_nodes=1,
348        start_method: str = "spawn",
349        max_restarts: int = 0,
350        exit_barrier_timeout=5,
351        master_addr_override: Optional[str] = None,
352        master_port_override: Optional[int] = None,
353        is_host=True,
354        monitor_interval=0.01,
355        log_line_prefix_template: Optional[str] = None,
356    ) -> Optional[RunResult]:
357        """
358        Runs a single agent. This method can be called either on a separate process
359        or the main test process. When calling this method on a separate process make
360        sure to pass the ``agent_results`` multiprocessing Queue so that the agent's
361        run results can be returned. If ``agent_results`` is omitted, then the
362        run result is returned from the method.
363        """
364
365        spec = self.get_worker_spec(
366            node_config=conf,
367            min_nodes=min_nodes,
368            max_nodes=max_nodes,
369            max_restarts=max_restarts,
370            master_addr_override=master_addr_override,
371            master_port_override=master_port_override,
372            is_host=is_host,
373            monitor_interval=monitor_interval,
374        )
375        agent = self.get_agent(
376            spec=spec,
377            node_config=conf,
378            start_method=start_method,
379            exit_barrier_timeout=exit_barrier_timeout,
380            log_line_prefix_template=log_line_prefix_template,
381        )
382
383        result = agent.run()
384        spec.rdzv_handler.shutdown()
385
386        if agent_results:
387            agent_results.put((conf.role, result))
388
389        if result.is_failed():
390            raise ChildFailedError(spec.get_entrypoint_name(), result.failures)
391        else:
392            if not agent_results:
393                return result
394
395    def run_job(
396        self,
397        node_configs: List[Conf],
398        exit_barrier_timeout: int = 5,
399        log_line_prefix_template: Optional[str] = None,
400    ) -> Dict[str, List[RunResult]]:
401        """
402        Simulates running a distributed job by running multiple agents
403        (one on each process). Agent 0 is run on the main process for
404        test coverage and ease of debugging
405        """
406
407        nnodes = len(node_configs)
408
409        # each element in this queue holds a tuple (role, RunResult) for each agent
410        agent_results = mp.Queue()
411
412        # run first agent of first config on main process for test coverage + ease of debugging
413        # it is important we loop in reverse order b/c running fn on the main process blocks
414        procs = []
415        for node_idx in reversed(range(len(node_configs))):
416            conf = node_configs[node_idx]
417            run_agent_args = {
418                "conf": conf,
419                "agent_results": agent_results,
420                "min_nodes": nnodes,
421                "max_nodes": nnodes,
422                "start_method": "spawn",
423                "max_restarts": 0,
424                "exit_barrier_timeout": exit_barrier_timeout,
425                "is_host": node_idx == 0,
426                "log_line_prefix_template": log_line_prefix_template,
427            }
428            p = mp.Process(target=self.run_agent, kwargs=run_agent_args)
429            procs.append(p)
430            p.start()
431        for p in procs:
432            p.join()
433
434        results: Dict[str, List[RunResult]] = {}
435        while not agent_results.empty():
436            role, run_result = agent_results.get()
437            results.setdefault(role, []).append(run_result)
438        return results
439
440    def run_test_with_backend(self, backend: str, test_to_run: Callable):
441        """
442        Sets the backend and determines the endpoint before running the
443        given test.
444
445        Note: This method must be invoked to run any test functions that spawn
446              an agent. This is because this function sets the backend and
447              endpoint parameters.
448        """
449        self._backend = backend
450
451        if self._backend == "etcd-v2" or self._backend == "etcd":
452            self._endpoint = self._etcd_server.get_endpoint()
453        else:
454            # the default is c10d backend
455            self._endpoint = f"localhost:{acquire_available_port()}"
456
457        test_to_run()
458
459    def dummy_compute(self):
460        res = self.run_agent(Conf(entrypoint=dummy_compute, local_world_size=2))
461        self.assertFalse(res.is_failed())
462        for return_value in res.return_values.values():
463            self.assertIsInstance(return_value, torch.Tensor)
464            self.assertEqual((100, 100), return_value.shape)
465
466    @skip_but_pass_in_sandcastle_if(
467        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
468    )
469    def test_dummy_compute_c10d(self):
470        self.run_test_with_backend(backend="c10d", test_to_run=self.dummy_compute)
471
472    @skip_but_pass_in_sandcastle_if(
473        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
474    )
475    def test_dummy_compute_etcd(self):
476        self.run_test_with_backend(backend="etcd", test_to_run=self.dummy_compute)
477
478    @skip_but_pass_in_sandcastle_if(
479        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
480    )
481    def test_dummy_compute_etcd_v2(self):
482        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.dummy_compute)
483
484    def run_happy_function(self):
485        res = self.run_agent(Conf(entrypoint=_happy_function, local_world_size=2))
486        self.assertFalse(res.is_failed())
487        self.assertIsNone(res.return_values[0])
488        self.assertIsNone(res.return_values[1])
489
490    @skip_but_pass_in_sandcastle_if(
491        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
492    )
493    def test_run_happy_function_c10d(self):
494        self.run_test_with_backend(backend="c10d", test_to_run=self.run_happy_function)
495
496    @skip_but_pass_in_sandcastle_if(
497        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
498    )
499    def test_run_happy_function_etcd(self):
500        self.run_test_with_backend(backend="etcd", test_to_run=self.run_happy_function)
501
502    @skip_but_pass_in_sandcastle_if(
503        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
504    )
505    def test_run_happy_function_etcd_v2(self):
506        self.run_test_with_backend(
507            backend="etcd-v2", test_to_run=self.run_happy_function
508        )
509
510    def check_master_addr_port_override(self):
511        master_addr = "test_host"
512        master_port = 42
513        res = self.run_agent(
514            Conf(
515                entrypoint=_check_master_port_addr_override,
516                args=(master_addr, master_port),
517                local_world_size=1,
518            ),
519            master_addr_override=master_addr,
520            master_port_override=master_port,
521        )
522        self.assertFalse(res.is_failed())
523        self.assertIsNone(res.return_values[0])
524
525    @skip_but_pass_in_sandcastle_if(
526        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
527    )
528    def test_check_master_addr_port_override_etcd(self):
529        self.run_test_with_backend(
530            backend="etcd", test_to_run=self.check_master_addr_port_override
531        )
532
533    @skip_but_pass_in_sandcastle_if(
534        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
535    )
536    def test_check_master_addr_port_override_etcd_v2(self):
537        self.run_test_with_backend(
538            backend="etcd-v2", test_to_run=self.check_master_addr_port_override
539        )
540
541    def run_check_env_function(self):
542        # just checks that all env vars that we need to set on the user script
543        # is actually set
544        res = self.run_agent(Conf(entrypoint=_check_env_function, local_world_size=1))
545        self.assertFalse(res.is_failed())
546
547    def run_check_nccl_async_error_handling_env(self):
548        # make sure TORCH_NCCL_ASYNC_ERROR_HANDLING set in os.environ is honored
549        with patch.dict(os.environ, {"TORCH_NCCL_ASYNC_ERROR_HANDLING": "0"}):
550            res = self.run_agent(
551                Conf(
552                    entrypoint=_check_env_value,
553                    local_world_size=1,
554                    args=("TORCH_NCCL_ASYNC_ERROR_HANDLING", "0"),
555                )
556            )
557            self.assertFalse(res.is_failed())
558
559    def run_check_nccl_async_error_handling_env_default(self):
560        # if not present in env var it should default to 1
561        res = self.run_agent(
562            Conf(
563                entrypoint=_check_env_value,
564                local_world_size=1,
565                args=("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1"),
566            )
567        )
568        self.assertFalse(res.is_failed())
569
570    def run_agent_local_watchdog_setup_enabled(self):
571        # Set the env for watchdog
572        watchdog_env_name = TORCHELASTIC_TIMER_FILE
573        watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
574        os.environ[watchdog_env_name] = watchdog_file_path
575        # Run the agent
576        node_conf = Conf(
577            entrypoint=_check_local_watchdog_setup,
578            local_world_size=1,
579            args=(TORCHELASTIC_TIMER_FILE, True),
580        )
581        spec = self.get_worker_spec(node_conf, max_restarts=2)
582        agent = self.get_agent(spec, node_config=node_conf)
583        res = agent.run()
584        self.assertFalse(res.is_failed())
585
586    def run_agent_local_watchdog_setup_disabled(self):
587        # Do not set the env for watchdog
588        watchdog_env_name = TORCHELASTIC_TIMER_FILE
589        if watchdog_env_name in os.environ:
590            del os.environ[watchdog_env_name]
591        # Run the agent
592        node_conf = Conf(
593            entrypoint=_check_local_watchdog_setup,
594            local_world_size=1,
595            args=(TORCHELASTIC_TIMER_FILE, False),
596        )
597        spec = self.get_worker_spec(node_conf, max_restarts=2)
598        agent = self.get_agent(spec, node_config=node_conf)
599        res = agent.run()
600        self.assertFalse(res.is_failed())
601
602    @skip_but_pass_in_sandcastle_if(
603        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
604    )
605    def test_run_agent_local_watchdog_setup_enabled_etcd(self):
606        self.run_test_with_backend(
607            backend="etcd", test_to_run=self.run_agent_local_watchdog_setup_enabled
608        )
609
610    @skip_but_pass_in_sandcastle_if(
611        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
612    )
613    def test_run_agent_local_watchdog_setup_enabled_c10d(self):
614        self.run_test_with_backend(
615            backend="c10d", test_to_run=self.run_agent_local_watchdog_setup_enabled
616        )
617
618    @skip_but_pass_in_sandcastle_if(
619        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
620    )
621    def test_run_agent_local_watchdog_setup_disabled_etcd(self):
622        self.run_test_with_backend(
623            backend="etcd", test_to_run=self.run_agent_local_watchdog_setup_disabled
624        )
625
626    @skip_but_pass_in_sandcastle_if(
627        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
628    )
629    def test_run_agent_local_watchdog_setup_disabled_c10d(self):
630        self.run_test_with_backend(
631            backend="c10d", test_to_run=self.run_agent_local_watchdog_setup_disabled
632        )
633
634    def run_agent_healthcheck_setup_enabled(self):
635        # Set the env for healthcheck
636        healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT
637        os.environ[healthcheck_port_env_name] = "12345"
638        # Run the agent
639        node_conf = Conf(
640            entrypoint=_check_local_watchdog_setup,
641            local_world_size=1,
642            args=(TORCHELASTIC_HEALTH_CHECK_PORT, True),
643        )
644        spec = self.get_worker_spec(node_conf, max_restarts=2)
645        agent = self.get_agent(spec, node_config=node_conf)
646        res = agent.run()
647        self.assertFalse(res.is_failed())
648
649    def run_agent_healthcheck_setup_disabled(self):
650        # Do not set the env for healthcheck
651        healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT
652        if healthcheck_port_env_name in os.environ:
653            del os.environ[healthcheck_port_env_name]
654        # Run the agent
655        node_conf = Conf(
656            entrypoint=_check_local_watchdog_setup,
657            local_world_size=1,
658            args=(TORCHELASTIC_HEALTH_CHECK_PORT, False),
659        )
660        spec = self.get_worker_spec(node_conf, max_restarts=2)
661        agent = self.get_agent(spec, node_config=node_conf)
662        res = agent.run()
663        self.assertFalse(res.is_failed())
664
665    @skip_but_pass_in_sandcastle_if(
666        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
667    )
668    def test_run_agent_healthcheck_setup_enabled_etcd(self):
669        self.run_test_with_backend(
670            backend="etcd", test_to_run=self.run_agent_healthcheck_setup_enabled
671        )
672
673    @skip_but_pass_in_sandcastle_if(
674        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
675    )
676    def test_run_agent_healthcheck_setup_enabled_c10d(self):
677        self.run_test_with_backend(
678            backend="c10d", test_to_run=self.run_agent_healthcheck_setup_enabled
679        )
680
681    @skip_but_pass_in_sandcastle_if(
682        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
683    )
684    def test_run_agent_healthcheck_setup_disabled_etcd(self):
685        self.run_test_with_backend(
686            backend="etcd", test_to_run=self.run_agent_healthcheck_setup_disabled
687        )
688
689    @skip_but_pass_in_sandcastle_if(
690        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
691    )
692    def test_run_agent_healthcheck_setup_disabled_c10d(self):
693        self.run_test_with_backend(
694            backend="c10d", test_to_run=self.run_agent_healthcheck_setup_disabled
695        )
696
697    @skip_but_pass_in_sandcastle_if(
698        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
699    )
700    def test_run_check_env_function_etcd(self):
701        self.run_test_with_backend(
702            backend="etcd", test_to_run=self.run_check_env_function
703        )
704
705    @skip_but_pass_in_sandcastle_if(
706        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
707    )
708    def test_run_check_nccl_async_error_handling_env_c10d(self):
709        self.run_test_with_backend(
710            backend="c10d", test_to_run=self.run_check_nccl_async_error_handling_env
711        )
712
713    @skip_but_pass_in_sandcastle_if(
714        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
715    )
716    def test_run_check_nccl_async_error_handling_env_default_c10d(self):
717        self.run_test_with_backend(
718            backend="c10d",
719            test_to_run=self.run_check_nccl_async_error_handling_env_default,
720        )
721
722    def run_function_with_return_value(self):
723        res = self.run_agent(Conf(entrypoint=_echo, args=("foo",), local_world_size=2))
724        self.assertFalse(res.is_failed())
725        self.assertEqual("foo", res.return_values[0])
726        self.assertEqual("foo", res.return_values[1])
727
728    @skip_but_pass_in_sandcastle_if(
729        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
730    )
731    def test_run_function_with_return_value_c10d(self):
732        self.run_test_with_backend(
733            backend="c10d", test_to_run=self.run_function_with_return_value
734        )
735
736    @skip_but_pass_in_sandcastle_if(
737        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
738    )
739    def test_run_function_with_return_value_etcd(self):
740        self.run_test_with_backend(
741            backend="etcd", test_to_run=self.run_function_with_return_value
742        )
743
744    @skip_but_pass_in_sandcastle_if(
745        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
746    )
747    def test_run_function_with_return_value_etcd_v2(self):
748        self.run_test_with_backend(
749            backend="etcd-v2", test_to_run=self.run_function_with_return_value
750        )
751
752    def simple_dist_sum(self):
753        res = self.run_agent(Conf(entrypoint=_dist_sum, local_world_size=2))
754        self.assertFalse(res.is_failed())
755        # _dist_sum internally checks that the sum computed is valid
756
757    @skip_but_pass_in_sandcastle_if(
758        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
759    )
760    def test_simple_dist_sum_c10d(self):
761        self.run_test_with_backend(backend="c10d", test_to_run=self.simple_dist_sum)
762
763    @skip_but_pass_in_sandcastle_if(
764        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
765    )
766    def test_simple_dist_sum_etcd(self):
767        self.run_test_with_backend(backend="etcd", test_to_run=self.simple_dist_sum)
768
769    @skip_but_pass_in_sandcastle_if(
770        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
771    )
772    def test_simple_dist_sum_etcd_v2(self):
773        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.simple_dist_sum)
774
775    def run_distributed_sum_homogeneous(
776        self, log_line_prefix_template: Optional[str] = None
777    ):
778        node_configs = [
779            Conf(role="sum", entrypoint=_dist_sum, local_world_size=4, tee=Std.ALL),
780            Conf(role="sum", entrypoint=_dist_sum, local_world_size=4, tee=Std.ALL),
781        ]
782        # When the process method is spawn, the coverage collector hangs
783        # due to getting stuck on the _dist_sum in waiting for TCPStore workers
784        # to join the cluster
785        # TODO(aivanou): t83447589 come up with the proper fix
786        res = self.run_job(
787            node_configs, log_line_prefix_template=log_line_prefix_template
788        )
789        self.assertEqual(2, len(res["sum"]))
790        ranks = set()
791        for run_results in res["sum"]:
792            self.assertFalse(run_results.is_failed())
793            ranks.update(run_results.return_values.keys())
794        self.assertSetEqual(set(range(4 + 4)), ranks)
795
796    @unittest.skipIf(
797        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
798        "test incompatible with dev/dbg asan or tsan",
799    )
800    def test_run_distributed_sum_homogeneous_c10d(self):
801        self.run_test_with_backend(
802            backend="c10d", test_to_run=self.run_distributed_sum_homogeneous
803        )
804
805    def test_run_with_custom_log_lines(self):
806        log_line_prefix_template = "[${role_name}-${local_rank}:${rank}]:"
807        self.run_test_with_backend(
808            backend="c10d",
809            test_to_run=lambda: self.run_distributed_sum_homogeneous(
810                log_line_prefix_template
811            ),
812        )
813
814    @unittest.skipIf(
815        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
816        "test incompatible with dev/dbg asan or tsan",
817    )
818    def test_run_distributed_sum_homogeneous_etcd(self):
819        self.run_test_with_backend(
820            backend="etcd", test_to_run=self.run_distributed_sum_homogeneous
821        )
822
823    @unittest.skipIf(
824        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
825        "test incompatible with dev/dbg asan or tsan",
826    )
827    def test_run_distributed_sum_homogeneous_etcd_v2(self):
828        self.run_test_with_backend(
829            backend="etcd-v2", test_to_run=self.run_distributed_sum_homogeneous
830        )
831
832    def run_distributed_sum_heterogeneous(self):
833        # sums all ranks on 3 agents; each running 1, 2, 3 workers respectively
834        # sum should be equal to 0 + (1 + 2) + (3 + 4 + 5) = 15
835        # sum asserted inside _dist_sum()
836        node_configs = [
837            Conf(role="sum", entrypoint=_dist_sum, local_world_size=1),
838            Conf(role="sum", entrypoint=_dist_sum, local_world_size=2),
839            Conf(role="sum", entrypoint=_dist_sum, local_world_size=3),
840        ]
841        # When the process method is spawn, the coverage collector hangs
842        # due to getting stuck on the _dist_sum in waiting for TCPStore workers
843        # to join the cluster
844        # TODO(aivanou): t83447589 come up with the proper fix
845        res = self.run_job(node_configs)
846        self.assertEqual(3, len(res["sum"]))
847        ranks = set()
848        for run_results in res["sum"]:
849            self.assertFalse(run_results.is_failed())
850            ranks.update(run_results.return_values.keys())
851        self.assertSetEqual(set(range(1 + 2 + 3)), ranks)
852
853    @skip_but_pass_in_sandcastle_if(
854        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
855    )
856    def test_run_distributed_sum_heterogeneous_c10d(self):
857        self.run_test_with_backend(
858            backend="c10d", test_to_run=self.run_distributed_sum_heterogeneous
859        )
860
861    @skip_but_pass_in_sandcastle_if(
862        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
863    )
864    def test_run_distributed_sum_heterogeneous_etcd(self):
865        self.run_test_with_backend(
866            backend="etcd", test_to_run=self.run_distributed_sum_heterogeneous
867        )
868
869    @skip_but_pass_in_sandcastle_if(
870        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
871    )
872    def test_run_distributed_sum_heterogeneous_etcd_v2(self):
873        self.run_test_with_backend(
874            backend="etcd-v2", test_to_run=self.run_distributed_sum_heterogeneous
875        )
876
877    def run_sad_function(self):
878        """
879        checks error propagation logic
880        """
881        replyfile = os.path.join(self._test_dir, "error.json")
882        with mock.patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": replyfile}):
883            with self.assertRaises(ChildFailedError) as cm:
884                self.run_agent(Conf(entrypoint=_sad_function, local_world_size=2))
885
886            rank, failure = cm.exception.get_first_failure()
887            failure_data = failure.error_file_data["message"]
888            with open(replyfile) as fp:
889                data = json.load(fp)["message"]
890
891                # ran two; both failed; first failure is either rank 0 or 1
892                self.assertTrue(rank in {0, 1})
893                self.assertTrue(failure.local_rank in {0, 1})
894                self.assertEqual(1, failure.exitcode)
895                self.assertEqual(data["message"], failure_data["message"])
896                self.assertEqual(int(data["extraInfo"]["timestamp"]), failure.timestamp)
897
898    @skip_but_pass_in_sandcastle_if(
899        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
900    )
901    def test_run_sad_function_c10d(self):
902        self.run_test_with_backend(backend="c10d", test_to_run=self.run_sad_function)
903
904    @skip_but_pass_in_sandcastle_if(
905        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
906    )
907    def test_run_sad_function_etcd(self):
908        self.run_test_with_backend(backend="etcd", test_to_run=self.run_sad_function)
909
910    @skip_but_pass_in_sandcastle_if(
911        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
912    )
913    def test_run_sad_function_etcd_v2(self):
914        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.run_sad_function)
915
916    def run_bipolar_function(self):
917        """
918        checks agent failure handling logic
919        """
920        node_conf = Conf(entrypoint=_bipolar_function, local_world_size=4)
921        spec = self.get_worker_spec(node_conf, max_restarts=2)
922        agent = self.get_agent(spec, node_config=node_conf)
923        run_result = agent.run()
924        self.assertTrue(run_result.is_failed())
925        self.assertEqual(0, agent._remaining_restarts)
926        self.assertEqual(WorkerState.FAILED, agent.get_worker_group().state)
927        self.assertTrue(agent._total_execution_time > 0)
928
929    @skip_but_pass_in_sandcastle_if(
930        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
931    )
932    def test_run_bipolar_function_c10d(self):
933        self.run_test_with_backend(
934            backend="c10d", test_to_run=self.run_bipolar_function
935        )
936
937    @skip_but_pass_in_sandcastle_if(
938        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
939    )
940    def test_run_bipolar_function_etcd(self):
941        self.run_test_with_backend(
942            backend="etcd", test_to_run=self.run_bipolar_function
943        )
944
945    @skip_but_pass_in_sandcastle_if(
946        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
947    )
948    def test_run_bipolar_function_etcd_v2(self):
949        self.run_test_with_backend(
950            backend="etcd-v2", test_to_run=self.run_bipolar_function
951        )
952
953    def correct_rank_assignment_heterogeneous(self):
954        node_configs = [
955            Conf(role="master", entrypoint=_get_role_info, local_world_size=8),
956            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=1),
957            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=2),
958            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=3),
959            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4),
960            Conf(role="ps", entrypoint=_get_role_info, local_world_size=5),
961            Conf(role="ps", entrypoint=_get_role_info, local_world_size=2),
962        ]
963        results = self.run_job(node_configs)
964        print(f"heterogeneous job result: {results}")
965        self.assertEqual(1, len(results["master"]))
966        self.assertEqual(4, len(results["trainer"]))
967        self.assertEqual(2, len(results["ps"]))
968        self.assert_rank_consistency(
969            results,
970            expected_role_world_sizes={
971                "master": 8,
972                "trainer": 1 + 2 + 3 + 4,
973                "ps": 5 + 2,
974            },
975        )
976
977    @unittest.skipIf(
978        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
979        "test incompatible with dev/dbg asan or tsan",
980    )
981    def test_correct_rank_assignment_heterogeneous_etcd(self):
982        self.run_test_with_backend(
983            backend="etcd", test_to_run=self.correct_rank_assignment_heterogeneous
984        )
985
986    @unittest.skipIf(
987        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
988        "test incompatible with dev/dbg asan or tsan",
989    )
990    def test_correct_rank_assignment_heterogeneous_etcd_v2(self):
991        self.run_test_with_backend(
992            backend="etcd-v2", test_to_run=self.correct_rank_assignment_heterogeneous
993        )
994
995    def correct_rank_assignment_homogeneous(self):
996        node_configs = [
997            Conf(role="master", entrypoint=_get_role_info, local_world_size=1),
998            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4),
999            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4),
1000            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4),
1001            Conf(role="trainer", entrypoint=_get_role_info, local_world_size=4),
1002            Conf(role="ps", entrypoint=_get_role_info, local_world_size=3),
1003            Conf(role="ps", entrypoint=_get_role_info, local_world_size=3),
1004        ]
1005        results = self.run_job(node_configs)
1006        print(f"homogeneous job result: {results}")
1007        self.assertEqual(1, len(results["master"]))
1008        self.assertEqual(4, len(results["trainer"]))
1009        self.assertEqual(2, len(results["ps"]))
1010        self.assert_rank_consistency(
1011            results,
1012            expected_role_world_sizes={"master": 1, "trainer": 4 * 4, "ps": 3 * 2},
1013        )
1014
1015    @unittest.skipIf(
1016        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1017        "test incompatible with dev/dbg asan or tsan",
1018    )
1019    def test_correct_rank_assignment_homogeneous_etcd(self):
1020        self.run_test_with_backend(
1021            backend="etcd", test_to_run=self.correct_rank_assignment_homogeneous
1022        )
1023
1024    @unittest.skipIf(
1025        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1026        "test incompatible with dev/dbg asan or tsan",
1027    )
1028    def test_correct_rank_assignment_homogeneous_etcd_v2(self):
1029        self.run_test_with_backend(
1030            backend="etcd-v2", test_to_run=self.correct_rank_assignment_homogeneous
1031        )
1032
1033    def assert_rank_consistency(
1034        self,
1035        run_results: Dict[str, List[RunResult]],
1036        expected_role_world_sizes: Dict[str, int],
1037    ):
1038        """
1039        Asserts that ranks are consecutive w.r.t role_rank. If local world sizes are 4:
1040        role_rank_0 -> ranks: 0,1,2,3
1041        role_rank_1 -> ranks: 4,5,6,7
1042        ... etc ...
1043        """
1044
1045        global_ranks: List[int] = []
1046        # role -> [role_rank,...]
1047        role_ranks: Dict[str, List[int]] = {}
1048        # group rank -> [(rank, role_rank),...]
1049        grouped_ranks: Dict[int, List[Tuple[int, int]]] = {}
1050
1051        # global world size == sum of all the role world sizes
1052        expected_world_size = sum(expected_role_world_sizes.values())
1053        for role, results in run_results.items():
1054            for result in results:
1055                res = result.return_values
1056                for role_info in res.values():
1057                    rank = role_info.rank
1058                    role_rank = role_info.role_rank
1059                    group_rank = role_info.group_rank
1060                    role_world_size = role_info.role_world_size
1061                    world_size = role_info.world_size
1062
1063                    self.assertEqual(expected_world_size, world_size)
1064                    self.assertEqual(expected_role_world_sizes[role], role_world_size)
1065                    grouped_ranks.setdefault(group_rank, []).append((rank, role_rank))
1066                    role_ranks.setdefault(role, []).append(role_rank)
1067                    global_ranks.append(rank)
1068
1069        global_ranks = sorted(global_ranks)
1070        self.assertEqual(list(range(expected_world_size)), global_ranks)
1071        for role, expected_role_world_size in expected_role_world_sizes.items():
1072            self.assertEqual(
1073                list(range(expected_role_world_size)), sorted(role_ranks[role])
1074            )
1075        # Make sure that each agent assigns consecutive ranks to workers
1076        # The first argument is the global_rank and the second argument
1077        # is role_rank
1078        for ranks_lst in grouped_ranks.values():
1079            self.assert_ranks_sequential(ranks_lst, 0)
1080            self.assert_ranks_sequential(ranks_lst, 1)
1081
1082    def assert_ranks_sequential(self, ranks_pairs, rank_idx):
1083        ranks = sorted(rank_pair[rank_idx] for rank_pair in ranks_pairs)
1084        start_rank, end_rank = ranks[0], ranks[-1]
1085        self.assertEqual(list(range(start_rank, end_rank + 1)), ranks)
1086
1087    def double_agent_fault_tolerance(self):
1088        """
1089        start ``nnodes`` agents, kill and restart odd ones, validate fault-tolerance works
1090        """
1091        nnodes = 2
1092        wait = 2
1093        node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
1094        agent_results = mp.Queue()
1095        agent_args = {
1096            "conf": node_conf,
1097            "agent_results": agent_results,
1098            "min_nodes": nnodes,
1099            "max_nodes": nnodes,
1100            "max_restarts": 2,
1101        }
1102
1103        procs = []
1104        for _ in range(nnodes):
1105            p = mp.Process(
1106                target=self.run_agent,
1107                kwargs=agent_args,
1108            )
1109            procs.append(p)
1110            p.start()
1111
1112        # restart odd agents
1113        for i in range(nnodes):
1114            if i % 2 != 0:
1115                procs[i].kill()
1116                p = mp.Process(
1117                    target=self.run_agent,
1118                    kwargs=agent_args,
1119                )
1120                procs[i] = p
1121                p.start()
1122
1123        for i in range(nnodes):
1124            p = procs[i]
1125            p.join()
1126            self.assertEqual(0, p.exitcode)
1127
1128    @unittest.skipIf(
1129        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1130        "test incompatible with dev/dbg asan or tsan",
1131    )
1132    def test_double_agent_fault_tolerance_etcd(self):
1133        self.run_test_with_backend(
1134            backend="etcd", test_to_run=self.double_agent_fault_tolerance
1135        )
1136
1137    @unittest.skipIf(
1138        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1139        "test incompatible with dev/dbg asan or tsan",
1140    )
1141    def test_double_agent_fault_tolerance_etcd_v2(self):
1142        self.run_test_with_backend(
1143            backend="etcd-v2", test_to_run=self.double_agent_fault_tolerance
1144        )
1145
1146    def no_exit_barrier_on_failure(self):
1147        """
1148        start ``nnodes`` agents, kill and restart odd ones, validate fault-tolerance works
1149        """
1150        nnodes = 2
1151        wait = 20
1152        node_conf = Conf(
1153            entrypoint=_bipolar_sleep_function, args=(wait,), local_world_size=2
1154        )
1155        agent_results = mp.Queue()
1156        monitor_interval_s = 0.5
1157        agent_args = {
1158            "conf": node_conf,
1159            "agent_results": agent_results,
1160            "min_nodes": nnodes,
1161            "max_nodes": nnodes,
1162            "max_restarts": 0,
1163            "exit_barrier_timeout": 300,
1164            "monitor_interval": monitor_interval_s,
1165        }
1166
1167        procs = []
1168        for _ in range(nnodes):
1169            p = mp.Process(
1170                target=self.run_agent,
1171                kwargs=agent_args,
1172            )
1173            procs.append(p)
1174            p.start()
1175
1176        # wait for all processes to finish should not take exit barrier time
1177        exit_interval_between_agents = 0
1178        for i in range(nnodes):
1179            p = procs[i]
1180            p.join()
1181            self.assertNotEqual(0, p.exitcode)
1182            exit_interval_between_agents = (
1183                time.monotonic() - exit_interval_between_agents
1184            )
1185
1186        # Validate that the processes finish close to each other.
1187        # Using a slightly higher timeout than 2 * monitor_interval (0.01) to make it less flaky
1188        self.assertGreater(
1189            2 * monitor_interval_s,
1190            exit_interval_between_agents,
1191            "Agents are not cleaned up until 2 * monitor_interval",
1192        )
1193
1194    @unittest.skipIf(
1195        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1196        "test incompatible with dev/dbg asan or tsan",
1197    )
1198    def test_no_exit_barrier_on_failure(self):
1199        self.run_test_with_backend(
1200            backend="c10d", test_to_run=self.no_exit_barrier_on_failure
1201        )
1202
1203    def double_agent_elastic(self):
1204        """
1205        start ``nnodes`` agents, kill odd ones (do not restart), validate
1206        elasticity (scale-down) works. (scale-up covered in fault_tolerance test)
1207        """
1208        min_nodes = 1
1209        max_nodes = 2
1210        wait = 2
1211        node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
1212        agent_results = mp.Queue()
1213        agent_args = {
1214            "conf": node_conf,
1215            "agent_results": agent_results,
1216            "min_nodes": min_nodes,
1217            "max_nodes": max_nodes,
1218            "max_restarts": 2,
1219        }
1220
1221        procs = []
1222        for _ in range(max_nodes):
1223            p = mp.Process(
1224                target=self.run_agent,
1225                kwargs=agent_args,
1226            )
1227            procs.append(p)
1228            p.start()
1229
1230        # kill odd agents
1231        for i in range(max_nodes):
1232            if i % 2 != 0:
1233                procs[i].kill()
1234
1235        for i in range(max_nodes):
1236            p = procs[i]
1237            p.join()
1238            if i % 2 == 0:
1239                self.assertEqual(0, p.exitcode)
1240            else:
1241                self.assertEqual(-signal.SIGKILL, p.exitcode)
1242
1243    @unittest.skipIf(
1244        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1245        "test incompatible with dev/dbg asan or tsan",
1246    )
1247    def test_double_agent_elastic_c10d(self):
1248        self.run_test_with_backend(
1249            backend="c10d", test_to_run=self.double_agent_elastic
1250        )
1251
1252    @unittest.skipIf(
1253        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1254        "test incompatible with dev/dbg asan or tsan",
1255    )
1256    def test_double_agent_elastic_etcd(self):
1257        self.run_test_with_backend(
1258            backend="etcd", test_to_run=self.double_agent_elastic
1259        )
1260
1261    @unittest.skipIf(
1262        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1263        "test incompatible with dev/dbg asan or tsan",
1264    )
1265    def test_double_agent_elastic_etcd_v2(self):
1266        self.run_test_with_backend(
1267            backend="etcd-v2", test_to_run=self.double_agent_elastic
1268        )
1269
1270    def torch_rpc(self):
1271        """
1272        Simple torch rpc example with torchelastic.
1273        Creates two agents (to simulate two node job),
1274        each agent runs a single worker. worker0 calls an rpc_sync on
1275        worker1.
1276        """
1277        msg = "hello world"
1278        node_configs = [
1279            Conf(
1280                role="master",
1281                entrypoint=rpc_master,
1282                args=(msg,),
1283                local_world_size=1,
1284                tee=Std.ALL,
1285            ),
1286            Conf(
1287                role="worker",
1288                entrypoint=rpc_worker,
1289                args=(),
1290                local_world_size=1,
1291                tee=Std.ALL,
1292            ),
1293        ]
1294
1295        results = self.run_job(node_configs)
1296        master_retvals = results["master"][0].return_values
1297        # there is only one master but the global rank is not stable
1298        # so compare the master return value as a collection
1299        self.assertEqual([f"{msg} from worker"], list(master_retvals.values()))
1300
1301    @unittest.skipIf(
1302        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1303        "test incompatible with dev/dbg asan or tsan",
1304    )
1305    def test_torch_rpc_c10d(self):
1306        self.run_test_with_backend(backend="c10d", test_to_run=self.torch_rpc)
1307
1308    @unittest.skipIf(
1309        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1310        "test incompatible with dev/dbg asan or tsan",
1311    )
1312    def test_torch_rpc_etcd(self):
1313        self.run_test_with_backend(backend="etcd", test_to_run=self.torch_rpc)
1314
1315    @unittest.skipIf(
1316        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1317        "test incompatible with dev/dbg asan or tsan",
1318    )
1319    def test_torch_rpc_etcd_v2(self):
1320        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.torch_rpc)
1321
1322    def workers_drift_success(self):
1323        """
1324        two agents (one worker each) finishes within ``sec`` seconds of each other,
1325        exit barrier timeout set to ``sec * 2 * 2``.
1326        """
1327
1328        sec = 1
1329        node_configs = [
1330            Conf(role="zzz", entrypoint=_sleep, args=(0 * sec,), local_world_size=1),
1331            Conf(role="zzz", entrypoint=_sleep, args=(2 * sec,), local_world_size=1),
1332        ]
1333        results = self.run_job(node_configs, exit_barrier_timeout=2 * 2 * sec)
1334        for i in range(2):
1335            run_results = results["zzz"][i]
1336            self.assertFalse(run_results.is_failed())
1337            for rank, output in run_results.return_values.items():
1338                # _sleep() returns its own rank
1339                self.assertEqual(rank, output)
1340
1341    @unittest.skipIf(
1342        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1343        "test incompatible with dev/dbg asan or tsan",
1344    )
1345    def test_workers_drift_success_etcd(self):
1346        self.run_test_with_backend(
1347            backend="etcd", test_to_run=self.workers_drift_success
1348        )
1349
1350    @unittest.skipIf(
1351        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1352        "test incompatible with dev/dbg asan or tsan",
1353    )
1354    def test_workers_drift_success_etcd_v2(self):
1355        self.run_test_with_backend(
1356            backend="etcd-v2", test_to_run=self.workers_drift_success
1357        )
1358
1359    def workers_drift_fail(self):
1360        """
1361        two agents (one worker each) finishes within ``4 x sec`` seconds of each other,
1362        exit barrier timeout set to 0. Exit barriers should NOT fail the job.
1363        """
1364        sec = 1
1365        node_configs = [
1366            Conf(role="zzz", entrypoint=_sleep, args=(0 * sec,), local_world_size=1),
1367            Conf(role="zzz", entrypoint=_sleep, args=(4 * sec,), local_world_size=1),
1368        ]
1369        results = self.run_job(node_configs, exit_barrier_timeout=0)
1370        for i in range(2):
1371            run_results = results["zzz"][i]
1372            self.assertFalse(run_results.is_failed())
1373            for rank, output in run_results.return_values.items():
1374                # _sleep() returns its own rank
1375                self.assertEqual(rank, output)
1376
1377    @unittest.skipIf(
1378        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1379        "test incompatible with dev/dbg asan or tsan",
1380    )
1381    def test_workers_drift_fail_etcd(self):
1382        self.run_test_with_backend(backend="etcd", test_to_run=self.workers_drift_fail)
1383
1384    @unittest.skipIf(
1385        TEST_WITH_DEV_DBG_ASAN or TEST_WITH_TSAN,
1386        "test incompatible with dev/dbg asan or tsan",
1387    )
1388    def test_workers_drift_fail_etcd_v2(self):
1389        self.run_test_with_backend(
1390            backend="etcd-v2", test_to_run=self.workers_drift_fail
1391        )
1392
1393    @patch("torch.distributed.elastic.utils.store.barrier")
1394    def barrier_failed(self, barrier_mock):
1395        """
1396        Failure during the barrier should NOT fail the job.
1397        """
1398        barrier_mock.side_effect = RuntimeError("test error")
1399        res = self.run_agent(Conf(entrypoint=_happy_function, local_world_size=1))
1400        self.assertFalse(res.is_failed())
1401        barrier_mock.assert_called_once()
1402
1403    @skip_but_pass_in_sandcastle_if(
1404        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1405    )
1406    def test_barrier_failed_c10d(self):
1407        self.run_test_with_backend(backend="c10d", test_to_run=self.barrier_failed)
1408
1409    @skip_but_pass_in_sandcastle_if(
1410        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1411    )
1412    def test_barrier_failed_etcd(self):
1413        self.run_test_with_backend(backend="etcd", test_to_run=self.barrier_failed)
1414
1415    @skip_but_pass_in_sandcastle_if(
1416        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1417    )
1418    def test_barrier_failed_etcd_v2(self):
1419        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.barrier_failed)
1420
1421    @patch("torch.distributed.elastic.agent.server.local_elastic_agent.start_processes")
1422    def shutdown_called(self, start_processes_mock):
1423        pcontext_mock = Mock()
1424        pcontext_mock.pids.return_value = {0: 0}
1425        start_processes_mock.return_value = pcontext_mock
1426        node_conf = Conf(entrypoint=_happy_function, local_world_size=1)
1427        spec = self.get_worker_spec(node_conf, max_restarts=0)
1428        agent = self.get_agent(spec, node_config=node_conf)
1429        with patch.object(agent, "_monitor_workers") as monitor_mock:
1430            monitor_mock.return_value = RunResult(
1431                state=WorkerState.SUCCEEDED, return_values={0: 0}
1432            )
1433            agent.run("worker")
1434        pcontext_mock.close.assert_called_once()
1435
1436    @skip_but_pass_in_sandcastle_if(
1437        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1438    )
1439    def test_shutdown_called_c10d(self):
1440        self.run_test_with_backend(backend="c10d", test_to_run=self.shutdown_called)
1441
1442    @skip_but_pass_in_sandcastle_if(
1443        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1444    )
1445    def test_shutdown_called_etcd(self):
1446        self.run_test_with_backend(backend="etcd", test_to_run=self.shutdown_called)
1447
1448    @skip_but_pass_in_sandcastle_if(
1449        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1450    )
1451    def test_shutdown_called_etcd_v2(self):
1452        self.run_test_with_backend(backend="etcd-v2", test_to_run=self.shutdown_called)
1453
1454    def fail_rank_one_once(self):
1455        res = self.run_agent(
1456            Conf(entrypoint=dummy_compute_simulate_rank_failure, local_world_size=2),
1457            max_restarts=3,
1458        )
1459        self.assertFalse(res.is_failed())
1460        for return_value in res.return_values.values():
1461            self.assertIsInstance(return_value, torch.Tensor)
1462            self.assertEqual((100, 100), return_value.shape)
1463
1464    @skip_but_pass_in_sandcastle_if(
1465        TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
1466    )
1467    def test_rank_restart_after_failure(self):
1468        self.run_test_with_backend(backend="c10d", test_to_run=self.fail_rank_one_once)
1469