1# mypy: allow-untyped-defs 2import copy 3import itertools 4import math 5from typing import Optional 6 7import torch 8import torch.distributed as dist 9from torch._utils import _get_device_module 10from torch.distributed import distributed_c10d 11from torch.distributed._shard.sharded_tensor import ( 12 Shard, 13 ShardedTensor, 14 ShardedTensorMetadata, 15 TensorProperties, 16) 17from torch.distributed._shard.sharding_spec import ShardMetadata 18from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard 19 20 21def _get_remote_device_str(rank, device_type, num_devices_per_node): 22 if device_type.lower() == "cpu": 23 return f"rank:{rank}/{device_type}" 24 elif device_type.lower() == "hpu": 25 return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}" 26 else: 27 return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}" 28 29 30def _create_chunk_sharded_tensor( 31 tensor: torch.Tensor, 32 rank: int, 33 world_size: int, 34 num_devices_per_node: int, 35 pg: dist.ProcessGroup, 36 device: Optional[torch.device] = None, 37) -> ShardedTensor: 38 """ 39 Shard a tensor to chunks along the first dimension. The local rank will gets its 40 corresponding chunk as the local shard to create a ShardedTensor. 41 """ 42 chunks = tensor.chunk(world_size, dim=0) 43 if len(chunks) > rank: 44 local_shard = chunks[rank].clone() 45 offsets = [0 for _ in tensor.size()] 46 offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank 47 local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)] 48 else: 49 local_shards = [] 50 51 # Create a ShardedTensor without invoking communication. 52 chunk_sizes = [list(chunk.size()) for chunk in chunks] 53 dim0_offsets = [0] + list( 54 itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]) 55 )[:-1] 56 offsets = [0] * (len(chunk_sizes[0]) - 1) 57 chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] 58 device_type = ( 59 distributed_c10d._get_pg_default_device(pg).type 60 if device is None 61 else device.type 62 ) 63 placements = [ 64 _get_remote_device_str( 65 dist.get_global_rank(pg, r), 66 device_type, 67 num_devices_per_node, 68 ) 69 for r in range(len(chunk_sizes)) 70 ] 71 assert len(chunk_sizes) == len(chunk_offsets) == len(placements) 72 shard_metadata = [ 73 ShardMetadata(offset, size, placement) 74 for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) 75 ] 76 sharded_tensor_metadata = ShardedTensorMetadata( 77 shards_metadata=shard_metadata, 78 size=tensor.size(), 79 tensor_properties=TensorProperties( 80 dtype=tensor.dtype, 81 layout=tensor.layout, 82 requires_grad=False, 83 memory_format=torch.contiguous_format, 84 pin_memory=tensor.is_pinned(), 85 ), 86 ) 87 return ShardedTensor._init_from_local_shards_and_global_metadata( 88 local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg 89 ) 90 91 92def _create_chunk_dtensor( 93 tensor: torch.Tensor, 94 rank: int, 95 device_mesh: DeviceMesh, 96) -> DTensor: 97 """ 98 Shard a tensor to chunks along the first dimension. The local rank will gets its 99 corresponding chunk as the local tensor to create a DTensor. 100 """ 101 # We need to explicitly call .detach() to return a new tensor detached from the current graph. 102 tensor = tensor.clone().detach() 103 104 # FSDP placements: [Shard(0)] 105 # HSDP placements: [Replicate(), Shard(0)] 106 replicate_placements = [Replicate() for _ in range(device_mesh.ndim)] 107 shard_placements = [Replicate() for _ in range(device_mesh.ndim)] 108 shard_placements[-1] = DShard(0) # type: ignore[call-overload] 109 110 return DTensor.from_local( 111 tensor, device_mesh, replicate_placements, run_check=False 112 ).redistribute( 113 placements=shard_placements, 114 ) 115 116 117def _all_gather_dtensor( 118 tensor: DTensor, 119 root_mesh: Optional[DeviceMesh], 120) -> torch.Tensor: 121 """ 122 All gather a DTensor in its sharded dimension and return the local tensor. 123 """ 124 assert ( 125 root_mesh == tensor.device_mesh 126 ), "The device mesh of a tensor should be a root mesh." 127 128 placements = list(copy.deepcopy(tensor.placements)) 129 # FSDP placements: [Shard(0)] -> [Replicate()] 130 # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] 131 placements[-1] = Replicate() 132 tensor = tensor.redistribute( 133 device_mesh=tensor.device_mesh, 134 placements=placements, 135 ) 136 137 return tensor.to_local() 138