1# Owner(s): ["oncall: cpu inductor"] 2import contextlib 3import functools 4import logging 5import os 6import sys 7import unittest 8from typing import Optional 9from unittest.mock import patch 10 11import torch 12import torch._dynamo.config 13import torch._dynamo.config as dynamo_config 14import torch._inductor.config as inductor_config 15import torch._inductor.select_algorithm as select_algorithm 16from torch._dynamo.utils import counters 17from torch._inductor.cpu_vec_isa import VecAMX 18from torch._inductor.test_case import run_tests, TestCase 19from torch.testing._internal.common_device_type import ( 20 dtypes, 21 instantiate_device_type_tests, 22) 23from torch.testing._internal.common_quantization import _generate_qdq_quantized_model 24from torch.testing._internal.common_quantized import ( 25 _calculate_dynamic_per_channel_qparams, 26) 27from torch.testing._internal.common_utils import ( 28 IS_MACOS, 29 parametrize, 30 skipIfWindows, 31 TEST_MKL, 32) 33 34 35log = logging.getLogger(__name__) 36 37 38try: 39 try: 40 from . import test_cpu_repro, test_torchinductor 41 except ImportError: 42 import test_cpu_repro 43 import test_torchinductor 44except unittest.SkipTest: 45 if __name__ == "__main__": 46 sys.exit(0) 47 raise 48 49check_model = test_torchinductor.check_model 50set_num_threads = test_cpu_repro.set_num_threads 51 52aten = torch.ops.aten 53 54 55def patches(fn): 56 def skip_cache(self, choices, name, key, benchmark): 57 if benchmark is None: 58 return {} 59 timings = benchmark(choices) 60 for choice, timing in timings.items(): 61 if isinstance(choice, select_algorithm.ExternKernelCaller): 62 # we intentionally make ATEN kernel slower to cover the cases 63 # where template kernels are always chosen with fusions applied 64 # and correctness checks at runtime. 65 timings[choice] = timing * 1000 66 return timings 67 68 for patcher in [ 69 dynamo_config.patch(verbose=True), 70 dynamo_config.patch(inline_inbuilt_nn_modules=True), 71 inductor_config.patch( 72 debug=True, 73 max_autotune=True, 74 epilogue_fusion=True, 75 max_autotune_gemm_backends="CPP,ATEN", 76 ), 77 patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), 78 patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), 79 ]: 80 fn = patcher(fn) 81 82 @functools.wraps(fn) 83 def wrapped(*args, **kwargs): 84 counters.clear() 85 torch.manual_seed(12345) 86 return fn(*args, **kwargs) 87 88 return wrapped 89 90 91@contextlib.contextmanager 92def verify(dtype): 93 # For bfloat16 and half, we have to relax the tolerance 94 # due to the difference associave orders in different 95 # kernel implementations 96 atol, rtol = 1e-4, 1e-4 97 if dtype == torch.half or dtype == torch.bfloat16: 98 atol, rtol = 1e-2, 1e-2 99 with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): 100 yield atol, rtol 101 102 103def _get_epilogue(epilogue: str, other: Optional[torch.Tensor] = None): 104 if epilogue == "none": 105 return lambda x: x 106 elif epilogue == "relu": 107 return torch.nn.ReLU() 108 elif epilogue == "gelu": 109 return torch.nn.GELU() 110 elif epilogue == "silu": 111 return torch.nn.SiLU() 112 elif epilogue == "sigmoid": 113 return torch.nn.Sigmoid() 114 elif epilogue == "tanh": 115 return torch.nn.Tanh() 116 elif epilogue == "hardswish": 117 return torch.nn.Hardswish() 118 elif epilogue == "hardsigmoid": 119 return torch.nn.Hardsigmoid() 120 elif epilogue == "leaky_relu": 121 return torch.nn.LeakyReLU() 122 elif epilogue == "hardtanh": 123 return torch.nn.Hardtanh() 124 elif epilogue == "add": 125 return lambda x: x + other 126 elif epilogue == "sub": 127 return lambda x: x - other 128 elif epilogue == "mul": 129 return lambda x: x * other 130 elif epilogue == "div": 131 return lambda x: x / other 132 133 134class BaseTestSelectAlgorithm(TestCase): 135 def _check_amx_counter(self, vec_amx): 136 if vec_amx: 137 self.assertTrue(counters["inductor"]["cpp_micro_gemm_amx_counter"] > 0) 138 else: 139 self.assertEqual(counters["inductor"]["cpp_micro_gemm_amx_counter"], 0) 140 141 142class TestSelectAlgorithm(BaseTestSelectAlgorithm): 143 common = check_model 144 145 @inductor_config.patch({"freezing": True}) 146 @patches 147 @torch.no_grad 148 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 149 @parametrize("batch_size", (1, 2, 1000)) 150 @parametrize("in_features", (1, 1000)) 151 @parametrize("out_features", (1, 1024)) 152 @parametrize("bias", (True, False)) 153 @parametrize("input_3d", (True, False)) 154 @dtypes(torch.float, torch.bfloat16, torch.half) 155 def test_linear_static_shapes( 156 self, batch_size, in_features, out_features, bias, input_3d, dtype 157 ): 158 class M(torch.nn.Module): 159 def __init__(self, bias): 160 super().__init__() 161 self.linear = torch.nn.Linear(in_features, out_features, bias) 162 163 def forward(self, x): 164 return self.linear(x) 165 166 counters.clear() 167 mod = M(bias=bias).to(dtype=dtype).eval() 168 B = (2, batch_size) if input_3d else (batch_size,) 169 v = torch.randn(*B, in_features).to(dtype=dtype) 170 with verify(dtype) as (atol, rtol): 171 self.common(mod, (v,), atol=atol, rtol=rtol) 172 if ( 173 counters["inductor"]["decompose_mm"] > 0 174 or counters["inductor"]["decompose_addmm"] > 0 175 ): 176 # This is a special case where we go directly with vectorized codegen 177 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) 178 else: 179 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 180 181 @inductor_config.patch({"freezing": True}) 182 @patches 183 @torch.no_grad 184 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 185 @parametrize("in_features", (1000,)) 186 @parametrize("out_features", (1024,)) 187 @parametrize("bias", (True,)) 188 @dtypes( 189 torch.float, 190 ) 191 def test_linear_wgt_multi_users(self, in_features, out_features, bias, dtype): 192 class M(torch.nn.Module): 193 def __init__(self, bias): 194 super().__init__() 195 self.embeddings = torch.nn.Embedding(out_features, in_features) 196 self.linear = torch.nn.Linear(in_features, out_features, bias) 197 self.linear.weight = self.embeddings.weight 198 199 def forward(self, x): 200 x = self.embeddings(x) 201 return self.linear(x) 202 203 counters.clear() 204 mod = M(bias=bias).to(dtype=dtype).eval() 205 v = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 206 with verify(dtype) as (atol, rtol): 207 self.common(mod, (v,), atol=atol, rtol=rtol) 208 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 209 210 @inductor_config.patch({"freezing": True}) 211 @patches 212 @torch.no_grad 213 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 214 @parametrize("bias", (True, False)) 215 @dtypes(torch.float) 216 def test_linear_input_transpose(self, bias, dtype): 217 batch_size = 384 218 in_features = 196 219 out_features = 384 220 221 class M(torch.nn.Module): 222 def __init__(self, bias): 223 super().__init__() 224 self.linear = torch.nn.Linear(in_features, out_features, bias) 225 226 @torch.compile 227 def forward(self, x): 228 return self.linear(x) 229 230 counters.clear() 231 mod = M(bias=bias).to(dtype=dtype).eval() 232 v = torch.randn(in_features, batch_size).to(dtype=dtype) 233 self.common(mod, (v.transpose(0, 1),)) 234 # TODO(jgong5): support transposed input 235 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) 236 237 @inductor_config.patch({"freezing": True}) 238 @patches 239 @torch.no_grad 240 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 241 @parametrize("batch_size", (384,)) 242 @parametrize("in_features", (196,)) 243 @parametrize("out_features", (384, 385)) 244 @parametrize("bias", (True, False)) 245 @parametrize( 246 "epilogue", 247 ( 248 "relu", 249 "gelu", 250 "silu", 251 "sigmoid", 252 "tanh", 253 "hardswish", 254 "hardsigmoid", 255 "leaky_relu", 256 "hardtanh", 257 "add", 258 "sub", 259 "mul", 260 "div", 261 ), 262 ) 263 @dtypes(torch.float, torch.bfloat16, torch.half) 264 @torch.fx.experimental._config.patch(use_duck_shape=False) 265 def test_linear_with_pointwise( 266 self, batch_size, in_features, out_features, bias, epilogue, dtype 267 ): 268 class M(torch.nn.Module): 269 def __init__(self, bias, epilogue, other): 270 super().__init__() 271 self.linear = torch.nn.Linear(in_features, out_features, bias) 272 self.epilogue = _get_epilogue(epilogue, other) 273 274 def forward(self, x): 275 return self.epilogue(self.linear(x)) 276 277 # TODO: debug utils, safe to remove in Oct 2024 278 if inductor_config.is_fbcode(): 279 log.warning( 280 f"DEBUG: torch.backends.mkl.is_available() is {torch.backends.mkl.is_available()}, " # noqa: G004 281 f"torch.ops.mkldnn._is_mkldnn_fp16_supported() is {torch.ops.mkldnn._is_mkldnn_fp16_supported()}, " 282 f"torch.ops.mkldnn._is_mkldnn_bf16_supported() is {torch.ops.mkldnn._is_mkldnn_bf16_supported()}, " 283 f"inductor_config.freezing is {inductor_config.freezing}, " 284 f"mkldnn._is_mkldnn_acl_supported() is {torch.ops.mkldnn._is_mkldnn_acl_supported()}, " 285 f"torch._C.has_mkl is {torch._C.has_mkl}, " 286 f"PYTORCH_TEST_FBCODE is {os.getenv('PYTORCH_TEST_FBCODE')}, " 287 f"PYTORCH_TEST_REMOTE_GPU is {os.getenv('PYTORCH_TEST_REMOTE_GPU')}, " 288 ) 289 290 counters.clear() 291 v = torch.randn(batch_size, in_features).to(dtype=dtype) 292 u = torch.randn(batch_size, out_features).to(dtype=dtype) 293 mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval() 294 with verify(dtype) as (atol, rtol): 295 self.common(mod, (v,), atol=atol, rtol=rtol) 296 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 297 if ( 298 ( 299 dtype == torch.bfloat16 300 or ( 301 dtype == torch.float16 302 and torch.ops.mkldnn._is_mkldnn_fp16_supported() 303 ) 304 ) 305 and epilogue != "mul" 306 and epilogue != "div" 307 or (dtype == torch.half and epilogue == "add" and not bias) 308 or ( 309 dtype == torch.float32 310 and epilogue == "add" 311 and not bias 312 and dynamo_config.dynamic_shapes 313 and not dynamo_config.assume_static_by_default 314 ) 315 ): 316 # Several scenarios where epilogue fusion is not counted in: 317 # 1. For bfloat16, the epilogue fusion is part of the template, 318 # not fused via scheduler. This will also be true for float16 when 319 # hardware has the float16 instruction. The exception is mul or 320 # div fusion which is not supported for oneDNN linear. 321 # 2. For float16, since oneDNN linear is not applied, linear w/o bias 322 # plus epilogue add is treated as linear w/ bias. 323 # 3. For float32, when dynamic shapes is enabled, mkl linear is not applied. 324 # and linear w/o bias plus epilogue add is treated as addmm. 325 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) 326 else: 327 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 328 329 @inductor_config.patch({"freezing": True}) 330 @patches 331 @torch.no_grad 332 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 333 @parametrize("batch_size", (384,)) 334 @parametrize("in_features", (196,)) 335 @parametrize("out_features", (128, 129)) 336 @parametrize("bias", (True, False)) 337 @parametrize( 338 "epilogue", 339 ( 340 "none", 341 "relu", 342 "add", 343 "sub", 344 "mul", 345 ), 346 ) 347 @dtypes(torch.float, torch.bfloat16, torch.half) 348 def test_linear_with_transpose( 349 self, batch_size, in_features, out_features, bias, epilogue, dtype 350 ): 351 class M(torch.nn.Module): 352 def __init__(self, bias, epilogue, other): 353 super().__init__() 354 self.epilogue = _get_epilogue(epilogue, other) 355 self.linear = torch.nn.Linear(in_features, out_features, bias) 356 357 def forward(self, x, y): 358 return self.epilogue(self.linear(x)).transpose(0, 1) + y 359 360 counters.clear() 361 v = torch.randn(batch_size, in_features).to(dtype=dtype) 362 u = torch.randn(out_features, batch_size).to(dtype=dtype) 363 other = torch.randn(batch_size, out_features).to(dtype=dtype) 364 mod = M(bias=bias, epilogue=epilogue, other=other).to(dtype=dtype).eval() 365 with verify(dtype) as (atol, rtol): 366 self.common(mod, (v, u), atol=atol, rtol=rtol) 367 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 368 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 369 370 @inductor_config.patch({"freezing": True}) 371 @patches 372 @torch.no_grad 373 @parametrize("batch_size", (1,)) 374 @parametrize("in_features", (16,)) 375 @parametrize("image_size", (18,)) 376 @parametrize("out_features", (32,)) 377 @parametrize( 378 "bias", 379 ( 380 False, 381 True, 382 ), 383 ) 384 @parametrize( 385 "has_non_epilogue_users", 386 ( 387 True, 388 False, 389 ), 390 ) 391 @dtypes(torch.bfloat16) 392 def test_linear_with_permute( 393 self, 394 batch_size, 395 in_features, 396 image_size, 397 out_features, 398 bias, 399 has_non_epilogue_users, 400 dtype, 401 ): 402 # Reproducer from the convnext model in timm 403 class M(torch.nn.Module): 404 def __init__(self, bias, has_non_epilogue_users): 405 super().__init__() 406 self.linear = torch.nn.Linear(in_features, out_features, bias) 407 self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1) 408 self.conv = torch.nn.Conv2d( 409 out_features, 410 out_features, 411 kernel_size=7, 412 padding=3, 413 groups=out_features, 414 ) 415 self.linear2 = torch.nn.Linear(out_features, out_features, bias) 416 self._frozen_param400 = torch.randn(batch_size, out_features, 1, 1) 417 self.has_non_epilogue_users = has_non_epilogue_users 418 419 def forward(self, mul_272, _convolution_pointwise_default_31): 420 out1 = torch.ops.prims.convert_element_type.default( 421 mul_272, torch.bfloat16 422 ) 423 mul_272 = None 424 425 _linear_pointwise_default_131 = self.linear(out1) 426 permute_188 = torch.ops.aten.permute.default( 427 _linear_pointwise_default_131, [0, 3, 1, 2] 428 ) 429 430 mul_273 = torch.ops.aten.mul.Tensor(permute_188, self._frozen_param398) 431 add_187 = torch.ops.aten.add.Tensor( 432 mul_273, _convolution_pointwise_default_31 433 ) 434 convert_element_type_847 = torch.ops.prims.convert_element_type.default( 435 add_187, torch.bfloat16 436 ) 437 _convolution_pointwise_default_29 = self.conv(convert_element_type_847) 438 permute_189 = torch.ops.aten.permute.default( 439 _convolution_pointwise_default_29, [0, 2, 3, 1] 440 ) 441 permute_189 = self.linear2(permute_189) 442 permute_189 = torch.ops.aten.permute.default(permute_189, [0, 3, 1, 2]) 443 permute_189 = torch.ops.aten.mul.Tensor( 444 permute_189, self._frozen_param400 445 ) 446 # If template_buffer will be used by nodes other than the epilogue nodes, 447 # we can't alias the template_buffer with the Y buffer. 448 if self.has_non_epilogue_users: 449 add_191 = torch.ops.aten.add.Tensor(permute_189, add_187) 450 return add_191 451 return permute_189 452 453 view_12 = torch.randn(batch_size, image_size, image_size, in_features) 454 _convolution_pointwise_default_31 = torch.randn( 455 batch_size, out_features, image_size, image_size 456 ).to(memory_format=torch.channels_last) 457 458 mod = M(bias=bias, has_non_epilogue_users=has_non_epilogue_users).eval() 459 with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast(): 460 self.common( 461 mod, 462 ( 463 view_12, 464 _convolution_pointwise_default_31, 465 ), 466 atol=atol, 467 rtol=rtol, 468 ) 469 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) 470 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) 471 472 @inductor_config.patch({"freezing": True}) 473 @patches 474 @torch.no_grad 475 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 476 @parametrize("batch_size", (8,)) 477 @parametrize("in_features", (3,)) 478 @parametrize("linear_in_features", (384,)) 479 @parametrize("out_features", (196,)) 480 @parametrize("bias", (True,)) 481 @dtypes(torch.float) 482 def test_linear_with_input_of_flexible_layout( 483 self, batch_size, in_features, linear_in_features, out_features, bias, dtype 484 ): 485 # Reproducer from the resmlp_12_224 model in timm 486 flatten_BS = int(batch_size * linear_in_features) 487 488 class M(torch.nn.Module): 489 def __init__(self, bias): 490 super().__init__() 491 self.conv = torch.nn.Conv2d( 492 in_features, 493 linear_in_features, 494 kernel_size=16, 495 padding=0, 496 stride=16, 497 dilation=1, 498 groups=1, 499 ) 500 self._frozen_param151 = torch.randn(1, 1, linear_in_features) 501 self._frozen_param3 = torch.randn(1, 1, linear_in_features) 502 self._frozen_param2 = torch.randn(linear_in_features) 503 504 self.linear = torch.nn.Linear(out_features, out_features, bias) 505 506 def forward(self, arg150_1): 507 _convolution_pointwise_default = self.conv(arg150_1) 508 view_73 = torch.ops.aten.reshape.default( 509 _convolution_pointwise_default, 510 [batch_size, linear_in_features, out_features], 511 ) 512 _convolution_pointwise_default = None 513 permute_62 = torch.ops.aten.permute.default(view_73, [0, 2, 1]) 514 view_73 = None 515 mul_111 = torch.ops.aten.mul.Tensor(self._frozen_param151, permute_62) 516 add_73 = torch.ops.aten.add.Tensor(self._frozen_param3, mul_111) 517 permute_63 = torch.ops.aten.permute.default(add_73, [0, 2, 1]) 518 add_73 = None 519 view_74 = torch.ops.aten.reshape.default( 520 permute_63, [flatten_BS, out_features] 521 ) 522 permute_63 = None 523 _mkl_linear_36 = self.linear(view_74) 524 view_75 = torch.ops.aten.reshape.default( 525 _mkl_linear_36, [batch_size, linear_in_features, out_features] 526 ) 527 _mkl_linear_36 = None 528 permute_65 = torch.ops.aten.permute.default(view_75, [0, 2, 1]) 529 view_75 = None 530 mul_112 = torch.ops.aten.mul.Tensor(self._frozen_param2, permute_65) 531 _frozen_param2 = permute_65 = None 532 add_74 = torch.ops.aten.add.Tensor(permute_62, mul_112) 533 permute_62 = mul_112 = None 534 return add_74 535 536 v = torch.randn(batch_size, in_features, 224, 224).to(dtype=dtype) 537 mod = M(bias=bias).to(dtype=dtype).eval() 538 with verify(dtype) as (atol, rtol): 539 self.common(mod, (v,), atol=atol, rtol=rtol) 540 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 541 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 542 543 @inductor_config.patch({"freezing": True}) 544 @patches 545 @torch.no_grad 546 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 547 @parametrize("batch_size", (384,)) 548 @parametrize("in_features", (196,)) 549 @parametrize("out_features", (384, 385)) 550 @parametrize("bias", (True, False)) 551 @parametrize( 552 "unary", 553 ("relu",), 554 ) 555 @parametrize( 556 "binary", 557 ( 558 "add", 559 "sub", 560 "mul", 561 "div", 562 ), 563 ) 564 @dtypes(torch.float, torch.bfloat16, torch.half) 565 def test_linear_with_unary_binary( 566 self, batch_size, in_features, out_features, bias, unary, binary, dtype 567 ): 568 class M(torch.nn.Module): 569 def __init__(self, bias, unary, binary, other): 570 super().__init__() 571 self.linear = torch.nn.Linear(in_features, out_features, bias) 572 self.unary = _get_epilogue(unary) 573 self.binary = _get_epilogue(binary, other) 574 575 def forward(self, x): 576 return self.binary(self.unary(self.linear(x))) 577 578 counters.clear() 579 v = torch.randn(batch_size, in_features).to(dtype=dtype) 580 u = torch.randn(batch_size, out_features).to(dtype=dtype) 581 mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval() 582 with verify(dtype) as (atol, rtol): 583 self.common(mod, (v,), atol=atol, rtol=rtol) 584 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 585 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 586 587 @inductor_config.patch({"freezing": True}) 588 @patches 589 @torch.no_grad 590 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 591 @parametrize("batch_size", (384,)) 592 @parametrize("in_features", (196,)) 593 @parametrize("out_features", (384,)) 594 @parametrize("bias", (True, False)) 595 @parametrize( 596 "binary", 597 ("add",), 598 ) 599 @dtypes(torch.float, torch.bfloat16, torch.half) 600 def test_linear_with_binary_input_3d( 601 self, batch_size, in_features, out_features, bias, binary, dtype 602 ): 603 class M(torch.nn.Module): 604 def __init__(self, bias, binary, other): 605 super().__init__() 606 self.linear = torch.nn.Linear(in_features, out_features, bias) 607 self.binary = _get_epilogue(binary, other) 608 609 def forward(self, x): 610 return self.binary(self.linear(x)) 611 612 counters.clear() 613 B = (2, batch_size) 614 v = torch.randn(*B, in_features).to(dtype=dtype) 615 u = torch.randn(*B, out_features).to(dtype=dtype) 616 mod = M(bias=bias, binary=binary, other=u).to(dtype=dtype).eval() 617 with verify(dtype) as (atol, rtol): 618 self.common(mod, (v,), atol=atol, rtol=rtol) 619 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 620 621 @inductor_config.patch({"freezing": True}) 622 @patches 623 @torch.no_grad 624 @parametrize("batch_size", (1024,)) 625 @parametrize("in_features", (1024,)) 626 @parametrize("out_features", (1024, 1025)) 627 @parametrize("bias", (True, False)) 628 @dtypes(torch.bfloat16) 629 def test_linear_amx(self, batch_size, in_features, out_features, bias, dtype): 630 class M(torch.nn.Module): 631 def __init__(self, bias): 632 super().__init__() 633 self.linear = torch.nn.Linear(in_features, out_features, bias) 634 635 def forward(self, x): 636 return self.linear(x) 637 638 counters.clear() 639 v = torch.randn(batch_size, in_features).to(dtype=dtype) 640 mod = M(bias=bias).to(dtype=dtype).eval() 641 with verify(dtype) as (atol, rtol): 642 self.common(mod, (v,), atol=atol, rtol=rtol) 643 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 644 vec_amx = VecAMX() 645 self._check_amx_counter(vec_amx) 646 647 @inductor_config.patch({"freezing": True}) 648 @patches 649 @torch.no_grad 650 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 651 @parametrize("batch_size", (8,)) 652 @parametrize("in_features", (128,)) 653 @parametrize("in_features_2", (196,)) 654 @parametrize("out_features", (256,)) 655 @parametrize( 656 "bias", 657 (True,), 658 ) 659 @dtypes(torch.float32) 660 def test_linear_with_multiple_reindexers( 661 self, 662 batch_size, 663 in_features, 664 in_features_2, 665 out_features, 666 bias, 667 dtype, 668 ): 669 flatten_BS = int(batch_size * in_features_2) 670 671 # Reproducer from the levit_128 model in timm 672 class M(torch.nn.Module): 673 def __init__(self, bias): 674 super().__init__() 675 self.conv = torch.nn.Conv2d( 676 64, 677 128, 678 kernel_size=3, 679 padding=1, 680 stride=2, 681 dilation=1, 682 groups=1, 683 ) 684 self.linear = torch.nn.Linear(in_features, out_features, bias=False) 685 self._frozen_param221 = torch.randn(out_features) 686 self._frozen_param389 = torch.randn(out_features) 687 self._frozen_param20 = torch.randn(out_features) 688 self._frozen_param21 = torch.randn(out_features) 689 690 def forward(self, view_368): 691 _mkl_linear_57 = self.linear(view_368) 692 view_369 = torch.ops.aten.reshape.default( 693 _mkl_linear_57, [batch_size, in_features_2, out_features] 694 ) 695 _mkl_linear_57 = None 696 697 view_370 = torch.ops.aten.reshape.default( 698 view_369, [flatten_BS, out_features] 699 ) 700 view_369 = None 701 sub_85 = torch.ops.aten.sub.Tensor(view_370, self._frozen_param221) 702 view_370 = _frozen_param221 = None 703 mul_261 = torch.ops.aten.mul.Tensor(sub_85, self._frozen_param389) 704 sub_85 = _frozen_param389 = None 705 mul_262 = torch.ops.aten.mul.Tensor(mul_261, self._frozen_param20) 706 mul_261 = _frozen_param20 = None 707 add_219 = torch.ops.aten.add.Tensor(mul_262, self._frozen_param21) 708 mul_262 = _frozen_param21 = None 709 view_371 = torch.ops.aten.reshape.default( 710 add_219, [batch_size, in_features_2, out_features] 711 ) 712 add_219 = None 713 714 add_220 = torch.ops.aten.add.Tensor(view_371, 3) 715 clamp_min_35 = torch.ops.aten.clamp_min.default(add_220, 0) 716 add_220 = None 717 clamp_max_35 = torch.ops.aten.clamp_max.default(clamp_min_35, 6) 718 clamp_min_35 = None 719 mul_263 = torch.ops.aten.mul.Tensor(view_371, clamp_max_35) 720 view_371 = clamp_max_35 = None 721 div_51 = torch.ops.aten.div.Tensor(mul_263, 6) 722 mul_263 = None 723 724 return div_51 725 726 view_368 = torch.randn(flatten_BS, in_features) 727 728 mod = M(bias=bias).eval() 729 with verify(dtype) as (atol, rtol): 730 self.common( 731 mod, 732 (view_368,), 733 atol=atol, 734 rtol=rtol, 735 ) 736 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 737 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2) 738 739 @inductor_config.patch({"freezing": True}) 740 @patches 741 @torch.no_grad 742 @parametrize("batch_size", (384,)) 743 @parametrize("in_features", (196,)) 744 @parametrize("out_features", (384,)) 745 @parametrize("bias", (True, False)) 746 @dtypes(torch.bfloat16) 747 def test_linear_with_embedding( 748 self, batch_size, in_features, out_features, bias, dtype 749 ): 750 class M(torch.nn.Module): 751 def __init__(self, bias): 752 super().__init__() 753 self.linear = torch.nn.Linear(in_features, out_features, bias).to( 754 dtype=dtype 755 ) 756 self.emb = torch.nn.Embedding(64, out_features) 757 758 def forward(self, idx, x): 759 return self.emb(idx) + self.linear(x) 760 761 idx = torch.randint(0, 64, (batch_size,)) 762 x = torch.randn(batch_size, in_features).to(dtype=dtype) 763 mod = M(bias=bias).eval() 764 with verify(dtype) as (atol, rtol): 765 self.common(mod, (idx, x), atol=atol, rtol=rtol) 766 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 767 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 768 769 @inductor_config.patch({"freezing": True}) 770 @patches 771 @torch.no_grad 772 @parametrize("batch_size", (2,)) 773 @parametrize("in_features", (16,)) 774 @parametrize("seq_lens", (128,)) 775 @parametrize("out_features", (32,)) 776 @parametrize("bias", (True,)) 777 @dtypes(torch.bfloat16) 778 def test_linear_with_indirect_indexing( 779 self, batch_size, in_features, seq_lens, out_features, bias, dtype 780 ): 781 # Reproducer from the GPT2ForSequenceClassification model in HuggingFace 782 class M(torch.nn.Module): 783 def __init__(self, bias): 784 super().__init__() 785 self.wte = torch.nn.Embedding(128, seq_lens) 786 self.wpe = torch.nn.Embedding(in_features, seq_lens) 787 self.linear = torch.nn.Linear(out_features, seq_lens, bias) 788 789 def forward(self, view_12, input_ids, view_9): 790 inputs_embeds = self.wte(input_ids) 791 792 position_ids = torch.arange(0, in_features, dtype=torch.long) 793 position_ids = position_ids.unsqueeze(0) 794 position_embeds = self.wpe(position_ids) 795 796 add = inputs_embeds + position_embeds 797 add_4 = view_9 + add 798 799 _linear_pointwise_default_45 = self.linear(view_12) 800 801 view_13 = torch.ops.aten.reshape.default( 802 _linear_pointwise_default_45, [batch_size, in_features, seq_lens] 803 ) 804 out = torch.ops.aten.add.Tensor(add_4, view_13) 805 806 return out 807 808 view_12 = torch.randn(batch_size * in_features, out_features) 809 input_ids = torch.randint(0, 128, (batch_size, in_features)) 810 view_9 = torch.randn(batch_size, in_features, seq_lens) 811 mod = M(bias=bias).eval() 812 with verify(dtype) as (atol, rtol), torch.cpu.amp.autocast(): 813 self.common( 814 mod, 815 ( 816 view_12, 817 input_ids, 818 view_9, 819 ), 820 atol=atol, 821 rtol=rtol, 822 ) 823 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 824 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 825 826 @inductor_config.patch({"freezing": True}) 827 @patches 828 @torch.no_grad 829 @parametrize("batch_size", (8,)) 830 @parametrize("in_features", (3,)) 831 @parametrize("in_features2", (192,)) 832 @parametrize("image_size", (224,)) 833 @parametrize("out_features", (64,)) 834 @parametrize( 835 "bias", 836 (True,), 837 ) 838 @dtypes(torch.float32) 839 def test_linear_with_in_out_buffer( 840 self, 841 batch_size, 842 in_features, 843 in_features2, 844 image_size, 845 out_features, 846 bias, 847 dtype, 848 ): 849 # Reproducer from the coat_lite_mini model in timm 850 class M(torch.nn.Module): 851 def __init__(self, bias): 852 super().__init__() 853 self._frozen_param398 = torch.randn(batch_size, out_features, 1, 1) 854 self.conv = torch.nn.Conv2d( 855 in_features, 856 out_features, 857 kernel_size=4, 858 padding=0, 859 stride=4, 860 dilation=1, 861 groups=1, 862 ) 863 self.conv2 = torch.nn.Conv2d( 864 out_features, 865 out_features, 866 kernel_size=3, 867 padding=1, 868 stride=1, 869 dilation=1, 870 groups=out_features, 871 ) 872 873 self.conv3 = torch.nn.Conv2d( 874 16, 875 16, 876 kernel_size=3, 877 padding=1, 878 stride=1, 879 dilation=1, 880 groups=16, 881 ) 882 883 self.conv4 = torch.nn.Conv2d( 884 24, 885 24, 886 kernel_size=5, 887 padding=2, 888 stride=1, 889 dilation=1, 890 groups=24, 891 ) 892 893 self.conv5 = torch.nn.Conv2d( 894 24, 895 24, 896 kernel_size=7, 897 padding=3, 898 stride=1, 899 dilation=1, 900 groups=24, 901 ) 902 903 self.linear = torch.nn.Linear(out_features, in_features2, bias) 904 905 self.linear2 = torch.nn.Linear(out_features, out_features, bias) 906 self._frozen_param2 = torch.randn(out_features) 907 self._frozen_param3 = torch.randn(out_features) 908 self._frozen_param7 = torch.randn(out_features) 909 self._frozen_param8 = torch.randn(out_features) 910 self._frozen_param153 = torch.randn(batch_size, 1, out_features) 911 912 def forward(self, arg152_1): 913 _convolution_pointwise_default_35 = self.conv(arg152_1) 914 arg152_1 = None 915 916 view_168 = torch.ops.aten.reshape.default( 917 _convolution_pointwise_default_35, [8, 64, 3136] 918 ) 919 _convolution_pointwise_default_35 = None 920 permute_97 = torch.ops.aten.permute.default(view_168, [0, 2, 1]) 921 view_168 = None 922 clone_65 = torch.ops.aten.clone.default( 923 permute_97, memory_format=torch.contiguous_format 924 ) 925 permute_97 = None 926 var_mean_21 = torch.ops.aten.var_mean.correction( 927 clone_65, [2], correction=0, keepdim=True 928 ) 929 getitem_90 = var_mean_21[0] 930 getitem_91 = var_mean_21[1] 931 var_mean_21 = None 932 add_82 = torch.ops.aten.add.Tensor(getitem_90, 1e-05) 933 getitem_90 = None 934 rsqrt_21 = torch.ops.aten.rsqrt.default(add_82) 935 add_82 = None 936 sub_29 = torch.ops.aten.sub.Tensor(clone_65, getitem_91) 937 clone_65 = getitem_91 = None 938 mul_82 = torch.ops.aten.mul.Tensor(sub_29, rsqrt_21) 939 sub_29 = rsqrt_21 = None 940 mul_83 = torch.ops.aten.mul.Tensor(mul_82, self._frozen_param2) 941 mul_82 = None 942 add_83 = torch.ops.aten.add.Tensor(mul_83, self._frozen_param3) 943 mul_83 = None 944 _frozen_param153 = self._frozen_param153 945 cat_20 = torch.ops.aten.cat.default([_frozen_param153, add_83], 1) 946 _frozen_param153 = add_83 = None 947 slice_111 = torch.ops.aten.slice.Tensor(cat_20, 1, 0, 1) 948 slice_113 = torch.ops.aten.slice.Tensor( 949 cat_20, 1, 1, 9223372036854775807 950 ) 951 cat_20 = None 952 permute_98 = torch.ops.aten.permute.default(slice_113, [0, 2, 1]) 953 slice_113 = None 954 view_169 = torch.ops.aten.reshape.default(permute_98, [8, 64, 56, 56]) 955 permute_98 = None 956 _convolution_pointwise_default_34 = self.conv2(view_169) 957 958 add_84 = torch.ops.aten.add.Tensor( 959 _convolution_pointwise_default_34, view_169 960 ) 961 _convolution_pointwise_default_34 = view_169 = None 962 view_170 = torch.ops.aten.reshape.default(add_84, [8, 64, 3136]) 963 add_84 = None 964 permute_99 = torch.ops.aten.permute.default(view_170, [0, 2, 1]) 965 view_170 = None 966 cat_21 = torch.ops.aten.cat.default([slice_111, permute_99], 1) 967 slice_111 = permute_99 = None 968 var_mean_22 = torch.ops.aten.var_mean.correction( 969 cat_21, [2], correction=0, keepdim=True 970 ) 971 getitem_92 = var_mean_22[0] 972 getitem_93 = var_mean_22[1] 973 var_mean_22 = None 974 add_85 = torch.ops.aten.add.Tensor(getitem_92, 1e-06) 975 getitem_92 = None 976 rsqrt_22 = torch.ops.aten.rsqrt.default(add_85) 977 add_85 = None 978 sub_30 = torch.ops.aten.sub.Tensor(cat_21, getitem_93) 979 getitem_93 = None 980 mul_84 = torch.ops.aten.mul.Tensor(sub_30, rsqrt_22) 981 sub_30 = rsqrt_22 = None 982 mul_85 = torch.ops.aten.mul.Tensor(mul_84, self._frozen_param7) 983 mul_84 = None 984 add_86 = torch.ops.aten.add.Tensor(mul_85, self._frozen_param8) 985 mul_85 = None 986 view_171 = torch.ops.aten.reshape.default(add_86, [25096, 64]) 987 add_86 = None 988 989 _mkl_linear_32 = self.linear(view_171) 990 view_171 = None 991 992 view_172 = torch.ops.aten.reshape.default( 993 _mkl_linear_32, [8, 3137, 192] 994 ) 995 _mkl_linear_32 = None 996 view_173 = torch.ops.aten.reshape.default(view_172, [8, 3137, 3, 8, 8]) 997 view_172 = None 998 permute_101 = torch.ops.aten.permute.default(view_173, [2, 0, 3, 1, 4]) 999 view_173 = None 1000 unbind_8 = torch.ops.aten.unbind.int(permute_101) 1001 permute_101 = None 1002 getitem_94 = unbind_8[0] 1003 getitem_95 = unbind_8[1] 1004 getitem_96 = unbind_8[2] 1005 unbind_8 = None 1006 clone_66 = torch.ops.aten.clone.default( 1007 getitem_95, memory_format=torch.contiguous_format 1008 ) 1009 getitem_95 = None 1010 amax_8 = torch.ops.aten.amax.default(clone_66, [2], True) 1011 sub_31 = torch.ops.aten.sub.Tensor(clone_66, amax_8) 1012 clone_66 = amax_8 = None 1013 exp_8 = torch.ops.aten.exp.default(sub_31) 1014 sub_31 = None 1015 sum_9 = torch.ops.aten.sum.dim_IntList(exp_8, [2], True) 1016 div_8 = torch.ops.aten.div.Tensor(exp_8, sum_9) 1017 exp_8 = sum_9 = None 1018 permute_102 = torch.ops.aten.permute.default(div_8, [0, 1, 3, 2]) 1019 div_8 = None 1020 expand_37 = torch.ops.aten.expand.default(permute_102, [8, 8, 8, 3137]) 1021 permute_102 = None 1022 view_174 = torch.ops.aten.reshape.default(expand_37, [64, 8, 3137]) 1023 expand_37 = None 1024 expand_38 = torch.ops.aten.expand.default(getitem_96, [8, 8, 3137, 8]) 1025 clone_67 = torch.ops.aten.clone.default( 1026 expand_38, memory_format=torch.contiguous_format 1027 ) 1028 expand_38 = None 1029 view_175 = torch.ops.aten.reshape.default(clone_67, [64, 3137, 8]) 1030 clone_67 = None 1031 bmm_16 = torch.ops.aten.bmm.default(view_174, view_175) 1032 view_174 = view_175 = None 1033 view_176 = torch.ops.aten.reshape.default(bmm_16, [8, 8, 8, 8]) 1034 bmm_16 = None 1035 expand_39 = torch.ops.aten.expand.default(getitem_94, [8, 8, 3137, 8]) 1036 clone_68 = torch.ops.aten.clone.default( 1037 expand_39, memory_format=torch.contiguous_format 1038 ) 1039 expand_39 = None 1040 view_177 = torch.ops.aten.reshape.default(clone_68, [64, 3137, 8]) 1041 clone_68 = None 1042 expand_40 = torch.ops.aten.expand.default(view_176, [8, 8, 8, 8]) 1043 view_176 = None 1044 view_178 = torch.ops.aten.reshape.default(expand_40, [64, 8, 8]) 1045 expand_40 = None 1046 bmm_17 = torch.ops.aten.bmm.default(view_177, view_178) 1047 view_177 = view_178 = None 1048 view_179 = torch.ops.aten.reshape.default(bmm_17, [8, 8, 3137, 8]) 1049 bmm_17 = None 1050 slice_116 = torch.ops.aten.slice.Tensor( 1051 getitem_94, 2, 1, 9223372036854775807 1052 ) 1053 getitem_94 = None 1054 slice_120 = torch.ops.aten.slice.Tensor( 1055 getitem_96, 2, 1, 9223372036854775807 1056 ) 1057 getitem_96 = None 1058 permute_103 = torch.ops.aten.permute.default(slice_120, [0, 1, 3, 2]) 1059 slice_120 = None 1060 view_180 = torch.ops.aten.reshape.default(permute_103, [8, 64, 56, 56]) 1061 permute_103 = None 1062 split_with_sizes_8 = torch.ops.aten.split_with_sizes.default( 1063 view_180, [16, 24, 24], 1 1064 ) 1065 view_180 = None 1066 getitem_97 = split_with_sizes_8[0] 1067 getitem_98 = split_with_sizes_8[1] 1068 getitem_99 = split_with_sizes_8[2] 1069 split_with_sizes_8 = None 1070 1071 _convolution_pointwise_default_33 = self.conv3(getitem_97) 1072 _convolution_pointwise_default_32 = self.conv4(getitem_98) 1073 _convolution_pointwise_default_31 = self.conv5(getitem_99) 1074 1075 cat_22 = torch.ops.aten.cat.default( 1076 [ 1077 _convolution_pointwise_default_33, 1078 _convolution_pointwise_default_32, 1079 _convolution_pointwise_default_31, 1080 ], 1081 1, 1082 ) 1083 _convolution_pointwise_default_33 = ( 1084 _convolution_pointwise_default_32 1085 ) = _convolution_pointwise_default_31 = None 1086 view_181 = torch.ops.aten.reshape.default(cat_22, [8, 8, 8, 3136]) 1087 cat_22 = None 1088 permute_104 = torch.ops.aten.permute.default(view_181, [0, 1, 3, 2]) 1089 view_181 = None 1090 1091 mul_86 = torch.ops.aten.mul.Tensor(slice_116, permute_104) 1092 slice_116 = permute_104 = None 1093 constant_pad_nd_8 = torch.ops.aten.constant_pad_nd.default( 1094 mul_86, [0, 0, 1, 0, 0, 0], 0.0 1095 ) 1096 mul_86 = None 1097 mul_87 = torch.ops.aten.mul.Tensor(view_179, 0.3535533905932738) 1098 view_179 = None 1099 add_87 = torch.ops.aten.add.Tensor(mul_87, constant_pad_nd_8) 1100 mul_87 = constant_pad_nd_8 = None 1101 return add_87 1102 1103 view_12 = torch.randn(batch_size, in_features, image_size, image_size) 1104 1105 mod = M(bias=bias).eval() 1106 with verify(dtype) as (atol, rtol): 1107 self.common( 1108 mod, 1109 (view_12,), 1110 atol=atol, 1111 rtol=rtol, 1112 ) 1113 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 1114 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) 1115 1116 @inductor_config.patch({"freezing": True}) 1117 @patches 1118 @torch.no_grad 1119 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 1120 @parametrize("batch_size", (32,)) 1121 @parametrize("in_features", (128,)) 1122 @parametrize("out_features", (64, 65)) 1123 @parametrize("bias", (False, True)) 1124 @parametrize("input_3d", (False, True)) 1125 @dtypes(torch.float32, torch.bfloat16) 1126 @parametrize( 1127 "epilogue", 1128 ( 1129 "none", 1130 "relu", 1131 "gelu", 1132 ), 1133 ) 1134 @skipIfWindows(msg="Windows don't support quantize.") 1135 def test_quantized_linear_with_pointwise( 1136 self, batch_size, in_features, out_features, bias, input_3d, dtype, epilogue 1137 ): 1138 B = (2, batch_size) if input_3d else (batch_size,) 1139 input = torch.randn(*B, in_features).to(dtype=torch.float32) 1140 1141 class M(torch.nn.Module): 1142 def __init__(self, bias): 1143 super().__init__() 1144 self.linear = torch.nn.Linear(in_features, out_features, bias) 1145 self.epilogue = _get_epilogue(epilogue) 1146 self.linear2 = torch.nn.Linear(out_features, out_features, bias) 1147 self.epilogue2 = _get_epilogue(epilogue) 1148 1149 def forward(self, x): 1150 res = self.epilogue(self.linear(x)) 1151 res = self.epilogue2(self.linear2(res)) 1152 return res 1153 1154 counters.clear() 1155 ref_quantized_mod = _generate_qdq_quantized_model( 1156 M(bias=bias).eval(), 1157 (input,), 1158 ) 1159 1160 atol, rtol = 1e-3, 1e-3 1161 if dtype == torch.bfloat16: 1162 atol, rtol = 5e-2, 5e-2 1163 1164 with patch.object( 1165 select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol) 1166 ), torch.no_grad(), torch.autocast( 1167 "cpu", enabled=(dtype == torch.bfloat16), dtype=dtype 1168 ): 1169 ref_res = ref_quantized_mod(input) 1170 cfn = torch.compile(ref_quantized_mod) 1171 res = cfn(input) 1172 self.assertEqual( 1173 res, 1174 ref_res, 1175 atol=atol, 1176 rtol=rtol, 1177 equal_nan=True, 1178 exact_dtype=True, 1179 ) 1180 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2) 1181 self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) 1182 1183 @inductor_config.patch({"freezing": True}) 1184 @patches 1185 @torch.no_grad 1186 @dtypes(torch.bfloat16) 1187 @parametrize("batch_size", (32,)) 1188 @parametrize("in_features", (128,)) 1189 @parametrize("out_features", (64, 65)) 1190 def test_int8_woq_mm(self, dtype, batch_size, in_features, out_features): 1191 # x will be reshaped from 3d to 2d 1192 second_dim_size = 8 1193 1194 def _convert_weight_to_int8pack(w): 1195 scale, zp = _calculate_dynamic_per_channel_qparams( 1196 w.to(torch.float), torch.int8 1197 ) 1198 scale = torch.from_numpy(scale) 1199 zp = torch.from_numpy(zp) 1200 w_int8 = torch.ao.quantization.fx._decomposed.quantize_per_channel( 1201 input=w, 1202 scales=scale, 1203 zero_points=zp, 1204 axis=0, 1205 quant_min=-128, 1206 quant_max=127, 1207 dtype=torch.int8, 1208 ) 1209 return w_int8, scale.to(torch.bfloat16) 1210 1211 class M(torch.nn.Module): 1212 def __init__(self, w): 1213 super().__init__() 1214 self.linear_weight = torch.nn.Parameter(w, requires_grad=False) 1215 1216 def forward(self, x, scale): 1217 return ( 1218 torch.nn.functional.linear(x, self.linear_weight.to(x.dtype)) 1219 * scale 1220 ) 1221 1222 counters.clear() 1223 # Currently, the corresponding torch.fx pattern only supports 3D x 1224 # Add 2D X case once the corresponding pattern-matcher pattern is added 1225 x = torch.rand((batch_size, second_dim_size, in_features), dtype=dtype) 1226 w = torch.rand((out_features, in_features), dtype=dtype) 1227 w_int8pack, w_scales = _convert_weight_to_int8pack(w) 1228 mod = M(w_int8pack).eval() 1229 self.common(mod, (x, w_scales)) 1230 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 1231 vec_amx = VecAMX() 1232 self._check_amx_counter(vec_amx) 1233 1234 @inductor_config.patch({"freezing": True}) 1235 @patches 1236 @torch.no_grad 1237 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 1238 @parametrize("batch_size", (32,)) 1239 @parametrize("in_features", (128,)) 1240 @parametrize("out_features", (64, 65)) 1241 @parametrize("bias", (False, True)) 1242 @parametrize("input_3d", (False, True)) 1243 @parametrize("int8_mixed_bf16", (False, True)) 1244 @dtypes(torch.float32, torch.bfloat16) 1245 @parametrize( 1246 "epilogue", 1247 ( 1248 "none", 1249 "relu", 1250 ), 1251 ) 1252 @skipIfWindows(msg="Windows don't support quantize.") 1253 def test_quantized_linear_with_pointwise_binary( 1254 self, 1255 batch_size, 1256 in_features, 1257 out_features, 1258 bias, 1259 input_3d, 1260 int8_mixed_bf16, 1261 dtype, 1262 epilogue, 1263 ): 1264 if not int8_mixed_bf16 and dtype == torch.bfloat16: 1265 return 1266 B = (2, batch_size) if input_3d else (batch_size,) 1267 input = torch.randn(*B, in_features).to(dtype=torch.float32) 1268 1269 other = torch.randn(*B, out_features).to(dtype=dtype) 1270 # Avoid hiting qlinear inplace sum fusion 1271 if input_3d: 1272 other2 = torch.randn(B[0] * B[1], out_features).to(dtype=dtype) 1273 else: 1274 other2 = torch.randn(1, *B, out_features).to(dtype=dtype) 1275 1276 class M(torch.nn.Module): 1277 def __init__(self, bias, input_3d): 1278 super().__init__() 1279 self.linear = torch.nn.Linear(in_features, out_features, bias) 1280 self.epilogue = _get_epilogue(epilogue) 1281 self.linear2 = torch.nn.Linear(out_features, out_features, bias) 1282 self.epilogue2 = _get_epilogue(epilogue) 1283 self.input_3d = input_3d 1284 1285 def forward(self, x, other, other2): 1286 res = self.epilogue(self.linear(x) + other) 1287 # Avoid hiting qlinear inplace sum fusion 1288 if self.input_3d: 1289 other2 = other2.view(2, other2.size(0) // 2, other2.size(1)) 1290 else: 1291 other2 = other2.view(other2.size(1), other2.size(2)) 1292 res = self.epilogue2(self.linear2(res) + other2) 1293 return res 1294 1295 counters.clear() 1296 ref_quantized_mod = _generate_qdq_quantized_model( 1297 M(bias=bias, input_3d=input_3d).eval(), 1298 (input, other, other2), 1299 ) 1300 atol, rtol = 5e-2, 5e-2 1301 with patch.object( 1302 select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol) 1303 ), torch.no_grad(), torch.autocast( 1304 "cpu", enabled=int8_mixed_bf16, dtype=torch.bfloat16 1305 ): 1306 ref_res = ref_quantized_mod(input, other, other2) 1307 cfn = torch.compile(ref_quantized_mod) 1308 res = cfn(input, other, other2) 1309 self.assertEqual( 1310 res, 1311 ref_res, 1312 atol=atol, 1313 rtol=rtol, 1314 equal_nan=True, 1315 exact_dtype=True, 1316 ) 1317 self.assertEqual( 1318 counters["inductor"]["select_algorithm_autotune"], 1319 2, 1320 ) 1321 self.assertEqual( 1322 counters["inductor"]["cpp_epilogue_fusion_counter"], 1323 0, 1324 ) 1325 1326 @inductor_config.patch({"freezing": True}) 1327 @patches 1328 @torch.no_grad 1329 @parametrize("batch_size", (3, 16, 32, 49)) 1330 @parametrize("in_features", (4, 68, 128)) # k should be a multiple of 4 1331 @parametrize("out_features", (64, 65)) 1332 @parametrize("bias", (True, False)) 1333 @skipIfWindows(msg="Windows don't support quantize.") 1334 def test_quantized_linear_amx(self, batch_size, in_features, out_features, bias): 1335 class M(torch.nn.Module): 1336 def __init__(self, bias): 1337 super().__init__() 1338 self.linear = torch.nn.Linear(in_features, out_features, bias) 1339 1340 def forward(self, x): 1341 return self.linear(x) 1342 1343 counters.clear() 1344 v = torch.randn(batch_size, in_features).to(dtype=torch.float32) 1345 ref_quantized_mod = _generate_qdq_quantized_model( 1346 M(bias=bias).eval(), 1347 (v,), 1348 ) 1349 atol, rtol = 1e-2, 1e-2 1350 with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): 1351 self.common(ref_quantized_mod, (v,), atol=atol, rtol=rtol) 1352 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 1353 vec_amx = VecAMX() 1354 self._check_amx_counter(vec_amx) 1355 1356 @inductor_config.patch({"freezing": True}) 1357 @inductor_config.patch({"cpp.gemm_max_k_slices": 0}) 1358 @patches 1359 @torch.no_grad 1360 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 1361 @parametrize("batch_size", (2,)) 1362 @parametrize("in_features", (1000,)) 1363 @parametrize("out_features", (2,)) 1364 @parametrize("bias", (True, False)) 1365 @parametrize( 1366 "epilogue", 1367 ( 1368 "none", 1369 "relu", 1370 ), 1371 ) 1372 @dtypes(torch.float, torch.bfloat16, torch.half) 1373 def test_linear_k_slicing( 1374 self, batch_size, in_features, out_features, bias, epilogue, dtype 1375 ): 1376 class M(torch.nn.Module): 1377 def __init__(self, bias, epilogue, other): 1378 super().__init__() 1379 self.linear = torch.nn.Linear(in_features, out_features, bias) 1380 self.epilogue = _get_epilogue(epilogue, other) 1381 1382 def forward(self, x): 1383 return self.epilogue(self.linear(x)) 1384 1385 counters.clear() 1386 v = torch.randn(batch_size, in_features).to(dtype=dtype) 1387 u = torch.randn(batch_size, out_features).to(dtype=dtype) 1388 mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval() 1389 with verify(dtype) as (atol, rtol): 1390 self.common(mod, (v,), atol=atol, rtol=rtol) 1391 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 1392 1393 @inductor_config.patch({"freezing": True}) 1394 @inductor_config.patch({"cpp.gemm_cache_blocking": "2,2,2"}) 1395 @patches 1396 @torch.no_grad 1397 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 1398 @set_num_threads(1) 1399 @parametrize("batch_size", (1024,)) 1400 @parametrize("in_features", (1024,)) 1401 @parametrize("out_features", (1024,)) 1402 @parametrize("bias", (True, False)) 1403 @dtypes(torch.float, torch.bfloat16, torch.half) 1404 def test_linear_cache_blocking( 1405 self, batch_size, in_features, out_features, bias, dtype 1406 ): 1407 class M(torch.nn.Module): 1408 def __init__(self, bias): 1409 super().__init__() 1410 self.linear = torch.nn.Linear(in_features, out_features, bias) 1411 1412 def forward(self, x): 1413 return self.linear(x) 1414 1415 counters.clear() 1416 v = torch.randn(batch_size, in_features).to(dtype=dtype) 1417 mod = M(bias=bias).to(dtype=dtype).eval() 1418 with verify(dtype) as (atol, rtol): 1419 self.common(mod, (v,), atol=atol, rtol=rtol) 1420 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 1421 1422 @inductor_config.patch({"freezing": True}) 1423 @inductor_config.patch({"cpp.gemm_thread_factors": "4,2,7"}) 1424 @patches 1425 @torch.no_grad 1426 @unittest.skipIf(not TEST_MKL, "Test requires MKL") 1427 @set_num_threads(56) 1428 @parametrize("batch_size", (1024,)) 1429 @parametrize("in_features", (1024,)) 1430 @parametrize("out_features", (1024,)) 1431 @parametrize("bias", (True, False)) 1432 @dtypes(torch.float, torch.bfloat16, torch.half) 1433 def test_linear_thread_factors( 1434 self, batch_size, in_features, out_features, bias, dtype 1435 ): 1436 class M(torch.nn.Module): 1437 def __init__(self, bias): 1438 super().__init__() 1439 self.linear = torch.nn.Linear(in_features, out_features, bias) 1440 1441 def forward(self, x): 1442 return self.linear(x) 1443 1444 counters.clear() 1445 v = torch.randn(batch_size, in_features).to(dtype=dtype) 1446 mod = M(bias=bias).to(dtype=dtype).eval() 1447 with verify(dtype) as (atol, rtol): 1448 self.common(mod, (v,), atol=atol, rtol=rtol) 1449 self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) 1450 1451 1452@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) 1453class _DynamicShapesTestBase(BaseTestSelectAlgorithm): 1454 pass 1455 1456 1457class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): 1458 common = check_model 1459 test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes 1460 test_linear_with_pointwise_dynamic_shapes = ( 1461 TestSelectAlgorithm.test_linear_with_pointwise 1462 ) 1463 test_linear_with_transpose_dynamic_shapes = ( 1464 TestSelectAlgorithm.test_linear_with_transpose 1465 ) 1466 test_linear_with_unary_binary_dynamic_shapes = ( 1467 TestSelectAlgorithm.test_linear_with_unary_binary 1468 ) 1469 test_linear_amx_dynamic_shapes = TestSelectAlgorithm.test_linear_amx 1470 test_linear_with_embedding_dynamic_shapes = ( 1471 TestSelectAlgorithm.test_linear_with_embedding 1472 ) 1473 test_quantized_linear_with_pointwise_dynamic_shapes = ( 1474 TestSelectAlgorithm.test_quantized_linear_with_pointwise 1475 ) 1476 test_quantized_linear_with_pointwise_binary_dynamic_shapes = ( 1477 TestSelectAlgorithm.test_quantized_linear_with_pointwise_binary 1478 ) 1479 test_quantized_linear_amx_dynamic_shapes = ( 1480 TestSelectAlgorithm.test_quantized_linear_amx 1481 ) 1482 test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing 1483 test_linear_cache_blocking_dynamic_shapes = ( 1484 TestSelectAlgorithm.test_linear_cache_blocking 1485 ) 1486 test_linear_thread_factors_dynamic_shapes = ( 1487 TestSelectAlgorithm.test_linear_thread_factors 1488 ) 1489 1490 1491instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") 1492instantiate_device_type_tests( 1493 TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu" 1494) 1495 1496 1497if __name__ == "__main__": 1498 from torch.testing._internal.inductor_utils import HAS_CPU 1499 1500 if HAS_CPU and not IS_MACOS: 1501 run_tests() 1502