1# Owner(s): ["module: nn"] 2 3import math 4import unittest 5from typing import List, Tuple, Union 6 7import torch 8from torch._inductor import config 9from torch.testing._internal.common_cuda import SM80OrLater 10from torch.testing._internal.common_device_type import instantiate_device_type_tests 11from torch.testing._internal.common_nn import NNTestCase 12from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests 13from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 14 15 16default_atol = { 17 torch.float16: 1e-3, 18 torch.bfloat16: float("infinity"), 19 torch.float32: 1e-5, 20} 21default_rtol = { 22 torch.float16: 1e-3, 23 torch.bfloat16: float("infinity"), 24 torch.float32: 1.3e-6, 25} 26 27 28def rand_math_tensor( 29 shape: Tuple[Union[int, List[int]]], 30 device: str, 31 dtype: torch.dtype, 32 requires_grad: bool = False, 33 packed: bool = False, 34) -> torch.Tensor: 35 """Creates rand dense or nested tensor with given shape and type. 36 37 Args: 38 shape (Tuple[int]): Shape of Tensor to construct 39 device (str): which device to create tensor on 40 dtype (torch.dtype): Tensors' dtype 41 requires_grad (bool, optional): Tensors grad status. Defaults to False. 42 packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False. 43 44 Returns: 45 torch.Tensor: A new tensor 46 """ 47 return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) 48 49 50def init_tensor(tensor_list, **kwargs) -> torch.Tensor: 51 return torch.Tensor(tensor_list).to(**kwargs) 52 53 54def run_comp_nocomp(function, *inputs, **kwargs): 55 c_function = torch.compile(function) 56 57 f_res = function(*inputs) 58 cf_res = c_function(*inputs) 59 60 if not (math.isinf(kwargs.get("atol", 0.0)) or math.isinf(kwargs.get("rtol", 0.0))): 61 torch.testing.assert_close(f_res, cf_res, **kwargs) 62 63 64# The test functions are used by several tests 65def torch_mm(a, b): 66 return torch.mm(a, b) 67 68 69def torch_addmm(add, b, c): 70 return torch.addmm(add, b, c) 71 72 73def torch_bmm(a, b): 74 return torch.bmm(a, b) 75 76 77def torch_baddbmm(add, b, c, alpha, beta): 78 return torch.baddbmm(add, b, c, alpha=alpha, beta=beta) 79 80 81# The shapes we test on 82ts_list = [ 83 (1, 32, 32, 1), 84 (1, 10, 10, 1), 85 (1, 3, 3, 1), 86 (32, 1, 1, 32), 87 (3, 1, 1, 3), 88 (4, 1, 1, 9), 89 (9, 1, 1, 4), 90] 91 92 93class TestDecomp(NNTestCase): 94 _do_cuda_memory_leak_check = GPU_TYPE == "cuda" 95 _do_cuda_non_default_stream = GPU_TYPE == "cuda" 96 97 @unittest.skipIf(not HAS_GPU, "GPU tests require triton") 98 @parametrize("dtype", [torch.float, torch.bfloat16]) 99 def test_simple_mm(self, device, dtype): 100 fudge = 10 101 rtol = default_rtol[dtype] * fudge 102 atol = default_atol[dtype] * fudge 103 104 for t_size in ts_list: 105 ((a1_0, a1_1, a2_0, a2_1)) = t_size 106 107 t1 = rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device) 108 t2 = rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device) 109 tadd = rand_math_tensor((a1_0, a2_1), dtype=dtype, device=device) 110 111 run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol) 112 run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol) 113 114 @unittest.skipIf(not HAS_GPU, "GPU tests require triton") 115 @parametrize( 116 "dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float] 117 ) 118 @parametrize("bs", [1, 2, 4, 10]) 119 def test_batched_mm(self, device, dtype, bs): 120 fudge = 3 121 rtol = default_rtol[dtype] * fudge 122 atol = default_atol[dtype] * fudge 123 124 for t_size in ts_list: 125 ((a1_0, a1_1, a2_0, a2_1)) = t_size 126 127 t1 = rand_math_tensor((bs, a1_0, a1_1), dtype=dtype, device=device) 128 t2 = rand_math_tensor((bs, a2_0, a2_1), dtype=dtype, device=device) 129 tadd = rand_math_tensor((bs, a1_0, a2_1), dtype=dtype, device=device) 130 131 run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol) 132 133 for alpha in (0, 1, -1, 0.5, -0.5): 134 for beta in (0, 1, -1, 0.5, -0.5): 135 run_comp_nocomp( 136 torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol 137 ) 138 139 @unittest.skipIf(not HAS_GPU, "GPU tests require triton") 140 @config.patch(coordinate_descent_tuning=True) 141 def test_bmm_batch2_last_dim_size_is_one(self, device): 142 fudge = 3 143 rtol = default_rtol[torch.float32] * fudge 144 atol = default_atol[torch.float32] * fudge 145 146 t1 = torch.randn(1, 32, 2, device=device) 147 t2 = torch.randn(1, 2, 1, device=device) 148 149 run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol) 150 151 @unittest.skipIf(not HAS_GPU, "GPU tests require triton") 152 @parametrize("dtype", [torch.float, torch.bfloat16, torch.int]) 153 def test_some(self, device, dtype): 154 # this Pytorch data type is not fully supported on cuda today 155 # - unfortunately we can't skipIf because we don't see the actual parms in skipIf 156 if device.startswith(GPU_TYPE) and dtype == torch.int: 157 return 158 159 run_comp_nocomp( 160 torch_mm, 161 init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device), 162 init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device), 163 ) 164 run_comp_nocomp( 165 torch_mm, 166 init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device), 167 init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device), 168 ) 169 170 @unittest.skipIf(not HAS_GPU, "GPU tests require triton") 171 @parametrize("dtype", [torch.float, torch.bfloat16, torch.int]) 172 @parametrize("bs", [1, 2, 4, 10]) 173 def test_some_batched(self, device, dtype, bs): 174 # this Pytorch data type is not fully supported on cuda today 175 # - unfortunately we can't skipIf because we don't see the actual parms in skipIf 176 if device.startswith(GPU_TYPE) and dtype == torch.int: 177 return 178 179 run_comp_nocomp( 180 torch_bmm, 181 init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device), 182 init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device), 183 ) 184 run_comp_nocomp( 185 torch_bmm, 186 init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device), 187 init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device), 188 ) 189 190 191device_types = ("cpu", GPU_TYPE) 192instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types) 193 194if __name__ == "__main__": 195 # We don't support torch.compile() on Windows 196 if not IS_WINDOWS: 197 run_tests() 198