1# Owner(s): ["oncall: quantization"] 2import copy 3import operator 4import unittest 5from typing import Any, Optional, Tuple, Type 6 7import torch 8from torch._export import capture_pre_autograd_graph 9from torch._utils_internal import capture_pre_autograd_graph_using_training_ir 10from torch.ao.quantization import ( 11 default_fake_quant, 12 FusedMovingAvgObsFakeQuantize, 13 MovingAverageMinMaxObserver, 14 MovingAveragePerChannelMinMaxObserver, 15 QConfigMapping, 16) 17from torch.ao.quantization.backend_config import get_qnnpack_backend_config 18from torch.ao.quantization.qconfig import ( 19 default_per_channel_symmetric_qnnpack_qat_qconfig, 20 default_symmetric_qnnpack_qat_qconfig, 21) 22from torch.ao.quantization.quantize_fx import prepare_qat_fx 23from torch.ao.quantization.quantize_pt2e import ( 24 _convert_to_reference_decomposed_fx, 25 convert_pt2e, 26 prepare_pt2e, 27 prepare_qat_pt2e, 28) 29from torch.ao.quantization.quantizer import ( 30 DerivedQuantizationSpec, 31 QuantizationAnnotation, 32 QuantizationSpec, 33 Quantizer, 34) 35from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 36 get_symmetric_quantization_config, 37 XNNPACKQuantizer, 38) 39from torch.testing._internal.common_cuda import TEST_CUDA 40from torch.testing._internal.common_quantization import ( 41 NodeSpec as ns, 42 QuantizationTestCase, 43 skip_if_no_torchvision, 44 skipIfNoQNNPACK, 45) 46from torch.testing._internal.common_quantized import override_quantized_engine 47 48 49class PT2EQATTestCase(QuantizationTestCase): 50 """ 51 Base QuantizationTestCase for PT2E QAT with some helper methods. 52 """ 53 54 class _BaseConvBnModel(torch.nn.Module): 55 def __init__( 56 self, 57 conv_class: Type[torch.nn.Module], 58 bn_class: Type[torch.nn.Module], 59 has_conv_bias: bool, 60 has_bn: bool, 61 has_relu: bool, 62 **conv_kwargs, 63 ): 64 super().__init__() 65 conv_kwargs.setdefault("in_channels", 3) 66 conv_kwargs.setdefault("out_channels", 3) 67 conv_kwargs.setdefault("kernel_size", 3) 68 conv_kwargs.setdefault("bias", has_conv_bias) 69 self.conv = conv_class(**conv_kwargs) 70 self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None 71 self.relu = torch.nn.ReLU() if has_relu else None 72 73 def forward(self, x): 74 x = self.conv(x) 75 if self.bn is not None: 76 x = self.bn(x) 77 if self.relu is not None: 78 x = self.relu(x) 79 return x 80 81 def _get_conv_bn_model( 82 self, 83 has_conv_bias: bool = True, 84 has_bn: bool = True, 85 has_relu: bool = False, 86 transpose: bool = False, 87 **conv_kwargs, 88 ): 89 """ 90 Return an instance of a simple test model containing the 91 conv[-bn][-relu] pattern. By default, this returns a 92 conv-bn model with conv bias. 93 """ 94 return self._BaseConvBnModel( 95 self.conv_transpose_class if transpose else self.conv_class, 96 self.bn_class, 97 has_conv_bias, 98 has_bn, 99 has_relu, 100 **conv_kwargs, 101 ) 102 103 def _verify_symmetric_xnnpack_qat_numerics( 104 self, 105 model: torch.nn.Module, 106 example_inputs: Tuple[Any, ...], 107 ): 108 self._verify_symmetric_xnnpack_qat_numerics_helper( 109 model, 110 example_inputs, 111 is_per_channel=True, 112 ) 113 self._verify_symmetric_xnnpack_qat_numerics_helper( 114 model, 115 example_inputs, 116 is_per_channel=False, 117 ) 118 119 def _verify_symmetric_xnnpack_qat_numerics_helper( 120 self, 121 model: torch.nn.Module, 122 example_inputs: Tuple[Any, ...], 123 is_per_channel: bool, 124 verify_convert: bool = True, 125 ): 126 """ 127 Helper method to verify that the QAT numerics for PT2E quantization match those of 128 FX graph mode quantization for symmetric qnnpack. 129 """ 130 # resetting dynamo cache 131 torch._dynamo.reset() 132 MANUAL_SEED = 100 133 134 # PT2 export 135 136 model_pt2e = copy.deepcopy(model) 137 quantizer = XNNPACKQuantizer() 138 quantizer.set_global( 139 get_symmetric_quantization_config( 140 is_per_channel=is_per_channel, is_qat=True 141 ) 142 ) 143 model_pt2e = capture_pre_autograd_graph( 144 model_pt2e, 145 example_inputs, 146 ) 147 model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) 148 torch.manual_seed(MANUAL_SEED) 149 after_prepare_result_pt2e = model_pt2e(*example_inputs) 150 151 model_fx = copy.deepcopy(model) 152 if is_per_channel: 153 default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig 154 else: 155 default_qconfig = default_symmetric_qnnpack_qat_qconfig 156 qconfig_mapping = QConfigMapping().set_global(default_qconfig) 157 backend_config = get_qnnpack_backend_config() 158 model_fx = prepare_qat_fx( 159 model_fx, qconfig_mapping, example_inputs, backend_config=backend_config 160 ) 161 torch.manual_seed(MANUAL_SEED) 162 after_prepare_result_fx = model_fx(*example_inputs) 163 164 # Verify that numerics match 165 self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx) 166 167 if verify_convert: 168 # We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e 169 torch.ao.quantization.move_exported_model_to_eval(model_pt2e) 170 model_pt2e = convert_pt2e(model_pt2e) 171 quant_result_pt2e = model_pt2e(*example_inputs) 172 model_fx.eval() 173 model_fx = _convert_to_reference_decomposed_fx( 174 model_fx, 175 backend_config=backend_config, 176 ) 177 quant_result_fx = model_fx(*example_inputs) 178 self.assertEqual(quant_result_pt2e, quant_result_fx) 179 180 def _verify_symmetric_xnnpack_qat_graph( 181 self, 182 m: torch.fx.GraphModule, 183 example_inputs: Tuple[Any, ...], 184 has_relu: bool, 185 has_bias: bool = True, 186 is_cuda: bool = False, 187 expected_conv_literal_args: Optional[Tuple[Any, ...]] = None, 188 # TODO: set this to true by default 189 verify_convert: bool = False, 190 ): 191 self._verify_symmetric_xnnpack_qat_graph_helper( 192 m, 193 example_inputs, 194 is_per_channel=True, 195 has_relu=has_relu, 196 has_bias=has_bias, 197 is_cuda=is_cuda, 198 expected_conv_literal_args=expected_conv_literal_args, 199 verify_convert=verify_convert, 200 ) 201 self._verify_symmetric_xnnpack_qat_graph_helper( 202 m, 203 example_inputs, 204 is_per_channel=False, 205 has_relu=has_relu, 206 has_bias=has_bias, 207 is_cuda=is_cuda, 208 expected_conv_literal_args=expected_conv_literal_args, 209 verify_convert=verify_convert, 210 ) 211 212 def _verify_symmetric_xnnpack_qat_graph_helper( 213 self, 214 m: torch.fx.GraphModule, 215 example_inputs: Tuple[Any, ...], 216 is_per_channel: bool, 217 has_relu: bool, 218 has_bias: bool = True, 219 is_cuda: bool = False, 220 expected_conv_literal_args: Optional[Tuple[Any, ...]] = None, 221 verify_convert: bool = False, 222 ): 223 """ 224 Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern 225 with fake quantizes inserted into the correct places. 226 # TODO: also verify that metadata is copied over to the new nodes. 227 """ 228 m = copy.deepcopy(m) 229 quantizer = XNNPACKQuantizer() 230 quantizer.set_global( 231 get_symmetric_quantization_config(is_per_channel, is_qat=True) 232 ) 233 m = capture_pre_autograd_graph( 234 m, 235 example_inputs, 236 ) 237 m = prepare_qat_pt2e(m, quantizer) 238 m(*example_inputs) 239 240 # Verify: getitem output activation fake quantize 241 output_node = list(m.graph.nodes)[-1] 242 output_fq_node = output_node.args[0][0] 243 self.assertTrue(output_fq_node.target.startswith("activation_post_process_")) 244 output_fq_mod = getattr(m, output_fq_node.target) 245 self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize) 246 self.assertEqual( 247 type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver 248 ) 249 self.assertEqual(output_fq_mod.dtype, torch.int8) 250 self.assertEqual(output_fq_mod.quant_min, -128) 251 self.assertEqual(output_fq_mod.quant_max, 127) 252 253 # Verify: getitem(bn, 0) or relu(getitem(bn, 0)) 254 if has_relu: 255 relu_node = output_fq_node.args[0] 256 getitem_node = relu_node.args[0] 257 self.assertEqual(relu_node.target, torch.ops.aten.relu.default) 258 else: 259 relu_node = None 260 getitem_node = output_fq_node.args[0] 261 262 is_training_ir_flag = capture_pre_autograd_graph_using_training_ir() 263 if is_training_ir_flag: 264 # The relu node takes in the output of bn. 265 # See NOTE [training ir has no getitem for bn node]. 266 bn_node = getitem_node 267 self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default) 268 else: 269 # TODO: This branch is going through a deprecated branch and should be deleted soon, 270 # after capture_pre_autograd_graph fully migrate to training IR 271 # T199018392 272 self.assertEqual(getitem_node.target, operator.getitem) 273 bn_node = getitem_node.args[0] 274 275 expected_bn_op = None 276 if is_cuda: 277 if torch.version.cuda is not None: 278 expected_bn_op = torch.ops.aten.cudnn_batch_norm.default 279 elif torch.version.hip is not None: 280 expected_bn_op = torch.ops.aten.miopen_batch_norm.default 281 else: 282 expected_bn_op = torch.ops.aten._native_batch_norm_legit.default 283 self.assertEqual(bn_node.target, expected_bn_op) 284 285 # Verify: conv / scale_factor.reshape [+ bias.reshape] 286 if has_bias: 287 add_bias_node = bn_node.args[0] 288 (div_scale_factor_node, bias_reshape_node) = add_bias_node.args 289 self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor) 290 self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default) 291 else: 292 div_scale_factor_node = bn_node.args[0] 293 (conv_node, scale_factor_reshape_node) = div_scale_factor_node.args 294 conv_op = conv_node.target 295 self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor) 296 self.assertTrue(_is_conv_node(conv_node)) 297 self.assertEqual( 298 scale_factor_reshape_node.target, torch.ops.aten.reshape.default 299 ) 300 301 # Verify: conv literal args 302 if expected_conv_literal_args is not None: 303 assert ( 304 len(expected_conv_literal_args) == 6 305 ), "wrong num conv args, bad test setup" 306 for i in range(6): 307 if i + 3 < len(conv_node.args): 308 self.assertEqual( 309 conv_node.args[i + 3], expected_conv_literal_args[i] 310 ) 311 312 # Verify: conv input activation fake quantize 313 conv_input_fq_node = conv_node.args[0] 314 conv_input_node = conv_input_fq_node.args[0] 315 self.assertTrue( 316 conv_input_fq_node.target.startswith("activation_post_process_") 317 ) 318 conv_input_fq_mod = getattr(m, conv_input_fq_node.target) 319 self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize) 320 self.assertEqual( 321 type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver 322 ) 323 self.assertEqual(conv_input_fq_mod.dtype, torch.int8) 324 self.assertEqual(conv_input_fq_mod.quant_min, -128) 325 self.assertEqual(conv_input_fq_mod.quant_max, 127) 326 self.assertTrue(conv_input_node.op, "placeholder") 327 328 # Verify: conv weight fake quantize 329 conv_weight_fq_node = conv_node.args[1] 330 self.assertTrue( 331 conv_weight_fq_node.target.startswith("activation_post_process_") 332 ) 333 conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target) 334 if is_per_channel: 335 expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver 336 else: 337 expected_weight_observer_type = MovingAverageMinMaxObserver 338 self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize) 339 self.assertEqual( 340 type(conv_weight_fq_mod.activation_post_process), 341 expected_weight_observer_type, 342 ) 343 self.assertEqual(conv_weight_fq_mod.dtype, torch.int8) 344 self.assertEqual(conv_weight_fq_mod.quant_min, -127) 345 self.assertEqual(conv_weight_fq_mod.quant_max, 127) 346 347 # Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias) 348 zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None 349 mul_weight_scale_factor_node = conv_weight_fq_node.args[0] 350 ( 351 conv_weight_fq_node, 352 scale_factor_reshape_node, 353 ) = mul_weight_scale_factor_node.args 354 if has_bias: 355 self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default) 356 else: 357 self.assertTrue(zero_bias_node is None) 358 self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor) 359 self.assertEqual( 360 scale_factor_reshape_node.target, torch.ops.aten.reshape.default 361 ) 362 363 # Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps) 364 scale_factor_node = scale_factor_reshape_node.args[0] 365 (bn_weight_node, sqrt_node) = scale_factor_node.args 366 bn_running_var_add_node = sqrt_node.args[0] 367 (bn_running_var_node, eps) = bn_running_var_add_node.args 368 self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor) 369 if is_training_ir_flag: 370 self.assertTrue("bn.weight" in bn_weight_node.target) 371 self.assertTrue("bn.running_var" in bn_running_var_node.target) 372 else: 373 self.assertTrue("bn_weight" in bn_weight_node.target) 374 self.assertTrue("bn_running_var" in bn_running_var_node.target) 375 self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default) 376 self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor) 377 self.assertEqual(eps, 1e-5) 378 379 # Optionally check the converted graph 380 if verify_convert: 381 m = convert_pt2e(m) 382 m(*example_inputs) 383 384 if is_per_channel: 385 conv_weight_dq_op = ( 386 torch.ops.quantized_decomposed.dequantize_per_channel.default 387 ) 388 node_occurrence = { 389 ns.call_function( 390 torch.ops.quantized_decomposed.quantize_per_tensor.default 391 ): 2, 392 ns.call_function( 393 torch.ops.quantized_decomposed.dequantize_per_tensor.default 394 ): 2, 395 ns.call_function( 396 torch.ops.quantized_decomposed.dequantize_per_channel.default 397 ): 1, 398 } 399 else: 400 conv_weight_dq_op = ( 401 torch.ops.quantized_decomposed.dequantize_per_tensor.default 402 ) 403 node_occurrence = { 404 ns.call_function( 405 torch.ops.quantized_decomposed.quantize_per_tensor.default 406 ): 2, 407 ns.call_function( 408 torch.ops.quantized_decomposed.dequantize_per_tensor.default 409 ): 3, 410 } 411 node_list = [ 412 ns.call_function( 413 torch.ops.quantized_decomposed.quantize_per_tensor.default 414 ), 415 ns.call_function( 416 torch.ops.quantized_decomposed.dequantize_per_tensor.default 417 ), 418 ns.call_function(conv_weight_dq_op), 419 ns.call_function(conv_op), 420 ns.call_function( 421 torch.ops.quantized_decomposed.quantize_per_tensor.default 422 ), 423 ns.call_function( 424 torch.ops.quantized_decomposed.dequantize_per_tensor.default 425 ), 426 ] 427 428 self.checkGraphModuleNodes( 429 m, 430 expected_node_list=node_list, 431 expected_node_occurrence=node_occurrence, 432 ) 433 434 435class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): 436 """ 437 Base TestCase to be used for all conv-bn[-relu] fusion patterns. 438 """ 439 440 # TODO: how can we avoid adding every new test to dynamo/expected_test_failures? 441 # Otherwise it fails with the following error: 442 # torch._dynamo.exc.InternalTorchDynamoError: 443 # 'QuantizationConfig' object has no attribute '__bool__' 444 445 def setUp(self): 446 # NB: Skip the test if this is a base class, this is to handle the test 447 # discovery logic in buck which finds and runs all tests here including 448 # the base class which we don't want to run 449 if self.id() and "_Base" in self.id(): 450 self.skipTest("Skipping test running from base class") 451 452 def test_qat_conv_no_bias(self): 453 m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True) 454 m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False) 455 self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs) 456 self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs) 457 458 def test_qat_conv_bn_fusion(self): 459 m = self._get_conv_bn_model() 460 self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) 461 self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) 462 463 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 464 def test_qat_conv_bn_fusion_cuda(self): 465 m = self._get_conv_bn_model().cuda() 466 example_inputs = (self.example_inputs[0].cuda(),) 467 self._verify_symmetric_xnnpack_qat_graph( 468 m, 469 example_inputs, 470 has_relu=False, 471 is_cuda=True, 472 ) 473 self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) 474 475 def test_qat_conv_bn_fusion_literal_args(self): 476 class M(torch.nn.Module): 477 def __init__(self, conv_class, bn_class): 478 super().__init__() 479 self.conv = conv_class(3, 3, 3, stride=2, padding=4) 480 self.bn = bn_class(3) 481 482 def forward(self, x): 483 x = self.conv(x) 484 x = self.bn(x) 485 return x 486 487 assert self.dim in [1, 2] 488 if self.dim == 1: 489 # stride, padding, dilation, transposed, output_padding, groups 490 conv_args = ((2,), (4,), (1,), False, (0,), 1) 491 example_inputs = (torch.randn(1, 3, 5),) 492 else: 493 # stride, padding, dilation, transposed, output_padding, groups 494 conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1) 495 example_inputs = (torch.randn(1, 3, 5, 5),) 496 497 m = M(self.conv_class, self.bn_class) 498 499 self._verify_symmetric_xnnpack_qat_graph( 500 m, 501 example_inputs, 502 has_relu=False, 503 expected_conv_literal_args=conv_args, 504 ) 505 self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) 506 507 def test_qat_conv_bn_fusion_no_conv_bias(self): 508 class M2(torch.nn.Module): 509 """ 510 Mixed conv + BN with and without conv bias. 511 """ 512 513 def __init__(self, conv_class, bn_class): 514 super().__init__() 515 self.conv1 = conv_class(3, 3, 3, bias=False) 516 self.bn1 = bn_class(3) 517 self.conv2 = conv_class(3, 3, 3, bias=True) 518 self.bn2 = bn_class(3) 519 520 def forward(self, x): 521 x = self.conv1(x) 522 x = self.bn1(x) 523 x = self.conv2(x) 524 x = self.bn2(x) 525 return x 526 527 m1 = self._get_conv_bn_model(has_conv_bias=False) 528 m2 = M2(self.conv_class, self.bn_class) 529 530 assert self.dim in [1, 2] 531 if self.dim == 1: 532 example_inputs = (torch.randn(3, 3, 5),) 533 else: 534 example_inputs = (torch.randn(3, 3, 5, 5),) 535 536 self._verify_symmetric_xnnpack_qat_graph( 537 m1, 538 example_inputs, 539 has_relu=False, 540 has_bias=False, 541 ) 542 self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs) 543 self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs) 544 545 def test_qat_conv_bn_relu_fusion(self): 546 m = self._get_conv_bn_model(has_relu=True) 547 self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) 548 self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) 549 550 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 551 def test_qat_conv_bn_relu_fusion_cuda(self): 552 m = self._get_conv_bn_model(has_relu=True).cuda() 553 example_inputs = (self.example_inputs[0].cuda(),) 554 self._verify_symmetric_xnnpack_qat_graph( 555 m, 556 example_inputs, 557 has_relu=True, 558 is_cuda=True, 559 ) 560 self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) 561 562 def test_qat_conv_bn_relu_fusion_no_conv_bias(self): 563 m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True) 564 self._verify_symmetric_xnnpack_qat_graph( 565 m, 566 self.example_inputs, 567 has_relu=True, 568 has_bias=False, 569 ) 570 self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) 571 572 def test_qat_inplace_add_relu(self): 573 class M(torch.nn.Module): 574 def __init__(self, conv_class): 575 super().__init__() 576 self.conv = conv_class(1, 1, 1) 577 self.relu = torch.nn.ReLU(inplace=True) 578 579 def forward(self, x): 580 x0 = x 581 x = self.conv(x) 582 x += x0 583 x = self.relu(x) 584 return x 585 586 assert self.dim in [1, 2] 587 if self.dim == 1: 588 example_inputs = (torch.randn(1, 1, 3),) 589 else: 590 example_inputs = (torch.randn(1, 1, 3, 3),) 591 592 m = M(self.conv_class) 593 self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) 594 595 def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self): 596 """ 597 Test the case where the placeholder node for the [conv - bn - getitem] pattern 598 is also a getitem node: 599 600 some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem 601 602 We want the metadata to be copied from the `conv_bn_getitem` node, not from 603 the `unrelated_getitem` node, which is not part of the conv-bn pattern but 604 is returned as part of the match anyway (as a placeholder). 605 """ 606 607 class M(torch.nn.Module): 608 def __init__(self, conv_class, bn_class): 609 super().__init__() 610 self.bn1 = bn_class(3) 611 self.conv = conv_class(3, 3, 3) 612 self.bn2 = bn_class(3) 613 614 def forward(self, x): 615 x = self.bn1(x) 616 x = self.conv(x) 617 x = self.bn2(x) 618 return x 619 620 def _get_getitem_nodes(m: torch.fx.GraphModule): 621 """ 622 Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph. 623 """ 624 unrelated_getitem_node, conv_bn_getitem_node = None, None 625 for node in m.graph.nodes: 626 if ( 627 node.target != operator.getitem 628 or node.args[0].target 629 != torch.ops.aten._native_batch_norm_legit.default 630 ): 631 continue 632 if node.args[0].args[0].op == "placeholder": 633 unrelated_getitem_node = node 634 else: 635 conv_bn_getitem_node = node 636 assert ( 637 unrelated_getitem_node is not None 638 ), "did not find unrelated getitem node, bad test setup" 639 assert ( 640 conv_bn_getitem_node is not None 641 ), "did not find conv bn getitem node, bad test setup" 642 return (unrelated_getitem_node, conv_bn_getitem_node) 643 644 # Program capture 645 m = M(self.conv_class, self.bn_class) 646 m = capture_pre_autograd_graph(m, self.example_inputs) 647 m.graph.eliminate_dead_code() 648 m.recompile() 649 (_, original_conv_bn_getitem_node) = _get_getitem_nodes(m) 650 651 # Prepare QAT 652 quantizer = XNNPACKQuantizer() 653 quantizer.set_global( 654 get_symmetric_quantization_config(is_per_channel=False, is_qat=True) 655 ) 656 m = prepare_qat_pt2e(m, quantizer) 657 (unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m) 658 659 # Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem` 660 original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[ 661 "quantization_annotation" 662 ] 663 conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"] 664 self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta) 665 self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta) 666 667 def test_qat_update_shared_qspec(self): 668 """ 669 Test the case where nodes used in SharedQuantizationSpec were replaced 670 during QAT subgraph rewriting. 671 """ 672 673 class M(torch.nn.Module): 674 def __init__(self, conv_class, bn_class): 675 super().__init__() 676 self.conv = conv_class(3, 3, 3) 677 self.bn = bn_class(3) 678 self.hardtanh = torch.nn.Hardtanh() 679 680 def forward(self, x): 681 x = self.conv(x) 682 x = self.bn(x) 683 x = self.hardtanh(x) 684 return x 685 686 m = M(self.conv_class, self.bn_class) 687 self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) 688 689 def test_qat_preserve_source_fn_stack(self): 690 """ 691 Test whether `source_fn_stack` is preserved after QAT fusion. 692 """ 693 694 class M(torch.nn.Module): 695 def __init__(self, conv_class, bn_class, backbone): 696 super().__init__() 697 self.conv = conv_class(5, 3, 3) 698 self.bn = bn_class(3) 699 self.relu = torch.nn.ReLU() 700 self.backbone = backbone 701 702 def forward(self, x): 703 x = self.conv(x) 704 x = self.bn(x) 705 x = self.relu(x) 706 x = self.backbone(x) 707 return x 708 709 assert self.dim in [1, 2] 710 if self.dim == 1: 711 example_inputs = (torch.randn(1, 5, 10),) 712 else: 713 example_inputs = (torch.randn(1, 5, 10, 10),) 714 715 # QAT prepare + convert 716 backbone = self._get_conv_bn_model(has_relu=True) 717 m = M(self.conv_class, self.bn_class, backbone) 718 quantizer = XNNPACKQuantizer() 719 quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) 720 m = capture_pre_autograd_graph(m, example_inputs) 721 m = prepare_qat_pt2e(m, quantizer) 722 m(*example_inputs) 723 m = convert_pt2e(m) 724 725 # Extract the conv and relu nodes (bn was folded into conv) 726 first_conv, first_relu, second_conv, second_relu = None, None, None, None 727 for n in m.graph.nodes: 728 if n.target == torch.ops.aten.relu.default: 729 if first_relu is None: 730 assert first_conv is None, "bad test setup" 731 first_relu = n 732 first_conv = n.args[0] 733 else: 734 assert second_conv is None, "bad test setup" 735 second_relu = n 736 second_conv = n.args[0] 737 738 # Extract the conv weight and bias nodes 739 def get_conv_weight_and_bias(conv_node: torch.fx.Node): 740 weight_dq_node = conv_node.args[1] 741 qweight_node = weight_dq_node.args[0] 742 bias_node = conv_node.args[2] 743 assert isinstance(qweight_node, torch.fx.Node) 744 assert isinstance(bias_node, torch.fx.Node) 745 return (qweight_node, bias_node) 746 747 first_conv_qweight, first_conv_bias = get_conv_weight_and_bias(first_conv) 748 second_conv_qweight, second_conv_bias = get_conv_weight_and_bias(second_conv) 749 750 # Assert that each set of conv, conv weight, and conv bias are in the same partition 751 def get_source_fn(node: torch.fx.Node): 752 # E.g. [('l__self___backbone1_conv', <class 'torch.nn.modules.conv.Conv2d'>)] 753 return node.meta["source_fn_stack"][0][0] 754 755 # we don't preserve this is quantized weight currently since it's folded 756 # but user can attach "quantization_tag" to the node and it will be preserved 757 # self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight)) 758 # self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight)) 759 760 self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias)) 761 self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias)) 762 763 # Assert that different sets of convs and relus have different partitions 764 self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu)) 765 self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv)) 766 self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu)) 767 self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu)) 768 769 # Assert that "backbone" exists only in the second set of conv and relu's partition 770 self.assertTrue("backbone" not in get_source_fn(first_conv)) 771 self.assertTrue("backbone" not in get_source_fn(first_relu)) 772 self.assertTrue("backbone" in get_source_fn(second_conv)) 773 self.assertTrue("backbone" in get_source_fn(second_relu)) 774 775 def test_qat_conv_bn_bias_derived_qspec(self): 776 m = self._get_conv_bn_model() 777 example_inputs = self.example_inputs 778 m = capture_pre_autograd_graph(m, example_inputs) 779 quantizer = ConvBnDerivedBiasQuantizer() 780 m = prepare_qat_pt2e(m, quantizer) 781 m(*example_inputs) 782 m = convert_pt2e(m) 783 m(*example_inputs) 784 785 # Assert that both weight and bias are quantized 786 (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) 787 weight_dq = conv_node.args[1] 788 bias_dq = conv_node.args[2] 789 self.assertEqual( 790 weight_dq.target, 791 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 792 ) 793 self.assertEqual( 794 bias_dq.target, 795 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 796 ) 797 weight_getattr = weight_dq.args[0] 798 bias_getattr = bias_dq.args[0] 799 self.assertEqual( 800 weight_getattr.op, 801 "get_attr", 802 ) 803 self.assertEqual( 804 bias_getattr.op, 805 "get_attr", 806 ) 807 808 # Assert that bias scale = weight scale * input scale 809 input_dq = conv_node.args[0] 810 input_scale = input_dq.args[1] 811 bias_scale = bias_dq.args[1] 812 weight_scale = weight_dq.args[1] 813 self.assertEqual(bias_scale, input_scale * weight_scale) 814 815 # Assert that args for the bias' quantize and dequantize ops 816 # are copied correctly after subgraph rewriting 817 (bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:] 818 self.assertEqual(bias_qmin, -(2**31)) 819 self.assertEqual(bias_qmax, 2**31 - 1) 820 self.assertEqual(bias_dtype, torch.int32) 821 822 def test_qat_per_channel_weight_custom_dtype(self): 823 m = self._get_conv_bn_model() 824 example_inputs = self.example_inputs 825 m = capture_pre_autograd_graph(m, example_inputs) 826 quantizer = ConvBnInt32WeightQuantizer() 827 m = prepare_qat_pt2e(m, quantizer) 828 m(*example_inputs) 829 m = convert_pt2e(m) 830 m(*example_inputs) 831 832 # Assert that conv weight is quantized per channel 833 (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) 834 weight_dq = conv_node.args[1] 835 self.assertEqual( 836 weight_dq.target, 837 torch.ops.quantized_decomposed.dequantize_per_channel.default, 838 ) 839 weight_getattr = weight_dq.args[0] 840 self.assertEqual( 841 weight_getattr.op, 842 "get_attr", 843 ) 844 845 # Assert that args for the weight's dequantize ops 846 # are copied correctly after subgraph rewriting 847 (dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:] 848 self.assertEqual(dq_axis, 0) 849 self.assertEqual(dq_qmin, 0) 850 self.assertEqual(dq_qmax, 2**31 - 1) 851 self.assertEqual(dq_dtype, torch.int32) 852 853 def _do_test_qat_conv_transpose_bn(self, has_relu: bool): 854 # Use different in/out channel sizes to test if conv weight is 855 # properly transposed in QAT pattern 856 m = self._get_conv_bn_model( 857 has_relu=has_relu, 858 transpose=True, 859 in_channels=3, 860 out_channels=5, 861 kernel_size=3, 862 ) 863 self._verify_symmetric_xnnpack_qat_graph( 864 m, 865 self.example_inputs, 866 has_relu=has_relu, 867 verify_convert=True, 868 ) 869 870 def test_qat_conv_transpose_bn(self): 871 self._do_test_qat_conv_transpose_bn(has_relu=False) 872 873 def test_qat_conv_transpose_bn_relu(self): 874 self._do_test_qat_conv_transpose_bn(has_relu=True) 875 876 def test_qat_conv_bn_per_channel_weight_bias(self): 877 m = self._get_conv_bn_model() 878 example_inputs = self.example_inputs 879 m = capture_pre_autograd_graph(m, example_inputs) 880 quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) 881 m = prepare_qat_pt2e(m, quantizer) 882 m(*example_inputs) 883 m = convert_pt2e(m) 884 m(*example_inputs) 885 886 # Expected graph: 887 # x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output 888 # weight -> q_channel -> dq_channel / 889 # bias -> q_channel -> dq_channel / 890 891 (conv_node, _, _) = _get_conv_bn_getitem_nodes(m) 892 conv_op = conv_node.target 893 conv_weight_dq_op = ( 894 torch.ops.quantized_decomposed.dequantize_per_channel.default 895 ) 896 node_occurrence = { 897 ns.call_function( 898 torch.ops.quantized_decomposed.quantize_per_tensor.default 899 ): 2, 900 ns.call_function( 901 torch.ops.quantized_decomposed.dequantize_per_tensor.default 902 ): 2, 903 ns.call_function( 904 torch.ops.quantized_decomposed.dequantize_per_channel.default 905 ): 2, 906 } 907 node_list = [ 908 ns.call_function( 909 torch.ops.quantized_decomposed.quantize_per_tensor.default 910 ), 911 ns.call_function( 912 torch.ops.quantized_decomposed.dequantize_per_tensor.default 913 ), 914 ns.call_function(conv_weight_dq_op), 915 ns.call_function(conv_weight_dq_op), 916 ns.call_function(conv_op), 917 ns.call_function( 918 torch.ops.quantized_decomposed.quantize_per_tensor.default 919 ), 920 ns.call_function( 921 torch.ops.quantized_decomposed.dequantize_per_tensor.default 922 ), 923 ] 924 self.checkGraphModuleNodes( 925 m, 926 expected_node_list=node_list, 927 expected_node_occurrence=node_occurrence, 928 ) 929 930 def test_fold_bn_erases_bn_node(self): 931 """ 932 Ensure the BN node is erased from the graph after folding 933 it into conv in `convert_pt2e` even in train mode. 934 """ 935 m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) 936 m = capture_pre_autograd_graph(m, self.example_inputs) 937 quantizer = XNNPACKQuantizer() 938 quantizer.set_global( 939 get_symmetric_quantization_config(is_per_channel=False, is_qat=True), 940 ) 941 m = prepare_qat_pt2e(m, quantizer) 942 m = convert_pt2e(m) 943 (conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m) 944 self.assertTrue(conv_node is not None) 945 self.assertTrue(bn_node is None) 946 947 948@skipIfNoQNNPACK 949class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base): 950 dim = 1 951 example_inputs = (torch.randn(1, 3, 5),) 952 conv_class = torch.nn.Conv1d 953 conv_transpose_class = torch.nn.ConvTranspose1d 954 bn_class = torch.nn.BatchNorm1d 955 956 957@skipIfNoQNNPACK 958class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base): 959 dim = 2 960 example_inputs = (torch.randn(1, 3, 5, 5),) 961 conv_class = torch.nn.Conv2d 962 conv_transpose_class = torch.nn.ConvTranspose2d 963 bn_class = torch.nn.BatchNorm2d 964 965 966def _is_conv_node(n: torch.fx.Node): 967 return n.op == "call_function" and n.target in [ 968 torch.ops.aten.conv1d.default, 969 torch.ops.aten.conv2d.default, 970 torch.ops.aten.conv_transpose1d, 971 torch.ops.aten.conv_transpose1d.default, 972 torch.ops.aten.conv_transpose2d, 973 torch.ops.aten.conv_transpose2d.input, 974 ] 975 976 977def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule): 978 """ 979 Return a 3-tuple of (conv, bn, getitem) nodes from the graph. 980 """ 981 model.graph.eliminate_dead_code() 982 model.recompile() 983 conv_node = None 984 bn_node = None 985 getitem_node = None 986 for n in model.graph.nodes: 987 if _is_conv_node(n): 988 conv_node = n 989 if n.target in ( 990 torch.ops.aten._native_batch_norm_legit.default, 991 torch.ops.aten.batch_norm.default, 992 ): 993 bn_node = n 994 if n.target == operator.getitem: 995 getitem_node = n 996 assert conv_node is not None, "bad test setup" 997 return (conv_node, bn_node, getitem_node) 998 999 1000class ConvBnInt32WeightQuantizer(Quantizer): 1001 """ 1002 Dummy quantizer that annotates conv bn in such a way that the weights 1003 are quantized per channel to int32. 1004 """ 1005 1006 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1007 conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model) 1008 act_qspec = QuantizationSpec( 1009 dtype=torch.uint8, 1010 quant_min=0, 1011 quant_max=255, 1012 qscheme=torch.per_tensor_affine, 1013 observer_or_fake_quant_ctr=default_fake_quant, 1014 ) 1015 weight_qspec = QuantizationSpec( 1016 dtype=torch.int32, 1017 quant_min=0, 1018 quant_max=2**31 - 1, 1019 qscheme=torch.per_channel_affine, 1020 observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( 1021 observer=MovingAveragePerChannelMinMaxObserver, 1022 ), 1023 ) 1024 conv_node.meta["quantization_annotation"] = QuantizationAnnotation( 1025 input_qspec_map={ 1026 conv_node.args[0]: act_qspec, 1027 conv_node.args[1]: weight_qspec, 1028 }, 1029 _annotated=True, 1030 ) 1031 if getitem_node is not None: 1032 # TODO: This branch is going through a deprecated branch and should be deleted soon, 1033 # after capture_pre_autograd_graph fully migrate to training IR 1034 # T199018392 1035 getitem_node.meta["quantization_annotation"] = QuantizationAnnotation( 1036 output_qspec=act_qspec, 1037 _annotated=True, 1038 ) 1039 else: 1040 # See NOTE [training ir has no getitem for bn node]. 1041 assert capture_pre_autograd_graph_using_training_ir() 1042 bn_node.meta["quantization_annotation"] = QuantizationAnnotation( 1043 output_qspec=act_qspec, 1044 _annotated=True, 1045 ) 1046 return model 1047 1048 def validate(self, model: torch.fx.GraphModule): 1049 pass 1050 1051 1052class ConvBnDerivedBiasQuantizer(Quantizer): 1053 """ 1054 Dummy quantizer that annotates conv bn in such a way that the bias qparams are 1055 derived from the conv input activation and weight qparams. 1056 """ 1057 1058 def __init__(self, is_per_channel: bool = False): 1059 super().__init__() 1060 self.is_per_channel = is_per_channel 1061 1062 def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs): 1063 act_scale, _ = obs_or_fqs[0].calculate_qparams() 1064 weight_scale, _ = obs_or_fqs[1].calculate_qparams() 1065 if self.is_per_channel: 1066 bias_scale = act_scale * weight_scale 1067 bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) 1068 else: 1069 bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32) 1070 bias_zero_point = torch.tensor([0], dtype=torch.int32) 1071 return bias_scale, bias_zero_point 1072 1073 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1074 if self.is_per_channel: 1075 weight_qscheme = torch.per_channel_symmetric 1076 weight_fq = FusedMovingAvgObsFakeQuantize.with_args( 1077 observer=MovingAveragePerChannelMinMaxObserver, 1078 ) 1079 else: 1080 weight_qscheme = torch.per_tensor_affine 1081 weight_fq = default_fake_quant 1082 conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model) 1083 act_qspec = QuantizationSpec( 1084 dtype=torch.uint8, 1085 quant_min=0, 1086 quant_max=255, 1087 qscheme=torch.per_tensor_affine, 1088 observer_or_fake_quant_ctr=default_fake_quant, 1089 ) 1090 weight_qspec = QuantizationSpec( 1091 dtype=torch.uint8, 1092 quant_min=0, 1093 quant_max=255, 1094 qscheme=weight_qscheme, 1095 observer_or_fake_quant_ctr=weight_fq, 1096 ) 1097 bias_qspec = DerivedQuantizationSpec( 1098 derived_from=[ 1099 (conv_node.args[0], conv_node), 1100 (conv_node.args[1], conv_node), 1101 ], 1102 derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams, 1103 dtype=torch.int32, 1104 quant_min=-(2**31), 1105 quant_max=2**31 - 1, 1106 qscheme=weight_qscheme, 1107 ch_axis=0 if self.is_per_channel else None, 1108 ) 1109 conv_node.meta["quantization_annotation"] = QuantizationAnnotation( 1110 input_qspec_map={ 1111 conv_node.args[0]: act_qspec, 1112 conv_node.args[1]: weight_qspec, 1113 conv_node.args[2]: bias_qspec, 1114 }, 1115 _annotated=True, 1116 ) 1117 1118 if getitem_node is not None: 1119 # TODO: This branch is going through a deprecated branch and should be deleted soon, 1120 # after capture_pre_autograd_graph fully migrate to training IR 1121 # T199018392 1122 getitem_node.meta["quantization_annotation"] = QuantizationAnnotation( 1123 output_qspec=act_qspec, 1124 _annotated=True, 1125 ) 1126 else: 1127 # NOTE [training ir has no getitem for bn node]. 1128 # getitem is None when we use the training IR. It outputs 1129 # aten.batch_norm.default, which do not need any getitem node. 1130 # In this case, we need to annotate on the batch norm node. 1131 # geteitem node should only be None if we are using training IR. 1132 assert capture_pre_autograd_graph_using_training_ir() 1133 bn_node.meta["quantization_annotation"] = QuantizationAnnotation( 1134 output_qspec=act_qspec, 1135 _annotated=True, 1136 ) 1137 return model 1138 1139 def validate(self, model: torch.fx.GraphModule): 1140 pass 1141 1142 1143@skipIfNoQNNPACK 1144class TestQuantizePT2EQATModels(PT2EQATTestCase): 1145 @skip_if_no_torchvision 1146 @skipIfNoQNNPACK 1147 def test_qat_resnet18(self): 1148 import torchvision 1149 1150 with override_quantized_engine("qnnpack"): 1151 example_inputs = (torch.randn(1, 3, 224, 224),) 1152 m = torchvision.models.resnet18() 1153 self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) 1154 1155 @skip_if_no_torchvision 1156 @skipIfNoQNNPACK 1157 def test_qat_mobilenet_v2(self): 1158 import torchvision 1159 1160 with override_quantized_engine("qnnpack"): 1161 example_inputs = (torch.randn(1, 3, 224, 224),) 1162 m = torchvision.models.mobilenet_v2() 1163 self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs) 1164 1165 1166class TestQuantizeMixQATAndPTQ(QuantizationTestCase): 1167 class TwoLinear(torch.nn.Module): 1168 def __init__(self) -> None: 1169 super().__init__() 1170 self.linear1 = torch.nn.Linear(16, 8, bias=False) 1171 self.linear2 = torch.nn.Linear(8, 8) 1172 1173 def forward(self, x): 1174 return self.linear2(self.linear1(x)) 1175 1176 class QATPTQTestModule(torch.nn.Module): 1177 def __init__(self) -> None: 1178 super().__init__() 1179 self.conv = torch.nn.Conv2d(3, 16, 3) 1180 self.linears = TestQuantizeMixQATAndPTQ.TwoLinear() 1181 self.my_linear = torch.nn.Linear(8, 8) 1182 1183 def forward(self, x): 1184 conv_out = self.conv(x) 1185 permute_out = torch.permute(conv_out, (0, 2, 3, 1)) 1186 linear_out = self.linears(permute_out) 1187 my_linear_out = self.my_linear(linear_out) 1188 # Hardtanh doesnt get quantized via xnnpack quantizer in this test 1189 # because it relies on the propagation rules 1190 # Need to fix this 1191 return torch.nn.functional.hardtanh(my_linear_out) 1192 1193 def _prepare_qat_linears(self, model): 1194 for name, child in model.named_children(): 1195 if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)): 1196 if isinstance(child, torch.nn.Linear): 1197 in_channels = child.weight.size(1) 1198 else: 1199 in_channels = child.linear1.weight.size(1) 1200 1201 example_input = (torch.rand((1, in_channels)),) 1202 traced_child = capture_pre_autograd_graph(child, example_input) 1203 quantizer = XNNPACKQuantizer() 1204 quantization_config = get_symmetric_quantization_config( 1205 is_per_channel=True, is_qat=True 1206 ) 1207 quantizer.set_global(quantization_config) 1208 traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer) 1209 setattr(model, name, traced_child_prepared) 1210 else: 1211 self._prepare_qat_linears(child) 1212 1213 def _convert_qat_linears(self, model): 1214 for name, child in model.named_children(): 1215 if isinstance(child, torch.fx.GraphModule): 1216 torch.ao.quantization.move_exported_model_to_eval(child) 1217 converted_child = convert_pt2e(child) 1218 setattr(model, name, converted_child) 1219 else: 1220 self._convert_qat_linears(child) 1221 1222 def test_mixing_qat_ptq(self): 1223 example_inputs = (torch.randn(2, 3, 4, 4),) 1224 model = TestQuantizeMixQATAndPTQ.QATPTQTestModule() 1225 1226 self._prepare_qat_linears(model) 1227 1228 after_prepare_result_pt2e = model(*example_inputs) 1229 # must be fixed model.eval() 1230 self._convert_qat_linears(model) 1231 quant_result_pt2e = model(*example_inputs) 1232 1233 model_pt2e = capture_pre_autograd_graph( 1234 model, 1235 example_inputs, 1236 ) 1237 1238 quantizer = XNNPACKQuantizer() 1239 quantizer.set_module_type(torch.nn.Linear, None) 1240 quantization_config = get_symmetric_quantization_config() 1241 quantizer.set_global(quantization_config) 1242 model_pt2e = prepare_pt2e(model_pt2e, quantizer) 1243 after_prepare_result_pt2e = model_pt2e(*example_inputs) 1244 model_pt2e = convert_pt2e(model_pt2e) 1245 quant_result_pt2e = model_pt2e(*example_inputs) 1246 1247 exported_model = torch.export.export(model_pt2e, example_inputs) 1248 1249 node_occurrence = { 1250 # conv2d: 1 for act, 1 for weight, 1 for output 1251 # 3 x linear: 1 for act, 1 for output 1252 ns.call_function( 1253 torch.ops.quantized_decomposed.quantize_per_tensor.default 1254 ): 8, 1255 ns.call_function( 1256 torch.ops.quantized_decomposed.dequantize_per_tensor.default 1257 ): 9, 1258 ns.call_function( 1259 torch.ops.quantized_decomposed.dequantize_per_channel.default 1260 ): 3, 1261 # There needs to be one for hardtanh 1262 } 1263 self.checkGraphModuleNodes( 1264 exported_model.graph_module, expected_node_occurrence=node_occurrence 1265 ) 1266