xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_shard_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import itertools
4import math
5from typing import Optional
6
7import torch
8import torch.distributed as dist
9from torch._utils import _get_device_module
10from torch.distributed import distributed_c10d
11from torch.distributed._shard.sharded_tensor import (
12    Shard,
13    ShardedTensor,
14    ShardedTensorMetadata,
15    TensorProperties,
16)
17from torch.distributed._shard.sharding_spec import ShardMetadata
18from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
19
20
21def _get_remote_device_str(rank, device_type, num_devices_per_node):
22    if device_type.lower() == "cpu":
23        return f"rank:{rank}/{device_type}"
24    elif device_type.lower() == "hpu":
25        return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}"
26    else:
27        return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
28
29
30def _create_chunk_sharded_tensor(
31    tensor: torch.Tensor,
32    rank: int,
33    world_size: int,
34    num_devices_per_node: int,
35    pg: dist.ProcessGroup,
36    device: Optional[torch.device] = None,
37) -> ShardedTensor:
38    """
39    Shard a tensor to chunks along the first dimension. The local rank will gets its
40    corresponding chunk as the local shard to create a ShardedTensor.
41    """
42    chunks = tensor.chunk(world_size, dim=0)
43    if len(chunks) > rank:
44        local_shard = chunks[rank].clone()
45        offsets = [0 for _ in tensor.size()]
46        offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
47        local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
48    else:
49        local_shards = []
50
51    # Create a ShardedTensor without invoking communication.
52    chunk_sizes = [list(chunk.size()) for chunk in chunks]
53    dim0_offsets = [0] + list(
54        itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
55    )[:-1]
56    offsets = [0] * (len(chunk_sizes[0]) - 1)
57    chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
58    device_type = (
59        distributed_c10d._get_pg_default_device(pg).type
60        if device is None
61        else device.type
62    )
63    placements = [
64        _get_remote_device_str(
65            dist.get_global_rank(pg, r),
66            device_type,
67            num_devices_per_node,
68        )
69        for r in range(len(chunk_sizes))
70    ]
71    assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
72    shard_metadata = [
73        ShardMetadata(offset, size, placement)
74        for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
75    ]
76    sharded_tensor_metadata = ShardedTensorMetadata(
77        shards_metadata=shard_metadata,
78        size=tensor.size(),
79        tensor_properties=TensorProperties(
80            dtype=tensor.dtype,
81            layout=tensor.layout,
82            requires_grad=False,
83            memory_format=torch.contiguous_format,
84            pin_memory=tensor.is_pinned(),
85        ),
86    )
87    return ShardedTensor._init_from_local_shards_and_global_metadata(
88        local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
89    )
90
91
92def _create_chunk_dtensor(
93    tensor: torch.Tensor,
94    rank: int,
95    device_mesh: DeviceMesh,
96) -> DTensor:
97    """
98    Shard a tensor to chunks along the first dimension. The local rank will gets its
99    corresponding chunk as the local tensor to create a DTensor.
100    """
101    # We need to explicitly call .detach() to return a new tensor detached from the current graph.
102    tensor = tensor.clone().detach()
103
104    # FSDP placements: [Shard(0)]
105    # HSDP placements: [Replicate(), Shard(0)]
106    replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
107    shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
108    shard_placements[-1] = DShard(0)  # type: ignore[call-overload]
109
110    return DTensor.from_local(
111        tensor, device_mesh, replicate_placements, run_check=False
112    ).redistribute(
113        placements=shard_placements,
114    )
115
116
117def _all_gather_dtensor(
118    tensor: DTensor,
119    root_mesh: Optional[DeviceMesh],
120) -> torch.Tensor:
121    """
122    All gather a DTensor in its sharded dimension and return the local tensor.
123    """
124    assert (
125        root_mesh == tensor.device_mesh
126    ), "The device mesh of a tensor should be a root mesh."
127
128    placements = list(copy.deepcopy(tensor.placements))
129    # FSDP placements: [Shard(0)] -> [Replicate()]
130    # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
131    placements[-1] = Replicate()
132    tensor = tensor.redistribute(
133        device_mesh=tensor.device_mesh,
134        placements=placements,
135    )
136
137    return tensor.to_local()
138