xref: /aosp_15_r20/external/pytorch/torch/nested/_internal/sdpa.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3from typing import Optional, Tuple
4
5import torch
6import torch.nn
7import torch.nn.functional as F
8from torch.backends.cuda import (
9    can_use_efficient_attention,
10    can_use_flash_attention,
11    flash_sdp_enabled,
12    math_sdp_enabled,
13    mem_efficient_sdp_enabled,
14    SDPAParams,
15)
16
17from torch.nn.attention import SDPBackend
18from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer
19
20log = logging.getLogger(__name__)
21
22
23def _validate_sdpa_input(
24    query: torch.Tensor,
25    key: torch.Tensor,
26    value: torch.Tensor,
27    attn_mask: Optional[torch.Tensor] = None,
28    dropout_p=0.0,
29    is_causal=False,
30    scale=None,
31):
32    if (
33        not isinstance(query, NestedTensor)
34        or not isinstance(key, NestedTensor)
35        or not isinstance(value, NestedTensor)
36    ):
37        raise ValueError(
38            f"Expected query, key, and value to be nested tensors, "
39            f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
40            f"and value.is_nested: {value.is_nested} instead."
41        )
42    if query.dtype != key.dtype or query.dtype != value.dtype:
43        raise ValueError(
44            f"Expected query, key, and value to have the same dtype, "
45            f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
46            f"and value.dtype: {value.dtype} instead."
47        )
48    if query.device != key.device or query.device != value.device:
49        raise ValueError(
50            f"Expected query, key, and value to have the same device type, "
51            f"but got query.device: {query.device}, key.device: {key.device}, "
52            f"and value.device: {value.device} instead."
53        )
54    if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
55        raise ValueError(
56            f"Expected query, key, and value to all be  at least 2 dimensional, but got query.dim: "
57            f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
58        )
59    if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
60        raise ValueError(
61            f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
62            f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
63        )
64    if attn_mask is not None:
65        # TODO: Figure out whether masks are actually supported for this layout or not
66        raise ValueError("Masks are not yet supported!")
67        if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
68            raise ValueError(
69                f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
70                f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
71            )
72
73
74def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
75    # This is expected to be called after check_tensor_shapes ensuring that the
76    # size() calls won't error since the inputs are all 4 dimensional
77    q_batch_size = params.query.size(0)
78    k_batch_size = params.key.size(0)
79    v_batch_size = params.value.size(0)
80
81    # num_heads logic for nested input is checked in
82    # check_for_seq_len_0_nested_tensor as there is handling there to make sure
83    # num_heads is not ragged
84    return q_batch_size == k_batch_size and q_batch_size == v_batch_size
85
86
87def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
88    max_size = 256
89    query_size_last = params.query.size(-1)
90    key_size_last = params.key.size(-1)
91    value_size_last = params.value.size(-1)
92    same_head_dim_size = (
93        query_size_last == key_size_last and query_size_last == value_size_last
94    )
95    if not (
96        same_head_dim_size
97        and (query_size_last % 8 == 0)
98        and (query_size_last <= max_size)
99    ):
100        if debug:
101            log.warning(
102                "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
103                "last dimension and to be a multiple of 8 and less than or equal to 256. "
104                "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
105                query_size_last,
106                key_size_last,
107                value_size_last,
108            )
109        return False
110    return True
111
112
113def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
114    param: torch.Tensor, param_name: str, debug=False
115) -> bool:
116    assert isinstance(param, NestedTensor), "param should be a jagged NT"
117
118    if param._ragged_idx == 1:
119        # num_head_dims is ragged
120        if debug:
121            log.warning(
122                "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
123                param_name,
124            )
125        return False
126
127    # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
128    if param._min_seqlen == 0:
129        if debug:
130            log.warning(
131                "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
132                param_name,
133            )
134        return False
135
136    return True
137
138
139def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
140    max_size = max(q_size, k_size, v_size)
141    if (
142        (q_size != max_size and q_size != 1)
143        or (k_size != max_size and k_size != 1)
144        or (v_size != max_size and v_size != 1)
145    ):
146        if debug:
147            log.warning(
148                "Both fused kernels require query, key and value to have broadcastable %s, "
149                "got Query %s %d, Key %s %d, Value %s %d instead.",
150                param_name,
151                param_name,
152                q_size,
153                param_name,
154                k_size,
155                param_name,
156                v_size,
157            )
158        return False
159    return True
160
161
162def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
163    # When this function is called we are assured that the nt is dim==4
164    q_is_safe = (
165        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
166            params.query, "query", debug
167        )
168        if params.query.is_nested
169        else True
170    )
171    # short circuit if any is unsafe
172    if not q_is_safe:
173        return False
174
175    k_is_safe = (
176        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
177            params.key, "key", debug
178        )
179        if params.key.is_nested
180        else True
181    )
182    # short circuit if any is unsafe
183    if not k_is_safe:
184        return False
185
186    v_is_safe = (
187        _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
188            params.value, "value", debug
189        )
190        if params.value.is_nested
191        else True
192    )
193    # short circuit if any is unsafe
194    if not v_is_safe:
195        return False
196
197    # We now know none of the inputs have ragged num_heads, so we can safely
198    # access .size(1)
199    q_num_heads = params.query.size(1)
200    k_num_heads = params.key.size(1)
201    v_num_heads = params.value.size(1)
202    same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
203
204    if not same_num_heads:
205        if (
206            params.query.requires_grad
207            or params.key.requires_grad
208            or params.value.requires_grad
209        ):
210            if debug:
211                log.warning(
212                    "Both fused kernels do not support training with broadcasted NT inputs."
213                )
214            return False
215        return _try_broadcast_param_size(
216            q_num_heads, k_num_heads, v_num_heads, "num heads", debug
217        )
218    return True
219
220
221def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
222    constraints = (
223        _check_batch_size_nested,
224        _check_head_dim_size_flash_nested,
225        _check_for_seq_len_0_nested,
226    )
227    for constraint in constraints:
228        if not constraint(params, debug):
229            return False
230    return True
231
232
233def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
234    constraints = (
235        _check_batch_size_nested,
236        _check_for_seq_len_0_nested,
237    )
238    for constraint in constraints:
239        if not constraint(params, debug):
240            return False
241    return True
242
243
244def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
245    if (
246        not params.query.transpose(1, 2).is_contiguous()
247        or not params.key.transpose(1, 2).is_contiguous()
248        or not params.value.transpose(1, 2).is_contiguous()
249    ):
250        if debug:
251            log.warning(
252                "If inputs are nested tensors they must be contiguous after transposing."
253            )
254        return False
255    if params.is_causal:
256        if debug:
257            log.warning(
258                "Nested tensors for query / key are not supported when is_causal=True."
259            )
260        return False
261    return True
262
263
264def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
265    if (
266        not flash_sdp_enabled()
267        and not mem_efficient_sdp_enabled()
268        and not math_sdp_enabled()
269    ):
270        return SDPBackend.ERROR
271
272    ordering = (
273        SDPBackend.FLASH_ATTENTION,
274        SDPBackend.EFFICIENT_ATTENTION,
275        SDPBackend.MATH,
276    )
277
278    params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
279
280    for backend in ordering:
281        if backend == SDPBackend.FLASH_ATTENTION:
282            if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
283                return SDPBackend.FLASH_ATTENTION
284        if backend == SDPBackend.EFFICIENT_ATTENTION:
285            if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
286                params
287            ):
288                return SDPBackend.EFFICIENT_ATTENTION
289        if backend == SDPBackend.MATH:
290            if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
291                return SDPBackend.MATH
292
293    log.warning("Memory efficient kernel not used because:")
294    can_use_efficient_attention(params, debug=True)
295    _can_use_efficient_sdpa_jagged(params, debug=True)
296    log.warning("Flash attention kernel not used because:")
297    can_use_flash_attention(params, debug=True)
298    _can_use_flash_sdpa_jagged(params, debug=True)
299    log.warning("Math attention kernel not used because:")
300    _can_use_math_sdpa_jagged(params, debug=True)
301    return SDPBackend.ERROR
302
303
304def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
305    # This function is used to calculate two pieces of metadata that are needed
306    # for use with flash-attention and efficient_attention kernels. They are the
307    # cumulative sequence_length over a batch of sequences and the maximum
308    # sequence length.
309
310    # It returns a tuple of cumulative sequence lengths and the maximum sequence
311    # length, and the last element in the cumulative_sequence_lengths
312    if not isinstance(qkv, NestedTensor):
313        raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
314
315    if qkv.lengths() is None:
316        # TODO: Explore performance impact of copying
317        cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
318        max_seqlen = qkv._max_seqlen
319        n_elem = qkv.values().shape[0]
320    else:
321        # TODO: Explore performance impact of copying
322        cumulative_seqlen = (
323            qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
324        )
325        batch_size = qkv.size(0)
326        max_seqlen = qkv._max_seqlen
327        # TODO: Explore performance impact when compiling
328        n_elem = int(cumulative_seqlen[-1].item())
329    return cumulative_seqlen, max_seqlen, n_elem
330
331
332def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
333    # This function checks if a nested tensor is valid for
334    # use with the flash-attention and efficient_attention kernels without
335    # needing to call contiguous on the nested tensor input.
336    # It checks that the storage offsets' adjacent_differences are a constant
337    # mutiple of the previous tensor in the nested tensor and that the strides
338    # are monitonically decreasing. This check is done after calling transpose on
339    # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
340
341    # Returns a boolean indicating if contiguous needs to be called for input
342    assert isinstance(tensor, NestedTensor)
343    offsets = tensor.offsets()
344    strides = tensor._strides
345
346    n_tensors = offsets.size(0) - 1
347    if n_tensors <= 1:
348        return True
349
350    # Check initially that the tensor strides are in strictly descending order
351    prev_stride = strides[1]
352    for stride in strides[2:]:
353        if prev_stride <= stride:
354            # This would mean that the last stride is greater than the seq_len
355            # stride
356            return False
357        prev_stride = stride
358
359    # Congrats you made it!
360    return True
361
362
363def _view_as_dense(
364    tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
365) -> torch.Tensor:
366    if tensor.is_nested:
367        return buffer_from_jagged(tensor)
368    return tensor.view(Nnz, num_heads, head_dim)
369
370
371# TODO: Next iteration should add test cases and check it works
372# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
373#     # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
374#     # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
375#     # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
376#     q_batch_size = query.size(0)
377#     k_batch_size = key.size(0)
378#     v_batch_size = value.size(0)
379
380#     output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
381
382#     q_num_heads = query.size(1)
383#     k_num_heads = key.size(1)
384#     v_num_heads = value.size(1)
385
386#     output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
387
388#     head_dim_qk = query.size(3)
389#     head_dim_v = value.size(3)
390
391#     q_t = query.transpose(1, 2)
392#     k_t = key.transpose(1, 2)
393#     v_t = value.transpose(1, 2)
394
395#     # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
396#     # output_batch_size/num_heads then they are 1
397#     q_batch_size_needs_broadcast = q_batch_size != output_batch_size
398#     k_batch_size_needs_broadcast = k_batch_size != output_batch_size
399#     v_batch_size_needs_broadcast = v_batch_size != output_batch_size
400
401#     # If {*}_batch_size_needs_broadcast, then
402#     # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
403#     #     this is because needs_broadcast indicates that the batch_size is 1
404#     #     and hence there is only 1 value for seq_len
405#     # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
406#     # ..., outut_batch_size * {*}_t.size(1)]
407#     # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
408
409#     if q_batch_size_needs_broadcast or not q_t.is_nested:
410#         max_seqlen_batch_q = q_t.size(1)
411#         cumulative_sequence_length_q = torch.arange(
412#             0,
413#             (output_batch_size + 1) * max_seqlen_batch_q,
414#             max_seqlen_batch_q,
415#             device=q_t.device,
416#             dtype=torch.int32,
417#         )
418#         Nnz_q = output_batch_size * max_seqlen_batch_q
419#     else:
420#         (
421#             cumulative_sequence_length_q,
422#             max_seqlen_batch_q,
423#             Nnz_q,
424#         ) = _cumulative_and_max_seq_len_nnz(q_t)
425
426#     if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
427#         assert k_t.size(1) == v_t.size(1)
428#         max_seqlen_batch_kv = k_t.size(1)
429#         cumulative_sequence_length_kv = torch.arange(
430#             0,
431#             (output_batch_size + 1) * max_seqlen_batch_kv,
432#             max_seqlen_batch_kv,
433#             device=k_t.device,
434#             dtype=torch.int32,
435#         )
436#         Nnz_kv = output_batch_size * max_seqlen_batch_kv
437#     else:
438#         cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
439#             _cumulative_and_max_seq_len_nnz(v_t)
440#             if k_batch_size_needs_broadcast
441#             else _cumulative_and_max_seq_len_nnz(k_t)
442#         )
443
444#     q_num_heads_needs_broadcast = q_num_heads != output_num_heads
445#     k_num_heads_needs_broadcast = k_num_heads != output_num_heads
446#     v_num_heads_needs_broadcast = v_num_heads != output_num_heads
447
448#     if not q_t.is_nested:
449#         query_buffer_reshaped = q_t.expand(
450#             output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
451#         )
452#         query_buffer_reshaped = query_buffer_reshaped.reshape(
453#             Nnz_q, output_num_heads, head_dim_qk
454#         )
455#     else:
456#         if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
457#             q_t = q_t.contiguous()
458#         # If we are broadcasting then Nnz_q will be the output_batch_size since
459#         # seq_len is 1
460#         effective_batch_size_q = (
461#             output_batch_size if q_batch_size_needs_broadcast else Nnz_q
462#         )
463#         query_buffer_reshaped = _view_as_dense(
464#             q_t, effective_batch_size_q, output_num_heads, head_dim_qk
465#         )
466
467#     # If the physical layout of the NestedTensor's storage
468#     # is not: batch, {seq_len}, num_heads, head_dim then we need
469#     # to call contiguous
470#     if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
471#         k_t = k_t.contiguous()
472#     if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
473#         v_t = v_t.contiguous()
474
475#     effective_batch_size_k = (
476#         output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
477#     )
478#     key_buffer_reshaped = _view_as_dense(
479#         k_t, effective_batch_size_k, output_num_heads, head_dim_qk
480#     )
481
482#     effective_batch_size_v = (
483#         output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
484#     )
485#     value_buffer_reshaped = _view_as_dense(
486#         v_t, effective_batch_size_v, output_num_heads, head_dim_v
487#     )
488
489#     if not q_batch_size_needs_broadcast:
490#         output_shape = q_t._size
491#         if head_dim_v != head_dim_qk:
492#             output_shape[-1] = head_dim_v
493#         if q_num_heads_needs_broadcast:
494#             output_shape[1] = output_num_heads
495#     else:
496#         output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
497#         output_shape[0] = q_t.size(1)
498#         output_shape[1] = output_num_heads
499#         output_shape[2] = head_dim_v
500
501#     return (
502#         query_buffer_reshaped,
503#         key_buffer_reshaped,
504#         value_buffer_reshaped,
505#         cumulative_sequence_length_q,
506#         cumulative_sequence_length_kv,
507#         max_seqlen_batch_q,
508#         max_seqlen_batch_kv,
509#         output_shape,
510#     )
511
512
513def _sdpa_nested_preprocessing(query, key, value):
514    # Query (Batch x Num_heads x {Q_seq_len}  x Dim_per_head)
515    # Key   (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
516    # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
517    q_batch_size = query.size(0)
518    k_batch_size = key.size(0)
519    v_batch_size = value.size(0)
520
521    q_num_heads = query.size(1)
522    k_num_heads = key.size(1)
523    v_num_heads = value.size(1)
524
525    if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
526        q_num_heads == k_num_heads and k_num_heads == v_num_heads
527    ):
528        raise RuntimeError(
529            "This path is currently not implemented for jagged layout NT."
530        )
531        # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
532
533    num_heads = query.size(1)
534    head_dim_qk = query.size(3)
535    head_dim_v = value.size(3)
536    q_t = query.transpose(1, 2)
537    k_t = key.transpose(1, 2)
538    v_t = value.transpose(1, 2)
539
540    (
541        cumulative_sequence_length_q,
542        max_seqlen_batch_q,
543        Nnz_q,
544    ) = _cumulative_and_max_seq_len_nnz(q_t)
545    (
546        cumulative_sequence_length_kv,
547        max_seqlen_batch_kv,
548        Nnz_kv,
549    ) = _cumulative_and_max_seq_len_nnz(k_t)
550
551    # [TODO] K and V have to have the same Nnz, should probably torch_check
552    # assume in order to not iterate over v
553
554    # If the physical layout of the NestedTensor's storage
555    # is not: batch, {seq_len}, num_heads, head_dim then we need
556    # to call contiguous
557    if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
558        q_t = q_t.contiguous()
559    if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
560        k_t = k_t.contiguous()
561    if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
562        v_t = v_t.contiguous()
563
564    query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
565    key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
566    value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
567
568    output_nt_info = {
569        "offsets": q_t.offsets(),
570        "_max_seqlen": q_t._max_seqlen,
571        "_min_seqlen": q_t._min_seqlen,
572    }
573
574    return (
575        query_buffer_reshaped,
576        key_buffer_reshaped,
577        value_buffer_reshaped,
578        cumulative_sequence_length_q,
579        cumulative_sequence_length_kv,
580        max_seqlen_batch_q,
581        max_seqlen_batch_kv,
582        output_nt_info,
583    )
584
585
586def _pad_last_dim(
587    tensor: torch.Tensor, alignment_size: int, slice: bool
588) -> torch.Tensor:
589    # FlashAttentionV2 requires that head dimension be a multiple of 8
590    # This was previously done within the kernel, however
591    # This causes the kernel to maybe alias query, key, value
592    # So instead we pad the head_dimensions to be a multiple of 8
593    # in the composite region
594    last_dim_size = tensor.size(-1)
595    if last_dim_size % alignment_size == 0:
596        return tensor
597    pad_count = alignment_size - (last_dim_size % alignment_size)
598    tensor = torch.nn.functional.pad(tensor, [0, pad_count])
599    if slice:
600        return tensor[..., 0:last_dim_size]
601    return tensor
602
603
604# TODO: coalesce with torch/nn/utils/attention.py
605def _calculate_scale(query, scale):
606    # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
607    softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
608    return softmax_scale
609
610
611def _post_process_flash_output(out: torch.Tensor, og_size):
612    if not out.is_nested and out.size(-1) != og_size:
613        out = out[..., 0:og_size]
614    return out
615
616
617def jagged_scaled_dot_product_attention(
618    query: torch.Tensor,
619    key: torch.Tensor,
620    value: torch.Tensor,
621    attn_mask: Optional[torch.Tensor] = None,
622    dropout_p=0.0,
623    is_causal=False,
624    scale=None,
625):
626    _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
627    # for mypy, ugh
628    assert (
629        isinstance(query, NestedTensor)
630        and isinstance(key, NestedTensor)
631        and isinstance(value, NestedTensor)
632    )
633
634    # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
635    # second batch dim instead). For this case, we can just send the dense buffers through
636    # vanilla SDPA.
637    if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
638        from torch.nested._internal.ops import extract_kwargs
639
640        output = F.scaled_dot_product_attention(
641            query._values,
642            key._values,
643            value._values,
644            attn_mask=(
645                attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
646            ),
647            dropout_p=dropout_p,
648            is_causal=is_causal,
649            scale=scale,
650        )
651
652        return NestedTensor(output, **extract_kwargs(query))
653
654    compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
655
656    backend_choice = _select_sdp_backend(
657        query, key, value, attn_mask, dropout_p, is_causal
658    )
659
660    if backend_choice == SDPBackend.FLASH_ATTENTION:
661        og_size = query.size(-1)
662        query_padded = _pad_last_dim(query, 8, False)
663        key_padded = _pad_last_dim(key, 8, False)
664        value_padded = _pad_last_dim(value, 8, False)
665        # We need to calculate the scale based off the OG head dim size
666        og_scale = _calculate_scale(query, scale)
667        (
668            query_buffer_reshaped,
669            key_buffer_reshaped,
670            value_buffer_reshaped,
671            cumulative_sequence_length_q,
672            cumulative_sequence_length_kv,
673            max_seqlen_batch_q,
674            max_seqlen_batch_kv,
675            output_nt_info,
676        ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
677
678        (
679            attention,
680            logsumexp,
681            philox_seed,
682            philox_offset,
683            debug_attn_mask,
684        ) = torch.ops.aten._flash_attention_forward(
685            query_buffer_reshaped,
686            key_buffer_reshaped,
687            value_buffer_reshaped,
688            cumulative_sequence_length_q,
689            cumulative_sequence_length_kv,
690            max_seqlen_batch_q,
691            max_seqlen_batch_kv,
692            dropout_p,
693            is_causal,
694            False,
695            scale=og_scale,
696        )
697        # Reshape output to convert nnz to batch_size and seq_len
698        attention = ViewNestedFromBuffer.apply(
699            attention,  # output from flash_attn is [total_q, num_heads, head_size_og]
700            output_nt_info["offsets"],
701        ).transpose(1, 2)
702        return _post_process_flash_output(attention, og_size)
703    elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
704        (
705            query_reshaped,
706            key_reshaped,
707            value_reshaped,
708            cumulative_sequence_length_q,
709            cumulative_sequence_length_kv,
710            max_seqlen_batch_q,
711            max_seqlen_batch_kv,
712            output_nt_info,
713        ) = _sdpa_nested_preprocessing(query, key, value)
714        (
715            attention,
716            log_sumexp,
717            seed,
718            offset,
719            max_seqlen_q,
720            max_seqlen_batch_kv,
721        ) = torch.ops.aten._efficient_attention_forward(
722            query_reshaped.unsqueeze(0),
723            key_reshaped.unsqueeze(0),
724            value_reshaped.unsqueeze(0),
725            None,
726            cumulative_sequence_length_q,
727            cumulative_sequence_length_kv,
728            max_seqlen_batch_q,
729            max_seqlen_batch_kv,
730            dropout_p,
731            int(is_causal),
732            compute_logsumexp,
733            scale=scale,
734        )
735
736        # Reshape output to convert nnz to batch_size and seq_len
737        return ViewNestedFromBuffer.apply(
738            attention.squeeze(0), output_nt_info["offsets"]
739        ).transpose(1, 2)
740    elif backend_choice == SDPBackend.MATH:
741        # save the offsets and shape of the inputs, so we can reshape the final output
742        # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
743        # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
744        offsets = query.offsets()
745        d1 = query._size[1]
746        d2 = value._size[-1]
747
748        # convert jagged layout Nested Tensor to strided layout Nested Tensor
749        # which support the math implementation of SDPA
750        def get_strided_layout_nested_tensor(jagged_layout_nt):
751            lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
752            transpose = torch.transpose(jagged_layout_nt, 1, 2)
753            tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0)
754            strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
755            strided_nt = strided_nt.transpose(1, 2).contiguous()
756            return strided_nt
757
758        query = get_strided_layout_nested_tensor(query)
759        key = get_strided_layout_nested_tensor(key)
760        value = get_strided_layout_nested_tensor(value)
761
762        attn_out = torch._scaled_dot_product_attention_math(
763            query, key, value, attn_mask, dropout_p, is_causal, scale=scale
764        )[0]
765
766        # convert strided layout Nested Tensor back to jagged layout Nested Tensor
767        attn_out = attn_out.transpose(1, 2).contiguous().values()
768        attn_out = attn_out.view(-1, d1, d2)
769        attn_out = ViewNestedFromBuffer.apply(attn_out, offsets)
770        attn_out = attn_out.transpose(1, 2)
771
772        return attn_out
773    else:
774        raise RuntimeError(
775            "No viable backend for scaled_dot_product_attention was found."
776        )
777