xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/rendezvous/etcd_rendezvous_backend_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import subprocess
10from base64 import b64encode
11from typing import cast, ClassVar
12from unittest import TestCase
13
14from etcd import EtcdKeyNotFound  # type: ignore[import]
15from rendezvous_backend_test import RendezvousBackendTestMixin
16
17from torch.distributed.elastic.rendezvous import (
18    RendezvousConnectionError,
19    RendezvousParameters,
20)
21from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import (
22    create_backend,
23    EtcdRendezvousBackend,
24)
25from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
26from torch.distributed.elastic.rendezvous.etcd_store import EtcdStore
27
28
29class EtcdRendezvousBackendTest(TestCase, RendezvousBackendTestMixin):
30    _server: ClassVar[EtcdServer]
31
32    @classmethod
33    def setUpClass(cls) -> None:
34        cls._server = EtcdServer()
35        cls._server.start(stderr=subprocess.DEVNULL)
36
37    @classmethod
38    def tearDownClass(cls) -> None:
39        cls._server.stop()
40
41    def setUp(self) -> None:
42        self._client = self._server.get_client()
43
44        # Make sure we have a clean slate.
45        try:
46            self._client.delete("/dummy_prefix", recursive=True, dir=True)
47        except EtcdKeyNotFound:
48            pass
49
50        self._backend = EtcdRendezvousBackend(
51            self._client, "dummy_run_id", "/dummy_prefix"
52        )
53
54    def _corrupt_state(self) -> None:
55        self._client.write("/dummy_prefix/dummy_run_id", "non_base64")
56
57
58class CreateBackendTest(TestCase):
59    _server: ClassVar[EtcdServer]
60
61    @classmethod
62    def setUpClass(cls) -> None:
63        cls._server = EtcdServer()
64        cls._server.start(stderr=subprocess.DEVNULL)
65
66    @classmethod
67    def tearDownClass(cls) -> None:
68        cls._server.stop()
69
70    def setUp(self) -> None:
71        self._params = RendezvousParameters(
72            backend="dummy_backend",
73            endpoint=self._server.get_endpoint(),
74            run_id="dummy_run_id",
75            min_nodes=1,
76            max_nodes=1,
77            protocol="hTTp",
78            read_timeout="10",
79        )
80
81        self._expected_read_timeout = 10
82
83    def test_create_backend_returns_backend(self) -> None:
84        backend, store = create_backend(self._params)
85
86        self.assertEqual(backend.name, "etcd-v2")
87
88        self.assertIsInstance(store, EtcdStore)
89
90        etcd_store = cast(EtcdStore, store)
91
92        self.assertEqual(etcd_store.client.read_timeout, self._expected_read_timeout)  # type: ignore[attr-defined]
93
94        client = self._server.get_client()
95
96        backend.set_state(b"dummy_state")
97
98        result = client.get("/torch/elastic/rendezvous/" + self._params.run_id)
99
100        self.assertEqual(result.value, b64encode(b"dummy_state").decode())
101        self.assertLessEqual(result.ttl, 7200)
102
103        store.set("dummy_key", "dummy_value")
104
105        result = client.get("/torch/elastic/store/" + b64encode(b"dummy_key").decode())
106
107        self.assertEqual(result.value, b64encode(b"dummy_value").decode())
108
109    def test_create_backend_returns_backend_if_protocol_is_not_specified(self) -> None:
110        del self._params.config["protocol"]
111
112        self.test_create_backend_returns_backend()
113
114    def test_create_backend_returns_backend_if_read_timeout_is_not_specified(
115        self,
116    ) -> None:
117        del self._params.config["read_timeout"]
118
119        self._expected_read_timeout = 60
120
121        self.test_create_backend_returns_backend()
122
123    def test_create_backend_raises_error_if_etcd_is_unreachable(self) -> None:
124        self._params.endpoint = "dummy:1234"
125
126        with self.assertRaisesRegex(
127            RendezvousConnectionError,
128            r"^The connection to etcd has failed. See inner exception for details.$",
129        ):
130            create_backend(self._params)
131
132    def test_create_backend_raises_error_if_protocol_is_invalid(self) -> None:
133        self._params.config["protocol"] = "dummy"
134
135        with self.assertRaisesRegex(
136            ValueError, r"^The protocol must be HTTP or HTTPS.$"
137        ):
138            create_backend(self._params)
139
140    def test_create_backend_raises_error_if_read_timeout_is_invalid(self) -> None:
141        for read_timeout in ["0", "-10"]:
142            with self.subTest(read_timeout=read_timeout):
143                self._params.config["read_timeout"] = read_timeout
144
145                with self.assertRaisesRegex(
146                    ValueError, r"^The read timeout must be a positive integer.$"
147                ):
148                    create_backend(self._params)
149