1# mypy: allow-untyped-defs 2import warnings 3from abc import ABC, abstractmethod 4from types import TracebackType 5from typing import Any, List, NamedTuple, Optional, Type 6 7import torch 8import torch.distributed as dist 9 10 11__all__ = ["JoinHook", "Joinable", "Join"] 12 13 14class JoinHook: 15 r""" 16 This defines a join hook, which provides two entry points in the join context manager. 17 18 Entry points : a main hook, which is called repeatedly while there exists a non-joined 19 process, and a post-hook, which is called once all processes have joined. 20 21 To implement a join hook for the generic join context manager, define a 22 class that inherits from :class:`JoinHook` and override ``main_hook()`` and 23 ``post_hook()`` as appropriate. 24 """ 25 26 def main_hook(self) -> None: 27 r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. 28 29 Training iteration i.e., in one forward pass, backward pass, and optimizer step. 30 """ 31 32 def post_hook(self, is_last_joiner: bool) -> None: 33 r""" 34 Call hook after all processes have joined. 35 36 It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. 37 38 Arguments: 39 is_last_joiner (bool): ``True`` if the rank is one of the last to 40 join; ``False`` otherwise. 41 """ 42 43 44class Joinable(ABC): 45 r""" 46 This defines an abstract base class for joinable classes. 47 48 A joinable class 49 (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, 50 which returns a :class:`JoinHook` instance, in addition to 51 :meth:`join_device` and :meth:`join_process_group` that return device and 52 process group information, respectively. 53 """ 54 55 @abstractmethod 56 def __init__(self) -> None: 57 super().__init__() 58 self._join_config = _JoinConfig.construct_disabled_join_config() 59 60 @abstractmethod 61 def join_hook(self, **kwargs) -> JoinHook: 62 r""" 63 Return a :class:`JoinHook` instance for the given :class:`Joinable`. 64 65 Arguments: 66 kwargs (dict): a :class:`dict` containing any keyword arguments 67 to modify the behavior of the join hook at run time; all 68 :class:`Joinable` instances sharing the same join context 69 manager are forwarded the same value for ``kwargs``. 70 """ 71 ... 72 73 @property 74 @abstractmethod 75 def join_device(self) -> torch.device: 76 r"""Return the device from which to perform collective communications needed by the join context manager.""" 77 ... 78 79 @property 80 @abstractmethod 81 def join_process_group(self) -> Any: 82 r"""Returns the process group for the collective communications needed by the join context manager itself.""" 83 ... 84 85 86class _JoinConfig(NamedTuple): 87 r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" 88 89 enable: bool 90 throw_on_early_termination: bool 91 is_first_joinable: bool 92 93 @staticmethod 94 def construct_disabled_join_config(): 95 r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. 96 97 e.g. if the caller is not in a join context manager. 98 """ 99 return _JoinConfig( 100 enable=False, throw_on_early_termination=False, is_first_joinable=False 101 ) 102 103 104class Join: 105 r""" 106 This class defines the generic join context manager, which allows custom hooks to be called after a process joins. 107 108 These hooks should shadow the 109 collective communications of non-joined processes to prevent hanging and 110 erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` 111 for details about the hook definition. 112 113 .. warning:: 114 The context manager requires each participating :class:`Joinable` to 115 call the method :meth:`notify_join_context()` before its own per- 116 iteration collective communications to ensure correctness. 117 118 .. warning:: 119 The context manager requires that all ``process_group`` attributes in 120 the :class:`JoinHook` objects are the same. If there are multiple 121 :class:`JoinHook` objects, then the ``device`` of the first is used. 122 The process group and device information is used for checking for non- 123 joined processes and for notifying processes to throw an exception if 124 ``throw_on_early_termination`` is enabled, both of which using an all- 125 reduce. 126 127 Arguments: 128 joinables (List[Joinable]): a list of the participating 129 :class:`Joinable` s; their hooks are iterated over in the given 130 order. 131 132 enable (bool): a flag enabling uneven input detection; setting to 133 ``False`` disables the context manager's functionality and should 134 only be set when the user knows the inputs will not be uneven 135 (default: ``True``). 136 137 throw_on_early_termination (bool): a flag controlling whether to throw an 138 exception upon detecting uneven inputs (default: ``False``). 139 140 Example:: 141 142 >>> import os 143 >>> import torch 144 >>> import torch.distributed as dist 145 >>> import torch.multiprocessing as mp 146 >>> # xdoctest: +SKIP 147 >>> import torch.nn.parallel.DistributedDataParallel as DDP 148 >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO 149 >>> from torch.distributed.algorithms.join import Join 150 >>> 151 >>> # On each spawned worker 152 >>> def worker(rank): 153 >>> dist.init_process_group("nccl", rank=rank, world_size=2) 154 >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) 155 >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) 156 >>> # Rank 1 gets one more input than rank 0 157 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] 158 >>> with Join([model, optim]): 159 >>> for input in inputs: 160 >>> loss = model(input).sum() 161 >>> loss.backward() 162 >>> optim.step() 163 >>> # All ranks reach here without hanging/erroring 164 """ 165 166 def __init__( 167 self, 168 joinables: List[Joinable], 169 enable: bool = True, 170 throw_on_early_termination: bool = False, 171 **kwargs, 172 ): 173 if len(joinables) == 0: 174 raise ValueError("The join context manager requires at least one joinable") 175 self._joinables = joinables 176 self._join_hooks = [ 177 joinable.join_hook(**kwargs) for joinable in self._joinables 178 ] 179 self._enable = enable 180 self._throw_on_early_termination = throw_on_early_termination 181 self._set_joinable_configs() 182 self._extract_dist_info() 183 184 def _set_joinable_configs(self) -> None: 185 r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" 186 assert len(self._joinables) > 0 187 is_first_joinable = True 188 for joinable in self._joinables: 189 joinable._join_config = _JoinConfig( 190 enable=self._enable, 191 throw_on_early_termination=self._throw_on_early_termination, 192 is_first_joinable=is_first_joinable, 193 ) 194 is_first_joinable = False 195 196 def _extract_dist_info(self) -> None: 197 r""" 198 Extract the process group and device information from the joinables. 199 200 If there are multiple joinables, then the context manager uses the 201 first specified device. 202 203 Preconditions: 204 ``self._joinables`` is not ``None`` and is non-empty. 205 206 Raises: 207 ValueError 208 If there are multiple conflicting ``process_group`` attributes 209 among the ``Joinable`` objects. 210 """ 211 process_group = None 212 device = None 213 for joinable in self._joinables: 214 if process_group is None: 215 process_group = joinable.join_process_group 216 elif process_group != joinable.join_process_group: 217 raise ValueError( 218 "Using join context manager with multiple process groups" 219 ) 220 if device is None: 221 device = joinable.join_device 222 self._process_group = process_group 223 self._rank = dist.get_rank(self._process_group) 224 self._device = device 225 226 def __enter__(self): 227 ... 228 229 def __exit__( 230 self, 231 type: Optional[Type[BaseException]], 232 value: Optional[BaseException], 233 traceback: Optional[TracebackType], 234 ): 235 r""" 236 Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. 237 238 Raises: 239 RuntimeError 240 If ``throw_on_early_termination=True``. 241 """ 242 if not self._enable or type: 243 return # propagate the exception directly if one was raised 244 245 all_procs_joined = False 246 is_last_joiner = True 247 248 i = 0 249 WARN_THRESHOLD = 1000 250 warnings.simplefilter("once") 251 252 while not all_procs_joined: 253 if i > WARN_THRESHOLD: 254 warnings.warn( 255 "Detected uneven input skew of greater than " 256 f"{WARN_THRESHOLD}. This means that rank " 257 f"{self._rank} has at least {WARN_THRESHOLD} " 258 f"fewer inputs than other currently-active ranks. " 259 "This level of skew could lead to performance " 260 "degradation during training." 261 ) 262 # Shadow the all-reduce in non-joined processes 263 num_nonjoined_procs = self._get_num_nonjoined_procs() 264 if num_nonjoined_procs == 0: 265 all_procs_joined = True 266 else: 267 if self._throw_on_early_termination: 268 self._notify_procs_to_terminate() 269 270 # Run main hooks 271 for join_hook in self._join_hooks: 272 join_hook.main_hook() 273 274 is_last_joiner = False 275 i += 1 276 277 # Run post-hooks 278 for join_hook in self._join_hooks: 279 join_hook.post_hook(is_last_joiner) 280 281 def _get_num_nonjoined_procs(self): 282 r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" 283 num_nonjoined_procs = torch.zeros(1, device=self._device) 284 dist.all_reduce(num_nonjoined_procs, group=self._process_group) 285 return num_nonjoined_procs.item() 286 287 def _notify_procs_to_terminate(self): 288 r"""Schedule an all-reduce to notify non-joined processes to terminate. 289 290 Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. 291 """ 292 ones = torch.ones(1, device=self._device) 293 dist.all_reduce(ones, group=self._process_group) 294 raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") 295 296 @staticmethod 297 def notify_join_context(joinable: Joinable): 298 r""" 299 Notifies the join context manager that the calling process has not yet joined. 300 301 Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected 302 (i.e. if one process has already joined) and throws an exception if so. 303 304 This method should be called from a :class:`Joinable` object before 305 its per-iteration collective communications. For example, this should 306 be called at the beginning of the forward pass in 307 :class:`DistributedDataParallel`. 308 309 Only the first :class:`Joinable` object passed into the context 310 manager performs the collective communications in this method, and 311 for the others, this method is vacuous. 312 313 Arguments: 314 joinable (Joinable): the :class:`Joinable` object calling this 315 method. 316 317 Returns: 318 An async work handle for the all-reduce meant to notify the context 319 manager that the process has not yet joined if ``joinable`` is the 320 first one passed into the context manager; ``None`` otherwise. 321 """ 322 assert hasattr(joinable, "_join_config"), ( 323 f"Check that the {type(joinable)} constructor calls the " 324 "``Joinable`` constructor" 325 ) 326 327 join_config = joinable._join_config 328 # First joinable is responsible for the collective communications 329 if not join_config.is_first_joinable or not join_config.enable: 330 return None 331 332 device = joinable.join_device 333 process_group = joinable.join_process_group 334 335 # Schedule an all-reduce to indicate that the caller has not yet joined 336 ones = torch.ones(1, device=device) 337 work = dist.all_reduce(ones, group=process_group, async_op=True) 338 339 if join_config.throw_on_early_termination: 340 # Check if uneven inputs have been detected 341 zeros = torch.zeros(1, device=device) 342 dist.all_reduce(zeros, group=process_group) 343 should_throw = zeros.item() 344 if should_throw: 345 raise RuntimeError( 346 "Detected at least one rank that exhausted inputs. " 347 "Throwing across all ranks." 348 ) 349 return work 350