1# Copyright (c) Meta Platforms, Inc. and affiliates 2 3import contextlib 4import itertools 5import logging 6import types 7import weakref 8from enum import Enum 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Generator, 14 List, 15 Optional, 16 Protocol, 17 Set, 18 Tuple, 19 Union, 20) 21 22import torch 23import torch.distributed as dist 24import torch.distributed._functional_collectives as ft_c 25import torch.nn.functional as F 26from torch import nn 27from torch.distributed.device_mesh import DeviceMesh 28from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard 29from torch.distributed.tensor.parallel.style import ParallelStyle 30 31 32# TODO: expose a single API 33__all__ = ["context_parallel"] 34 35aten = torch.ops.aten 36logger = logging.getLogger(__name__) 37# Whether to upcast parameters and gradients to float32 to avoid accumulation 38# errors. It is likely this is always True but we currently keep this variable 39# for the experimental purpose. 40_convert_to_f32 = True 41 42 43class _CausalBehavior(Enum): 44 SKIP = None 45 NOT_IS_CAUSAL = False 46 IS_CAUSAL = True 47 48 49def _is_causal_behavior( 50 rank: int, world_size: int, i: int, is_causal: bool 51) -> _CausalBehavior: 52 """ 53 Calculate is_causal behavior for each KV block. The attention can either be 54 calculated in full, not at all or with the causal mask applied. 55 """ 56 if not is_causal: 57 return _CausalBehavior.NOT_IS_CAUSAL 58 59 if i == 0: 60 return _CausalBehavior.IS_CAUSAL 61 62 source_rank = (rank - i) % world_size 63 if source_rank < rank: 64 return _CausalBehavior.NOT_IS_CAUSAL 65 else: 66 return _CausalBehavior.SKIP 67 68 69def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: 70 """ 71 When tracing the code, the result tensor is not an AsyncCollectiveTensor, 72 so we cannot call ``wait()``. 73 """ 74 if isinstance(tensor, ft_c.AsyncCollectiveTensor): 75 return tensor.wait() 76 return tensor 77 78 79class _SDPAMerger: 80 """A class to help to merge the local SDPA result.""" 81 82 def __init__(self, convert_to_f32: bool): 83 self._out: Optional[torch.Tensor] = None 84 self._lse: Optional[torch.Tensor] = None 85 self._convert_to_f32 = convert_to_f32 86 self._out_dtype = torch.float32 87 self._lse_dtype = torch.float32 88 89 def _merge_one(self, block_out: torch.Tensor, block_lse: torch.Tensor) -> None: 90 block_lse = block_lse.unsqueeze(dim=-1) 91 if self._lse is None: 92 self._lse = block_lse 93 self._out = block_out 94 else: 95 # The algorithm from 96 # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 97 # gives a relatively stable result. 98 self._out = self._out - F.sigmoid(block_lse - self._lse) * ( 99 self._out - block_out 100 ) 101 self._lse = self._lse - F.logsigmoid(self._lse - block_lse) 102 103 def step(self, out: torch.Tensor, lse: torch.Tensor) -> None: 104 self._out_dtype = out.dtype 105 self._lse_dtype = lse.dtype 106 107 if self._convert_to_f32: 108 out = out.to(torch.float32) 109 lse = lse.to(torch.float32) 110 111 self._merge_one(out, lse) 112 113 def results(self) -> Tuple[torch.Tensor, torch.Tensor]: 114 assert self._out is not None 115 assert self._lse is not None 116 out, lse = self._out, self._lse.squeeze(-1) 117 return out.to(self._out_dtype), lse.to(self._lse_dtype) 118 119 120def _scaled_dot_product_ring_flash_attention( 121 mesh: DeviceMesh, 122 query: torch.Tensor, 123 key: torch.Tensor, 124 value: torch.Tensor, 125 dropout_p: float = 0.0, 126 is_causal: bool = False, 127 return_debug_mask: bool = False, 128 *, 129 scale: Optional[float] = None, 130) -> Tuple[torch.Tensor, ...]: 131 if return_debug_mask: 132 raise NotImplementedError("return_debug_mask is not supported yet") 133 134 return _templated_ring_attention( 135 mesh, 136 aten._scaled_dot_product_flash_attention, 137 query=query, 138 key=key, 139 value=value, 140 is_causal=is_causal, 141 dropout_p=dropout_p, 142 scale=scale, 143 ) 144 145 146def _scaled_dot_product_ring_efficient_attention( 147 mesh: DeviceMesh, 148 query: torch.Tensor, 149 key: torch.Tensor, 150 value: torch.Tensor, 151 attn_bias: Optional[torch.Tensor] = None, 152 compute_log_sumexp: bool = True, 153 dropout_p: float = 0.0, 154 is_causal: bool = False, 155 *, 156 scale: Optional[float] = None, 157) -> Tuple[torch.Tensor, ...]: 158 if attn_bias is not None: 159 raise NotImplementedError("attn_bias is not supported yet") 160 if not compute_log_sumexp: 161 raise NotImplementedError("compute_log_sumexp must be set") 162 163 return _templated_ring_attention( 164 mesh, 165 aten._scaled_dot_product_efficient_attention, 166 query=query, 167 key=key, 168 value=value, 169 is_causal=is_causal, 170 attn_bias=attn_bias, 171 dropout_p=dropout_p, 172 scale=scale, 173 compute_log_sumexp=compute_log_sumexp, 174 ) 175 176 177class _AttentionOp(Protocol): 178 def __call__( 179 self, 180 query: torch.Tensor, 181 key: torch.Tensor, 182 value: torch.Tensor, 183 **kwargs: object, 184 ) -> Tuple[torch.Tensor, ...]: 185 ... 186 187 188def _ring_rotate( 189 block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool 190) -> torch.Tensor: 191 size = dist.get_world_size(pg) 192 dsts = ( 193 list(range(1, size)) + [0] 194 if send_to_next 195 else [size - 1] + list(range(0, size - 1)) 196 ) 197 return ft_c.permute_tensor(block, dsts, pg) 198 199 200def _templated_ring_attention( 201 mesh: DeviceMesh, 202 op: _AttentionOp, 203 query: torch.Tensor, 204 key: torch.Tensor, 205 value: torch.Tensor, 206 is_causal: bool = False, 207 **kwargs: object, 208) -> Tuple[torch.Tensor, ...]: 209 """ 210 This is a generalized ring attention implementation that can support multiple attention ops. 211 212 Parameters 213 ---------- 214 op: 215 The attention op to use 216 *args: 217 additional args are passed to the op 218 **kwargs: 219 additional kwargs are passed to the op 220 221 Returns 222 ------- 223 out: 224 The merged attention output 225 softmax_lse: 226 The logsumexp of the merged attention output 227 """ 228 if is_causal and (query.size(2) != key.size(2)): 229 raise NotImplementedError( 230 "is_causal requires the same query and context sequence lengths" 231 ) 232 233 if isinstance(mesh, dist.ProcessGroup): 234 pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh 235 else: 236 pg = mesh.get_group() 237 assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension" 238 rank = dist.get_rank(pg) 239 size = dist.get_world_size(pg) 240 241 next_kv = None 242 243 # Without making key and value contiguous(), the lose curve is bad. 244 # TODO(fegin): figure out why this is a requirement since SDPA does not have 245 # this requirement. 246 key = key.contiguous() 247 value = value.contiguous() 248 249 sdpa_merger = _SDPAMerger(_convert_to_f32) 250 251 rest: List[Any] 252 out: torch.Tensor 253 logsumexp: torch.Tensor 254 255 for i in range(size): 256 # overlap communication with compute 257 if next_kv is not None: 258 next_kv = _maybe_wait(next_kv) 259 key = next_kv[: key.numel()].reshape(key.shape) 260 value = next_kv[key.numel() :].reshape(value.shape) 261 262 if i < (size - 1): 263 next_kv = torch.cat([key.flatten(), value.flatten()]) 264 next_kv = _ring_rotate(next_kv, pg, send_to_next=True) 265 266 is_causal_behavior = _is_causal_behavior( 267 rank=rank, world_size=size, i=i, is_causal=is_causal 268 ) 269 270 if is_causal_behavior != _CausalBehavior.SKIP: 271 out, logsumexp, *rest = op( 272 query, 273 key, 274 value, 275 is_causal=is_causal_behavior.value, 276 **kwargs, 277 ) 278 279 sdpa_merger.step(out, logsumexp) 280 281 return *sdpa_merger.results(), *rest 282 283 284def _sdpa_handler( 285 op_call: torch._ops.OpOverload, 286 args: Tuple[object, ...], 287 kwargs: Dict[str, object], 288) -> object: 289 # extract local tensor and sharding infos to a OpInfo 290 op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) 291 logger.debug("Dispatching op_call: %s", op_info.schema) 292 293 # sharding propagation 294 # TODO: remove the context parallel strategy from the default propagation 295 # rule. Either figure out how to dynamically enable it or just don't call 296 # propagate. 297 DTensor._op_dispatcher.sharding_propagator.propagate(op_info) 298 output_sharding = op_info.output_sharding 299 assert output_sharding is not None, "output sharding should not be None" 300 assert not output_sharding.needs_redistribute, "inputs need to be redistributed" 301 302 if op_call == aten._scaled_dot_product_flash_attention.default: 303 local_results = _scaled_dot_product_ring_flash_attention( 304 op_info.mesh, 305 *op_info.local_args, # type: ignore[arg-type] 306 **op_info.local_kwargs, # type: ignore[arg-type] 307 ) 308 elif op_call == aten._scaled_dot_product_efficient_attention.default: 309 local_results = _scaled_dot_product_ring_efficient_attention( 310 op_info.mesh, 311 *op_info.local_args, # type: ignore[arg-type] 312 **op_info.local_kwargs, # type: ignore[arg-type] 313 ) 314 else: 315 raise NotImplementedError( 316 "CP only supports flash attention and memory efficient attention now." 317 ) 318 319 return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) 320 321 322def _sdpa_backward_handler( 323 op_call: torch._ops.OpOverload, 324 args: Tuple[object, ...], 325 kwargs: Dict[str, object], 326) -> object: 327 # Redistribute grad_output tensor to the same placement as output tensor 328 args = list(args) 329 args = tuple(args) 330 331 # extract local tensor and sharding infos to a OpInfo 332 op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) 333 logger.debug("Dispatching op_call: %s", op_info.schema) 334 335 # sharding propagation 336 DTensor._op_dispatcher.sharding_propagator.propagate(op_info) 337 output_sharding = op_info.output_sharding 338 assert output_sharding is not None, "output sharding should not be None" 339 assert not output_sharding.needs_redistribute, "inputs need to be redistributed" 340 341 if op_call == aten._scaled_dot_product_flash_attention_backward.default: 342 local_results = _scaled_dot_product_ring_flash_attention_backward( 343 op_info.mesh, 344 *op_info.local_args, # type: ignore[arg-type] 345 **op_info.local_kwargs, # type: ignore[arg-type] 346 ) 347 elif op_call == aten._scaled_dot_product_efficient_attention_backward.default: 348 local_results = _scaled_dot_product_ring_efficient_attention_backward( 349 op_info.mesh, 350 *op_info.local_args, # type: ignore[arg-type] 351 **op_info.local_kwargs, # type: ignore[arg-type] 352 ) 353 else: 354 raise NotImplementedError(f"{op_call=}") 355 356 return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) 357 358 359def _templated_ring_attention_backward( 360 mesh: DeviceMesh, 361 op: _AttentionOp, 362 grad_out: torch.Tensor, 363 grad_out_name: str, 364 query: torch.Tensor, 365 key: torch.Tensor, 366 value: torch.Tensor, 367 out: torch.Tensor, 368 logsumexp: torch.Tensor, 369 is_causal: bool, 370 **kwargs: Any, 371) -> Tuple[torch.Tensor, ...]: 372 pg = mesh.get_group() 373 assert isinstance(pg, dist.ProcessGroup), "must be single dimension" 374 rank = dist.get_rank(pg) 375 size = dist.get_world_size(pg) 376 next_kv = None 377 next_grad_kv = None 378 rest: List[Any] 379 grad_query_, grad_key_, grad_value_ = None, None, None 380 381 accum_dtype = torch.float32 if _convert_to_f32 else query.dtype 382 grad_query = torch.zeros_like(query, dtype=accum_dtype) 383 grad_key = torch.zeros_like(key, dtype=accum_dtype) 384 grad_value = torch.zeros_like(value, dtype=accum_dtype) 385 386 key = key.contiguous() 387 value = value.contiguous() 388 for i in range(size): 389 if next_kv is not None: 390 buffer = _maybe_wait(next_kv) 391 pointer = 0 392 key = buffer[pointer : pointer + key.numel()].reshape(key.shape) 393 pointer += key.numel() 394 value = buffer[pointer : pointer + value.numel()].reshape(value.shape) 395 pointer += value.numel() 396 397 if i != size - 1: 398 next_kv = torch.cat([key.flatten(), value.flatten()]) 399 next_kv = _ring_rotate(next_kv, pg, send_to_next=True) 400 401 is_causal_behavior = _is_causal_behavior( 402 rank=rank, world_size=size, i=i, is_causal=is_causal 403 ) 404 405 if is_causal_behavior != _CausalBehavior.SKIP: 406 kwargs[grad_out_name] = grad_out 407 grad_query_, grad_key_, grad_value_, *rest = op( 408 query=query, 409 key=key, 410 value=value, 411 out=out, 412 logsumexp=logsumexp, 413 is_causal=is_causal_behavior.value, 414 **kwargs, 415 ) 416 else: 417 grad_query_ = torch.zeros_like(query, dtype=accum_dtype) 418 grad_key_ = torch.zeros_like(key, dtype=accum_dtype) 419 grad_value_ = torch.zeros_like(value, dtype=accum_dtype) 420 421 # Get the grad key and grad value for the i round. 422 if i > 0: 423 pointer = 0 424 assert next_grad_kv is not None 425 next_grad_kv = _maybe_wait(next_grad_kv) 426 grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( 427 grad_key.shape 428 ) 429 pointer += grad_key.numel() 430 grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape( 431 grad_value.shape 432 ) 433 434 grad_key += grad_key_ 435 grad_value += grad_value_ 436 437 # Send the key, value, grad key, and grad value to the next rank. 438 next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) 439 next_grad_kv = _ring_rotate(next_grad_kv, pg, send_to_next=True) 440 grad_query += grad_query_ 441 442 assert next_grad_kv is not None 443 assert grad_key_ is not None 444 assert grad_value_ is not None 445 grad_query = grad_query.to(query.dtype) 446 next_grad_kv = _maybe_wait(next_grad_kv).to(key.dtype) 447 grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) 448 grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape) 449 return ( 450 grad_query, 451 grad_key, 452 grad_value, 453 *rest, 454 ) 455 456 457def _scaled_dot_product_ring_flash_attention_backward( 458 mesh: DeviceMesh, 459 grad_out: torch.Tensor, 460 query: torch.Tensor, 461 key: torch.Tensor, 462 value: torch.Tensor, 463 out: torch.Tensor, 464 logsumexp: torch.Tensor, 465 cum_seq_q: torch.Tensor, 466 cum_seq_k: torch.Tensor, 467 max_q: int, 468 max_k: int, 469 dropout_p: float, 470 is_causal: bool, 471 philox_seed: torch.Tensor, 472 philox_offset: torch.Tensor, 473 *, 474 scale: Optional[float] = None, 475) -> Tuple[torch.Tensor, ...]: 476 return _templated_ring_attention_backward( 477 mesh, 478 aten._scaled_dot_product_flash_attention_backward.default, 479 grad_out=grad_out, 480 grad_out_name="grad_out", 481 query=query, 482 key=key, 483 value=value, 484 out=out, 485 logsumexp=logsumexp, 486 is_causal=is_causal, 487 cum_seq_q=cum_seq_q, 488 cum_seq_k=cum_seq_k, 489 max_q=max_q, 490 max_k=max_k, 491 dropout_p=dropout_p, 492 philox_seed=philox_seed, 493 philox_offset=philox_offset, 494 scale=scale, 495 ) 496 497 498def _scaled_dot_product_ring_efficient_attention_backward( 499 mesh: DeviceMesh, 500 grad_out: torch.Tensor, 501 query: torch.Tensor, 502 key: torch.Tensor, 503 value: torch.Tensor, 504 bias: torch.Tensor, 505 out: torch.Tensor, 506 logsumexp: torch.Tensor, 507 philox_seed: torch.Tensor, 508 philox_offset: torch.Tensor, 509 dropout_p: float, 510 grad_input_mask: Tuple[bool, ...], 511 is_causal: bool = False, 512 *, 513 scale: Optional[float] = None, 514) -> Tuple[torch.Tensor, ...]: 515 return _templated_ring_attention_backward( 516 mesh, 517 aten._scaled_dot_product_efficient_attention_backward.default, 518 grad_out=grad_out, 519 grad_out_name="grad_out_", 520 query=query, 521 key=key, 522 value=value, 523 attn_bias=bias, 524 out=out, 525 logsumexp=logsumexp, 526 philox_seed=philox_seed, 527 philox_offset=philox_offset, 528 dropout_p=dropout_p, 529 grad_input_mask=grad_input_mask, 530 is_causal=is_causal, 531 scale=scale, 532 ) 533 534 535customized_ops = { 536 aten._scaled_dot_product_flash_attention.default: _sdpa_handler, 537 aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler, 538 aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, 539 aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler, 540} 541 542 543_replaced_functions: Dict[Callable, Tuple[str, Callable]] = {} 544 545 546def _distribute_function( 547 fn: Callable, 548 fn_module: types.ModuleType, 549 device_mesh: DeviceMesh, 550 input_fn: Optional[Callable] = None, 551 output_fn: Optional[Callable] = None, 552) -> None: 553 """ 554 ``distribute_function`` is an experimental API that allows users to "distribute" 555 the inputs and outputs of a function. Similar to ``distribute_module``, this API 556 installs hooks to the ``fn`` to convert the inputs and outputs. There are two 557 major differences between ``distribute_function`` and ``distribute_module``. 558 First, a function does not have parammeters and buffers, as a result, 559 ``distribute_function`` itself won't convert any parameters/buffers but simply 560 install the input and output hooks. The tensor conversion will happen in the hooks. 561 Another difference is an nn.Module subclass can have several instances and each 562 instance be fed into ``distribute_module`` independently with affecting other 563 instance. On the other hand, function is a singleton object. So if a function 564 is distributed by ``distribute_function`` all subsequent calls to the function 565 will invoke the installed hooks. 566 567 Args: 568 fn (Callable): the function to be distributed. 569 fn_module (types.ModuleType): the Python module that the function is declared. 570 e.g., if ``fn`` is ``torch.nn.functional.scaled_dot_product_attention``, 571 ``fn_module`` is ``torch.nn.functional``. 572 device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the 573 input and output hooks to distribute the tensors. 574 input_fn (Optioinal[Callable]): the hook to distribute or convert the input 575 arguments of ``fn``. 576 output_fn (Optioinal[Callable]): the hook to distribute or convert the output 577 arguments of ``fn``. 578 """ 579 580 def wrapper( 581 target_fn: Callable, input_fn: Optional[Callable], output_fn: Optional[Callable] 582 ) -> Callable: 583 def inner_fn(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: 584 if input_fn is not None: 585 args, kwargs = input_fn(device_mesh, *args, **kwargs) 586 output = target_fn(*args, **kwargs) 587 if output_fn is not None: 588 output = output_fn(device_mesh, output) 589 return output 590 591 return inner_fn 592 593 global _replaced_functions 594 595 if fn in _replaced_functions: 596 return 597 598 wrapper_fn = wrapper(fn, input_fn, output_fn) 599 setattr(fn_module, fn.__name__, wrapper_fn) 600 _replaced_functions[wrapper_fn] = (fn.__name__, fn) 601 602 603def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None: 604 """Restore the function that is replaced by _distribute_function.""" 605 global _original_functions 606 global _wrapper_functions 607 608 if fn not in _replaced_functions: 609 return 610 611 original_name, original_fn = _replaced_functions[fn] 612 setattr(fn_module, original_name, original_fn) 613 614 615@contextlib.contextmanager 616def _enable_cp_dispatcher() -> Generator[None, None, None]: 617 """Enables DTensor dispatcher to dispatch SDPA to CP.""" 618 old_handlers = DTensor._op_dispatcher._custom_op_handlers 619 DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops} 620 621 yield 622 623 DTensor._op_dispatcher._custom_op_handlers = old_handlers 624 625 626class _AttentionContextParallel(ParallelStyle): 627 """ 628 Applies context parallel optimizations to the attention layer. 629 630 This will work for nn.MultiHeadedAttention and custom attention layers that 631 call F.scaled_dotproduct_attention with a simliar signature. 632 633 This expects the `forward` method consumes either: 634 635 * a single tensor for self attention 636 * one argument for each of: query, key, value 637 638 This currently only supports ring attention and the 639 SDPBackend.FLASH_ATTENTION backend. See sdpa_kernel. 640 641 Non-flash attention backends will result in incorrect results. 642 """ 643 644 # use a weakref dictionary to store context managers for each nn.Module 645 _CONTEXT_MANAGERS: "weakref.WeakKeyDictionary[nn.Module, Any]" = ( 646 weakref.WeakKeyDictionary() 647 ) 648 649 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 650 if not isinstance(device_mesh, DeviceMesh): 651 raise ValueError( 652 f"{type(device_mesh)} is not supported by {type(self)} yet." 653 ) 654 655 if not device_mesh.ndim == 1: 656 raise ValueError 657 658 return distribute_module( 659 module, 660 device_mesh, 661 input_fn=self._input_fn, # type: ignore[arg-type] 662 output_fn=self._output_fn, # type: ignore[arg-type] 663 ) 664 665 @classmethod 666 def _input_fn( 667 cls, 668 module: nn.Module, 669 inputs: Tuple[Union[torch.Tensor, int, float], ...], 670 device_mesh: DeviceMesh, 671 ) -> Tuple[Union[torch.Tensor, int, float], ...]: 672 # TODO(d4l3k); this should be Shard(2), need to fix Linear layer rules 673 placement = [Replicate()] 674 675 def backward_hook(grad: torch.Tensor) -> None: 676 if module in cls._CONTEXT_MANAGERS: 677 cls._CONTEXT_MANAGERS[module].__exit__(None, None, None) 678 del cls._CONTEXT_MANAGERS[module] 679 680 # convert inputs to DTensor 681 inp = [] 682 for input in inputs: 683 if isinstance(input, torch.Tensor) and not isinstance(input, DTensor): 684 input = DTensor.from_local( 685 input.contiguous(), device_mesh, placement, run_check=False 686 ) 687 688 if isinstance(input, torch.Tensor) and input.requires_grad: 689 input.register_hook(backward_hook) 690 691 inp.append(input) 692 693 manager = _enable_cp_dispatcher() 694 manager.__enter__() 695 cls._CONTEXT_MANAGERS[module] = manager 696 697 return tuple(inp) 698 699 @classmethod 700 def _output_fn( 701 cls, 702 module: nn.Module, 703 outputs: Union[torch.Tensor, Tuple[Union[torch.Tensor, int, float], ...]], 704 device_mesh: DeviceMesh, 705 ) -> Union[ 706 Union[torch.Tensor, int, float], Tuple[Union[torch.Tensor, int, float], ...] 707 ]: 708 cls._CONTEXT_MANAGERS[module].__exit__(None, None, None) 709 del cls._CONTEXT_MANAGERS[module] 710 711 def backward_hook(grad: torch.Tensor) -> None: 712 if module not in cls._CONTEXT_MANAGERS: 713 manager = _enable_cp_dispatcher() 714 manager.__enter__() 715 cls._CONTEXT_MANAGERS[module] = manager 716 717 # back to local tensor 718 out = [] 719 for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: 720 output = output.to_local() if isinstance(output, DTensor) else output 721 722 if isinstance(output, torch.Tensor) and output.requires_grad: 723 output.register_hook(backward_hook) 724 725 out.append(output) 726 727 if isinstance(outputs, torch.Tensor): 728 return out[0] 729 730 return tuple(out) 731 732 733@contextlib.contextmanager 734def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, None]: 735 """Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher.""" 736 737 def attention_input_fn( 738 mesh: DeviceMesh, *args: Tuple[Any, ...], **kwargs: Dict[str, Any] 739 ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 740 placement = [Shard(seq_dim)] 741 all_args = [] 742 743 for arg in itertools.chain(args, kwargs.values()): 744 if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor): 745 arg = DTensor.from_local(arg, mesh, placement, run_check=False) 746 747 all_args.append(arg) 748 749 new_args = tuple(all_args[0 : len(args)]) 750 new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :])) 751 return new_args, new_kwargs 752 753 def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any: 754 new_outputs = [] 755 for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: 756 output = output.to_local() if isinstance(output, DTensor) else output 757 new_outputs.append(output) 758 759 if isinstance(outputs, torch.Tensor): 760 return new_outputs[0] 761 762 return tuple(new_outputs) 763 764 # TODO: provide a more robust way to replace SDPA. 765 # Currently we use monkey patch to replace scaled_dot_product_attention with the 766 # wrapped fn. This is okay if users do `import torch.nn.functional` but will not 767 # work if users do `import torch.nn.functional.scaled_dot_product_attention`. 768 _distribute_function( 769 F.scaled_dot_product_attention, 770 F, 771 mesh, 772 attention_input_fn, 773 attention_output_fn, 774 ) 775 776 with _enable_cp_dispatcher(): 777 yield 778 779 _restore_function(F.scaled_dot_product_attention, F) 780 781 782def _get_sequence_shard( 783 buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int 784) -> torch.Tensor: 785 return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] 786 787 788def _context_parallel_buffers( 789 mesh: DeviceMesh, 790 buffers: List[torch.Tensor], 791 buffer_seq_dims: List[int], 792) -> List[torch.Tensor]: 793 """Shard the buffers along the sequence dimensions according to CP rules.""" 794 new_buffers = [] 795 for buffer, seq_dim in zip(buffers, buffer_seq_dims): 796 new_buffers.append(_get_sequence_shard(buffer, mesh, seq_dim)) 797 798 return new_buffers 799 800 801@contextlib.contextmanager 802@torch.no_grad() 803def context_parallel( 804 mesh: DeviceMesh, 805 *, 806 buffers: Optional[List[torch.Tensor]] = None, 807 buffer_seq_dims: Optional[List[int]] = None, 808 no_restore_buffers: Optional[Set[torch.Tensor]] = None, 809) -> Generator[None, None, None]: 810 """ 811 812 ``context_parallel`` is an experimental API to enable context 813 parallelism (CP). This API performs two actions: 1) patch the SDPA 814 (``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled 815 one, 2) shard ``buffers`` along the sequence dimension and each rank will 816 preserve the corresponding shard according ``mesh``. 817 818 Args: 819 mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. 820 buffers (Optional[List[torch.Tensor]]): buffers that the usage depend 821 on the sequence dimension. Examples are input batch, labels and 822 positional embedding buffers. These buffers must be sharded along 823 the sequence dimension to ensure the accuracy. The sharding will 824 happen in-place, the buffer's shape will change within the context. 825 The buffers will be restored after the context finishes. 826 ``no_restore_buffers`` can be used to specify which buffers don't 827 need to be restored. Note that ``buffers`` should not contain any 828 nn.Parameter. 829 buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``. 830 no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set 831 won't be restored after the context exits. This set must be a subset 832 of ``buffers``. If the buffers won't be used after the context exits, 833 these buffers can be put in this list to avoid extra restore time. 834 835 .. warning:: 836 `torch.distributed._tensor.experimental.attention.context_parallel` is a 837 prototype feature in PyTorch. The API is subject to change. 838 """ 839 buffers = [] if buffers is None else buffers 840 buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims 841 no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers 842 843 if len(buffers) != len(buffer_seq_dims): 844 raise ValueError( 845 "`seq_dims` must have the same number of elements as `buffers`." 846 ) 847 848 for buffer in no_restore_buffers: 849 # Cannot use `if not buffer in buffers` which will incur tensor comparison. 850 if not any(b is buffer for b in buffers): 851 raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") 852 853 original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] 854 855 chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims) 856 for buffer, chunk in zip(buffers, chunks): 857 chunk = chunk.clone() 858 buffer.resize_(chunk.shape) 859 buffer.copy_(chunk) 860 861 with _context_parallel(seq_dim=2, mesh=mesh): 862 yield 863 864 for buffer, original_buffer in zip(buffers, original_buffers): 865 if original_buffer is not None: 866 buffer.resize_(original_buffer.shape) 867 buffer.copy_(original_buffer) 868