xref: /aosp_15_r20/external/pytorch/test/inductor/test_mmdecomp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3import math
4import unittest
5from typing import List, Tuple, Union
6
7import torch
8from torch._inductor import config
9from torch.testing._internal.common_cuda import SM80OrLater
10from torch.testing._internal.common_device_type import instantiate_device_type_tests
11from torch.testing._internal.common_nn import NNTestCase
12from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests
13from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
14
15
16default_atol = {
17    torch.float16: 1e-3,
18    torch.bfloat16: float("infinity"),
19    torch.float32: 1e-5,
20}
21default_rtol = {
22    torch.float16: 1e-3,
23    torch.bfloat16: float("infinity"),
24    torch.float32: 1.3e-6,
25}
26
27
28def rand_math_tensor(
29    shape: Tuple[Union[int, List[int]]],
30    device: str,
31    dtype: torch.dtype,
32    requires_grad: bool = False,
33    packed: bool = False,
34) -> torch.Tensor:
35    """Creates rand dense or nested tensor with given shape and type.
36
37    Args:
38        shape (Tuple[int]): Shape of Tensor to construct
39        device (str): which device to create tensor on
40        dtype (torch.dtype): Tensors' dtype
41        requires_grad (bool, optional): Tensors grad status. Defaults to False.
42        packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
43
44    Returns:
45        torch.Tensor: A new tensor
46    """
47    return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
48
49
50def init_tensor(tensor_list, **kwargs) -> torch.Tensor:
51    return torch.Tensor(tensor_list).to(**kwargs)
52
53
54def run_comp_nocomp(function, *inputs, **kwargs):
55    c_function = torch.compile(function)
56
57    f_res = function(*inputs)
58    cf_res = c_function(*inputs)
59
60    if not (math.isinf(kwargs.get("atol", 0.0)) or math.isinf(kwargs.get("rtol", 0.0))):
61        torch.testing.assert_close(f_res, cf_res, **kwargs)
62
63
64# The test functions are used by several tests
65def torch_mm(a, b):
66    return torch.mm(a, b)
67
68
69def torch_addmm(add, b, c):
70    return torch.addmm(add, b, c)
71
72
73def torch_bmm(a, b):
74    return torch.bmm(a, b)
75
76
77def torch_baddbmm(add, b, c, alpha, beta):
78    return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
79
80
81# The shapes we test on
82ts_list = [
83    (1, 32, 32, 1),
84    (1, 10, 10, 1),
85    (1, 3, 3, 1),
86    (32, 1, 1, 32),
87    (3, 1, 1, 3),
88    (4, 1, 1, 9),
89    (9, 1, 1, 4),
90]
91
92
93class TestDecomp(NNTestCase):
94    _do_cuda_memory_leak_check = GPU_TYPE == "cuda"
95    _do_cuda_non_default_stream = GPU_TYPE == "cuda"
96
97    @unittest.skipIf(not HAS_GPU, "GPU tests require triton")
98    @parametrize("dtype", [torch.float, torch.bfloat16])
99    def test_simple_mm(self, device, dtype):
100        fudge = 10
101        rtol = default_rtol[dtype] * fudge
102        atol = default_atol[dtype] * fudge
103
104        for t_size in ts_list:
105            ((a1_0, a1_1, a2_0, a2_1)) = t_size
106
107            t1 = rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device)
108            t2 = rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device)
109            tadd = rand_math_tensor((a1_0, a2_1), dtype=dtype, device=device)
110
111            run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol)
112            run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol)
113
114    @unittest.skipIf(not HAS_GPU, "GPU tests require triton")
115    @parametrize(
116        "dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float]
117    )
118    @parametrize("bs", [1, 2, 4, 10])
119    def test_batched_mm(self, device, dtype, bs):
120        fudge = 3
121        rtol = default_rtol[dtype] * fudge
122        atol = default_atol[dtype] * fudge
123
124        for t_size in ts_list:
125            ((a1_0, a1_1, a2_0, a2_1)) = t_size
126
127            t1 = rand_math_tensor((bs, a1_0, a1_1), dtype=dtype, device=device)
128            t2 = rand_math_tensor((bs, a2_0, a2_1), dtype=dtype, device=device)
129            tadd = rand_math_tensor((bs, a1_0, a2_1), dtype=dtype, device=device)
130
131            run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
132
133            for alpha in (0, 1, -1, 0.5, -0.5):
134                for beta in (0, 1, -1, 0.5, -0.5):
135                    run_comp_nocomp(
136                        torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol
137                    )
138
139    @unittest.skipIf(not HAS_GPU, "GPU tests require triton")
140    @config.patch(coordinate_descent_tuning=True)
141    def test_bmm_batch2_last_dim_size_is_one(self, device):
142        fudge = 3
143        rtol = default_rtol[torch.float32] * fudge
144        atol = default_atol[torch.float32] * fudge
145
146        t1 = torch.randn(1, 32, 2, device=device)
147        t2 = torch.randn(1, 2, 1, device=device)
148
149        run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
150
151    @unittest.skipIf(not HAS_GPU, "GPU tests require triton")
152    @parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
153    def test_some(self, device, dtype):
154        # this Pytorch data type is not fully supported on cuda today
155        # - unfortunately we can't skipIf because we don't see the actual parms in skipIf
156        if device.startswith(GPU_TYPE) and dtype == torch.int:
157            return
158
159        run_comp_nocomp(
160            torch_mm,
161            init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
162            init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
163        )
164        run_comp_nocomp(
165            torch_mm,
166            init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
167            init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
168        )
169
170    @unittest.skipIf(not HAS_GPU, "GPU tests require triton")
171    @parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
172    @parametrize("bs", [1, 2, 4, 10])
173    def test_some_batched(self, device, dtype, bs):
174        # this Pytorch data type is not fully supported on cuda today
175        # - unfortunately we can't skipIf because we don't see the actual parms in skipIf
176        if device.startswith(GPU_TYPE) and dtype == torch.int:
177            return
178
179        run_comp_nocomp(
180            torch_bmm,
181            init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
182            init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
183        )
184        run_comp_nocomp(
185            torch_bmm,
186            init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
187            init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
188        )
189
190
191device_types = ("cpu", GPU_TYPE)
192instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
193
194if __name__ == "__main__":
195    # We don't support torch.compile() on Windows
196    if not IS_WINDOWS:
197        run_tests()
198