xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3from contextlib import contextmanager
4from typing import cast
5
6from . import api, TensorPipeAgent
7
8
9logger = logging.getLogger(__name__)
10
11
12@contextmanager
13def _group_membership_management(store, name, is_join):
14    token_key = "RpcGroupManagementToken"
15    join_or_leave = "join" if is_join else "leave"
16    my_token = f"Token_for_{name}_{join_or_leave}"
17    while True:
18        # Retrieve token from store to signal start of rank join/leave critical section
19        returned = store.compare_set(token_key, "", my_token).decode()
20        if returned == my_token:
21            # Yield to the function this context manager wraps
22            yield
23            # Finished, now exit and release token
24            # Update from store to signal end of rank join/leave critical section
25            store.set(token_key, "")
26            # Other will wait for this token to be set before they execute
27            store.set(my_token, "Done")
28            break
29        else:
30            # Store will wait for the token to be released
31            try:
32                store.wait([returned])
33            except RuntimeError:
34                logger.error(
35                    "Group membership token %s timed out waiting for %s to be released.",
36                    my_token,
37                    returned,
38                )
39                raise
40
41
42def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
43    agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
44    ret = agent._update_group_membership(
45        worker_info, my_devices, reverse_device_map, is_join
46    )
47    return ret
48