1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import logging 4import operator 5from abc import ABC, abstractmethod 6from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 7 8import torch 9import torch.distributed as dist 10import torch.fx as fx 11import torch.nn as nn 12from torch._subclasses.fake_tensor import FakeTensor 13from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard 14from torch.fx.node import map_aggregate 15from torch.nn.parallel import DistributedDataParallel 16 17from ._backward import stage_backward, stage_backward_input, stage_backward_weight 18from ._debug import map_debug_info 19from ._utils import flatten_args, PipeInfo, validate_tensors_metadata 20 21 22__all__ = [ 23 "PipelineStage", 24 "build_stage", 25] 26 27logger = logging.getLogger(__name__) 28 29 30class _RootArgPlaceholder: 31 """ 32 Placeholder for model-level inputs. 33 """ 34 35 def __init__(self, tensor): 36 self.meta = tensor.to("meta") 37 38 39class _RecvInfo: 40 """ 41 Represents a stage input. 42 """ 43 44 def __init__( 45 self, 46 input_name: str, 47 source: int, 48 buffer: torch.Tensor, 49 ): 50 # Name of this input 51 self.input_name = input_name 52 # Stage index of the source of this input 53 self.source = source 54 # Buffer to receive the input into. 55 self.buffer = buffer 56 57 def __repr__(self): 58 return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" 59 60 61# An input can be either a received activation or a model input 62InputInfo = Union[_RecvInfo, _RootArgPlaceholder] 63 64 65def _make_tensor_from_meta( 66 example: Union[torch.Tensor, FakeTensor], 67 device: torch.device, 68) -> torch.Tensor: 69 """ 70 Create a real tensor from a tensor. 71 """ 72 return torch.empty( 73 example.size(), 74 dtype=example.dtype, 75 layout=example.layout, 76 device=device, 77 ) 78 79 80class _PipelineStageBase(ABC): 81 """ 82 Base class for pipeline stages. 83 Defines or implements common methods used by the `_PipelineStage` used by 84 the tracing frontend and `PipelineStage` used by manual frontend. 85 """ 86 87 def __init__( 88 self, 89 submodule: torch.nn.Module, 90 stage_index: int, 91 num_stages: int, 92 device: torch.device, 93 group: Optional[dist.ProcessGroup] = None, 94 dw_builder: Optional[Callable[[], Callable[..., None]]] = None, 95 ): 96 """ 97 Args: 98 submodule (torch.nn.Module): The module to be executed in this stage. 99 stage_index (int): The index of this stage. 100 num_stages (int): The total number of stages in this pipeline. 101 device (torch.device): The device to run this stage on. 102 group (Optional[dist.ProcessGroup]): The process group to use for communication. 103 If `None`, the default process group will be used. 104 Default: `None`. 105 dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_runner is a builder function 106 that will build a new dw_runner function that will run parts of module backward that were intentionally 107 skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs 108 model backwards, and stage should save the latest dw_runner to run during weight pass. 109 If not provided, a dw_runner will be generated automatically by traversing the autograd graph. 110 When used with schedules that only have F and B steps, the fresh dw_runner function will be called as 111 part of B. 112 When used with F,B,W schedules, the dw_runner function implements 'W'. 113 """ 114 super().__init__() 115 if stage_index >= num_stages: 116 raise ValueError( 117 f"Stage index {stage_index} is out of range of {num_stages}" 118 ) 119 120 self.submod = submodule 121 self.stage_index = stage_index 122 self.num_stages = num_stages 123 self.device = device 124 self.group = group 125 126 self.dw_builder = dw_builder 127 128 # backward state 129 self.backward_state: Dict[int, Tuple[Any, ...]] = {} 130 131 # store dw_runner per microbatch_id 132 self.dw_runner: Dict[int, Callable[..., None]] = {} 133 134 # `group_rank` is rank in process group `group`. 135 self.group_rank = dist.get_rank(self.group) 136 self.group_size = dist.get_world_size(self.group) 137 if self.group_size > self.num_stages: 138 raise RuntimeError( 139 f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" 140 ) 141 142 # Run time states 143 self._outputs_meta: Optional[Tuple[torch.Tensor, ...]] = None 144 # map microbatch ID to list of forward tensor args 145 self.fwd_cache: Dict[int, Tuple[Any, List[torch.Tensor]]] = {} 146 # Caching chunk outputs for final output merge or reduction 147 self.output_chunks: List[Any] = [] 148 149 # Initialize has_backward to false; this will be set to true if loss 150 # function is passed to pipeline schedule 151 self.has_backward = False 152 # Log prefix 153 self.log_prefix = f"[Stage {self.stage_index}]" 154 155 # Forward infra 156 self.args_recv_info: Dict[int, Tuple[InputInfo, ...]] = {} 157 self.set_requires_grad: Dict[int, bool] = {} 158 self.act_send_info: Dict[int, List] = {} 159 160 # Backward infra will created lazily 161 self.grad_recv_info: Dict = {} 162 self.grad_send_info: Optional[List] = None 163 164 # Number of backward chunks seen. This is used to determine when to do 165 # grad reduction in DDP or FSDP. 166 self._seen_bwd_chunks = 0 167 168 # To be populated later by the Schedule 169 self.chunks: Optional[int] = None 170 self.stage_index_to_group_rank: Dict[int, int] = { 171 i: i % self.group_size for i in range(self.num_stages) 172 } 173 174 @property 175 def has_backward(self) -> bool: 176 """ 177 Returns true if this stage has a backward pass. 178 """ 179 return self._has_backward 180 181 @has_backward.setter 182 def has_backward(self, has_backward: bool): 183 self._has_backward = has_backward 184 185 @property 186 def is_first(self): 187 """ 188 Returns true if this stage is the first stage in the pipeline. 189 """ 190 return self.stage_index == 0 191 192 @property 193 def is_last(self): 194 """ 195 Returns true if this stage is the last stage in the pipeline. 196 """ 197 return self.stage_index == self.num_stages - 1 198 199 def _check_chunk_id(self, chunk_id: int): 200 if self.chunks is None: 201 raise RuntimeError( 202 "Attempted to access chunk_id before chunks have been configured." 203 ) 204 if chunk_id >= self.chunks: 205 raise RuntimeError( 206 f"Chunk id {chunk_id} is out of range [0, {self.chunks})" 207 ) 208 209 def _configure_outputs_meta(self, outputs_meta: Tuple[torch.Tensor, ...]): 210 """ 211 Track the output shapes/dtype of this stage since they determine the send operation(s) which must match 212 recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial 213 configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches 214 which could show up as hangs, silent corruption, or other errors. 215 """ 216 assert ( 217 self._outputs_meta is None 218 ), "Attempting to reconfigure output_meta, which is not supported" 219 self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] 220 221 def get_outputs_meta(self) -> Tuple[torch.Tensor, ...]: 222 """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" 223 assert ( 224 self._outputs_meta is not None 225 ), "Attempted to get_outputs_meta() without configuring output meta" 226 return self._outputs_meta 227 228 def _create_grad_send_info( 229 self, 230 args_recv_info: Tuple, 231 ) -> List[Optional[int]]: 232 """ 233 Create a list of stage indices to send gradients to. 234 """ 235 grad_send_info: List[Optional[int]] = [] 236 237 def map_recv_to_send(a): 238 # Note: we send gradients back to previous stage as long as in 239 # forward it is a received input, regardless of whether it requires 240 # grad. It is up to the previous stage to disgard this gradient. 241 if isinstance(a, _RecvInfo): 242 grad_send_info.append(a.source) 243 return a.source 244 else: 245 grad_send_info.append(None) 246 return None 247 248 map_aggregate(args_recv_info, map_recv_to_send) 249 250 logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) 251 return grad_send_info 252 253 @abstractmethod 254 def _prepare_forward_infra(self, num_microbatches: int): 255 raise NotImplementedError 256 257 def _prepare_backward_infra(self, num_microbatches: int): 258 # TODO: this is needed for backward_maybe_with_nosync 259 self.chunks = num_microbatches 260 261 for mb_index in range(num_microbatches): 262 # `grad_recv_info` is a mirror of `act_send_info` 263 self.grad_recv_info[mb_index] = self._create_grad_recv_info( 264 self.act_send_info 265 ) 266 267 @abstractmethod 268 def _create_grad_recv_info( 269 self, 270 act_send_info: Dict, 271 ) -> Tuple[_RecvInfo, ...]: 272 raise NotImplementedError 273 274 def _get_recv_ops( 275 self, 276 recv_infos: Tuple[InputInfo, ...], 277 ) -> List[dist.P2POp]: 278 """ 279 Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. 280 Returns a list of ops that correspond to the recv infos. 281 """ 282 ops: List[dist.P2POp] = [] 283 for info in recv_infos: 284 if not isinstance(info, _RecvInfo): 285 continue 286 287 peer_rank = self.stage_index_to_group_rank[info.source] 288 peer_global_rank = ( 289 peer_rank 290 if self.group is None 291 else dist.get_global_rank(self.group, peer_rank) 292 ) # TODO 293 ops.append( 294 dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) 295 ) 296 297 return ops 298 299 def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: 300 """ 301 Returns a list of ops that are needed to receive the input arguments 302 for this stage. 303 """ 304 recv_infos: Tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] 305 306 # In case there is backward pass, set requires_grad for receive buffers 307 # before first forward 308 if self.has_backward and not self.set_requires_grad[fwd_chunk_id]: 309 for a in recv_infos: 310 if isinstance(a, _RecvInfo): 311 a.buffer.requires_grad_(True) 312 313 return self._get_recv_ops(recv_infos) 314 315 def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: 316 """ 317 Returns a list of ops that are needed to receive the gradients 318 for this stage. 319 """ 320 if not self.has_backward or self.is_last: 321 return [] 322 323 recv_infos = self.grad_recv_info[bwd_chunk_id] 324 return self._get_recv_ops(recv_infos) 325 326 def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: 327 """ 328 Get the activation send ops for current stage's forward. 329 """ 330 output = self.output_chunks[fwd_chunk_id] 331 # Unify output form to tuple for easy correspondance with 332 # `act_send_info` 333 output_tuple = output if type(output) is tuple else (output,) 334 335 ops: List[dist.P2POp] = [] 336 337 for idx, out in enumerate(output_tuple): 338 dst_stages = self.act_send_info[idx] 339 for dst in dst_stages: 340 if dst is None: 341 continue 342 logger.debug( 343 "%s Sending tensor to Stage %s: %s", 344 self.log_prefix, 345 dst, 346 out.size(), 347 ) 348 peer_rank = self.stage_index_to_group_rank[dst] 349 peer_global_rank = ( 350 peer_rank 351 if self.group is None 352 else dist.get_global_rank(self.group, peer_rank) 353 ) # TODO 354 ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) 355 356 return ops 357 358 def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: 359 """ 360 Get the gradient send ops for current stage's backward. 361 """ 362 self._check_chunk_id(bwd_chunk_id) 363 364 if not self.has_backward or self.is_first: 365 return [] 366 367 # Create bwd send infra lazily 368 if self.grad_send_info is None: 369 # Send info for input grads during backward: 370 # List of destinations corresponding to input grads 371 # Can be None if an input has no grad 372 # `grad_send_info` is a mirror of `args_recv_info` 373 self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) 374 375 ops: List[dist.P2POp] = [] 376 for grad, grad_recv_stage in zip(self.grads_input, self.grad_send_info): 377 if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: 378 logger.debug( 379 "%s Sending gradient to Stage %s: %s", 380 self.log_prefix, 381 grad_recv_stage, 382 grad.size(), 383 ) 384 peer_rank = self.stage_index_to_group_rank[grad_recv_stage] 385 peer_global_rank = ( 386 peer_rank 387 if self.group is None 388 else dist.get_global_rank(self.group, peer_rank) 389 ) # TODO 390 ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) 391 else: 392 if not (grad is None and grad_recv_stage is None): 393 raise RuntimeError( 394 f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " 395 f"and is expecting to send gradients to stage {grad_recv_stage}" 396 ) 397 return ops 398 399 def clear_runtime_states(self) -> None: 400 """ 401 Clear runtime states of the stage. 402 """ 403 # map microbatch ID to list of forward tensor args 404 self.fwd_cache.clear() 405 # Caching chunk outputs for final output merge or reduction 406 self.output_chunks.clear() 407 # Reset bwd chunk counter 408 self._seen_bwd_chunks = 0 409 410 # Clear grad of input buffers in between schedule steps. This is because 411 # `torch.autograd.backward()` will accumulate gradients into leaf 412 # tensors by default. For gradients to pass back to previous stages, we 413 # don't want such accumulation. 414 for recv_tuple in self.args_recv_info.values(): # iterate over all chunks 415 for a in recv_tuple: # iterate over all input args 416 if isinstance(a, _RecvInfo): 417 # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. 418 # See https://github.com/pytorch/pytorch/pull/92731 419 a.buffer.grad = None 420 421 def _map_tensor_from_recv_info( 422 self, 423 recv_infos: Tuple[InputInfo, ...], 424 ): 425 """ 426 Map tensors from recv infos to a list. 427 """ 428 429 def get_recv_tensor(info): 430 if isinstance(info, _RecvInfo): 431 return info.buffer 432 else: 433 raise AssertionError(f"Expected _RecvInfo but got {type(info)}") 434 435 tensors = map_aggregate( 436 recv_infos, 437 get_recv_tensor, 438 ) 439 440 return tensors 441 442 def _retrieve_recv_activations(self, fwd_chunk_id: int): 443 """ 444 Retrieve the activations received for the current stage during forward. 445 """ 446 recv_infos = self.args_recv_info[fwd_chunk_id] 447 activations = self._map_tensor_from_recv_info(recv_infos) 448 return activations 449 450 def _retrieve_recv_grads( 451 self, 452 bwd_chunk_id: int, 453 ): 454 """ 455 Retrieve the gradients received for the current stage during backward. 456 """ 457 recv_infos = self.grad_recv_info[bwd_chunk_id] 458 grads = self._map_tensor_from_recv_info(recv_infos) 459 return grads 460 461 def forward_maybe_with_nosync(self, *args, **kwargs): 462 # If submod is wrapped with DDP, we use the `no_sync` context manager to 463 # avoid gradient all-reduce per microbatch 464 if isinstance(self.submod, DistributedDataParallel): 465 with self.submod.no_sync(): # type: ignore[operator] 466 out_val = self.submod(*args, **kwargs) 467 else: 468 out_val = self.submod(*args, **kwargs) 469 return out_val 470 471 def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict): 472 """ 473 Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the 474 other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but 475 there are additional state-variables and performance considerations depending on the data parallelism used. 476 This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. 477 """ 478 full_backward = bwd_kwargs["full_backward"] 479 if full_backward: 480 last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] 481 else: 482 # For backwards are split into weight and input, we will see twice as many bwd_chunks 483 last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1 # type: ignore[operator] 484 485 def perform_backward(backward_type): 486 if backward_type == "full": 487 return lambda: stage_backward( 488 bwd_kwargs["stage_output"], 489 bwd_kwargs["output_grads"], 490 bwd_kwargs["input_values"], 491 ) 492 elif backward_type == "input": 493 return lambda: stage_backward_input( 494 bwd_kwargs["stage_output"], 495 bwd_kwargs["output_grads"], 496 bwd_kwargs["input_values"], 497 self.submod.parameters(), 498 ) 499 elif backward_type == "weight": 500 return lambda: stage_backward_weight( 501 self.submod.parameters(), bwd_kwargs["param_groups"] 502 ) 503 else: 504 raise RuntimeError(f"Unknown backward type: {backward_type}") 505 506 # If submod is wrapped by DDP 507 if isinstance(self.submod, DistributedDataParallel): 508 if last_backward: 509 # Last chunk, prepare for gradient reduction 510 # HACK: reaching into DDP implementation details here. Is there a better way? 511 self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] 512 list( 513 torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] 514 bwd_kwargs["stage_output"] 515 ) 516 ) 517 ) 518 result = perform_backward(backward_type)() 519 else: 520 with self.submod.no_sync(): # type: ignore[operator] 521 result = perform_backward(backward_type)() 522 # If submod is a FSDP module 523 elif isinstance(self.submod, FSDPModule): 524 self.submod.set_is_last_backward(False) 525 self.submod.set_reshard_after_backward(False) 526 self.submod.set_requires_gradient_sync(False) 527 result = perform_backward(backward_type)() 528 if last_backward: 529 # Manually call post backward for FSDP 530 def run_post_backward(fsdp_module: FSDPModule) -> None: 531 fsdp_module.set_is_last_backward(True) 532 fsdp_module.set_reshard_after_backward(True) 533 fsdp_module.set_requires_gradient_sync(True) 534 fsdp_state = fully_shard.state(fsdp_module) 535 for state in fsdp_state._state_ctx.all_states: 536 if state._fsdp_param_group: 537 state._fsdp_param_group.post_backward() 538 539 run_post_backward(self.submod) 540 else: 541 # Non-DP submodule, regular backward 542 result = perform_backward(backward_type)() 543 544 self._seen_bwd_chunks += 1 545 546 if isinstance(result, tuple) and len(result) == 2: 547 # for stage_backward_input() 548 grads, param_groups = result 549 else: 550 grads, param_groups = result, None 551 552 return grads, param_groups 553 554 def forward_one_chunk( 555 self, 556 fwd_chunk_id: int, 557 args: Tuple[Any, ...], 558 kwargs: Optional[Dict[str, Any]] = None, 559 ): 560 """ 561 Perform forward pass on the stage with one microbatch. 562 `args` and `kwargs` are the inputs from *external* to this stage. They 563 applies only to the first stage in most cases. 564 """ 565 566 if self.is_first: 567 # First stage doesn't need to receive anything 568 composite_args = args 569 composite_kwargs = kwargs or {} 570 else: 571 # Receive activations for this chunk 572 # Activations only come in args form 573 composite_args = self._retrieve_recv_activations(fwd_chunk_id) 574 composite_kwargs = {} 575 576 self._validate_fwd_input(args, kwargs) 577 578 # Compute forward 579 try: 580 output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) 581 582 except Exception as e: 583 exc_msg = f""" 584 {self.log_prefix} failed to run forward: 585 args: {map_debug_info(composite_args)} 586 kwargs: {map_debug_info(composite_kwargs)} 587 """ 588 raise RuntimeError(exc_msg) from e 589 590 if type(output) is list: 591 # HACK: this is a hacky workaround for the fact that export creates 592 # output in list format 593 output = tuple(output) 594 595 # Unify output form to tuple for easy correspondance with 596 # `act_send_info` 597 output_tuple = output if type(output) is tuple else (output,) 598 # Prepare for final output merge or reduction 599 self.output_chunks.append(output) 600 601 # Save activations and inputs for backward 602 flat_args = flatten_args(composite_args) 603 flat_kwargs = flatten_args(composite_kwargs) 604 flatten_input_tensors = flat_args + flat_kwargs 605 self.fwd_cache[fwd_chunk_id] = ( 606 output_tuple, # stage_output 607 flatten_input_tensors, # input_values 608 ) 609 610 logger.debug( 611 "%s Forwarded chunk %s, outputs: %s", 612 self.log_prefix, 613 fwd_chunk_id, 614 map_debug_info(output), 615 ) 616 self._validate_fwd_outputs(output_tuple) 617 return output 618 619 def backward_one_chunk( 620 self, bwd_chunk_id: int, loss=None, full_backward: bool = True 621 ): 622 """ 623 Perform backward pass on the module. 624 This should only be called once per microbatch. 625 626 If full_backward is True (the default), the full backward pass including weight and input gradients will be run, 627 and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. 628 629 If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, 630 and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. 631 """ 632 self._check_chunk_id(bwd_chunk_id) 633 634 ( 635 stage_output, 636 input_values, 637 ) = self.fwd_cache.pop(bwd_chunk_id) 638 639 # Compute backward 640 if self.is_last: 641 # Last stage computes gradients from loss and has no gradients from 642 # next stage 643 bwd_kwargs = { 644 "stage_output": loss, 645 "output_grads": None, 646 "input_values": input_values, 647 } 648 else: 649 # Otherwise, receive gradients from next stage 650 grads_output = self._retrieve_recv_grads(bwd_chunk_id) 651 # If an input to the pipeline requires gradient, 652 # `torch.autograd.backward` will accumulate the gradient into the 653 # `.grad` field of such input 654 bwd_kwargs = { 655 "stage_output": stage_output, 656 "output_grads": grads_output, 657 "input_values": input_values, 658 } 659 660 # Save full_backward 661 bwd_kwargs["full_backward"] = full_backward 662 663 # Custom backward function 664 if self.dw_builder: 665 # TODO: We may want to change our semantics so we are allowed to ignore 666 # the 'dw_builder' and call full_backward directly when it is a full_backward op. 667 self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs) 668 if full_backward: 669 self.dw_builder()() 670 else: 671 self.dw_runner[bwd_chunk_id] = self.dw_builder() 672 else: 673 if full_backward: 674 self.grads_input, _ = self.backward_maybe_with_nosync( 675 "full", bwd_kwargs 676 ) 677 else: 678 # perform the partial backwards for the inputs with a custom backward function 679 # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors 680 if isinstance(bwd_kwargs["stage_output"], torch.Tensor): 681 bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) 682 683 grads_input, param_groups = self.backward_maybe_with_nosync( 684 "input", bwd_kwargs 685 ) 686 687 # TODO: we dont need to save this, add to dw_runner? 688 self.backward_state[bwd_chunk_id] = ( 689 input_values, 690 param_groups, 691 bwd_kwargs["stage_output"], 692 bwd_kwargs["output_grads"], 693 ) 694 self.grads_input = grads_input 695 # Save a placeholder for the dw_runner 696 self.dw_runner[bwd_chunk_id] = lambda: None 697 logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) 698 699 def backward_weight_one_chunk(self, bwd_chunk_id: int): 700 assert bwd_chunk_id in self.dw_runner, ( 701 f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" 702 " without first calling `backward_one_chunk(full_backward=False)`" 703 ) 704 705 if self.dw_builder is not None: 706 self.dw_runner.pop(bwd_chunk_id)() 707 else: 708 ( 709 input_values, 710 param_groups, 711 stage_output, 712 output_grads, 713 ) = self.backward_state.pop(bwd_chunk_id) 714 715 if self.stage_index != 0: 716 bwd_kwargs = { 717 "stage_output": stage_output, 718 "param_groups": param_groups, 719 "full_backward": False, 720 } 721 weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs) 722 else: 723 # TODO: figure out a better way to do this: 724 # if inputs does not require gradient, 725 # then the parameter group will not be fully captured during stage_backward_input 726 # in this case, we need call grad directly on the parameters 727 # To solve: make input fn do the intersect compute and then finish it off during W 728 bwd_kwargs = { 729 "stage_output": stage_output, 730 "output_grads": output_grads, 731 "input_values": input_values, 732 "full_backward": False, 733 } 734 self.backward_maybe_with_nosync("full", bwd_kwargs) 735 736 def _validate_fwd_input(self, args, kwargs): 737 """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" 738 739 if self.is_first: 740 # TODO why is there a separate recv_info for each pipeline chunk? 741 # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we 742 # check all chunks against args_recv_info[0] 743 expected_args = self.args_recv_info[0] 744 else: 745 # We don't check inputs for non-0 stages assuming they don't accept 746 # user inputs in canonical pipeline scenarios 747 return 748 749 if len(kwargs): 750 # TODO- need a mapping of kwarg to position in self.args_recv_info 751 # without it, we just validate shapes for args and ignore kwargs 752 expected_args = expected_args[: len(expected_args) - len(kwargs)] 753 754 # TODO- need a mapping of kwarg to position in self.args_recv_info 755 # maybe it's impossible to tell whether the len mismatches because 756 # (a) the user passed an extra arg or missed an arg 757 # (b) the user did not pass a kwarg, which has a default value baked into expected_args 758 expected_tensors_meta = [ 759 e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer 760 for e in expected_args 761 ] 762 validate_tensors_metadata( 763 f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args 764 ) 765 766 def _validate_fwd_outputs(self, outputs: Tuple[torch.Tensor, ...]): 767 """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. 768 Most likely, this could be cause either by incorrect user specification of output shapes, or becuase 769 shape inference was done on the original model but then at runtime the model is wrapped with something like 770 mixed precision which changes output dtype. 771 """ 772 expected_tensors_meta = self.get_outputs_meta() 773 validate_tensors_metadata( 774 f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs 775 ) 776 777 778class _PipelineStage(_PipelineStageBase): 779 def __init__( 780 self, 781 stage_module: torch.nn.Module, 782 stage_index: int, 783 pipe_info: PipeInfo, 784 device: torch.device, 785 group: Optional[dist.ProcessGroup] = None, 786 ): 787 """ 788 Create a pipeline stage given a stage_module to be wrapped by this stage 789 and a `pipe_info` describing the stage relationship of the pipeline. 790 791 Args: 792 stage_module (torch.nn.Module): the module to be wrapped by this stage 793 stage_index (int): the index of this stage in the pipeline 794 pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` 795 device (torch.device): the device to be used by this stage 796 group (Optional[dist.ProcessGroup]): the process group to be used by this stage 797 """ 798 _PipelineStageBase.__init__( 799 self, 800 stage_module, 801 stage_index, 802 pipe_info.num_stages, 803 device, 804 group, 805 ) 806 self.pipe_info = pipe_info 807 808 # Find stage nodes in graph 809 submod_nodes = [ 810 node for node in pipe_info.graph.nodes if node.op == "call_module" 811 ] 812 if len(submod_nodes) != self.num_stages: 813 raise AssertionError( 814 f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" 815 ) 816 817 # Find my stage node in graph 818 self.node = submod_nodes[self.stage_index] 819 self.name = self.node.name 820 logger.info( 821 "[%s] Creating PipelineStage %s for %s", 822 self.group_rank, 823 stage_index, 824 self.name, 825 ) 826 827 # Create mapping from stage name to stage index 828 self.submod_to_stage_index: Dict[str, int] = {} 829 for i, node in enumerate(submod_nodes): 830 self.submod_to_stage_index.setdefault(node.name, i) 831 832 # Cast submodule to device 833 self._move_submod_to_device() 834 835 def _move_submod_to_device(self): 836 # Move submodule to indicated device if possible 837 # Note: we cannot move meta module to real devices because meta tensors 838 # do not support to() method. One needs to do an in-place tensor swap in 839 # that case. 840 has_meta_param = any( 841 isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters() 842 ) 843 if has_meta_param: 844 logger.debug("%s Found meta parameters!", self.log_prefix) 845 else: 846 self.submod.to(self.device) 847 848 def _prepare_forward_infra(self, num_microbatches: int): 849 """ 850 Create send/recv infrastructures for activations (during forward) 851 """ 852 # Flag per chunk to keep track of whether we have set `requires_grad` 853 # for receive buffers. Format: {chunk : Boolean} 854 for chunk in range(num_microbatches): 855 self.args_recv_info[chunk] = self._create_act_recv_info() 856 self.set_requires_grad[chunk] = False 857 858 # Send info during forward for each activation 859 self.act_send_info = self._create_act_send_info() 860 861 def get_stage_index_of_submod( 862 self, 863 submod_name: str, 864 ): 865 """ 866 Given a submodule name, return the stage index of the submodule. 867 """ 868 if submod_name not in self.submod_to_stage_index: 869 raise AssertionError(f"Stage id of {submod_name} not found") 870 871 return self.submod_to_stage_index[submod_name] 872 873 def _create_act_recv_info( 874 self, 875 ): 876 """ 877 Create a tuple of `_RecvInfo` for inputs to the stage. 878 """ 879 880 def create_recv_tensor(placeholder, arg_node): 881 """ 882 Create a receive buffer for a placeholder. 883 """ 884 example_value = placeholder.meta["val"] 885 if arg_node.op == "placeholder": 886 # This is a root level placeholder, thus an input argument to the entire model. 887 # We are likely at stage 0, hence no need to create a receive buffer. 888 return _RootArgPlaceholder(example_value) 889 890 # Figure out the source stage of this input 891 while arg_node.target is operator.getitem: 892 # If the input is a getitem, we need to go deeper 893 arg_node = arg_node.args[0] 894 895 assert ( 896 arg_node.op == "call_module" 897 ), f"Expecting call_module, got {arg_node.op}" 898 src_stage = self.get_stage_index_of_submod(arg_node.name) 899 900 # Create a receive buffer for this placeholder 901 logger.debug( 902 "%s Creating recv buffer for input '%s' : %s, %s", 903 self.log_prefix, 904 placeholder.name, 905 example_value.shape, 906 example_value.dtype, 907 ) 908 buffer = _make_tensor_from_meta(example_value, self.device) 909 910 return _RecvInfo( 911 arg_node.name, 912 src_stage, 913 buffer, 914 ) 915 916 args_recv_info: List[InputInfo] = [] 917 # Filter out placeholder nodes from `self.submod` (a GraphModule) 918 placeholders = filter( 919 lambda node: node.op == "placeholder", self.submod.graph.nodes 920 ) 921 # `placeholders` are nodes internal to submod. 922 # `self.node.args` are dependency nodes in the outer graph. 923 # The two are 1:1. 924 for placeholder, arg_node in zip(placeholders, self.node.args): 925 # Create a receive buffer for this placeholder 926 recv_info = create_recv_tensor(placeholder, arg_node) 927 args_recv_info.append(recv_info) 928 929 logger.debug( 930 "%s Activation recv / args info: %s", self.log_prefix, args_recv_info 931 ) 932 # `args` is a Tuple, hence we will return a Tuple[InputInfo] 933 return tuple(args_recv_info) 934 935 def find_dst_rank( 936 self, 937 user: fx.Node, 938 ) -> Optional[int]: 939 """ 940 Find the destination rank of a `user` node. 941 If the `user` is not a submod, `None` may be returned. 942 """ 943 if user.op == "call_module": 944 # User is a stage (`call_module`) 945 return self.get_stage_index_of_submod(user.name) 946 else: 947 # - If user.op == "output": 948 # No need to send back to rank 0 949 # - If user.target is stage_backward: 950 # No need to send assuming submod output is stored locally or 951 # should be re-calucated in case of activation checkpointing 952 return None 953 954 def _create_act_send_info(self): 955 """ 956 Create a dict of send info for activations. 957 The dict is of the form: 958 { 959 output_index: [dst_rank_0, dst_rank_1, ...], 960 ... 961 } 962 where the list of `dst_rank`s covers the case where an output value may 963 be consumed by multiple stages. 964 """ 965 # Output index: List of receiver ranks 966 act_send_info: Dict[int, List] = {} 967 out_idx = 0 968 969 for user in self.node.users: 970 if user.target is operator.getitem: 971 # Recursively find the real destination 972 gi_dsts = act_send_info.setdefault(out_idx, []) 973 for gi_user in user.users: 974 dst_rank = self.find_dst_rank(gi_user) 975 if dst_rank is not None: 976 gi_dsts.append(dst_rank) 977 # Next `getitem` will point to the next output index 978 out_idx += 1 979 else: 980 # In case of single output value, `out_idx` will not increase 981 dsts = act_send_info.setdefault(out_idx, []) 982 dst_rank = self.find_dst_rank(user) 983 if dst_rank is not None: 984 dsts.append(dst_rank) 985 986 output_node = self._get_output_node() 987 output_vals: Tuple[torch.Tensor] = tuple( 988 v.meta["val"] for v in flatten_args(output_node.args) 989 ) 990 self._configure_outputs_meta(output_vals) 991 992 logger.debug("%s Send info: %s", self.log_prefix, act_send_info) 993 return act_send_info 994 995 def _get_output_node(self): 996 output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] 997 assert len(output_nodes) == 1 998 output_node = output_nodes[0] 999 return output_node 1000 1001 def _create_grad_recv_info( 1002 self, 1003 act_send_info: Dict, 1004 ) -> Tuple[_RecvInfo, ...]: 1005 """ 1006 Create a tuple of `_RecvInfo` for gradients. 1007 """ 1008 # Dict[output_index, _RecvInfo] 1009 grad_recv_info: Dict[int, _RecvInfo] = {} 1010 output_node = self._get_output_node() 1011 1012 # The output node may take multiple args, meaning the submod having multiple output values. 1013 output_vals = flatten_args(output_node.args) 1014 1015 for out_idx, dst_list in act_send_info.items(): 1016 if not dst_list: 1017 # No actual receiver for activation so no grad coming back 1018 continue 1019 1020 output = output_vals[out_idx] 1021 example_value = output.meta["val"] 1022 logger.debug( 1023 f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 1024 f": {example_value.shape}, {example_value.dtype}" 1025 ) 1026 1027 # TODO: otherwise needs grad accumulation 1028 assert len(dst_list) == 1, "Backward of skip connections not supported yet" 1029 grad_src = dst_list[0] 1030 grad_recv_info[out_idx] = _RecvInfo( 1031 f"{grad_src}", # noqa: G004 1032 grad_src, 1033 _make_tensor_from_meta(example_value, self.device), 1034 ) 1035 1036 # Convert to tuple for convenience in get_ops and retrieve tensor 1037 grad_recv_info_tuple = tuple(grad_recv_info.values()) 1038 logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) 1039 return grad_recv_info_tuple 1040 1041 1042# A helper function to create a pipeline stage based on traced pipeline information 1043def build_stage( 1044 stage_module: torch.nn.Module, 1045 stage_index: int, 1046 pipe_info: PipeInfo, 1047 device: torch.device, 1048 group: Optional[dist.ProcessGroup] = None, 1049) -> _PipelineStage: 1050 """ 1051 Create a pipeline stage given a stage_module to be wrapped by this stage 1052 and pipeline information. 1053 1054 Args: 1055 stage_module (torch.nn.Module): the module to be wrapped by this stage 1056 stage_index (int): the index of this stage in the pipeline 1057 pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` 1058 device (torch.device): the device to be used by this stage 1059 group (Optional[dist.ProcessGroup]): the process group to be used by this stage 1060 1061 Returns: 1062 _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. 1063 """ 1064 return _PipelineStage( 1065 stage_module, 1066 stage_index, 1067 pipe_info, 1068 device, 1069 group, 1070 ) 1071 1072 1073# Manual PipelineStage functions and definition 1074 1075METADATA_TENSOR_LEN = 100 1076PLACEHOLDER_VAL = -1 1077 1078 1079def _create_empty_tensors( 1080 tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device 1081) -> List[torch.Tensor]: 1082 """ 1083 Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s), 1084 and places them on the specified device. 1085 Args: 1086 tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s). 1087 device (torch.device): The device where the new tensors will be placed. 1088 Returns: 1089 List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s). 1090 """ 1091 if isinstance(tensor, torch.Tensor): 1092 return [torch.empty_like(tensor, device=device)] 1093 elif isinstance(tensor, (list, tuple)): 1094 return [torch.empty_like(t, device=device) for t in tensor] 1095 raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors") 1096 1097 1098def _create_metadata_tensor( 1099 tensors: Optional[List[torch.Tensor]] = None, 1100 device: Optional[torch.device] = torch.device("cpu"), 1101) -> torch.Tensor: 1102 """ 1103 Create a metadata tensor that can be sent over the wire. 1104 This tensor contains the number of dimensions and the shape of each tensor being sent. 1105 1106 The data is of format [num_dims, dim1, dim2, ...]. 1107 If the tensor is None, a tensor of only placeholder values will be returned. 1108 1109 Inputs: 1110 tensors: A list of tensors, the tensors will converted into its shape dimensions and 1111 these dimensions will be concatenated. 1112 device: The device where the metadata tensor will be created. 1113 If the tensor is None, then this tensor will contain PLACEHOLDER_VALs. 1114 1115 """ 1116 metadata_tensor = torch.full( 1117 (METADATA_TENSOR_LEN,), 1118 PLACEHOLDER_VAL, 1119 dtype=torch.int32, 1120 device=device, 1121 ) 1122 if tensors: 1123 # Create a list of tensors containing the number of dimensions and the shape of each tensor 1124 data = [ 1125 # data is of format [num_dims, dim1, dim2, ...] 1126 torch.tensor( 1127 [len(tensor.shape)] + list(tensor.shape), 1128 dtype=torch.int32, 1129 device=device, 1130 ) 1131 for tensor in tensors 1132 ] 1133 # Concatenate the data into a single tensor 1134 data_tensor = torch.cat(data) 1135 dt_shape = data_tensor.shape[0] 1136 if dt_shape > METADATA_TENSOR_LEN: 1137 raise ValueError( 1138 f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})." 1139 ) 1140 metadata_tensor[:dt_shape] = data_tensor 1141 return metadata_tensor 1142 1143 1144def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]: 1145 """ 1146 Extract the number of dimensions and the shape of each tensor from a metadata tensor. 1147 """ 1148 metadata: List[torch.Size] = [] 1149 i = 0 1150 while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL: 1151 num_dims = int(tensor[i].item()) 1152 shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist()) 1153 metadata.append(shape) 1154 i += num_dims + 1 1155 return metadata 1156 1157 1158def _get_stage_shapes( 1159 stage_modules: List[nn.Module], 1160 stage_ids: List[int], 1161 num_stages: int, 1162 rank: int, 1163 world_size: int, 1164 device: torch.device, 1165 microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, 1166): 1167 """ 1168 Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of 1169 virtual pipelining) and returns the shape of the inputs and outputs of the module. 1170 Only the first stage must pass in a microbatch. 1171 1172 Each rank must call _get_stage_shapes or the program will hang. 1173 1174 Args: 1175 stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any 1176 non-interleaved schedules and >1 for any interleaved schedules. 1177 stage_ids: The id of the stages assigned to this rank. 1178 num_stages: Total number of stages. 1179 rank: Rank of the current process. 1180 world_size: Number of processes participating in the pipeline. 1181 device: Device where the tensors are allocated. 1182 1183 Returns a dictionary containing the following keys: 1184 "inputs": Shape of the inputs to the module 1185 "outputs": Shape of the outputs of the module 1186 """ 1187 1188 stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {} 1189 for stage_id, model in zip(stage_ids, stage_modules): 1190 input_shape_metadata_tensor = _create_metadata_tensor(device=device) 1191 # TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1 1192 prev_rank = (rank - 1) % world_size 1193 next_rank = (rank + 1) % world_size 1194 shapes = {} 1195 1196 # first stage doesn't receive anything and uses a microbatch 1197 if stage_id == 0: 1198 if microbatch is None: 1199 raise RuntimeError("Microbatch is required for first stage") 1200 example_fwd_inputs = microbatch 1201 if isinstance(example_fwd_inputs, torch.Tensor): 1202 example_fwd_inputs = [example_fwd_inputs] 1203 else: 1204 # other stages must receive shape information 1205 # TODO: send/recv should take a group, rather than use the default group 1206 dist.recv(input_shape_metadata_tensor, prev_rank) 1207 metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor) 1208 example_fwd_inputs = [ 1209 torch.empty(shape_list, device=device) for shape_list in metadata 1210 ] 1211 shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs] 1212 1213 # perform forward 1214 # TODO: if forward fails raise a more descriptive error explaining which stage failed 1215 fwd_outputs = model(*example_fwd_inputs) 1216 fwd_outputs = _create_empty_tensors(fwd_outputs, device) 1217 shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs] 1218 1219 # send shape dims 1220 if stage_id != num_stages - 1: 1221 output_shape_metadata_tensor = _create_metadata_tensor( 1222 fwd_outputs, device=device 1223 ) 1224 dist.send(output_shape_metadata_tensor, next_rank) 1225 stage_id_to_shapes[stage_id] = shapes 1226 logger.info(stage_id_to_shapes) 1227 return stage_id_to_shapes 1228 1229 1230class PipelineStage(_PipelineStageBase): 1231 """ 1232 A class representing a pipeline stage in a pipeline parallelism setup. 1233 This class is created manually by providing a example input (and optionally output) 1234 as opposed to the PipelineStage class that is outputed from pipeline(). 1235 This class extends the `_PipelineStageBase` class and can similarly be used 1236 in `PipelineScheule`. 1237 1238 Args: 1239 submodule (nn.Module): The PyTorch module wrapped by this stage. 1240 stage_index (int): The ID of this stage. 1241 num_stages (int): The total number of stages. 1242 device (torch.device): The device where this stage is located. 1243 input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. 1244 output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. 1245 group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. 1246 dw_builder: TODO clean up comments 1247 """ 1248 1249 def __init__( 1250 self, 1251 submodule: nn.Module, 1252 stage_index: int, 1253 num_stages: int, 1254 device: torch.device, 1255 input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], 1256 output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, 1257 group: Optional[dist.ProcessGroup] = None, 1258 dw_builder: Optional[Callable[[], Callable[..., None]]] = None, 1259 ): 1260 super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) 1261 self.submod.to(self.device) 1262 # When we materialize the model partition on cuda, we call reset_parameters() if it is available 1263 self.inputs: List[torch.Tensor] = [] 1264 self.outputs: List[torch.Tensor] = [] 1265 1266 self.inputs = _create_empty_tensors(input_args, device) 1267 1268 if output_args is None: 1269 logger.info("output_args not provided, performing forward using input_args") 1270 self.outputs = self.submod(*self.inputs) 1271 # create buffers for the output so that the data is in the correct 1272 # shape in order to use in p2p op (send) 1273 self.outputs = _create_empty_tensors(self.outputs, device) 1274 else: 1275 self.outputs = _create_empty_tensors(output_args, device) 1276 1277 self._configure_outputs_meta(tuple(self.outputs)) 1278 1279 # these are the buffers used in backwards send/recv, they are allocated later 1280 self.outputs_grad: List[torch.Tensor] = [] 1281 1282 def stage_global_rank(peer_rank): 1283 return ( 1284 peer_rank 1285 if self.group is None 1286 else dist.get_global_rank(self.group, peer_rank) 1287 ) 1288 1289 self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size) 1290 self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size) 1291 1292 logger.debug( 1293 f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 1294 f"{self.is_last=}, {self.num_stages=}, " 1295 f"inputs: {[inp.shape for inp in self.inputs]}, " 1296 f"output: {[output.shape for output in self.outputs]}" 1297 ) 1298 1299 def _prepare_forward_infra(self, num_microbatches: int) -> None: 1300 # Receive info during forward 1301 # TODO: create args_recv_info lazily? (same needed for PipelineStage) 1302 for chunk_id in range(num_microbatches): 1303 self.set_requires_grad[chunk_id] = False 1304 if not self.is_first: 1305 # We assume that we always receive from stage - 1 1306 recv_infos = tuple( 1307 [ 1308 _RecvInfo( 1309 f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", 1310 self.stage_index - 1, 1311 _make_tensor_from_meta(inp, self.device), 1312 ) 1313 for inp in self.inputs 1314 ] 1315 ) 1316 1317 self.args_recv_info[chunk_id] = recv_infos 1318 else: 1319 self.args_recv_info[chunk_id] = tuple( 1320 [_RootArgPlaceholder(i) for i in self.inputs] 1321 ) 1322 1323 # Send info during forward for each activation 1324 # only need the rank that is being sent to 1325 self.act_send_info: Dict[int, List] = {} 1326 for idx in range(len(self.outputs)): 1327 # We assume we always send to stage + 1 1328 if not self.is_last: 1329 self.act_send_info[idx] = [self.stage_index + 1] 1330 else: 1331 self.act_send_info[idx] = [] 1332 1333 def _create_grad_recv_info( 1334 self, 1335 act_send_info: Dict, 1336 ) -> Tuple[_RecvInfo, ...]: 1337 grad_recv_info: Tuple[_RecvInfo, ...] = () 1338 if not self.is_last: 1339 # Receiving gradients from multiple sources is not supported 1340 # hence we only take the first destination 1341 grad_recv_info = tuple( 1342 [ 1343 _RecvInfo( 1344 f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", 1345 dst_list[0], 1346 _make_tensor_from_meta(self.outputs[idx], self.device), 1347 ) 1348 for idx, dst_list in act_send_info.items() 1349 ] 1350 ) 1351 return grad_recv_info 1352 1353 def _init_p2p_neighbors(self): 1354 """ 1355 Set up p2p communitors between previous and next stages 1356 by sending a dummy tensor. 1357 1358 If this is used, must be called for all pipeline stages. 1359 """ 1360 ops = [] 1361 recv_tensor = torch.zeros(1, device="cuda") 1362 send_tensor = torch.ones(1, device="cuda") 1363 # forward 1364 if not self.is_first: 1365 ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group)) 1366 if not self.is_last: 1367 ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group)) 1368 1369 # backward 1370 if not self.is_first: 1371 ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group)) 1372 if not self.is_last: 1373 ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group)) 1374 1375 return True 1376 1377 1378def _validate_stage_shapes(pipeline_stages: List[PipelineStage]): 1379 """ 1380 Check that the buffer shapes match between stages was expected by performing an all_gather between 1381 all stages. 1382 """ 1383 if len(pipeline_stages) == 0: 1384 raise ValueError("No pipeline stages provided.") 1385 1386 virtual_pipeline_size = len(pipeline_stages) 1387 all_inputs = [] 1388 all_outputs = [] 1389 world_size = pipeline_stages[0].group_size 1390 num_stages = pipeline_stages[0].num_stages 1391 1392 # perform all gathers between all stages 1393 for virtual_id, stage in enumerate(pipeline_stages): 1394 world_size = stage.group_size 1395 stage_id: int = stage.stage_index 1396 rank = stage.group_rank 1397 # check that world_size and num_stages are consistent across all stages 1398 if stage.group_size != world_size: 1399 raise ValueError( 1400 f"Stage id {stage_id} has world size ({stage.group_size}) \ 1401 which does not match world size ({world_size}) of other stages." 1402 ) 1403 if stage.num_stages != num_stages: 1404 raise ValueError( 1405 f"Stage id {stage_id} has num stages ({stage.num_stages}) \ 1406 which does not match num stages ({num_stages}) of other stages." 1407 ) 1408 1409 pg_rank = dist.get_rank(stage.group) 1410 if rank != pg_rank: 1411 raise ValueError( 1412 f"Rank {rank} is not equal to process group rank {pg_rank}" 1413 ) 1414 1415 if (num_stages := stage.num_stages) % world_size != 0: 1416 raise ValueError( 1417 f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})" 1418 ) 1419 1420 # all gather each ranks inputs 1421 tensor_list = [ 1422 _create_metadata_tensor(device=stage.device) 1423 for _ in range(stage.group_size) 1424 ] 1425 expected_inputs = stage.inputs 1426 stage_input = _create_metadata_tensor(expected_inputs, device=stage.device) 1427 dist.all_gather(tensor_list, stage_input) 1428 stage_input_shapes = [ 1429 _extract_metadata_from_tensor(tensor) for tensor in tensor_list 1430 ] 1431 1432 # all gather each ranks outputs 1433 tensor_list = [ 1434 _create_metadata_tensor(device=stage.device) 1435 for _ in range(stage.group_size) 1436 ] 1437 expected_outputs = stage.outputs 1438 stage_output = _create_metadata_tensor(expected_outputs, device=stage.device) 1439 dist.all_gather(tensor_list, stage_output) 1440 stage_output_shapes = [ 1441 _extract_metadata_from_tensor(tensor) for tensor in tensor_list 1442 ] 1443 1444 logger.debug( 1445 f"Rank: {pg_rank}" # noqa: G004 1446 f"Stage id: {stage_id}" 1447 f"Stage num stages: {stage.num_stages}" 1448 f"Stage rank: {rank}" 1449 f"Stage world size: {world_size}" 1450 f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003 1451 f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003 1452 ) 1453 1454 all_inputs.extend(stage_input_shapes) 1455 all_outputs.extend(stage_output_shapes) 1456 1457 # log only rank 0's view, they will all be equivalent 1458 if pg_rank == 0: 1459 logger.info( 1460 "all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs 1461 ) 1462 1463 # Check if the output for stage 0 matches the input at stage 1, and so forth 1464 for i in range(virtual_pipeline_size * world_size - 1): 1465 if (out := all_outputs[i]) != (inp := all_inputs[i + 1]): 1466 raise ValueError( 1467 f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}." 1468 ) 1469