1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import contextlib 4import warnings 5from typing import Dict, List, Optional 6 7import torch 8import torch.distributed as dist 9from torch import Tensor 10from torch.distributed.device_mesh import _get_device_handle, DeviceMesh 11from torch.distributed.tensor._dtensor_spec import DTensorSpec 12from torch.distributed.tensor.placement_types import Shard 13 14 15__all__ = [ 16 "is_rng_supported_mesh", 17 "manual_seed", 18 "OffsetBasedRNGTracker", 19 "TensorParallelRNGTracker", 20] 21 22_rng_tracker: Optional["_RNGStateTracker"] = None 23 24 25def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool: 26 """Checks if the current device of ``device_mesh`` supports DTensor's random APIs. 27 Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest 28 users call this API to test the availability before using our random APIs. 29 30 Args: 31 device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the 32 random ops APIs are supported. 33 34 Returns: 35 A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise. 36 37 .. warning:: 38 Currently we only support correct RNG on cuda/cuda-like devices. 39 """ 40 device_handle = _get_device_handle(device_mesh.device_type) 41 if device_handle and hasattr(device_handle, "set_rng_state"): 42 return True 43 else: 44 # TODO: Logs way too much 45 warnings.warn( 46 f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh" 47 ) 48 return False 49 50 51def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: 52 """Sets the seed for generating random numbers for the calling rank. 53 54 Args: 55 seed (int): The desired seed. 56 device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. 57 58 Returns: 59 None 60 61 .. warning:: 62 When calling this function, :func:`manual_seed` must be called from all ranks of the 63 default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``, 64 with the same ``seed`` value. 65 If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it, 66 ``manual_seed`` will not set its GPU device's generator seed. 67 Current implementation only supports a GPU device mesh. 68 """ 69 device_handle = _get_device_handle(device_mesh.device_type) 70 if not device_handle: 71 raise NotImplementedError( 72 f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}" 73 ) 74 75 # allgather the seed over the default PG 76 object_list = [seed] * dist.get_world_size() 77 dist.all_gather_object(object_list, seed) 78 for rank, object in enumerate(object_list): 79 if seed != int(object): 80 raise RuntimeError( 81 f"calling manual_seed function over {device_mesh} but received different seed values on ranks:", 82 f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!", 83 ) 84 # instantiate a RNG tracker if haven't. By default DTensor uses an 85 # OffsetBasedRNGTracker to perform random operators. 86 global _rng_tracker 87 if not _rng_tracker: 88 _rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type) 89 90 # the current rank is in mesh 91 if device_mesh.get_coordinate() is not None: 92 if isinstance(_rng_tracker, TensorParallelRNGTracker): 93 _rng_tracker._manual_seed(device_mesh, seed) 94 elif isinstance(_rng_tracker, OffsetBasedRNGTracker): 95 _rng_tracker._manual_seed(seed) 96 else: 97 raise RuntimeError( 98 f"Unknown type of cuda RNG state tracker: _rng_tracker = {_rng_tracker}" 99 ) 100 101 102class _RNGStateTracker: 103 """ 104 _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object) 105 in a dict, mapping from a corresponding tag to each state tensor. It also provides 106 a set of convenient utility methods to help access/modify the state tensors. The most 107 important interface is _distribute_region which will be used when DTensor executes 108 a random op (an operator that calls RNG). 109 """ 110 111 def __init__(self, device_type: str = "cuda"): 112 self._device_type = device_type 113 self._device_handle = _get_device_handle(device_type) 114 if not (self._device_handle and self._device_handle.is_available()): 115 raise RuntimeError( 116 f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device" 117 ) 118 119 self._states: Dict[str, Tensor] = {} 120 self._devices = [self._device_handle.current_device()] 121 self._use_distribute_region = True 122 123 @property 124 def rng_states(self) -> Dict[str, Tensor]: 125 return self._states 126 127 @property 128 def distribute_region_enabled(self) -> bool: 129 return self._use_distribute_region 130 131 @distribute_region_enabled.setter 132 def distribute_region_enabled(self, value) -> None: 133 self._use_distribute_region = value 134 135 def rng_state_is_sync(self, name) -> bool: 136 return name in self.rng_states 137 138 def get_seed(self, name: str) -> int: 139 if name not in self.rng_states: 140 raise RuntimeError( 141 f"{self.__class__.__name__} does not have random state for {name}" 142 ) 143 144 seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64) 145 return int(seed_tensor.item()) 146 147 def set_seed(self, name: str, seed: int) -> None: 148 seed_tensor = torch.tensor([seed]).view(torch.uint8) 149 offset_tensor = torch.tensor([0]).view(torch.uint8) 150 self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) 151 152 def _distribute_region(self, spec: DTensorSpec): 153 pass 154 155 156class OffsetBasedRNGTracker(_RNGStateTracker): 157 """ 158 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states 159 should be shared and synchronized among all ranks to respect the semantics of DTensor 160 random operators. 161 """ 162 163 def __init__(self, device_type: str = "cuda"): 164 super().__init__(device_type) 165 # synchronize RNG state using rank 0's current one 166 rng_state = self._device_handle.get_rng_state().to(device_type) 167 dist.broadcast(rng_state, 0) 168 self.rng_states["parallel-rng"] = rng_state.to("cpu") 169 170 def _manual_seed(self, parallel_seed: int) -> None: 171 self.set_seed("parallel-rng", parallel_seed) 172 173 @contextlib.contextmanager 174 def _distribute_region(self, spec: DTensorSpec): 175 # check if the parallel rng state has been synchronized or not 176 if not self.rng_state_is_sync("parallel-rng"): 177 raise RuntimeError( 178 "OffsetBasedRNGTracker requires the random state to be synchronized " 179 "before entering into a distribute region!" 180 ) 181 182 if self.distribute_region_enabled: 183 old_offset = self.get_offset("parallel-rng") 184 self._set_pre_op_offset(spec) 185 with torch.random.fork_rng(self._devices, device_type=self._device_type): 186 self._device_handle.set_rng_state(self.rng_states["parallel-rng"]) 187 try: 188 yield # execute the region code 189 finally: 190 # update offset to synchronize among ranks 191 self._set_post_op_offset(spec, old_offset) 192 else: 193 yield 194 195 def get_offset(self, name: str) -> int: 196 if name not in self.rng_states: 197 raise RuntimeError( 198 f"{self.__class__.__name__} does not have random state for {name}" 199 ) 200 201 offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64) 202 return int(offset_tensor.item()) 203 204 def set_offset(self, name: str, offset: int) -> None: 205 if name not in self.rng_states: 206 raise RuntimeError( 207 f"{self.__class__.__name__} does not have random state for {name}" 208 ) 209 210 seed_tensor = (self.rng_states[name])[0:8] 211 offset_tensor = torch.tensor([offset]).view(torch.uint8) 212 self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) 213 214 def _set_pre_op_offset(self, spec: DTensorSpec) -> None: 215 """Set the starting RNG offset for current device's local shard before actual 216 op execution. The pre_op_offset value should start from the current RNG offset 217 and increment by the size of local shard until it reaches the size of the whole 218 DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset 219 will be the same. 220 221 Args: 222 spec (:class:`DTensorSpec`): the spec of the DTensor object on which 223 we prepare the offset for running random ops. 224 225 Returns: 226 None 227 228 .. warning:: 229 Note that, current implementation does not consider DTensor's continguity. 230 231 Example: 232 take a DTensor of shape [8, 16] as an example. Assume that the DTensor 233 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), 234 and the mesh is: 235 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 236 ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank 237 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). 238 239 Another concept to introduce besides rank coordinate is shard coordinate. 240 Each rank holds a local shard of the DTensor. In the example, the DTensor 241 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and 242 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. 243 That being said, the local shard on rank 0 and rank 2 correspond to the same 244 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate 245 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice 246 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). 247 248 Once we have rank coordinate and shard coordinate, we can calculate on each rank 249 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map 250 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord 251 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). 252 Following this calculation, 253 rank 0 and rank 2 holds the shard of coord (0, 0); 254 rank 1 and rank 3 holds the shard of coord (0, 1); 255 rank 4 and rank 6 holds the shard of coord (1, 0); 256 rank 5 and rank 7 holds the shard of coord (1, 1); 257 258 The last value to calculate before obtaining the starting offset is the shard linear index. 259 The starting offset for each rank will be its shard_linear_index * local_tensor_numel. 260 """ 261 dtensor_shape = spec.shape 262 mesh = spec.mesh 263 dim_map = spec.dim_map 264 265 # Compute shard coordinate: 266 # The coordinate on each tensor dim is a tuple (idx, range) 267 # If a DTensor is partitioned on its dim i into n shards, and the current rank 268 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i 269 coordinate = mesh.get_coordinate() 270 assert coordinate is not None 271 shard_coord = [ 272 coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dim_map 273 ] 274 shard_size = [ 275 mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dim_map 276 ] 277 278 # compute shard linear index 279 shard_linear_idx = self._calc_shard_linear_idx(shard_coord, shard_size) 280 281 # compute starting offset using the first shard's size 282 local_size_on_rank_0 = list(dtensor_shape) 283 for idx, placement in enumerate(spec.placements): 284 if isinstance(placement, Shard): 285 mesh_dim_size = mesh.size(idx) 286 shard_dim = placement.dim 287 local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim( 288 dtensor_shape[shard_dim], 289 mesh_dim_size, 290 0, 291 return_offset=False, 292 )[0] 293 294 from torch.distributed.tensor._ops.utils import prod 295 296 local_size = prod(local_size_on_rank_0) 297 298 # get current RNG offset 299 current_offset = self.get_offset("parallel-rng") 300 301 # pytorch: offset must be multiple of 4 302 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp 303 offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 304 self.set_offset("parallel-rng", current_offset + offset_incr) 305 306 def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: 307 """Sets the RNG to a synchronized state after running the local random op. Every 308 rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is 309 the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor 310 random ops. 311 312 Args: 313 spec (:class:`DTensorSpec`): the spec of the DTensor object on which 314 we post-process the offset for running random ops. 315 316 Returns: 317 None 318 """ 319 dtensor_shape = spec.shape 320 321 from torch.distributed.tensor._ops.utils import prod 322 323 numel = prod(dtensor_shape) 324 # pytorch: offset must be multiple of 4 325 # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp 326 numel = (numel + 3) // 4 * 4 327 self.set_offset("parallel-rng", old_offset + numel) 328 329 def _calc_shard_linear_idx( 330 self, shard_coord: List[int], shard_size: List[int] 331 ) -> int: 332 # compute shard linear index 333 shard_linear_idx = 0 334 shard_coord_stride = 1 335 for idx, size in zip(reversed(shard_coord), reversed(shard_size)): 336 shard_linear_idx += idx * shard_coord_stride 337 shard_coord_stride *= size 338 339 return shard_linear_idx 340 341 342class TensorParallelRNGTracker(_RNGStateTracker): 343 def __init__(self, device_type: str = "cuda"): 344 super().__init__(device_type) 345 # copy the default RNG state 346 self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state() 347 348 def _manual_seed( 349 self, 350 tp_mesh: DeviceMesh, 351 base_seed: int = 1234, 352 ): 353 tensor_parallel_rank = tp_mesh.get_local_rank() 354 # this magic number 2718 comes from Megatron's code 355 # (https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/random.py#L162-L163) 356 MegatronMagicNum = 2718 357 tensor_parallel_seed = base_seed + MegatronMagicNum + tensor_parallel_rank 358 self.set_seed("tensor-parallel-rng", tensor_parallel_seed) 359 360 @contextlib.contextmanager 361 def _distribute_region(self, spec: DTensorSpec): 362 # check if the tensor parallel rng state has been synchronized or not 363 if not self.rng_state_is_sync("tensor-parallel-rng"): 364 raise RuntimeError( 365 "TensorParallelRNGTracker requires the random state to be synchronized " 366 "before entering into a distribute region!" 367 ) 368 369 if self.distribute_region_enabled: 370 with torch.random.fork_rng(self._devices, device_type=self._device_type): 371 self._device_handle.set_rng_state( 372 self.rng_states["tensor-parallel-rng"] 373 ) 374 try: 375 yield 376 finally: 377 self.rng_states[ 378 "tensor-parallel-rng" 379 ] = self._device_handle.get_rng_state() 380 else: 381 yield 382