xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Sequence
2
3import torch
4from torch.distributed._shard.metadata import ShardMetadata
5
6
7DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor."
8
9
10def narrow_tensor_by_index(
11    tensor: torch.Tensor,
12    offsets: Sequence[int],
13    sizes: Sequence[int],
14) -> torch.Tensor:
15    """
16    Narrow the tensor according to ``offsets`` and ``sizes``.
17    """
18    narrowed_tensor = tensor
19    for idx, (offset, size) in enumerate(zip(offsets, sizes)):
20        if size < tensor.size(idx):
21            # Reshape to get shard for this rank and we don't want autograd
22            # recording here for the narrow op and 'local_shard' should be a
23            # leaf variable in the autograd graph.
24            narrowed_tensor = narrowed_tensor.narrow(idx, offset, size)
25    return narrowed_tensor
26
27
28def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
29    """
30    Narrow the tensor according to the metadata
31    """
32    return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)
33