1# mypy: allow-untyped-defs 2import functools 3from typing import List, TYPE_CHECKING 4 5import torch 6from torch.distributed._shard.op_registry_utils import _decorator_func 7 8from .api import ( 9 _CUSTOM_SHARDED_OPS, 10 _SHARDED_OPS, 11 Shard, 12 ShardedTensor, 13 ShardedTensorBase, 14 ShardedTensorMetadata, 15 TensorProperties, 16) 17from .metadata import ShardMetadata # noqa: F401 18 19 20if TYPE_CHECKING: 21 from torch.distributed._shard.sharding_spec import ShardingSpec 22else: 23 ShardingSpec = "ShardingSpec" 24 25 26def empty( 27 sharding_spec: ShardingSpec, 28 *size, 29 dtype=None, 30 layout=torch.strided, 31 requires_grad=False, 32 pin_memory=False, 33 memory_format=torch.contiguous_format, 34 process_group=None, 35 init_rrefs=False, 36) -> ShardedTensor: 37 """ 38 Returns a :class:`ShardedTensor` filled with uninitialized data. 39 Needs to be called on all ranks in an SPMD fashion. 40 41 Args: 42 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 43 describing how to shard the Tensor. 44 size (int...): a sequence of integers defining the shape of the output 45 tensor. Can be a variable number of arguments or a collection like a list or tuple. 46 47 Keyword args: 48 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 49 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 50 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 51 Default: ``torch.strided``. 52 requires_grad (bool, optional): If autograd should record operations on the 53 returned tensor. Default: ``False``. 54 pin_memory (bool, optional): If set, returned tensor would be allocated in 55 the pinned memory. Works only for CPU tensors. Default: ``False``. 56 memory_format (:class:`torch.memory_format`, optional): the desired memory format of 57 returned Tensor. Default: ``torch.contiguous_format``. 58 process_group (ProcessGroup, optional): The process group to work on. If None, 59 the default process group will be used. 60 init_rrefs (bool, optional): Whether or not to initialize 61 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 62 Need to initialize the RPC Framework if specified as ``True``. 63 Default: ``False``. 64 65 Returns: 66 A :class:`ShardedTensor` object on each rank 67 """ 68 return ShardedTensor( 69 sharding_spec, 70 *size, 71 dtype=dtype, 72 layout=layout, 73 requires_grad=requires_grad, 74 pin_memory=pin_memory, 75 memory_format=memory_format, 76 process_group=process_group, 77 init_rrefs=init_rrefs, 78 ) 79 80 81def ones( 82 sharding_spec: ShardingSpec, 83 *size, 84 dtype=None, 85 layout=torch.strided, 86 requires_grad=False, 87 pin_memory=False, 88 memory_format=torch.contiguous_format, 89 process_group=None, 90 init_rrefs=False, 91) -> ShardedTensor: 92 """ 93 Returns a :class:`ShardedTensor` with the scalar value 1. 94 Needs to be called on all ranks in an SPMD fashion. 95 96 Args: 97 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 98 describing how to shard the Tensor. 99 size (int...): a sequence of integers defining the shape of the output 100 tensor. Can be a variable number of arguments or a collection like a list or tuple. 101 102 Keyword args: 103 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 104 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 105 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 106 Default: ``torch.strided``. 107 requires_grad (bool, optional): If autograd should record operations on the 108 returned tensor. Default: ``False``. 109 pin_memory (bool, optional): If set, returned tensor would be allocated in 110 the pinned memory. Works only for CPU tensors. Default: ``False``. 111 process_group (ProcessGroup, optional): The process group to work on. If None, 112 the default process group will be used. 113 init_rrefs (bool, optional): Whether or not to initialize 114 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 115 Need to initialize the RPC Framework if specified as ``True``. 116 Default: ``False``. 117 118 Returns: 119 A :class:`ShardedTensor` object on each rank 120 """ 121 return full( 122 sharding_spec, 123 size, 124 fill_value=1, 125 dtype=dtype, 126 layout=layout, 127 requires_grad=requires_grad, 128 pin_memory=pin_memory, 129 memory_format=memory_format, 130 process_group=process_group, 131 init_rrefs=init_rrefs, 132 ) 133 134 135def zeros( 136 sharding_spec: ShardingSpec, 137 *size, 138 dtype=None, 139 layout=torch.strided, 140 requires_grad=False, 141 pin_memory=False, 142 memory_format=torch.contiguous_format, 143 process_group=None, 144 init_rrefs=False, 145) -> ShardedTensor: 146 """ 147 Returns a :class:`ShardedTensor` filled with the scalar value 0. 148 Needs to be called on all ranks in an SPMD fashion. 149 150 Args: 151 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 152 describing how to shard the Tensor. 153 size (int...): a sequence of integers defining the shape of the output 154 tensor. Can be a variable number of arguments or a collection like a list or tuple. 155 156 Keyword args: 157 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 158 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 159 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 160 Default: ``torch.strided``. 161 requires_grad (bool, optional): If autograd should record operations on the 162 returned tensor. Default: ``False``. 163 pin_memory (bool, optional): If set, returned tensor would be allocated in 164 the pinned memory. Works only for CPU tensors. Default: ``False``. 165 process_group (ProcessGroup, optional): The process group to work on. If None, 166 the default process group will be used. 167 init_rrefs (bool, optional): Whether or not to initialize 168 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 169 Need to initialize the RPC Framework if specified as ``True``. 170 Default: ``False``. 171 172 Returns: 173 A :class:`ShardedTensor` object on each rank 174 """ 175 return full( 176 sharding_spec, 177 size, 178 fill_value=0, 179 dtype=dtype, 180 layout=layout, 181 requires_grad=requires_grad, 182 pin_memory=pin_memory, 183 memory_format=memory_format, 184 process_group=process_group, 185 init_rrefs=init_rrefs, 186 ) 187 188 189def full( 190 sharding_spec: ShardingSpec, 191 size, 192 fill_value, 193 *, 194 dtype=None, 195 layout=torch.strided, 196 requires_grad=False, 197 pin_memory=False, 198 memory_format=torch.contiguous_format, 199 process_group=None, 200 init_rrefs=False, 201) -> ShardedTensor: 202 """ 203 Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype 204 is inferred from fill_value. If dtype is specified, it will override the 205 inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. 206 Args: 207 sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification 208 describing how to shard the Tensor. 209 size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the 210 output tensor. 211 fill_value (Scalar) - the value to fill the output tensor with. 212 Keyword args: 213 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 214 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 215 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 216 Default: ``torch.strided``. 217 requires_grad (bool, optional): If autograd should record operations on the 218 returned tensor. Default: ``False``. 219 pin_memory (bool, optional): If set, returned tensor would be allocated in 220 the pinned memory. Works only for CPU tensors. Default: ``False``. 221 process_group (ProcessGroup, optional): The process group to work on. If None, 222 the default process group will be used. 223 init_rrefs (bool, optional): Whether or not to initialize 224 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 225 Need to initialize the RPC Framework if specified as ``True``. 226 Default: ``False``. 227 Returns: 228 A :class:`ShardedTensor` object on each rank 229 """ 230 sharded_tensor = ShardedTensor( 231 sharding_spec, 232 *size, 233 dtype=dtype, 234 layout=layout, 235 requires_grad=requires_grad, 236 pin_memory=pin_memory, 237 memory_format=memory_format, 238 process_group=process_group, 239 init_rrefs=init_rrefs, 240 ) 241 torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] 242 return sharded_tensor 243 244 245def rand( 246 sharding_spec: ShardingSpec, 247 *size, 248 dtype=None, 249 layout=torch.strided, 250 requires_grad=False, 251 pin_memory=False, 252 memory_format=torch.contiguous_format, 253 process_group=None, 254 init_rrefs=False, 255) -> ShardedTensor: 256 """ 257 Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution 258 on the interval :math:`[0, 1)`. The shape of the tensor is defined by the 259 variable argument `size`. Needs to be called on all ranks in an SPMD fashion. 260 261 Args: 262 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 263 describing how to shard the Tensor. 264 size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the 265 output tensor. 266 267 Keyword args: 268 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 269 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 270 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 271 Default: ``torch.strided``. 272 requires_grad (bool, optional): If autograd should record operations on the 273 returned tensor. Default: ``False``. 274 pin_memory (bool, optional): If set, returned tensor would be allocated in 275 the pinned memory. Works only for CPU tensors. Default: ``False``. 276 process_group (ProcessGroup, optional): The process group to work on. If None, 277 the default process group will be used. 278 init_rrefs (bool, optional): Whether or not to initialize 279 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 280 Need to initialize the RPC Framework if specified as ``True``. 281 Default: ``False``. 282 283 Returns: 284 A :class:`ShardedTensor` object on each rank 285 """ 286 sharded_tensor = ShardedTensor( 287 sharding_spec, 288 *size, 289 dtype=dtype, 290 layout=layout, 291 requires_grad=requires_grad, 292 pin_memory=pin_memory, 293 memory_format=memory_format, 294 process_group=process_group, 295 init_rrefs=init_rrefs, 296 ) 297 torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] 298 return sharded_tensor 299 300 301def randn( 302 sharding_spec: ShardingSpec, 303 *size, 304 dtype=None, 305 layout=torch.strided, 306 requires_grad=False, 307 pin_memory=False, 308 memory_format=torch.contiguous_format, 309 process_group=None, 310 init_rrefs=False, 311) -> ShardedTensor: 312 """ 313 Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution 314 with mean `0` and variance `1` (also called standard normal distribution). The shape 315 of the tensor is defined by the variable argument `size`. Needs to be called on all ranks 316 in an SPMD fashion. 317 318 Args: 319 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 320 describing how to shard the Tensor. 321 size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the 322 output tensor. 323 324 Keyword args: 325 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 326 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 327 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 328 Default: ``torch.strided``. 329 requires_grad (bool, optional): If autograd should record operations on the 330 returned tensor. Default: ``False``. 331 pin_memory (bool, optional): If set, returned tensor would be allocated in 332 the pinned memory. Works only for CPU tensors. Default: ``False``. 333 process_group (ProcessGroup, optional): The process group to work on. If None, 334 the default process group will be used. 335 init_rrefs (bool, optional): Whether or not to initialize 336 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 337 Need to initialize the RPC Framework if specified as ``True``. 338 Default: ``False``. 339 340 Returns: 341 A :class:`ShardedTensor` object on each rank 342 """ 343 sharded_tensor = ShardedTensor( 344 sharding_spec, 345 *size, 346 dtype=dtype, 347 layout=layout, 348 requires_grad=requires_grad, 349 pin_memory=pin_memory, 350 memory_format=memory_format, 351 process_group=process_group, 352 init_rrefs=init_rrefs, 353 ) 354 torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] 355 return sharded_tensor 356 357 358def init_from_local_shards( 359 local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False 360) -> ShardedTensor: 361 """ 362 Creates an :class:`ShardedTensor` from local shards and the global metadata. 363 Needs to be called on all ranks in an SPMD fashion. 364 365 Args: 366 local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list 367 of shards that represent the local shards on this rank. 368 global_size (int...): a list, tuple, or `torch.Size` of integers defining the 369 shape of the overall sharded tensor. 370 371 Keyword args: 372 process_group (ProcessGroup, optional): The process group to work on. If None, 373 the default process group will be used. 374 init_rrefs (bool, optional): Whether or not to initialize 375 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 376 Need to initialize the RPC Framework if specified as ``True``. 377 Default: ``False``. 378 379 Returns: 380 A :class:`ShardedTensor` object handle on this rank 381 382 383 Examples: 384 Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), 385 each shard have a (5, 5) local tensor, we can do it like below: 386 387 on rank 0: 388 >>> # xdoctest: +SKIP("not distributed") 389 >>> local_shard_metadata = ShardMetadata( 390 >>> shard_offsets=[0, 0], 391 >>> shard_lengths=[5, 5], 392 >>> placement="rank:0/cuda:0" 393 >>> ) 394 >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] 395 >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) 396 397 on rank 1: 398 >>> # xdoctest: +SKIP("not distributed") 399 >>> local_shard_metadata = ShardMetadata( 400 >>> shard_offsets=[5, 0], 401 >>> shard_lengths=[5, 5], 402 >>> placement="rank:1/cuda:1" 403 >>> ) 404 >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] 405 >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) 406 """ 407 return ShardedTensor._init_from_local_shards( 408 local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs 409 ) 410 411 412def state_dict_hook(module, destination, prefix, local_metadata): 413 """ 414 Hook to add ShardedTensor to Module's ``state_dict``. Needs to be 415 registered to the Module using 416 :meth:`torch.nn.Module._register_state_dict_hook`. 417 """ 418 for submodule_name, submodule in module.named_modules(): 419 for attr_name, attr in submodule.__dict__.items(): 420 if isinstance(attr, ShardedTensor): 421 mod_prefix = prefix + submodule_name 422 key = mod_prefix + ("." if mod_prefix else "") + attr_name 423 destination[key] = attr 424 425 426def pre_load_state_dict_hook( 427 module, 428 state_dict, 429 prefix, 430 local_metadata, 431 strict, 432 missing_keys, 433 unexpected_keys, 434 error_msgs, 435): 436 """ 437 Pre-load state dict hook to add ShardedTensor to the module. 438 """ 439 for submodule_name, submodule in module.named_modules(): 440 for attr_name in submodule.__dict__.keys(): 441 mod_prefix = prefix + submodule_name 442 key = mod_prefix + ("." if mod_prefix else "") + attr_name 443 if key in state_dict: 444 if isinstance(state_dict[key], ShardedTensor): 445 setattr(submodule, attr_name, state_dict[key]) 446 447 448def custom_sharded_op_impl(func): 449 """ 450 Provides a way for users to write their own custom sharded operator. This 451 can be used to override existing ShardedTensor operators or write a new 452 one not supported by ShardedTensor. If the operator in question is covered 453 by ``__torch_function__`` dispatch and has a ShardedTensor as any of its 454 parameters, the function provided will be invoked for that operator. 455 456 Example:: 457 >>> # xdoctest: +SKIP 458 >>> @custom_sharded_op_impl(torch.nn.functional.linear) 459 >>> def my_custom_sharded_linear(types, args, kwargs, process_group): 460 >>> ... 461 >>> # xdoctest: +SKIP("Undefined variables") 462 >>> input = torch.rand(10, 32) 463 >>> weight = sharded_tensor.rand(32, 16) 464 >>> bias = torch.rand(16) 465 >>> # This will call 'my_custom_sharded_linear' 466 >>> torch.nn.functional.linear(input, weight, bias) 467 468 The types, args and kwargs parameters are the same parameters that are 469 passed to ``__torch_function__`` dispatch API 470 (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). 471 There is an additional ``process_group`` parameter which is the 472 process_group used for the ShardedTensor and can be used by 473 implementations for communications within a sharded implementation. 474 475 Args: 476 func(Callable): Torch function for which we want to provide a sharded 477 implementation (ex: torch.nn.functional.linear) 478 """ 479 return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) 480 481 482def _sharded_op_impl(func): 483 """ 484 Decorator to register a default sharded op. 485 """ 486 return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) 487 488 489# Import all builtin sharded ops 490from ._ops import * # noqa: F403 491