1# Owner(s): ["oncall: quantization"] 2import copy 3import itertools 4import sys 5from enum import Enum 6 7import torch 8import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq 9import torch.nn as nn 10from torch._export import capture_pre_autograd_graph 11from torch.ao.quantization import ObserverBase 12from torch.ao.quantization.quantize_pt2e import ( 13 convert_pt2e, 14 prepare_pt2e, 15 prepare_qat_pt2e, 16) 17from torch.ao.quantization.quantizer.x86_inductor_quantizer import ( 18 QUANT_ANNOTATION_KEY, 19 X86InductorQuantizer, 20) 21from torch.testing._internal.common_quantization import ( 22 NodeSpec as ns, 23 QuantizationTestCase, 24 skipIfNoInductorSupport, 25 skipIfNoX86, 26) 27from torch.testing._internal.common_quantized import override_quantized_engine 28from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfTorchDynamo 29 30 31if IS_WINDOWS and IS_CI: 32 sys.stderr.write("Windows CI still has some issue to be fixed.\n") 33 sys.exit(0) 34 35 36class NodePosType(Enum): 37 left = 1 38 right = 2 39 both = 3 40 41 42class TestHelperModules: 43 class SingleConv2dModule(torch.nn.Module): 44 def __init__(self, with_bn=False) -> None: 45 super().__init__() 46 self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1)) 47 self.bn = torch.nn.BatchNorm2d(6) 48 self.with_bn = with_bn 49 50 def forward(self, x): 51 x = self.conv(x) 52 if self.with_bn: 53 x = self.bn(x) 54 return x 55 56 class Conv2dUnaryModule(torch.nn.Module): 57 def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None: 58 super().__init__() 59 self.conv = nn.Conv2d( 60 3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias 61 ) 62 self.post_op = post_op 63 self.bn = torch.nn.BatchNorm2d(6) 64 self.with_bn = with_bn 65 self.maxpool = torch.nn.MaxPool2d((3, 3)) 66 67 def forward(self, x): 68 x = self.conv(x) 69 if self.with_bn: 70 x = self.bn(x) 71 x = self.post_op(x) 72 x = self.maxpool(x) 73 return x 74 75 class Conv2dAddModule(torch.nn.Module): 76 def __init__( 77 self, 78 inplace_add: bool = False, 79 conv2d_type: NodePosType = NodePosType.left, 80 use_bias: bool = False, 81 with_bn: bool = False, 82 ) -> None: 83 super().__init__() 84 self.conv = torch.nn.Conv2d( 85 in_channels=3, 86 out_channels=3, 87 kernel_size=3, 88 stride=1, 89 padding=1, 90 bias=use_bias, 91 ) 92 self.conv2 = torch.nn.Conv2d( 93 in_channels=3, 94 out_channels=3, 95 kernel_size=3, 96 stride=1, 97 padding=1, 98 bias=use_bias, 99 ) 100 self.relu = nn.ReLU() 101 self.inplace_add = inplace_add 102 self.conv2d_type = conv2d_type 103 self.bn = torch.nn.BatchNorm2d(3) 104 self.with_bn = with_bn 105 106 def forward(self, x): 107 if self.conv2d_type == NodePosType.left: 108 if self.inplace_add: 109 tmp = self.conv(x) 110 if self.with_bn: 111 tmp = self.bn(tmp) 112 tmp += self.relu(x) 113 return tmp 114 else: 115 tmp = self.conv(x) 116 if self.with_bn: 117 tmp = self.bn(tmp) 118 return tmp + self.relu(x) 119 elif self.conv2d_type == NodePosType.right: 120 if self.inplace_add: 121 tmp = self.relu(x) 122 tmp += self.conv(x) 123 return tmp 124 else: 125 return self.relu(x) + self.conv(x) 126 elif self.conv2d_type == NodePosType.both: 127 if self.inplace_add: 128 tmp = self.conv(x) 129 tmp += self.conv2(x) 130 return tmp 131 else: 132 return self.conv(x) + self.conv2(x) 133 134 class Conv2dAddReLUModule(torch.nn.Module): 135 def __init__( 136 self, 137 inplace_add: bool = False, 138 conv2d_type: NodePosType = NodePosType.left, 139 inplace_relu: bool = False, 140 use_bias: bool = False, 141 with_bn: bool = False, 142 ) -> None: 143 super().__init__() 144 self.conv = torch.nn.Conv2d( 145 in_channels=3, 146 out_channels=3, 147 kernel_size=3, 148 stride=1, 149 padding=1, 150 bias=use_bias, 151 ) 152 self.conv2 = torch.nn.Conv2d( 153 in_channels=3, 154 out_channels=3, 155 kernel_size=3, 156 stride=1, 157 padding=1, 158 bias=use_bias, 159 ) 160 self.relu = nn.ReLU() 161 self.inplace_add = inplace_add 162 self.conv2d_type = conv2d_type 163 self.relu2 = nn.ReLU(inplace=inplace_relu) 164 self.bn = torch.nn.BatchNorm2d(3) 165 self.with_bn = with_bn 166 167 def forward(self, x): 168 if self.conv2d_type == NodePosType.left: 169 if self.inplace_add: 170 tmp = self.conv(x) 171 if self.with_bn: 172 tmp = self.bn(tmp) 173 tmp += self.relu(x) 174 return self.relu2(tmp) 175 else: 176 tmp = self.conv(x) 177 if self.with_bn: 178 tmp = self.bn(tmp) 179 return self.relu2(tmp + self.relu(x)) 180 elif self.conv2d_type == NodePosType.right: 181 if self.inplace_add: 182 tmp = self.relu(x) 183 tmp += self.conv(x) 184 return self.relu2(tmp) 185 else: 186 return self.relu2(self.relu(x) + self.conv(x)) 187 elif self.conv2d_type == NodePosType.both: 188 if self.inplace_add: 189 tmp = self.conv(x) 190 tmp += self.conv2(x) 191 return self.relu2(tmp) 192 else: 193 return self.relu2(self.conv(x) + self.conv2(x)) 194 195 class Conv2dSingleOpPowModule(nn.Module): 196 def __init__(self, single_op): 197 super().__init__() 198 self.conv = nn.Conv2d(2, 2, 1) 199 self.single_op = single_op 200 201 def forward(self, x): 202 x = self.conv(x) 203 x = self.single_op(x) 204 return torch.pow(x, 2) 205 206 class SerialsConv2dAddReLUModule(torch.nn.Module): 207 """Serials of 2 Conv2d -> Add -> ReLU Pattern.""" 208 209 def __init__( 210 self, 211 ) -> None: 212 super().__init__() 213 self.conv = torch.nn.Conv2d( 214 in_channels=3, 215 out_channels=3, 216 kernel_size=3, 217 stride=1, 218 padding=1, 219 bias=True, 220 ) 221 self.conv2 = torch.nn.Conv2d( 222 in_channels=3, 223 out_channels=3, 224 kernel_size=3, 225 stride=1, 226 padding=1, 227 bias=True, 228 ) 229 self.conv3 = torch.nn.Conv2d( 230 in_channels=3, 231 out_channels=3, 232 kernel_size=3, 233 stride=1, 234 padding=1, 235 bias=True, 236 ) 237 self.conv4 = torch.nn.Conv2d( 238 in_channels=3, 239 out_channels=3, 240 kernel_size=3, 241 stride=1, 242 padding=1, 243 bias=True, 244 ) 245 self.relu = nn.ReLU() 246 self.relu2 = nn.ReLU() 247 248 def forward(self, x): 249 x1 = self.conv(x) 250 res1 = self.relu(self.conv2(x1) + self.conv3(x1)) 251 res2 = self.relu2(self.conv4(res1) + res1) 252 return res2 253 254 class Conv2dCatMaxpool2d(torch.nn.Module): 255 def __init__( 256 self, 257 ): 258 super().__init__() 259 self.conv = torch.nn.Conv2d( 260 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 261 ) 262 self.conv2 = torch.nn.Conv2d( 263 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 264 ) 265 self.relu = torch.nn.ReLU() 266 self.maxpool = torch.nn.MaxPool2d(3, stride=2, padding=1) 267 self.conv3 = torch.nn.Conv2d( 268 32, 32, 7, bias=True, stride=2, padding=3, dilation=1 269 ) 270 271 def forward(self, x): 272 temp1 = self.relu(self.conv(x)) 273 temp2 = self.conv2(x + 1) 274 temp3 = torch.cat((temp1, temp2), 1) 275 temp4 = self.maxpool(temp3) 276 temp5 = self.conv3(temp4) 277 return temp5 278 279 class Conv2dAvgPool2d(torch.nn.Module): 280 def __init__( 281 self, 282 ): 283 super().__init__() 284 self.conv = torch.nn.Conv2d( 285 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 286 ) 287 self.avgpool = torch.nn.AvgPool2d(3, stride=2, padding=1) 288 289 def forward(self, x): 290 temp1 = self.avgpool(self.conv(x)) 291 return temp1 292 293 class Conv2dCatSameInputs(torch.nn.Module): 294 def __init__( 295 self, 296 ): 297 super().__init__() 298 self.conv = torch.nn.Conv2d( 299 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 300 ) 301 self.relu = torch.nn.ReLU() 302 303 def forward(self, x): 304 temp1 = self.relu(self.conv(x)) 305 temp3 = torch.cat((temp1, temp1), 1) 306 return temp3 307 308 class Conv2dCatSingleInput(torch.nn.Module): 309 def __init__( 310 self, 311 ): 312 super().__init__() 313 self.conv = torch.nn.Conv2d( 314 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 315 ) 316 self.relu = torch.nn.ReLU() 317 318 def forward(self, x): 319 temp1 = self.relu(self.conv(x)) 320 temp3 = torch.cat((temp1,), 1) 321 return temp3 322 323 class SingleLinearModule(torch.nn.Module): 324 def __init__(self, use_bias) -> None: 325 super().__init__() 326 self.linear = nn.Linear(4, 4, bias=use_bias) 327 328 def forward(self, x): 329 return self.linear(x) 330 331 class LinearUnaryModule(torch.nn.Module): 332 def __init__( 333 self, use_bias, postop, inplace_postop=False, post_op_algo="none" 334 ) -> None: 335 super().__init__() 336 self.linear = nn.Linear(4, 4, bias=use_bias) 337 if postop == nn.GELU: 338 self.postop = postop(approximate=post_op_algo) 339 else: 340 self.postop = postop(inplace=inplace_postop) 341 342 def forward(self, x): 343 return self.postop(self.linear(x)) 344 345 class LinearAddModule(torch.nn.Module): 346 def __init__( 347 self, 348 inplace_add: bool = False, 349 linear_pos: NodePosType = NodePosType.left, 350 use_bias: bool = False, 351 ) -> None: 352 super().__init__() 353 self.linear = torch.nn.Linear( 354 in_features=16, out_features=16, bias=use_bias 355 ) 356 self.linear2 = torch.nn.Linear( 357 in_features=16, out_features=16, bias=use_bias 358 ) 359 self.relu = nn.ReLU() 360 self.inplace_add = inplace_add 361 self.linear_pos = linear_pos 362 363 def forward(self, x): 364 if self.linear_pos == NodePosType.left: 365 if self.inplace_add: 366 tmp = self.linear(x) 367 tmp += self.relu(x) 368 return tmp 369 else: 370 tmp = self.linear(x) 371 return tmp + self.relu(x) 372 elif self.linear_pos == NodePosType.right: 373 if self.inplace_add: 374 tmp = self.relu(x) 375 tmp += self.linear(x) 376 return tmp 377 else: 378 return self.relu(x) + self.linear(x) 379 elif self.linear_pos == NodePosType.both: 380 if self.inplace_add: 381 tmp = self.linear(x) 382 tmp += self.linear2(x) 383 return tmp 384 else: 385 return self.linear(x) + self.linear2(x) 386 387 class LinearAddReLUModule(torch.nn.Module): 388 def __init__( 389 self, 390 inplace_add: bool = False, 391 linear_pos: NodePosType = NodePosType.left, 392 inplace_relu: bool = False, 393 use_bias: bool = False, 394 ) -> None: 395 super().__init__() 396 self.linear = torch.nn.Linear( 397 in_features=16, out_features=16, bias=use_bias 398 ) 399 self.linear2 = torch.nn.Linear( 400 in_features=16, out_features=16, bias=use_bias 401 ) 402 self.relu = nn.ReLU() 403 self.inplace_add = inplace_add 404 self.linear_pos = linear_pos 405 self.relu2 = nn.ReLU(inplace=inplace_relu) 406 407 def forward(self, x): 408 if self.linear_pos == NodePosType.left: 409 if self.inplace_add: 410 tmp = self.linear(x) 411 tmp += self.relu(x) 412 return self.relu2(tmp) 413 else: 414 tmp = self.linear(x) 415 return self.relu2(tmp + self.relu(x)) 416 elif self.linear_pos == NodePosType.right: 417 if self.inplace_add: 418 tmp = self.relu(x) 419 tmp += self.linear(x) 420 return self.relu2(tmp) 421 else: 422 return self.relu2(self.relu(x) + self.linear(x)) 423 elif self.linear_pos == NodePosType.both: 424 if self.inplace_add: 425 tmp = self.linear(x) 426 tmp += self.linear2(x) 427 return self.relu2(tmp) 428 else: 429 return self.relu2(self.linear(x) + self.linear2(x)) 430 431 class SerialsLinearAddReLUModule(torch.nn.Module): 432 """Serials of 2 Linear -> Add -> ReLU Pattern.""" 433 434 def __init__( 435 self, 436 ) -> None: 437 super().__init__() 438 self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) 439 self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) 440 self.linear3 = torch.nn.Linear(in_features=16, out_features=16, bias=True) 441 self.linear4 = torch.nn.Linear(in_features=16, out_features=16, bias=True) 442 self.relu = nn.ReLU() 443 self.relu2 = nn.ReLU() 444 445 def forward(self, x): 446 x1 = self.linear(x) 447 res1 = self.relu(self.linear2(x1) + self.linear3(x1)) 448 res2 = self.relu2(self.linear4(res1) + res1) 449 return res2 450 451 class LinearAddModule2(torch.nn.Module): 452 def __init__( 453 self, 454 inplace_add: bool = False, 455 ) -> None: 456 super().__init__() 457 self.linear = torch.nn.Linear(in_features=16, out_features=16, bias=True) 458 self.linear2 = torch.nn.Linear(in_features=16, out_features=16, bias=True) 459 self.inplace_add = inplace_add 460 461 def forward(self, x): 462 if self.inplace_add: 463 tmp = self.linear(x) 464 tmp += self.linear2(tmp) 465 return tmp 466 else: 467 tmp = self.linear(x) 468 return tmp + self.linear2(tmp) 469 470 class Conv2dAddModule2(torch.nn.Module): 471 def __init__( 472 self, 473 inplace_add: bool = False, 474 ) -> None: 475 super().__init__() 476 self.conv = torch.nn.Conv2d( 477 in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 478 ) 479 self.conv2 = torch.nn.Conv2d( 480 in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1 481 ) 482 self.inplace_add = inplace_add 483 self.bn = torch.nn.BatchNorm2d(3) 484 self.bn2 = torch.nn.BatchNorm2d(3) 485 486 def forward(self, x): 487 if self.inplace_add: 488 tmp = self.bn(self.conv(x)) 489 tmp += self.bn2(self.conv2(tmp)) 490 return tmp 491 else: 492 tmp = self.bn(self.conv(x)) 493 return tmp + self.bn2(self.conv2(tmp)) 494 495 class SelfAttnLikeModule(torch.nn.Module): 496 def __init__( 497 self, 498 input_dim, 499 transpose_for_score=False, 500 num_attention_heads=None, 501 attention_head_size=None, 502 ) -> None: 503 super().__init__() 504 self.input_dim = input_dim 505 self.q_proj = nn.Linear(input_dim, input_dim, bias=False) 506 self.k_proj = nn.Linear(input_dim, input_dim, bias=False) 507 self.v_proj = nn.Linear(input_dim, input_dim, bias=False) 508 self.softmax = nn.Softmax(dim=-1) 509 self.transpose_for_score = transpose_for_score 510 if self.transpose_for_score: 511 assert num_attention_heads is not None 512 assert attention_head_size is not None 513 self.num_attention_heads = num_attention_heads 514 self.attention_head_size = attention_head_size 515 516 def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 517 new_x_shape = x.size()[:-1] + ( 518 self.num_attention_heads, 519 self.attention_head_size, 520 ) 521 x = x.view(new_x_shape) 522 return x.permute(0, 2, 1, 3) 523 524 def forward(self, x): 525 q = self.q_proj(x) 526 k = self.k_proj(x) 527 v = self.v_proj(x) 528 if self.transpose_for_score: 529 q = self.transpose_for_scores(q) 530 k = self.transpose_for_scores(k) 531 v = self.transpose_for_scores(v) 532 scores = torch.matmul(q, k.transpose(-1, -2)) / (self.input_dim**0.5) 533 attention = self.softmax(scores) 534 weighted = torch.matmul(attention, v) 535 return weighted 536 537 538class X86InductorQuantTestCase(QuantizationTestCase): 539 def _test_quantizer( 540 self, 541 model, 542 example_inputs, 543 quantizer, 544 expected_node_occurrence, 545 expected_node_list=None, 546 is_qat=False, 547 debug=False, 548 ): 549 m_eager = model.train() if is_qat else model.eval() 550 551 # program capture 552 m = copy.deepcopy(m_eager) 553 m = capture_pre_autograd_graph( 554 m, 555 example_inputs, 556 ) 557 558 # QAT Model failed to deepcopy 559 export_model = m if is_qat else copy.deepcopy(m) 560 m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) 561 # Calibrate 562 m(*example_inputs) 563 prepare_model = copy.deepcopy(m) 564 m = convert_pt2e(m) 565 convert_model = copy.deepcopy(m) 566 if debug: 567 convert_model.print_readable(True) 568 pt2_quant_output = m(*example_inputs) 569 node_occurrence = { 570 ns.call_function(k): v for k, v in expected_node_occurrence.items() 571 } 572 if expected_node_list is None: 573 expected_node_list = [] 574 node_list = [ns.call_function(n) for n in expected_node_list] 575 self.checkGraphModuleNodes( 576 m, expected_node_occurrence=node_occurrence, expected_node_list=node_list 577 ) 578 return export_model, prepare_model, convert_model 579 580 581@skipIfNoInductorSupport 582class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): 583 @skipIfNoX86 584 def test_conv2d(self): 585 """ 586 Test pattern of single conv2d with X86InductorQuantizer. 587 """ 588 with override_quantized_engine("x86"), torch.no_grad(): 589 m = TestHelperModules.SingleConv2dModule().eval() 590 example_inputs = (torch.randn(2, 3, 16, 16),) 591 quantizer = X86InductorQuantizer().set_global( 592 xiq.get_default_x86_inductor_quantization_config() 593 ) 594 node_occurrence = { 595 # one for input and weight of the conv 596 torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, 597 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 598 # note: quantize op for weights are const propagated 599 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 600 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 601 } 602 node_list = [ 603 torch.ops.quantized_decomposed.quantize_per_tensor.default, 604 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 605 torch.ops.aten.conv2d.default, 606 ] 607 self._test_quantizer( 608 m, 609 example_inputs, 610 quantizer, 611 node_occurrence, 612 node_list, 613 ) 614 615 @skipIfNoX86 616 def test_conv2d_unary(self): 617 """ 618 Test pattern of conv2d with unary post ops (such as relu, hardtanh, hardswish, relu6) with X86InductorQuantizer. 619 """ 620 unary_map = { 621 "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], 622 "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], 623 "hardtanh": [ 624 torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), 625 torch.ops.aten.hardtanh.default, 626 ], 627 "hardtanh_inplace": [ 628 torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), 629 torch.ops.aten.hardtanh_.default, 630 ], 631 "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], 632 "relu6_inplace": [ 633 torch.nn.ReLU6(inplace=True), 634 torch.ops.aten.hardtanh_.default, 635 ], 636 "hardswish": [ 637 torch.nn.Hardswish(inplace=False), 638 torch.ops.aten.hardswish.default, 639 ], 640 "hardswish_inplace": [ 641 torch.nn.Hardswish(inplace=True), 642 torch.ops.aten.hardswish_.default, 643 ], 644 "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default], 645 "swish_inplace": [ 646 torch.nn.SiLU(inplace=True), 647 torch.ops.aten.silu_.default, 648 ], 649 } 650 use_bias_list = [True, False] 651 with override_quantized_engine("x86"), torch.no_grad(): 652 for unary_op, use_bias in itertools.product( 653 unary_map.keys(), use_bias_list 654 ): 655 m = TestHelperModules.Conv2dUnaryModule( 656 unary_map[unary_op][0], use_bias=use_bias 657 ).eval() 658 example_inputs = (torch.randn(2, 3, 16, 16),) 659 quantizer = X86InductorQuantizer().set_global( 660 xiq.get_default_x86_inductor_quantization_config() 661 ) 662 node_occurrence = { 663 # one for input and weight of the conv 664 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 665 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 666 # note: quantize op for weights are const propagated 667 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 668 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 669 } 670 node_list = [ 671 torch.ops.quantized_decomposed.quantize_per_tensor.default, 672 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 673 torch.ops.aten.conv2d.default, 674 unary_map[unary_op][1], 675 ] 676 self._test_quantizer( 677 m, 678 example_inputs, 679 quantizer, 680 node_occurrence, 681 node_list, 682 ) 683 684 @skipIfNoX86 685 def test_conv2d_binary(self): 686 """ 687 Test pattern of conv2d with binary post ops (such as add) with X86InductorQuantizer. 688 Currently, only add as binary post op is supported. 689 """ 690 conv2d_type_list = [NodePosType.left, NodePosType.both] 691 example_inputs = (torch.randn(2, 3, 6, 6),) 692 quantizer = X86InductorQuantizer().set_global( 693 xiq.get_default_x86_inductor_quantization_config() 694 ) 695 with override_quantized_engine("x86"), torch.no_grad(): 696 for conv2d_type in conv2d_type_list: 697 m = TestHelperModules.Conv2dAddModule(conv2d_type=conv2d_type).eval() 698 if conv2d_type != NodePosType.both: 699 node_occurrence = { 700 # one for input and weight of the conv 701 # one for extra input node of add 702 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 703 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, 704 # quantize_per_channel for weights are const propagated 705 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 706 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 707 } 708 else: 709 node_occurrence = { 710 # one for input of the conv 711 # one for input of another conv 712 # 2 conv will share same input quant/dequant 713 # one for extra input node of add 714 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 715 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 716 # quantize_per_channel for weights are const propagated 717 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 718 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 719 } 720 node_list = [ 721 torch.ops.quantized_decomposed.quantize_per_tensor.default, 722 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 723 torch.ops.aten.conv2d.default, 724 torch.ops.aten.add.Tensor, 725 ] 726 self._test_quantizer( 727 m, 728 example_inputs, 729 quantizer, 730 node_occurrence, 731 node_list, 732 ) 733 734 @skipIfNoX86 735 def test_conv2d_binary2(self): 736 """ 737 Test Pattern: 738 tmp = conv2d_1(x) 739 tmp2 = conv2d_2(tmp) 740 return tmp + tmp2 741 Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 742 """ 743 example_inputs = (torch.randn(2, 3, 6, 6),) 744 quantizer = X86InductorQuantizer().set_global( 745 xiq.get_default_x86_inductor_quantization_config() 746 ) 747 inplace_add_list = [True, False] 748 with override_quantized_engine("x86"), torch.no_grad(): 749 for inplace_add in inplace_add_list: 750 m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add).eval() 751 node_occurrence = { 752 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 753 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 754 # quantize_per_channel for weights are const propagated 755 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 756 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 757 } 758 node_list = [ 759 torch.ops.quantized_decomposed.quantize_per_tensor.default, 760 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 761 torch.ops.aten.conv2d.default, 762 torch.ops.quantized_decomposed.quantize_per_tensor.default, 763 ( 764 torch.ops.aten.add_.Tensor 765 if inplace_add 766 else torch.ops.aten.add.Tensor 767 ), 768 ] 769 self._test_quantizer( 770 m, 771 example_inputs, 772 quantizer, 773 node_occurrence, 774 node_list, 775 ) 776 777 @skipIfNoX86 778 def test_conv2d_binary_unary(self): 779 """ 780 Test pattern of conv2d with binary + unary post ops (such as add + relu) with X86InductorQuantizer. 781 Currently, only add as binary post op and relu as unary post op are supported. 782 """ 783 conv2d_type_list = [NodePosType.left, NodePosType.both] 784 example_inputs = (torch.randn(2, 3, 6, 6),) 785 quantizer = X86InductorQuantizer().set_global( 786 xiq.get_default_x86_inductor_quantization_config() 787 ) 788 with override_quantized_engine("x86"), torch.no_grad(): 789 for conv2d_type in conv2d_type_list: 790 m = TestHelperModules.Conv2dAddReLUModule( 791 conv2d_type=conv2d_type, 792 ).eval() 793 if conv2d_type != NodePosType.both: 794 node_occurrence = { 795 # one for input for conv 796 # one for extra input node of add 797 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 798 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, 799 # note: quantize op for weights are const propagated 800 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 801 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 802 } 803 else: 804 node_occurrence = { 805 # one for input of the conv 806 # one for input of another conv 807 # 2 conv will share same input quant/dequant 808 # one for extra input node of add 809 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 810 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 811 # note: quantize op for weights are const propagated 812 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 813 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 814 } 815 node_list = [ 816 torch.ops.quantized_decomposed.quantize_per_tensor.default, 817 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 818 torch.ops.aten.conv2d.default, 819 torch.ops.aten.add.Tensor, 820 ] 821 self._test_quantizer( 822 m, 823 example_inputs, 824 quantizer, 825 node_occurrence, 826 node_list, 827 ) 828 829 @skipIfNoX86 830 def test_conv2d_serials_binary_unary(self): 831 """ 832 Test pattern of 2 following up conv2d add relu with X86InductorQuantizer. 833 """ 834 with override_quantized_engine("x86"), torch.no_grad(): 835 m = TestHelperModules.SerialsConv2dAddReLUModule().eval() 836 example_inputs = (torch.randn(2, 3, 16, 16),) 837 quantizer = X86InductorQuantizer().set_global( 838 xiq.get_default_x86_inductor_quantization_config() 839 ) 840 node_occurrence = { 841 torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, 842 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, 843 # quantize_per_channel for weights are const propagated 844 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 845 torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, 846 } 847 node_list = [ 848 torch.ops.quantized_decomposed.quantize_per_tensor.default, 849 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 850 torch.ops.aten.conv2d.default, 851 torch.ops.quantized_decomposed.quantize_per_tensor.default, 852 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 853 torch.ops.aten.conv2d.default, 854 torch.ops.aten.conv2d.default, 855 torch.ops.aten.add.Tensor, 856 torch.ops.aten.relu.default, 857 ] 858 self._test_quantizer( 859 m, 860 example_inputs, 861 quantizer, 862 node_occurrence, 863 node_list, 864 ) 865 866 def _single_op_share_observer_recipe_test_helper(self, m, x, single_op): 867 quantizer = X86InductorQuantizer().set_global( 868 xiq.get_default_x86_inductor_quantization_config() 869 ) 870 example_inputs = (x,) 871 node_occurrence = { 872 # one for input and weight of the conv, two for input/output for the maxpool2d 873 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 874 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 875 # quantize_per_channel for weights are const propagated 876 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 877 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 878 } 879 node_list = [ 880 torch.ops.quantized_decomposed.quantize_per_tensor.default, 881 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 882 torch.ops.aten.conv2d.default, 883 torch.ops.quantized_decomposed.quantize_per_tensor.default, 884 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 885 single_op, 886 torch.ops.quantized_decomposed.quantize_per_tensor.default, 887 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 888 ] 889 _, prepare_model, _ = self._test_quantizer( 890 m, 891 example_inputs, 892 quantizer, 893 node_occurrence, 894 node_list, 895 ) 896 # Check Maxpool2d has share observer at input and output 897 for node in prepare_model.graph.nodes: 898 if node.op == "call_function" and node.target is single_op: 899 single_op_node = node 900 input_obs_of_single_op = getattr( 901 prepare_model, single_op_node.args[0].target 902 ) 903 output_obs_of_single_op = getattr( 904 prepare_model, next(iter(single_op_node.users)).target 905 ) 906 elif ( 907 node.op == "call_function" 908 and node.target is torch.ops.aten.conv2d.default 909 ): 910 conv_node = node 911 input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target) 912 self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase)) 913 self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase)) 914 self.assertTrue(isinstance(input_obs_of_conv, ObserverBase)) 915 self.assertTrue(input_obs_of_single_op is output_obs_of_single_op) 916 self.assertTrue(input_obs_of_single_op is not input_obs_of_conv) 917 918 @skipIfNoX86 919 def test_maxpool2d_recipe(self): 920 r""" 921 Test pattern: int8_in_int8_out_ops(maxpool) - non_quantizable op(pow) 922 Since maxpool is a int8_in_int8_out_op, there is obs between maxpool and pow. 923 """ 924 self._single_op_share_observer_recipe_test_helper( 925 TestHelperModules.Conv2dSingleOpPowModule(nn.MaxPool2d(1, 1)).eval(), 926 torch.rand(1, 2, 14, 14), 927 torch.ops.aten.max_pool2d.default, 928 ) 929 930 @skipIfNoX86 931 def test_adaptive_avg_pool2d_recipe(self): 932 r""" 933 Test pattern: int8_in_int8_out_ops(adaptive_avg_pool2d) - non_quantizable op(pow) 934 Since adaptive_avg_pool2d is a int8_in_int8_out_op, there is obs between adaptive_avg_pool2d and pow. 935 """ 936 self._single_op_share_observer_recipe_test_helper( 937 TestHelperModules.Conv2dSingleOpPowModule( 938 nn.AdaptiveAvgPool2d((1, 1)) 939 ).eval(), 940 torch.rand(1, 2, 14, 14), 941 torch.ops.aten.adaptive_avg_pool2d.default, 942 ) 943 944 @skipIfNoX86 945 def test_flatten_recipe(self): 946 r""" 947 Test pattern: int8_in_int8_out_ops(flatten) - non_quantizable op(pow) 948 Since flatten is a int8_in_int8_out_op, there is obs between flatten and pow. 949 """ 950 self._single_op_share_observer_recipe_test_helper( 951 TestHelperModules.Conv2dSingleOpPowModule( 952 lambda x: torch.flatten(x, 1) 953 ).eval(), 954 torch.rand(1, 2, 14, 14), 955 torch.ops.aten.flatten.using_ints, 956 ) 957 958 @skipIfNoX86 959 def test_cat_recipe(self): 960 r""" 961 Test pattern: conv -> cat -> maxpool2d 962 Since cat, maxpool is a int8_in_int8_out_op, the inputs and outputs should with same observer. 963 """ 964 m = TestHelperModules.Conv2dCatMaxpool2d().eval() 965 x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) 966 quantizer = X86InductorQuantizer().set_global( 967 xiq.get_default_x86_inductor_quantization_config() 968 ) 969 example_inputs = (x,) 970 node_occurrence = { 971 torch.ops.quantized_decomposed.quantize_per_tensor.default: 6, 972 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 6, 973 # quantize_per_channel for weights are const propagated 974 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 975 torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, 976 } 977 node_list = [ 978 torch.ops.quantized_decomposed.quantize_per_tensor.default, 979 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 980 torch.ops.aten.conv2d.default, 981 torch.ops.quantized_decomposed.quantize_per_tensor.default, 982 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 983 torch.ops.aten.cat.default, 984 torch.ops.quantized_decomposed.quantize_per_tensor.default, 985 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 986 torch.ops.aten.max_pool2d.default, 987 torch.ops.quantized_decomposed.quantize_per_tensor.default, 988 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 989 ] 990 _, prepare_model, _ = self._test_quantizer( 991 m, 992 example_inputs, 993 quantizer, 994 node_occurrence, 995 node_list, 996 ) 997 # Check Cat/Maxpool2d has share observer at input and output 998 for node in prepare_model.graph.nodes: 999 if node.op == "call_function" and node.target == torch.ops.aten.cat.default: 1000 cat_act_obs0 = getattr(prepare_model, node.all_input_nodes[0].target) 1001 cat_act_obs1 = getattr(prepare_model, node.all_input_nodes[1].target) 1002 cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) 1003 elif ( 1004 node.op == "call_function" 1005 and node.target is torch.ops.aten.max_pool2d.default 1006 ): 1007 maxpool_node = node 1008 input_obs_of_maxpool = getattr( 1009 prepare_model, maxpool_node.args[0].target 1010 ) 1011 output_obs_of_maxpool = getattr( 1012 prepare_model, next(iter(maxpool_node.users)).target 1013 ) 1014 self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) 1015 self.assertTrue(isinstance(cat_act_obs1, ObserverBase)) 1016 self.assertTrue(isinstance(cat_out_obs, ObserverBase)) 1017 self.assertTrue(isinstance(input_obs_of_maxpool, ObserverBase)) 1018 self.assertTrue(isinstance(output_obs_of_maxpool, ObserverBase)) 1019 self.assertTrue(cat_act_obs0 is cat_act_obs1) 1020 self.assertTrue(cat_act_obs0 is cat_out_obs) 1021 self.assertTrue(cat_out_obs is input_obs_of_maxpool) 1022 self.assertTrue(input_obs_of_maxpool is output_obs_of_maxpool) 1023 1024 @skipIfNoX86 1025 def test_cat_recipe_same_inputs(self): 1026 r""" 1027 Test pattern: conv -> cat([input0, input0]) 1028 Since cat has 2 input node of same tensor, they should also be with same observer. 1029 """ 1030 m = TestHelperModules.Conv2dCatSameInputs().eval() 1031 x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) 1032 quantizer = X86InductorQuantizer().set_global( 1033 xiq.get_default_x86_inductor_quantization_config() 1034 ) 1035 example_inputs = (x,) 1036 node_occurrence = { 1037 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1038 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1039 # quantize_per_channel for weights are const propagated 1040 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1041 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1042 } 1043 node_list = [ 1044 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1045 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1046 torch.ops.aten.conv2d.default, 1047 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1048 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1049 torch.ops.aten.cat.default, 1050 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1051 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1052 ] 1053 _, prepare_model, _ = self._test_quantizer( 1054 m, 1055 example_inputs, 1056 quantizer, 1057 node_occurrence, 1058 node_list, 1059 ) 1060 # Check Cat has share observer at input and output 1061 for node in prepare_model.graph.nodes: 1062 if node.op == "call_function" and node.target == torch.ops.aten.cat.default: 1063 cat_act_obs0 = getattr(prepare_model, node.args[0][0].target) 1064 cat_act_obs1 = getattr(prepare_model, node.args[0][1].target) 1065 cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) 1066 self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) 1067 self.assertTrue(isinstance(cat_act_obs1, ObserverBase)) 1068 self.assertTrue(isinstance(cat_out_obs, ObserverBase)) 1069 self.assertTrue(cat_act_obs0 is cat_act_obs1) 1070 self.assertTrue(cat_act_obs0 is cat_out_obs) 1071 1072 @skipIfNoX86 1073 def test_cat_recipe_single_input(self): 1074 r""" 1075 Test pattern: conv -> cat([input0,]) 1076 Since cat has 1 input node, they should also be with same observer. 1077 """ 1078 m = TestHelperModules.Conv2dCatSingleInput().eval() 1079 x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) 1080 quantizer = X86InductorQuantizer().set_global( 1081 xiq.get_default_x86_inductor_quantization_config() 1082 ) 1083 example_inputs = (x,) 1084 node_occurrence = { 1085 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1086 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1087 # quantize_per_channel for weights are const propagated 1088 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1089 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1090 } 1091 node_list = [ 1092 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1093 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1094 torch.ops.aten.conv2d.default, 1095 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1096 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1097 torch.ops.aten.cat.default, 1098 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1099 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1100 ] 1101 _, prepare_model, _ = self._test_quantizer( 1102 m, 1103 example_inputs, 1104 quantizer, 1105 node_occurrence, 1106 node_list, 1107 ) 1108 # Check Cat has share observer at input and output 1109 for node in prepare_model.graph.nodes: 1110 if node.op == "call_function" and node.target == torch.ops.aten.cat.default: 1111 cat_act_obs0 = getattr(prepare_model, node.args[0][0].target) 1112 cat_out_obs = getattr(prepare_model, next(iter(node.users)).target) 1113 self.assertTrue(isinstance(cat_act_obs0, ObserverBase)) 1114 self.assertTrue(isinstance(cat_out_obs, ObserverBase)) 1115 self.assertTrue(cat_act_obs0 is cat_out_obs) 1116 1117 @skipIfNoX86 1118 def test_avg_pool2d_recipe(self): 1119 r""" 1120 Test pattern: conv -> AvgPool2d 1121 Since AvgPool2d is a int8_in_int8_out_op, the inputs and outputs should with same observer. 1122 """ 1123 m = TestHelperModules.Conv2dAvgPool2d().eval() 1124 x = torch.randn(16, 3, 16, 16).contiguous(memory_format=torch.channels_last) 1125 quantizer = X86InductorQuantizer().set_global( 1126 xiq.get_default_x86_inductor_quantization_config() 1127 ) 1128 example_inputs = (x,) 1129 node_occurrence = { 1130 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1131 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1132 # quantize_per_channel for weights are const propagated 1133 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1134 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1135 } 1136 node_list = [ 1137 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1138 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1139 torch.ops.aten.conv2d.default, 1140 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1141 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1142 torch.ops.aten.avg_pool2d.default, 1143 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1144 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1145 ] 1146 _, prepare_model, _ = self._test_quantizer( 1147 m, 1148 example_inputs, 1149 quantizer, 1150 node_occurrence, 1151 node_list, 1152 ) 1153 for node in prepare_model.graph.nodes: 1154 if ( 1155 node.op == "call_function" 1156 and node.target is torch.ops.aten.avg_pool2d.default 1157 ): 1158 avgpool_node = node 1159 input_obs_of_avgpool = getattr( 1160 prepare_model, avgpool_node.args[0].target 1161 ) 1162 output_obs_of_avgpool = getattr( 1163 prepare_model, next(iter(avgpool_node.users)).target 1164 ) 1165 elif ( 1166 node.op == "call_function" 1167 and node.target is torch.ops.aten.conv2d.default 1168 ): 1169 conv_node = node 1170 output_obs_of_conv = getattr( 1171 prepare_model, next(iter(conv_node.users)).target 1172 ) 1173 self.assertTrue(isinstance(input_obs_of_avgpool, ObserverBase)) 1174 self.assertTrue(isinstance(output_obs_of_avgpool, ObserverBase)) 1175 self.assertTrue(isinstance(output_obs_of_conv, ObserverBase)) 1176 self.assertTrue(input_obs_of_avgpool is output_obs_of_avgpool) 1177 self.assertTrue(input_obs_of_avgpool is output_obs_of_conv) 1178 1179 @skipIfNoX86 1180 def test_linear(self): 1181 """ 1182 Test pattern of single linear with X86InductorQuantizer. 1183 """ 1184 with override_quantized_engine("x86"), torch.no_grad(): 1185 for use_bias in [True, False]: 1186 m = TestHelperModules.SingleLinearModule(use_bias).eval() 1187 example_inputs = (torch.randn(2, 4),) 1188 quantizer = X86InductorQuantizer().set_global( 1189 xiq.get_default_x86_inductor_quantization_config() 1190 ) 1191 node_occurrence = { 1192 # one for input and weight, one for output 1193 torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, 1194 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 1195 # quantize_per_channel for weights are const propagated 1196 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1197 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1198 } 1199 node_list = [ 1200 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1201 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1202 torch.ops.aten.linear.default, 1203 ] 1204 self._test_quantizer( 1205 m, 1206 example_inputs, 1207 quantizer, 1208 node_occurrence, 1209 node_list, 1210 ) 1211 1212 def _test_linear_unary_helper( 1213 self, 1214 post_op_module, 1215 post_op_aten, 1216 post_op_aten_inplace, 1217 post_op_algo_list=None, 1218 is_qat=False, 1219 is_dynamic=False, 1220 ): 1221 """ 1222 Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. 1223 """ 1224 use_bias_list = [True, False] 1225 # TODO test for inplace add after refactoring of capture_pre_autograd_graph 1226 inplace_list = [False] 1227 if post_op_algo_list is None: 1228 post_op_algo_list = [None] 1229 cases = itertools.product(use_bias_list, inplace_list, post_op_algo_list) 1230 with override_quantized_engine("x86"), torch.no_grad(): 1231 for use_bias, inplace, post_op_algo in cases: 1232 if inplace and post_op_aten_inplace is None: 1233 continue 1234 m = TestHelperModules.LinearUnaryModule( 1235 use_bias=use_bias, 1236 postop=post_op_module, 1237 inplace_postop=inplace, 1238 post_op_algo=post_op_algo, 1239 ).eval() 1240 example_inputs = (torch.randn(2, 4),) 1241 quantizer = X86InductorQuantizer().set_global( 1242 xiq.get_default_x86_inductor_quantization_config( 1243 is_qat=is_qat, 1244 is_dynamic=is_dynamic, 1245 ) 1246 ) 1247 quantize_per_tensor_op = ( 1248 torch.ops.quantized_decomposed.quantize_per_tensor.tensor 1249 if is_dynamic 1250 else torch.ops.quantized_decomposed.quantize_per_tensor.default 1251 ) 1252 dequantize_per_tensor_op = ( 1253 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 1254 if is_dynamic 1255 else torch.ops.quantized_decomposed.dequantize_per_tensor.default 1256 ) 1257 node_occurrence = { 1258 # one for input of the linear 1259 quantize_per_tensor_op: 1, 1260 dequantize_per_tensor_op: 1, 1261 # quantize_per_channel for weights are const propagated 1262 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1263 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1264 } 1265 node_list = [ 1266 quantize_per_tensor_op, 1267 dequantize_per_tensor_op, 1268 torch.ops.aten.linear.default, 1269 post_op_aten_inplace if inplace else post_op_aten, 1270 ] 1271 self._test_quantizer( 1272 m, 1273 example_inputs, 1274 quantizer, 1275 node_occurrence, 1276 node_list, 1277 is_qat=is_qat, 1278 ) 1279 1280 @skipIfNoX86 1281 def test_linear_unary(self): 1282 aten = torch.ops.aten 1283 self._test_linear_unary_helper(nn.ReLU, aten.relu.default, aten.relu_.default) 1284 self._test_linear_unary_helper( 1285 nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default 1286 ) 1287 self._test_linear_unary_helper( 1288 nn.GELU, aten.gelu.default, None, ["none", "tanh"] 1289 ) 1290 1291 @skipIfNoX86 1292 def test_linear_unary_qat(self): 1293 aten = torch.ops.aten 1294 self._test_linear_unary_helper( 1295 nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True 1296 ) 1297 self._test_linear_unary_helper( 1298 nn.LeakyReLU, aten.leaky_relu.default, aten.leaky_relu_.default, is_qat=True 1299 ) 1300 self._test_linear_unary_helper( 1301 nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_qat=True 1302 ) 1303 1304 @skipIfNoX86 1305 def test_linear_unary_dynamic(self): 1306 aten = torch.ops.aten 1307 self._test_linear_unary_helper( 1308 nn.ReLU, aten.relu.default, aten.relu_.default, is_dynamic=True 1309 ) 1310 self._test_linear_unary_helper( 1311 nn.LeakyReLU, 1312 aten.leaky_relu.default, 1313 aten.leaky_relu_.default, 1314 is_dynamic=True, 1315 ) 1316 self._test_linear_unary_helper( 1317 nn.GELU, aten.gelu.default, None, ["none", "tanh"], is_dynamic=True 1318 ) 1319 1320 @skipIfNoX86 1321 def test_linear_unary_dynamic_qat(self): 1322 aten = torch.ops.aten 1323 self._test_linear_unary_helper( 1324 nn.ReLU, aten.relu.default, aten.relu_.default, is_qat=True, is_dynamic=True 1325 ) 1326 self._test_linear_unary_helper( 1327 nn.LeakyReLU, 1328 aten.leaky_relu.default, 1329 aten.leaky_relu_.default, 1330 is_qat=True, 1331 is_dynamic=True, 1332 ) 1333 self._test_linear_unary_helper( 1334 nn.GELU, 1335 aten.gelu.default, 1336 None, 1337 ["none", "tanh"], 1338 is_qat=True, 1339 is_dynamic=True, 1340 ) 1341 1342 def _check_annotation_stat(self, gm, expected_stat_dict): 1343 # Check expected annotation statistics to ensure the annotation is correct 1344 1345 def _check_annotation(node): 1346 annot = node.meta.get(QUANT_ANNOTATION_KEY, None) 1347 if annot is None: 1348 return False, False 1349 return annot._annotated, annot._is_output_of_quantized_pattern 1350 1351 for node in gm.graph.nodes: 1352 if node.target in expected_stat_dict.keys(): 1353 annotated, is_quant_out = _check_annotation(node) 1354 expected_stat_dict[node.target]["annotated"] -= annotated 1355 expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out 1356 for op_stat in expected_stat_dict.values(): 1357 assert all(v == 0 for v in op_stat.values()) 1358 1359 def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): 1360 """ 1361 Test pattern of linear with binary post ops (such as add) with X86InductorQuantizer. 1362 Currently, only add as binary post op is supported. 1363 """ 1364 linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] 1365 # TODO test for inplace add after refactoring of capture_pre_autograd_graph 1366 inplace_add_list = [False] 1367 example_inputs = (torch.randn(2, 16),) 1368 quantizer = X86InductorQuantizer().set_global( 1369 xiq.get_default_x86_inductor_quantization_config( 1370 is_qat=is_qat, 1371 is_dynamic=is_dynamic, 1372 ) 1373 ) 1374 quantize_per_tensor_op = ( 1375 torch.ops.quantized_decomposed.quantize_per_tensor.tensor 1376 if is_dynamic 1377 else torch.ops.quantized_decomposed.quantize_per_tensor.default 1378 ) 1379 dequantize_per_tensor_op = ( 1380 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 1381 if is_dynamic 1382 else torch.ops.quantized_decomposed.dequantize_per_tensor.default 1383 ) 1384 cases = itertools.product(linear_pos_list, inplace_add_list) 1385 with override_quantized_engine("x86"), torch.no_grad(): 1386 for linear_pos, inplace_add in cases: 1387 m = TestHelperModules.LinearAddModule( 1388 inplace_add=inplace_add, linear_pos=linear_pos 1389 ).eval() 1390 if linear_pos != NodePosType.both: 1391 node_occurrence = { 1392 # Only one 1 q-dq for input of the linear 1393 # No q-dq for extra input node of add 1394 quantize_per_tensor_op: 1, 1395 dequantize_per_tensor_op: 1, 1396 # quantize_per_channel for weights are const propagated 1397 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1398 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1399 } 1400 else: 1401 # convert_pt2e disables duplicate dequant for dynamic quant 1402 num_dequant = 1 if is_dynamic else 2 1403 node_occurrence = { 1404 # One quantize_per_tensor for both linear nodes (shared) 1405 # Two dequantize_per_tensor for two linear nodes 1406 # No q-dq for extra input node of add 1407 quantize_per_tensor_op: 1, 1408 dequantize_per_tensor_op: num_dequant, 1409 # quantize_per_channel for weights are const propagated 1410 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1411 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 1412 } 1413 node_list = [ 1414 quantize_per_tensor_op, 1415 dequantize_per_tensor_op, 1416 torch.ops.aten.linear.default, 1417 ( 1418 torch.ops.aten.add_.Tensor 1419 if inplace_add 1420 else torch.ops.aten.add.Tensor 1421 ), 1422 ] 1423 fq_m = self._test_quantizer( 1424 m, 1425 example_inputs, 1426 quantizer, 1427 node_occurrence, 1428 node_list, 1429 is_qat=is_qat, 1430 )[-1] 1431 # One linear and add are fused. The other linear is quantized alone if present 1432 aten = torch.ops.aten 1433 add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor 1434 expected_annotation_stat = { 1435 aten.linear.default: { 1436 "annotated": 2 if linear_pos == NodePosType.both else 1, 1437 "is_quant_out": 1 if linear_pos == NodePosType.both else 0, 1438 }, 1439 add_op: {"annotated": 1, "is_quant_out": 1}, 1440 } 1441 self._check_annotation_stat(fq_m, expected_annotation_stat) 1442 1443 @skipIfNoX86 1444 def test_linear_binary(self): 1445 self._test_linear_binary_helper() 1446 1447 @skipIfNoX86 1448 def test_linear_binary_qat(self): 1449 self._test_linear_binary_helper(is_qat=True) 1450 1451 @skipIfNoX86 1452 def test_linear_binary_dynamic(self): 1453 self._test_linear_binary_helper(is_dynamic=True) 1454 1455 @skipIfNoX86 1456 def test_linear_binary_dynamic_qat(self): 1457 self._test_linear_binary_helper(is_qat=True, is_dynamic=True) 1458 1459 @skipIfNoX86 1460 def test_linear_binary2(self): 1461 """ 1462 Test Pattern: 1463 tmp = linear_1(x) 1464 tmp2 = linear_2(tmp) 1465 return tmp + tmp2 1466 Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 1467 """ 1468 example_inputs = (torch.randn(2, 16),) 1469 # TODO test for inplace add after refactoring of capture_pre_autograd_graph 1470 inplace_add_list = [False] 1471 is_qat_list = [False, True] 1472 is_dynamic_list = [False, True] 1473 cases = itertools.product(inplace_add_list, is_qat_list, is_dynamic_list) 1474 with override_quantized_engine("x86"), torch.no_grad(): 1475 for inplace_add, is_qat, is_dynamic in cases: 1476 quantizer = X86InductorQuantizer().set_global( 1477 xiq.get_default_x86_inductor_quantization_config( 1478 is_qat=is_qat, is_dynamic=is_dynamic 1479 ) 1480 ) 1481 m = TestHelperModules.LinearAddModule2(inplace_add=inplace_add).eval() 1482 quantize_per_tensor_op = ( 1483 torch.ops.quantized_decomposed.quantize_per_tensor.tensor 1484 if is_dynamic 1485 else torch.ops.quantized_decomposed.quantize_per_tensor.default 1486 ) 1487 dequantize_per_tensor_op = ( 1488 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 1489 if is_dynamic 1490 else torch.ops.quantized_decomposed.dequantize_per_tensor.default 1491 ) 1492 # Two q-dq nodes for inputs of linear nodes 1493 # No q-dq for extra input node of add 1494 node_occurrence = { 1495 quantize_per_tensor_op: 2, 1496 dequantize_per_tensor_op: 2, 1497 # quantize_per_channel for weights are const propagated 1498 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1499 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 1500 } 1501 node_list = [ 1502 torch.ops.quantized_decomposed.dequantize_per_channel.default, 1503 quantize_per_tensor_op, 1504 dequantize_per_tensor_op, 1505 torch.ops.aten.linear.default, 1506 ( 1507 torch.ops.aten.add_.Tensor 1508 if inplace_add 1509 else torch.ops.aten.add.Tensor 1510 ), 1511 ] 1512 fq_m = self._test_quantizer( 1513 m, 1514 example_inputs, 1515 quantizer, 1516 node_occurrence, 1517 node_list, 1518 )[-1] 1519 # One linear and add are fused. The other linear is quantized alone if present 1520 aten = torch.ops.aten 1521 add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor 1522 expected_annotation_stat = { 1523 aten.linear.default: { 1524 "annotated": 2, 1525 "is_quant_out": 1, 1526 }, 1527 add_op: {"annotated": 1, "is_quant_out": 1}, 1528 } 1529 self._check_annotation_stat(fq_m, expected_annotation_stat) 1530 1531 @skipIfNoX86 1532 def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): 1533 """ 1534 Test pattern of linear with binary + unary post ops (such as add + relu) with X86InductorQuantizer. 1535 Currently, only add as binary post op and relu as unary post op are supported. 1536 """ 1537 linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] 1538 # TODO test for inplace add after refactoring of capture_pre_autograd_graph 1539 inplace_add_list = [False] 1540 # TODO test for inplace relu after refactoring of capture_pre_autograd_graph 1541 inplace_relu_list = [False] 1542 example_inputs = (torch.randn(2, 16),) 1543 quantizer = X86InductorQuantizer().set_global( 1544 xiq.get_default_x86_inductor_quantization_config( 1545 is_qat=is_qat, 1546 is_dynamic=is_dynamic, 1547 ) 1548 ) 1549 quantize_per_tensor_op = ( 1550 torch.ops.quantized_decomposed.quantize_per_tensor.tensor 1551 if is_dynamic 1552 else torch.ops.quantized_decomposed.quantize_per_tensor.default 1553 ) 1554 dequantize_per_tensor_op = ( 1555 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 1556 if is_dynamic 1557 else torch.ops.quantized_decomposed.dequantize_per_tensor.default 1558 ) 1559 cases = itertools.product(linear_pos_list, inplace_add_list, inplace_relu_list) 1560 with override_quantized_engine("x86"), torch.no_grad(): 1561 for linear_pos, inplace_add, inplace_relu in cases: 1562 m = TestHelperModules.LinearAddReLUModule( 1563 inplace_add=inplace_add, 1564 linear_pos=linear_pos, 1565 inplace_relu=inplace_relu, 1566 ).eval() 1567 if linear_pos != NodePosType.both: 1568 node_occurrence = { 1569 # Only one q-dq node for input of the linear 1570 # No q-dq node for extra input node of add 1571 quantize_per_tensor_op: 1, 1572 dequantize_per_tensor_op: 1, 1573 # note: quantize op for weights are const propagated 1574 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1575 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1576 } 1577 else: 1578 # convert_pt2e disables duplicate dequant for dynamic quant 1579 num_dequant = 1 if is_dynamic else 2 1580 node_occurrence = { 1581 # One quantize_per_tensor for both linear nodes (shared) 1582 # Two dequantize_per_tensor for two linear nodes 1583 # No q-dq for extra input node of add 1584 quantize_per_tensor_op: 1, 1585 dequantize_per_tensor_op: num_dequant, 1586 # note: quantize op for weights are const propagated 1587 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1588 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 1589 } 1590 node_list = [ 1591 quantize_per_tensor_op, 1592 dequantize_per_tensor_op, 1593 torch.ops.aten.linear.default, 1594 ( 1595 torch.ops.aten.add_.Tensor 1596 if inplace_add 1597 else torch.ops.aten.add.Tensor 1598 ), 1599 ] 1600 fq_m = self._test_quantizer( 1601 m, 1602 example_inputs, 1603 quantizer, 1604 node_occurrence, 1605 node_list, 1606 )[-1] 1607 # linear, add, relu are fused 1608 # The other linear is quantized alone if present 1609 aten = torch.ops.aten 1610 add_op = aten.add_.Tensor if inplace_add else aten.add.Tensor 1611 relu_op = aten.relu_.default if inplace_relu else aten.relu.default 1612 expected_annotation_stat = { 1613 aten.linear.default: { 1614 "annotated": 2 if linear_pos == NodePosType.both else 1, 1615 "is_quant_out": 1 if linear_pos == NodePosType.both else 0, 1616 }, 1617 add_op: {"annotated": 1, "is_quant_out": 0}, 1618 relu_op: {"annotated": 1, "is_quant_out": 1}, 1619 } 1620 self._check_annotation_stat(fq_m, expected_annotation_stat) 1621 1622 @skipIfNoX86 1623 def test_linear_binary_unary(self): 1624 self._test_linear_binary_unary_helper() 1625 1626 @skipIfNoX86 1627 def test_linear_binary_unary_qat(self): 1628 self._test_linear_binary_unary_helper(is_qat=True) 1629 1630 @skipIfNoX86 1631 def test_linear_binary_unary_dynamic(self): 1632 self._test_linear_binary_unary_helper(is_dynamic=True) 1633 1634 @skipIfNoX86 1635 def test_linear_binary_unary_dynamic_qat(self): 1636 self._test_linear_binary_unary_helper(is_qat=True, is_dynamic=True) 1637 1638 @skipIfNoX86 1639 def test_linear_binary_unary_serials(self): 1640 """ 1641 Test pattern of 2 following up linear add relu with X86InductorQuantizer. 1642 """ 1643 is_qat_list = [False, True] 1644 is_dynamic_list = [False, True] 1645 cases = itertools.product(is_qat_list, is_dynamic_list) 1646 with override_quantized_engine("x86"), torch.no_grad(): 1647 for is_qat, is_dynamic in cases: 1648 m = TestHelperModules.SerialsLinearAddReLUModule().eval() 1649 example_inputs = (torch.randn(2, 16),) 1650 quantizer = X86InductorQuantizer().set_global( 1651 xiq.get_default_x86_inductor_quantization_config( 1652 is_qat=is_qat, 1653 is_dynamic=is_dynamic, 1654 ) 1655 ) 1656 quantize_per_tensor_op = ( 1657 torch.ops.quantized_decomposed.quantize_per_tensor.tensor 1658 if is_dynamic 1659 else torch.ops.quantized_decomposed.quantize_per_tensor.default 1660 ) 1661 dequantize_per_tensor_op = ( 1662 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 1663 if is_dynamic 1664 else torch.ops.quantized_decomposed.dequantize_per_tensor.default 1665 ) 1666 # convert_pt2e disables duplicate dequant for dynamic quant 1667 num_dequant = 3 if is_dynamic else 4 1668 node_occurrence = { 1669 # quantize_per_tensor: 1 for linear_1, 1 for linear_2/3 (shared), 1 for linear_4 1670 # dequantize_per_tensor: 1 for each linear 1671 # No q-dq for extra input node of add 1672 quantize_per_tensor_op: 3, 1673 dequantize_per_tensor_op: num_dequant, 1674 # quantize_per_channel for weights are const propagated 1675 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1676 torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, 1677 } 1678 node_list = [ 1679 torch.ops.quantized_decomposed.dequantize_per_channel.default, 1680 quantize_per_tensor_op, 1681 dequantize_per_tensor_op, 1682 torch.ops.aten.linear.default, 1683 torch.ops.aten.linear.default, 1684 torch.ops.aten.linear.default, 1685 torch.ops.aten.add.Tensor, 1686 torch.ops.aten.relu.default, 1687 ] 1688 fq_m = self._test_quantizer( 1689 m, 1690 example_inputs, 1691 quantizer, 1692 node_occurrence, 1693 node_list, 1694 )[-1] 1695 # Two linear nodes are quantized alone 1696 # The other two are fused with add and relu 1697 aten = torch.ops.aten 1698 expected_annotation_stat = { 1699 aten.linear.default: { 1700 "annotated": 4, 1701 "is_quant_out": 2, 1702 }, 1703 aten.add.Tensor: {"annotated": 2, "is_quant_out": 0}, 1704 aten.relu.default: {"annotated": 2, "is_quant_out": 2}, 1705 } 1706 self._check_annotation_stat(fq_m, expected_annotation_stat) 1707 1708 @skipIfTorchDynamo("very slow") 1709 @skipIfNoX86 1710 def test_qat_conv2d(self): 1711 """ 1712 Test QAT pattern of conv2d_bn with X86InductorQuantizer. 1713 """ 1714 with override_quantized_engine("x86"): 1715 m = TestHelperModules.SingleConv2dModule(with_bn=True) 1716 example_inputs = (torch.randn(2, 3, 16, 16),) 1717 quantizer = X86InductorQuantizer().set_global( 1718 xiq.get_default_x86_inductor_quantization_config(is_qat=True) 1719 ) 1720 node_occurrence = { 1721 # one for input and weight of the conv, one for output for the conv 1722 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 1723 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, 1724 # note: quantize op for weights are const propagated 1725 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1726 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1727 # BN should be folded into Conv 1728 torch.ops.aten._native_batch_norm_legit.default: 0, 1729 } 1730 node_list = [ 1731 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1732 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1733 torch.ops.aten.conv2d.default, 1734 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1735 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1736 ] 1737 self._test_quantizer( 1738 m, 1739 example_inputs, 1740 quantizer, 1741 node_occurrence, 1742 node_list, 1743 is_qat=True, 1744 ) 1745 1746 @skipIfTorchDynamo("very slow") 1747 @skipIfNoX86 1748 def test_qat_conv2d_unary(self): 1749 """ 1750 Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. 1751 Currently, only relu as unary post op is supported. 1752 """ 1753 unary_map = { 1754 "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], 1755 "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], 1756 "hardtanh": [ 1757 torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), 1758 torch.ops.aten.hardtanh.default, 1759 ], 1760 "hardtanh_inplace": [ 1761 torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), 1762 torch.ops.aten.hardtanh_.default, 1763 ], 1764 "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], 1765 "relu6_inplace": [ 1766 torch.nn.ReLU6(inplace=True), 1767 torch.ops.aten.hardtanh_.default, 1768 ], 1769 "hardswish": [ 1770 torch.nn.Hardswish(inplace=False), 1771 torch.ops.aten.hardswish.default, 1772 ], 1773 "hardswish_inplace": [ 1774 torch.nn.Hardswish(inplace=True), 1775 torch.ops.aten.hardswish_.default, 1776 ], 1777 "swish": [torch.nn.SiLU(inplace=False), torch.ops.aten.silu.default], 1778 "swish_inplace": [ 1779 torch.nn.SiLU(inplace=True), 1780 torch.ops.aten.silu_.default, 1781 ], 1782 } 1783 1784 with override_quantized_engine("x86"): 1785 for unary_op in unary_map.keys(): 1786 m = TestHelperModules.Conv2dUnaryModule( 1787 unary_map[unary_op][0], with_bn=True 1788 ) 1789 example_inputs = (torch.randn(2, 3, 16, 16),) 1790 quantizer = X86InductorQuantizer().set_global( 1791 xiq.get_default_x86_inductor_quantization_config(is_qat=True) 1792 ) 1793 node_occurrence = { 1794 # one for input and weight of the conv, one for output for the relu 1795 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1796 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1797 # note: quantize op for weights are const propagated 1798 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1799 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1800 # BN should be folded into Conv 1801 torch.ops.aten._native_batch_norm_legit.default: 0, 1802 } 1803 node_list = [ 1804 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1805 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1806 torch.ops.aten.conv2d.default, 1807 unary_map[unary_op][1], 1808 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1809 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1810 ] 1811 self._test_quantizer( 1812 m, 1813 example_inputs, 1814 quantizer, 1815 node_occurrence, 1816 node_list, 1817 is_qat=True, 1818 ) 1819 1820 @skipIfTorchDynamo("very slow") 1821 @skipIfNoX86 1822 def test_qat_conv2d_binary(self): 1823 """ 1824 Test qat pattern of conv2d_bn with binary post ops (such as add) with X86InductorQuantizer. 1825 Currently, only add as binary post op is supported. 1826 """ 1827 example_inputs = (torch.randn(2, 3, 6, 6),) 1828 quantizer = X86InductorQuantizer().set_global( 1829 xiq.get_default_x86_inductor_quantization_config(is_qat=True) 1830 ) 1831 with override_quantized_engine("x86"): 1832 for inplace_add in [True, False]: 1833 m = TestHelperModules.Conv2dAddModule( 1834 inplace_add=inplace_add, with_bn=True 1835 ) 1836 node_occurrence = { 1837 # one for input and weight of the conv 1838 # one for output for the add 1839 # one for extra input node of add 1840 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1841 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1842 # quantize_per_channel for weights are const propagated 1843 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1844 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1845 # BN should be folded into Conv 1846 torch.ops.aten._native_batch_norm_legit.default: 0, 1847 } 1848 node_list = [ 1849 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1850 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1851 torch.ops.aten.conv2d.default, 1852 ( 1853 torch.ops.aten.add_.Tensor 1854 if inplace_add 1855 else torch.ops.aten.add.Tensor 1856 ), 1857 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1858 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1859 ] 1860 self._test_quantizer( 1861 m, 1862 example_inputs, 1863 quantizer, 1864 node_occurrence, 1865 node_list, 1866 is_qat=True, 1867 ) 1868 1869 @skipIfTorchDynamo("very slow") 1870 @skipIfNoX86 1871 def test_qat_conv2d_binary2(self): 1872 """ 1873 Test qat Pattern: 1874 tmp = bn1(conv2d_1(x)) 1875 tmp2 = bn2(conv2d_2(tmp)) 1876 return tmp + tmp2 1877 Since conv2d_1 has 2 users, we should annotate conv2d_2 for binary fusion instead of conv2d_1 1878 """ 1879 example_inputs = (torch.randn(2, 3, 6, 6),) 1880 quantizer = X86InductorQuantizer().set_global( 1881 xiq.get_default_x86_inductor_quantization_config(is_qat=True) 1882 ) 1883 inplace_add_list = [True, False] 1884 with override_quantized_engine("x86"), torch.no_grad(): 1885 for inplace_add in inplace_add_list: 1886 m = TestHelperModules.Conv2dAddModule2(inplace_add=inplace_add) 1887 node_occurrence = { 1888 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1889 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, 1890 # quantize_per_channel for weights are const propagated 1891 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1892 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 1893 # BN should be folded into Conv 1894 torch.ops.aten._native_batch_norm_legit.default: 0, 1895 } 1896 node_list = [ 1897 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1898 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1899 torch.ops.aten.conv2d.default, 1900 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1901 ( 1902 torch.ops.aten.add_.Tensor 1903 if inplace_add 1904 else torch.ops.aten.add.Tensor 1905 ), 1906 ] 1907 self._test_quantizer( 1908 m, 1909 example_inputs, 1910 quantizer, 1911 node_occurrence, 1912 node_list, 1913 is_qat=True, 1914 ) 1915 1916 @skipIfTorchDynamo("very slow") 1917 @skipIfNoX86 1918 def test_qat_conv2d_binary_unary(self): 1919 """ 1920 Test QAT pattern of conv2d_bn with binary + unary post ops (such as add + relu) with X86InductorQuantizer. 1921 Currently, only add as binary post op and relu as unary post op are supported. 1922 """ 1923 example_inputs = (torch.randn(2, 3, 6, 6),) 1924 quantizer = X86InductorQuantizer().set_global( 1925 xiq.get_default_x86_inductor_quantization_config(is_qat=True) 1926 ) 1927 with override_quantized_engine("x86"): 1928 m = TestHelperModules.Conv2dAddReLUModule(with_bn=True) 1929 node_occurrence = { 1930 # one for input for conv 1931 # one for output for the relu 1932 # one for extra input node of add 1933 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1934 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1935 # note: quantize op for weights are const propagated 1936 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1937 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1938 # BN should be folded into Conv 1939 torch.ops.aten._native_batch_norm_legit.default: 0, 1940 } 1941 node_list = [ 1942 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1943 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1944 torch.ops.aten.conv2d.default, 1945 torch.ops.aten.add.Tensor, 1946 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1947 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1948 ] 1949 self._test_quantizer( 1950 m, 1951 example_inputs, 1952 quantizer, 1953 node_occurrence, 1954 node_list, 1955 is_qat=True, 1956 ) 1957 1958 @skipIfNoX86 1959 def test_dynamic_quant_linear(self): 1960 """ 1961 Test pattern of dynamic quantization of linear with X86InductorQuantizer. 1962 """ 1963 with override_quantized_engine("x86"), torch.no_grad(): 1964 m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() 1965 example_inputs = (torch.randn(1, 4, 64),) 1966 quantizer = X86InductorQuantizer().set_global( 1967 xiq.get_default_x86_inductor_quantization_config(is_dynamic=True) 1968 ) 1969 node_occurrence = { 1970 torch.ops.quantized_decomposed.choose_qparams.tensor: 1, 1971 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 1972 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 1973 # quantize_per_channel for weights are const propagated 1974 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1975 torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, 1976 } 1977 node_list = [ 1978 torch.ops.quantized_decomposed.choose_qparams.tensor, 1979 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 1980 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 1981 torch.ops.aten.linear.default, 1982 ] 1983 self._test_quantizer( 1984 m, 1985 example_inputs, 1986 quantizer, 1987 node_occurrence, 1988 node_list, 1989 ) 1990 1991 @skipIfNoX86 1992 def test_qat_dynamic_quant_linear(self): 1993 """ 1994 Test pattern of qat dynamic quantization of linear with X86InductorQuantizer. 1995 """ 1996 with override_quantized_engine("x86"), torch.no_grad(): 1997 m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() 1998 example_inputs = (torch.randn(1, 4, 64),) 1999 quantizer = X86InductorQuantizer().set_global( 2000 xiq.get_default_x86_inductor_quantization_config( 2001 is_qat=True, is_dynamic=True 2002 ) 2003 ) 2004 node_occurrence = { 2005 torch.ops.quantized_decomposed.choose_qparams.tensor: 1, 2006 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 2007 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 2008 # quantize_per_channel for weights are const propagated 2009 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 2010 torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, 2011 } 2012 node_list = [ 2013 torch.ops.quantized_decomposed.choose_qparams.tensor, 2014 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 2015 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 2016 torch.ops.aten.linear.default, 2017 ] 2018 self._test_quantizer( 2019 m, 2020 example_inputs, 2021 quantizer, 2022 node_occurrence, 2023 node_list, 2024 is_qat=True, 2025 ) 2026 2027 @skipIfNoX86 2028 def test_set_module_name_qconfig(self): 2029 """Test case for quantizing a specific submodule by configuring `set_module_name_qconfig`. 2030 2031 Expect that all linear layers within the submodule `sub` are quantized. 2032 """ 2033 2034 class Sub(torch.nn.Module): 2035 def __init__(self) -> None: 2036 super().__init__() 2037 self.linear1 = torch.nn.Linear(5, 10) 2038 self.relu1 = torch.nn.ReLU(inplace=False) 2039 self.linear2 = torch.nn.Linear(10, 5) 2040 2041 def forward(self, x): 2042 x = self.linear1(x) 2043 x = self.relu1(x) 2044 x = self.linear2(x) 2045 return x 2046 2047 class M(torch.nn.Module): 2048 def __init__(self) -> None: 2049 super().__init__() 2050 self.linear = torch.nn.Linear(5, 5) 2051 self.sub = Sub() 2052 2053 def forward(self, x): 2054 x = self.linear(x) 2055 x = self.sub(x) 2056 return x 2057 2058 m = M().eval() 2059 example_inputs = (torch.randn(3, 5),) 2060 # Set global to `None` and then default config for a specific submodule. 2061 quantizer = X86InductorQuantizer() 2062 quantizer.set_module_name_qconfig( 2063 "sub", xiq.get_default_x86_inductor_quantization_config() 2064 ) 2065 node_occurrence = { 2066 torch.ops.aten.linear.default: 3, 2067 # quantize and dequantize the input of two linear layers from `sub` 2068 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 2069 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, 2070 # dequantize the weight of two linear layers from `sub` 2071 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 2072 } 2073 node_list = [ 2074 # first linear is not quantized 2075 torch.ops.aten.linear.default, 2076 # two Q/DQ pairs for two linear layers from `sub` 2077 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2078 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2079 torch.ops.aten.linear.default, 2080 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2081 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2082 torch.ops.aten.linear.default, 2083 ] 2084 self._test_quantizer( 2085 m, 2086 example_inputs, 2087 quantizer, 2088 node_occurrence, 2089 node_list, 2090 ) 2091 2092 @skipIfNoX86 2093 def test_set_module_name_qconfig_with_underscores(self) -> None: 2094 """Test that if a module name has an underscore, we can still quantize it.""" 2095 2096 class M(torch.nn.Module): 2097 def __init__(self) -> None: 2098 super().__init__() 2099 # This module name has underscores, which can be part of a mangled name. 2100 self.foo_bar = torch.nn.Linear(2, 2) 2101 self.baz = torch.nn.Linear(2, 2) 2102 2103 def forward(self, x): 2104 return self.baz(self.foo_bar(x)) 2105 2106 # Set global to no quantization and then default config for a specific submodule whose name includes an underscore. 2107 quantizer = X86InductorQuantizer() 2108 quantizer.set_module_name_qconfig( 2109 "foo_bar", xiq.get_default_x86_inductor_quantization_config() 2110 ) 2111 example_inputs = (torch.randn(2, 2),) 2112 m = M().eval() 2113 m = capture_pre_autograd_graph(m, example_inputs) 2114 m = prepare_pt2e(m, quantizer) 2115 # Use a linear count instead of names because the names might change, but 2116 # the order should be the same. 2117 count = 0 2118 for n in m.graph.nodes: 2119 if n.op == "call_function" and n.target == torch.ops.aten.linear.default: 2120 # Get the weight observer to see the per-channel vs per-tensor. 2121 weight_observer_node = n.args[1] 2122 if count == 0: 2123 # for foo_bar. 2124 self.assertEqual( 2125 weight_observer_node.op, 2126 "call_module", 2127 f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", 2128 ) 2129 observer_instance = getattr(m, weight_observer_node.target) 2130 self.assertEqual( 2131 observer_instance.qscheme, torch.per_channel_symmetric 2132 ) 2133 else: 2134 # For baz it should have no observer at all. 2135 self.assertNotEqual( 2136 weight_observer_node.op, 2137 "call_module", 2138 f"The op of linear({count})'s weight_observer_node is {weight_observer_node.op} instead call_module", 2139 ) 2140 count += 1 2141 2142 @skipIfNoX86 2143 def test_set_module_name_and_module_type_case1(self): 2144 """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. 2145 2146 Expect that all linear layers are not quantized except the last one. 2147 """ 2148 2149 class M(torch.nn.Module): 2150 def __init__(self) -> None: 2151 super().__init__() 2152 self.linear1 = torch.nn.Linear(5, 10) 2153 self.linear2 = torch.nn.Linear(10, 5) 2154 self.sub = torch.nn.Linear(5, 5) 2155 2156 def forward(self, x): 2157 x = self.linear1(x) 2158 x = self.linear2(x) 2159 x = self.sub(x) 2160 return x 2161 2162 m = M().eval() 2163 example_inputs = (torch.randn(3, 5),) 2164 # Set `sub` with default config and then `None` for all `Linear`. 2165 # The config set by `set_module_name_qconfig` has higher priority than `set_module_type_qconfig`. 2166 quantizer = X86InductorQuantizer() 2167 quantizer.set_module_name_qconfig( 2168 "sub", xiq.get_default_x86_inductor_quantization_config() 2169 ).set_module_type_qconfig(torch.nn.Linear, None) 2170 2171 node_occurrence = { 2172 torch.ops.aten.linear.default: 3, 2173 # quantize and dequantize the input of the last linear 2174 torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, 2175 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 2176 # dequantize the weight of the last linear 2177 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 2178 } 2179 node_list = [ 2180 # first and second linear are not quantized 2181 torch.ops.aten.linear.default, 2182 torch.ops.aten.linear.default, 2183 # last linear is quantized 2184 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2185 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2186 torch.ops.aten.linear.default, 2187 ] 2188 self._test_quantizer( 2189 m, 2190 example_inputs, 2191 quantizer, 2192 node_occurrence, 2193 node_list, 2194 ) 2195 2196 @skipIfNoX86 2197 def test_set_module_name_and_module_type_case2(self): 2198 """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time. 2199 2200 Expect that all linear layers are quantized except the last one. 2201 """ 2202 2203 class M(torch.nn.Module): 2204 def __init__(self) -> None: 2205 super().__init__() 2206 self.linear1 = torch.nn.Linear(5, 10) 2207 self.linear2 = torch.nn.Linear(10, 5) 2208 self.sub = torch.nn.Linear(5, 5) 2209 2210 def forward(self, x): 2211 x = self.linear1(x) 2212 x = self.linear2(x) 2213 x = self.sub(x) 2214 return x 2215 2216 m = M().eval() 2217 example_inputs = (torch.randn(3, 5),) 2218 # Set `sub` with None and then default config for a all `Linear`. 2219 quantizer = X86InductorQuantizer() 2220 quantizer.set_module_name_qconfig("sub", None).set_module_type_qconfig( 2221 torch.nn.Linear, xiq.get_default_x86_inductor_quantization_config() 2222 ) 2223 2224 node_occurrence = { 2225 torch.ops.aten.linear.default: 3, 2226 # quantize and dequantize the input and output of the first and second linear 2227 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 2228 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, 2229 # dequantize the weight of the first and second linear 2230 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 2231 } 2232 node_list = [ 2233 # Q/DQ for first lienar 2234 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2235 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2236 torch.ops.aten.linear.default, 2237 # Q/DQ for second lienar 2238 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2239 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2240 torch.ops.aten.linear.default, 2241 # last linear is not quantized 2242 torch.ops.aten.linear.default, 2243 ] 2244 self._test_quantizer( 2245 m, 2246 example_inputs, 2247 quantizer, 2248 node_occurrence, 2249 node_list, 2250 ) 2251 2252 @skipIfNoX86 2253 def test_set_module_name_qconfig_for_dynamic_quant(self): 2254 """Test that quantize a specific submodule for dynamic quantization.""" 2255 2256 with override_quantized_engine("x86"), torch.no_grad(): 2257 for is_qat in [False, True]: 2258 m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() 2259 example_inputs = (torch.randn(1, 4, 64),) 2260 # only quantize `q_proj` `v_proj` 2261 dynamic_config = xiq.get_default_x86_inductor_quantization_config( 2262 is_dynamic=True, is_qat=is_qat 2263 ) 2264 quantizer = ( 2265 X86InductorQuantizer() 2266 .set_module_name_qconfig("q_proj", dynamic_config) 2267 .set_module_name_qconfig("v_proj", dynamic_config) 2268 ) 2269 node_occurrence = { 2270 # quantize and dequantize the input 2271 torch.ops.quantized_decomposed.choose_qparams.tensor: 1, 2272 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 2273 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 2274 # dequantize the weight of q_proj and v_proj 2275 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 2276 } 2277 node_list = [ 2278 # quantize and dequantize the input 2279 torch.ops.quantized_decomposed.choose_qparams.tensor, 2280 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 2281 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 2282 # q_proj 2283 torch.ops.aten.linear.default, 2284 # k_proj 2285 torch.ops.aten.linear.default, 2286 # v_proj 2287 torch.ops.aten.linear.default, 2288 ] 2289 self._test_quantizer( 2290 m, 2291 example_inputs, 2292 quantizer, 2293 node_occurrence, 2294 node_list, 2295 is_qat=is_qat, 2296 ) 2297 2298 @skipIfNoX86 2299 def test_set_module_name_with_mixed_configs(self): 2300 """Test case for setting module names with mixed static/dynamic or QAT/non-QAT configurations. 2301 2302 The config for 'v_proj' will always be ignored and raise a warning. 2303 """ 2304 with override_quantized_engine("x86"), torch.no_grad(): 2305 with self.assertWarns(UserWarning) as context: 2306 for q_is_dynamic, v_is_dynamic, q_is_qat, v_is_qat in itertools.product( 2307 [False, True], repeat=4 2308 ): 2309 if q_is_dynamic == v_is_dynamic and q_is_qat == v_is_qat: 2310 continue 2311 m = TestHelperModules.SelfAttnLikeModule(input_dim=64).eval() 2312 example_inputs = (torch.randn(1, 4, 64),) 2313 quantizer = ( 2314 X86InductorQuantizer() 2315 .set_module_name_qconfig( 2316 "q_proj", 2317 xiq.get_default_x86_inductor_quantization_config( 2318 is_qat=q_is_qat, is_dynamic=q_is_dynamic 2319 ), 2320 ) 2321 .set_module_name_qconfig( 2322 "v_proj", 2323 xiq.get_default_x86_inductor_quantization_config( 2324 is_qat=v_is_qat, is_dynamic=v_is_dynamic 2325 ), 2326 ) 2327 ) 2328 quant_op = ( 2329 torch.ops.quantized_decomposed.quantize_per_tensor.tensor 2330 if q_is_dynamic 2331 else torch.ops.quantized_decomposed.quantize_per_tensor.default 2332 ) 2333 dequant_op = ( 2334 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor 2335 if q_is_dynamic 2336 else torch.ops.quantized_decomposed.dequantize_per_tensor.default 2337 ) 2338 node_occurrence = { 2339 # quantize and dequantize the input 2340 quant_op: 1, 2341 dequant_op: 1, 2342 # only `q_proj` was quantized, dequantize its weight 2343 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 2344 } 2345 node_list = [ 2346 # quantize and dequantize the input 2347 quant_op, 2348 dequant_op, 2349 # q_proj 2350 torch.ops.aten.linear.default, 2351 # k_proj/v_proj 2352 torch.ops.aten.linear.default, 2353 torch.ops.aten.linear.default, 2354 ] 2355 self._test_quantizer( 2356 m, 2357 example_inputs, 2358 quantizer, 2359 node_occurrence, 2360 node_list, 2361 is_qat=q_is_qat, 2362 ) 2363 warning_msg = ( 2364 "Mixed QAT and Non-QAT" 2365 if q_is_qat != v_is_qat 2366 else "Mixed dynamic and static" 2367 ) 2368 self.assertTrue( 2369 any( 2370 warning_msg in msg 2371 for msg in [str(w.message) for w in context.warnings] 2372 ) 2373 ) 2374 2375 @skipIfNoX86 2376 def test_set_module_name_and_module_type_with_mixed_configs(self): 2377 """Test that set `module_name_qconfig` and `module_type_qconfig` at the same time with mixed the configs. 2378 2379 Expect that only the last linear(`sub`) is quantized using static quantization. 2380 """ 2381 2382 class M(torch.nn.Module): 2383 def __init__(self) -> None: 2384 super().__init__() 2385 self.linear1 = torch.nn.Linear(5, 10) 2386 self.linear2 = torch.nn.Linear(10, 5) 2387 self.sub = torch.nn.Linear(5, 5) 2388 2389 def forward(self, x): 2390 x = self.linear1(x) 2391 x = self.linear2(x) 2392 x = self.sub(x) 2393 return x 2394 2395 m = M().eval() 2396 example_inputs = (torch.randn(3, 5),) 2397 # Set `sub` with static config and then dynamic config for a all `Linear`(ignored). 2398 quantizer = X86InductorQuantizer() 2399 quantizer.set_module_name_qconfig( 2400 "sub", xiq.get_default_x86_inductor_quantization_config(is_dynamic=False) 2401 ).set_module_type_qconfig( 2402 torch.nn.Linear, 2403 xiq.get_default_x86_inductor_quantization_config(is_dynamic=True), 2404 ) 2405 2406 node_occurrence = { 2407 torch.ops.aten.linear.default: 3, 2408 # quantize and dequantize the input of the last linear 2409 torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, 2410 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 2411 # dequantize the weight of the last linear 2412 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 2413 } 2414 node_list = [ 2415 # first and second linear are not quantized 2416 torch.ops.aten.linear.default, 2417 torch.ops.aten.linear.default, 2418 # Q/DQ pairs for the last linear 2419 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2420 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2421 torch.ops.aten.linear.default, 2422 ] 2423 self._test_quantizer( 2424 m, 2425 example_inputs, 2426 quantizer, 2427 node_occurrence, 2428 node_list, 2429 ) 2430 2431 @skipIfNoX86 2432 def test_filter_conv2d_recipe(self): 2433 """ 2434 Test removing conv2d from default recipe of X86InductorQuantizer. 2435 """ 2436 with override_quantized_engine("x86"), torch.no_grad(): 2437 m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval() 2438 example_inputs = (torch.randn(2, 3, 16, 16),) 2439 quantizer = X86InductorQuantizer().set_global( 2440 xiq.get_default_x86_inductor_quantization_config() 2441 ) 2442 quantizer.set_module_type_qconfig(torch.nn.Conv2d, None) 2443 node_occurrence = { 2444 # one for input and weight of the conv 2445 torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, 2446 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, 2447 # note: quantize op for weights are const propagated 2448 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 2449 torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, 2450 } 2451 node_list = [ 2452 torch.ops.aten.conv2d.default, 2453 torch.ops.aten.relu.default, 2454 ] 2455 self._test_quantizer( 2456 m, 2457 example_inputs, 2458 quantizer, 2459 node_occurrence, 2460 node_list, 2461 ) 2462 2463 @skipIfNoX86 2464 def test_filter_linear_recipe(self): 2465 """ 2466 Test removing linear from default recipe of X86InductorQuantizer. 2467 """ 2468 with override_quantized_engine("x86"), torch.no_grad(): 2469 m = TestHelperModules.LinearUnaryModule( 2470 use_bias=True, 2471 postop=nn.ReLU, 2472 ).eval() 2473 example_inputs = (torch.randn(2, 4),) 2474 quantizer = X86InductorQuantizer().set_global( 2475 xiq.get_default_x86_inductor_quantization_config() 2476 ) 2477 quantizer.set_function_type_qconfig(torch.nn.functional.linear, None) 2478 node_occurrence = { 2479 # one for input and weight of the conv 2480 torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, 2481 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, 2482 # note: quantize op for weights are const propagated 2483 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 2484 torch.ops.quantized_decomposed.dequantize_per_channel.default: 0, 2485 } 2486 node_list = [ 2487 torch.ops.aten.linear.default, 2488 torch.ops.aten.relu.default, 2489 ] 2490 self._test_quantizer( 2491 m, 2492 example_inputs, 2493 quantizer, 2494 node_occurrence, 2495 node_list, 2496 ) 2497 2498 @skipIfNoX86 2499 def test_filter_maxpool2d_recipe(self): 2500 """ 2501 Test removing maxpool2d from default recipe of X86InductorQuantizer. 2502 """ 2503 with override_quantized_engine("x86"), torch.no_grad(): 2504 m = TestHelperModules.Conv2dUnaryModule(torch.nn.ReLU(inplace=False)).eval() 2505 example_inputs = (torch.randn(2, 3, 16, 16),) 2506 quantizer = X86InductorQuantizer().set_global( 2507 xiq.get_default_x86_inductor_quantization_config() 2508 ) 2509 quantizer.set_function_type_qconfig(torch.nn.functional.max_pool2d, None) 2510 node_occurrence = { 2511 # one for input and weight of the conv 2512 torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, 2513 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 2514 # note: quantize op for weights are const propagated 2515 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 2516 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 2517 } 2518 node_list = [ 2519 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2520 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2521 torch.ops.aten.conv2d.default, 2522 torch.ops.aten.relu.default, 2523 torch.ops.aten.max_pool2d.default, 2524 ] 2525 self._test_quantizer( 2526 m, 2527 example_inputs, 2528 quantizer, 2529 node_occurrence, 2530 node_list, 2531 ) 2532 2533 @skipIfNoX86 2534 def test_attention_block(self): 2535 """ 2536 Test pattern of Attention like Block with X86InductorQuantizer. 2537 """ 2538 for annotate_matmul in [False, True]: 2539 with override_quantized_engine("x86"), torch.no_grad(): 2540 m = TestHelperModules.SelfAttnLikeModule( 2541 input_dim=64 * 16, 2542 transpose_for_score=True, 2543 num_attention_heads=16, 2544 attention_head_size=64, 2545 ).eval() 2546 example_inputs = (torch.randn(2, 384, 1024),) 2547 2548 m(*example_inputs) 2549 2550 quantizer = X86InductorQuantizer().set_global( 2551 xiq.get_default_x86_inductor_quantization_config() 2552 ) 2553 2554 if annotate_matmul: 2555 quantizer.set_function_type_qconfig( 2556 torch.matmul, quantizer.get_global_quantization_config() 2557 ) 2558 2559 node_occurrence = { 2560 torch.ops.quantized_decomposed.quantize_per_tensor.default: ( 2561 5 if annotate_matmul else 1 2562 ), 2563 torch.ops.quantized_decomposed.dequantize_per_tensor.default: ( 2564 7 if annotate_matmul else 3 2565 ), 2566 # quantize_per_channel for weights are const propagated 2567 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 2568 torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, 2569 } 2570 if annotate_matmul: 2571 node_list = [ 2572 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2573 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2574 torch.ops.aten.linear.default, 2575 torch.ops.aten.view.default, 2576 torch.ops.aten.permute.default, 2577 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2578 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2579 torch.ops.aten.matmul.default, 2580 torch.ops.aten.div.Tensor, 2581 torch.ops.aten.softmax.int, 2582 ] 2583 else: 2584 node_list = [ 2585 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2586 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2587 torch.ops.aten.linear.default, 2588 torch.ops.aten.view.default, 2589 torch.ops.aten.permute.default, 2590 torch.ops.aten.matmul.default, 2591 torch.ops.aten.div.Tensor, 2592 torch.ops.aten.softmax.int, 2593 ] 2594 self._test_quantizer( 2595 m, 2596 example_inputs, 2597 quantizer, 2598 node_occurrence, 2599 node_list, 2600 ) 2601