xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_fsdp_extensions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from abc import ABC, abstractmethod
2from typing import Any, List, Optional, Tuple
3
4import torch
5import torch.distributed as dist
6from torch.distributed._shard.sharded_tensor.api import ShardedTensor
7from torch.distributed._shard.sharded_tensor.shard import Shard
8from torch.distributed.fsdp._shard_utils import (
9    _all_gather_dtensor,
10    _create_chunk_dtensor,
11    _create_chunk_sharded_tensor,
12)
13from torch.distributed.tensor import DeviceMesh, DTensor
14
15
16class FSDPExtensions(ABC):
17    """
18    This enables some customizable hooks to enable composability with tensor
19    parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
20    set a custom :class:`FSDPExtensions` that implements the hooks.
21    """
22
23    @abstractmethod
24    def pre_flatten_transform(
25        self,
26        tensor: torch.Tensor,
27    ) -> Tuple[torch.Tensor, Optional[Any]]:
28        """E.g. converting ``DistributedTensor`` to local tensor."""
29        ...
30
31    @abstractmethod
32    def post_unflatten_transform(
33        self,
34        tensor: torch.Tensor,
35        param_extension: Any,
36    ) -> torch.Tensor:
37        """E.g. converting local tensor to ``DistributedTensor``."""
38        ...
39
40    @abstractmethod
41    def chunk_tensor(
42        self,
43        tensor: torch.Tensor,
44        rank: int,
45        world_size: int,
46        num_devices_per_node: int,
47        pg: dist.ProcessGroup,
48        device: Optional[torch.device] = None,
49    ) -> torch.Tensor:
50        """Shards a tensor to chunks and returns the local chunk."""
51        ...
52
53    @abstractmethod
54    def chunk_dtensor(
55        self,
56        tensor: torch.Tensor,
57        rank: int,
58        device_mesh: DeviceMesh,
59    ) -> torch.Tensor:
60        """Shards a tensor/DTensor to DTensor and returns the local DTensor."""
61        ...
62
63    @abstractmethod
64    def pre_load_state_dict_transform(
65        self,
66        tensor: torch.Tensor,
67    ) -> Tuple[torch.Tensor, List[Shard]]:
68        """
69        This is to be called before loading a *sharded* model state dict and
70        should return the tensor and list of shards from which to load data.
71        """
72        ...
73
74    @abstractmethod
75    def all_gather_dtensor(
76        self,
77        tensor: DTensor,
78        parent_mesh: Optional[DeviceMesh],
79    ) -> torch.Tensor:
80        """
81        This is to be called before loading a *sharded* DTensor state dict.
82        This gathers tensor in FSDP dimension and returns local tensor of
83        TP DTensor.
84        """
85        ...
86
87
88_extensions: Optional[FSDPExtensions] = None
89
90
91def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
92    global _extensions
93    _extensions = flattener
94
95
96def _ext_pre_flatten_transform(
97    tensor: torch.Tensor,
98    fsdp_extension: Optional[FSDPExtensions] = None,
99) -> Tuple[torch.Tensor, Optional[Any]]:
100    if fsdp_extension is not None:
101        new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor)
102        if param_extension is not None:
103            return new_tensor, param_extension
104    return tensor, None
105
106
107def _ext_post_unflatten_transform(
108    tensor: torch.Tensor,
109    param_extension: Any,
110    fsdp_extension: Optional[FSDPExtensions] = None,
111) -> torch.Tensor:
112    if fsdp_extension is not None and param_extension is not None:
113        return fsdp_extension.post_unflatten_transform(tensor, param_extension)
114    return tensor
115
116
117def _ext_chunk_tensor(
118    tensor: torch.Tensor,
119    rank: int,
120    world_size: int,
121    num_devices_per_node: int,
122    pg: dist.ProcessGroup,
123    fsdp_extension: Optional[FSDPExtensions] = None,
124) -> torch.Tensor:
125    chunk_tensor_fn = (
126        fsdp_extension.chunk_tensor
127        if fsdp_extension is not None
128        else _create_chunk_sharded_tensor
129    )
130    return chunk_tensor_fn(
131        tensor,
132        rank,
133        world_size,
134        num_devices_per_node,
135        pg,
136    )
137
138
139def _ext_chunk_dtensor(
140    tensor: torch.Tensor,
141    rank: int,
142    device_mesh: DeviceMesh,
143    fsdp_extension: Optional[FSDPExtensions] = None,
144) -> torch.Tensor:
145    chunk_dtensor_fn = (
146        fsdp_extension.chunk_dtensor
147        if fsdp_extension is not None
148        else _create_chunk_dtensor
149    )
150    return chunk_dtensor_fn(
151        tensor,
152        rank,
153        device_mesh,
154    )
155
156
157def _ext_pre_load_state_dict_transform(
158    tensor: torch.Tensor,
159    fsdp_extension: Optional[FSDPExtensions] = None,
160) -> Tuple[torch.Tensor, List[Shard]]:
161    if fsdp_extension is not None:
162        return fsdp_extension.pre_load_state_dict_transform(tensor)
163
164    assert type(tensor) is ShardedTensor
165    shards = tensor.local_shards()
166    return (tensor, shards)
167
168
169def _ext_all_gather_dtensor(
170    tensor: DTensor,
171    parent_mesh: Optional[DeviceMesh],
172    fsdp_extension: Optional[FSDPExtensions] = None,
173) -> torch.Tensor:
174    all_gather_dtensor_fn = (
175        fsdp_extension.all_gather_dtensor
176        if fsdp_extension is not None
177        else _all_gather_dtensor
178    )
179    return all_gather_dtensor_fn(tensor, parent_mesh)
180