1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from typing import List 4 5import torch 6from torch.distributed._shard.metadata import ShardMetadata 7from torch.distributed.remote_device import _remote_device 8 9 10@dataclass 11class Shard: 12 """ 13 Container which holds the data for a shard as a Tensor and also 14 the associated metadata for that shard. 15 16 Args: 17 tensor(torch.Tensor): Local tensor for the shard. 18 metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): 19 The metadata for the shard, including offsets, lengths and device placement. 20 """ 21 22 __slots__ = ["tensor", "metadata"] 23 tensor: torch.Tensor 24 metadata: ShardMetadata 25 26 def __post_init__(self): 27 # verification between local tensor and metadata 28 if list(self.tensor.size()) != self.metadata.shard_sizes: 29 raise ValueError( 30 "Shard tensor size does not match with metadata.shard_lengths! " 31 f"Found shard tensor size: {list(self.tensor.size())}, " 32 f"metadata.shard_lengths: {self.metadata.shard_sizes}, " 33 ) 34 placement_device = self.metadata.placement 35 if ( 36 placement_device is not None 37 and placement_device.device() != self.tensor.device 38 ): 39 raise ValueError( 40 f"Local shard tensor device does not match with local Shard's placement! " 41 f"Found local shard tensor device: {self.tensor.device}, " 42 f"local shard metadata placement device: {placement_device.device()}" 43 ) 44 45 @classmethod 46 def from_tensor_and_offsets( 47 cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int 48 ): 49 """ 50 Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. 51 52 Args: 53 tensor(torch.Tensor): Local tensor for the shard. 54 shard_offsets(List[int]): List of integers specify the offset 55 of the shard on each dimension. 56 rank(int): Specify the rank for the shard. 57 """ 58 shard_sizes = list(tensor.size()) 59 placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") 60 shard_meta = ShardMetadata( 61 shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement 62 ) 63 return Shard(tensor, shard_meta) 64