1# mypy: allow-untyped-defs 2try: 3 from urllib.parse import urlparse, urlunparse 4except ImportError as e: 5 raise ImportError( 6 "urllib cannot be found, urlparse from python2 is no longer supported." 7 ) from e 8 9import numbers 10import os 11import sys 12from datetime import timedelta 13from typing import Callable, Dict, Iterator, Optional, Tuple 14 15from torch.distributed import FileStore, PrefixStore, Store, TCPStore 16 17from .constants import default_pg_timeout 18 19 20_rendezvous_handlers: Dict[str, Callable[..., Iterator[Tuple[Store, int, int]]]] = {} 21 22__all__ = ["register_rendezvous_handler", "rendezvous"] 23 24 25def register_rendezvous_handler(scheme, handler): 26 """ 27 Register a new rendezvous handler. 28 29 Before we can run collective algorithms, participating processes 30 need to find each other and exchange information to be able to 31 communicate. We call this process rendezvous. 32 33 The outcome of the rendezvous process is a triplet containing a 34 shared key/value store, the rank of the process, and the total 35 number of participating processes. 36 37 If none of the bundled rendezvous methods apply to your execution 38 environment you can opt to register your own rendezvous handler. 39 Pick a unique name and use the URL scheme to identify it when 40 calling the `rendezvous()` function. 41 42 Args: 43 scheme (str): URL scheme to identify your rendezvous handler. 44 handler (function): Handler that is invoked when the 45 `rendezvous()` function is called with a URL that uses 46 the corresponding scheme. It must be a generator function 47 that yields the triplet. 48 """ 49 global _rendezvous_handlers 50 if scheme in _rendezvous_handlers: 51 raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") 52 _rendezvous_handlers[scheme] = handler 53 54 55# Query will have format "rank=0&world_size=1" and is 56# converted into {"rank": 0, "world_size": 1} 57def _query_to_dict(query: str) -> Dict[str, str]: 58 return { 59 pair[0]: pair[1] 60 for pair in (pair.split("=") for pair in filter(None, query.split("&"))) 61 } 62 63 64def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: 65 # libuv is the default backend for TCPStore. To enable the non-libuv backend, 66 # user can explicitly specify ``use_libuv=0`` in the URL parameter. 67 return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" 68 69 70def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwargs): 71 result = urlparse(url) 72 if world_size_opt is None: 73 world_size = -1 74 if result.scheme == "env": 75 rank = int(os.environ.get("RANK", rank)) 76 # If the world_size env variable is not present then it is a dynamic group 77 world_size = int(os.environ.get("WORLD_SIZE", world_size)) 78 else: 79 world_size = world_size_opt 80 if rank != -1 or world_size != -1 or world_size_opt is None: 81 query_dict = _query_to_dict(result.query) 82 assert ( 83 "rank" not in query_dict and "world_size" not in query_dict 84 ), f"The url: {url} has node-specific arguments(rank, world_size) already." 85 if rank != -1: 86 query_dict["rank"] = str(rank) 87 if world_size != -1 or world_size_opt is None: 88 query_dict["world_size"] = str(world_size) 89 result = result._replace( 90 query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}" 91 ) 92 url = urlunparse(result) 93 94 if result.scheme not in _rendezvous_handlers: 95 raise RuntimeError(f"No rendezvous handler for {result.scheme}://") 96 return _rendezvous_handlers[result.scheme](url, **kwargs) 97 98 99def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): 100 if not isinstance(url, (str, bytes)): 101 raise RuntimeError(f"`url` must be a string. {type(url)}: {url}") 102 103 if not isinstance(rank, numbers.Integral): 104 raise RuntimeError(f"`rank` must be an integer. {rank}") 105 106 if not isinstance(world_size, numbers.Integral): 107 raise RuntimeError(f"`world_size` must be an integer. {world_size}") 108 109 return _rendezvous_helper(url, rank, world_size, **kwargs) 110 111 112def _create_store_from_options(backend_options, rank): 113 store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None)) 114 return store 115 116 117def _rendezvous_error(msg): 118 return ValueError("Error initializing torch.distributed using " + msg) 119 120 121def _file_rendezvous_handler(url: str, **kwargs): 122 def _error(msg): 123 return _rendezvous_error("file:// rendezvous: " + msg) 124 125 result = urlparse(url) 126 path = result.path 127 if sys.platform == "win32": 128 import urllib.request 129 130 full_path = result.netloc + result.path 131 path = urllib.request.url2pathname(full_path) 132 if path: 133 # Normalizing an empty string produces ".", which is not expected. 134 path = os.path.normpath(path) 135 136 if not path: 137 raise _error("path missing") 138 query_dict = _query_to_dict(result.query) 139 if "rank" not in query_dict: 140 raise _error("rank parameter missing") 141 if "world_size" not in query_dict: 142 raise _error("world size parameter missing") 143 144 rank = int(query_dict["rank"]) 145 world_size = int(query_dict["world_size"]) 146 store = FileStore(path, world_size) 147 yield (store, rank, world_size) 148 149 # If this configuration is invalidated, there is nothing we can do about it 150 raise RuntimeError("Unable to perform rerendezvous using file:// method") 151 152 153def _torchelastic_use_agent_store() -> bool: 154 return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) 155 156 157def _create_c10d_store( 158 hostname, port, rank, world_size, timeout, use_libuv=True 159) -> Store: 160 """ 161 Smartly creates a c10d Store object on ``rank`` based on whether we need to re-use agent store. 162 163 The TCPStore server is assumed to be hosted 164 on ``hostname:port``. 165 166 By default, the TCPStore server uses the asynchronous implementation 167 ``LibUVStoreDaemon`` which utilizes libuv. 168 169 If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that 170 the agent leader (node rank 0) hosts the TCPStore server (for which the 171 endpoint is specified by the given ``hostname:port``). Hence 172 ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). 173 174 If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host 175 the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname 176 and port are correctly passed via ``hostname`` and ``port``. All 177 non-zero ranks will create and return a TCPStore client. 178 """ 179 # check if port is uint16_t 180 if not 0 <= port < 2**16: 181 raise ValueError(f"port must have value from 0 to 65535 but was {port}.") 182 183 if _torchelastic_use_agent_store(): 184 attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] 185 tcp_store = TCPStore(hostname, port, world_size, False, timeout) 186 return PrefixStore(f"/worker/attempt_{attempt}", tcp_store) 187 else: 188 start_daemon = rank == 0 189 return TCPStore( 190 hostname, 191 port, 192 world_size, 193 start_daemon, 194 timeout, 195 multi_tenant=True, 196 use_libuv=use_libuv, 197 ) 198 199 200def _tcp_rendezvous_handler( 201 url: str, timeout: timedelta = default_pg_timeout, **kwargs 202): 203 def _error(msg): 204 return _rendezvous_error("tcp:// rendezvous: " + msg) 205 206 result = urlparse(url) 207 if not result.port: 208 raise _error("port number missing") 209 query_dict = _query_to_dict(result.query) 210 if "rank" not in query_dict: 211 raise _error("rank parameter missing") 212 if "world_size" not in query_dict: 213 raise _error("world size parameter missing") 214 215 rank = int(query_dict["rank"]) 216 world_size = int(query_dict["world_size"]) 217 use_libuv = _get_use_libuv_from_query_dict(query_dict) 218 219 assert result.hostname is not None 220 221 store = _create_c10d_store( 222 result.hostname, result.port, rank, world_size, timeout, use_libuv 223 ) 224 225 yield (store, rank, world_size) 226 227 # If this configuration is invalidated, there is nothing we can do about it 228 raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") 229 230 231def _env_rendezvous_handler( 232 url: str, timeout: timedelta = default_pg_timeout, **kwargs 233): 234 def _error(msg): 235 return _rendezvous_error("env:// rendezvous: " + msg) 236 237 def _env_error(var): 238 return _error(f"environment variable {var} expected, but not set") 239 240 def _get_env_or_raise(env_var: str) -> str: 241 env_val = os.environ.get(env_var, None) 242 if not env_val: 243 raise _env_error(env_var) 244 else: 245 return env_val 246 247 result = urlparse(url) 248 query_dict = _query_to_dict(result.query) 249 250 rank: int 251 world_size: int 252 master_port: int 253 master_addr: str 254 255 if "rank" in query_dict: 256 rank = int(query_dict["rank"]) 257 else: 258 rank = int(_get_env_or_raise("RANK")) 259 260 if "world_size" in query_dict: 261 world_size = int(query_dict["world_size"]) 262 else: 263 world_size = int(_get_env_or_raise("WORLD_SIZE")) 264 265 master_addr = _get_env_or_raise("MASTER_ADDR") 266 master_port = int(_get_env_or_raise("MASTER_PORT")) 267 use_libuv = _get_use_libuv_from_query_dict(query_dict) 268 269 store = _create_c10d_store( 270 master_addr, master_port, rank, world_size, timeout, use_libuv 271 ) 272 273 yield (store, rank, world_size) 274 275 # If this configuration is invalidated, there is nothing we can do about it 276 raise RuntimeError("Unable to perform re-rendezvous using env:// method") 277 278 279register_rendezvous_handler("tcp", _tcp_rendezvous_handler) 280register_rendezvous_handler("env", _env_rendezvous_handler) 281register_rendezvous_handler("file", _file_rendezvous_handler) 282