1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import os 10import tempfile 11from base64 import b64encode 12from datetime import timedelta 13from typing import Callable, cast, ClassVar 14from unittest import mock, TestCase 15 16from rendezvous_backend_test import RendezvousBackendTestMixin 17 18from torch.distributed import FileStore, TCPStore 19from torch.distributed.elastic.rendezvous import ( 20 RendezvousConnectionError, 21 RendezvousError, 22 RendezvousParameters, 23) 24from torch.distributed.elastic.rendezvous.c10d_rendezvous_backend import ( 25 C10dRendezvousBackend, 26 create_backend, 27) 28from torch.distributed.elastic.utils.distributed import get_free_port 29 30 31class TCPStoreBackendTest(TestCase, RendezvousBackendTestMixin): 32 _store: ClassVar[TCPStore] 33 34 @classmethod 35 def setUpClass(cls) -> None: 36 cls._store = TCPStore("localhost", 0, is_master=True) # type: ignore[call-arg] 37 38 def setUp(self) -> None: 39 # Make sure we have a clean slate. 40 self._store.delete_key("torch.rendezvous.dummy_run_id") 41 42 self._backend = C10dRendezvousBackend(self._store, "dummy_run_id") 43 44 def _corrupt_state(self) -> None: 45 self._store.set("torch.rendezvous.dummy_run_id", "non_base64") 46 47 48class FileStoreBackendTest(TestCase, RendezvousBackendTestMixin): 49 _store: ClassVar[FileStore] 50 51 def setUp(self) -> None: 52 _, path = tempfile.mkstemp() 53 self._path = path 54 55 # Currently, filestore doesn't implement a delete_key method, so a new 56 # filestore has to be initialized for every test in order to have a 57 # clean slate. 58 self._store = FileStore(path) 59 self._backend = C10dRendezvousBackend(self._store, "dummy_run_id") 60 61 def tearDown(self) -> None: 62 os.remove(self._path) 63 64 def _corrupt_state(self) -> None: 65 self._store.set("torch.rendezvous.dummy_run_id", "non_base64") 66 67 68class CreateBackendTest(TestCase): 69 def setUp(self) -> None: 70 # For testing, the default parameters used are for tcp. If a test 71 # uses parameters for file store, we set the self._params to 72 # self._params_filestore. 73 74 port = get_free_port() 75 self._params = RendezvousParameters( 76 backend="dummy_backend", 77 endpoint=f"localhost:{port}", 78 run_id="dummy_run_id", 79 min_nodes=1, 80 max_nodes=1, 81 is_host="true", 82 store_type="tCp", 83 read_timeout="10", 84 ) 85 86 _, tmp_path = tempfile.mkstemp() 87 88 # Parameters for filestore testing. 89 self._params_filestore = RendezvousParameters( 90 backend="dummy_backend", 91 endpoint=tmp_path, 92 run_id="dummy_run_id", 93 min_nodes=1, 94 max_nodes=1, 95 store_type="fIlE", 96 ) 97 self._expected_endpoint_file = tmp_path 98 self._expected_temp_dir = tempfile.gettempdir() 99 100 self._expected_endpoint_host = "localhost" 101 self._expected_endpoint_port = port 102 self._expected_store_type = TCPStore 103 self._expected_read_timeout = timedelta(seconds=10) 104 105 def tearDown(self) -> None: 106 os.remove(self._expected_endpoint_file) 107 108 def _run_test_with_store(self, store_type: str, test_to_run: Callable): 109 """ 110 Use this function to specify the store type to use in a test. If 111 not used, the test will default to TCPStore. 112 """ 113 if store_type == "file": 114 self._params = self._params_filestore 115 self._expected_store_type = FileStore 116 self._expected_read_timeout = timedelta(seconds=300) 117 118 test_to_run() 119 120 def _assert_create_backend_returns_backend(self) -> None: 121 backend, store = create_backend(self._params) 122 123 self.assertEqual(backend.name, "c10d") 124 125 self.assertIsInstance(store, self._expected_store_type) 126 127 typecast_store = cast(self._expected_store_type, store) 128 self.assertEqual(typecast_store.timeout, self._expected_read_timeout) # type: ignore[attr-defined] 129 if self._expected_store_type == TCPStore: 130 self.assertEqual(typecast_store.host, self._expected_endpoint_host) # type: ignore[attr-defined] 131 self.assertEqual(typecast_store.port, self._expected_endpoint_port) # type: ignore[attr-defined] 132 if self._expected_store_type == FileStore: 133 if self._params.endpoint: 134 self.assertEqual(typecast_store.path, self._expected_endpoint_file) # type: ignore[attr-defined] 135 else: 136 self.assertTrue(typecast_store.path.startswith(self._expected_temp_dir)) # type: ignore[attr-defined] 137 138 backend.set_state(b"dummy_state") 139 140 state = store.get("torch.rendezvous." + self._params.run_id) 141 142 self.assertEqual(state, b64encode(b"dummy_state")) 143 144 def test_create_backend_returns_backend(self) -> None: 145 for store_type in ["tcp", "file"]: 146 with self.subTest(store_type=store_type): 147 self._run_test_with_store( 148 store_type, self._assert_create_backend_returns_backend 149 ) 150 151 def test_create_backend_returns_backend_if_is_host_is_false(self) -> None: 152 store = TCPStore( # type: ignore[call-arg] # noqa: F841 153 self._expected_endpoint_host, self._expected_endpoint_port, is_master=True 154 ) 155 156 self._params.config["is_host"] = "false" 157 158 self._assert_create_backend_returns_backend() 159 160 def test_create_backend_returns_backend_if_is_host_is_not_specified(self) -> None: 161 del self._params.config["is_host"] 162 163 self._assert_create_backend_returns_backend() 164 165 def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists( 166 self, 167 ) -> None: 168 store = TCPStore( # type: ignore[call-arg] # noqa: F841 169 self._expected_endpoint_host, self._expected_endpoint_port, is_master=True 170 ) 171 172 del self._params.config["is_host"] 173 174 self._assert_create_backend_returns_backend() 175 176 def test_create_backend_returns_backend_if_endpoint_port_is_not_specified( 177 self, 178 ) -> None: 179 # patch default port and pass endpoint with no port specified 180 with mock.patch( 181 "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.DEFAULT_PORT", 182 self._expected_endpoint_port, 183 ): 184 self._params.endpoint = self._expected_endpoint_host 185 186 self._assert_create_backend_returns_backend() 187 188 def test_create_backend_returns_backend_if_endpoint_file_is_not_specified( 189 self, 190 ) -> None: 191 self._params_filestore.endpoint = "" 192 193 self._run_test_with_store("file", self._assert_create_backend_returns_backend) 194 195 def test_create_backend_returns_backend_if_store_type_is_not_specified( 196 self, 197 ) -> None: 198 del self._params.config["store_type"] 199 200 self._expected_store_type = TCPStore 201 if not self._params.get("read_timeout"): 202 self._expected_read_timeout = timedelta(seconds=60) 203 204 self._assert_create_backend_returns_backend() 205 206 def test_create_backend_returns_backend_if_read_timeout_is_not_specified( 207 self, 208 ) -> None: 209 del self._params.config["read_timeout"] 210 211 self._expected_read_timeout = timedelta(seconds=60) 212 213 self._assert_create_backend_returns_backend() 214 215 def test_create_backend_raises_error_if_store_is_unreachable(self) -> None: 216 self._params.config["is_host"] = "false" 217 self._params.config["read_timeout"] = "2" 218 219 with self.assertRaisesRegex( 220 RendezvousConnectionError, 221 r"^The connection to the C10d store has failed. See inner exception for details.$", 222 ): 223 create_backend(self._params) 224 225 def test_create_backend_raises_error_if_endpoint_is_invalid(self) -> None: 226 for is_host in [True, False]: 227 with self.subTest(is_host=is_host): 228 self._params.config["is_host"] = str(is_host) 229 230 self._params.endpoint = "dummy_endpoint" 231 232 with self.assertRaisesRegex( 233 RendezvousConnectionError, 234 r"^The connection to the C10d store has failed. See inner exception for " 235 r"details.$", 236 ): 237 create_backend(self._params) 238 239 def test_create_backend_raises_error_if_store_type_is_invalid(self) -> None: 240 self._params.config["store_type"] = "dummy_store_type" 241 242 with self.assertRaisesRegex( 243 ValueError, 244 r"^Invalid store type given. Currently only supports file and tcp.$", 245 ): 246 create_backend(self._params) 247 248 def test_create_backend_raises_error_if_read_timeout_is_invalid(self) -> None: 249 for read_timeout in ["0", "-10"]: 250 with self.subTest(read_timeout=read_timeout): 251 self._params.config["read_timeout"] = read_timeout 252 253 with self.assertRaisesRegex( 254 ValueError, r"^The read timeout must be a positive integer.$" 255 ): 256 create_backend(self._params) 257 258 @mock.patch("tempfile.mkstemp") 259 def test_create_backend_raises_error_if_tempfile_creation_fails( 260 self, tempfile_mock 261 ) -> None: 262 tempfile_mock.side_effect = OSError("test error") 263 # Set the endpoint to empty so it defaults to creating a temp file 264 self._params_filestore.endpoint = "" 265 with self.assertRaisesRegex( 266 RendezvousError, 267 r"The file creation for C10d store has failed. See inner exception for details.", 268 ): 269 create_backend(self._params_filestore) 270 271 @mock.patch( 272 "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.FileStore" 273 ) 274 def test_create_backend_raises_error_if_file_path_is_invalid( 275 self, filestore_mock 276 ) -> None: 277 filestore_mock.side_effect = RuntimeError("test error") 278 self._params_filestore.endpoint = "bad file path" 279 with self.assertRaisesRegex( 280 RendezvousConnectionError, 281 r"^The connection to the C10d store has failed. See inner exception for " 282 r"details.$", 283 ): 284 create_backend(self._params_filestore) 285