1from typing import Sequence 2 3import torch 4from torch.distributed._shard.metadata import ShardMetadata 5 6 7DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." 8 9 10def narrow_tensor_by_index( 11 tensor: torch.Tensor, 12 offsets: Sequence[int], 13 sizes: Sequence[int], 14) -> torch.Tensor: 15 """ 16 Narrow the tensor according to ``offsets`` and ``sizes``. 17 """ 18 narrowed_tensor = tensor 19 for idx, (offset, size) in enumerate(zip(offsets, sizes)): 20 if size < tensor.size(idx): 21 # Reshape to get shard for this rank and we don't want autograd 22 # recording here for the narrow op and 'local_shard' should be a 23 # leaf variable in the autograd graph. 24 narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) 25 return narrowed_tensor 26 27 28def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: 29 """ 30 Narrow the tensor according to the metadata 31 """ 32 return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes) 33