xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharded_tensor/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from typing import List, TYPE_CHECKING
4
5import torch
6from torch.distributed._shard.op_registry_utils import _decorator_func
7
8from .api import (
9    _CUSTOM_SHARDED_OPS,
10    _SHARDED_OPS,
11    Shard,
12    ShardedTensor,
13    ShardedTensorBase,
14    ShardedTensorMetadata,
15    TensorProperties,
16)
17from .metadata import ShardMetadata  # noqa: F401
18
19
20if TYPE_CHECKING:
21    from torch.distributed._shard.sharding_spec import ShardingSpec
22else:
23    ShardingSpec = "ShardingSpec"
24
25
26def empty(
27    sharding_spec: ShardingSpec,
28    *size,
29    dtype=None,
30    layout=torch.strided,
31    requires_grad=False,
32    pin_memory=False,
33    memory_format=torch.contiguous_format,
34    process_group=None,
35    init_rrefs=False,
36) -> ShardedTensor:
37    """
38    Returns a :class:`ShardedTensor` filled with uninitialized data.
39        Needs to be called on all ranks in an SPMD fashion.
40
41    Args:
42        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
43            describing how to shard the Tensor.
44        size (int...): a sequence of integers defining the shape of the output
45            tensor. Can be a variable number of arguments or a collection like a list or tuple.
46
47    Keyword args:
48        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
49            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
50        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
51            Default: ``torch.strided``.
52        requires_grad (bool, optional): If autograd should record operations on the
53            returned tensor. Default: ``False``.
54        pin_memory (bool, optional): If set, returned tensor would be allocated in
55            the pinned memory. Works only for CPU tensors. Default: ``False``.
56        memory_format (:class:`torch.memory_format`, optional): the desired memory format of
57            returned Tensor. Default: ``torch.contiguous_format``.
58        process_group (ProcessGroup, optional): The process group to work on. If None,
59            the default process group will be used.
60        init_rrefs (bool, optional): Whether or not to initialize
61            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
62            Need to initialize the RPC Framework if specified as ``True``.
63            Default: ``False``.
64
65    Returns:
66        A :class:`ShardedTensor` object on each rank
67    """
68    return ShardedTensor(
69        sharding_spec,
70        *size,
71        dtype=dtype,
72        layout=layout,
73        requires_grad=requires_grad,
74        pin_memory=pin_memory,
75        memory_format=memory_format,
76        process_group=process_group,
77        init_rrefs=init_rrefs,
78    )
79
80
81def ones(
82    sharding_spec: ShardingSpec,
83    *size,
84    dtype=None,
85    layout=torch.strided,
86    requires_grad=False,
87    pin_memory=False,
88    memory_format=torch.contiguous_format,
89    process_group=None,
90    init_rrefs=False,
91) -> ShardedTensor:
92    """
93    Returns a :class:`ShardedTensor` with the scalar value 1.
94        Needs to be called on all ranks in an SPMD fashion.
95
96    Args:
97        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
98            describing how to shard the Tensor.
99        size (int...): a sequence of integers defining the shape of the output
100            tensor. Can be a variable number of arguments or a collection like a list or tuple.
101
102    Keyword args:
103        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
104            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
105        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
106            Default: ``torch.strided``.
107        requires_grad (bool, optional): If autograd should record operations on the
108            returned tensor. Default: ``False``.
109        pin_memory (bool, optional): If set, returned tensor would be allocated in
110            the pinned memory. Works only for CPU tensors. Default: ``False``.
111        process_group (ProcessGroup, optional): The process group to work on. If None,
112            the default process group will be used.
113        init_rrefs (bool, optional): Whether or not to initialize
114            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
115            Need to initialize the RPC Framework if specified as ``True``.
116            Default: ``False``.
117
118    Returns:
119        A :class:`ShardedTensor` object on each rank
120    """
121    return full(
122        sharding_spec,
123        size,
124        fill_value=1,
125        dtype=dtype,
126        layout=layout,
127        requires_grad=requires_grad,
128        pin_memory=pin_memory,
129        memory_format=memory_format,
130        process_group=process_group,
131        init_rrefs=init_rrefs,
132    )
133
134
135def zeros(
136    sharding_spec: ShardingSpec,
137    *size,
138    dtype=None,
139    layout=torch.strided,
140    requires_grad=False,
141    pin_memory=False,
142    memory_format=torch.contiguous_format,
143    process_group=None,
144    init_rrefs=False,
145) -> ShardedTensor:
146    """
147    Returns a :class:`ShardedTensor` filled with the scalar value 0.
148        Needs to be called on all ranks in an SPMD fashion.
149
150    Args:
151        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
152            describing how to shard the Tensor.
153        size (int...): a sequence of integers defining the shape of the output
154            tensor. Can be a variable number of arguments or a collection like a list or tuple.
155
156    Keyword args:
157        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
158            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
159        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
160            Default: ``torch.strided``.
161        requires_grad (bool, optional): If autograd should record operations on the
162            returned tensor. Default: ``False``.
163        pin_memory (bool, optional): If set, returned tensor would be allocated in
164            the pinned memory. Works only for CPU tensors. Default: ``False``.
165        process_group (ProcessGroup, optional): The process group to work on. If None,
166            the default process group will be used.
167        init_rrefs (bool, optional): Whether or not to initialize
168            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
169            Need to initialize the RPC Framework if specified as ``True``.
170            Default: ``False``.
171
172    Returns:
173        A :class:`ShardedTensor` object on each rank
174    """
175    return full(
176        sharding_spec,
177        size,
178        fill_value=0,
179        dtype=dtype,
180        layout=layout,
181        requires_grad=requires_grad,
182        pin_memory=pin_memory,
183        memory_format=memory_format,
184        process_group=process_group,
185        init_rrefs=init_rrefs,
186    )
187
188
189def full(
190    sharding_spec: ShardingSpec,
191    size,
192    fill_value,
193    *,
194    dtype=None,
195    layout=torch.strided,
196    requires_grad=False,
197    pin_memory=False,
198    memory_format=torch.contiguous_format,
199    process_group=None,
200    init_rrefs=False,
201) -> ShardedTensor:
202    """
203    Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype
204        is inferred from fill_value. If dtype is specified, it will override the
205        inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
206    Args:
207        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
208            describing how to shard the Tensor.
209        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
210            output tensor.
211        fill_value (Scalar) - the value to fill the output tensor with.
212    Keyword args:
213        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
214            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
215        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
216            Default: ``torch.strided``.
217        requires_grad (bool, optional): If autograd should record operations on the
218            returned tensor. Default: ``False``.
219        pin_memory (bool, optional): If set, returned tensor would be allocated in
220            the pinned memory. Works only for CPU tensors. Default: ``False``.
221        process_group (ProcessGroup, optional): The process group to work on. If None,
222            the default process group will be used.
223        init_rrefs (bool, optional): Whether or not to initialize
224            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
225            Need to initialize the RPC Framework if specified as ``True``.
226            Default: ``False``.
227    Returns:
228        A :class:`ShardedTensor` object on each rank
229    """
230    sharded_tensor = ShardedTensor(
231        sharding_spec,
232        *size,
233        dtype=dtype,
234        layout=layout,
235        requires_grad=requires_grad,
236        pin_memory=pin_memory,
237        memory_format=memory_format,
238        process_group=process_group,
239        init_rrefs=init_rrefs,
240    )
241    torch.nn.init.constant_(sharded_tensor, fill_value)  # type: ignore[arg-type]
242    return sharded_tensor
243
244
245def rand(
246    sharding_spec: ShardingSpec,
247    *size,
248    dtype=None,
249    layout=torch.strided,
250    requires_grad=False,
251    pin_memory=False,
252    memory_format=torch.contiguous_format,
253    process_group=None,
254    init_rrefs=False,
255) -> ShardedTensor:
256    """
257    Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution
258        on the interval :math:`[0, 1)`. The shape of the tensor is defined by the
259        variable argument `size`. Needs to be called on all ranks in an SPMD fashion.
260
261    Args:
262        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
263            describing how to shard the Tensor.
264        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
265            output tensor.
266
267    Keyword args:
268        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
269            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
270        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
271            Default: ``torch.strided``.
272        requires_grad (bool, optional): If autograd should record operations on the
273            returned tensor. Default: ``False``.
274        pin_memory (bool, optional): If set, returned tensor would be allocated in
275            the pinned memory. Works only for CPU tensors. Default: ``False``.
276        process_group (ProcessGroup, optional): The process group to work on. If None,
277            the default process group will be used.
278        init_rrefs (bool, optional): Whether or not to initialize
279            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
280            Need to initialize the RPC Framework if specified as ``True``.
281            Default: ``False``.
282
283    Returns:
284        A :class:`ShardedTensor` object on each rank
285    """
286    sharded_tensor = ShardedTensor(
287        sharding_spec,
288        *size,
289        dtype=dtype,
290        layout=layout,
291        requires_grad=requires_grad,
292        pin_memory=pin_memory,
293        memory_format=memory_format,
294        process_group=process_group,
295        init_rrefs=init_rrefs,
296    )
297    torch.nn.init.uniform_(sharded_tensor, 0, 1)  # type: ignore[arg-type]
298    return sharded_tensor
299
300
301def randn(
302    sharding_spec: ShardingSpec,
303    *size,
304    dtype=None,
305    layout=torch.strided,
306    requires_grad=False,
307    pin_memory=False,
308    memory_format=torch.contiguous_format,
309    process_group=None,
310    init_rrefs=False,
311) -> ShardedTensor:
312    """
313    Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution
314        with mean `0` and variance `1` (also called standard normal distribution). The shape
315        of the tensor is defined by the variable argument `size`. Needs to be called on all ranks
316        in an SPMD fashion.
317
318    Args:
319        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
320            describing how to shard the Tensor.
321        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
322            output tensor.
323
324    Keyword args:
325        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
326            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
327        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
328            Default: ``torch.strided``.
329        requires_grad (bool, optional): If autograd should record operations on the
330            returned tensor. Default: ``False``.
331        pin_memory (bool, optional): If set, returned tensor would be allocated in
332            the pinned memory. Works only for CPU tensors. Default: ``False``.
333        process_group (ProcessGroup, optional): The process group to work on. If None,
334            the default process group will be used.
335        init_rrefs (bool, optional): Whether or not to initialize
336            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
337            Need to initialize the RPC Framework if specified as ``True``.
338            Default: ``False``.
339
340    Returns:
341        A :class:`ShardedTensor` object on each rank
342    """
343    sharded_tensor = ShardedTensor(
344        sharding_spec,
345        *size,
346        dtype=dtype,
347        layout=layout,
348        requires_grad=requires_grad,
349        pin_memory=pin_memory,
350        memory_format=memory_format,
351        process_group=process_group,
352        init_rrefs=init_rrefs,
353    )
354    torch.nn.init.normal_(sharded_tensor, 0, 1)  # type: ignore[arg-type]
355    return sharded_tensor
356
357
358def init_from_local_shards(
359    local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False
360) -> ShardedTensor:
361    """
362    Creates an :class:`ShardedTensor` from local shards and the global metadata.
363    Needs to be called on all ranks in an SPMD fashion.
364
365    Args:
366        local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list
367            of shards that represent the local shards on this rank.
368        global_size (int...):  a list, tuple, or `torch.Size` of integers defining the
369            shape of the overall sharded tensor.
370
371    Keyword args:
372        process_group (ProcessGroup, optional): The process group to work on. If None,
373            the default process group will be used.
374        init_rrefs (bool, optional): Whether or not to initialize
375            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
376            Need to initialize the RPC Framework if specified as ``True``.
377            Default: ``False``.
378
379    Returns:
380        A :class:`ShardedTensor` object handle on this rank
381
382
383    Examples:
384        Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
385        each shard have a (5, 5) local tensor, we can do it like below:
386
387        on rank 0:
388        >>> # xdoctest: +SKIP("not distributed")
389        >>> local_shard_metadata = ShardMetadata(
390        >>>     shard_offsets=[0, 0],
391        >>>     shard_lengths=[5, 5],
392        >>>     placement="rank:0/cuda:0"
393        >>> )
394        >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
395        >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
396
397        on rank 1:
398        >>> # xdoctest: +SKIP("not distributed")
399        >>> local_shard_metadata = ShardMetadata(
400        >>>     shard_offsets=[5, 0],
401        >>>     shard_lengths=[5, 5],
402        >>>     placement="rank:1/cuda:1"
403        >>> )
404        >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
405        >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
406    """
407    return ShardedTensor._init_from_local_shards(
408        local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs
409    )
410
411
412def state_dict_hook(module, destination, prefix, local_metadata):
413    """
414    Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
415    registered to the Module using
416    :meth:`torch.nn.Module._register_state_dict_hook`.
417    """
418    for submodule_name, submodule in module.named_modules():
419        for attr_name, attr in submodule.__dict__.items():
420            if isinstance(attr, ShardedTensor):
421                mod_prefix = prefix + submodule_name
422                key = mod_prefix + ("." if mod_prefix else "") + attr_name
423                destination[key] = attr
424
425
426def pre_load_state_dict_hook(
427    module,
428    state_dict,
429    prefix,
430    local_metadata,
431    strict,
432    missing_keys,
433    unexpected_keys,
434    error_msgs,
435):
436    """
437    Pre-load state dict hook to add ShardedTensor to the module.
438    """
439    for submodule_name, submodule in module.named_modules():
440        for attr_name in submodule.__dict__.keys():
441            mod_prefix = prefix + submodule_name
442            key = mod_prefix + ("." if mod_prefix else "") + attr_name
443            if key in state_dict:
444                if isinstance(state_dict[key], ShardedTensor):
445                    setattr(submodule, attr_name, state_dict[key])
446
447
448def custom_sharded_op_impl(func):
449    """
450    Provides a way for users to write their own custom sharded operator. This
451    can be used to override existing ShardedTensor operators or write a new
452    one not supported by ShardedTensor. If the operator in question is covered
453    by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
454    parameters, the function provided will be invoked for that operator.
455
456    Example::
457        >>> # xdoctest: +SKIP
458        >>> @custom_sharded_op_impl(torch.nn.functional.linear)
459        >>> def my_custom_sharded_linear(types, args, kwargs, process_group):
460        >>>     ...
461        >>> # xdoctest: +SKIP("Undefined variables")
462        >>> input = torch.rand(10, 32)
463        >>> weight = sharded_tensor.rand(32, 16)
464        >>> bias = torch.rand(16)
465        >>> # This will call 'my_custom_sharded_linear'
466        >>> torch.nn.functional.linear(input, weight, bias)
467
468    The types, args and kwargs parameters are the same parameters that are
469    passed to ``__torch_function__`` dispatch API
470    (https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
471    There is an additional ``process_group`` parameter which is the
472    process_group used for the ShardedTensor and can be used by
473    implementations for communications within a sharded implementation.
474
475    Args:
476        func(Callable): Torch function for which we want to provide a sharded
477            implementation (ex: torch.nn.functional.linear)
478    """
479    return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS)
480
481
482def _sharded_op_impl(func):
483    """
484    Decorator to register a default sharded op.
485    """
486    return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS)
487
488
489# Import all builtin sharded ops
490from ._ops import *  # noqa: F403
491