xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/experimental/_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2
3import contextlib
4import itertools
5import logging
6import types
7import weakref
8from enum import Enum
9from typing import (
10    Any,
11    Callable,
12    Dict,
13    Generator,
14    List,
15    Optional,
16    Protocol,
17    Set,
18    Tuple,
19    Union,
20)
21
22import torch
23import torch.distributed as dist
24import torch.distributed._functional_collectives as ft_c
25import torch.nn.functional as F
26from torch import nn
27from torch.distributed.device_mesh import DeviceMesh
28from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard
29from torch.distributed.tensor.parallel.style import ParallelStyle
30
31
32# TODO: expose a single API
33__all__ = ["context_parallel"]
34
35aten = torch.ops.aten
36logger = logging.getLogger(__name__)
37# Whether to upcast parameters and gradients to float32 to avoid accumulation
38# errors. It is likely this is always True but we currently keep this variable
39# for the experimental purpose.
40_convert_to_f32 = True
41
42
43class _CausalBehavior(Enum):
44    SKIP = None
45    NOT_IS_CAUSAL = False
46    IS_CAUSAL = True
47
48
49def _is_causal_behavior(
50    rank: int, world_size: int, i: int, is_causal: bool
51) -> _CausalBehavior:
52    """
53    Calculate is_causal behavior for each KV block. The attention can either be
54    calculated in full, not at all or with the causal mask applied.
55    """
56    if not is_causal:
57        return _CausalBehavior.NOT_IS_CAUSAL
58
59    if i == 0:
60        return _CausalBehavior.IS_CAUSAL
61
62    source_rank = (rank - i) % world_size
63    if source_rank < rank:
64        return _CausalBehavior.NOT_IS_CAUSAL
65    else:
66        return _CausalBehavior.SKIP
67
68
69def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
70    """
71    When tracing the code, the result tensor is not an AsyncCollectiveTensor,
72    so we cannot call ``wait()``.
73    """
74    if isinstance(tensor, ft_c.AsyncCollectiveTensor):
75        return tensor.wait()
76    return tensor
77
78
79class _SDPAMerger:
80    """A class to help to merge the local SDPA result."""
81
82    def __init__(self, convert_to_f32: bool):
83        self._out: Optional[torch.Tensor] = None
84        self._lse: Optional[torch.Tensor] = None
85        self._convert_to_f32 = convert_to_f32
86        self._out_dtype = torch.float32
87        self._lse_dtype = torch.float32
88
89    def _merge_one(self, block_out: torch.Tensor, block_lse: torch.Tensor) -> None:
90        block_lse = block_lse.unsqueeze(dim=-1)
91        if self._lse is None:
92            self._lse = block_lse
93            self._out = block_out
94        else:
95            # The algorithm from
96            # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
97            # gives a relatively stable result.
98            self._out = self._out - F.sigmoid(block_lse - self._lse) * (
99                self._out - block_out
100            )
101            self._lse = self._lse - F.logsigmoid(self._lse - block_lse)
102
103    def step(self, out: torch.Tensor, lse: torch.Tensor) -> None:
104        self._out_dtype = out.dtype
105        self._lse_dtype = lse.dtype
106
107        if self._convert_to_f32:
108            out = out.to(torch.float32)
109            lse = lse.to(torch.float32)
110
111        self._merge_one(out, lse)
112
113    def results(self) -> Tuple[torch.Tensor, torch.Tensor]:
114        assert self._out is not None
115        assert self._lse is not None
116        out, lse = self._out, self._lse.squeeze(-1)
117        return out.to(self._out_dtype), lse.to(self._lse_dtype)
118
119
120def _scaled_dot_product_ring_flash_attention(
121    mesh: DeviceMesh,
122    query: torch.Tensor,
123    key: torch.Tensor,
124    value: torch.Tensor,
125    dropout_p: float = 0.0,
126    is_causal: bool = False,
127    return_debug_mask: bool = False,
128    *,
129    scale: Optional[float] = None,
130) -> Tuple[torch.Tensor, ...]:
131    if return_debug_mask:
132        raise NotImplementedError("return_debug_mask is not supported yet")
133
134    return _templated_ring_attention(
135        mesh,
136        aten._scaled_dot_product_flash_attention,
137        query=query,
138        key=key,
139        value=value,
140        is_causal=is_causal,
141        dropout_p=dropout_p,
142        scale=scale,
143    )
144
145
146def _scaled_dot_product_ring_efficient_attention(
147    mesh: DeviceMesh,
148    query: torch.Tensor,
149    key: torch.Tensor,
150    value: torch.Tensor,
151    attn_bias: Optional[torch.Tensor] = None,
152    compute_log_sumexp: bool = True,
153    dropout_p: float = 0.0,
154    is_causal: bool = False,
155    *,
156    scale: Optional[float] = None,
157) -> Tuple[torch.Tensor, ...]:
158    if attn_bias is not None:
159        raise NotImplementedError("attn_bias is not supported yet")
160    if not compute_log_sumexp:
161        raise NotImplementedError("compute_log_sumexp must be set")
162
163    return _templated_ring_attention(
164        mesh,
165        aten._scaled_dot_product_efficient_attention,
166        query=query,
167        key=key,
168        value=value,
169        is_causal=is_causal,
170        attn_bias=attn_bias,
171        dropout_p=dropout_p,
172        scale=scale,
173        compute_log_sumexp=compute_log_sumexp,
174    )
175
176
177class _AttentionOp(Protocol):
178    def __call__(
179        self,
180        query: torch.Tensor,
181        key: torch.Tensor,
182        value: torch.Tensor,
183        **kwargs: object,
184    ) -> Tuple[torch.Tensor, ...]:
185        ...
186
187
188def _ring_rotate(
189    block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool
190) -> torch.Tensor:
191    size = dist.get_world_size(pg)
192    dsts = (
193        list(range(1, size)) + [0]
194        if send_to_next
195        else [size - 1] + list(range(0, size - 1))
196    )
197    return ft_c.permute_tensor(block, dsts, pg)
198
199
200def _templated_ring_attention(
201    mesh: DeviceMesh,
202    op: _AttentionOp,
203    query: torch.Tensor,
204    key: torch.Tensor,
205    value: torch.Tensor,
206    is_causal: bool = False,
207    **kwargs: object,
208) -> Tuple[torch.Tensor, ...]:
209    """
210    This is a generalized ring attention implementation that can support multiple attention ops.
211
212    Parameters
213    ----------
214    op:
215        The attention op to use
216    *args:
217        additional args are passed to the op
218    **kwargs:
219        additional kwargs are passed to the op
220
221    Returns
222    -------
223    out:
224        The merged attention output
225    softmax_lse:
226        The logsumexp of the merged attention output
227    """
228    if is_causal and (query.size(2) != key.size(2)):
229        raise NotImplementedError(
230            "is_causal requires the same query and context sequence lengths"
231        )
232
233    if isinstance(mesh, dist.ProcessGroup):
234        pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
235    else:
236        pg = mesh.get_group()
237    assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
238    rank = dist.get_rank(pg)
239    size = dist.get_world_size(pg)
240
241    next_kv = None
242
243    # Without making key and value contiguous(), the lose curve is bad.
244    # TODO(fegin): figure out why this is a requirement since SDPA does not have
245    # this requirement.
246    key = key.contiguous()
247    value = value.contiguous()
248
249    sdpa_merger = _SDPAMerger(_convert_to_f32)
250
251    rest: List[Any]
252    out: torch.Tensor
253    logsumexp: torch.Tensor
254
255    for i in range(size):
256        # overlap communication with compute
257        if next_kv is not None:
258            next_kv = _maybe_wait(next_kv)
259            key = next_kv[: key.numel()].reshape(key.shape)
260            value = next_kv[key.numel() :].reshape(value.shape)
261
262        if i < (size - 1):
263            next_kv = torch.cat([key.flatten(), value.flatten()])
264            next_kv = _ring_rotate(next_kv, pg, send_to_next=True)
265
266        is_causal_behavior = _is_causal_behavior(
267            rank=rank, world_size=size, i=i, is_causal=is_causal
268        )
269
270        if is_causal_behavior != _CausalBehavior.SKIP:
271            out, logsumexp, *rest = op(
272                query,
273                key,
274                value,
275                is_causal=is_causal_behavior.value,
276                **kwargs,
277            )
278
279            sdpa_merger.step(out, logsumexp)
280
281    return *sdpa_merger.results(), *rest
282
283
284def _sdpa_handler(
285    op_call: torch._ops.OpOverload,
286    args: Tuple[object, ...],
287    kwargs: Dict[str, object],
288) -> object:
289    # extract local tensor and sharding infos to a OpInfo
290    op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
291    logger.debug("Dispatching op_call: %s", op_info.schema)
292
293    # sharding propagation
294    # TODO: remove the context parallel strategy from the default propagation
295    # rule. Either figure out how to dynamically enable it or just don't call
296    # propagate.
297    DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
298    output_sharding = op_info.output_sharding
299    assert output_sharding is not None, "output sharding should not be None"
300    assert not output_sharding.needs_redistribute, "inputs need to be redistributed"
301
302    if op_call == aten._scaled_dot_product_flash_attention.default:
303        local_results = _scaled_dot_product_ring_flash_attention(
304            op_info.mesh,
305            *op_info.local_args,  # type: ignore[arg-type]
306            **op_info.local_kwargs,  # type: ignore[arg-type]
307        )
308    elif op_call == aten._scaled_dot_product_efficient_attention.default:
309        local_results = _scaled_dot_product_ring_efficient_attention(
310            op_info.mesh,
311            *op_info.local_args,  # type: ignore[arg-type]
312            **op_info.local_kwargs,  # type: ignore[arg-type]
313        )
314    else:
315        raise NotImplementedError(
316            "CP only supports flash attention and memory efficient attention now."
317        )
318
319    return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
320
321
322def _sdpa_backward_handler(
323    op_call: torch._ops.OpOverload,
324    args: Tuple[object, ...],
325    kwargs: Dict[str, object],
326) -> object:
327    # Redistribute grad_output tensor to the same placement as output tensor
328    args = list(args)
329    args = tuple(args)
330
331    # extract local tensor and sharding infos to a OpInfo
332    op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
333    logger.debug("Dispatching op_call: %s", op_info.schema)
334
335    # sharding propagation
336    DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
337    output_sharding = op_info.output_sharding
338    assert output_sharding is not None, "output sharding should not be None"
339    assert not output_sharding.needs_redistribute, "inputs need to be redistributed"
340
341    if op_call == aten._scaled_dot_product_flash_attention_backward.default:
342        local_results = _scaled_dot_product_ring_flash_attention_backward(
343            op_info.mesh,
344            *op_info.local_args,  # type: ignore[arg-type]
345            **op_info.local_kwargs,  # type: ignore[arg-type]
346        )
347    elif op_call == aten._scaled_dot_product_efficient_attention_backward.default:
348        local_results = _scaled_dot_product_ring_efficient_attention_backward(
349            op_info.mesh,
350            *op_info.local_args,  # type: ignore[arg-type]
351            **op_info.local_kwargs,  # type: ignore[arg-type]
352        )
353    else:
354        raise NotImplementedError(f"{op_call=}")
355
356    return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
357
358
359def _templated_ring_attention_backward(
360    mesh: DeviceMesh,
361    op: _AttentionOp,
362    grad_out: torch.Tensor,
363    grad_out_name: str,
364    query: torch.Tensor,
365    key: torch.Tensor,
366    value: torch.Tensor,
367    out: torch.Tensor,
368    logsumexp: torch.Tensor,
369    is_causal: bool,
370    **kwargs: Any,
371) -> Tuple[torch.Tensor, ...]:
372    pg = mesh.get_group()
373    assert isinstance(pg, dist.ProcessGroup), "must be single dimension"
374    rank = dist.get_rank(pg)
375    size = dist.get_world_size(pg)
376    next_kv = None
377    next_grad_kv = None
378    rest: List[Any]
379    grad_query_, grad_key_, grad_value_ = None, None, None
380
381    accum_dtype = torch.float32 if _convert_to_f32 else query.dtype
382    grad_query = torch.zeros_like(query, dtype=accum_dtype)
383    grad_key = torch.zeros_like(key, dtype=accum_dtype)
384    grad_value = torch.zeros_like(value, dtype=accum_dtype)
385
386    key = key.contiguous()
387    value = value.contiguous()
388    for i in range(size):
389        if next_kv is not None:
390            buffer = _maybe_wait(next_kv)
391            pointer = 0
392            key = buffer[pointer : pointer + key.numel()].reshape(key.shape)
393            pointer += key.numel()
394            value = buffer[pointer : pointer + value.numel()].reshape(value.shape)
395            pointer += value.numel()
396
397        if i != size - 1:
398            next_kv = torch.cat([key.flatten(), value.flatten()])
399            next_kv = _ring_rotate(next_kv, pg, send_to_next=True)
400
401        is_causal_behavior = _is_causal_behavior(
402            rank=rank, world_size=size, i=i, is_causal=is_causal
403        )
404
405        if is_causal_behavior != _CausalBehavior.SKIP:
406            kwargs[grad_out_name] = grad_out
407            grad_query_, grad_key_, grad_value_, *rest = op(
408                query=query,
409                key=key,
410                value=value,
411                out=out,
412                logsumexp=logsumexp,
413                is_causal=is_causal_behavior.value,
414                **kwargs,
415            )
416        else:
417            grad_query_ = torch.zeros_like(query, dtype=accum_dtype)
418            grad_key_ = torch.zeros_like(key, dtype=accum_dtype)
419            grad_value_ = torch.zeros_like(value, dtype=accum_dtype)
420
421        # Get the grad key and grad value for the i round.
422        if i > 0:
423            pointer = 0
424            assert next_grad_kv is not None
425            next_grad_kv = _maybe_wait(next_grad_kv)
426            grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape(
427                grad_key.shape
428            )
429            pointer += grad_key.numel()
430            grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape(
431                grad_value.shape
432            )
433
434        grad_key += grad_key_
435        grad_value += grad_value_
436
437        # Send the key, value, grad key, and grad value to the next rank.
438        next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()])
439        next_grad_kv = _ring_rotate(next_grad_kv, pg, send_to_next=True)
440        grad_query += grad_query_
441
442    assert next_grad_kv is not None
443    assert grad_key_ is not None
444    assert grad_value_ is not None
445    grad_query = grad_query.to(query.dtype)
446    next_grad_kv = _maybe_wait(next_grad_kv).to(key.dtype)
447    grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape)
448    grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape)
449    return (
450        grad_query,
451        grad_key,
452        grad_value,
453        *rest,
454    )
455
456
457def _scaled_dot_product_ring_flash_attention_backward(
458    mesh: DeviceMesh,
459    grad_out: torch.Tensor,
460    query: torch.Tensor,
461    key: torch.Tensor,
462    value: torch.Tensor,
463    out: torch.Tensor,
464    logsumexp: torch.Tensor,
465    cum_seq_q: torch.Tensor,
466    cum_seq_k: torch.Tensor,
467    max_q: int,
468    max_k: int,
469    dropout_p: float,
470    is_causal: bool,
471    philox_seed: torch.Tensor,
472    philox_offset: torch.Tensor,
473    *,
474    scale: Optional[float] = None,
475) -> Tuple[torch.Tensor, ...]:
476    return _templated_ring_attention_backward(
477        mesh,
478        aten._scaled_dot_product_flash_attention_backward.default,
479        grad_out=grad_out,
480        grad_out_name="grad_out",
481        query=query,
482        key=key,
483        value=value,
484        out=out,
485        logsumexp=logsumexp,
486        is_causal=is_causal,
487        cum_seq_q=cum_seq_q,
488        cum_seq_k=cum_seq_k,
489        max_q=max_q,
490        max_k=max_k,
491        dropout_p=dropout_p,
492        philox_seed=philox_seed,
493        philox_offset=philox_offset,
494        scale=scale,
495    )
496
497
498def _scaled_dot_product_ring_efficient_attention_backward(
499    mesh: DeviceMesh,
500    grad_out: torch.Tensor,
501    query: torch.Tensor,
502    key: torch.Tensor,
503    value: torch.Tensor,
504    bias: torch.Tensor,
505    out: torch.Tensor,
506    logsumexp: torch.Tensor,
507    philox_seed: torch.Tensor,
508    philox_offset: torch.Tensor,
509    dropout_p: float,
510    grad_input_mask: Tuple[bool, ...],
511    is_causal: bool = False,
512    *,
513    scale: Optional[float] = None,
514) -> Tuple[torch.Tensor, ...]:
515    return _templated_ring_attention_backward(
516        mesh,
517        aten._scaled_dot_product_efficient_attention_backward.default,
518        grad_out=grad_out,
519        grad_out_name="grad_out_",
520        query=query,
521        key=key,
522        value=value,
523        attn_bias=bias,
524        out=out,
525        logsumexp=logsumexp,
526        philox_seed=philox_seed,
527        philox_offset=philox_offset,
528        dropout_p=dropout_p,
529        grad_input_mask=grad_input_mask,
530        is_causal=is_causal,
531        scale=scale,
532    )
533
534
535customized_ops = {
536    aten._scaled_dot_product_flash_attention.default: _sdpa_handler,
537    aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler,
538    aten._scaled_dot_product_efficient_attention.default: _sdpa_handler,
539    aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler,
540}
541
542
543_replaced_functions: Dict[Callable, Tuple[str, Callable]] = {}
544
545
546def _distribute_function(
547    fn: Callable,
548    fn_module: types.ModuleType,
549    device_mesh: DeviceMesh,
550    input_fn: Optional[Callable] = None,
551    output_fn: Optional[Callable] = None,
552) -> None:
553    """
554    ``distribute_function`` is an experimental API that allows users to "distribute"
555    the inputs and outputs of a function. Similar to ``distribute_module``, this API
556    installs hooks to the ``fn`` to convert the inputs and outputs. There are two
557    major differences between ``distribute_function`` and ``distribute_module``.
558    First, a function does not have parammeters and buffers, as a result,
559    ``distribute_function`` itself won't convert any parameters/buffers but simply
560    install the input and output hooks.  The tensor conversion will happen in the hooks.
561    Another difference is an nn.Module subclass can have several instances and each
562    instance be fed into ``distribute_module`` independently with affecting other
563    instance. On the other hand, function is a singleton object. So if a function
564    is distributed by ``distribute_function`` all subsequent calls to the function
565    will invoke the installed hooks.
566
567    Args:
568        fn (Callable): the function to be distributed.
569        fn_module (types.ModuleType): the Python module that the function is declared.
570            e.g., if ``fn`` is ``torch.nn.functional.scaled_dot_product_attention``,
571            ``fn_module`` is ``torch.nn.functional``.
572        device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the
573            input and output hooks to distribute the tensors.
574        input_fn (Optioinal[Callable]): the hook to distribute or convert the input
575            arguments of ``fn``.
576        output_fn (Optioinal[Callable]): the hook to distribute or convert the output
577            arguments of ``fn``.
578    """
579
580    def wrapper(
581        target_fn: Callable, input_fn: Optional[Callable], output_fn: Optional[Callable]
582    ) -> Callable:
583        def inner_fn(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
584            if input_fn is not None:
585                args, kwargs = input_fn(device_mesh, *args, **kwargs)
586            output = target_fn(*args, **kwargs)
587            if output_fn is not None:
588                output = output_fn(device_mesh, output)
589            return output
590
591        return inner_fn
592
593    global _replaced_functions
594
595    if fn in _replaced_functions:
596        return
597
598    wrapper_fn = wrapper(fn, input_fn, output_fn)
599    setattr(fn_module, fn.__name__, wrapper_fn)
600    _replaced_functions[wrapper_fn] = (fn.__name__, fn)
601
602
603def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
604    """Restore the function that is replaced by _distribute_function."""
605    global _original_functions
606    global _wrapper_functions
607
608    if fn not in _replaced_functions:
609        return
610
611    original_name, original_fn = _replaced_functions[fn]
612    setattr(fn_module, original_name, original_fn)
613
614
615@contextlib.contextmanager
616def _enable_cp_dispatcher() -> Generator[None, None, None]:
617    """Enables DTensor dispatcher to dispatch SDPA to CP."""
618    old_handlers = DTensor._op_dispatcher._custom_op_handlers
619    DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops}
620
621    yield
622
623    DTensor._op_dispatcher._custom_op_handlers = old_handlers
624
625
626class _AttentionContextParallel(ParallelStyle):
627    """
628    Applies context parallel optimizations to the attention layer.
629
630    This will work for nn.MultiHeadedAttention and custom attention layers that
631    call F.scaled_dotproduct_attention with a simliar signature.
632
633    This expects the `forward` method consumes either:
634
635    * a single tensor for self attention
636    * one argument for each of: query, key, value
637
638    This currently only supports ring attention and the
639    SDPBackend.FLASH_ATTENTION backend. See sdpa_kernel.
640
641    Non-flash attention backends will result in incorrect results.
642    """
643
644    # use a weakref dictionary to store context managers for each nn.Module
645    _CONTEXT_MANAGERS: "weakref.WeakKeyDictionary[nn.Module, Any]" = (
646        weakref.WeakKeyDictionary()
647    )
648
649    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
650        if not isinstance(device_mesh, DeviceMesh):
651            raise ValueError(
652                f"{type(device_mesh)} is not supported by {type(self)} yet."
653            )
654
655        if not device_mesh.ndim == 1:
656            raise ValueError
657
658        return distribute_module(
659            module,
660            device_mesh,
661            input_fn=self._input_fn,  # type: ignore[arg-type]
662            output_fn=self._output_fn,  # type: ignore[arg-type]
663        )
664
665    @classmethod
666    def _input_fn(
667        cls,
668        module: nn.Module,
669        inputs: Tuple[Union[torch.Tensor, int, float], ...],
670        device_mesh: DeviceMesh,
671    ) -> Tuple[Union[torch.Tensor, int, float], ...]:
672        # TODO(d4l3k); this should be Shard(2), need to fix Linear layer rules
673        placement = [Replicate()]
674
675        def backward_hook(grad: torch.Tensor) -> None:
676            if module in cls._CONTEXT_MANAGERS:
677                cls._CONTEXT_MANAGERS[module].__exit__(None, None, None)
678                del cls._CONTEXT_MANAGERS[module]
679
680        # convert inputs to DTensor
681        inp = []
682        for input in inputs:
683            if isinstance(input, torch.Tensor) and not isinstance(input, DTensor):
684                input = DTensor.from_local(
685                    input.contiguous(), device_mesh, placement, run_check=False
686                )
687
688            if isinstance(input, torch.Tensor) and input.requires_grad:
689                input.register_hook(backward_hook)
690
691            inp.append(input)
692
693        manager = _enable_cp_dispatcher()
694        manager.__enter__()
695        cls._CONTEXT_MANAGERS[module] = manager
696
697        return tuple(inp)
698
699    @classmethod
700    def _output_fn(
701        cls,
702        module: nn.Module,
703        outputs: Union[torch.Tensor, Tuple[Union[torch.Tensor, int, float], ...]],
704        device_mesh: DeviceMesh,
705    ) -> Union[
706        Union[torch.Tensor, int, float], Tuple[Union[torch.Tensor, int, float], ...]
707    ]:
708        cls._CONTEXT_MANAGERS[module].__exit__(None, None, None)
709        del cls._CONTEXT_MANAGERS[module]
710
711        def backward_hook(grad: torch.Tensor) -> None:
712            if module not in cls._CONTEXT_MANAGERS:
713                manager = _enable_cp_dispatcher()
714                manager.__enter__()
715                cls._CONTEXT_MANAGERS[module] = manager
716
717        # back to local tensor
718        out = []
719        for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
720            output = output.to_local() if isinstance(output, DTensor) else output
721
722            if isinstance(output, torch.Tensor) and output.requires_grad:
723                output.register_hook(backward_hook)
724
725            out.append(output)
726
727        if isinstance(outputs, torch.Tensor):
728            return out[0]
729
730        return tuple(out)
731
732
733@contextlib.contextmanager
734def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, None]:
735    """Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher."""
736
737    def attention_input_fn(
738        mesh: DeviceMesh, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]
739    ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
740        placement = [Shard(seq_dim)]
741        all_args = []
742
743        for arg in itertools.chain(args, kwargs.values()):
744            if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
745                arg = DTensor.from_local(arg, mesh, placement, run_check=False)
746
747            all_args.append(arg)
748
749        new_args = tuple(all_args[0 : len(args)])
750        new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :]))
751        return new_args, new_kwargs
752
753    def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any:
754        new_outputs = []
755        for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
756            output = output.to_local() if isinstance(output, DTensor) else output
757            new_outputs.append(output)
758
759        if isinstance(outputs, torch.Tensor):
760            return new_outputs[0]
761
762        return tuple(new_outputs)
763
764    # TODO: provide a more robust way to replace SDPA.
765    # Currently we use monkey patch to replace scaled_dot_product_attention with the
766    # wrapped fn. This is okay if users do `import torch.nn.functional` but will not
767    # work if users do `import torch.nn.functional.scaled_dot_product_attention`.
768    _distribute_function(
769        F.scaled_dot_product_attention,
770        F,
771        mesh,
772        attention_input_fn,
773        attention_output_fn,
774    )
775
776    with _enable_cp_dispatcher():
777        yield
778
779    _restore_function(F.scaled_dot_product_attention, F)
780
781
782def _get_sequence_shard(
783    buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
784) -> torch.Tensor:
785    return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()]
786
787
788def _context_parallel_buffers(
789    mesh: DeviceMesh,
790    buffers: List[torch.Tensor],
791    buffer_seq_dims: List[int],
792) -> List[torch.Tensor]:
793    """Shard the buffers along the sequence dimensions according to CP rules."""
794    new_buffers = []
795    for buffer, seq_dim in zip(buffers, buffer_seq_dims):
796        new_buffers.append(_get_sequence_shard(buffer, mesh, seq_dim))
797
798    return new_buffers
799
800
801@contextlib.contextmanager
802@torch.no_grad()
803def context_parallel(
804    mesh: DeviceMesh,
805    *,
806    buffers: Optional[List[torch.Tensor]] = None,
807    buffer_seq_dims: Optional[List[int]] = None,
808    no_restore_buffers: Optional[Set[torch.Tensor]] = None,
809) -> Generator[None, None, None]:
810    """
811
812    ``context_parallel`` is an experimental API to enable context
813    parallelism (CP). This API performs two actions: 1) patch the SDPA
814    (``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled
815    one, 2) shard ``buffers`` along the sequence dimension and each rank will
816    preserve the corresponding shard according ``mesh``.
817
818    Args:
819        mesh (:class:`DeviceMesh`): the device mesh for the context parallelism.
820        buffers (Optional[List[torch.Tensor]]): buffers that the usage depend
821            on the sequence dimension. Examples are input batch, labels and
822            positional embedding buffers. These buffers must be sharded along
823            the sequence dimension to ensure the accuracy. The sharding will
824            happen in-place, the buffer's shape will change within the context.
825            The buffers will be restored after the context finishes.
826            ``no_restore_buffers`` can be used to specify which buffers don't
827            need to be restored. Note that ``buffers`` should not contain any
828            nn.Parameter.
829        buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``.
830        no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set
831            won't be restored after the context exits. This set must be a subset
832            of ``buffers``. If the buffers won't be used after the context exits,
833            these buffers can be put in this list to avoid extra restore time.
834
835    .. warning::
836        `torch.distributed._tensor.experimental.attention.context_parallel` is a
837        prototype feature in PyTorch. The API is subject to change.
838    """
839    buffers = [] if buffers is None else buffers
840    buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims
841    no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers
842
843    if len(buffers) != len(buffer_seq_dims):
844        raise ValueError(
845            "`seq_dims` must have the same number of elements as `buffers`."
846        )
847
848    for buffer in no_restore_buffers:
849        # Cannot use `if not buffer in buffers` which will incur tensor comparison.
850        if not any(b is buffer for b in buffers):
851            raise ValueError("`no_restore_buffers` must be a subset of `buffers`.")
852
853    original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers]
854
855    chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims)
856    for buffer, chunk in zip(buffers, chunks):
857        chunk = chunk.clone()
858        buffer.resize_(chunk.shape)
859        buffer.copy_(chunk)
860
861    with _context_parallel(seq_dim=2, mesh=mesh):
862        yield
863
864    for buffer, original_buffer in zip(buffers, original_buffers):
865        if original_buffer is not None:
866            buffer.resize_(original_buffer.shape)
867            buffer.copy_(original_buffer)
868