1# mypy: allow-untyped-defs 2import logging 3from typing import Optional, Tuple 4 5import torch 6import torch.nn 7import torch.nn.functional as F 8from torch.backends.cuda import ( 9 can_use_efficient_attention, 10 can_use_flash_attention, 11 flash_sdp_enabled, 12 math_sdp_enabled, 13 mem_efficient_sdp_enabled, 14 SDPAParams, 15) 16 17from torch.nn.attention import SDPBackend 18from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer 19 20log = logging.getLogger(__name__) 21 22 23def _validate_sdpa_input( 24 query: torch.Tensor, 25 key: torch.Tensor, 26 value: torch.Tensor, 27 attn_mask: Optional[torch.Tensor] = None, 28 dropout_p=0.0, 29 is_causal=False, 30 scale=None, 31): 32 if ( 33 not isinstance(query, NestedTensor) 34 or not isinstance(key, NestedTensor) 35 or not isinstance(value, NestedTensor) 36 ): 37 raise ValueError( 38 f"Expected query, key, and value to be nested tensors, " 39 f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, " 40 f"and value.is_nested: {value.is_nested} instead." 41 ) 42 if query.dtype != key.dtype or query.dtype != value.dtype: 43 raise ValueError( 44 f"Expected query, key, and value to have the same dtype, " 45 f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " 46 f"and value.dtype: {value.dtype} instead." 47 ) 48 if query.device != key.device or query.device != value.device: 49 raise ValueError( 50 f"Expected query, key, and value to have the same device type, " 51 f"but got query.device: {query.device}, key.device: {key.device}, " 52 f"and value.device: {value.device} instead." 53 ) 54 if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: 55 raise ValueError( 56 f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " 57 f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." 58 ) 59 if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx: 60 raise ValueError( 61 f"Expected query, key, and value to all be ragged on the same dimension, but got ragged " 62 f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively." 63 ) 64 if attn_mask is not None: 65 # TODO: Figure out whether masks are actually supported for this layout or not 66 raise ValueError("Masks are not yet supported!") 67 if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype: 68 raise ValueError( 69 f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: " 70 f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead." 71 ) 72 73 74def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool: 75 # This is expected to be called after check_tensor_shapes ensuring that the 76 # size() calls won't error since the inputs are all 4 dimensional 77 q_batch_size = params.query.size(0) 78 k_batch_size = params.key.size(0) 79 v_batch_size = params.value.size(0) 80 81 # num_heads logic for nested input is checked in 82 # check_for_seq_len_0_nested_tensor as there is handling there to make sure 83 # num_heads is not ragged 84 return q_batch_size == k_batch_size and q_batch_size == v_batch_size 85 86 87def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool: 88 max_size = 256 89 query_size_last = params.query.size(-1) 90 key_size_last = params.key.size(-1) 91 value_size_last = params.value.size(-1) 92 same_head_dim_size = ( 93 query_size_last == key_size_last and query_size_last == value_size_last 94 ) 95 if not ( 96 same_head_dim_size 97 and (query_size_last % 8 == 0) 98 and (query_size_last <= max_size) 99 ): 100 if debug: 101 log.warning( 102 "For NestedTensor inputs, Flash attention requires q,k,v to have the same " 103 "last dimension and to be a multiple of 8 and less than or equal to 256. " 104 "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.", 105 query_size_last, 106 key_size_last, 107 value_size_last, 108 ) 109 return False 110 return True 111 112 113def _check_for_seq_len_0_and_consistent_head_dim_nested_helper( 114 param: torch.Tensor, param_name: str, debug=False 115) -> bool: 116 assert isinstance(param, NestedTensor), "param should be a jagged NT" 117 118 if param._ragged_idx == 1: 119 # num_head_dims is ragged 120 if debug: 121 log.warning( 122 "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.", 123 param_name, 124 ) 125 return False 126 127 # This is being called inside sdp with shape [batch, heads, {seq_len}, dim] 128 if param._min_seqlen == 0: 129 if debug: 130 log.warning( 131 "Fused kernels do not support seq_len == 0, %s has a seq len of 0.", 132 param_name, 133 ) 134 return False 135 136 return True 137 138 139def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool: 140 max_size = max(q_size, k_size, v_size) 141 if ( 142 (q_size != max_size and q_size != 1) 143 or (k_size != max_size and k_size != 1) 144 or (v_size != max_size and v_size != 1) 145 ): 146 if debug: 147 log.warning( 148 "Both fused kernels require query, key and value to have broadcastable %s, " 149 "got Query %s %d, Key %s %d, Value %s %d instead.", 150 param_name, 151 param_name, 152 q_size, 153 param_name, 154 k_size, 155 param_name, 156 v_size, 157 ) 158 return False 159 return True 160 161 162def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool: 163 # When this function is called we are assured that the nt is dim==4 164 q_is_safe = ( 165 _check_for_seq_len_0_and_consistent_head_dim_nested_helper( 166 params.query, "query", debug 167 ) 168 if params.query.is_nested 169 else True 170 ) 171 # short circuit if any is unsafe 172 if not q_is_safe: 173 return False 174 175 k_is_safe = ( 176 _check_for_seq_len_0_and_consistent_head_dim_nested_helper( 177 params.key, "key", debug 178 ) 179 if params.key.is_nested 180 else True 181 ) 182 # short circuit if any is unsafe 183 if not k_is_safe: 184 return False 185 186 v_is_safe = ( 187 _check_for_seq_len_0_and_consistent_head_dim_nested_helper( 188 params.value, "value", debug 189 ) 190 if params.value.is_nested 191 else True 192 ) 193 # short circuit if any is unsafe 194 if not v_is_safe: 195 return False 196 197 # We now know none of the inputs have ragged num_heads, so we can safely 198 # access .size(1) 199 q_num_heads = params.query.size(1) 200 k_num_heads = params.key.size(1) 201 v_num_heads = params.value.size(1) 202 same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads 203 204 if not same_num_heads: 205 if ( 206 params.query.requires_grad 207 or params.key.requires_grad 208 or params.value.requires_grad 209 ): 210 if debug: 211 log.warning( 212 "Both fused kernels do not support training with broadcasted NT inputs." 213 ) 214 return False 215 return _try_broadcast_param_size( 216 q_num_heads, k_num_heads, v_num_heads, "num heads", debug 217 ) 218 return True 219 220 221def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool: 222 constraints = ( 223 _check_batch_size_nested, 224 _check_head_dim_size_flash_nested, 225 _check_for_seq_len_0_nested, 226 ) 227 for constraint in constraints: 228 if not constraint(params, debug): 229 return False 230 return True 231 232 233def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool: 234 constraints = ( 235 _check_batch_size_nested, 236 _check_for_seq_len_0_nested, 237 ) 238 for constraint in constraints: 239 if not constraint(params, debug): 240 return False 241 return True 242 243 244def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool: 245 if ( 246 not params.query.transpose(1, 2).is_contiguous() 247 or not params.key.transpose(1, 2).is_contiguous() 248 or not params.value.transpose(1, 2).is_contiguous() 249 ): 250 if debug: 251 log.warning( 252 "If inputs are nested tensors they must be contiguous after transposing." 253 ) 254 return False 255 if params.is_causal: 256 if debug: 257 log.warning( 258 "Nested tensors for query / key are not supported when is_causal=True." 259 ) 260 return False 261 return True 262 263 264def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): 265 if ( 266 not flash_sdp_enabled() 267 and not mem_efficient_sdp_enabled() 268 and not math_sdp_enabled() 269 ): 270 return SDPBackend.ERROR 271 272 ordering = ( 273 SDPBackend.FLASH_ATTENTION, 274 SDPBackend.EFFICIENT_ATTENTION, 275 SDPBackend.MATH, 276 ) 277 278 params = SDPAParams(query, key, value, attn_mask, dropout, is_causal) 279 280 for backend in ordering: 281 if backend == SDPBackend.FLASH_ATTENTION: 282 if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params): 283 return SDPBackend.FLASH_ATTENTION 284 if backend == SDPBackend.EFFICIENT_ATTENTION: 285 if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged( 286 params 287 ): 288 return SDPBackend.EFFICIENT_ATTENTION 289 if backend == SDPBackend.MATH: 290 if math_sdp_enabled() and _can_use_math_sdpa_jagged(params): 291 return SDPBackend.MATH 292 293 log.warning("Memory efficient kernel not used because:") 294 can_use_efficient_attention(params, debug=True) 295 _can_use_efficient_sdpa_jagged(params, debug=True) 296 log.warning("Flash attention kernel not used because:") 297 can_use_flash_attention(params, debug=True) 298 _can_use_flash_sdpa_jagged(params, debug=True) 299 log.warning("Math attention kernel not used because:") 300 _can_use_math_sdpa_jagged(params, debug=True) 301 return SDPBackend.ERROR 302 303 304def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]: 305 # This function is used to calculate two pieces of metadata that are needed 306 # for use with flash-attention and efficient_attention kernels. They are the 307 # cumulative sequence_length over a batch of sequences and the maximum 308 # sequence length. 309 310 # It returns a tuple of cumulative sequence lengths and the maximum sequence 311 # length, and the last element in the cumulative_sequence_lengths 312 if not isinstance(qkv, NestedTensor): 313 raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.") 314 315 if qkv.lengths() is None: 316 # TODO: Explore performance impact of copying 317 cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device) 318 max_seqlen = qkv._max_seqlen 319 n_elem = qkv.values().shape[0] 320 else: 321 # TODO: Explore performance impact of copying 322 cumulative_seqlen = ( 323 qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) 324 ) 325 batch_size = qkv.size(0) 326 max_seqlen = qkv._max_seqlen 327 # TODO: Explore performance impact when compiling 328 n_elem = int(cumulative_seqlen[-1].item()) 329 return cumulative_seqlen, max_seqlen, n_elem 330 331 332def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): 333 # This function checks if a nested tensor is valid for 334 # use with the flash-attention and efficient_attention kernels without 335 # needing to call contiguous on the nested tensor input. 336 # It checks that the storage offsets' adjacent_differences are a constant 337 # mutiple of the previous tensor in the nested tensor and that the strides 338 # are monitonically decreasing. This check is done after calling transpose on 339 # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim] 340 341 # Returns a boolean indicating if contiguous needs to be called for input 342 assert isinstance(tensor, NestedTensor) 343 offsets = tensor.offsets() 344 strides = tensor._strides 345 346 n_tensors = offsets.size(0) - 1 347 if n_tensors <= 1: 348 return True 349 350 # Check initially that the tensor strides are in strictly descending order 351 prev_stride = strides[1] 352 for stride in strides[2:]: 353 if prev_stride <= stride: 354 # This would mean that the last stride is greater than the seq_len 355 # stride 356 return False 357 prev_stride = stride 358 359 # Congrats you made it! 360 return True 361 362 363def _view_as_dense( 364 tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int 365) -> torch.Tensor: 366 if tensor.is_nested: 367 return buffer_from_jagged(tensor) 368 return tensor.view(Nnz, num_heads, head_dim) 369 370 371# TODO: Next iteration should add test cases and check it works 372# def _sdpa_nested_preprocessing_with_broadcast(query, key, value): 373# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) 374# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) 375# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) 376# q_batch_size = query.size(0) 377# k_batch_size = key.size(0) 378# v_batch_size = value.size(0) 379 380# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size) 381 382# q_num_heads = query.size(1) 383# k_num_heads = key.size(1) 384# v_num_heads = value.size(1) 385 386# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads) 387 388# head_dim_qk = query.size(3) 389# head_dim_v = value.size(3) 390 391# q_t = query.transpose(1, 2) 392# k_t = key.transpose(1, 2) 393# v_t = value.transpose(1, 2) 394 395# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads != 396# # output_batch_size/num_heads then they are 1 397# q_batch_size_needs_broadcast = q_batch_size != output_batch_size 398# k_batch_size_needs_broadcast = k_batch_size != output_batch_size 399# v_batch_size_needs_broadcast = v_batch_size != output_batch_size 400 401# # If {*}_batch_size_needs_broadcast, then 402# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1) 403# # this is because needs_broadcast indicates that the batch_size is 1 404# # and hence there is only 1 value for seq_len 405# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1), 406# # ..., outut_batch_size * {*}_t.size(1)] 407# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1) 408 409# if q_batch_size_needs_broadcast or not q_t.is_nested: 410# max_seqlen_batch_q = q_t.size(1) 411# cumulative_sequence_length_q = torch.arange( 412# 0, 413# (output_batch_size + 1) * max_seqlen_batch_q, 414# max_seqlen_batch_q, 415# device=q_t.device, 416# dtype=torch.int32, 417# ) 418# Nnz_q = output_batch_size * max_seqlen_batch_q 419# else: 420# ( 421# cumulative_sequence_length_q, 422# max_seqlen_batch_q, 423# Nnz_q, 424# ) = _cumulative_and_max_seq_len_nnz(q_t) 425 426# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast: 427# assert k_t.size(1) == v_t.size(1) 428# max_seqlen_batch_kv = k_t.size(1) 429# cumulative_sequence_length_kv = torch.arange( 430# 0, 431# (output_batch_size + 1) * max_seqlen_batch_kv, 432# max_seqlen_batch_kv, 433# device=k_t.device, 434# dtype=torch.int32, 435# ) 436# Nnz_kv = output_batch_size * max_seqlen_batch_kv 437# else: 438# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = ( 439# _cumulative_and_max_seq_len_nnz(v_t) 440# if k_batch_size_needs_broadcast 441# else _cumulative_and_max_seq_len_nnz(k_t) 442# ) 443 444# q_num_heads_needs_broadcast = q_num_heads != output_num_heads 445# k_num_heads_needs_broadcast = k_num_heads != output_num_heads 446# v_num_heads_needs_broadcast = v_num_heads != output_num_heads 447 448# if not q_t.is_nested: 449# query_buffer_reshaped = q_t.expand( 450# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk 451# ) 452# query_buffer_reshaped = query_buffer_reshaped.reshape( 453# Nnz_q, output_num_heads, head_dim_qk 454# ) 455# else: 456# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): 457# q_t = q_t.contiguous() 458# # If we are broadcasting then Nnz_q will be the output_batch_size since 459# # seq_len is 1 460# effective_batch_size_q = ( 461# output_batch_size if q_batch_size_needs_broadcast else Nnz_q 462# ) 463# query_buffer_reshaped = _view_as_dense( 464# q_t, effective_batch_size_q, output_num_heads, head_dim_qk 465# ) 466 467# # If the physical layout of the NestedTensor's storage 468# # is not: batch, {seq_len}, num_heads, head_dim then we need 469# # to call contiguous 470# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): 471# k_t = k_t.contiguous() 472# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): 473# v_t = v_t.contiguous() 474 475# effective_batch_size_k = ( 476# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv 477# ) 478# key_buffer_reshaped = _view_as_dense( 479# k_t, effective_batch_size_k, output_num_heads, head_dim_qk 480# ) 481 482# effective_batch_size_v = ( 483# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv 484# ) 485# value_buffer_reshaped = _view_as_dense( 486# v_t, effective_batch_size_v, output_num_heads, head_dim_v 487# ) 488 489# if not q_batch_size_needs_broadcast: 490# output_shape = q_t._size 491# if head_dim_v != head_dim_qk: 492# output_shape[-1] = head_dim_v 493# if q_num_heads_needs_broadcast: 494# output_shape[1] = output_num_heads 495# else: 496# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu")) 497# output_shape[0] = q_t.size(1) 498# output_shape[1] = output_num_heads 499# output_shape[2] = head_dim_v 500 501# return ( 502# query_buffer_reshaped, 503# key_buffer_reshaped, 504# value_buffer_reshaped, 505# cumulative_sequence_length_q, 506# cumulative_sequence_length_kv, 507# max_seqlen_batch_q, 508# max_seqlen_batch_kv, 509# output_shape, 510# ) 511 512 513def _sdpa_nested_preprocessing(query, key, value): 514 # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) 515 # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) 516 # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) 517 q_batch_size = query.size(0) 518 k_batch_size = key.size(0) 519 v_batch_size = value.size(0) 520 521 q_num_heads = query.size(1) 522 k_num_heads = key.size(1) 523 v_num_heads = value.size(1) 524 525 if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not ( 526 q_num_heads == k_num_heads and k_num_heads == v_num_heads 527 ): 528 raise RuntimeError( 529 "This path is currently not implemented for jagged layout NT." 530 ) 531 # return _sdpa_nested_preprocessing_with_broadcast(query, key, value) 532 533 num_heads = query.size(1) 534 head_dim_qk = query.size(3) 535 head_dim_v = value.size(3) 536 q_t = query.transpose(1, 2) 537 k_t = key.transpose(1, 2) 538 v_t = value.transpose(1, 2) 539 540 ( 541 cumulative_sequence_length_q, 542 max_seqlen_batch_q, 543 Nnz_q, 544 ) = _cumulative_and_max_seq_len_nnz(q_t) 545 ( 546 cumulative_sequence_length_kv, 547 max_seqlen_batch_kv, 548 Nnz_kv, 549 ) = _cumulative_and_max_seq_len_nnz(k_t) 550 551 # [TODO] K and V have to have the same Nnz, should probably torch_check 552 # assume in order to not iterate over v 553 554 # If the physical layout of the NestedTensor's storage 555 # is not: batch, {seq_len}, num_heads, head_dim then we need 556 # to call contiguous 557 if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t): 558 q_t = q_t.contiguous() 559 if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t): 560 k_t = k_t.contiguous() 561 if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t): 562 v_t = v_t.contiguous() 563 564 query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk) 565 key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk) 566 value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v) 567 568 output_nt_info = { 569 "offsets": q_t.offsets(), 570 "_max_seqlen": q_t._max_seqlen, 571 "_min_seqlen": q_t._min_seqlen, 572 } 573 574 return ( 575 query_buffer_reshaped, 576 key_buffer_reshaped, 577 value_buffer_reshaped, 578 cumulative_sequence_length_q, 579 cumulative_sequence_length_kv, 580 max_seqlen_batch_q, 581 max_seqlen_batch_kv, 582 output_nt_info, 583 ) 584 585 586def _pad_last_dim( 587 tensor: torch.Tensor, alignment_size: int, slice: bool 588) -> torch.Tensor: 589 # FlashAttentionV2 requires that head dimension be a multiple of 8 590 # This was previously done within the kernel, however 591 # This causes the kernel to maybe alias query, key, value 592 # So instead we pad the head_dimensions to be a multiple of 8 593 # in the composite region 594 last_dim_size = tensor.size(-1) 595 if last_dim_size % alignment_size == 0: 596 return tensor 597 pad_count = alignment_size - (last_dim_size % alignment_size) 598 tensor = torch.nn.functional.pad(tensor, [0, pad_count]) 599 if slice: 600 return tensor[..., 0:last_dim_size] 601 return tensor 602 603 604# TODO: coalesce with torch/nn/utils/attention.py 605def _calculate_scale(query, scale): 606 # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo? 607 softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1)) 608 return softmax_scale 609 610 611def _post_process_flash_output(out: torch.Tensor, og_size): 612 if not out.is_nested and out.size(-1) != og_size: 613 out = out[..., 0:og_size] 614 return out 615 616 617def jagged_scaled_dot_product_attention( 618 query: torch.Tensor, 619 key: torch.Tensor, 620 value: torch.Tensor, 621 attn_mask: Optional[torch.Tensor] = None, 622 dropout_p=0.0, 623 is_causal=False, 624 scale=None, 625): 626 _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) 627 # for mypy, ugh 628 assert ( 629 isinstance(query, NestedTensor) 630 and isinstance(key, NestedTensor) 631 and isinstance(value, NestedTensor) 632 ) 633 634 # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged 635 # second batch dim instead). For this case, we can just send the dense buffers through 636 # vanilla SDPA. 637 if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1: 638 from torch.nested._internal.ops import extract_kwargs 639 640 output = F.scaled_dot_product_attention( 641 query._values, 642 key._values, 643 value._values, 644 attn_mask=( 645 attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask 646 ), 647 dropout_p=dropout_p, 648 is_causal=is_causal, 649 scale=scale, 650 ) 651 652 return NestedTensor(output, **extract_kwargs(query)) 653 654 compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad 655 656 backend_choice = _select_sdp_backend( 657 query, key, value, attn_mask, dropout_p, is_causal 658 ) 659 660 if backend_choice == SDPBackend.FLASH_ATTENTION: 661 og_size = query.size(-1) 662 query_padded = _pad_last_dim(query, 8, False) 663 key_padded = _pad_last_dim(key, 8, False) 664 value_padded = _pad_last_dim(value, 8, False) 665 # We need to calculate the scale based off the OG head dim size 666 og_scale = _calculate_scale(query, scale) 667 ( 668 query_buffer_reshaped, 669 key_buffer_reshaped, 670 value_buffer_reshaped, 671 cumulative_sequence_length_q, 672 cumulative_sequence_length_kv, 673 max_seqlen_batch_q, 674 max_seqlen_batch_kv, 675 output_nt_info, 676 ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded) 677 678 ( 679 attention, 680 logsumexp, 681 philox_seed, 682 philox_offset, 683 debug_attn_mask, 684 ) = torch.ops.aten._flash_attention_forward( 685 query_buffer_reshaped, 686 key_buffer_reshaped, 687 value_buffer_reshaped, 688 cumulative_sequence_length_q, 689 cumulative_sequence_length_kv, 690 max_seqlen_batch_q, 691 max_seqlen_batch_kv, 692 dropout_p, 693 is_causal, 694 False, 695 scale=og_scale, 696 ) 697 # Reshape output to convert nnz to batch_size and seq_len 698 attention = ViewNestedFromBuffer.apply( 699 attention, # output from flash_attn is [total_q, num_heads, head_size_og] 700 output_nt_info["offsets"], 701 ).transpose(1, 2) 702 return _post_process_flash_output(attention, og_size) 703 elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: 704 ( 705 query_reshaped, 706 key_reshaped, 707 value_reshaped, 708 cumulative_sequence_length_q, 709 cumulative_sequence_length_kv, 710 max_seqlen_batch_q, 711 max_seqlen_batch_kv, 712 output_nt_info, 713 ) = _sdpa_nested_preprocessing(query, key, value) 714 ( 715 attention, 716 log_sumexp, 717 seed, 718 offset, 719 max_seqlen_q, 720 max_seqlen_batch_kv, 721 ) = torch.ops.aten._efficient_attention_forward( 722 query_reshaped.unsqueeze(0), 723 key_reshaped.unsqueeze(0), 724 value_reshaped.unsqueeze(0), 725 None, 726 cumulative_sequence_length_q, 727 cumulative_sequence_length_kv, 728 max_seqlen_batch_q, 729 max_seqlen_batch_kv, 730 dropout_p, 731 int(is_causal), 732 compute_logsumexp, 733 scale=scale, 734 ) 735 736 # Reshape output to convert nnz to batch_size and seq_len 737 return ViewNestedFromBuffer.apply( 738 attention.squeeze(0), output_nt_info["offsets"] 739 ).transpose(1, 2) 740 elif backend_choice == SDPBackend.MATH: 741 # save the offsets and shape of the inputs, so we can reshape the final output 742 # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1] 743 # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2] 744 offsets = query.offsets() 745 d1 = query._size[1] 746 d2 = value._size[-1] 747 748 # convert jagged layout Nested Tensor to strided layout Nested Tensor 749 # which support the math implementation of SDPA 750 def get_strided_layout_nested_tensor(jagged_layout_nt): 751 lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1] 752 transpose = torch.transpose(jagged_layout_nt, 1, 2) 753 tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0) 754 strided_nt = torch.nested.as_nested_tensor(list(tensor_list)) 755 strided_nt = strided_nt.transpose(1, 2).contiguous() 756 return strided_nt 757 758 query = get_strided_layout_nested_tensor(query) 759 key = get_strided_layout_nested_tensor(key) 760 value = get_strided_layout_nested_tensor(value) 761 762 attn_out = torch._scaled_dot_product_attention_math( 763 query, key, value, attn_mask, dropout_p, is_causal, scale=scale 764 )[0] 765 766 # convert strided layout Nested Tensor back to jagged layout Nested Tensor 767 attn_out = attn_out.transpose(1, 2).contiguous().values() 768 attn_out = attn_out.view(-1, d1, d2) 769 attn_out = ViewNestedFromBuffer.apply(attn_out, offsets) 770 attn_out = attn_out.transpose(1, 2) 771 772 return attn_out 773 else: 774 raise RuntimeError( 775 "No viable backend for scaled_dot_product_attention was found." 776 ) 777