1# Owner(s): ["module: nn"] 2import contextlib 3import random 4import unittest 5import unittest.mock as mock 6 7import torch 8import torch.nn as nn 9from torch.nn import MultiheadAttention 10from torch.testing._internal.common_device_type import ( 11 dtypes, 12 instantiate_device_type_tests, 13 onlyCUDAAndPRIVATEUSE1, 14) 15from torch.testing._internal.common_nn import NNTestCase 16from torch.testing._internal.common_utils import ( 17 instantiate_parametrized_tests, 18 parametrize as parametrize_test, 19 run_tests, 20 skipIfRocm, 21 TEST_NUMPY, 22 TEST_WITH_CROSSREF, 23) 24 25 26if TEST_NUMPY: 27 import numpy as np 28 29 30# WARNING: If you add a new top-level test case to this file, you MUST 31# update test/run_test.py to list it, otherwise it will NOT be run in 32# CI. 33 34 35class TestMultiheadAttentionNN(NNTestCase): 36 _do_cuda_memory_leak_check = True 37 _do_cuda_non_default_stream = True 38 39 @unittest.skipIf(not TEST_NUMPY, "numpy not found") 40 @parametrize_test("average_attn_weights", [True, False]) 41 def test_multihead_attention(self, average_attn_weights): 42 def _scaled_dot_attn_ref( 43 Q, 44 K, 45 V, 46 dims, 47 unseen_mask=None, 48 key_padding_mask=None, 49 average_attn_weights=average_attn_weights, 50 ): 51 """Numpy-based reference implementation of scaled dot attention 52 for testing""" 53 54 QKT = _batchmatmul( 55 Q, 56 np.transpose(K, axes=[0, 1, 3, 2]) 57 / np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head) 58 ) 59 b1, b2, s1, s2 = QKT.shape 60 if unseen_mask is not None or key_padding_mask is not None: 61 # assert s1 == s2 62 for i in range(b1): 63 for j in range(b2): 64 for m in range(s1): 65 for n in range(s2): 66 if unseen_mask is not None and unseen_mask[m][n] == 0: 67 QKT[i, j, m, n] = -np.inf 68 if ( 69 key_padding_mask is not None 70 and key_padding_mask[i][n] 71 ): 72 QKT[i, j, m, n] = -np.inf 73 74 reference = _softmax(QKT) 75 ref_attn_weight = reference 76 if average_attn_weights: 77 ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2 78 reference = _batchmatmul(reference, V) 79 return reference, ref_attn_weight 80 81 def _batchmatmul(a, b): # batchmatmul over 4 dim matrix 82 """Numpy-based batch matrix multiply over 4 dim matrix""" 83 assert a.shape[0] == b.shape[0] 84 assert a.shape[1] == b.shape[1] 85 retval = np.zeros( 86 (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32 87 ) 88 for i in range(a.shape[0]): 89 for j in range(a.shape[1]): 90 retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :]) 91 return retval 92 93 def _softmax(x): # softmax over 4 dim matrix 94 """Numpy-based reference softmax over 4 dim matrix""" 95 np.seterr(invalid="ignore") 96 output = np.zeros(x.shape, dtype=np.float64) 97 for i in range(x.shape[0]): 98 for j in range(x.shape[1]): 99 for k in range(x.shape[2]): 100 x_curr = x[i, j, k, :] 101 e_x = np.exp(x_curr - np.amax(x_curr)) 102 output[i, j, k, :] = e_x / np.sum(e_x) 103 return output 104 105 def _split_heads_ref(X, dims, nheads, d_head): 106 X_split = np.reshape(X, dims[:2] + [nheads, d_head]) 107 X_split_transposed = np.transpose(X_split, [0, 2, 1, 3]) 108 reference = np.reshape( 109 X_split_transposed, [dims[0], nheads, dims[1], d_head] 110 ) 111 return reference 112 113 def _combine_heads_ref(X, dims, nheads, d_head): 114 X_transposed = np.transpose(X, [0, 2, 1, 3]) 115 reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head]) 116 return reference 117 118 def _fc(X, X_weight, X_bias): 119 X_fc_b = X_bias.detach().numpy() 120 X_fc_w = X_weight.detach().numpy() 121 return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b 122 123 def _create_src_lengths_mask(batch_size, src_lengths): 124 """ 125 Generate boolean mask to prevent attention beyond the end of source 126 Inputs: 127 batch_size : int 128 src_lengths : [batch_size] of sentence lengths 129 Outputs: 130 [batch_size, max_src_len] 131 """ 132 max_srclen = src_lengths.max() 133 src_indices = torch.arange(0, max_srclen).unsqueeze(0).to(src_lengths) 134 src_indices = src_indices.expand(batch_size, max_srclen) 135 src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) 136 # returns [batch_size, max_seq_len] 137 return (src_indices < src_lengths).int().detach() 138 139 def _multihead_attn_test_helper( 140 add_key_padding_mask=False, 141 add_bias_kv=False, 142 add_zero_attn=False, 143 saved_kv=False, 144 same_embed_dim=False, 145 average_attn_weights=average_attn_weights, 146 ): 147 for _ in range(100): 148 batch_sz, seq_len = (random.randint(2, 10) for r in range(2)) 149 d_head = random.randint(3, 10) 150 nheads = random.randint(2, 5) * 2 151 d_model = d_head * nheads 152 if same_embed_dim: 153 kv_dim = d_model 154 else: 155 kv_dim = random.randint(5, 20) 156 dims = [batch_sz, seq_len, kv_dim] 157 158 saved_k = None 159 saved_k_tensor = None 160 saved_v = None 161 saved_v_tensor = None 162 if saved_kv: 163 saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head) 164 saved_k_tensor = torch.from_numpy(saved_k).to( 165 torch.get_default_dtype() 166 ) 167 saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head) 168 saved_v_tensor = torch.from_numpy(saved_v).to( 169 torch.get_default_dtype() 170 ) 171 172 key_padding_mask = None 173 key_padding_mask_tensor = None 174 if add_key_padding_mask: 175 seq_mask = np.random.randint(0, 2, (1, seq_len)) 176 key_padding_mask = np.repeat(seq_mask, batch_sz, axis=0) == 1 177 key_padding_mask_tensor = torch.from_numpy(key_padding_mask) 178 decoder_state = np.random.rand(batch_sz, d_model) 179 K = np.random.rand(*dims) 180 V = K 181 Q = np.expand_dims(decoder_state, 1) 182 attn_mask = np.random.randint(0, 2, size=(1, seq_len)) 183 attn_mask_tensor = torch.from_numpy(attn_mask).float() 184 attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float("-inf")) 185 attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float("0.0")) 186 187 decoder_state_tensor = torch.from_numpy(decoder_state).to( 188 torch.get_default_dtype() 189 ) 190 source_hid_tensor = ( 191 torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1) 192 ) 193 194 multihead_attn_module = MultiheadAttention( 195 d_model, 196 nheads, 197 add_bias_kv=add_bias_kv, 198 add_zero_attn=add_zero_attn, 199 kdim=kv_dim, 200 vdim=kv_dim, 201 ) 202 203 if add_bias_kv: 204 bias_k = multihead_attn_module.bias_k.detach().numpy() 205 bias_v = multihead_attn_module.bias_v.detach().numpy() 206 else: 207 bias_k = None 208 bias_v = None 209 210 _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1) 211 _V = source_hid_tensor 212 _K = source_hid_tensor 213 214 if multihead_attn_module._qkv_same_embed_dim: 215 ( 216 result, 217 result_weight, 218 ) = torch.nn.functional.multi_head_attention_forward( 219 _Q, 220 _K, 221 _V, 222 d_model, 223 nheads, 224 multihead_attn_module.in_proj_weight, 225 multihead_attn_module.in_proj_bias, 226 multihead_attn_module.bias_k, 227 multihead_attn_module.bias_v, 228 multihead_attn_module.add_zero_attn, 229 multihead_attn_module.dropout, 230 multihead_attn_module.out_proj.weight, 231 multihead_attn_module.out_proj.bias, 232 multihead_attn_module.training, 233 key_padding_mask_tensor, 234 True, 235 attn_mask_tensor, 236 static_k=saved_k_tensor, 237 static_v=saved_v_tensor, 238 average_attn_weights=average_attn_weights, 239 is_causal=False, 240 ) 241 else: 242 ( 243 result, 244 result_weight, 245 ) = torch.nn.functional.multi_head_attention_forward( 246 _Q, 247 _K, 248 _V, 249 d_model, 250 nheads, 251 None, 252 multihead_attn_module.in_proj_bias, 253 multihead_attn_module.bias_k, 254 multihead_attn_module.bias_v, 255 multihead_attn_module.add_zero_attn, 256 multihead_attn_module.dropout, 257 multihead_attn_module.out_proj.weight, 258 multihead_attn_module.out_proj.bias, 259 multihead_attn_module.training, 260 key_padding_mask_tensor, 261 True, 262 attn_mask_tensor, 263 True, 264 multihead_attn_module.q_proj_weight, 265 multihead_attn_module.k_proj_weight, 266 multihead_attn_module.v_proj_weight, 267 static_k=saved_k_tensor, 268 static_v=saved_v_tensor, 269 average_attn_weights=average_attn_weights, 270 is_causal=False, 271 ) 272 273 result = result.squeeze(0).detach().numpy() 274 275 if multihead_attn_module._qkv_same_embed_dim: 276 q_proj_weight = multihead_attn_module.in_proj_weight[:d_model] 277 k_proj_weight = multihead_attn_module.in_proj_weight[ 278 d_model : (d_model * 2) 279 ] 280 v_proj_weight = multihead_attn_module.in_proj_weight[ 281 (d_model * 2) : 282 ] 283 else: 284 q_proj_weight = multihead_attn_module.q_proj_weight 285 k_proj_weight = multihead_attn_module.k_proj_weight 286 v_proj_weight = multihead_attn_module.v_proj_weight 287 288 Q_fc = _fc( 289 Q, q_proj_weight, multihead_attn_module.in_proj_bias[:d_model] 290 ) 291 K_fc = _fc( 292 K, 293 k_proj_weight, 294 multihead_attn_module.in_proj_bias[d_model : (d_model * 2)], 295 ) 296 V_fc = _fc( 297 V, 298 v_proj_weight, 299 multihead_attn_module.in_proj_bias[(d_model * 2) :], 300 ) 301 302 if add_bias_kv: 303 K_fc = np.concatenate( 304 (K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1 305 ) 306 V_fc = np.concatenate( 307 (V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1 308 ) 309 if attn_mask is not None: 310 attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) 311 if key_padding_mask is not None: 312 key_padding_mask = np.concatenate( 313 ( 314 key_padding_mask, 315 np.full((batch_sz, 1), False, dtype=bool), 316 ), 317 axis=1, 318 ) 319 dims[1] += 1 320 Q_split = _split_heads_ref(Q_fc, [batch_sz, 1, d_model], nheads, d_head) 321 322 if saved_k is not None: 323 K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head]) 324 else: 325 K_split = _split_heads_ref(K_fc, dims, nheads, d_head) 326 327 if saved_v is not None: 328 V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head]) 329 else: 330 V_split = _split_heads_ref(V_fc, dims, nheads, d_head) 331 332 if add_zero_attn: 333 dims[1] += 1 334 K_split = np.concatenate( 335 ( 336 K_split, 337 np.zeros( 338 [ 339 K_split.shape[0], 340 K_split.shape[1], 341 1, 342 K_split.shape[3], 343 ] 344 ), 345 ), 346 axis=2, 347 ) 348 V_split = np.concatenate( 349 ( 350 V_split, 351 np.zeros( 352 [ 353 V_split.shape[0], 354 V_split.shape[1], 355 1, 356 V_split.shape[3], 357 ] 358 ), 359 ), 360 axis=2, 361 ) 362 363 if attn_mask is not None: 364 attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) 365 366 if key_padding_mask is not None: 367 key_padding_mask = np.concatenate( 368 ( 369 key_padding_mask, 370 np.full((batch_sz, 1), False, dtype=bool), 371 ), 372 axis=1, 373 ) 374 attn_heads, ref_attn_weight = _scaled_dot_attn_ref( 375 Q=Q_split, 376 K=K_split, 377 V=V_split, 378 dims=Q_split.shape, 379 unseen_mask=attn_mask, 380 key_padding_mask=key_padding_mask, 381 ) 382 combined_attn_heads = _combine_heads_ref( 383 X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head 384 ) 385 386 reference = _fc( 387 combined_attn_heads, 388 multihead_attn_module.out_proj.weight, 389 multihead_attn_module.out_proj.bias, 390 ) 391 reference = np.squeeze(reference, axis=1) 392 393 # result = reference 394 self.assertEqual(tuple(result.shape), (batch_sz, d_model)) 395 np.testing.assert_allclose(result, reference, atol=1e-5) 396 397 # result_weight = ref_attn_weight 398 result_weight = result_weight.detach().numpy() 399 self.assertEqual( 400 tuple(result_weight.shape), tuple(ref_attn_weight.shape) 401 ) 402 np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5) 403 404 def test_multihead_attn_add_bias_kv(): 405 _multihead_attn_test_helper(add_bias_kv=True) 406 407 def test_multihead_attn_add_zero_attn(): 408 _multihead_attn_test_helper(add_zero_attn=True) 409 410 def test_multihead_attn_no_masking(): 411 _multihead_attn_test_helper() 412 413 def test_multihead_attn_key_padding_mask(): 414 _multihead_attn_test_helper(add_key_padding_mask=True) 415 416 def test_multihead_attn_saved_kv(): 417 _multihead_attn_test_helper(saved_kv=True) 418 419 def test_multihead_attn_add_bias_kv_zero_attn(): 420 _multihead_attn_test_helper( 421 add_key_padding_mask=True, add_bias_kv=True, add_zero_attn=True 422 ) 423 424 def test_multihead_attn_all_arguments1(): 425 _multihead_attn_test_helper( 426 add_key_padding_mask=True, add_zero_attn=True, saved_kv=True 427 ) 428 429 def test_multihead_attn_all_arguments2(): 430 _multihead_attn_test_helper( 431 add_key_padding_mask=True, 432 add_bias_kv=True, 433 add_zero_attn=True, 434 saved_kv=True, 435 ) 436 437 def test_multihead_attn_all_arguments3(): 438 _multihead_attn_test_helper( 439 add_key_padding_mask=True, 440 add_zero_attn=True, 441 saved_kv=True, 442 same_embed_dim=True, 443 ) 444 445 test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn 446 test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv 447 test_multihead_attn_no_masking() # Test MultiheadAttention without masking 448 test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths 449 test_multihead_attn_saved_kv() # Test MultiheadAttention with static kv. 450 test_multihead_attn_add_bias_kv_zero_attn() # Test MultiheadAttention with bias_kv and zero_attn. 451 test_multihead_attn_all_arguments1() # Test MultiheadAttention with all the argument. 452 with self.assertRaisesRegex( 453 AssertionError, "bias cannot be added to static key." 454 ): 455 test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument. 456 test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument. 457 458 def test_multihead_attn_3d_attn_mask(self): 459 embed_dim = 8 460 num_heads = 4 461 batch_size = 8 462 src_len = 3 463 tgt_len = 2 464 465 query = torch.rand(batch_size, tgt_len, embed_dim) # [N, T, D] 466 key = torch.rand(batch_size, src_len, embed_dim) # [N, S, D] 467 value = key # [N, S, D] 468 attn_mask = torch.randint( 469 0, 2, (batch_size, tgt_len, src_len) 470 ).float() # [N, T, S] 471 attn_mask = attn_mask.masked_fill(attn_mask == 0, float("-inf")).masked_fill( 472 attn_mask == 1, 0.0 473 ) 474 475 mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads) 476 477 # Generate 3D results 478 attn_mask_3d = torch.repeat_interleave( 479 attn_mask, num_heads, dim=0 480 ) # [N * H, T, S] 481 output_3d = mta_model( 482 query.transpose(0, 1), 483 key.transpose(0, 1), 484 value.transpose(0, 1), 485 attn_mask=attn_mask_3d, 486 )[0] 487 output_3d = output_3d.transpose(0, 1) # [N, T, D] 488 489 for i in range(0, batch_size): 490 output_2d = mta_model( 491 query[i].unsqueeze(0).transpose(0, 1), 492 key[i].unsqueeze(0).transpose(0, 1), 493 value[i].unsqueeze(0).transpose(0, 1), 494 attn_mask=attn_mask[i], 495 )[0] 496 497 # output_2d in shape of [T, 1, D] 498 self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d) 499 500 def test_multihead_attn_no_bias(self): 501 embed_dim = 8 502 num_heads = 4 503 mha = torch.nn.MultiheadAttention(embed_dim, num_heads, bias=False) 504 505 # Verify that bias=False applies to both in and out projection layers. 506 self.assertIsNone(mha.in_proj_bias) 507 self.assertIsNone(mha.out_proj.bias) 508 509 def _test_multihead_attn_invalid_shape_impl(self, mha): 510 # Batched (3D) query cases 511 query = torch.randn(4, 4, 4) 512 key = torch.randn(4, 4, 4) 513 value = torch.randn(4, 4, 4) 514 515 msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively" 516 # 3D query, 2D key and 3D value 517 with self.assertRaisesRegex(AssertionError, msg): 518 mha(query, torch.randn(4, 4), value) 519 520 msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively" 521 # 3D query, 3D key and 2D value 522 with self.assertRaisesRegex(AssertionError, msg): 523 mha(query, key, torch.randn(4, 4)) 524 525 msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead" 526 # 3D query, 3D key, 3D value and 1D key_padding_mask 527 with self.assertRaisesRegex(AssertionError, msg): 528 mha( 529 query, 530 key, 531 value, 532 key_padding_mask=torch.tensor( 533 [False, False, True, True], dtype=torch.bool 534 ), 535 ) 536 537 msg = ( 538 "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" 539 ) 540 # 3D query, 3D key, 3D value and 1D attn_mask 541 with self.assertRaisesRegex(AssertionError, msg): 542 mha( 543 query, 544 key, 545 value, 546 attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool), 547 ) 548 549 # Unbatched (2D) query cases 550 query = torch.randn(4, 4) 551 key = torch.randn(4, 4) 552 value = torch.randn(4, 4) 553 554 msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively" 555 # 2D query, 3D key and 2D value 556 with self.assertRaisesRegex(AssertionError, msg): 557 mha(query, torch.randn(4, 4, 4), value) 558 559 msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively" 560 # 2D query, 3D key and 2D value 561 with self.assertRaisesRegex(AssertionError, msg): 562 mha(query, key, torch.randn(4, 4, 4)) 563 564 msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead" 565 # 2D query, 2D key, 2D value and 1D key_padding_mask 566 with self.assertRaisesRegex(AssertionError, msg): 567 mha( 568 query, 569 key, 570 value, 571 key_padding_mask=torch.tensor( 572 [[False, False, True, True] * 2], dtype=torch.bool 573 ), 574 ) 575 576 msg = ( 577 "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" 578 ) 579 # 2D query, 2D key, 2D value and 1D attn_mask 580 with self.assertRaisesRegex(AssertionError, msg): 581 mha( 582 query, 583 key, 584 value, 585 attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool), 586 ) 587 588 msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)" 589 # 2D query, 2D key, 2D value and 3D incorrect attn_mask 590 with self.assertRaisesRegex(AssertionError, msg): 591 mha( 592 query, 593 key, 594 value, 595 attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool), 596 ) 597 598 def test_multihead_attn_invalid_shape(self): 599 mha = torch.nn.MultiheadAttention(4, 4) 600 self._test_multihead_attn_invalid_shape_impl(mha) 601 # Give the test a chance to hit the fast path. (Right now, it 602 # won't, but gating may be less restricted in the future.) 603 with torch.no_grad(): 604 self._test_multihead_attn_invalid_shape_impl(mha.eval()) 605 606 @torch.no_grad() 607 def test_multihead_attn_fast_path_invalid_shape(self): 608 mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval() 609 610 # Batched (3D) query cases 611 query = torch.randn(4, 4, 4) 612 key = torch.randn(4, 4, 4) 613 value = torch.randn(4, 4, 4) 614 615 # Currently, this case will just go to the slow path and get 616 # the usual message because it fails the requirement to be 617 # batched. 618 msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively" 619 # 3D query, 2D key and 3D value 620 with self.assertRaisesRegex(AssertionError, msg): 621 mha(query, torch.randn(3, 3), value, need_weights=False) 622 623 # Currently, this case will just go to the slow path and get 624 # the usual message because it fails the requirement to be 625 # batched. 626 msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively" 627 # 3D query, 3D key and 2D value 628 with self.assertRaisesRegex(AssertionError, msg): 629 mha(query, key, torch.randn(3, 3), need_weights=False) 630 631 msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead" 632 # 3D query, 3D key, 3D value and 1D key_padding_mask 633 with self.assertRaisesRegex(AssertionError, msg): 634 mha( 635 query, 636 key, 637 value, 638 key_padding_mask=torch.tensor([False, True, True], dtype=torch.bool), 639 need_weights=False, 640 ) 641 642 msg = ( 643 "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" 644 ) 645 # 3D query, 3D key, 3D value and 1D attn_mask 646 with self.assertRaisesRegex(AssertionError, msg): 647 mha( 648 query, 649 key, 650 value, 651 attn_mask=torch.tensor([False, True, True], dtype=torch.bool), 652 need_weights=False, 653 ) 654 655 # Unbatched (2D) query cases 656 # NOTE: error messages are the same as regular path because the fast path doesn't support 2D. 657 query = torch.randn(4, 4) 658 key = torch.randn(4, 4) 659 value = torch.randn(4, 4) 660 661 msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively" 662 # 2D query, 3D key and 2D value 663 with self.assertRaisesRegex(AssertionError, msg): 664 mha(query, torch.randn(4, 4, 4), value) 665 666 msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively" 667 # 2D query, 3D key and 2D value 668 with self.assertRaisesRegex(AssertionError, msg): 669 mha(query, key, torch.randn(4, 4, 4)) 670 671 msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead" 672 # 2D query, 2D key, 2D value and 1D key_padding_mask 673 with self.assertRaisesRegex(AssertionError, msg): 674 mha( 675 query, 676 key, 677 value, 678 key_padding_mask=torch.tensor( 679 [[False, False, True, True] * 2], dtype=torch.bool 680 ), 681 ) 682 683 msg = ( 684 "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" 685 ) 686 # 2D query, 2D key, 2D value and 1D attn_mask 687 with self.assertRaisesRegex(AssertionError, msg): 688 mha( 689 query, 690 key, 691 value, 692 attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool), 693 ) 694 695 msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)" 696 # 2D query, 2D key, 2D value and 3D incorrect attn_mask 697 with self.assertRaisesRegex(AssertionError, msg): 698 mha( 699 query, 700 key, 701 value, 702 attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool), 703 ) 704 705 def test_multihead_attn_nested_tensor_outside_fast_path(self): 706 mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval() 707 nt = torch.nested.nested_tensor([torch.randn(4, 4)]) 708 # One tested platform (linux-bionic-py3.7-clang) has a torch_function for one 709 # or more of these. Take advantage of that to test the torch_function bailout. 710 has_torch_func = torch.overrides.has_torch_function( 711 ( 712 nt, 713 mha.in_proj_weight, 714 mha.in_proj_bias, 715 mha.out_proj.weight, 716 mha.out_proj.bias, 717 ) 718 ) 719 if has_torch_func: 720 msg = "MultiheadAttention does not support NestedTensor.*argument has_torch_function" 721 else: 722 msg = ( 723 "MultiheadAttention does not support NestedTensor outside of its fast path.*grad is " 724 + "enabled and.*or biases requires_grad" 725 ) 726 with self.assertRaisesRegex(AssertionError, msg): 727 mha(nt, nt, nt) 728 729 if has_torch_func: 730 # Just give up, they're all going to fail with the same message. 731 return 732 733 with torch.no_grad(): 734 mha(nt, nt, nt) 735 with torch.inference_mode(): 736 mha(nt, nt, nt) 737 nt = torch.nested.nested_tensor([torch.randn(4, 4, requires_grad=False)]) 738 nt.requires_grad = False 739 with self.assertRaisesRegex(AssertionError, msg): 740 mha(nt, nt, nt) 741 mha.in_proj_weight.requires_grad = False 742 mha.in_proj_bias.requires_grad = False 743 mha.out_proj.weight.requires_grad = False 744 mha.out_proj.bias.requires_grad = False 745 mha(nt, nt, nt) 746 747 748class TestMultiheadAttentionNNDeviceType(NNTestCase): 749 @skipIfRocm(msg="To investigate: yields NaN") 750 def test_multihead_self_attn_two_masks_fast_path(self, device): 751 """ 752 Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path 753 when both attention mask (mask type 0) and key padding mask (mask type 1) are provided 754 """ 755 with torch.no_grad(): 756 embed_dim = 14 757 num_heads = 7 758 batch_size = 8 759 src_len = 5 760 761 query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device) 762 # Create masks of two different types 763 attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device) 764 key_padding_mask = ( 765 torch.randint(0, 2, (batch_size, src_len)).bool().to(device) 766 ) 767 768 # We'll need expanded versions of the masks for masking out the outputs below 769 attn_mask_expanded = attn_mask.reshape(1, 1, src_len, src_len).expand( 770 batch_size, num_heads, src_len, src_len 771 ) 772 key_padding_mask_expanded = key_padding_mask.reshape( 773 batch_size, 1, 1, src_len 774 ).expand(batch_size, num_heads, src_len, src_len) 775 merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) 776 777 # Compute attention on the fast path 778 mta_model = torch.nn.MultiheadAttention( 779 embed_dim, num_heads, batch_first=True, device=device 780 ) 781 mta_model.training = False 782 result_fast_path, _ = mta_model( 783 query, 784 key, 785 value, 786 attn_mask=attn_mask, 787 key_padding_mask=key_padding_mask, 788 ) 789 790 # Compute attention on the slow path 791 result_ref, _ = torch.nn.functional.multi_head_attention_forward( 792 query.transpose(0, 1), 793 key.transpose(0, 1), 794 value.transpose(0, 1), 795 embed_dim, 796 num_heads, 797 mta_model.in_proj_weight, 798 mta_model.in_proj_bias, 799 mta_model.bias_k, 800 mta_model.bias_v, 801 mta_model.add_zero_attn, 802 mta_model.dropout, 803 mta_model.out_proj.weight, 804 mta_model.out_proj.bias, 805 training=mta_model.training, 806 key_padding_mask=key_padding_mask, 807 need_weights=False, 808 attn_mask=attn_mask, 809 use_separate_proj_weight=False, 810 q_proj_weight=mta_model.q_proj_weight, 811 k_proj_weight=mta_model.k_proj_weight, 812 v_proj_weight=mta_model.v_proj_weight, 813 average_attn_weights=False, 814 ) 815 result_ref = result_ref.transpose(0, 1) # Convert to batch-first 816 817 # Rows which are completely masked out are nan, we need to exclude them from comparison 818 mask_out = ( 819 merged_mask[:, 0, :, :] 820 .all(-1, keepdim=True) 821 .expand(batch_size, src_len, embed_dim) 822 ) 823 result_fast_path_masked = result_fast_path.masked_fill(mask_out, 0) 824 result_ref_masked = result_ref.masked_fill(mask_out, 0) 825 826 self.assertEqual(result_fast_path_masked, result_ref_masked) 827 828 @torch.no_grad() 829 @unittest.skipIf( 830 TEST_WITH_CROSSREF, 831 "CrossRef turns on TorchFunctionMode, and so disables fastpath.", 832 ) 833 def test_multihead_self_attn_two_masks_fast_path_mock(self, device): 834 """ 835 Multihead self-attention should take fast path when both attention mask (mask type 0) 836 and key padding mask (mask type 1) are provided at the same time on CPU and CUDA and PrivateUse1 837 """ 838 device = device.rstrip(":0123456789") 839 if device not in ["cpu", "cuda", torch._C._get_privateuse1_backend_name()]: 840 self.skipTest("Fastpath only runs on CPU and CUDA and PrivateUse1.") 841 842 with torch.autocast(device_type=device, enabled=False): 843 embed_dim = 16 844 num_heads = 8 845 batch_size = 8 846 src_len = 5 847 848 query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device) 849 # Create masks of two different types 850 attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device) 851 key_padding_mask = ( 852 torch.randint(0, 2, (batch_size, src_len)).bool().to(device) 853 ) 854 855 with mock.patch( 856 "torch._native_multi_head_attention", 857 new=mock.MagicMock(return_value=(torch.Tensor(), torch.Tensor())), 858 ) as fastpath_mock: 859 # Compute attention on the fast path 860 mta_model = torch.nn.MultiheadAttention( 861 embed_dim, num_heads, batch_first=True, device=device 862 ).eval() 863 mta_model.training = False 864 mta_model( 865 query, 866 key, 867 value, 868 attn_mask=attn_mask, 869 key_padding_mask=key_padding_mask, 870 ) 871 # If mock was called, fastpath was taken 872 self.assertTrue(fastpath_mock.called) 873 874 @onlyCUDAAndPRIVATEUSE1 875 @dtypes(torch.half, torch.float, torch.double) 876 def test_multihead_attention_dtype(self, device, dtype): 877 embed_dim = 128 878 num_heads = 8 879 sl = 10 880 bs = 8 881 model = nn.MultiheadAttention(embed_dim, num_heads).to(device).to(dtype) 882 q = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) 883 k = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) 884 v = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) 885 out = model(q, k, v) 886 self.assertEqual(q.size(), out[0].size()) 887 self.assertEqual(dtype, out[0].dtype) 888 889 @onlyCUDAAndPRIVATEUSE1 890 @dtypes(torch.half, torch.float, torch.double) 891 def test_multihead_attention_dtype_batch_first(self, device, dtype): 892 embed_dim = 128 893 num_heads = 8 894 sl = 10 895 bs = 8 896 # With batch_first=True, we have the possibility of hitting 897 # the native fast path if we call .eval() and enable inference 898 # mode. Test both paths. 899 for training in (True, False): 900 model = ( 901 nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) 902 .to(device) 903 .to(dtype) 904 ) 905 if not training: 906 model = model.eval() 907 cm = torch.no_grad() 908 else: 909 cm = contextlib.nullcontext() 910 with cm: 911 q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) 912 k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) 913 v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) 914 # fast path currently doesn't support weights 915 out = model(q, k, v, need_weights=False) 916 self.assertEqual(q.size(), out[0].size()) 917 self.assertEqual(dtype, out[0].dtype) 918 919 @dtypes(torch.double) 920 @torch.no_grad() 921 def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes( 922 self, device, dtype 923 ): 924 mha = torch.nn.MultiheadAttention( 925 4, 4, batch_first=True, dtype=dtype, device=device 926 ).eval() 927 mha.in_proj_bias = torch.nn.Parameter( 928 mha.in_proj_bias.to(torch.half).to(device) 929 ) 930 query = torch.randn(4, 4, 4, dtype=dtype, device=device) 931 mha(query, query, query) 932 933 @dtypes(torch.double) 934 @torch.no_grad() 935 def test_multihead_attn_fast_path_small_test(self, device, dtype): 936 mha = torch.nn.MultiheadAttention( 937 4, 4, batch_first=True, dtype=dtype, device=device 938 ).eval() 939 query = torch.randn(4, 4, 4, dtype=dtype, device=device) 940 mha(query, query, query) 941 942 @dtypes(torch.double) 943 @torch.no_grad() 944 def test_multihead_attn_in_proj_bias_none(self, device, dtype): 945 mha = torch.nn.MultiheadAttention(2, 2, bias=False, dtype=dtype, device=device) 946 query = torch.rand(2, 2, 2, dtype=dtype, device=device) 947 mha(query, query, query) 948 949 @dtypes(torch.double) 950 @torch.no_grad() 951 def test_multihead_attn_in_proj_weight_none(self, device, dtype): 952 # Setting kdim == vdim == 2 means that vdim != embed_dim 953 # will cause the logic to use per-input project weights, thereby 954 # forcing self.in_proj_weight = None 955 mha = torch.nn.MultiheadAttention( 956 4, 4, vdim=2, kdim=2, dtype=dtype, device=device 957 ) 958 query = torch.rand(4, 4, 4, dtype=dtype, device=device) 959 key = torch.rand(4, 4, 2, dtype=dtype, device=device) 960 mha(query, key, key) 961 962 963instantiate_device_type_tests(TestMultiheadAttentionNNDeviceType, globals()) 964instantiate_parametrized_tests(TestMultiheadAttentionNN) 965 966if __name__ == "__main__": 967 run_tests() 968