xref: /aosp_15_r20/external/pytorch/test/inductor/test_cpu_select_algorithm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: cpu inductor"]
2import contextlib
3import functools
4import logging
5import os
6import sys
7import unittest
8from typing import Optional
9from unittest.mock import patch
10
11import torch
12import torch._dynamo.config
13import torch._dynamo.config as dynamo_config
14import torch._inductor.config as inductor_config
15import torch._inductor.select_algorithm as select_algorithm
16from torch._dynamo.utils import counters
17from torch._inductor.cpu_vec_isa import VecAMX
18from torch._inductor.test_case import run_tests, TestCase
19from torch.testing._internal.common_device_type import (
20    dtypes,
21    instantiate_device_type_tests,
22)
23from torch.testing._internal.common_quantization import _generate_qdq_quantized_model
24from torch.testing._internal.common_quantized import (
25    _calculate_dynamic_per_channel_qparams,
26)
27from torch.testing._internal.common_utils import (
28    IS_MACOS,
29    parametrize,
30    skipIfWindows,
31    TEST_MKL,
32)
33
34
35log = logging.getLogger(__name__)
36
37
38try:
39    try:
40        from . import test_cpu_repro, test_torchinductor
41    except ImportError:
42        import test_cpu_repro
43        import test_torchinductor
44except unittest.SkipTest:
45    if __name__ == "__main__":
46        sys.exit(0)
47    raise
48
49check_model = test_torchinductor.check_model
50set_num_threads = test_cpu_repro.set_num_threads
51
52aten = torch.ops.aten
53
54
55def patches(fn):
56    def skip_cache(self, choices, name, key, benchmark):
57        if benchmark is None:
58            return {}
59        timings = benchmark(choices)
60        for choice, timing in timings.items():
61            if isinstance(choice, select_algorithm.ExternKernelCaller):
62                # we intentionally make ATEN kernel slower to cover the cases
63                # where template kernels are always chosen with fusions applied
64                # and correctness checks at runtime.
65                timings[choice] = timing * 1000
66        return timings
67
68    for patcher in [
69        dynamo_config.patch(verbose=True),
70        dynamo_config.patch(inline_inbuilt_nn_modules=True),
71        inductor_config.patch(
72            debug=True,
73            max_autotune=True,
74            epilogue_fusion=True,
75            max_autotune_gemm_backends="CPP,ATEN",
76        ),
77        patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
78        patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
79    ]:
80        fn = patcher(fn)
81
82    @functools.wraps(fn)
83    def wrapped(*args, **kwargs):
84        counters.clear()
85        torch.manual_seed(12345)
86        return fn(*args, **kwargs)
87
88    return wrapped
89
90
91@contextlib.contextmanager
92def verify(dtype):
93    # For bfloat16 and half, we have to relax the tolerance
94    # due to the difference associave orders in different
95    # kernel implementations
96    atol, rtol = 1e-4, 1e-4
97    if dtype == torch.half or dtype == torch.bfloat16:
98        atol, rtol = 1e-2, 1e-2
99    with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
100        yield atol, rtol
101
102
103def _get_epilogue(epilogue: str, other: Optional[torch.Tensor] = None):
104    if epilogue == "none":
105        return lambda x: x
106    elif epilogue == "relu":
107        return torch.nn.ReLU()
108    elif epilogue == "gelu":
109        return torch.nn.GELU()
110    elif epilogue == "silu":
111        return torch.nn.SiLU()
112    elif epilogue == "sigmoid":
113        return torch.nn.Sigmoid()
114    elif epilogue == "tanh":
115        return torch.nn.Tanh()
116    elif epilogue == "hardswish":
117        return torch.nn.Hardswish()
118    elif epilogue == "hardsigmoid":
119        return torch.nn.Hardsigmoid()
120    elif epilogue == "leaky_relu":
121        return torch.nn.LeakyReLU()
122    elif epilogue == "hardtanh":
123        return torch.nn.Hardtanh()
124    elif epilogue == "add":
125        return lambda x: x + other
126    elif epilogue == "sub":
127        return lambda x: x - other
128    elif epilogue == "mul":
129        return lambda x: x * other
130    elif epilogue == "div":
131        return lambda x: x / other
132
133
134class BaseTestSelectAlgorithm(TestCase):
135    def _check_amx_counter(self, vec_amx):
136        if vec_amx:
137            self.assertTrue(counters["inductor"]["cpp_micro_gemm_amx_counter"] > 0)
138        else:
139            self.assertEqual(counters["inductor"]["cpp_micro_gemm_amx_counter"], 0)
140
141
142class TestSelectAlgorithm(BaseTestSelectAlgorithm):
143    common = check_model
144
145    @inductor_config.patch({"freezing": True})
146    @patches
147    @torch.no_grad
148    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
149    @parametrize("batch_size", (1, 2, 1000))
150    @parametrize("in_features", (1, 1000))
151    @parametrize("out_features", (1, 1024))
152    @parametrize("bias", (True, False))
153    @parametrize("input_3d", (True, False))
154    @dtypes(torch.float, torch.bfloat16, torch.half)
155    def test_linear_static_shapes(
156        self, batch_size, in_features, out_features, bias, input_3d, dtype
157    ):
158        class M(torch.nn.Module):
159            def __init__(self, bias):
160                super().__init__()
161                self.linear = torch.nn.Linear(in_features, out_features, bias)
162
163            def forward(self, x):
164                return self.linear(x)
165
166        counters.clear()
167        mod = M(bias=bias).to(dtype=dtype).eval()
168        B = (2, batch_size) if input_3d else (batch_size,)
169        v = torch.randn(*B, in_features).to(dtype=dtype)
170        with verify(dtype) as (atol, rtol):
171            self.common(mod, (v,), atol=atol, rtol=rtol)
172        if (
173            counters["inductor"]["decompose_mm"] > 0
174            or counters["inductor"]["decompose_addmm"] > 0
175        ):
176            # This is a special case where we go directly with vectorized codegen
177            self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
178        else:
179            self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
180
181    @inductor_config.patch({"freezing": True})
182    @patches
183    @torch.no_grad
184    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
185    @parametrize("in_features", (1000,))
186    @parametrize("out_features", (1024,))
187    @parametrize("bias", (True,))
188    @dtypes(
189        torch.float,
190    )
191    def test_linear_wgt_multi_users(self, in_features, out_features, bias, dtype):
192        class M(torch.nn.Module):
193            def __init__(self, bias):
194                super().__init__()
195                self.embeddings = torch.nn.Embedding(out_features, in_features)
196                self.linear = torch.nn.Linear(in_features, out_features, bias)
197                self.linear.weight = self.embeddings.weight
198
199            def forward(self, x):
200                x = self.embeddings(x)
201                return self.linear(x)
202
203        counters.clear()
204        mod = M(bias=bias).to(dtype=dtype).eval()
205        v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
206        with verify(dtype) as (atol, rtol):
207            self.common(mod, (v,), atol=atol, rtol=rtol)
208        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
209
210    @inductor_config.patch({"freezing": True})
211    @patches
212    @torch.no_grad
213    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
214    @parametrize("bias", (True, False))
215    @dtypes(torch.float)
216    def test_linear_input_transpose(self, bias, dtype):
217        batch_size = 384
218        in_features = 196
219        out_features = 384
220
221        class M(torch.nn.Module):
222            def __init__(self, bias):
223                super().__init__()
224                self.linear = torch.nn.Linear(in_features, out_features, bias)
225
226            @torch.compile
227            def forward(self, x):
228                return self.linear(x)
229
230        counters.clear()
231        mod = M(bias=bias).to(dtype=dtype).eval()
232        v = torch.randn(in_features, batch_size).to(dtype=dtype)
233        self.common(mod, (v.transpose(0, 1),))
234        # TODO(jgong5): support transposed input
235        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
236
237    @inductor_config.patch({"freezing": True})
238    @patches
239    @torch.no_grad
240    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
241    @parametrize("batch_size", (384,))
242    @parametrize("in_features", (196,))
243    @parametrize("out_features", (384, 385))
244    @parametrize("bias", (True, False))
245    @parametrize(
246        "epilogue",
247        (
248            "relu",
249            "gelu",
250            "silu",
251            "sigmoid",
252            "tanh",
253            "hardswish",
254            "hardsigmoid",
255            "leaky_relu",
256            "hardtanh",
257            "add",
258            "sub",
259            "mul",
260            "div",
261        ),
262    )
263    @dtypes(torch.float, torch.bfloat16, torch.half)
264    @torch.fx.experimental._config.patch(use_duck_shape=False)
265    def test_linear_with_pointwise(
266        self, batch_size, in_features, out_features, bias, epilogue, dtype
267    ):
268        class M(torch.nn.Module):
269            def __init__(self, bias, epilogue, other):
270                super().__init__()
271                self.linear = torch.nn.Linear(in_features, out_features, bias)
272                self.epilogue = _get_epilogue(epilogue, other)
273
274            def forward(self, x):
275                return self.epilogue(self.linear(x))
276
277        # TODO: debug utils, safe to remove in Oct 2024
278        if inductor_config.is_fbcode():
279            log.warning(
280                f"DEBUG: torch.backends.mkl.is_available() is {torch.backends.mkl.is_available()}, "  # noqa: G004
281                f"torch.ops.mkldnn._is_mkldnn_fp16_supported() is {torch.ops.mkldnn._is_mkldnn_fp16_supported()}, "
282                f"torch.ops.mkldnn._is_mkldnn_bf16_supported() is {torch.ops.mkldnn._is_mkldnn_bf16_supported()}, "
283                f"inductor_config.freezing is {inductor_config.freezing}, "
284                f"mkldnn._is_mkldnn_acl_supported() is {torch.ops.mkldnn._is_mkldnn_acl_supported()}, "
285                f"torch._C.has_mkl is {torch._C.has_mkl}, "
286                f"PYTORCH_TEST_FBCODE is {os.getenv('PYTORCH_TEST_FBCODE')}, "
287                f"PYTORCH_TEST_REMOTE_GPU is {os.getenv('PYTORCH_TEST_REMOTE_GPU')}, "
288            )
289
290        counters.clear()
291        v = torch.randn(batch_size, in_features).to(dtype=dtype)
292        u = torch.randn(batch_size, out_features).to(dtype=dtype)
293        mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
294        with verify(dtype) as (atol, rtol):
295            self.common(mod, (v,), atol=atol, rtol=rtol)
296        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
297        if (
298            (
299                dtype == torch.bfloat16
300                or (
301                    dtype == torch.float16
302                    and torch.ops.mkldnn._is_mkldnn_fp16_supported()
303                )
304            )
305            and epilogue != "mul"
306            and epilogue != "div"
307            or (dtype == torch.half and epilogue == "add" and not bias)
308            or (
309                dtype == torch.float32
310                and epilogue == "add"
311                and not bias
312                and dynamo_config.dynamic_shapes
313                and not dynamo_config.assume_static_by_default
314            )
315        ):
316            # Several scenarios where epilogue fusion is not counted in:
317            # 1. For bfloat16, the epilogue fusion is part of the template,
318            #    not fused via scheduler. This will also be true for float16 when
319            #    hardware has the float16 instruction. The exception is mul or
320            #    div fusion which is not supported for oneDNN linear.
321            # 2. For float16, since oneDNN linear is not applied, linear w/o bias
322            #    plus epilogue add is treated as linear w/ bias.
323            # 3. For float32, when dynamic shapes is enabled, mkl linear is not applied.
324            #    and linear w/o bias plus epilogue add is treated as addmm.
325            self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
326        else:
327            self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
328
329    @inductor_config.patch({"freezing": True})
330    @patches
331    @torch.no_grad
332    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
333    @parametrize("batch_size", (384,))
334    @parametrize("in_features", (196,))
335    @parametrize("out_features", (128, 129))
336    @parametrize("bias", (True, False))
337    @parametrize(
338        "epilogue",
339        (
340            "none",
341            "relu",
342            "add",
343            "sub",
344            "mul",
345        ),
346    )
347    @dtypes(torch.float, torch.bfloat16, torch.half)
348    def test_linear_with_transpose(
349        self, batch_size, in_features, out_features, bias, epilogue, dtype
350    ):
351        class M(torch.nn.Module):
352            def __init__(self, bias, epilogue, other):
353                super().__init__()
354                self.epilogue = _get_epilogue(epilogue, other)
355                self.linear = torch.nn.Linear(in_features, out_features, bias)
356
357            def forward(self, x, y):
358                return self.epilogue(self.linear(x)).transpose(0, 1) + y
359
360        counters.clear()
361        v = torch.randn(batch_size, in_features).to(dtype=dtype)
362        u = torch.randn(out_features, batch_size).to(dtype=dtype)
363        other = torch.randn(batch_size, out_features).to(dtype=dtype)
364        mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval()
365        with verify(dtype) as (atol, rtol):
366            self.common(mod, (v, u), atol=atol, rtol=rtol)
367        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
368        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
369
370    @inductor_config.patch({"freezing": True})
371    @patches
372    @torch.no_grad
373    @parametrize("batch_size", (1,))
374    @parametrize("in_features", (16,))
375    @parametrize("image_size", (18,))
376    @parametrize("out_features", (32,))
377    @parametrize(
378        "bias",
379        (
380            False,
381            True,
382        ),
383    )
384    @parametrize(
385        "has_non_epilogue_users",
386        (
387            True,
388            False,
389        ),
390    )
391    @dtypes(torch.bfloat16)
392    def test_linear_with_permute(
393        self,
394        batch_size,
395        in_features,
396        image_size,
397        out_features,
398        bias,
399        has_non_epilogue_users,
400        dtype,
401    ):
402        # Reproducer from the convnext model in timm
403        class M(torch.nn.Module):
404            def __init__(self, bias, has_non_epilogue_users):
405                super().__init__()
406                self.linear = torch.nn.Linear(in_features, out_features, bias)
407                self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1)
408                self.conv = torch.nn.Conv2d(
409                    out_features,
410                    out_features,
411                    kernel_size=7,
412                    padding=3,
413                    groups=out_features,
414                )
415                self.linear2 = torch.nn.Linear(out_features, out_features, bias)
416                self._frozen_param400 = torch.randn(batch_size, out_features, 1, 1)
417                self.has_non_epilogue_users = has_non_epilogue_users
418
419            def forward(self, mul_272, _convolution_pointwise_default_31):
420                out1 = torch.ops.prims.convert_element_type.default(
421                    mul_272, torch.bfloat16
422                )
423                mul_272 = None
424
425                _linear_pointwise_default_131 = self.linear(out1)
426                permute_188 = torch.ops.aten.permute.default(
427                    _linear_pointwise_default_131, [0, 3, 1, 2]
428                )
429
430                mul_273 = torch.ops.aten.mul.Tensor(permute_188, self._frozen_param398)
431                add_187 = torch.ops.aten.add.Tensor(
432                    mul_273, _convolution_pointwise_default_31
433                )
434                convert_element_type_847 = torch.ops.prims.convert_element_type.default(
435                    add_187, torch.bfloat16
436                )
437                _convolution_pointwise_default_29 = self.conv(convert_element_type_847)
438                permute_189 = torch.ops.aten.permute.default(
439                    _convolution_pointwise_default_29, [0, 2, 3, 1]
440                )
441                permute_189 = self.linear2(permute_189)
442                permute_189 = torch.ops.aten.permute.default(permute_189, [0, 3, 1, 2])
443                permute_189 = torch.ops.aten.mul.Tensor(
444                    permute_189, self._frozen_param400
445                )
446                # If template_buffer will be used by nodes other than the epilogue nodes,
447                # we can't alias the template_buffer with the Y buffer.
448                if self.has_non_epilogue_users:
449                    add_191 = torch.ops.aten.add.Tensor(permute_189, add_187)
450                    return add_191
451                return permute_189
452
453        view_12 = torch.randn(batch_size, image_size, image_size, in_features)
454        _convolution_pointwise_default_31 = torch.randn(
455            batch_size, out_features, image_size, image_size
456        ).to(memory_format=torch.channels_last)
457
458        mod = M(bias=bias, has_non_epilogue_users=has_non_epilogue_users).eval()
459        with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast():
460            self.common(
461                mod,
462                (
463                    view_12,
464                    _convolution_pointwise_default_31,
465                ),
466                atol=atol,
467                rtol=rtol,
468            )
469        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
470        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
471
472    @inductor_config.patch({"freezing": True})
473    @patches
474    @torch.no_grad
475    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
476    @parametrize("batch_size", (8,))
477    @parametrize("in_features", (3,))
478    @parametrize("linear_in_features", (384,))
479    @parametrize("out_features", (196,))
480    @parametrize("bias", (True,))
481    @dtypes(torch.float)
482    def test_linear_with_input_of_flexible_layout(
483        self, batch_size, in_features, linear_in_features, out_features, bias, dtype
484    ):
485        # Reproducer from the resmlp_12_224 model in timm
486        flatten_BS = int(batch_size * linear_in_features)
487
488        class M(torch.nn.Module):
489            def __init__(self, bias):
490                super().__init__()
491                self.conv = torch.nn.Conv2d(
492                    in_features,
493                    linear_in_features,
494                    kernel_size=16,
495                    padding=0,
496                    stride=16,
497                    dilation=1,
498                    groups=1,
499                )
500                self._frozen_param151 = torch.randn(1, 1, linear_in_features)
501                self._frozen_param3 = torch.randn(1, 1, linear_in_features)
502                self._frozen_param2 = torch.randn(linear_in_features)
503
504                self.linear = torch.nn.Linear(out_features, out_features, bias)
505
506            def forward(self, arg150_1):
507                _convolution_pointwise_default = self.conv(arg150_1)
508                view_73 = torch.ops.aten.reshape.default(
509                    _convolution_pointwise_default,
510                    [batch_size, linear_in_features, out_features],
511                )
512                _convolution_pointwise_default = None
513                permute_62 = torch.ops.aten.permute.default(view_73, [0, 2, 1])
514                view_73 = None
515                mul_111 = torch.ops.aten.mul.Tensor(self._frozen_param151, permute_62)
516                add_73 = torch.ops.aten.add.Tensor(self._frozen_param3, mul_111)
517                permute_63 = torch.ops.aten.permute.default(add_73, [0, 2, 1])
518                add_73 = None
519                view_74 = torch.ops.aten.reshape.default(
520                    permute_63, [flatten_BS, out_features]
521                )
522                permute_63 = None
523                _mkl_linear_36 = self.linear(view_74)
524                view_75 = torch.ops.aten.reshape.default(
525                    _mkl_linear_36, [batch_size, linear_in_features, out_features]
526                )
527                _mkl_linear_36 = None
528                permute_65 = torch.ops.aten.permute.default(view_75, [0, 2, 1])
529                view_75 = None
530                mul_112 = torch.ops.aten.mul.Tensor(self._frozen_param2, permute_65)
531                _frozen_param2 = permute_65 = None
532                add_74 = torch.ops.aten.add.Tensor(permute_62, mul_112)
533                permute_62 = mul_112 = None
534                return add_74
535
536        v = torch.randn(batch_size, in_features, 224, 224).to(dtype=dtype)
537        mod = M(bias=bias).to(dtype=dtype).eval()
538        with verify(dtype) as (atol, rtol):
539            self.common(mod, (v,), atol=atol, rtol=rtol)
540        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
541        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
542
543    @inductor_config.patch({"freezing": True})
544    @patches
545    @torch.no_grad
546    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
547    @parametrize("batch_size", (384,))
548    @parametrize("in_features", (196,))
549    @parametrize("out_features", (384, 385))
550    @parametrize("bias", (True, False))
551    @parametrize(
552        "unary",
553        ("relu",),
554    )
555    @parametrize(
556        "binary",
557        (
558            "add",
559            "sub",
560            "mul",
561            "div",
562        ),
563    )
564    @dtypes(torch.float, torch.bfloat16, torch.half)
565    def test_linear_with_unary_binary(
566        self, batch_size, in_features, out_features, bias, unary, binary, dtype
567    ):
568        class M(torch.nn.Module):
569            def __init__(self, bias, unary, binary, other):
570                super().__init__()
571                self.linear = torch.nn.Linear(in_features, out_features, bias)
572                self.unary = _get_epilogue(unary)
573                self.binary = _get_epilogue(binary, other)
574
575            def forward(self, x):
576                return self.binary(self.unary(self.linear(x)))
577
578        counters.clear()
579        v = torch.randn(batch_size, in_features).to(dtype=dtype)
580        u = torch.randn(batch_size, out_features).to(dtype=dtype)
581        mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
582        with verify(dtype) as (atol, rtol):
583            self.common(mod, (v,), atol=atol, rtol=rtol)
584        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
585        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
586
587    @inductor_config.patch({"freezing": True})
588    @patches
589    @torch.no_grad
590    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
591    @parametrize("batch_size", (384,))
592    @parametrize("in_features", (196,))
593    @parametrize("out_features", (384,))
594    @parametrize("bias", (True, False))
595    @parametrize(
596        "binary",
597        ("add",),
598    )
599    @dtypes(torch.float, torch.bfloat16, torch.half)
600    def test_linear_with_binary_input_3d(
601        self, batch_size, in_features, out_features, bias, binary, dtype
602    ):
603        class M(torch.nn.Module):
604            def __init__(self, bias, binary, other):
605                super().__init__()
606                self.linear = torch.nn.Linear(in_features, out_features, bias)
607                self.binary = _get_epilogue(binary, other)
608
609            def forward(self, x):
610                return self.binary(self.linear(x))
611
612        counters.clear()
613        B = (2, batch_size)
614        v = torch.randn(*B, in_features).to(dtype=dtype)
615        u = torch.randn(*B, out_features).to(dtype=dtype)
616        mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval()
617        with verify(dtype) as (atol, rtol):
618            self.common(mod, (v,), atol=atol, rtol=rtol)
619        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
620
621    @inductor_config.patch({"freezing": True})
622    @patches
623    @torch.no_grad
624    @parametrize("batch_size", (1024,))
625    @parametrize("in_features", (1024,))
626    @parametrize("out_features", (1024, 1025))
627    @parametrize("bias", (True, False))
628    @dtypes(torch.bfloat16)
629    def test_linear_amx(self, batch_size, in_features, out_features, bias, dtype):
630        class M(torch.nn.Module):
631            def __init__(self, bias):
632                super().__init__()
633                self.linear = torch.nn.Linear(in_features, out_features, bias)
634
635            def forward(self, x):
636                return self.linear(x)
637
638        counters.clear()
639        v = torch.randn(batch_size, in_features).to(dtype=dtype)
640        mod = M(bias=bias).to(dtype=dtype).eval()
641        with verify(dtype) as (atol, rtol):
642            self.common(mod, (v,), atol=atol, rtol=rtol)
643        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
644        vec_amx = VecAMX()
645        self._check_amx_counter(vec_amx)
646
647    @inductor_config.patch({"freezing": True})
648    @patches
649    @torch.no_grad
650    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
651    @parametrize("batch_size", (8,))
652    @parametrize("in_features", (128,))
653    @parametrize("in_features_2", (196,))
654    @parametrize("out_features", (256,))
655    @parametrize(
656        "bias",
657        (True,),
658    )
659    @dtypes(torch.float32)
660    def test_linear_with_multiple_reindexers(
661        self,
662        batch_size,
663        in_features,
664        in_features_2,
665        out_features,
666        bias,
667        dtype,
668    ):
669        flatten_BS = int(batch_size * in_features_2)
670
671        # Reproducer from the levit_128 model in timm
672        class M(torch.nn.Module):
673            def __init__(self, bias):
674                super().__init__()
675                self.conv = torch.nn.Conv2d(
676                    64,
677                    128,
678                    kernel_size=3,
679                    padding=1,
680                    stride=2,
681                    dilation=1,
682                    groups=1,
683                )
684                self.linear = torch.nn.Linear(in_features, out_features, bias=False)
685                self._frozen_param221 = torch.randn(out_features)
686                self._frozen_param389 = torch.randn(out_features)
687                self._frozen_param20 = torch.randn(out_features)
688                self._frozen_param21 = torch.randn(out_features)
689
690            def forward(self, view_368):
691                _mkl_linear_57 = self.linear(view_368)
692                view_369 = torch.ops.aten.reshape.default(
693                    _mkl_linear_57, [batch_size, in_features_2, out_features]
694                )
695                _mkl_linear_57 = None
696
697                view_370 = torch.ops.aten.reshape.default(
698                    view_369, [flatten_BS, out_features]
699                )
700                view_369 = None
701                sub_85 = torch.ops.aten.sub.Tensor(view_370, self._frozen_param221)
702                view_370 = _frozen_param221 = None
703                mul_261 = torch.ops.aten.mul.Tensor(sub_85, self._frozen_param389)
704                sub_85 = _frozen_param389 = None
705                mul_262 = torch.ops.aten.mul.Tensor(mul_261, self._frozen_param20)
706                mul_261 = _frozen_param20 = None
707                add_219 = torch.ops.aten.add.Tensor(mul_262, self._frozen_param21)
708                mul_262 = _frozen_param21 = None
709                view_371 = torch.ops.aten.reshape.default(
710                    add_219, [batch_size, in_features_2, out_features]
711                )
712                add_219 = None
713
714                add_220 = torch.ops.aten.add.Tensor(view_371, 3)
715                clamp_min_35 = torch.ops.aten.clamp_min.default(add_220, 0)
716                add_220 = None
717                clamp_max_35 = torch.ops.aten.clamp_max.default(clamp_min_35, 6)
718                clamp_min_35 = None
719                mul_263 = torch.ops.aten.mul.Tensor(view_371, clamp_max_35)
720                view_371 = clamp_max_35 = None
721                div_51 = torch.ops.aten.div.Tensor(mul_263, 6)
722                mul_263 = None
723
724                return div_51
725
726        view_368 = torch.randn(flatten_BS, in_features)
727
728        mod = M(bias=bias).eval()
729        with verify(dtype) as (atol, rtol):
730            self.common(
731                mod,
732                (view_368,),
733                atol=atol,
734                rtol=rtol,
735            )
736        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
737        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)
738
739    @inductor_config.patch({"freezing": True})
740    @patches
741    @torch.no_grad
742    @parametrize("batch_size", (384,))
743    @parametrize("in_features", (196,))
744    @parametrize("out_features", (384,))
745    @parametrize("bias", (True, False))
746    @dtypes(torch.bfloat16)
747    def test_linear_with_embedding(
748        self, batch_size, in_features, out_features, bias, dtype
749    ):
750        class M(torch.nn.Module):
751            def __init__(self, bias):
752                super().__init__()
753                self.linear = torch.nn.Linear(in_features, out_features, bias).to(
754                    dtype=dtype
755                )
756                self.emb = torch.nn.Embedding(64, out_features)
757
758            def forward(self, idx, x):
759                return self.emb(idx) + self.linear(x)
760
761        idx = torch.randint(0, 64, (batch_size,))
762        x = torch.randn(batch_size, in_features).to(dtype=dtype)
763        mod = M(bias=bias).eval()
764        with verify(dtype) as (atol, rtol):
765            self.common(mod, (idx, x), atol=atol, rtol=rtol)
766        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
767        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
768
769    @inductor_config.patch({"freezing": True})
770    @patches
771    @torch.no_grad
772    @parametrize("batch_size", (2,))
773    @parametrize("in_features", (16,))
774    @parametrize("seq_lens", (128,))
775    @parametrize("out_features", (32,))
776    @parametrize("bias", (True,))
777    @dtypes(torch.bfloat16)
778    def test_linear_with_indirect_indexing(
779        self, batch_size, in_features, seq_lens, out_features, bias, dtype
780    ):
781        # Reproducer from the GPT2ForSequenceClassification model in HuggingFace
782        class M(torch.nn.Module):
783            def __init__(self, bias):
784                super().__init__()
785                self.wte = torch.nn.Embedding(128, seq_lens)
786                self.wpe = torch.nn.Embedding(in_features, seq_lens)
787                self.linear = torch.nn.Linear(out_features, seq_lens, bias)
788
789            def forward(self, view_12, input_ids, view_9):
790                inputs_embeds = self.wte(input_ids)
791
792                position_ids = torch.arange(0, in_features, dtype=torch.long)
793                position_ids = position_ids.unsqueeze(0)
794                position_embeds = self.wpe(position_ids)
795
796                add = inputs_embeds + position_embeds
797                add_4 = view_9 + add
798
799                _linear_pointwise_default_45 = self.linear(view_12)
800
801                view_13 = torch.ops.aten.reshape.default(
802                    _linear_pointwise_default_45, [batch_size, in_features, seq_lens]
803                )
804                out = torch.ops.aten.add.Tensor(add_4, view_13)
805
806                return out
807
808        view_12 = torch.randn(batch_size * in_features, out_features)
809        input_ids = torch.randint(0, 128, (batch_size, in_features))
810        view_9 = torch.randn(batch_size, in_features, seq_lens)
811        mod = M(bias=bias).eval()
812        with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast():
813            self.common(
814                mod,
815                (
816                    view_12,
817                    input_ids,
818                    view_9,
819                ),
820                atol=atol,
821                rtol=rtol,
822            )
823        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
824        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
825
826    @inductor_config.patch({"freezing": True})
827    @patches
828    @torch.no_grad
829    @parametrize("batch_size", (8,))
830    @parametrize("in_features", (3,))
831    @parametrize("in_features2", (192,))
832    @parametrize("image_size", (224,))
833    @parametrize("out_features", (64,))
834    @parametrize(
835        "bias",
836        (True,),
837    )
838    @dtypes(torch.float32)
839    def test_linear_with_in_out_buffer(
840        self,
841        batch_size,
842        in_features,
843        in_features2,
844        image_size,
845        out_features,
846        bias,
847        dtype,
848    ):
849        # Reproducer from the coat_lite_mini model in timm
850        class M(torch.nn.Module):
851            def __init__(self, bias):
852                super().__init__()
853                self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1)
854                self.conv = torch.nn.Conv2d(
855                    in_features,
856                    out_features,
857                    kernel_size=4,
858                    padding=0,
859                    stride=4,
860                    dilation=1,
861                    groups=1,
862                )
863                self.conv2 = torch.nn.Conv2d(
864                    out_features,
865                    out_features,
866                    kernel_size=3,
867                    padding=1,
868                    stride=1,
869                    dilation=1,
870                    groups=out_features,
871                )
872
873                self.conv3 = torch.nn.Conv2d(
874                    16,
875                    16,
876                    kernel_size=3,
877                    padding=1,
878                    stride=1,
879                    dilation=1,
880                    groups=16,
881                )
882
883                self.conv4 = torch.nn.Conv2d(
884                    24,
885                    24,
886                    kernel_size=5,
887                    padding=2,
888                    stride=1,
889                    dilation=1,
890                    groups=24,
891                )
892
893                self.conv5 = torch.nn.Conv2d(
894                    24,
895                    24,
896                    kernel_size=7,
897                    padding=3,
898                    stride=1,
899                    dilation=1,
900                    groups=24,
901                )
902
903                self.linear = torch.nn.Linear(out_features, in_features2, bias)
904
905                self.linear2 = torch.nn.Linear(out_features, out_features, bias)
906                self._frozen_param2 = torch.randn(out_features)
907                self._frozen_param3 = torch.randn(out_features)
908                self._frozen_param7 = torch.randn(out_features)
909                self._frozen_param8 = torch.randn(out_features)
910                self._frozen_param153 = torch.randn(batch_size, 1, out_features)
911
912            def forward(self, arg152_1):
913                _convolution_pointwise_default_35 = self.conv(arg152_1)
914                arg152_1 = None
915
916                view_168 = torch.ops.aten.reshape.default(
917                    _convolution_pointwise_default_35, [8, 64, 3136]
918                )
919                _convolution_pointwise_default_35 = None
920                permute_97 = torch.ops.aten.permute.default(view_168, [0, 2, 1])
921                view_168 = None
922                clone_65 = torch.ops.aten.clone.default(
923                    permute_97, memory_format=torch.contiguous_format
924                )
925                permute_97 = None
926                var_mean_21 = torch.ops.aten.var_mean.correction(
927                    clone_65, [2], correction=0, keepdim=True
928                )
929                getitem_90 = var_mean_21[0]
930                getitem_91 = var_mean_21[1]
931                var_mean_21 = None
932                add_82 = torch.ops.aten.add.Tensor(getitem_90, 1e-05)
933                getitem_90 = None
934                rsqrt_21 = torch.ops.aten.rsqrt.default(add_82)
935                add_82 = None
936                sub_29 = torch.ops.aten.sub.Tensor(clone_65, getitem_91)
937                clone_65 = getitem_91 = None
938                mul_82 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_21)
939                sub_29 = rsqrt_21 = None
940                mul_83 = torch.ops.aten.mul.Tensor(mul_82, self._frozen_param2)
941                mul_82 = None
942                add_83 = torch.ops.aten.add.Tensor(mul_83, self._frozen_param3)
943                mul_83 = None
944                _frozen_param153 = self._frozen_param153
945                cat_20 = torch.ops.aten.cat.default([_frozen_param153, add_83], 1)
946                _frozen_param153 = add_83 = None
947                slice_111 = torch.ops.aten.slice.Tensor(cat_20, 1, 0, 1)
948                slice_113 = torch.ops.aten.slice.Tensor(
949                    cat_20, 1, 1, 9223372036854775807
950                )
951                cat_20 = None
952                permute_98 = torch.ops.aten.permute.default(slice_113, [0, 2, 1])
953                slice_113 = None
954                view_169 = torch.ops.aten.reshape.default(permute_98, [8, 64, 56, 56])
955                permute_98 = None
956                _convolution_pointwise_default_34 = self.conv2(view_169)
957
958                add_84 = torch.ops.aten.add.Tensor(
959                    _convolution_pointwise_default_34, view_169
960                )
961                _convolution_pointwise_default_34 = view_169 = None
962                view_170 = torch.ops.aten.reshape.default(add_84, [8, 64, 3136])
963                add_84 = None
964                permute_99 = torch.ops.aten.permute.default(view_170, [0, 2, 1])
965                view_170 = None
966                cat_21 = torch.ops.aten.cat.default([slice_111, permute_99], 1)
967                slice_111 = permute_99 = None
968                var_mean_22 = torch.ops.aten.var_mean.correction(
969                    cat_21, [2], correction=0, keepdim=True
970                )
971                getitem_92 = var_mean_22[0]
972                getitem_93 = var_mean_22[1]
973                var_mean_22 = None
974                add_85 = torch.ops.aten.add.Tensor(getitem_92, 1e-06)
975                getitem_92 = None
976                rsqrt_22 = torch.ops.aten.rsqrt.default(add_85)
977                add_85 = None
978                sub_30 = torch.ops.aten.sub.Tensor(cat_21, getitem_93)
979                getitem_93 = None
980                mul_84 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_22)
981                sub_30 = rsqrt_22 = None
982                mul_85 = torch.ops.aten.mul.Tensor(mul_84, self._frozen_param7)
983                mul_84 = None
984                add_86 = torch.ops.aten.add.Tensor(mul_85, self._frozen_param8)
985                mul_85 = None
986                view_171 = torch.ops.aten.reshape.default(add_86, [25096, 64])
987                add_86 = None
988
989                _mkl_linear_32 = self.linear(view_171)
990                view_171 = None
991
992                view_172 = torch.ops.aten.reshape.default(
993                    _mkl_linear_32, [8, 3137, 192]
994                )
995                _mkl_linear_32 = None
996                view_173 = torch.ops.aten.reshape.default(view_172, [8, 3137, 3, 8, 8])
997                view_172 = None
998                permute_101 = torch.ops.aten.permute.default(view_173, [2, 0, 3, 1, 4])
999                view_173 = None
1000                unbind_8 = torch.ops.aten.unbind.int(permute_101)
1001                permute_101 = None
1002                getitem_94 = unbind_8[0]
1003                getitem_95 = unbind_8[1]
1004                getitem_96 = unbind_8[2]
1005                unbind_8 = None
1006                clone_66 = torch.ops.aten.clone.default(
1007                    getitem_95, memory_format=torch.contiguous_format
1008                )
1009                getitem_95 = None
1010                amax_8 = torch.ops.aten.amax.default(clone_66, [2], True)
1011                sub_31 = torch.ops.aten.sub.Tensor(clone_66, amax_8)
1012                clone_66 = amax_8 = None
1013                exp_8 = torch.ops.aten.exp.default(sub_31)
1014                sub_31 = None
1015                sum_9 = torch.ops.aten.sum.dim_IntList(exp_8, [2], True)
1016                div_8 = torch.ops.aten.div.Tensor(exp_8, sum_9)
1017                exp_8 = sum_9 = None
1018                permute_102 = torch.ops.aten.permute.default(div_8, [0, 1, 3, 2])
1019                div_8 = None
1020                expand_37 = torch.ops.aten.expand.default(permute_102, [8, 8, 8, 3137])
1021                permute_102 = None
1022                view_174 = torch.ops.aten.reshape.default(expand_37, [64, 8, 3137])
1023                expand_37 = None
1024                expand_38 = torch.ops.aten.expand.default(getitem_96, [8, 8, 3137, 8])
1025                clone_67 = torch.ops.aten.clone.default(
1026                    expand_38, memory_format=torch.contiguous_format
1027                )
1028                expand_38 = None
1029                view_175 = torch.ops.aten.reshape.default(clone_67, [64, 3137, 8])
1030                clone_67 = None
1031                bmm_16 = torch.ops.aten.bmm.default(view_174, view_175)
1032                view_174 = view_175 = None
1033                view_176 = torch.ops.aten.reshape.default(bmm_16, [8, 8, 8, 8])
1034                bmm_16 = None
1035                expand_39 = torch.ops.aten.expand.default(getitem_94, [8, 8, 3137, 8])
1036                clone_68 = torch.ops.aten.clone.default(
1037                    expand_39, memory_format=torch.contiguous_format
1038                )
1039                expand_39 = None
1040                view_177 = torch.ops.aten.reshape.default(clone_68, [64, 3137, 8])
1041                clone_68 = None
1042                expand_40 = torch.ops.aten.expand.default(view_176, [8, 8, 8, 8])
1043                view_176 = None
1044                view_178 = torch.ops.aten.reshape.default(expand_40, [64, 8, 8])
1045                expand_40 = None
1046                bmm_17 = torch.ops.aten.bmm.default(view_177, view_178)
1047                view_177 = view_178 = None
1048                view_179 = torch.ops.aten.reshape.default(bmm_17, [8, 8, 3137, 8])
1049                bmm_17 = None
1050                slice_116 = torch.ops.aten.slice.Tensor(
1051                    getitem_94, 2, 1, 9223372036854775807
1052                )
1053                getitem_94 = None
1054                slice_120 = torch.ops.aten.slice.Tensor(
1055                    getitem_96, 2, 1, 9223372036854775807
1056                )
1057                getitem_96 = None
1058                permute_103 = torch.ops.aten.permute.default(slice_120, [0, 1, 3, 2])
1059                slice_120 = None
1060                view_180 = torch.ops.aten.reshape.default(permute_103, [8, 64, 56, 56])
1061                permute_103 = None
1062                split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(
1063                    view_180, [16, 24, 24], 1
1064                )
1065                view_180 = None
1066                getitem_97 = split_with_sizes_8[0]
1067                getitem_98 = split_with_sizes_8[1]
1068                getitem_99 = split_with_sizes_8[2]
1069                split_with_sizes_8 = None
1070
1071                _convolution_pointwise_default_33 = self.conv3(getitem_97)
1072                _convolution_pointwise_default_32 = self.conv4(getitem_98)
1073                _convolution_pointwise_default_31 = self.conv5(getitem_99)
1074
1075                cat_22 = torch.ops.aten.cat.default(
1076                    [
1077                        _convolution_pointwise_default_33,
1078                        _convolution_pointwise_default_32,
1079                        _convolution_pointwise_default_31,
1080                    ],
1081                    1,
1082                )
1083                _convolution_pointwise_default_33 = (
1084                    _convolution_pointwise_default_32
1085                ) = _convolution_pointwise_default_31 = None
1086                view_181 = torch.ops.aten.reshape.default(cat_22, [8, 8, 8, 3136])
1087                cat_22 = None
1088                permute_104 = torch.ops.aten.permute.default(view_181, [0, 1, 3, 2])
1089                view_181 = None
1090
1091                mul_86 = torch.ops.aten.mul.Tensor(slice_116, permute_104)
1092                slice_116 = permute_104 = None
1093                constant_pad_nd_8 = torch.ops.aten.constant_pad_nd.default(
1094                    mul_86, [0, 0, 1, 0, 0, 0], 0.0
1095                )
1096                mul_86 = None
1097                mul_87 = torch.ops.aten.mul.Tensor(view_179, 0.3535533905932738)
1098                view_179 = None
1099                add_87 = torch.ops.aten.add.Tensor(mul_87, constant_pad_nd_8)
1100                mul_87 = constant_pad_nd_8 = None
1101                return add_87
1102
1103        view_12 = torch.randn(batch_size, in_features, image_size, image_size)
1104
1105        mod = M(bias=bias).eval()
1106        with verify(dtype) as (atol, rtol):
1107            self.common(
1108                mod,
1109                (view_12,),
1110                atol=atol,
1111                rtol=rtol,
1112            )
1113        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
1114        self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
1115
1116    @inductor_config.patch({"freezing": True})
1117    @patches
1118    @torch.no_grad
1119    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
1120    @parametrize("batch_size", (32,))
1121    @parametrize("in_features", (128,))
1122    @parametrize("out_features", (64, 65))
1123    @parametrize("bias", (False, True))
1124    @parametrize("input_3d", (False, True))
1125    @dtypes(torch.float32, torch.bfloat16)
1126    @parametrize(
1127        "epilogue",
1128        (
1129            "none",
1130            "relu",
1131            "gelu",
1132        ),
1133    )
1134    @skipIfWindows(msg="Windows don't support quantize.")
1135    def test_quantized_linear_with_pointwise(
1136        self, batch_size, in_features, out_features, bias, input_3d, dtype, epilogue
1137    ):
1138        B = (2, batch_size) if input_3d else (batch_size,)
1139        input = torch.randn(*B, in_features).to(dtype=torch.float32)
1140
1141        class M(torch.nn.Module):
1142            def __init__(self, bias):
1143                super().__init__()
1144                self.linear = torch.nn.Linear(in_features, out_features, bias)
1145                self.epilogue = _get_epilogue(epilogue)
1146                self.linear2 = torch.nn.Linear(out_features, out_features, bias)
1147                self.epilogue2 = _get_epilogue(epilogue)
1148
1149            def forward(self, x):
1150                res = self.epilogue(self.linear(x))
1151                res = self.epilogue2(self.linear2(res))
1152                return res
1153
1154        counters.clear()
1155        ref_quantized_mod = _generate_qdq_quantized_model(
1156            M(bias=bias).eval(),
1157            (input,),
1158        )
1159
1160        atol, rtol = 1e-3, 1e-3
1161        if dtype == torch.bfloat16:
1162            atol, rtol = 5e-2, 5e-2
1163
1164        with patch.object(
1165            select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)
1166        ), torch.no_grad(), torch.autocast(
1167            "cpu", enabled=(dtype == torch.bfloat16), dtype=dtype
1168        ):
1169            ref_res = ref_quantized_mod(input)
1170            cfn = torch.compile(ref_quantized_mod)
1171            res = cfn(input)
1172            self.assertEqual(
1173                res,
1174                ref_res,
1175                atol=atol,
1176                rtol=rtol,
1177                equal_nan=True,
1178                exact_dtype=True,
1179            )
1180            self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
1181            self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
1182
1183    @inductor_config.patch({"freezing": True})
1184    @patches
1185    @torch.no_grad
1186    @dtypes(torch.bfloat16)
1187    @parametrize("batch_size", (32,))
1188    @parametrize("in_features", (128,))
1189    @parametrize("out_features", (64, 65))
1190    def test_int8_woq_mm(self, dtype, batch_size, in_features, out_features):
1191        # x will be reshaped from 3d to 2d
1192        second_dim_size = 8
1193
1194        def _convert_weight_to_int8pack(w):
1195            scale, zp = _calculate_dynamic_per_channel_qparams(
1196                w.to(torch.float), torch.int8
1197            )
1198            scale = torch.from_numpy(scale)
1199            zp = torch.from_numpy(zp)
1200            w_int8 = torch.ao.quantization.fx._decomposed.quantize_per_channel(
1201                input=w,
1202                scales=scale,
1203                zero_points=zp,
1204                axis=0,
1205                quant_min=-128,
1206                quant_max=127,
1207                dtype=torch.int8,
1208            )
1209            return w_int8, scale.to(torch.bfloat16)
1210
1211        class M(torch.nn.Module):
1212            def __init__(self, w):
1213                super().__init__()
1214                self.linear_weight = torch.nn.Parameter(w, requires_grad=False)
1215
1216            def forward(self, x, scale):
1217                return (
1218                    torch.nn.functional.linear(x, self.linear_weight.to(x.dtype))
1219                    * scale
1220                )
1221
1222        counters.clear()
1223        # Currently, the corresponding torch.fx pattern only supports 3D x
1224        # Add 2D X case once the corresponding pattern-matcher pattern is added
1225        x = torch.rand((batch_size, second_dim_size, in_features), dtype=dtype)
1226        w = torch.rand((out_features, in_features), dtype=dtype)
1227        w_int8pack, w_scales = _convert_weight_to_int8pack(w)
1228        mod = M(w_int8pack).eval()
1229        self.common(mod, (x, w_scales))
1230        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
1231        vec_amx = VecAMX()
1232        self._check_amx_counter(vec_amx)
1233
1234    @inductor_config.patch({"freezing": True})
1235    @patches
1236    @torch.no_grad
1237    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
1238    @parametrize("batch_size", (32,))
1239    @parametrize("in_features", (128,))
1240    @parametrize("out_features", (64, 65))
1241    @parametrize("bias", (False, True))
1242    @parametrize("input_3d", (False, True))
1243    @parametrize("int8_mixed_bf16", (False, True))
1244    @dtypes(torch.float32, torch.bfloat16)
1245    @parametrize(
1246        "epilogue",
1247        (
1248            "none",
1249            "relu",
1250        ),
1251    )
1252    @skipIfWindows(msg="Windows don't support quantize.")
1253    def test_quantized_linear_with_pointwise_binary(
1254        self,
1255        batch_size,
1256        in_features,
1257        out_features,
1258        bias,
1259        input_3d,
1260        int8_mixed_bf16,
1261        dtype,
1262        epilogue,
1263    ):
1264        if not int8_mixed_bf16 and dtype == torch.bfloat16:
1265            return
1266        B = (2, batch_size) if input_3d else (batch_size,)
1267        input = torch.randn(*B, in_features).to(dtype=torch.float32)
1268
1269        other = torch.randn(*B, out_features).to(dtype=dtype)
1270        # Avoid hiting qlinear inplace sum fusion
1271        if input_3d:
1272            other2 = torch.randn(B[0] * B[1], out_features).to(dtype=dtype)
1273        else:
1274            other2 = torch.randn(1, *B, out_features).to(dtype=dtype)
1275
1276        class M(torch.nn.Module):
1277            def __init__(self, bias, input_3d):
1278                super().__init__()
1279                self.linear = torch.nn.Linear(in_features, out_features, bias)
1280                self.epilogue = _get_epilogue(epilogue)
1281                self.linear2 = torch.nn.Linear(out_features, out_features, bias)
1282                self.epilogue2 = _get_epilogue(epilogue)
1283                self.input_3d = input_3d
1284
1285            def forward(self, x, other, other2):
1286                res = self.epilogue(self.linear(x) + other)
1287                # Avoid hiting qlinear inplace sum fusion
1288                if self.input_3d:
1289                    other2 = other2.view(2, other2.size(0) // 2, other2.size(1))
1290                else:
1291                    other2 = other2.view(other2.size(1), other2.size(2))
1292                res = self.epilogue2(self.linear2(res) + other2)
1293                return res
1294
1295        counters.clear()
1296        ref_quantized_mod = _generate_qdq_quantized_model(
1297            M(bias=bias, input_3d=input_3d).eval(),
1298            (input, other, other2),
1299        )
1300        atol, rtol = 5e-2, 5e-2
1301        with patch.object(
1302            select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)
1303        ), torch.no_grad(), torch.autocast(
1304            "cpu", enabled=int8_mixed_bf16, dtype=torch.bfloat16
1305        ):
1306            ref_res = ref_quantized_mod(input, other, other2)
1307            cfn = torch.compile(ref_quantized_mod)
1308            res = cfn(input, other, other2)
1309            self.assertEqual(
1310                res,
1311                ref_res,
1312                atol=atol,
1313                rtol=rtol,
1314                equal_nan=True,
1315                exact_dtype=True,
1316            )
1317            self.assertEqual(
1318                counters["inductor"]["select_algorithm_autotune"],
1319                2,
1320            )
1321            self.assertEqual(
1322                counters["inductor"]["cpp_epilogue_fusion_counter"],
1323                0,
1324            )
1325
1326    @inductor_config.patch({"freezing": True})
1327    @patches
1328    @torch.no_grad
1329    @parametrize("batch_size", (3, 16, 32, 49))
1330    @parametrize("in_features", (4, 68, 128))  # k should be a multiple of 4
1331    @parametrize("out_features", (64, 65))
1332    @parametrize("bias", (True, False))
1333    @skipIfWindows(msg="Windows don't support quantize.")
1334    def test_quantized_linear_amx(self, batch_size, in_features, out_features, bias):
1335        class M(torch.nn.Module):
1336            def __init__(self, bias):
1337                super().__init__()
1338                self.linear = torch.nn.Linear(in_features, out_features, bias)
1339
1340            def forward(self, x):
1341                return self.linear(x)
1342
1343        counters.clear()
1344        v = torch.randn(batch_size, in_features).to(dtype=torch.float32)
1345        ref_quantized_mod = _generate_qdq_quantized_model(
1346            M(bias=bias).eval(),
1347            (v,),
1348        )
1349        atol, rtol = 1e-2, 1e-2
1350        with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
1351            self.common(ref_quantized_mod, (v,), atol=atol, rtol=rtol)
1352        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
1353        vec_amx = VecAMX()
1354        self._check_amx_counter(vec_amx)
1355
1356    @inductor_config.patch({"freezing": True})
1357    @inductor_config.patch({"cpp.gemm_max_k_slices": 0})
1358    @patches
1359    @torch.no_grad
1360    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
1361    @parametrize("batch_size", (2,))
1362    @parametrize("in_features", (1000,))
1363    @parametrize("out_features", (2,))
1364    @parametrize("bias", (True, False))
1365    @parametrize(
1366        "epilogue",
1367        (
1368            "none",
1369            "relu",
1370        ),
1371    )
1372    @dtypes(torch.float, torch.bfloat16, torch.half)
1373    def test_linear_k_slicing(
1374        self, batch_size, in_features, out_features, bias, epilogue, dtype
1375    ):
1376        class M(torch.nn.Module):
1377            def __init__(self, bias, epilogue, other):
1378                super().__init__()
1379                self.linear = torch.nn.Linear(in_features, out_features, bias)
1380                self.epilogue = _get_epilogue(epilogue, other)
1381
1382            def forward(self, x):
1383                return self.epilogue(self.linear(x))
1384
1385        counters.clear()
1386        v = torch.randn(batch_size, in_features).to(dtype=dtype)
1387        u = torch.randn(batch_size, out_features).to(dtype=dtype)
1388        mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
1389        with verify(dtype) as (atol, rtol):
1390            self.common(mod, (v,), atol=atol, rtol=rtol)
1391        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
1392
1393    @inductor_config.patch({"freezing": True})
1394    @inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"})
1395    @patches
1396    @torch.no_grad
1397    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
1398    @set_num_threads(1)
1399    @parametrize("batch_size", (1024,))
1400    @parametrize("in_features", (1024,))
1401    @parametrize("out_features", (1024,))
1402    @parametrize("bias", (True, False))
1403    @dtypes(torch.float, torch.bfloat16, torch.half)
1404    def test_linear_cache_blocking(
1405        self, batch_size, in_features, out_features, bias, dtype
1406    ):
1407        class M(torch.nn.Module):
1408            def __init__(self, bias):
1409                super().__init__()
1410                self.linear = torch.nn.Linear(in_features, out_features, bias)
1411
1412            def forward(self, x):
1413                return self.linear(x)
1414
1415        counters.clear()
1416        v = torch.randn(batch_size, in_features).to(dtype=dtype)
1417        mod = M(bias=bias).to(dtype=dtype).eval()
1418        with verify(dtype) as (atol, rtol):
1419            self.common(mod, (v,), atol=atol, rtol=rtol)
1420        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
1421
1422    @inductor_config.patch({"freezing": True})
1423    @inductor_config.patch({"cpp.gemm_thread_factors": "4,2,7"})
1424    @patches
1425    @torch.no_grad
1426    @unittest.skipIf(not TEST_MKL, "Test requires MKL")
1427    @set_num_threads(56)
1428    @parametrize("batch_size", (1024,))
1429    @parametrize("in_features", (1024,))
1430    @parametrize("out_features", (1024,))
1431    @parametrize("bias", (True, False))
1432    @dtypes(torch.float, torch.bfloat16, torch.half)
1433    def test_linear_thread_factors(
1434        self, batch_size, in_features, out_features, bias, dtype
1435    ):
1436        class M(torch.nn.Module):
1437            def __init__(self, bias):
1438                super().__init__()
1439                self.linear = torch.nn.Linear(in_features, out_features, bias)
1440
1441            def forward(self, x):
1442                return self.linear(x)
1443
1444        counters.clear()
1445        v = torch.randn(batch_size, in_features).to(dtype=dtype)
1446        mod = M(bias=bias).to(dtype=dtype).eval()
1447        with verify(dtype) as (atol, rtol):
1448            self.common(mod, (v,), atol=atol, rtol=rtol)
1449        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
1450
1451
1452@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
1453class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
1454    pass
1455
1456
1457class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
1458    common = check_model
1459    test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes
1460    test_linear_with_pointwise_dynamic_shapes = (
1461        TestSelectAlgorithm.test_linear_with_pointwise
1462    )
1463    test_linear_with_transpose_dynamic_shapes = (
1464        TestSelectAlgorithm.test_linear_with_transpose
1465    )
1466    test_linear_with_unary_binary_dynamic_shapes = (
1467        TestSelectAlgorithm.test_linear_with_unary_binary
1468    )
1469    test_linear_amx_dynamic_shapes = TestSelectAlgorithm.test_linear_amx
1470    test_linear_with_embedding_dynamic_shapes = (
1471        TestSelectAlgorithm.test_linear_with_embedding
1472    )
1473    test_quantized_linear_with_pointwise_dynamic_shapes = (
1474        TestSelectAlgorithm.test_quantized_linear_with_pointwise
1475    )
1476    test_quantized_linear_with_pointwise_binary_dynamic_shapes = (
1477        TestSelectAlgorithm.test_quantized_linear_with_pointwise_binary
1478    )
1479    test_quantized_linear_amx_dynamic_shapes = (
1480        TestSelectAlgorithm.test_quantized_linear_amx
1481    )
1482    test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
1483    test_linear_cache_blocking_dynamic_shapes = (
1484        TestSelectAlgorithm.test_linear_cache_blocking
1485    )
1486    test_linear_thread_factors_dynamic_shapes = (
1487        TestSelectAlgorithm.test_linear_thread_factors
1488    )
1489
1490
1491instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
1492instantiate_device_type_tests(
1493    TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu"
1494)
1495
1496
1497if __name__ == "__main__":
1498    from torch.testing._internal.inductor_utils import HAS_CPU
1499
1500    if HAS_CPU and not IS_MACOS:
1501        run_tests()
1502