xref: /aosp_15_r20/external/pytorch/test/nn/test_multihead_attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import contextlib
3import random
4import unittest
5import unittest.mock as mock
6
7import torch
8import torch.nn as nn
9from torch.nn import MultiheadAttention
10from torch.testing._internal.common_device_type import (
11    dtypes,
12    instantiate_device_type_tests,
13    onlyCUDAAndPRIVATEUSE1,
14)
15from torch.testing._internal.common_nn import NNTestCase
16from torch.testing._internal.common_utils import (
17    instantiate_parametrized_tests,
18    parametrize as parametrize_test,
19    run_tests,
20    skipIfRocm,
21    TEST_NUMPY,
22    TEST_WITH_CROSSREF,
23)
24
25
26if TEST_NUMPY:
27    import numpy as np
28
29
30# WARNING: If you add a new top-level test case to this file, you MUST
31# update test/run_test.py to list it, otherwise it will NOT be run in
32# CI.
33
34
35class TestMultiheadAttentionNN(NNTestCase):
36    _do_cuda_memory_leak_check = True
37    _do_cuda_non_default_stream = True
38
39    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
40    @parametrize_test("average_attn_weights", [True, False])
41    def test_multihead_attention(self, average_attn_weights):
42        def _scaled_dot_attn_ref(
43            Q,
44            K,
45            V,
46            dims,
47            unseen_mask=None,
48            key_padding_mask=None,
49            average_attn_weights=average_attn_weights,
50        ):
51            """Numpy-based reference implementation of scaled dot attention
52            for testing"""
53
54            QKT = _batchmatmul(
55                Q,
56                np.transpose(K, axes=[0, 1, 3, 2])
57                / np.sqrt(dims[3], dtype=np.float32),  # divide by sqrt(d_head)
58            )
59            b1, b2, s1, s2 = QKT.shape
60            if unseen_mask is not None or key_padding_mask is not None:
61                # assert s1 == s2
62                for i in range(b1):
63                    for j in range(b2):
64                        for m in range(s1):
65                            for n in range(s2):
66                                if unseen_mask is not None and unseen_mask[m][n] == 0:
67                                    QKT[i, j, m, n] = -np.inf
68                                if (
69                                    key_padding_mask is not None
70                                    and key_padding_mask[i][n]
71                                ):
72                                    QKT[i, j, m, n] = -np.inf
73
74            reference = _softmax(QKT)
75            ref_attn_weight = reference
76            if average_attn_weights:
77                ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2
78            reference = _batchmatmul(reference, V)
79            return reference, ref_attn_weight
80
81        def _batchmatmul(a, b):  # batchmatmul over 4 dim matrix
82            """Numpy-based batch matrix multiply over 4 dim matrix"""
83            assert a.shape[0] == b.shape[0]
84            assert a.shape[1] == b.shape[1]
85            retval = np.zeros(
86                (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32
87            )
88            for i in range(a.shape[0]):
89                for j in range(a.shape[1]):
90                    retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :])
91            return retval
92
93        def _softmax(x):  # softmax over 4 dim matrix
94            """Numpy-based reference softmax over 4 dim matrix"""
95            np.seterr(invalid="ignore")
96            output = np.zeros(x.shape, dtype=np.float64)
97            for i in range(x.shape[0]):
98                for j in range(x.shape[1]):
99                    for k in range(x.shape[2]):
100                        x_curr = x[i, j, k, :]
101                        e_x = np.exp(x_curr - np.amax(x_curr))
102                        output[i, j, k, :] = e_x / np.sum(e_x)
103            return output
104
105        def _split_heads_ref(X, dims, nheads, d_head):
106            X_split = np.reshape(X, dims[:2] + [nheads, d_head])
107            X_split_transposed = np.transpose(X_split, [0, 2, 1, 3])
108            reference = np.reshape(
109                X_split_transposed, [dims[0], nheads, dims[1], d_head]
110            )
111            return reference
112
113        def _combine_heads_ref(X, dims, nheads, d_head):
114            X_transposed = np.transpose(X, [0, 2, 1, 3])
115            reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head])
116            return reference
117
118        def _fc(X, X_weight, X_bias):
119            X_fc_b = X_bias.detach().numpy()
120            X_fc_w = X_weight.detach().numpy()
121            return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b
122
123        def _create_src_lengths_mask(batch_size, src_lengths):
124            """
125            Generate boolean mask to prevent attention beyond the end of source
126            Inputs:
127              batch_size : int
128              src_lengths : [batch_size] of sentence lengths
129            Outputs:
130              [batch_size, max_src_len]
131            """
132            max_srclen = src_lengths.max()
133            src_indices = torch.arange(0, max_srclen).unsqueeze(0).to(src_lengths)
134            src_indices = src_indices.expand(batch_size, max_srclen)
135            src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen)
136            # returns [batch_size, max_seq_len]
137            return (src_indices < src_lengths).int().detach()
138
139        def _multihead_attn_test_helper(
140            add_key_padding_mask=False,
141            add_bias_kv=False,
142            add_zero_attn=False,
143            saved_kv=False,
144            same_embed_dim=False,
145            average_attn_weights=average_attn_weights,
146        ):
147            for _ in range(100):
148                batch_sz, seq_len = (random.randint(2, 10) for r in range(2))
149                d_head = random.randint(3, 10)
150                nheads = random.randint(2, 5) * 2
151                d_model = d_head * nheads
152                if same_embed_dim:
153                    kv_dim = d_model
154                else:
155                    kv_dim = random.randint(5, 20)
156                dims = [batch_sz, seq_len, kv_dim]
157
158                saved_k = None
159                saved_k_tensor = None
160                saved_v = None
161                saved_v_tensor = None
162                if saved_kv:
163                    saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head)
164                    saved_k_tensor = torch.from_numpy(saved_k).to(
165                        torch.get_default_dtype()
166                    )
167                    saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head)
168                    saved_v_tensor = torch.from_numpy(saved_v).to(
169                        torch.get_default_dtype()
170                    )
171
172                key_padding_mask = None
173                key_padding_mask_tensor = None
174                if add_key_padding_mask:
175                    seq_mask = np.random.randint(0, 2, (1, seq_len))
176                    key_padding_mask = np.repeat(seq_mask, batch_sz, axis=0) == 1
177                    key_padding_mask_tensor = torch.from_numpy(key_padding_mask)
178                decoder_state = np.random.rand(batch_sz, d_model)
179                K = np.random.rand(*dims)
180                V = K
181                Q = np.expand_dims(decoder_state, 1)
182                attn_mask = np.random.randint(0, 2, size=(1, seq_len))
183                attn_mask_tensor = torch.from_numpy(attn_mask).float()
184                attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float("-inf"))
185                attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float("0.0"))
186
187                decoder_state_tensor = torch.from_numpy(decoder_state).to(
188                    torch.get_default_dtype()
189                )
190                source_hid_tensor = (
191                    torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1)
192                )
193
194                multihead_attn_module = MultiheadAttention(
195                    d_model,
196                    nheads,
197                    add_bias_kv=add_bias_kv,
198                    add_zero_attn=add_zero_attn,
199                    kdim=kv_dim,
200                    vdim=kv_dim,
201                )
202
203                if add_bias_kv:
204                    bias_k = multihead_attn_module.bias_k.detach().numpy()
205                    bias_v = multihead_attn_module.bias_v.detach().numpy()
206                else:
207                    bias_k = None
208                    bias_v = None
209
210                _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1)
211                _V = source_hid_tensor
212                _K = source_hid_tensor
213
214                if multihead_attn_module._qkv_same_embed_dim:
215                    (
216                        result,
217                        result_weight,
218                    ) = torch.nn.functional.multi_head_attention_forward(
219                        _Q,
220                        _K,
221                        _V,
222                        d_model,
223                        nheads,
224                        multihead_attn_module.in_proj_weight,
225                        multihead_attn_module.in_proj_bias,
226                        multihead_attn_module.bias_k,
227                        multihead_attn_module.bias_v,
228                        multihead_attn_module.add_zero_attn,
229                        multihead_attn_module.dropout,
230                        multihead_attn_module.out_proj.weight,
231                        multihead_attn_module.out_proj.bias,
232                        multihead_attn_module.training,
233                        key_padding_mask_tensor,
234                        True,
235                        attn_mask_tensor,
236                        static_k=saved_k_tensor,
237                        static_v=saved_v_tensor,
238                        average_attn_weights=average_attn_weights,
239                        is_causal=False,
240                    )
241                else:
242                    (
243                        result,
244                        result_weight,
245                    ) = torch.nn.functional.multi_head_attention_forward(
246                        _Q,
247                        _K,
248                        _V,
249                        d_model,
250                        nheads,
251                        None,
252                        multihead_attn_module.in_proj_bias,
253                        multihead_attn_module.bias_k,
254                        multihead_attn_module.bias_v,
255                        multihead_attn_module.add_zero_attn,
256                        multihead_attn_module.dropout,
257                        multihead_attn_module.out_proj.weight,
258                        multihead_attn_module.out_proj.bias,
259                        multihead_attn_module.training,
260                        key_padding_mask_tensor,
261                        True,
262                        attn_mask_tensor,
263                        True,
264                        multihead_attn_module.q_proj_weight,
265                        multihead_attn_module.k_proj_weight,
266                        multihead_attn_module.v_proj_weight,
267                        static_k=saved_k_tensor,
268                        static_v=saved_v_tensor,
269                        average_attn_weights=average_attn_weights,
270                        is_causal=False,
271                    )
272
273                result = result.squeeze(0).detach().numpy()
274
275                if multihead_attn_module._qkv_same_embed_dim:
276                    q_proj_weight = multihead_attn_module.in_proj_weight[:d_model]
277                    k_proj_weight = multihead_attn_module.in_proj_weight[
278                        d_model : (d_model * 2)
279                    ]
280                    v_proj_weight = multihead_attn_module.in_proj_weight[
281                        (d_model * 2) :
282                    ]
283                else:
284                    q_proj_weight = multihead_attn_module.q_proj_weight
285                    k_proj_weight = multihead_attn_module.k_proj_weight
286                    v_proj_weight = multihead_attn_module.v_proj_weight
287
288                Q_fc = _fc(
289                    Q, q_proj_weight, multihead_attn_module.in_proj_bias[:d_model]
290                )
291                K_fc = _fc(
292                    K,
293                    k_proj_weight,
294                    multihead_attn_module.in_proj_bias[d_model : (d_model * 2)],
295                )
296                V_fc = _fc(
297                    V,
298                    v_proj_weight,
299                    multihead_attn_module.in_proj_bias[(d_model * 2) :],
300                )
301
302                if add_bias_kv:
303                    K_fc = np.concatenate(
304                        (K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1
305                    )
306                    V_fc = np.concatenate(
307                        (V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1
308                    )
309                    if attn_mask is not None:
310                        attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
311                    if key_padding_mask is not None:
312                        key_padding_mask = np.concatenate(
313                            (
314                                key_padding_mask,
315                                np.full((batch_sz, 1), False, dtype=bool),
316                            ),
317                            axis=1,
318                        )
319                    dims[1] += 1
320                Q_split = _split_heads_ref(Q_fc, [batch_sz, 1, d_model], nheads, d_head)
321
322                if saved_k is not None:
323                    K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head])
324                else:
325                    K_split = _split_heads_ref(K_fc, dims, nheads, d_head)
326
327                if saved_v is not None:
328                    V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head])
329                else:
330                    V_split = _split_heads_ref(V_fc, dims, nheads, d_head)
331
332                if add_zero_attn:
333                    dims[1] += 1
334                    K_split = np.concatenate(
335                        (
336                            K_split,
337                            np.zeros(
338                                [
339                                    K_split.shape[0],
340                                    K_split.shape[1],
341                                    1,
342                                    K_split.shape[3],
343                                ]
344                            ),
345                        ),
346                        axis=2,
347                    )
348                    V_split = np.concatenate(
349                        (
350                            V_split,
351                            np.zeros(
352                                [
353                                    V_split.shape[0],
354                                    V_split.shape[1],
355                                    1,
356                                    V_split.shape[3],
357                                ]
358                            ),
359                        ),
360                        axis=2,
361                    )
362
363                    if attn_mask is not None:
364                        attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
365
366                    if key_padding_mask is not None:
367                        key_padding_mask = np.concatenate(
368                            (
369                                key_padding_mask,
370                                np.full((batch_sz, 1), False, dtype=bool),
371                            ),
372                            axis=1,
373                        )
374                attn_heads, ref_attn_weight = _scaled_dot_attn_ref(
375                    Q=Q_split,
376                    K=K_split,
377                    V=V_split,
378                    dims=Q_split.shape,
379                    unseen_mask=attn_mask,
380                    key_padding_mask=key_padding_mask,
381                )
382                combined_attn_heads = _combine_heads_ref(
383                    X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
384                )
385
386                reference = _fc(
387                    combined_attn_heads,
388                    multihead_attn_module.out_proj.weight,
389                    multihead_attn_module.out_proj.bias,
390                )
391                reference = np.squeeze(reference, axis=1)
392
393                # result = reference
394                self.assertEqual(tuple(result.shape), (batch_sz, d_model))
395                np.testing.assert_allclose(result, reference, atol=1e-5)
396
397                # result_weight = ref_attn_weight
398                result_weight = result_weight.detach().numpy()
399                self.assertEqual(
400                    tuple(result_weight.shape), tuple(ref_attn_weight.shape)
401                )
402                np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5)
403
404        def test_multihead_attn_add_bias_kv():
405            _multihead_attn_test_helper(add_bias_kv=True)
406
407        def test_multihead_attn_add_zero_attn():
408            _multihead_attn_test_helper(add_zero_attn=True)
409
410        def test_multihead_attn_no_masking():
411            _multihead_attn_test_helper()
412
413        def test_multihead_attn_key_padding_mask():
414            _multihead_attn_test_helper(add_key_padding_mask=True)
415
416        def test_multihead_attn_saved_kv():
417            _multihead_attn_test_helper(saved_kv=True)
418
419        def test_multihead_attn_add_bias_kv_zero_attn():
420            _multihead_attn_test_helper(
421                add_key_padding_mask=True, add_bias_kv=True, add_zero_attn=True
422            )
423
424        def test_multihead_attn_all_arguments1():
425            _multihead_attn_test_helper(
426                add_key_padding_mask=True, add_zero_attn=True, saved_kv=True
427            )
428
429        def test_multihead_attn_all_arguments2():
430            _multihead_attn_test_helper(
431                add_key_padding_mask=True,
432                add_bias_kv=True,
433                add_zero_attn=True,
434                saved_kv=True,
435            )
436
437        def test_multihead_attn_all_arguments3():
438            _multihead_attn_test_helper(
439                add_key_padding_mask=True,
440                add_zero_attn=True,
441                saved_kv=True,
442                same_embed_dim=True,
443            )
444
445        test_multihead_attn_add_zero_attn()  # Test MultiheadAttention with add_zero_attn
446        test_multihead_attn_add_bias_kv()  # Test MultiheadAttention with add_bias_kv
447        test_multihead_attn_no_masking()  # Test MultiheadAttention without masking
448        test_multihead_attn_key_padding_mask()  # Test MultiheadAttention with src lengths
449        test_multihead_attn_saved_kv()  # Test MultiheadAttention with static kv.
450        test_multihead_attn_add_bias_kv_zero_attn()  # Test MultiheadAttention with bias_kv and zero_attn.
451        test_multihead_attn_all_arguments1()  # Test MultiheadAttention with all the argument.
452        with self.assertRaisesRegex(
453            AssertionError, "bias cannot be added to static key."
454        ):
455            test_multihead_attn_all_arguments2()  # Test MultiheadAttention with all the argument.
456        test_multihead_attn_all_arguments3()  # Test MultiheadAttention with all the argument.
457
458    def test_multihead_attn_3d_attn_mask(self):
459        embed_dim = 8
460        num_heads = 4
461        batch_size = 8
462        src_len = 3
463        tgt_len = 2
464
465        query = torch.rand(batch_size, tgt_len, embed_dim)  # [N, T, D]
466        key = torch.rand(batch_size, src_len, embed_dim)  # [N, S, D]
467        value = key  # [N, S, D]
468        attn_mask = torch.randint(
469            0, 2, (batch_size, tgt_len, src_len)
470        ).float()  # [N, T, S]
471        attn_mask = attn_mask.masked_fill(attn_mask == 0, float("-inf")).masked_fill(
472            attn_mask == 1, 0.0
473        )
474
475        mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads)
476
477        # Generate 3D results
478        attn_mask_3d = torch.repeat_interleave(
479            attn_mask, num_heads, dim=0
480        )  # [N * H, T, S]
481        output_3d = mta_model(
482            query.transpose(0, 1),
483            key.transpose(0, 1),
484            value.transpose(0, 1),
485            attn_mask=attn_mask_3d,
486        )[0]
487        output_3d = output_3d.transpose(0, 1)  # [N, T, D]
488
489        for i in range(0, batch_size):
490            output_2d = mta_model(
491                query[i].unsqueeze(0).transpose(0, 1),
492                key[i].unsqueeze(0).transpose(0, 1),
493                value[i].unsqueeze(0).transpose(0, 1),
494                attn_mask=attn_mask[i],
495            )[0]
496
497            # output_2d in shape of [T, 1, D]
498            self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d)
499
500    def test_multihead_attn_no_bias(self):
501        embed_dim = 8
502        num_heads = 4
503        mha = torch.nn.MultiheadAttention(embed_dim, num_heads, bias=False)
504
505        # Verify that bias=False applies to both in and out projection layers.
506        self.assertIsNone(mha.in_proj_bias)
507        self.assertIsNone(mha.out_proj.bias)
508
509    def _test_multihead_attn_invalid_shape_impl(self, mha):
510        # Batched (3D) query cases
511        query = torch.randn(4, 4, 4)
512        key = torch.randn(4, 4, 4)
513        value = torch.randn(4, 4, 4)
514
515        msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively"
516        # 3D query, 2D key and 3D value
517        with self.assertRaisesRegex(AssertionError, msg):
518            mha(query, torch.randn(4, 4), value)
519
520        msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively"
521        # 3D query, 3D key and 2D value
522        with self.assertRaisesRegex(AssertionError, msg):
523            mha(query, key, torch.randn(4, 4))
524
525        msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead"
526        # 3D query, 3D key, 3D value and 1D key_padding_mask
527        with self.assertRaisesRegex(AssertionError, msg):
528            mha(
529                query,
530                key,
531                value,
532                key_padding_mask=torch.tensor(
533                    [False, False, True, True], dtype=torch.bool
534                ),
535            )
536
537        msg = (
538            "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
539        )
540        # 3D query, 3D key, 3D value and 1D attn_mask
541        with self.assertRaisesRegex(AssertionError, msg):
542            mha(
543                query,
544                key,
545                value,
546                attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool),
547            )
548
549        # Unbatched (2D) query cases
550        query = torch.randn(4, 4)
551        key = torch.randn(4, 4)
552        value = torch.randn(4, 4)
553
554        msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively"
555        # 2D query, 3D key and 2D value
556        with self.assertRaisesRegex(AssertionError, msg):
557            mha(query, torch.randn(4, 4, 4), value)
558
559        msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively"
560        # 2D query, 3D key and 2D value
561        with self.assertRaisesRegex(AssertionError, msg):
562            mha(query, key, torch.randn(4, 4, 4))
563
564        msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead"
565        # 2D query, 2D key, 2D value and 1D key_padding_mask
566        with self.assertRaisesRegex(AssertionError, msg):
567            mha(
568                query,
569                key,
570                value,
571                key_padding_mask=torch.tensor(
572                    [[False, False, True, True] * 2], dtype=torch.bool
573                ),
574            )
575
576        msg = (
577            "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
578        )
579        # 2D query, 2D key, 2D value and 1D attn_mask
580        with self.assertRaisesRegex(AssertionError, msg):
581            mha(
582                query,
583                key,
584                value,
585                attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool),
586            )
587
588        msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)"
589        # 2D query, 2D key, 2D value and 3D incorrect attn_mask
590        with self.assertRaisesRegex(AssertionError, msg):
591            mha(
592                query,
593                key,
594                value,
595                attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool),
596            )
597
598    def test_multihead_attn_invalid_shape(self):
599        mha = torch.nn.MultiheadAttention(4, 4)
600        self._test_multihead_attn_invalid_shape_impl(mha)
601        # Give the test a chance to hit the fast path. (Right now, it
602        # won't, but gating may be less restricted in the future.)
603        with torch.no_grad():
604            self._test_multihead_attn_invalid_shape_impl(mha.eval())
605
606    @torch.no_grad()
607    def test_multihead_attn_fast_path_invalid_shape(self):
608        mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval()
609
610        # Batched (3D) query cases
611        query = torch.randn(4, 4, 4)
612        key = torch.randn(4, 4, 4)
613        value = torch.randn(4, 4, 4)
614
615        # Currently, this case will just go to the slow path and get
616        # the usual message because it fails the requirement to be
617        # batched.
618        msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively"
619        # 3D query, 2D key and 3D value
620        with self.assertRaisesRegex(AssertionError, msg):
621            mha(query, torch.randn(3, 3), value, need_weights=False)
622
623        # Currently, this case will just go to the slow path and get
624        # the usual message because it fails the requirement to be
625        # batched.
626        msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively"
627        # 3D query, 3D key and 2D value
628        with self.assertRaisesRegex(AssertionError, msg):
629            mha(query, key, torch.randn(3, 3), need_weights=False)
630
631        msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead"
632        # 3D query, 3D key, 3D value and 1D key_padding_mask
633        with self.assertRaisesRegex(AssertionError, msg):
634            mha(
635                query,
636                key,
637                value,
638                key_padding_mask=torch.tensor([False, True, True], dtype=torch.bool),
639                need_weights=False,
640            )
641
642        msg = (
643            "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
644        )
645        # 3D query, 3D key, 3D value and 1D attn_mask
646        with self.assertRaisesRegex(AssertionError, msg):
647            mha(
648                query,
649                key,
650                value,
651                attn_mask=torch.tensor([False, True, True], dtype=torch.bool),
652                need_weights=False,
653            )
654
655        # Unbatched (2D) query cases
656        # NOTE: error messages are the same as regular path because the fast path doesn't support 2D.
657        query = torch.randn(4, 4)
658        key = torch.randn(4, 4)
659        value = torch.randn(4, 4)
660
661        msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively"
662        # 2D query, 3D key and 2D value
663        with self.assertRaisesRegex(AssertionError, msg):
664            mha(query, torch.randn(4, 4, 4), value)
665
666        msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively"
667        # 2D query, 3D key and 2D value
668        with self.assertRaisesRegex(AssertionError, msg):
669            mha(query, key, torch.randn(4, 4, 4))
670
671        msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead"
672        # 2D query, 2D key, 2D value and 1D key_padding_mask
673        with self.assertRaisesRegex(AssertionError, msg):
674            mha(
675                query,
676                key,
677                value,
678                key_padding_mask=torch.tensor(
679                    [[False, False, True, True] * 2], dtype=torch.bool
680                ),
681            )
682
683        msg = (
684            "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead"
685        )
686        # 2D query, 2D key, 2D value and 1D attn_mask
687        with self.assertRaisesRegex(AssertionError, msg):
688            mha(
689                query,
690                key,
691                value,
692                attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool),
693            )
694
695        msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)"
696        # 2D query, 2D key, 2D value and 3D incorrect attn_mask
697        with self.assertRaisesRegex(AssertionError, msg):
698            mha(
699                query,
700                key,
701                value,
702                attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool),
703            )
704
705    def test_multihead_attn_nested_tensor_outside_fast_path(self):
706        mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval()
707        nt = torch.nested.nested_tensor([torch.randn(4, 4)])
708        # One tested platform (linux-bionic-py3.7-clang) has a torch_function for one
709        # or more of these. Take advantage of that to test the torch_function bailout.
710        has_torch_func = torch.overrides.has_torch_function(
711            (
712                nt,
713                mha.in_proj_weight,
714                mha.in_proj_bias,
715                mha.out_proj.weight,
716                mha.out_proj.bias,
717            )
718        )
719        if has_torch_func:
720            msg = "MultiheadAttention does not support NestedTensor.*argument has_torch_function"
721        else:
722            msg = (
723                "MultiheadAttention does not support NestedTensor outside of its fast path.*grad is "
724                + "enabled and.*or biases requires_grad"
725            )
726        with self.assertRaisesRegex(AssertionError, msg):
727            mha(nt, nt, nt)
728
729        if has_torch_func:
730            # Just give up, they're all going to fail with the same message.
731            return
732
733        with torch.no_grad():
734            mha(nt, nt, nt)
735        with torch.inference_mode():
736            mha(nt, nt, nt)
737        nt = torch.nested.nested_tensor([torch.randn(4, 4, requires_grad=False)])
738        nt.requires_grad = False
739        with self.assertRaisesRegex(AssertionError, msg):
740            mha(nt, nt, nt)
741        mha.in_proj_weight.requires_grad = False
742        mha.in_proj_bias.requires_grad = False
743        mha.out_proj.weight.requires_grad = False
744        mha.out_proj.bias.requires_grad = False
745        mha(nt, nt, nt)
746
747
748class TestMultiheadAttentionNNDeviceType(NNTestCase):
749    @skipIfRocm(msg="To investigate: yields NaN")
750    def test_multihead_self_attn_two_masks_fast_path(self, device):
751        """
752        Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path
753        when both attention mask (mask type 0) and key padding mask (mask type 1) are provided
754        """
755        with torch.no_grad():
756            embed_dim = 14
757            num_heads = 7
758            batch_size = 8
759            src_len = 5
760
761            query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device)
762            # Create masks of two different types
763            attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device)
764            key_padding_mask = (
765                torch.randint(0, 2, (batch_size, src_len)).bool().to(device)
766            )
767
768            # We'll need expanded versions of the masks for masking out the outputs below
769            attn_mask_expanded = attn_mask.reshape(1, 1, src_len, src_len).expand(
770                batch_size, num_heads, src_len, src_len
771            )
772            key_padding_mask_expanded = key_padding_mask.reshape(
773                batch_size, 1, 1, src_len
774            ).expand(batch_size, num_heads, src_len, src_len)
775            merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded)
776
777            # Compute attention on the fast path
778            mta_model = torch.nn.MultiheadAttention(
779                embed_dim, num_heads, batch_first=True, device=device
780            )
781            mta_model.training = False
782            result_fast_path, _ = mta_model(
783                query,
784                key,
785                value,
786                attn_mask=attn_mask,
787                key_padding_mask=key_padding_mask,
788            )
789
790            # Compute attention on the slow path
791            result_ref, _ = torch.nn.functional.multi_head_attention_forward(
792                query.transpose(0, 1),
793                key.transpose(0, 1),
794                value.transpose(0, 1),
795                embed_dim,
796                num_heads,
797                mta_model.in_proj_weight,
798                mta_model.in_proj_bias,
799                mta_model.bias_k,
800                mta_model.bias_v,
801                mta_model.add_zero_attn,
802                mta_model.dropout,
803                mta_model.out_proj.weight,
804                mta_model.out_proj.bias,
805                training=mta_model.training,
806                key_padding_mask=key_padding_mask,
807                need_weights=False,
808                attn_mask=attn_mask,
809                use_separate_proj_weight=False,
810                q_proj_weight=mta_model.q_proj_weight,
811                k_proj_weight=mta_model.k_proj_weight,
812                v_proj_weight=mta_model.v_proj_weight,
813                average_attn_weights=False,
814            )
815            result_ref = result_ref.transpose(0, 1)  # Convert to batch-first
816
817            # Rows which are completely masked out are nan, we need to exclude them from comparison
818            mask_out = (
819                merged_mask[:, 0, :, :]
820                .all(-1, keepdim=True)
821                .expand(batch_size, src_len, embed_dim)
822            )
823            result_fast_path_masked = result_fast_path.masked_fill(mask_out, 0)
824            result_ref_masked = result_ref.masked_fill(mask_out, 0)
825
826            self.assertEqual(result_fast_path_masked, result_ref_masked)
827
828    @torch.no_grad()
829    @unittest.skipIf(
830        TEST_WITH_CROSSREF,
831        "CrossRef turns on TorchFunctionMode, and so disables fastpath.",
832    )
833    def test_multihead_self_attn_two_masks_fast_path_mock(self, device):
834        """
835        Multihead self-attention should take fast path when both attention mask (mask type 0)
836        and key padding mask (mask type 1) are provided at the same time on CPU and CUDA and PrivateUse1
837        """
838        device = device.rstrip(":0123456789")
839        if device not in ["cpu", "cuda", torch._C._get_privateuse1_backend_name()]:
840            self.skipTest("Fastpath only runs on CPU and CUDA and PrivateUse1.")
841
842        with torch.autocast(device_type=device, enabled=False):
843            embed_dim = 16
844            num_heads = 8
845            batch_size = 8
846            src_len = 5
847
848            query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device)
849            # Create masks of two different types
850            attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device)
851            key_padding_mask = (
852                torch.randint(0, 2, (batch_size, src_len)).bool().to(device)
853            )
854
855            with mock.patch(
856                "torch._native_multi_head_attention",
857                new=mock.MagicMock(return_value=(torch.Tensor(), torch.Tensor())),
858            ) as fastpath_mock:
859                # Compute attention on the fast path
860                mta_model = torch.nn.MultiheadAttention(
861                    embed_dim, num_heads, batch_first=True, device=device
862                ).eval()
863                mta_model.training = False
864                mta_model(
865                    query,
866                    key,
867                    value,
868                    attn_mask=attn_mask,
869                    key_padding_mask=key_padding_mask,
870                )
871                # If mock was called, fastpath was taken
872                self.assertTrue(fastpath_mock.called)
873
874    @onlyCUDAAndPRIVATEUSE1
875    @dtypes(torch.half, torch.float, torch.double)
876    def test_multihead_attention_dtype(self, device, dtype):
877        embed_dim = 128
878        num_heads = 8
879        sl = 10
880        bs = 8
881        model = nn.MultiheadAttention(embed_dim, num_heads).to(device).to(dtype)
882        q = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype)
883        k = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype)
884        v = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype)
885        out = model(q, k, v)
886        self.assertEqual(q.size(), out[0].size())
887        self.assertEqual(dtype, out[0].dtype)
888
889    @onlyCUDAAndPRIVATEUSE1
890    @dtypes(torch.half, torch.float, torch.double)
891    def test_multihead_attention_dtype_batch_first(self, device, dtype):
892        embed_dim = 128
893        num_heads = 8
894        sl = 10
895        bs = 8
896        # With batch_first=True, we have the possibility of hitting
897        # the native fast path if we call .eval() and enable inference
898        # mode. Test both paths.
899        for training in (True, False):
900            model = (
901                nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
902                .to(device)
903                .to(dtype)
904            )
905            if not training:
906                model = model.eval()
907                cm = torch.no_grad()
908            else:
909                cm = contextlib.nullcontext()
910            with cm:
911                q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
912                k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
913                v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype)
914                # fast path currently doesn't support weights
915                out = model(q, k, v, need_weights=False)
916                self.assertEqual(q.size(), out[0].size())
917                self.assertEqual(dtype, out[0].dtype)
918
919    @dtypes(torch.double)
920    @torch.no_grad()
921    def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes(
922        self, device, dtype
923    ):
924        mha = torch.nn.MultiheadAttention(
925            4, 4, batch_first=True, dtype=dtype, device=device
926        ).eval()
927        mha.in_proj_bias = torch.nn.Parameter(
928            mha.in_proj_bias.to(torch.half).to(device)
929        )
930        query = torch.randn(4, 4, 4, dtype=dtype, device=device)
931        mha(query, query, query)
932
933    @dtypes(torch.double)
934    @torch.no_grad()
935    def test_multihead_attn_fast_path_small_test(self, device, dtype):
936        mha = torch.nn.MultiheadAttention(
937            4, 4, batch_first=True, dtype=dtype, device=device
938        ).eval()
939        query = torch.randn(4, 4, 4, dtype=dtype, device=device)
940        mha(query, query, query)
941
942    @dtypes(torch.double)
943    @torch.no_grad()
944    def test_multihead_attn_in_proj_bias_none(self, device, dtype):
945        mha = torch.nn.MultiheadAttention(2, 2, bias=False, dtype=dtype, device=device)
946        query = torch.rand(2, 2, 2, dtype=dtype, device=device)
947        mha(query, query, query)
948
949    @dtypes(torch.double)
950    @torch.no_grad()
951    def test_multihead_attn_in_proj_weight_none(self, device, dtype):
952        # Setting kdim == vdim == 2 means that vdim != embed_dim
953        # will cause the logic to use per-input project weights, thereby
954        # forcing self.in_proj_weight = None
955        mha = torch.nn.MultiheadAttention(
956            4, 4, vdim=2, kdim=2, dtype=dtype, device=device
957        )
958        query = torch.rand(4, 4, 4, dtype=dtype, device=device)
959        key = torch.rand(4, 4, 2, dtype=dtype, device=device)
960        mha(query, key, key)
961
962
963instantiate_device_type_tests(TestMultiheadAttentionNNDeviceType, globals())
964instantiate_parametrized_tests(TestMultiheadAttentionNN)
965
966if __name__ == "__main__":
967    run_tests()
968