xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_dtensor_spec.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from dataclasses import dataclass
2from typing import Any, cast, List, NamedTuple, Optional, Tuple
3
4import torch
5from torch.distributed.device_mesh import DeviceMesh
6from torch.distributed.tensor.placement_types import (
7    Partial,
8    Placement,
9    Replicate,
10    Shard,
11)
12
13
14class TensorMeta(NamedTuple):
15    # simple named tuple to represent tensor metadata
16    # intentionally to stay simple only for sharding
17    # propagation purposes.
18    shape: torch.Size
19    stride: Tuple[int, ...]
20    dtype: torch.dtype
21
22
23# used internally to propagate the placements
24@dataclass
25class DTensorSpec:
26    mesh: DeviceMesh
27    placements: Tuple[Placement, ...]
28
29    # tensor meta will only be set during sharding propagation
30    tensor_meta: Optional[TensorMeta] = None
31
32    def __post_init__(self) -> None:
33        if not isinstance(self.placements, tuple):
34            self.placements = tuple(self.placements)
35        self._hash: Optional[int] = None
36
37    def __setattr__(self, attr: str, value: Any) -> None:
38        super().__setattr__(attr, value)
39        # Make sure to recompute the hash in case any of the hashed attributes
40        # change (though we do not expect `mesh` or `placements` to change)
41        if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
42            self._hash = None
43
44    def _hash_impl(self) -> int:
45        # hashing and equality check for DTensorSpec are used to cache the sharding
46        # propagation results. We only need to consider the mesh, placements, shape
47        # dtype and stride.
48        # Caveat: we need to keep this in mind and sync hash and eq if we add more
49        # fields to them.
50        if self.tensor_meta is not None:
51            return hash(
52                (
53                    self.mesh,
54                    self.placements,
55                    self.tensor_meta.shape,
56                    self.tensor_meta.stride,
57                    self.tensor_meta.dtype,
58                )
59            )
60        return hash((self.mesh, self.placements))
61
62    def __hash__(self) -> int:
63        # We lazily cache the spec to avoid recomputing the hash upon each
64        # use, where we make sure to update the hash when the `tensor_meta`
65        # changes by overriding `__setattr__`. This must be lazy so that Dynamo
66        # does not try to hash non-singleton `SymInt`s for the stride.
67        if self._hash is None:
68            self._hash = self._hash_impl()
69        return self._hash
70
71    def __eq__(self, __o: object) -> bool:
72        if not (
73            isinstance(__o, DTensorSpec)
74            and self.mesh == __o.mesh
75            and self.placements == __o.placements
76        ):
77            return False
78        if self.tensor_meta is None or __o.tensor_meta is None:
79            return self.tensor_meta == __o.tensor_meta
80
81        return (
82            self.tensor_meta.shape == __o.tensor_meta.shape  # type: ignore[union-attr]
83            and self.tensor_meta.stride == __o.tensor_meta.stride  # type: ignore[union-attr]
84            and self.tensor_meta.dtype == __o.tensor_meta.dtype  # type: ignore[union-attr]
85        )
86
87    def __str__(self) -> str:
88        """
89        human readable representation of the DTensorSpec
90        """
91        if len(self.placements) == 1:
92            placement_str = str(self.placements[0])
93        else:
94            placement_str = str(self.placements)
95
96        if self.tensor_meta is not None:
97            tensor_shape = str(tuple(self.tensor_meta.shape))
98        else:
99            tensor_shape = "unknown shape"
100
101        return f"Spec({placement_str} on {tensor_shape})"
102
103    @property
104    def shape(self) -> torch.Size:
105        if self.tensor_meta is None:
106            raise ValueError("tensor_meta is not set")
107        return self.tensor_meta.shape
108
109    @property
110    def stride(self) -> Tuple[int, ...]:
111        if self.tensor_meta is None:
112            raise ValueError("tensor_meta is not set")
113        return self.tensor_meta.stride
114
115    @property
116    def ndim(self) -> int:
117        if self.tensor_meta is None:
118            raise ValueError("tensor_meta is not set")
119        return len(self.tensor_meta.shape)
120
121    @property
122    def num_shards(self) -> int:
123        num_shards = 1
124        for i, placement in enumerate(self.placements):
125            if placement.is_shard():
126                num_shards *= self.mesh.size(i)
127        return num_shards
128
129    @property
130    def device_mesh(self) -> DeviceMesh:
131        # simple aliasing for the mesh field, make some
132        # checks that mixes DTensor/DTensorSpec easier
133        return self.mesh
134
135    @property
136    def dim_map(self) -> List[int]:
137        """
138        dim_map is a property we derive from `placements` of
139        the distributed tensor. It simply return a list of ints
140        where dim_map[i] denotes the sharding mapping to the mesh
141        dimension, and len(dim_map) == dist_tensor.ndim
142        dim_map[i] = -1: means tensor dim i replicate on mesh
143        dim_map[i] = j: means tensor dim i shard on mesh dim j
144
145        For example, we have a dist tensor that have the shape of
146        [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements:
147        [Shard(1)], the dim_map of this placement would be:
148        [-1, 0, -1]. This representation is pretty helpful during
149        sharding propagation where we could know exactly each
150        tensor dimension is sharded or not.
151
152        Note that if placements contains `_Partial`, we have to
153        explicitly deal with it, so that when we create a DTensorSpec
154        with dim_map, we could properly record the pending sums.
155        """
156        # dims mapping of dist tensor sharding
157        # return size of tensor ndim, -1 represent replicate
158        # and int >=0 represent shard on that device mesh dim
159        r = [-1] * self.ndim
160        for i, placement in enumerate(self.placements):
161            if placement.is_shard():
162                shard_dim = cast(Shard, placement).dim
163                if r[shard_dim] > -1:
164                    raise ValueError(
165                        f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]},"
166                        " DTensor operator implementation does not support things like hybrid"
167                        " sharding strategies yet (i.e. [Shard(0), Shard(0)])"
168                    )
169                r[shard_dim] = i
170        return r
171
172    @property
173    def num_shards_map(self) -> List[int]:
174        """
175        dim_map is a property we derive from `placements` of
176        the distributed tensor. Unlike `dim_map`, `num_shards_map`
177        denotes how many shards each tensor dim has. Like `dim_map`:
178            len(num_shards_map) == dist_tensor.ndim
179            num_shards_map[i] = 1: means tensor dim i is not sharded
180            num_shards_map[i] = j: means tensor dim i has j shards in total
181
182        For example, we have a dist tensor of shape [18, 20, 30],
183        a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
184        ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
185        would be: [4, 2, 1].
186        """
187        r = [1] * self.ndim
188        for i, placement in enumerate(self.placements):
189            if placement.is_shard():
190                shard_dim = cast(Shard, placement).dim
191                r[shard_dim] *= self.mesh.size(i)
192
193        return r
194
195    @property
196    def sums(self) -> List[int]:
197        """
198        sums is a property we derive from `placements` of the
199        distributed tensor. It simply return a list of ints where
200        sums[i] denotes the pending sum (partial) on mesh dim i
201        """
202        return [
203            idx
204            for idx, placement in enumerate(self.placements)
205            if placement.is_partial()
206        ]
207
208    @classmethod
209    def from_dim_map(
210        cls,
211        mesh: DeviceMesh,
212        dim_map: List[int],
213        sums: List[int],
214        tensor_meta: Optional[TensorMeta] = None,
215    ) -> "DTensorSpec":
216        """
217        Construct a DTensorSpec from dim_map list and pending sum.
218
219        Args:
220            mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec
221            dim_map (List[int]): a list of integer that represents sharding on each
222                tensor dimension, see `dim_map` property doc for details
223            sums (List[int]): a list of integer that represents the dist tensor have
224                pending sum on which device mesh dimension.
225            tensor meta (TensorMeta): DTensor metadata
226
227        Return:
228            a class:`DTensorSpec` object
229        """
230        # by default replicate on device mesh dims
231        placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
232
233        # find all mesh dims that need pending reductions
234        for s in sums:
235            placements[s] = Partial()
236
237        for i, m in enumerate(dim_map):
238            if m >= 0:
239                placement = placements[m]
240                if placement.is_shard():
241                    placement = cast(Shard, placement)
242                    raise RuntimeError(
243                        f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}"
244                    )
245                elif placement.is_partial():
246                    raise RuntimeError(
247                        f"DeviceMesh dimension {m} cannot be both shard and partial!"
248                    )
249                placements[m] = Shard(i)
250
251        return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
252
253    def is_replicated(self) -> bool:
254        """
255        return True if the current DTensorSpec replicates on all mesh dims (devices)
256        """
257        return all(placement.is_replicate() for placement in self.placements)
258
259    def is_sharded(self) -> bool:
260        """
261        return True if the current DTensorSpec is sharded on any mesh dims (devices)
262        """
263        return any(placement.is_shard() for placement in self.placements)
264
265    def shallow_copy_with_tensor_meta(
266        self, tensor_meta: Optional[TensorMeta]
267    ) -> "DTensorSpec":
268        """
269        Shallow copy the DTensorSpec with a new tensor_meta.
270        """
271        assert tensor_meta is not None, "shallow copy with no tensor_meta!"
272        return DTensorSpec(
273            self.mesh,
274            self.placements,
275            tensor_meta=tensor_meta,
276        )
277