xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import os
10import tempfile
11from base64 import b64encode
12from datetime import timedelta
13from typing import Callable, cast, ClassVar
14from unittest import mock, TestCase
15
16from rendezvous_backend_test import RendezvousBackendTestMixin
17
18from torch.distributed import FileStore, TCPStore
19from torch.distributed.elastic.rendezvous import (
20    RendezvousConnectionError,
21    RendezvousError,
22    RendezvousParameters,
23)
24from torch.distributed.elastic.rendezvous.c10d_rendezvous_backend import (
25    C10dRendezvousBackend,
26    create_backend,
27)
28from torch.distributed.elastic.utils.distributed import get_free_port
29
30
31class TCPStoreBackendTest(TestCase, RendezvousBackendTestMixin):
32    _store: ClassVar[TCPStore]
33
34    @classmethod
35    def setUpClass(cls) -> None:
36        cls._store = TCPStore("localhost", 0, is_master=True)  # type: ignore[call-arg]
37
38    def setUp(self) -> None:
39        # Make sure we have a clean slate.
40        self._store.delete_key("torch.rendezvous.dummy_run_id")
41
42        self._backend = C10dRendezvousBackend(self._store, "dummy_run_id")
43
44    def _corrupt_state(self) -> None:
45        self._store.set("torch.rendezvous.dummy_run_id", "non_base64")
46
47
48class FileStoreBackendTest(TestCase, RendezvousBackendTestMixin):
49    _store: ClassVar[FileStore]
50
51    def setUp(self) -> None:
52        _, path = tempfile.mkstemp()
53        self._path = path
54
55        # Currently, filestore doesn't implement a delete_key method, so a new
56        # filestore has to be initialized for every test in order to have a
57        # clean slate.
58        self._store = FileStore(path)
59        self._backend = C10dRendezvousBackend(self._store, "dummy_run_id")
60
61    def tearDown(self) -> None:
62        os.remove(self._path)
63
64    def _corrupt_state(self) -> None:
65        self._store.set("torch.rendezvous.dummy_run_id", "non_base64")
66
67
68class CreateBackendTest(TestCase):
69    def setUp(self) -> None:
70        # For testing, the default parameters used are for tcp. If a test
71        # uses parameters for file store, we set the self._params to
72        # self._params_filestore.
73
74        port = get_free_port()
75        self._params = RendezvousParameters(
76            backend="dummy_backend",
77            endpoint=f"localhost:{port}",
78            run_id="dummy_run_id",
79            min_nodes=1,
80            max_nodes=1,
81            is_host="true",
82            store_type="tCp",
83            read_timeout="10",
84        )
85
86        _, tmp_path = tempfile.mkstemp()
87
88        # Parameters for filestore testing.
89        self._params_filestore = RendezvousParameters(
90            backend="dummy_backend",
91            endpoint=tmp_path,
92            run_id="dummy_run_id",
93            min_nodes=1,
94            max_nodes=1,
95            store_type="fIlE",
96        )
97        self._expected_endpoint_file = tmp_path
98        self._expected_temp_dir = tempfile.gettempdir()
99
100        self._expected_endpoint_host = "localhost"
101        self._expected_endpoint_port = port
102        self._expected_store_type = TCPStore
103        self._expected_read_timeout = timedelta(seconds=10)
104
105    def tearDown(self) -> None:
106        os.remove(self._expected_endpoint_file)
107
108    def _run_test_with_store(self, store_type: str, test_to_run: Callable):
109        """
110        Use this function to specify the store type to use in a test. If
111        not used, the test will default to TCPStore.
112        """
113        if store_type == "file":
114            self._params = self._params_filestore
115            self._expected_store_type = FileStore
116            self._expected_read_timeout = timedelta(seconds=300)
117
118        test_to_run()
119
120    def _assert_create_backend_returns_backend(self) -> None:
121        backend, store = create_backend(self._params)
122
123        self.assertEqual(backend.name, "c10d")
124
125        self.assertIsInstance(store, self._expected_store_type)
126
127        typecast_store = cast(self._expected_store_type, store)
128        self.assertEqual(typecast_store.timeout, self._expected_read_timeout)  # type: ignore[attr-defined]
129        if self._expected_store_type == TCPStore:
130            self.assertEqual(typecast_store.host, self._expected_endpoint_host)  # type: ignore[attr-defined]
131            self.assertEqual(typecast_store.port, self._expected_endpoint_port)  # type: ignore[attr-defined]
132        if self._expected_store_type == FileStore:
133            if self._params.endpoint:
134                self.assertEqual(typecast_store.path, self._expected_endpoint_file)  # type: ignore[attr-defined]
135            else:
136                self.assertTrue(typecast_store.path.startswith(self._expected_temp_dir))  # type: ignore[attr-defined]
137
138        backend.set_state(b"dummy_state")
139
140        state = store.get("torch.rendezvous." + self._params.run_id)
141
142        self.assertEqual(state, b64encode(b"dummy_state"))
143
144    def test_create_backend_returns_backend(self) -> None:
145        for store_type in ["tcp", "file"]:
146            with self.subTest(store_type=store_type):
147                self._run_test_with_store(
148                    store_type, self._assert_create_backend_returns_backend
149                )
150
151    def test_create_backend_returns_backend_if_is_host_is_false(self) -> None:
152        store = TCPStore(  # type: ignore[call-arg] # noqa: F841
153            self._expected_endpoint_host, self._expected_endpoint_port, is_master=True
154        )
155
156        self._params.config["is_host"] = "false"
157
158        self._assert_create_backend_returns_backend()
159
160    def test_create_backend_returns_backend_if_is_host_is_not_specified(self) -> None:
161        del self._params.config["is_host"]
162
163        self._assert_create_backend_returns_backend()
164
165    def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists(
166        self,
167    ) -> None:
168        store = TCPStore(  # type: ignore[call-arg] # noqa: F841
169            self._expected_endpoint_host, self._expected_endpoint_port, is_master=True
170        )
171
172        del self._params.config["is_host"]
173
174        self._assert_create_backend_returns_backend()
175
176    def test_create_backend_returns_backend_if_endpoint_port_is_not_specified(
177        self,
178    ) -> None:
179        # patch default port and pass endpoint with no port specified
180        with mock.patch(
181            "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.DEFAULT_PORT",
182            self._expected_endpoint_port,
183        ):
184            self._params.endpoint = self._expected_endpoint_host
185
186            self._assert_create_backend_returns_backend()
187
188    def test_create_backend_returns_backend_if_endpoint_file_is_not_specified(
189        self,
190    ) -> None:
191        self._params_filestore.endpoint = ""
192
193        self._run_test_with_store("file", self._assert_create_backend_returns_backend)
194
195    def test_create_backend_returns_backend_if_store_type_is_not_specified(
196        self,
197    ) -> None:
198        del self._params.config["store_type"]
199
200        self._expected_store_type = TCPStore
201        if not self._params.get("read_timeout"):
202            self._expected_read_timeout = timedelta(seconds=60)
203
204        self._assert_create_backend_returns_backend()
205
206    def test_create_backend_returns_backend_if_read_timeout_is_not_specified(
207        self,
208    ) -> None:
209        del self._params.config["read_timeout"]
210
211        self._expected_read_timeout = timedelta(seconds=60)
212
213        self._assert_create_backend_returns_backend()
214
215    def test_create_backend_raises_error_if_store_is_unreachable(self) -> None:
216        self._params.config["is_host"] = "false"
217        self._params.config["read_timeout"] = "2"
218
219        with self.assertRaisesRegex(
220            RendezvousConnectionError,
221            r"^The connection to the C10d store has failed. See inner exception for details.$",
222        ):
223            create_backend(self._params)
224
225    def test_create_backend_raises_error_if_endpoint_is_invalid(self) -> None:
226        for is_host in [True, False]:
227            with self.subTest(is_host=is_host):
228                self._params.config["is_host"] = str(is_host)
229
230                self._params.endpoint = "dummy_endpoint"
231
232                with self.assertRaisesRegex(
233                    RendezvousConnectionError,
234                    r"^The connection to the C10d store has failed. See inner exception for "
235                    r"details.$",
236                ):
237                    create_backend(self._params)
238
239    def test_create_backend_raises_error_if_store_type_is_invalid(self) -> None:
240        self._params.config["store_type"] = "dummy_store_type"
241
242        with self.assertRaisesRegex(
243            ValueError,
244            r"^Invalid store type given. Currently only supports file and tcp.$",
245        ):
246            create_backend(self._params)
247
248    def test_create_backend_raises_error_if_read_timeout_is_invalid(self) -> None:
249        for read_timeout in ["0", "-10"]:
250            with self.subTest(read_timeout=read_timeout):
251                self._params.config["read_timeout"] = read_timeout
252
253                with self.assertRaisesRegex(
254                    ValueError, r"^The read timeout must be a positive integer.$"
255                ):
256                    create_backend(self._params)
257
258    @mock.patch("tempfile.mkstemp")
259    def test_create_backend_raises_error_if_tempfile_creation_fails(
260        self, tempfile_mock
261    ) -> None:
262        tempfile_mock.side_effect = OSError("test error")
263        # Set the endpoint to empty so it defaults to creating a temp file
264        self._params_filestore.endpoint = ""
265        with self.assertRaisesRegex(
266            RendezvousError,
267            r"The file creation for C10d store has failed. See inner exception for details.",
268        ):
269            create_backend(self._params_filestore)
270
271    @mock.patch(
272        "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.FileStore"
273    )
274    def test_create_backend_raises_error_if_file_path_is_invalid(
275        self, filestore_mock
276    ) -> None:
277        filestore_mock.side_effect = RuntimeError("test error")
278        self._params_filestore.endpoint = "bad file path"
279        with self.assertRaisesRegex(
280            RendezvousConnectionError,
281            r"^The connection to the C10d store has failed. See inner exception for "
282            r"details.$",
283        ):
284            create_backend(self._params_filestore)
285