# Owner(s): ["module: nn"] import math import copy import torch from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, instantiate_device_type_tests, onlyCUDA, skipMeta, ) from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_WITH_ROCM class TestMHADeviceType(TestCase): @torch.no_grad() def _test_transform_bias_rescale_qkv_impl( self, device, dtype, use_nt, use_padding=False ): tests = [ (64, 4, 16, 8), # dim_per_head = 12 does not divide evenly by CPU vectorization length of 8 (24, 2, 4, 2), # Make sure CUDA can handle small input sizes (2, 2, 2, 2), # dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4, # causes alignment issues (24, 4, 4, 2), (48, 4, 16, 8), ] for (embed_dim, num_heads, bs, sl) in tests: with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl): torch.manual_seed(9343) dense_x = x = ( torch.randn(bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10 ) if use_padding: x[0][-1] = torch.full(x[0][-1].shape, float("-Inf")) if use_nt: xs = list(torch.unbind(x)) if use_padding: xs[0] = xs[0][:-1] x = torch.nested.nested_tensor(xs, device=device, dtype=dtype) qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) # We have to use inference_mode here because q/k/v are # all views of the same Tensor, which autograd doesn't # like. This is fine because this function is only # exposed to Python for purposes of writing this test. with torch.inference_mode(): (q, k, v) = torch._transform_bias_rescale_qkv( x, qkv.bias, num_heads=num_heads ) def simple_transform_bias_rescale_qkv(qkv, bias): (q, k, v) = torch.split(qkv, embed_dim, dim=-1) (q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1) def embiggen(x): if not use_nt: return x b, t, d = x.size() t = t + (8 - t % 8) % 8 newsize = (b, t, d) new_x = torch.zeros(newsize, device=device, dtype=dtype) new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x return new_x return tuple( embiggen(x).reshape( (bs, -1, num_heads, embed_dim // num_heads) ).transpose(2, 1) for x in ( (q + q_bias) / math.sqrt(embed_dim // num_heads), (k + k_bias), (v + v_bias), ) ) correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv( dense_x, qkv.bias ) if use_nt and use_padding: for t in (correct_q, correct_k, correct_v): t[t == float("-Inf")] = 0 self.assertEqual(q.size(), correct_q.size()) torch.testing.assert_close(q, correct_q) torch.testing.assert_close(k, correct_k) torch.testing.assert_close(v, correct_v) @dtypesIfCUDA(torch.float) @dtypes(torch.float) @skipMeta def test_transform_bias_rescale_qkv(self, device, dtype): for use_padding in (False, True): with self.subTest(use_padding=use_padding): self._test_transform_bias_rescale_qkv_impl( device, dtype, use_nt=False, use_padding=use_padding ) @dtypesIfCUDA(torch.float) @dtypes(torch.float) @skipMeta @onlyCUDA def test_transform_bias_rescale_qkv_nested(self, device, dtype): for use_padding in (False, True): with self.subTest(use_padding=use_padding): self._test_transform_bias_rescale_qkv_impl( device, dtype, use_nt=True, use_padding=use_padding ) def _test_multihead_attention_impl( self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False ): embed_dim = 64 num_heads = 4 bs = 16 sl = 8 q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 if use_padding: if pad_all: for q_i in q: q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32) mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) for mask_i in mask: mask_i[-1] = True else: q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32) mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) mask[0][-1] = True if mode == "self": k = q v = q elif mode == "encdec": k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 v = k elif mode == "generic": k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 v = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 else: self.fail(f"invalid mode `{mode}`!") qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=torch.float32) native_qkv = copy.deepcopy(qkv).to(dtype=dtype) proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32) native_proj = copy.deepcopy(proj).to(dtype=dtype) pt = torch.nn.MultiheadAttention( embed_dim, num_heads, batch_first=True, device=device, dtype=torch.float32 ) pt.in_proj_weight = qkv.weight pt.in_proj_bias = qkv.bias pt.out_proj.weight = proj.weight pt.out_proj.bias = proj.bias class NativeMHA(torch.nn.Module): def __init__(self, embed_dim, num_heads, qkv, proj): super().__init__() self.qkv = qkv self.proj = proj self.embed_dim = embed_dim self.num_heads = num_heads def forward(self, q, k, v, key_padding_mask): return torch._native_multi_head_attention( q, k, v, self.embed_dim, self.num_heads, self.qkv.weight, self.qkv.bias, self.proj.weight, self.proj.bias, key_padding_mask, need_weights=need_weights, average_attn_weights=average_attn_weights, mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask ) npt = NativeMHA( embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj ).to(dtype) if device == "cuda": pt = pt.cuda() npt = npt.cuda() ypt, weight_pt = pt( q, k, v, need_weights=need_weights, average_attn_weights=average_attn_weights, key_padding_mask=mask if use_padding else None, ) if use_nt: qs = list(torch.unbind(q)) if use_padding: if pad_all: qs = [x[:-1] for x in qs] else: qs[0] = qs[0][:-1] q = torch.nested.nested_tensor(qs, device=device, dtype=dtype) if mode == "self": k = v = q elif mode == "encdec": k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype) v = k else: k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype) v = torch.nested.nested_tensor(torch.unbind(v), device=device, dtype=dtype) native_q = q.to(dtype=dtype) native_k = k.to(dtype=dtype) native_v = v.to(dtype=dtype) ynpt, weight_npt = npt( native_q, native_k, native_v, key_padding_mask=mask if use_padding and not use_nt else None ) if use_nt: ynpt = ynpt.to_padded_tensor(0) if pad_all: ynpt_final = torch.zeros_like(ypt) ynpt_final[:, :ynpt.shape[1], :] = ynpt ynpt = ynpt_final def do_pad_all(tensors): for t in tensors: for t_i in t: t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype) # PyTorch implementation returns non-zero junk in the padding # locations; overwrite it so that the comparison works out. if use_padding: ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype) ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype) if pad_all: do_pad_all((ypt, ynpt)) # Zero the last row of each TxT weight matrix if need_weights: if average_attn_weights: weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype) weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype) if pad_all: do_pad_all((weight_pt, weight_npt)) else: for nh in range(num_heads): weight_pt[0][nh][-1] = torch.zeros_like(weight_pt[0][nh][-1], device=device, dtype=dtype) weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype) if dtype == torch.half: torch.testing.assert_close(ypt, ynpt.to(torch.float32), atol=1e-3, rtol=1e-3) else: # High rtol seems necessary for # test_native_multihead_attention_cpu_float32 on Windows, # otherwise 2e-4 would likely be fine. torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3) if need_weights: torch.testing.assert_close(weight_pt, weight_npt.to(torch.float32), atol=5e-4, rtol=5e-4) else: self.assertEqual(weight_pt, weight_npt) @dtypesIfCUDA(torch.float, torch.half) @dtypes(torch.float) @skipMeta @parametrize("use_nt", [False, True]) @parametrize("use_padding, pad_all", [(False, False), (True, False), (True, True)]) @parametrize("need_weights", [False]) @parametrize("average_attn_weights", [False, True]) @parametrize("fused", [False, True]) @torch.no_grad() def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding, pad_all, fused): if TEST_WITH_ROCM: if use_nt: self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") if use_padding and not pad_all and fused: self.skipTest("Large numerical errors on ROCM to investigate.") for need_weights in (False, not pad_all): with self.subTest(use_padding=use_padding, pad_all=pad_all, use_nt=use_nt, need_weights=need_weights, average_attn_weights=average_attn_weights): with torch.backends.cuda.sdp_kernel( enable_flash=False, enable_mem_efficient=False ) if not fused else torch.backends.cuda.sdp_kernel( enable_flash=True, enable_mem_efficient=True ): self._test_multihead_attention_impl( device, dtype, "self", use_nt=use_nt, use_padding=use_padding, pad_all=pad_all, need_weights=need_weights, average_attn_weights=average_attn_weights, ) @dtypesIfCUDA(torch.float, torch.half) @dtypes(torch.float) @skipMeta @torch.no_grad() def test_native_multihead_encoder_decoder_attention(self, device, dtype): self._test_multihead_attention_impl( device, dtype, "encdec", use_nt=False, need_weights=False, average_attn_weights=False, ) @dtypesIfCUDA(torch.float, torch.half) @dtypes(torch.float) @skipMeta @torch.no_grad() def test_native_multihead_attention(self, device, dtype): self._test_multihead_attention_impl( device, dtype, "generic", use_nt=False, need_weights=False, average_attn_weights=False, ) instantiate_device_type_tests(TestMHADeviceType, globals()) if __name__ == "__main__": run_tests()