xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import contextlib
4import warnings
5from typing import Dict, List, Optional
6
7import torch
8import torch.distributed as dist
9from torch import Tensor
10from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
11from torch.distributed.tensor._dtensor_spec import DTensorSpec
12from torch.distributed.tensor.placement_types import Shard
13
14
15__all__ = [
16    "is_rng_supported_mesh",
17    "manual_seed",
18    "OffsetBasedRNGTracker",
19    "TensorParallelRNGTracker",
20]
21
22_rng_tracker: Optional["_RNGStateTracker"] = None
23
24
25def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
26    """Checks if the current device of ``device_mesh`` supports DTensor's random APIs.
27    Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest
28    users call this API to test the availability before using our random APIs.
29
30    Args:
31        device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the
32            random ops APIs are supported.
33
34    Returns:
35        A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise.
36
37    .. warning::
38        Currently we only support correct RNG on cuda/cuda-like devices.
39    """
40    device_handle = _get_device_handle(device_mesh.device_type)
41    if device_handle and hasattr(device_handle, "set_rng_state"):
42        return True
43    else:
44        # TODO: Logs way too much
45        warnings.warn(
46            f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh"
47        )
48        return False
49
50
51def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
52    """Sets the seed for generating random numbers for the calling rank.
53
54    Args:
55        seed (int): The desired seed.
56        device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
57
58    Returns:
59        None
60
61    .. warning::
62        When calling this function, :func:`manual_seed` must be called from all ranks of the
63        default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``,
64        with the same ``seed`` value.
65        If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
66        ``manual_seed`` will not set its GPU device's generator seed.
67        Current implementation only supports a GPU device mesh.
68    """
69    device_handle = _get_device_handle(device_mesh.device_type)
70    if not device_handle:
71        raise NotImplementedError(
72            f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
73        )
74
75    # allgather the seed over the default PG
76    object_list = [seed] * dist.get_world_size()
77    dist.all_gather_object(object_list, seed)
78    for rank, object in enumerate(object_list):
79        if seed != int(object):
80            raise RuntimeError(
81                f"calling manual_seed function over {device_mesh} but received different seed values on ranks:",
82                f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!",
83            )
84    # instantiate a RNG tracker if haven't. By default DTensor uses an
85    # OffsetBasedRNGTracker to perform random operators.
86    global _rng_tracker
87    if not _rng_tracker:
88        _rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
89
90    # the current rank is in mesh
91    if device_mesh.get_coordinate() is not None:
92        if isinstance(_rng_tracker, TensorParallelRNGTracker):
93            _rng_tracker._manual_seed(device_mesh, seed)
94        elif isinstance(_rng_tracker, OffsetBasedRNGTracker):
95            _rng_tracker._manual_seed(seed)
96        else:
97            raise RuntimeError(
98                f"Unknown type of cuda RNG state tracker: _rng_tracker = {_rng_tracker}"
99            )
100
101
102class _RNGStateTracker:
103    """
104    _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object)
105    in a dict, mapping from a corresponding tag to each state tensor. It also provides
106    a set of convenient utility methods to help access/modify the state tensors. The most
107    important interface is _distribute_region which will be used when DTensor executes
108    a random op (an operator that calls RNG).
109    """
110
111    def __init__(self, device_type: str = "cuda"):
112        self._device_type = device_type
113        self._device_handle = _get_device_handle(device_type)
114        if not (self._device_handle and self._device_handle.is_available()):
115            raise RuntimeError(
116                f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
117            )
118
119        self._states: Dict[str, Tensor] = {}
120        self._devices = [self._device_handle.current_device()]
121        self._use_distribute_region = True
122
123    @property
124    def rng_states(self) -> Dict[str, Tensor]:
125        return self._states
126
127    @property
128    def distribute_region_enabled(self) -> bool:
129        return self._use_distribute_region
130
131    @distribute_region_enabled.setter
132    def distribute_region_enabled(self, value) -> None:
133        self._use_distribute_region = value
134
135    def rng_state_is_sync(self, name) -> bool:
136        return name in self.rng_states
137
138    def get_seed(self, name: str) -> int:
139        if name not in self.rng_states:
140            raise RuntimeError(
141                f"{self.__class__.__name__} does not have random state for {name}"
142            )
143
144        seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64)
145        return int(seed_tensor.item())
146
147    def set_seed(self, name: str, seed: int) -> None:
148        seed_tensor = torch.tensor([seed]).view(torch.uint8)
149        offset_tensor = torch.tensor([0]).view(torch.uint8)
150        self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
151
152    def _distribute_region(self, spec: DTensorSpec):
153        pass
154
155
156class OffsetBasedRNGTracker(_RNGStateTracker):
157    """
158    This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states
159    should be shared and synchronized among all ranks to respect the semantics of DTensor
160    random operators.
161    """
162
163    def __init__(self, device_type: str = "cuda"):
164        super().__init__(device_type)
165        # synchronize RNG state using rank 0's current one
166        rng_state = self._device_handle.get_rng_state().to(device_type)
167        dist.broadcast(rng_state, 0)
168        self.rng_states["parallel-rng"] = rng_state.to("cpu")
169
170    def _manual_seed(self, parallel_seed: int) -> None:
171        self.set_seed("parallel-rng", parallel_seed)
172
173    @contextlib.contextmanager
174    def _distribute_region(self, spec: DTensorSpec):
175        # check if the parallel rng state has been synchronized or not
176        if not self.rng_state_is_sync("parallel-rng"):
177            raise RuntimeError(
178                "OffsetBasedRNGTracker requires the random state to be synchronized "
179                "before entering into a distribute region!"
180            )
181
182        if self.distribute_region_enabled:
183            old_offset = self.get_offset("parallel-rng")
184            self._set_pre_op_offset(spec)
185            with torch.random.fork_rng(self._devices, device_type=self._device_type):
186                self._device_handle.set_rng_state(self.rng_states["parallel-rng"])
187                try:
188                    yield  # execute the region code
189                finally:
190                    # update offset to synchronize among ranks
191                    self._set_post_op_offset(spec, old_offset)
192        else:
193            yield
194
195    def get_offset(self, name: str) -> int:
196        if name not in self.rng_states:
197            raise RuntimeError(
198                f"{self.__class__.__name__} does not have random state for {name}"
199            )
200
201        offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64)
202        return int(offset_tensor.item())
203
204    def set_offset(self, name: str, offset: int) -> None:
205        if name not in self.rng_states:
206            raise RuntimeError(
207                f"{self.__class__.__name__} does not have random state for {name}"
208            )
209
210        seed_tensor = (self.rng_states[name])[0:8]
211        offset_tensor = torch.tensor([offset]).view(torch.uint8)
212        self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
213
214    def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
215        """Set the starting RNG offset for current device's local shard before actual
216        op execution. The pre_op_offset value should start from the current RNG offset
217        and increment by the size of local shard until it reaches the size of the whole
218        DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset
219        will be the same.
220
221        Args:
222            spec (:class:`DTensorSpec`): the spec of the DTensor object on which
223                we prepare the offset for running random ops.
224
225        Returns:
226            None
227
228        .. warning::
229            Note that, current implementation does not consider DTensor's continguity.
230
231        Example:
232            take a DTensor of shape [8, 16] as an example. Assume that the DTensor
233            is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
234            and the mesh is:
235                [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
236            ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank
237            in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
238
239            Another concept to introduce besides rank coordinate is shard coordinate.
240            Each rank holds a local shard of the DTensor. In the example, the DTensor
241            is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
242            rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
243            That being said, the local shard on rank 0 and rank 2 correspond to the same
244            shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
245            (in the example, it will be a tuple (i, j) where shard (i, j) has the slice
246            DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
247
248            Once we have rank coordinate and shard coordinate, we can calculate on each rank
249            what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
250            of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
251            (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
252            Following this calculation,
253            rank 0 and rank 2 holds the shard of coord (0, 0);
254            rank 1 and rank 3 holds the shard of coord (0, 1);
255            rank 4 and rank 6 holds the shard of coord (1, 0);
256            rank 5 and rank 7 holds the shard of coord (1, 1);
257
258            The last value to calculate before obtaining the starting offset is the shard linear index.
259            The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
260        """
261        dtensor_shape = spec.shape
262        mesh = spec.mesh
263        dim_map = spec.dim_map
264
265        # Compute shard coordinate:
266        # The coordinate on each tensor dim is a tuple (idx, range)
267        # If a DTensor is partitioned on its dim i into n shards, and the current rank
268        # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
269        coordinate = mesh.get_coordinate()
270        assert coordinate is not None
271        shard_coord = [
272            coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dim_map
273        ]
274        shard_size = [
275            mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dim_map
276        ]
277
278        # compute shard linear index
279        shard_linear_idx = self._calc_shard_linear_idx(shard_coord, shard_size)
280
281        # compute starting offset using the first shard's size
282        local_size_on_rank_0 = list(dtensor_shape)
283        for idx, placement in enumerate(spec.placements):
284            if isinstance(placement, Shard):
285                mesh_dim_size = mesh.size(idx)
286                shard_dim = placement.dim
287                local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim(
288                    dtensor_shape[shard_dim],
289                    mesh_dim_size,
290                    0,
291                    return_offset=False,
292                )[0]
293
294        from torch.distributed.tensor._ops.utils import prod
295
296        local_size = prod(local_size_on_rank_0)
297
298        # get current RNG offset
299        current_offset = self.get_offset("parallel-rng")
300
301        # pytorch: offset must be multiple of 4
302        # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
303        offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4
304        self.set_offset("parallel-rng", current_offset + offset_incr)
305
306    def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
307        """Sets the RNG to a synchronized state after running the local random op. Every
308        rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is
309        the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor
310        random ops.
311
312        Args:
313            spec (:class:`DTensorSpec`): the spec of the DTensor object on which
314                we post-process the offset for running random ops.
315
316        Returns:
317            None
318        """
319        dtensor_shape = spec.shape
320
321        from torch.distributed.tensor._ops.utils import prod
322
323        numel = prod(dtensor_shape)
324        # pytorch: offset must be multiple of 4
325        # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
326        numel = (numel + 3) // 4 * 4
327        self.set_offset("parallel-rng", old_offset + numel)
328
329    def _calc_shard_linear_idx(
330        self, shard_coord: List[int], shard_size: List[int]
331    ) -> int:
332        # compute shard linear index
333        shard_linear_idx = 0
334        shard_coord_stride = 1
335        for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
336            shard_linear_idx += idx * shard_coord_stride
337            shard_coord_stride *= size
338
339        return shard_linear_idx
340
341
342class TensorParallelRNGTracker(_RNGStateTracker):
343    def __init__(self, device_type: str = "cuda"):
344        super().__init__(device_type)
345        # copy the default RNG state
346        self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()
347
348    def _manual_seed(
349        self,
350        tp_mesh: DeviceMesh,
351        base_seed: int = 1234,
352    ):
353        tensor_parallel_rank = tp_mesh.get_local_rank()
354        # this magic number 2718 comes from Megatron's code
355        # (https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/random.py#L162-L163)
356        MegatronMagicNum = 2718
357        tensor_parallel_seed = base_seed + MegatronMagicNum + tensor_parallel_rank
358        self.set_seed("tensor-parallel-rng", tensor_parallel_seed)
359
360    @contextlib.contextmanager
361    def _distribute_region(self, spec: DTensorSpec):
362        # check if the tensor parallel rng state has been synchronized or not
363        if not self.rng_state_is_sync("tensor-parallel-rng"):
364            raise RuntimeError(
365                "TensorParallelRNGTracker requires the random state to be synchronized "
366                "before entering into a distribute region!"
367            )
368
369        if self.distribute_region_enabled:
370            with torch.random.fork_rng(self._devices, device_type=self._device_type):
371                self._device_handle.set_rng_state(
372                    self.rng_states["tensor-parallel-rng"]
373                )
374                try:
375                    yield
376                finally:
377                    self.rng_states[
378                        "tensor-parallel-rng"
379                    ] = self._device_handle.get_rng_state()
380        else:
381            yield
382