xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/rendezvous/etcd_rendezvous_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8import os
9import sys
10import unittest
11import uuid
12
13from torch.distributed.elastic.rendezvous import RendezvousParameters
14from torch.distributed.elastic.rendezvous.etcd_rendezvous import create_rdzv_handler
15from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
16
17
18if os.getenv("CIRCLECI"):
19    print("T85992919 temporarily disabling in circle ci", file=sys.stderr)
20    sys.exit(0)
21
22
23class EtcdRendezvousTest(unittest.TestCase):
24    @classmethod
25    def setUpClass(cls):
26        # start a standalone, single process etcd server to use for all tests
27        cls._etcd_server = EtcdServer()
28        cls._etcd_server.start()
29
30    @classmethod
31    def tearDownClass(cls):
32        # stop the standalone etcd server
33        cls._etcd_server.stop()
34
35    def test_etcd_rdzv_basic_params(self):
36        """
37        Check that we can create the handler with a minimum set of
38        params
39        """
40        rdzv_params = RendezvousParameters(
41            backend="etcd",
42            endpoint=f"{self._etcd_server.get_endpoint()}",
43            run_id=f"{uuid.uuid4()}",
44            min_nodes=1,
45            max_nodes=1,
46        )
47        etcd_rdzv = create_rdzv_handler(rdzv_params)
48        self.assertIsNotNone(etcd_rdzv)
49
50    def test_etcd_rdzv_additional_params(self):
51        run_id = str(uuid.uuid4())
52        rdzv_params = RendezvousParameters(
53            backend="etcd",
54            endpoint=f"{self._etcd_server.get_endpoint()}",
55            run_id=run_id,
56            min_nodes=1,
57            max_nodes=1,
58            timeout=60,
59            last_call_timeout=30,
60            protocol="http",
61        )
62
63        etcd_rdzv = create_rdzv_handler(rdzv_params)
64
65        self.assertIsNotNone(etcd_rdzv)
66        self.assertEqual(run_id, etcd_rdzv.get_run_id())
67
68    def test_get_backend(self):
69        run_id = str(uuid.uuid4())
70        rdzv_params = RendezvousParameters(
71            backend="etcd",
72            endpoint=f"{self._etcd_server.get_endpoint()}",
73            run_id=run_id,
74            min_nodes=1,
75            max_nodes=1,
76            timeout=60,
77            last_call_timeout=30,
78            protocol="http",
79        )
80
81        etcd_rdzv = create_rdzv_handler(rdzv_params)
82
83        self.assertEqual("etcd", etcd_rdzv.get_backend())
84