1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8import os 9import sys 10import unittest 11import uuid 12 13from torch.distributed.elastic.rendezvous import RendezvousParameters 14from torch.distributed.elastic.rendezvous.etcd_rendezvous import create_rdzv_handler 15from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer 16 17 18if os.getenv("CIRCLECI"): 19 print("T85992919 temporarily disabling in circle ci", file=sys.stderr) 20 sys.exit(0) 21 22 23class EtcdRendezvousTest(unittest.TestCase): 24 @classmethod 25 def setUpClass(cls): 26 # start a standalone, single process etcd server to use for all tests 27 cls._etcd_server = EtcdServer() 28 cls._etcd_server.start() 29 30 @classmethod 31 def tearDownClass(cls): 32 # stop the standalone etcd server 33 cls._etcd_server.stop() 34 35 def test_etcd_rdzv_basic_params(self): 36 """ 37 Check that we can create the handler with a minimum set of 38 params 39 """ 40 rdzv_params = RendezvousParameters( 41 backend="etcd", 42 endpoint=f"{self._etcd_server.get_endpoint()}", 43 run_id=f"{uuid.uuid4()}", 44 min_nodes=1, 45 max_nodes=1, 46 ) 47 etcd_rdzv = create_rdzv_handler(rdzv_params) 48 self.assertIsNotNone(etcd_rdzv) 49 50 def test_etcd_rdzv_additional_params(self): 51 run_id = str(uuid.uuid4()) 52 rdzv_params = RendezvousParameters( 53 backend="etcd", 54 endpoint=f"{self._etcd_server.get_endpoint()}", 55 run_id=run_id, 56 min_nodes=1, 57 max_nodes=1, 58 timeout=60, 59 last_call_timeout=30, 60 protocol="http", 61 ) 62 63 etcd_rdzv = create_rdzv_handler(rdzv_params) 64 65 self.assertIsNotNone(etcd_rdzv) 66 self.assertEqual(run_id, etcd_rdzv.get_run_id()) 67 68 def test_get_backend(self): 69 run_id = str(uuid.uuid4()) 70 rdzv_params = RendezvousParameters( 71 backend="etcd", 72 endpoint=f"{self._etcd_server.get_endpoint()}", 73 run_id=run_id, 74 min_nodes=1, 75 max_nodes=1, 76 timeout=60, 77 last_call_timeout=30, 78 protocol="http", 79 ) 80 81 etcd_rdzv = create_rdzv_handler(rdzv_params) 82 83 self.assertEqual("etcd", etcd_rdzv.get_backend()) 84