xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/microbatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import logging
4from typing import Any, Dict, List, Optional, Tuple
5
6import torch
7from torch.fx.node import map_aggregate
8from torch.utils._pytree import tree_flatten, tree_unflatten
9
10
11__all__ = [
12    "TensorChunkSpec",
13    "split_args_kwargs_into_chunks",
14    "merge_chunks",
15]
16
17logger = logging.getLogger(__name__)
18
19"""
20_debug_mask_minibatches specifies to send masked versions of the mini-batch
21through instead of micro-batch slices--this can be used for more stable
22numerical testing (see [A Note About Correctness Testing])
23"""
24_debug_mask_minibatches = False
25
26
27class _CustomReducer:
28    """
29    Custom reducer class that can be used to specify a custom operation that
30    reduces losses of multiple microbatches into one value.
31
32    Example:
33    >>> # xdoctest: +SKIP
34    >>> sum_reducer = _CustomReducer(
35    >>>     torch.tensor(0.0),
36    >>>     lambda a, b: a + b
37    >>> )
38    """
39
40    def __init__(self, init_value, reduce_fn):
41        self.init_value = init_value
42        self.reduce_fn = reduce_fn
43
44
45class _LossReducer(_CustomReducer):
46    pass
47
48
49sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
50
51# Default chunking dimension is 0. This is used for the case where the user did
52# not specify a chunking dimension.
53DEFAULT_CHUNK_DIM = 0
54
55
56class TensorChunkSpec:
57    """
58    Class used to specify chunking of inputs
59    """
60
61    def __init__(self, split_dim):
62        self.split_dim = split_dim
63
64    split_dim: int
65
66    def __repr__(self):
67        return (
68            f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
69        )
70
71    def __str__(self):
72        return f"TensorChunkSpec({self.split_dim})"
73
74    @staticmethod
75    def from_tuple(
76        chunk_dims: Tuple[int, ...],
77    ):
78        """
79        A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
80        dimensions (int's).
81        Example:
82            >>> # xdoctest: +SKIP
83            >>> # There are three positional arguments to the model, and
84            >>> # we are chunking them along dimension 0, 0 and 1, respectively
85            >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
86        """
87        args_chunk_spec = map_aggregate(
88            chunk_dims,
89            lambda dim: TensorChunkSpec(dim),  # type: ignore[arg-type,return-value]
90        )
91        return args_chunk_spec
92
93    @staticmethod
94    def from_dict(
95        chunk_dims: Dict[str, int],
96    ):
97        """
98        A helper for creating a dictionary of `TensorChunkSpec` from a
99        dictionary of chunk dimensions (int's).
100        Example:
101            >>> # xdoctest: +SKIP
102            >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
103            >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
104        """
105        kwargs_chunk_spec = map_aggregate(
106            chunk_dims,
107            lambda dim: TensorChunkSpec(dim),  # type: ignore[arg-type,return-value]
108        )
109        return kwargs_chunk_spec
110
111
112# Class used to specify replication of inputs
113class _Replicate:
114    pass
115
116
117def _shard_dict_of_args(
118    args_dict,
119    args_chunk_spec,
120    num_chunks,
121):
122    """
123    Given a dictionary of args, and a dictionary of chunking specs, shard the
124    args according to the chunking specs.
125
126    Args:
127        args_dict: Dictionary of args
128        args_chunk_spec: Dictionary of chunking specs
129        num_chunks: Number of chunks to shard the args into
130
131    Returns:
132        args_split: List of sharded args
133    """
134    # Stage 1+2: flatten and shard/replicate
135
136    # args_sharded_replicated : [num args, num flat values, num chunks]
137    args_sharded_replicated = {}
138    arg_specs = []
139
140    real_num_chunks = num_chunks
141    first_tensor = True
142
143    assert len(args_dict) == len(
144        args_chunk_spec
145    ), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
146
147    for arg_key, arg in args_dict.items():
148        flat, spec = tree_flatten(arg)
149        arg_specs.append(spec)
150
151        chunk_spec = args_chunk_spec[arg_key]
152        assert chunk_spec is not None  # Should have been set by caller
153        chunk_spec_flat, _ = tree_flatten(chunk_spec)
154        if len(flat) != len(chunk_spec_flat):
155            raise ValueError(
156                f"Argument value {arg} did not have the same number of "
157                f"values as as chunk spec {chunk_spec}"
158            )
159
160        sharded_arg_flat = []
161
162        for v, chunk_v in zip(flat, chunk_spec_flat):
163            if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
164                sharded_arg_flat.append([v] * real_num_chunks)
165            elif isinstance(chunk_v, TensorChunkSpec):
166                # TODO: check type of v. If it's a tensor, use chunk (or debug mask).
167                # If it's a collection type, split it as you would expect. Otherwise,
168                # Throw an error
169                assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
170
171                v_split_dim_size = v.size(chunk_v.split_dim)
172                if v_split_dim_size < real_num_chunks:
173                    if first_tensor:
174                        # We can only adjust number of chunks when we hit this
175                        # issue at the first tensor encountered
176                        logger.warning(
177                            f"Tensor size on chunking dimension is {v_split_dim_size}, "  # noqa: G004
178                            f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
179                        )
180                        real_num_chunks = v_split_dim_size
181                    else:
182                        raise RuntimeError(
183                            f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
184                            f"smaller than the number of chunks {num_chunks}. "
185                            "PiPPy cannot reduce the number of chunks because "
186                            "other arguments have bigger chunk-dimension sizes. "
187                            "Please adjust your num_chunks setting."
188                        )
189
190                chunk_tensors = torch.tensor_split(
191                    v, real_num_chunks, chunk_v.split_dim
192                )
193
194                if _debug_mask_minibatches:
195                    expanded_chunks = []
196
197                    split_dim_idx = 0
198                    for chunk_tensor in chunk_tensors:
199                        new_val = torch.zeros_like(v)
200                        upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
201
202                        slice_indices = [slice(None, None, None)] * new_val.ndim
203                        slice_indices[chunk_v.split_dim] = slice(
204                            split_dim_idx, upper_idx
205                        )
206                        new_val[slice_indices] = chunk_tensor
207
208                        expanded_chunks.append(new_val)
209
210                        split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
211
212                    sharded_arg_flat.append(expanded_chunks)
213                else:
214                    sharded_arg_flat.append(chunk_tensors)  # type: ignore[arg-type]
215
216                first_tensor = False
217            else:
218                raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
219
220        args_sharded_replicated[arg_key] = sharded_arg_flat
221
222    # chunks_flat : [num chunks, num args, num flat values]
223    chunks_flat = []
224    for chunk_idx in range(real_num_chunks):
225        chunk_args = {}
226        for key, arg in args_sharded_replicated.items():
227            arg_single_chunk = []
228            for v_flat in arg:
229                arg_single_chunk.append(v_flat[chunk_idx])
230            chunk_args[key] = arg_single_chunk
231        chunks_flat.append(chunk_args)
232
233    # args_split : [num chunks, num args]
234    args_split = []
235
236    for chunk in chunks_flat:
237        per_chunk_args = {}
238        assert len(arg_specs) == len(chunk)
239        for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
240            per_chunk_args[key] = tree_unflatten(arg, arg_spec)
241        args_split.append(per_chunk_args)
242
243    return args_split
244
245
246def split_args_kwargs_into_chunks(
247    args: Tuple[Any, ...],
248    kwargs: Optional[Dict[str, Any]],
249    chunks: int,
250    args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
251    kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
252) -> Tuple[List[Tuple], List[Dict]]:
253    """
254    Given a sequence of args and kwargs, split them into a number of chunks
255    according to  their respective chunking specs.
256
257    Args:
258        args: Tuple of args
259        kwargs: Dict of kwargs
260        chunks: Number of chunks to split the args and kwargs into
261        args_chunk_spec: chunking specs for args, in same shape as args
262        kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
263
264    Returns:
265        args_split: List of sharded args
266        kwargs_split: List of sharded kwargs
267    """
268    # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
269    # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
270    # and `kwargs_chunk_spec` specifications. The steps are as follows:
271    #
272    # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
273    #    To use a running example: suppose our inputs look like
274    #
275    #       args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
276    #       (kwargs not shown but it's a similar process)
277    #
278    #    Then for this step we would end up with
279    #
280    #       args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
281    #
282    # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
283    #
284    #       args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
285    #
286    # 3. Rotate the nesting order such that chunks are the outer dimension
287    #
288    #       args_chunks = [
289    #           ([A, B, C_1], D),
290    #           ([A, B, C_2], D),
291    #       ]
292    #
293    # 4. Unflatten each chunk according to the spec
294    #
295    #       args_chunks = [
296    #           ([A, [B, C_1]], D),
297    #           ([A, [B, C_2]], D),
298    #       ]
299
300    # TODO: _debug_mask_minibatches
301    # Handle the case where kwargs is None
302    if kwargs is None:
303        kwargs = {}
304
305    # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
306    # their format and use default chunking along dim 0
307    if args_chunk_spec is None:
308        args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
309
310    if kwargs_chunk_spec is None:
311        kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
312
313    args_split_dict = _shard_dict_of_args(
314        dict(enumerate(args)),
315        dict(enumerate(args_chunk_spec)),
316        chunks,
317    )
318    real_num_chunks = len(args_split_dict)
319
320    kwargs_split = _shard_dict_of_args(
321        kwargs,
322        kwargs_chunk_spec,
323        real_num_chunks,
324    )
325
326    if len(kwargs_split) < real_num_chunks:
327        # In case kwargs are sharded into less chunks
328        # e.g. when `args` has no tensor, just values
329        real_num_chunks = len(kwargs_split)
330        # Re-shard args
331        args_split_dict = _shard_dict_of_args(
332            dict(enumerate(args)),
333            dict(enumerate(args_chunk_spec)),
334            real_num_chunks,
335        )
336
337    if len(args_split_dict) != len(kwargs_split):
338        raise RuntimeError(
339            "args and kwargs are split into different number of chunks: "
340            f"{len(args_split_dict)}, {len(kwargs_split)}"
341        )
342
343    args_split = []
344    for chunk_args in args_split_dict:
345        args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
346
347    return args_split, kwargs_split
348
349
350def merge_chunks(
351    chunks: List[Any],
352    chunk_spec,
353):
354    """
355    Given a list of chunks, merge them into a single value according to
356    the chunk spec.
357
358    Args:
359        chunks: list of chunks
360        chunk_spec: Chunking spec for the chunks
361
362    Returns:
363        value: Merged value
364    """
365    # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
366    # steps are similar to the steps in that function but in reverse. Given the
367    # input values:
368    #
369    #       chunks = [
370    #           ([A, [B, C_1]], D),
371    #           ([A, [B, C_2]], D),
372    #       ]
373    #       args_spec = ([None, [None, TensorChunkSpec]], None)
374    #
375    # 1. Flatten the chunks according to the chunk_spec
376    #
377    #       chunks_flat = [
378    #           ([A, B, C_1], D),
379    #           ([A, B, C_2], D),
380    #       ]
381    #
382    # 2. Rotate the nesting order such that chunks are the inner dimension
383    #
384    #       value_inner = ([A, B, [C_1, C_2]], D)
385    #
386    # 3. Concatenate sharded arguments
387    #
388    #       value_combined = ([A, B, C], D)
389    #
390    # 4. Unflatten the combined args given the spec
391    #
392    #       value = ([A, [B, C]], D)
393
394    # Preliminary: flatten the chunk spec
395    if chunk_spec is not None:
396        spec_flattened, flatten_spec = tree_flatten(chunk_spec)
397    else:
398        # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
399        # We obtain the output structure by flattening chunk 0 and generate the chunk_spec
400        chunk0_flat, flatten_spec = tree_flatten(chunks[0])
401        spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
402
403    # Stage 1: flatten chunks
404    # chunks_flattened : [num chunks, num args]
405    chunks_flattened = []
406
407    for chunk in chunks:
408        chunk_flattened, _ = tree_flatten(chunk)
409        if len(chunk_flattened) != len(spec_flattened):
410            raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
411
412        chunks_flattened.append(chunk_flattened)
413
414    # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
415    #                concatenate sharded operands
416    # args_flattened : [num args]
417    args_flattened = []
418    for arg_idx, arg in enumerate(spec_flattened):
419        if isinstance(arg, TensorChunkSpec):
420            partial_values = [
421                chunks_flattened[chunk_idx][arg_idx]
422                for chunk_idx in range(len(chunks_flattened))
423            ]
424
425            if _debug_mask_minibatches:
426                # Infer size of individual chunks by running `tensor_split` again
427                overall_shape = partial_values[0].shape
428                for val in partial_values[1:]:
429                    assert val.shape == overall_shape
430                meta_chunks = torch.tensor_split(
431                    torch.empty(*overall_shape, device="meta"),
432                    sections=len(partial_values),
433                    dim=arg.split_dim,
434                )
435
436                values_to_cat = []
437                chunk_start_idx = 0
438                assert len(partial_values) == len(meta_chunks)
439                for partial_value, meta_chunk in zip(partial_values, meta_chunks):
440                    chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
441
442                    slice_indices = [slice(None, None, None)] * partial_value.ndim
443                    slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
444                    sliced = partial_value[slice_indices]
445                    values_to_cat.append(sliced)
446
447                    chunk_start_idx = chunk_end_idx
448
449            else:
450                values_to_cat = partial_values
451
452            args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
453        elif isinstance(arg, _CustomReducer):
454            reduced_val = arg.init_value
455
456            for chunk_idx in range(len(chunks_flattened)):
457                reduced_val = arg.reduce_fn(
458                    reduced_val, chunks_flattened[chunk_idx][arg_idx]
459                )
460
461            args_flattened.append(reduced_val)
462        else:
463            value = chunks_flattened[0][arg_idx]
464            for chunk_idx in range(1, len(chunks_flattened)):
465                assert chunks_flattened[chunk_idx][arg_idx] == value
466            args_flattened.append(value)
467
468    # Stage 4: Unflatten combined args
469    return tree_unflatten(args_flattened, flatten_spec)
470