1# mypy: allow-untyped-defs 2import logging 3from contextlib import contextmanager 4from typing import cast 5 6from . import api, TensorPipeAgent 7 8 9logger = logging.getLogger(__name__) 10 11 12@contextmanager 13def _group_membership_management(store, name, is_join): 14 token_key = "RpcGroupManagementToken" 15 join_or_leave = "join" if is_join else "leave" 16 my_token = f"Token_for_{name}_{join_or_leave}" 17 while True: 18 # Retrieve token from store to signal start of rank join/leave critical section 19 returned = store.compare_set(token_key, "", my_token).decode() 20 if returned == my_token: 21 # Yield to the function this context manager wraps 22 yield 23 # Finished, now exit and release token 24 # Update from store to signal end of rank join/leave critical section 25 store.set(token_key, "") 26 # Other will wait for this token to be set before they execute 27 store.set(my_token, "Done") 28 break 29 else: 30 # Store will wait for the token to be released 31 try: 32 store.wait([returned]) 33 except RuntimeError: 34 logger.error( 35 "Group membership token %s timed out waiting for %s to be released.", 36 my_token, 37 returned, 38 ) 39 raise 40 41 42def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): 43 agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) 44 ret = agent._update_group_membership( 45 worker_info, my_devices, reverse_device_map, is_join 46 ) 47 return ret 48