1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8import os 9import sys 10import unittest 11 12import etcd 13 14from torch.distributed.elastic.rendezvous.etcd_rendezvous import ( 15 EtcdRendezvous, 16 EtcdRendezvousHandler, 17) 18from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer 19 20 21if os.getenv("CIRCLECI"): 22 print("T85992919 temporarily disabling in circle ci", file=sys.stderr) 23 sys.exit(0) 24 25 26class EtcdServerTest(unittest.TestCase): 27 def test_etcd_server_start_stop(self): 28 server = EtcdServer() 29 server.start() 30 31 try: 32 port = server.get_port() 33 host = server.get_host() 34 35 self.assertGreater(port, 0) 36 self.assertEqual("localhost", host) 37 self.assertEqual(f"{host}:{port}", server.get_endpoint()) 38 self.assertIsNotNone(server.get_client().version) 39 finally: 40 server.stop() 41 42 def test_etcd_server_with_rendezvous(self): 43 server = EtcdServer() 44 server.start() 45 46 try: 47 client = etcd.Client(server.get_host(), server.get_port()) 48 49 rdzv = EtcdRendezvous( 50 client=client, 51 prefix="test", 52 run_id=1, 53 num_min_workers=1, 54 num_max_workers=1, 55 timeout=60, 56 last_call_timeout=30, 57 ) 58 rdzv_handler = EtcdRendezvousHandler(rdzv) 59 rdzv_info = rdzv_handler.next_rendezvous() 60 self.assertIsNotNone(rdzv_info.store) 61 self.assertEqual(0, rdzv_info.rank) 62 self.assertEqual(1, rdzv_info.world_size) 63 finally: 64 server.stop() 65