1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import subprocess 10from base64 import b64encode 11from typing import cast, ClassVar 12from unittest import TestCase 13 14from etcd import EtcdKeyNotFound # type: ignore[import] 15from rendezvous_backend_test import RendezvousBackendTestMixin 16 17from torch.distributed.elastic.rendezvous import ( 18 RendezvousConnectionError, 19 RendezvousParameters, 20) 21from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import ( 22 create_backend, 23 EtcdRendezvousBackend, 24) 25from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer 26from torch.distributed.elastic.rendezvous.etcd_store import EtcdStore 27 28 29class EtcdRendezvousBackendTest(TestCase, RendezvousBackendTestMixin): 30 _server: ClassVar[EtcdServer] 31 32 @classmethod 33 def setUpClass(cls) -> None: 34 cls._server = EtcdServer() 35 cls._server.start(stderr=subprocess.DEVNULL) 36 37 @classmethod 38 def tearDownClass(cls) -> None: 39 cls._server.stop() 40 41 def setUp(self) -> None: 42 self._client = self._server.get_client() 43 44 # Make sure we have a clean slate. 45 try: 46 self._client.delete("/dummy_prefix", recursive=True, dir=True) 47 except EtcdKeyNotFound: 48 pass 49 50 self._backend = EtcdRendezvousBackend( 51 self._client, "dummy_run_id", "/dummy_prefix" 52 ) 53 54 def _corrupt_state(self) -> None: 55 self._client.write("/dummy_prefix/dummy_run_id", "non_base64") 56 57 58class CreateBackendTest(TestCase): 59 _server: ClassVar[EtcdServer] 60 61 @classmethod 62 def setUpClass(cls) -> None: 63 cls._server = EtcdServer() 64 cls._server.start(stderr=subprocess.DEVNULL) 65 66 @classmethod 67 def tearDownClass(cls) -> None: 68 cls._server.stop() 69 70 def setUp(self) -> None: 71 self._params = RendezvousParameters( 72 backend="dummy_backend", 73 endpoint=self._server.get_endpoint(), 74 run_id="dummy_run_id", 75 min_nodes=1, 76 max_nodes=1, 77 protocol="hTTp", 78 read_timeout="10", 79 ) 80 81 self._expected_read_timeout = 10 82 83 def test_create_backend_returns_backend(self) -> None: 84 backend, store = create_backend(self._params) 85 86 self.assertEqual(backend.name, "etcd-v2") 87 88 self.assertIsInstance(store, EtcdStore) 89 90 etcd_store = cast(EtcdStore, store) 91 92 self.assertEqual(etcd_store.client.read_timeout, self._expected_read_timeout) # type: ignore[attr-defined] 93 94 client = self._server.get_client() 95 96 backend.set_state(b"dummy_state") 97 98 result = client.get("/torch/elastic/rendezvous/" + self._params.run_id) 99 100 self.assertEqual(result.value, b64encode(b"dummy_state").decode()) 101 self.assertLessEqual(result.ttl, 7200) 102 103 store.set("dummy_key", "dummy_value") 104 105 result = client.get("/torch/elastic/store/" + b64encode(b"dummy_key").decode()) 106 107 self.assertEqual(result.value, b64encode(b"dummy_value").decode()) 108 109 def test_create_backend_returns_backend_if_protocol_is_not_specified(self) -> None: 110 del self._params.config["protocol"] 111 112 self.test_create_backend_returns_backend() 113 114 def test_create_backend_returns_backend_if_read_timeout_is_not_specified( 115 self, 116 ) -> None: 117 del self._params.config["read_timeout"] 118 119 self._expected_read_timeout = 60 120 121 self.test_create_backend_returns_backend() 122 123 def test_create_backend_raises_error_if_etcd_is_unreachable(self) -> None: 124 self._params.endpoint = "dummy:1234" 125 126 with self.assertRaisesRegex( 127 RendezvousConnectionError, 128 r"^The connection to etcd has failed. See inner exception for details.$", 129 ): 130 create_backend(self._params) 131 132 def test_create_backend_raises_error_if_protocol_is_invalid(self) -> None: 133 self._params.config["protocol"] = "dummy" 134 135 with self.assertRaisesRegex( 136 ValueError, r"^The protocol must be HTTP or HTTPS.$" 137 ): 138 create_backend(self._params) 139 140 def test_create_backend_raises_error_if_read_timeout_is_invalid(self) -> None: 141 for read_timeout in ["0", "-10"]: 142 with self.subTest(read_timeout=read_timeout): 143 self._params.config["read_timeout"] = read_timeout 144 145 with self.assertRaisesRegex( 146 ValueError, r"^The read timeout must be a positive integer.$" 147 ): 148 create_backend(self._params) 149