1from abc import ABC, abstractmethod 2from typing import Any, List, Optional, Tuple 3 4import torch 5import torch.distributed as dist 6from torch.distributed._shard.sharded_tensor.api import ShardedTensor 7from torch.distributed._shard.sharded_tensor.shard import Shard 8from torch.distributed.fsdp._shard_utils import ( 9 _all_gather_dtensor, 10 _create_chunk_dtensor, 11 _create_chunk_sharded_tensor, 12) 13from torch.distributed.tensor import DeviceMesh, DTensor 14 15 16class FSDPExtensions(ABC): 17 """ 18 This enables some customizable hooks to enable composability with tensor 19 parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to 20 set a custom :class:`FSDPExtensions` that implements the hooks. 21 """ 22 23 @abstractmethod 24 def pre_flatten_transform( 25 self, 26 tensor: torch.Tensor, 27 ) -> Tuple[torch.Tensor, Optional[Any]]: 28 """E.g. converting ``DistributedTensor`` to local tensor.""" 29 ... 30 31 @abstractmethod 32 def post_unflatten_transform( 33 self, 34 tensor: torch.Tensor, 35 param_extension: Any, 36 ) -> torch.Tensor: 37 """E.g. converting local tensor to ``DistributedTensor``.""" 38 ... 39 40 @abstractmethod 41 def chunk_tensor( 42 self, 43 tensor: torch.Tensor, 44 rank: int, 45 world_size: int, 46 num_devices_per_node: int, 47 pg: dist.ProcessGroup, 48 device: Optional[torch.device] = None, 49 ) -> torch.Tensor: 50 """Shards a tensor to chunks and returns the local chunk.""" 51 ... 52 53 @abstractmethod 54 def chunk_dtensor( 55 self, 56 tensor: torch.Tensor, 57 rank: int, 58 device_mesh: DeviceMesh, 59 ) -> torch.Tensor: 60 """Shards a tensor/DTensor to DTensor and returns the local DTensor.""" 61 ... 62 63 @abstractmethod 64 def pre_load_state_dict_transform( 65 self, 66 tensor: torch.Tensor, 67 ) -> Tuple[torch.Tensor, List[Shard]]: 68 """ 69 This is to be called before loading a *sharded* model state dict and 70 should return the tensor and list of shards from which to load data. 71 """ 72 ... 73 74 @abstractmethod 75 def all_gather_dtensor( 76 self, 77 tensor: DTensor, 78 parent_mesh: Optional[DeviceMesh], 79 ) -> torch.Tensor: 80 """ 81 This is to be called before loading a *sharded* DTensor state dict. 82 This gathers tensor in FSDP dimension and returns local tensor of 83 TP DTensor. 84 """ 85 ... 86 87 88_extensions: Optional[FSDPExtensions] = None 89 90 91def _set_fsdp_extensions(flattener: FSDPExtensions) -> None: 92 global _extensions 93 _extensions = flattener 94 95 96def _ext_pre_flatten_transform( 97 tensor: torch.Tensor, 98 fsdp_extension: Optional[FSDPExtensions] = None, 99) -> Tuple[torch.Tensor, Optional[Any]]: 100 if fsdp_extension is not None: 101 new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor) 102 if param_extension is not None: 103 return new_tensor, param_extension 104 return tensor, None 105 106 107def _ext_post_unflatten_transform( 108 tensor: torch.Tensor, 109 param_extension: Any, 110 fsdp_extension: Optional[FSDPExtensions] = None, 111) -> torch.Tensor: 112 if fsdp_extension is not None and param_extension is not None: 113 return fsdp_extension.post_unflatten_transform(tensor, param_extension) 114 return tensor 115 116 117def _ext_chunk_tensor( 118 tensor: torch.Tensor, 119 rank: int, 120 world_size: int, 121 num_devices_per_node: int, 122 pg: dist.ProcessGroup, 123 fsdp_extension: Optional[FSDPExtensions] = None, 124) -> torch.Tensor: 125 chunk_tensor_fn = ( 126 fsdp_extension.chunk_tensor 127 if fsdp_extension is not None 128 else _create_chunk_sharded_tensor 129 ) 130 return chunk_tensor_fn( 131 tensor, 132 rank, 133 world_size, 134 num_devices_per_node, 135 pg, 136 ) 137 138 139def _ext_chunk_dtensor( 140 tensor: torch.Tensor, 141 rank: int, 142 device_mesh: DeviceMesh, 143 fsdp_extension: Optional[FSDPExtensions] = None, 144) -> torch.Tensor: 145 chunk_dtensor_fn = ( 146 fsdp_extension.chunk_dtensor 147 if fsdp_extension is not None 148 else _create_chunk_dtensor 149 ) 150 return chunk_dtensor_fn( 151 tensor, 152 rank, 153 device_mesh, 154 ) 155 156 157def _ext_pre_load_state_dict_transform( 158 tensor: torch.Tensor, 159 fsdp_extension: Optional[FSDPExtensions] = None, 160) -> Tuple[torch.Tensor, List[Shard]]: 161 if fsdp_extension is not None: 162 return fsdp_extension.pre_load_state_dict_transform(tensor) 163 164 assert type(tensor) is ShardedTensor 165 shards = tensor.local_shards() 166 return (tensor, shards) 167 168 169def _ext_all_gather_dtensor( 170 tensor: DTensor, 171 parent_mesh: Optional[DeviceMesh], 172 fsdp_extension: Optional[FSDPExtensions] = None, 173) -> torch.Tensor: 174 all_gather_dtensor_fn = ( 175 fsdp_extension.all_gather_dtensor 176 if fsdp_extension is not None 177 else _all_gather_dtensor 178 ) 179 return all_gather_dtensor_fn(tensor, parent_mesh) 180