xref: /aosp_15_r20/external/pytorch/torch/distributed/pipelining/stage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import logging
4import operator
5from abc import ABC, abstractmethod
6from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
7
8import torch
9import torch.distributed as dist
10import torch.fx as fx
11import torch.nn as nn
12from torch._subclasses.fake_tensor import FakeTensor
13from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard
14from torch.fx.node import map_aggregate
15from torch.nn.parallel import DistributedDataParallel
16
17from ._backward import stage_backward, stage_backward_input, stage_backward_weight
18from ._debug import map_debug_info
19from ._utils import flatten_args, PipeInfo, validate_tensors_metadata
20
21
22__all__ = [
23    "PipelineStage",
24    "build_stage",
25]
26
27logger = logging.getLogger(__name__)
28
29
30class _RootArgPlaceholder:
31    """
32    Placeholder for model-level inputs.
33    """
34
35    def __init__(self, tensor):
36        self.meta = tensor.to("meta")
37
38
39class _RecvInfo:
40    """
41    Represents a stage input.
42    """
43
44    def __init__(
45        self,
46        input_name: str,
47        source: int,
48        buffer: torch.Tensor,
49    ):
50        # Name of this input
51        self.input_name = input_name
52        # Stage index of the source of this input
53        self.source = source
54        # Buffer to receive the input into.
55        self.buffer = buffer
56
57    def __repr__(self):
58        return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})"
59
60
61# An input can be either a received activation or a model input
62InputInfo = Union[_RecvInfo, _RootArgPlaceholder]
63
64
65def _make_tensor_from_meta(
66    example: Union[torch.Tensor, FakeTensor],
67    device: torch.device,
68) -> torch.Tensor:
69    """
70    Create a real tensor from a tensor.
71    """
72    return torch.empty(
73        example.size(),
74        dtype=example.dtype,
75        layout=example.layout,
76        device=device,
77    )
78
79
80class _PipelineStageBase(ABC):
81    """
82    Base class for pipeline stages.
83    Defines or implements common methods used by the `_PipelineStage` used by
84    the tracing frontend and `PipelineStage` used by manual frontend.
85    """
86
87    def __init__(
88        self,
89        submodule: torch.nn.Module,
90        stage_index: int,
91        num_stages: int,
92        device: torch.device,
93        group: Optional[dist.ProcessGroup] = None,
94        dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
95    ):
96        """
97        Args:
98            submodule (torch.nn.Module): The module to be executed in this stage.
99            stage_index (int): The index of this stage.
100            num_stages (int): The total number of stages in this pipeline.
101            device (torch.device): The device to run this stage on.
102            group (Optional[dist.ProcessGroup]): The process group to use for communication.
103                If `None`, the default process group will be used.
104                Default: `None`.
105            dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_runner is a builder function
106                that will build a new dw_runner function that will run parts of module backward that were intentionally
107                skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs
108                model backwards, and stage should save the latest dw_runner to run during weight pass.
109                If not provided, a dw_runner will be generated automatically by traversing the autograd graph.
110                When used with schedules that only have F and B steps, the fresh dw_runner function will be called as
111                part of B.
112                When used with F,B,W schedules, the dw_runner function implements 'W'.
113        """
114        super().__init__()
115        if stage_index >= num_stages:
116            raise ValueError(
117                f"Stage index {stage_index} is out of range of {num_stages}"
118            )
119
120        self.submod = submodule
121        self.stage_index = stage_index
122        self.num_stages = num_stages
123        self.device = device
124        self.group = group
125
126        self.dw_builder = dw_builder
127
128        # backward state
129        self.backward_state: Dict[int, Tuple[Any, ...]] = {}
130
131        # store dw_runner per microbatch_id
132        self.dw_runner: Dict[int, Callable[..., None]] = {}
133
134        # `group_rank` is rank in process group `group`.
135        self.group_rank = dist.get_rank(self.group)
136        self.group_size = dist.get_world_size(self.group)
137        if self.group_size > self.num_stages:
138            raise RuntimeError(
139                f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}"
140            )
141
142        # Run time states
143        self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None
144        # map microbatch ID to list of forward tensor args
145        self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {}
146        # Caching chunk outputs for final output merge or reduction
147        self.output_chunks: List[Any] = []
148
149        # Initialize has_backward to false; this will be set to true if loss
150        # function is passed to pipeline schedule
151        self.has_backward = False
152        # Log prefix
153        self.log_prefix = f"[Stage {self.stage_index}]"
154
155        # Forward infra
156        self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {}
157        self.set_requires_grad: Dict[int, bool] = {}
158        self.act_send_info: Dict[int, List] = {}
159
160        # Backward infra will created lazily
161        self.grad_recv_info: Dict = {}
162        self.grad_send_info: Optional[List] = None
163
164        # Number of backward chunks seen. This is used to determine when to do
165        # grad reduction in DDP or FSDP.
166        self._seen_bwd_chunks = 0
167
168        # To be populated later by the Schedule
169        self.chunks: Optional[int] = None
170        self.stage_index_to_group_rank: Dict[int, int] = {
171            i: i % self.group_size for i in range(self.num_stages)
172        }
173
174    @property
175    def has_backward(self) -> bool:
176        """
177        Returns true if this stage has a backward pass.
178        """
179        return self._has_backward
180
181    @has_backward.setter
182    def has_backward(self, has_backward: bool):
183        self._has_backward = has_backward
184
185    @property
186    def is_first(self):
187        """
188        Returns true if this stage is the first stage in the pipeline.
189        """
190        return self.stage_index == 0
191
192    @property
193    def is_last(self):
194        """
195        Returns true if this stage is the last stage in the pipeline.
196        """
197        return self.stage_index == self.num_stages - 1
198
199    def _check_chunk_id(self, chunk_id: int):
200        if self.chunks is None:
201            raise RuntimeError(
202                "Attempted to access chunk_id before chunks have been configured."
203            )
204        if chunk_id >= self.chunks:
205            raise RuntimeError(
206                f"Chunk id {chunk_id} is out of range [0, {self.chunks})"
207            )
208
209    def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]):
210        """
211        Track the output shapes/dtype of this stage since they determine the send operation(s) which must match
212        recv operations of the next stage.  The next stage _will_ be freezing its recv buffers based on its initial
213        configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches
214        which could show up as hangs, silent corruption, or other errors.
215        """
216        assert (
217            self._outputs_meta is None
218        ), "Attempting to reconfigure output_meta, which is not supported"
219        self._outputs_meta = tuple(outputs_meta)  # type: ignore[assignment]
220
221    def get_outputs_meta(self) -> Tuple[torch.Tensor, ...]:
222        """Get the output metadata (meta tensors) reprensenting the outputs of this stage"""
223        assert (
224            self._outputs_meta is not None
225        ), "Attempted to get_outputs_meta() without configuring output meta"
226        return self._outputs_meta
227
228    def _create_grad_send_info(
229        self,
230        args_recv_info: Tuple,
231    ) -> List[Optional[int]]:
232        """
233        Create a list of stage indices to send gradients to.
234        """
235        grad_send_info: List[Optional[int]] = []
236
237        def map_recv_to_send(a):
238            # Note: we send gradients back to previous stage as long as in
239            # forward it is a received input, regardless of whether it requires
240            # grad. It is up to the previous stage to disgard this gradient.
241            if isinstance(a, _RecvInfo):
242                grad_send_info.append(a.source)
243                return a.source
244            else:
245                grad_send_info.append(None)
246                return None
247
248        map_aggregate(args_recv_info, map_recv_to_send)
249
250        logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info)
251        return grad_send_info
252
253    @abstractmethod
254    def _prepare_forward_infra(self, num_microbatches: int):
255        raise NotImplementedError
256
257    def _prepare_backward_infra(self, num_microbatches: int):
258        # TODO: this is needed for backward_maybe_with_nosync
259        self.chunks = num_microbatches
260
261        for mb_index in range(num_microbatches):
262            # `grad_recv_info` is a mirror of `act_send_info`
263            self.grad_recv_info[mb_index] = self._create_grad_recv_info(
264                self.act_send_info
265            )
266
267    @abstractmethod
268    def _create_grad_recv_info(
269        self,
270        act_send_info: Dict,
271    ) -> Tuple[_RecvInfo, ...]:
272        raise NotImplementedError
273
274    def _get_recv_ops(
275        self,
276        recv_infos: Tuple[InputInfo, ...],
277    ) -> List[dist.P2POp]:
278        """
279        Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`.
280        Returns a list of ops that correspond to the recv infos.
281        """
282        ops: List[dist.P2POp] = []
283        for info in recv_infos:
284            if not isinstance(info, _RecvInfo):
285                continue
286
287            peer_rank = self.stage_index_to_group_rank[info.source]
288            peer_global_rank = (
289                peer_rank
290                if self.group is None
291                else dist.get_global_rank(self.group, peer_rank)
292            )  # TODO
293            ops.append(
294                dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group)
295            )
296
297        return ops
298
299    def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
300        """
301        Returns a list of ops that are needed to receive the input arguments
302        for this stage.
303        """
304        recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id]
305
306        # In case there is backward pass, set requires_grad for receive buffers
307        # before first forward
308        if self.has_backward and not self.set_requires_grad[fwd_chunk_id]:
309            for a in recv_infos:
310                if isinstance(a, _RecvInfo):
311                    a.buffer.requires_grad_(True)
312
313        return self._get_recv_ops(recv_infos)
314
315    def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
316        """
317        Returns a list of ops that are needed to receive the gradients
318        for this stage.
319        """
320        if not self.has_backward or self.is_last:
321            return []
322
323        recv_infos = self.grad_recv_info[bwd_chunk_id]
324        return self._get_recv_ops(recv_infos)
325
326    def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
327        """
328        Get the activation send ops for current stage's forward.
329        """
330        output = self.output_chunks[fwd_chunk_id]
331        # Unify output form to tuple for easy correspondance with
332        # `act_send_info`
333        output_tuple = output if type(output) is tuple else (output,)
334
335        ops: List[dist.P2POp] = []
336
337        for idx, out in enumerate(output_tuple):
338            dst_stages = self.act_send_info[idx]
339            for dst in dst_stages:
340                if dst is None:
341                    continue
342                logger.debug(
343                    "%s Sending tensor to Stage %s: %s",
344                    self.log_prefix,
345                    dst,
346                    out.size(),
347                )
348                peer_rank = self.stage_index_to_group_rank[dst]
349                peer_global_rank = (
350                    peer_rank
351                    if self.group is None
352                    else dist.get_global_rank(self.group, peer_rank)
353                )  # TODO
354                ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group))
355
356        return ops
357
358    def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
359        """
360        Get the gradient send ops for current stage's backward.
361        """
362        self._check_chunk_id(bwd_chunk_id)
363
364        if not self.has_backward or self.is_first:
365            return []
366
367        # Create bwd send infra lazily
368        if self.grad_send_info is None:
369            # Send info for input grads during backward:
370            # List of destinations corresponding to input grads
371            # Can be None if an input has no grad
372            # `grad_send_info` is a mirror of `args_recv_info`
373            self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0])
374
375        ops: List[dist.P2POp] = []
376        for grad, grad_recv_stage in zip(self.grads_input, self.grad_send_info):
377            if isinstance(grad, torch.Tensor) and grad_recv_stage is not None:
378                logger.debug(
379                    "%s Sending gradient to Stage %s: %s",
380                    self.log_prefix,
381                    grad_recv_stage,
382                    grad.size(),
383                )
384                peer_rank = self.stage_index_to_group_rank[grad_recv_stage]
385                peer_global_rank = (
386                    peer_rank
387                    if self.group is None
388                    else dist.get_global_rank(self.group, peer_rank)
389                )  # TODO
390                ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group))
391            else:
392                if not (grad is None and grad_recv_stage is None):
393                    raise RuntimeError(
394                        f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} "
395                        f"and is expecting to send gradients to stage {grad_recv_stage}"
396                    )
397        return ops
398
399    def clear_runtime_states(self) -> None:
400        """
401        Clear runtime states of the stage.
402        """
403        # map microbatch ID to list of forward tensor args
404        self.fwd_cache.clear()
405        # Caching chunk outputs for final output merge or reduction
406        self.output_chunks.clear()
407        # Reset bwd chunk counter
408        self._seen_bwd_chunks = 0
409
410        # Clear grad of input buffers in between schedule steps. This is because
411        # `torch.autograd.backward()` will accumulate gradients into leaf
412        # tensors by default. For gradients to pass back to previous stages, we
413        # don't want such accumulation.
414        for recv_tuple in self.args_recv_info.values():  # iterate over all chunks
415            for a in recv_tuple:  # iterate over all input args
416                if isinstance(a, _RecvInfo):
417                    # Set to None is the newer and recommended way to clear grads, compared to `zero_()`.
418                    # See https://github.com/pytorch/pytorch/pull/92731
419                    a.buffer.grad = None
420
421    def _map_tensor_from_recv_info(
422        self,
423        recv_infos: Tuple[InputInfo, ...],
424    ):
425        """
426        Map tensors from recv infos to a list.
427        """
428
429        def get_recv_tensor(info):
430            if isinstance(info, _RecvInfo):
431                return info.buffer
432            else:
433                raise AssertionError(f"Expected _RecvInfo but got {type(info)}")
434
435        tensors = map_aggregate(
436            recv_infos,
437            get_recv_tensor,
438        )
439
440        return tensors
441
442    def _retrieve_recv_activations(self, fwd_chunk_id: int):
443        """
444        Retrieve the activations received for the current stage during forward.
445        """
446        recv_infos = self.args_recv_info[fwd_chunk_id]
447        activations = self._map_tensor_from_recv_info(recv_infos)
448        return activations
449
450    def _retrieve_recv_grads(
451        self,
452        bwd_chunk_id: int,
453    ):
454        """
455        Retrieve the gradients received for the current stage during backward.
456        """
457        recv_infos = self.grad_recv_info[bwd_chunk_id]
458        grads = self._map_tensor_from_recv_info(recv_infos)
459        return grads
460
461    def forward_maybe_with_nosync(self, *args, **kwargs):
462        # If submod is wrapped with DDP, we use the `no_sync` context manager to
463        # avoid gradient all-reduce per microbatch
464        if isinstance(self.submod, DistributedDataParallel):
465            with self.submod.no_sync():  # type: ignore[operator]
466                out_val = self.submod(*args, **kwargs)
467        else:
468            out_val = self.submod(*args, **kwargs)
469        return out_val
470
471    def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict):
472        """
473        Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the
474        other steps.  Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
475        there are additional state-variables and performance considerations depending on the data parallelism used.
476        This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries.
477        """
478        full_backward = bwd_kwargs["full_backward"]
479        if full_backward:
480            last_backward = self._seen_bwd_chunks == self.chunks - 1  # type: ignore[operator]
481        else:
482            # For backwards are split into weight and input, we will see twice as many bwd_chunks
483            last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1  # type: ignore[operator]
484
485        def perform_backward(backward_type):
486            if backward_type == "full":
487                return lambda: stage_backward(
488                    bwd_kwargs["stage_output"],
489                    bwd_kwargs["output_grads"],
490                    bwd_kwargs["input_values"],
491                )
492            elif backward_type == "input":
493                return lambda: stage_backward_input(
494                    bwd_kwargs["stage_output"],
495                    bwd_kwargs["output_grads"],
496                    bwd_kwargs["input_values"],
497                    self.submod.parameters(),
498                )
499            elif backward_type == "weight":
500                return lambda: stage_backward_weight(
501                    self.submod.parameters(), bwd_kwargs["param_groups"]
502                )
503            else:
504                raise RuntimeError(f"Unknown backward type: {backward_type}")
505
506        # If submod is wrapped by DDP
507        if isinstance(self.submod, DistributedDataParallel):
508            if last_backward:
509                # Last chunk, prepare for gradient reduction
510                # HACK: reaching into DDP implementation details here. Is there a better way?
511                self.submod.reducer.prepare_for_backward(  # type: ignore[union-attr, operator]
512                    list(
513                        torch.nn.parallel.distributed._find_tensors(  # type: ignore[attr-defined]
514                            bwd_kwargs["stage_output"]
515                        )
516                    )
517                )
518                result = perform_backward(backward_type)()
519            else:
520                with self.submod.no_sync():  # type: ignore[operator]
521                    result = perform_backward(backward_type)()
522        # If submod is a FSDP module
523        elif isinstance(self.submod, FSDPModule):
524            self.submod.set_is_last_backward(False)
525            self.submod.set_reshard_after_backward(False)
526            self.submod.set_requires_gradient_sync(False)
527            result = perform_backward(backward_type)()
528            if last_backward:
529                # Manually call post backward for FSDP
530                def run_post_backward(fsdp_module: FSDPModule) -> None:
531                    fsdp_module.set_is_last_backward(True)
532                    fsdp_module.set_reshard_after_backward(True)
533                    fsdp_module.set_requires_gradient_sync(True)
534                    fsdp_state = fully_shard.state(fsdp_module)
535                    for state in fsdp_state._state_ctx.all_states:
536                        if state._fsdp_param_group:
537                            state._fsdp_param_group.post_backward()
538
539                run_post_backward(self.submod)
540        else:
541            # Non-DP submodule, regular backward
542            result = perform_backward(backward_type)()
543
544        self._seen_bwd_chunks += 1
545
546        if isinstance(result, tuple) and len(result) == 2:
547            # for stage_backward_input()
548            grads, param_groups = result
549        else:
550            grads, param_groups = result, None
551
552        return grads, param_groups
553
554    def forward_one_chunk(
555        self,
556        fwd_chunk_id: int,
557        args: Tuple[Any, ...],
558        kwargs: Optional[Dict[str, Any]] = None,
559    ):
560        """
561        Perform forward pass on the stage with one microbatch.
562        `args` and `kwargs` are the inputs from *external* to this stage. They
563        applies only to the first stage in most cases.
564        """
565
566        if self.is_first:
567            # First stage doesn't need to receive anything
568            composite_args = args
569            composite_kwargs = kwargs or {}
570        else:
571            # Receive activations for this chunk
572            # Activations only come in args form
573            composite_args = self._retrieve_recv_activations(fwd_chunk_id)
574            composite_kwargs = {}
575
576        self._validate_fwd_input(args, kwargs)
577
578        # Compute forward
579        try:
580            output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
581
582        except Exception as e:
583            exc_msg = f"""
584            {self.log_prefix} failed to run forward:
585            args: {map_debug_info(composite_args)}
586            kwargs: {map_debug_info(composite_kwargs)}
587            """
588            raise RuntimeError(exc_msg) from e
589
590        if type(output) is list:
591            # HACK: this is a hacky workaround for the fact that export creates
592            # output in list format
593            output = tuple(output)
594
595        # Unify output form to tuple for easy correspondance with
596        # `act_send_info`
597        output_tuple = output if type(output) is tuple else (output,)
598        # Prepare for final output merge or reduction
599        self.output_chunks.append(output)
600
601        # Save activations and inputs for backward
602        flat_args = flatten_args(composite_args)
603        flat_kwargs = flatten_args(composite_kwargs)
604        flatten_input_tensors = flat_args + flat_kwargs
605        self.fwd_cache[fwd_chunk_id] = (
606            output_tuple,  # stage_output
607            flatten_input_tensors,  # input_values
608        )
609
610        logger.debug(
611            "%s Forwarded chunk %s, outputs: %s",
612            self.log_prefix,
613            fwd_chunk_id,
614            map_debug_info(output),
615        )
616        self._validate_fwd_outputs(output_tuple)
617        return output
618
619    def backward_one_chunk(
620        self, bwd_chunk_id: int, loss=None, full_backward: bool = True
621    ):
622        """
623        Perform backward pass on the module.
624        This should only be called once per microbatch.
625
626        If full_backward is True (the default), the full backward pass including weight and input gradients will be run,
627        and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id.
628
629        If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time,
630        and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward.
631        """
632        self._check_chunk_id(bwd_chunk_id)
633
634        (
635            stage_output,
636            input_values,
637        ) = self.fwd_cache.pop(bwd_chunk_id)
638
639        # Compute backward
640        if self.is_last:
641            # Last stage computes gradients from loss and has no gradients from
642            # next stage
643            bwd_kwargs = {
644                "stage_output": loss,
645                "output_grads": None,
646                "input_values": input_values,
647            }
648        else:
649            # Otherwise, receive gradients from next stage
650            grads_output = self._retrieve_recv_grads(bwd_chunk_id)
651            # If an input to the pipeline requires gradient,
652            # `torch.autograd.backward` will accumulate the gradient into the
653            # `.grad` field of such input
654            bwd_kwargs = {
655                "stage_output": stage_output,
656                "output_grads": grads_output,
657                "input_values": input_values,
658            }
659
660        # Save full_backward
661        bwd_kwargs["full_backward"] = full_backward
662
663        # Custom backward function
664        if self.dw_builder:
665            # TODO: We may want to change our semantics so we are allowed to ignore
666            # the 'dw_builder' and call full_backward directly when it is a full_backward op.
667            self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs)
668            if full_backward:
669                self.dw_builder()()
670            else:
671                self.dw_runner[bwd_chunk_id] = self.dw_builder()
672        else:
673            if full_backward:
674                self.grads_input, _ = self.backward_maybe_with_nosync(
675                    "full", bwd_kwargs
676                )
677            else:
678                # perform the partial backwards for the inputs with a custom backward function
679                # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors
680                if isinstance(bwd_kwargs["stage_output"], torch.Tensor):
681                    bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],)
682
683                grads_input, param_groups = self.backward_maybe_with_nosync(
684                    "input", bwd_kwargs
685                )
686
687                # TODO: we dont need to save this, add to dw_runner?
688                self.backward_state[bwd_chunk_id] = (
689                    input_values,
690                    param_groups,
691                    bwd_kwargs["stage_output"],
692                    bwd_kwargs["output_grads"],
693                )
694                self.grads_input = grads_input
695                # Save a placeholder for the dw_runner
696                self.dw_runner[bwd_chunk_id] = lambda: None
697        logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id)
698
699    def backward_weight_one_chunk(self, bwd_chunk_id: int):
700        assert bwd_chunk_id in self.dw_runner, (
701            f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}"
702            " without first calling `backward_one_chunk(full_backward=False)`"
703        )
704
705        if self.dw_builder is not None:
706            self.dw_runner.pop(bwd_chunk_id)()
707        else:
708            (
709                input_values,
710                param_groups,
711                stage_output,
712                output_grads,
713            ) = self.backward_state.pop(bwd_chunk_id)
714
715            if self.stage_index != 0:
716                bwd_kwargs = {
717                    "stage_output": stage_output,
718                    "param_groups": param_groups,
719                    "full_backward": False,
720                }
721                weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs)
722            else:
723                # TODO: figure out a better way to do this:
724                # if inputs does not require gradient,
725                # then the parameter group will not be fully captured during stage_backward_input
726                # in this case, we need call grad directly on the parameters
727                # To solve: make input fn do the intersect compute and then finish it off during W
728                bwd_kwargs = {
729                    "stage_output": stage_output,
730                    "output_grads": output_grads,
731                    "input_values": input_values,
732                    "full_backward": False,
733                }
734                self.backward_maybe_with_nosync("full", bwd_kwargs)
735
736    def _validate_fwd_input(self, args, kwargs):
737        """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage."""
738
739        if self.is_first:
740            # TODO why is there a separate recv_info for each pipeline chunk?
741            # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we
742            # check all chunks against args_recv_info[0]
743            expected_args = self.args_recv_info[0]
744        else:
745            # We don't check inputs for non-0 stages assuming they don't accept
746            # user inputs in canonical pipeline scenarios
747            return
748
749        if len(kwargs):
750            # TODO- need a mapping of kwarg to position in self.args_recv_info
751            # without it, we just validate shapes for args and ignore kwargs
752            expected_args = expected_args[: len(expected_args) - len(kwargs)]
753
754        # TODO- need a mapping of kwarg to position in self.args_recv_info
755        # maybe it's impossible to tell whether the len mismatches because
756        # (a) the user passed an extra arg or missed an arg
757        # (b) the user did not pass a kwarg, which has a default value baked into expected_args
758        expected_tensors_meta = [
759            e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer
760            for e in expected_args
761        ]
762        validate_tensors_metadata(
763            f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args
764        )
765
766    def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor, ...]):
767        """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype.
768        Most likely, this could be cause either by incorrect user specification of output shapes, or becuase
769        shape inference was done on the original model but then at runtime the model is wrapped with something like
770        mixed precision which changes output dtype.
771        """
772        expected_tensors_meta = self.get_outputs_meta()
773        validate_tensors_metadata(
774            f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs
775        )
776
777
778class _PipelineStage(_PipelineStageBase):
779    def __init__(
780        self,
781        stage_module: torch.nn.Module,
782        stage_index: int,
783        pipe_info: PipeInfo,
784        device: torch.device,
785        group: Optional[dist.ProcessGroup] = None,
786    ):
787        """
788        Create a pipeline stage given a stage_module to be wrapped by this stage
789        and a `pipe_info` describing the stage relationship of the pipeline.
790
791        Args:
792            stage_module (torch.nn.Module): the module to be wrapped by this stage
793            stage_index (int): the index of this stage in the pipeline
794            pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()`
795            device (torch.device): the device to be used by this stage
796            group (Optional[dist.ProcessGroup]): the process group to be used by this stage
797        """
798        _PipelineStageBase.__init__(
799            self,
800            stage_module,
801            stage_index,
802            pipe_info.num_stages,
803            device,
804            group,
805        )
806        self.pipe_info = pipe_info
807
808        # Find stage nodes in graph
809        submod_nodes = [
810            node for node in pipe_info.graph.nodes if node.op == "call_module"
811        ]
812        if len(submod_nodes) != self.num_stages:
813            raise AssertionError(
814                f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}"
815            )
816
817        # Find my stage node in graph
818        self.node = submod_nodes[self.stage_index]
819        self.name = self.node.name
820        logger.info(
821            "[%s] Creating PipelineStage %s for %s",
822            self.group_rank,
823            stage_index,
824            self.name,
825        )
826
827        # Create mapping from stage name to stage index
828        self.submod_to_stage_index: Dict[str, int] = {}
829        for i, node in enumerate(submod_nodes):
830            self.submod_to_stage_index.setdefault(node.name, i)
831
832        # Cast submodule to device
833        self._move_submod_to_device()
834
835    def _move_submod_to_device(self):
836        # Move submodule to indicated device if possible
837        # Note: we cannot move meta module to real devices because meta tensors
838        # do not support to() method. One needs to do an in-place tensor swap in
839        # that case.
840        has_meta_param = any(
841            isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters()
842        )
843        if has_meta_param:
844            logger.debug("%s Found meta parameters!", self.log_prefix)
845        else:
846            self.submod.to(self.device)
847
848    def _prepare_forward_infra(self, num_microbatches: int):
849        """
850        Create send/recv infrastructures for activations (during forward)
851        """
852        # Flag per chunk to keep track of whether we have set `requires_grad`
853        # for receive buffers. Format: {chunk : Boolean}
854        for chunk in range(num_microbatches):
855            self.args_recv_info[chunk] = self._create_act_recv_info()
856            self.set_requires_grad[chunk] = False
857
858        # Send info during forward for each activation
859        self.act_send_info = self._create_act_send_info()
860
861    def get_stage_index_of_submod(
862        self,
863        submod_name: str,
864    ):
865        """
866        Given a submodule name, return the stage index of the submodule.
867        """
868        if submod_name not in self.submod_to_stage_index:
869            raise AssertionError(f"Stage id of {submod_name} not found")
870
871        return self.submod_to_stage_index[submod_name]
872
873    def _create_act_recv_info(
874        self,
875    ):
876        """
877        Create a tuple of `_RecvInfo` for inputs to the stage.
878        """
879
880        def create_recv_tensor(placeholder, arg_node):
881            """
882            Create a receive buffer for a placeholder.
883            """
884            example_value = placeholder.meta["val"]
885            if arg_node.op == "placeholder":
886                # This is a root level placeholder, thus an input argument to the entire model.
887                # We are likely at stage 0, hence no need to create a receive buffer.
888                return _RootArgPlaceholder(example_value)
889
890            # Figure out the source stage of this input
891            while arg_node.target is operator.getitem:
892                # If the input is a getitem, we need to go deeper
893                arg_node = arg_node.args[0]
894
895            assert (
896                arg_node.op == "call_module"
897            ), f"Expecting call_module, got {arg_node.op}"
898            src_stage = self.get_stage_index_of_submod(arg_node.name)
899
900            # Create a receive buffer for this placeholder
901            logger.debug(
902                "%s Creating recv buffer for input '%s' : %s, %s",
903                self.log_prefix,
904                placeholder.name,
905                example_value.shape,
906                example_value.dtype,
907            )
908            buffer = _make_tensor_from_meta(example_value, self.device)
909
910            return _RecvInfo(
911                arg_node.name,
912                src_stage,
913                buffer,
914            )
915
916        args_recv_info: List[InputInfo] = []
917        # Filter out placeholder nodes from `self.submod` (a GraphModule)
918        placeholders = filter(
919            lambda node: node.op == "placeholder", self.submod.graph.nodes
920        )
921        # `placeholders` are nodes internal to submod.
922        # `self.node.args` are dependency nodes in the outer graph.
923        # The two are 1:1.
924        for placeholder, arg_node in zip(placeholders, self.node.args):
925            # Create a receive buffer for this placeholder
926            recv_info = create_recv_tensor(placeholder, arg_node)
927            args_recv_info.append(recv_info)
928
929        logger.debug(
930            "%s Activation recv / args info: %s", self.log_prefix, args_recv_info
931        )
932        # `args` is a Tuple, hence we will return a Tuple[InputInfo]
933        return tuple(args_recv_info)
934
935    def find_dst_rank(
936        self,
937        user: fx.Node,
938    ) -> Optional[int]:
939        """
940        Find the destination rank of a `user` node.
941        If the `user` is not a submod, `None` may be returned.
942        """
943        if user.op == "call_module":
944            # User is a stage (`call_module`)
945            return self.get_stage_index_of_submod(user.name)
946        else:
947            # - If user.op == "output":
948            #   No need to send back to rank 0
949            # - If user.target is stage_backward:
950            #   No need to send assuming submod output is stored locally or
951            #   should be re-calucated in case of activation checkpointing
952            return None
953
954    def _create_act_send_info(self):
955        """
956        Create a dict of send info for activations.
957        The dict is of the form:
958        {
959            output_index: [dst_rank_0, dst_rank_1, ...],
960            ...
961        }
962        where the list of `dst_rank`s covers the case where an output value may
963        be consumed by multiple stages.
964        """
965        # Output index: List of receiver ranks
966        act_send_info: Dict[int, List] = {}
967        out_idx = 0
968
969        for user in self.node.users:
970            if user.target is operator.getitem:
971                # Recursively find the real destination
972                gi_dsts = act_send_info.setdefault(out_idx, [])
973                for gi_user in user.users:
974                    dst_rank = self.find_dst_rank(gi_user)
975                    if dst_rank is not None:
976                        gi_dsts.append(dst_rank)
977                # Next `getitem` will point to the next output index
978                out_idx += 1
979            else:
980                # In case of single output value, `out_idx` will not increase
981                dsts = act_send_info.setdefault(out_idx, [])
982                dst_rank = self.find_dst_rank(user)
983                if dst_rank is not None:
984                    dsts.append(dst_rank)
985
986        output_node = self._get_output_node()
987        output_vals: Tuple[torch.Tensor] = tuple(
988            v.meta["val"] for v in flatten_args(output_node.args)
989        )
990        self._configure_outputs_meta(output_vals)
991
992        logger.debug("%s Send info: %s", self.log_prefix, act_send_info)
993        return act_send_info
994
995    def _get_output_node(self):
996        output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"]
997        assert len(output_nodes) == 1
998        output_node = output_nodes[0]
999        return output_node
1000
1001    def _create_grad_recv_info(
1002        self,
1003        act_send_info: Dict,
1004    ) -> Tuple[_RecvInfo, ...]:
1005        """
1006        Create a tuple of `_RecvInfo` for gradients.
1007        """
1008        # Dict[output_index, _RecvInfo]
1009        grad_recv_info: Dict[int, _RecvInfo] = {}
1010        output_node = self._get_output_node()
1011
1012        # The output node may take multiple args, meaning the submod having multiple output values.
1013        output_vals = flatten_args(output_node.args)
1014
1015        for out_idx, dst_list in act_send_info.items():
1016            if not dst_list:
1017                # No actual receiver for activation so no grad coming back
1018                continue
1019
1020            output = output_vals[out_idx]
1021            example_value = output.meta["val"]
1022            logger.debug(
1023                f"{self.log_prefix} Creating grad recv buffer for output {output.name} "  # noqa: G004
1024                f": {example_value.shape}, {example_value.dtype}"
1025            )
1026
1027            # TODO: otherwise needs grad accumulation
1028            assert len(dst_list) == 1, "Backward of skip connections not supported yet"
1029            grad_src = dst_list[0]
1030            grad_recv_info[out_idx] = _RecvInfo(
1031                f"{grad_src}",  # noqa: G004
1032                grad_src,
1033                _make_tensor_from_meta(example_value, self.device),
1034            )
1035
1036        # Convert to tuple for convenience in get_ops and retrieve tensor
1037        grad_recv_info_tuple = tuple(grad_recv_info.values())
1038        logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple)
1039        return grad_recv_info_tuple
1040
1041
1042# A helper function to create a pipeline stage based on traced pipeline information
1043def build_stage(
1044    stage_module: torch.nn.Module,
1045    stage_index: int,
1046    pipe_info: PipeInfo,
1047    device: torch.device,
1048    group: Optional[dist.ProcessGroup] = None,
1049) -> _PipelineStage:
1050    """
1051    Create a pipeline stage given a stage_module to be wrapped by this stage
1052    and pipeline information.
1053
1054    Args:
1055        stage_module (torch.nn.Module): the module to be wrapped by this stage
1056        stage_index (int): the index of this stage in the pipeline
1057        pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()`
1058        device (torch.device): the device to be used by this stage
1059        group (Optional[dist.ProcessGroup]): the process group to be used by this stage
1060
1061    Returns:
1062        _PipelineStage: a pipeline stage that can run with `PipelineSchedules`.
1063    """
1064    return _PipelineStage(
1065        stage_module,
1066        stage_index,
1067        pipe_info,
1068        device,
1069        group,
1070    )
1071
1072
1073# Manual PipelineStage functions and definition
1074
1075METADATA_TENSOR_LEN = 100
1076PLACEHOLDER_VAL = -1
1077
1078
1079def _create_empty_tensors(
1080    tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device
1081) -> List[torch.Tensor]:
1082    """
1083    Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s),
1084    and places them on the specified device.
1085    Args:
1086        tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s).
1087        device (torch.device): The device where the new tensors will be placed.
1088    Returns:
1089        List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s).
1090    """
1091    if isinstance(tensor, torch.Tensor):
1092        return [torch.empty_like(tensor, device=device)]
1093    elif isinstance(tensor, (list, tuple)):
1094        return [torch.empty_like(t, device=device) for t in tensor]
1095    raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors")
1096
1097
1098def _create_metadata_tensor(
1099    tensors: Optional[List[torch.Tensor]] = None,
1100    device: Optional[torch.device] = torch.device("cpu"),
1101) -> torch.Tensor:
1102    """
1103    Create a metadata tensor that can be sent over the wire.
1104    This tensor contains the number of dimensions and the shape of each tensor being sent.
1105
1106    The data is of format [num_dims, dim1, dim2, ...].
1107    If the tensor is None, a tensor of only placeholder values will be returned.
1108
1109    Inputs:
1110        tensors: A list of tensors, the tensors will converted into its shape dimensions and
1111                 these dimensions will be concatenated.
1112        device: The device where the metadata tensor will be created.
1113    If the tensor is None, then this tensor will contain PLACEHOLDER_VALs.
1114
1115    """
1116    metadata_tensor = torch.full(
1117        (METADATA_TENSOR_LEN,),
1118        PLACEHOLDER_VAL,
1119        dtype=torch.int32,
1120        device=device,
1121    )
1122    if tensors:
1123        # Create a list of tensors containing the number of dimensions and the shape of each tensor
1124        data = [
1125            # data is of format [num_dims, dim1, dim2, ...]
1126            torch.tensor(
1127                [len(tensor.shape)] + list(tensor.shape),
1128                dtype=torch.int32,
1129                device=device,
1130            )
1131            for tensor in tensors
1132        ]
1133        # Concatenate the data into a single tensor
1134        data_tensor = torch.cat(data)
1135        dt_shape = data_tensor.shape[0]
1136        if dt_shape > METADATA_TENSOR_LEN:
1137            raise ValueError(
1138                f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})."
1139            )
1140        metadata_tensor[:dt_shape] = data_tensor
1141    return metadata_tensor
1142
1143
1144def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]:
1145    """
1146    Extract the number of dimensions and the shape of each tensor from a metadata tensor.
1147    """
1148    metadata: List[torch.Size] = []
1149    i = 0
1150    while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL:
1151        num_dims = int(tensor[i].item())
1152        shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist())
1153        metadata.append(shape)
1154        i += num_dims + 1
1155    return metadata
1156
1157
1158def _get_stage_shapes(
1159    stage_modules: List[nn.Module],
1160    stage_ids: List[int],
1161    num_stages: int,
1162    rank: int,
1163    world_size: int,
1164    device: torch.device,
1165    microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
1166):
1167    """
1168    Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of
1169    virtual pipelining) and returns the shape of the inputs and outputs of the module.
1170    Only the first stage must pass in a microbatch.
1171
1172    Each rank must call _get_stage_shapes or the program will hang.
1173
1174    Args:
1175        stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any
1176                non-interleaved schedules and >1 for any interleaved schedules.
1177        stage_ids: The id of the stages assigned to this rank.
1178        num_stages: Total number of stages.
1179        rank: Rank of the current process.
1180        world_size: Number of processes participating in the pipeline.
1181        device: Device where the tensors are allocated.
1182
1183    Returns a dictionary containing the following keys:
1184        "inputs": Shape of the inputs to the module
1185        "outputs": Shape of the outputs of the module
1186    """
1187
1188    stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {}
1189    for stage_id, model in zip(stage_ids, stage_modules):
1190        input_shape_metadata_tensor = _create_metadata_tensor(device=device)
1191        # TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1
1192        prev_rank = (rank - 1) % world_size
1193        next_rank = (rank + 1) % world_size
1194        shapes = {}
1195
1196        # first stage doesn't receive anything and uses a microbatch
1197        if stage_id == 0:
1198            if microbatch is None:
1199                raise RuntimeError("Microbatch is required for first stage")
1200            example_fwd_inputs = microbatch
1201            if isinstance(example_fwd_inputs, torch.Tensor):
1202                example_fwd_inputs = [example_fwd_inputs]
1203        else:
1204            # other stages must receive shape information
1205            # TODO: send/recv should take a group, rather than use the default group
1206            dist.recv(input_shape_metadata_tensor, prev_rank)
1207            metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor)
1208            example_fwd_inputs = [
1209                torch.empty(shape_list, device=device) for shape_list in metadata
1210            ]
1211        shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs]
1212
1213        # perform forward
1214        # TODO: if forward fails raise a more descriptive error explaining which stage failed
1215        fwd_outputs = model(*example_fwd_inputs)
1216        fwd_outputs = _create_empty_tensors(fwd_outputs, device)
1217        shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs]
1218
1219        # send shape dims
1220        if stage_id != num_stages - 1:
1221            output_shape_metadata_tensor = _create_metadata_tensor(
1222                fwd_outputs, device=device
1223            )
1224            dist.send(output_shape_metadata_tensor, next_rank)
1225        stage_id_to_shapes[stage_id] = shapes
1226    logger.info(stage_id_to_shapes)
1227    return stage_id_to_shapes
1228
1229
1230class PipelineStage(_PipelineStageBase):
1231    """
1232    A class representing a pipeline stage in a pipeline parallelism setup.
1233    This class is created manually by providing a example input (and optionally output)
1234    as opposed to the PipelineStage class that is outputed from pipeline().
1235    This class extends the `_PipelineStageBase` class and can similarly be used
1236    in `PipelineScheule`.
1237
1238    Args:
1239        submodule (nn.Module): The PyTorch module wrapped by this stage.
1240        stage_index (int): The ID of this stage.
1241        num_stages (int): The total number of stages.
1242        device (torch.device): The device where this stage is located.
1243        input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule.
1244        output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule.
1245        group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group.
1246        dw_builder: TODO clean up comments
1247    """
1248
1249    def __init__(
1250        self,
1251        submodule: nn.Module,
1252        stage_index: int,
1253        num_stages: int,
1254        device: torch.device,
1255        input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
1256        output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None,
1257        group: Optional[dist.ProcessGroup] = None,
1258        dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
1259    ):
1260        super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
1261        self.submod.to(self.device)
1262        # When we materialize the model partition on cuda, we call reset_parameters() if it is available
1263        self.inputs: List[torch.Tensor] = []
1264        self.outputs: List[torch.Tensor] = []
1265
1266        self.inputs = _create_empty_tensors(input_args, device)
1267
1268        if output_args is None:
1269            logger.info("output_args not provided, performing forward using input_args")
1270            self.outputs = self.submod(*self.inputs)
1271            # create buffers for the output so that the data is in the correct
1272            # shape in order to use in p2p op (send)
1273            self.outputs = _create_empty_tensors(self.outputs, device)
1274        else:
1275            self.outputs = _create_empty_tensors(output_args, device)
1276
1277        self._configure_outputs_meta(tuple(self.outputs))
1278
1279        # these are the buffers used in backwards send/recv, they are allocated later
1280        self.outputs_grad: List[torch.Tensor] = []
1281
1282        def stage_global_rank(peer_rank):
1283            return (
1284                peer_rank
1285                if self.group is None
1286                else dist.get_global_rank(self.group, peer_rank)
1287            )
1288
1289        self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size)
1290        self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size)
1291
1292        logger.debug(
1293            f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, "  # noqa: G004
1294            f"{self.is_last=}, {self.num_stages=}, "
1295            f"inputs: {[inp.shape for inp in self.inputs]}, "
1296            f"output: {[output.shape for output in self.outputs]}"
1297        )
1298
1299    def _prepare_forward_infra(self, num_microbatches: int) -> None:
1300        # Receive info during forward
1301        # TODO: create args_recv_info lazily? (same needed for PipelineStage)
1302        for chunk_id in range(num_microbatches):
1303            self.set_requires_grad[chunk_id] = False
1304            if not self.is_first:
1305                # We assume that we always receive from stage - 1
1306                recv_infos = tuple(
1307                    [
1308                        _RecvInfo(
1309                            f"recv_for_{self.stage_index}_from_{self.stage_index - 1}",
1310                            self.stage_index - 1,
1311                            _make_tensor_from_meta(inp, self.device),
1312                        )
1313                        for inp in self.inputs
1314                    ]
1315                )
1316
1317                self.args_recv_info[chunk_id] = recv_infos
1318            else:
1319                self.args_recv_info[chunk_id] = tuple(
1320                    [_RootArgPlaceholder(i) for i in self.inputs]
1321                )
1322
1323        # Send info during forward for each activation
1324        # only need the rank that is being sent to
1325        self.act_send_info: Dict[int, List] = {}
1326        for idx in range(len(self.outputs)):
1327            # We assume we always send to stage + 1
1328            if not self.is_last:
1329                self.act_send_info[idx] = [self.stage_index + 1]
1330            else:
1331                self.act_send_info[idx] = []
1332
1333    def _create_grad_recv_info(
1334        self,
1335        act_send_info: Dict,
1336    ) -> Tuple[_RecvInfo, ...]:
1337        grad_recv_info: Tuple[_RecvInfo, ...] = ()
1338        if not self.is_last:
1339            # Receiving gradients from multiple sources is not supported
1340            # hence we only take the first destination
1341            grad_recv_info = tuple(
1342                [
1343                    _RecvInfo(
1344                        f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}",
1345                        dst_list[0],
1346                        _make_tensor_from_meta(self.outputs[idx], self.device),
1347                    )
1348                    for idx, dst_list in act_send_info.items()
1349                ]
1350            )
1351        return grad_recv_info
1352
1353    def _init_p2p_neighbors(self):
1354        """
1355        Set up p2p communitors between previous and next stages
1356        by sending a dummy tensor.
1357
1358        If this is used, must be called for all pipeline stages.
1359        """
1360        ops = []
1361        recv_tensor = torch.zeros(1, device="cuda")
1362        send_tensor = torch.ones(1, device="cuda")
1363        # forward
1364        if not self.is_first:
1365            ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group))
1366        if not self.is_last:
1367            ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group))
1368
1369        # backward
1370        if not self.is_first:
1371            ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group))
1372        if not self.is_last:
1373            ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group))
1374
1375        return True
1376
1377
1378def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
1379    """
1380    Check that the buffer shapes match between stages was expected by performing an all_gather between
1381    all stages.
1382    """
1383    if len(pipeline_stages) == 0:
1384        raise ValueError("No pipeline stages provided.")
1385
1386    virtual_pipeline_size = len(pipeline_stages)
1387    all_inputs = []
1388    all_outputs = []
1389    world_size = pipeline_stages[0].group_size
1390    num_stages = pipeline_stages[0].num_stages
1391
1392    # perform all gathers between all stages
1393    for virtual_id, stage in enumerate(pipeline_stages):
1394        world_size = stage.group_size
1395        stage_id: int = stage.stage_index
1396        rank = stage.group_rank
1397        # check that world_size and num_stages are consistent across all stages
1398        if stage.group_size != world_size:
1399            raise ValueError(
1400                f"Stage id {stage_id} has world size ({stage.group_size}) \
1401                which does not match world size ({world_size}) of other stages."
1402            )
1403        if stage.num_stages != num_stages:
1404            raise ValueError(
1405                f"Stage id {stage_id} has num stages ({stage.num_stages}) \
1406                which does not match num stages ({num_stages}) of other stages."
1407            )
1408
1409        pg_rank = dist.get_rank(stage.group)
1410        if rank != pg_rank:
1411            raise ValueError(
1412                f"Rank {rank} is not equal to process group rank {pg_rank}"
1413            )
1414
1415        if (num_stages := stage.num_stages) % world_size != 0:
1416            raise ValueError(
1417                f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})"
1418            )
1419
1420        # all gather each ranks inputs
1421        tensor_list = [
1422            _create_metadata_tensor(device=stage.device)
1423            for _ in range(stage.group_size)
1424        ]
1425        expected_inputs = stage.inputs
1426        stage_input = _create_metadata_tensor(expected_inputs, device=stage.device)
1427        dist.all_gather(tensor_list, stage_input)
1428        stage_input_shapes = [
1429            _extract_metadata_from_tensor(tensor) for tensor in tensor_list
1430        ]
1431
1432        # all gather each ranks outputs
1433        tensor_list = [
1434            _create_metadata_tensor(device=stage.device)
1435            for _ in range(stage.group_size)
1436        ]
1437        expected_outputs = stage.outputs
1438        stage_output = _create_metadata_tensor(expected_outputs, device=stage.device)
1439        dist.all_gather(tensor_list, stage_output)
1440        stage_output_shapes = [
1441            _extract_metadata_from_tensor(tensor) for tensor in tensor_list
1442        ]
1443
1444        logger.debug(
1445            f"Rank: {pg_rank}"  # noqa: G004
1446            f"Stage id: {stage_id}"
1447            f"Stage num stages: {stage.num_stages}"
1448            f"Stage rank: {rank}"
1449            f"Stage world size: {world_size}"
1450            f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}"  # noqa: G003
1451            f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}"  # noqa: G003
1452        )
1453
1454        all_inputs.extend(stage_input_shapes)
1455        all_outputs.extend(stage_output_shapes)
1456
1457        # log only rank 0's view, they will all be equivalent
1458        if pg_rank == 0:
1459            logger.info(
1460                "all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs
1461            )
1462
1463    # Check if the output for stage 0 matches the input at stage 1, and so forth
1464    for i in range(virtual_pipeline_size * world_size - 1):
1465        if (out := all_outputs[i]) != (inp := all_inputs[i + 1]):
1466            raise ValueError(
1467                f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
1468            )
1469