1# Owner(s): ["oncall: quantization"] 2from typing import List, Tuple 3 4import torch 5from torch import Tensor 6from torch._export import capture_pre_autograd_graph 7from torch._utils_internal import capture_pre_autograd_graph_using_training_ir 8from torch.ao.quantization import observer, ObserverOrFakeQuantize, QConfigMapping 9from torch.ao.quantization.qconfig import ( 10 default_per_channel_symmetric_qnnpack_qconfig, 11 float_qparams_weight_only_qconfig, 12 per_channel_weight_observer_range_neg_127_to_127, 13 QConfig, 14 weight_observer_range_neg_127_to_127, 15) 16from torch.ao.quantization.quantize_pt2e import ( 17 convert_pt2e, 18 prepare_pt2e, 19 prepare_qat_pt2e, 20) 21from torch.ao.quantization.quantizer import ( 22 DerivedQuantizationSpec, 23 FixedQParamsQuantizationSpec, 24 QuantizationAnnotation, 25 QuantizationSpec, 26 Quantizer, 27 SharedQuantizationSpec, 28) 29from torch.ao.quantization.quantizer.composable_quantizer import ( # noqa: F811 30 ComposableQuantizer, 31) 32from torch.ao.quantization.quantizer.embedding_quantizer import ( # noqa: F811 33 EmbeddingQuantizer, 34) 35from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 36 get_symmetric_quantization_config, 37 XNNPACKQuantizer, 38) 39from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 40 OP_TO_ANNOTATOR, 41 QuantizationConfig, 42) 43from torch.fx import Node 44from torch.testing._internal.common_quantization import ( 45 NodeSpec as ns, 46 PT2EQuantizationTestCase, 47 skipIfNoQNNPACK, 48 TestHelperModules, 49) 50from torch.testing._internal.common_utils import ( 51 instantiate_parametrized_tests, 52 parametrize, 53 TemporaryFileName, 54 TEST_CUDA, 55 TEST_WITH_ROCM, 56) 57 58 59@skipIfNoQNNPACK 60class TestQuantizePT2E(PT2EQuantizationTestCase): 61 def test_simple_quantizer(self): 62 # TODO: use OP_TO_ANNOTATOR 63 class BackendAQuantizer(Quantizer): 64 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 65 for node in model.graph.nodes: 66 if ( 67 node.op == "call_function" 68 and node.target == torch.ops.aten.conv2d.default 69 ): 70 input_act = node.args[0] 71 assert isinstance(input_act, Node) 72 weight = node.args[1] 73 assert isinstance(weight, Node) 74 bias = node.args[2] 75 assert isinstance(bias, Node) 76 act_qspec = QuantizationSpec( 77 dtype=torch.uint8, 78 quant_min=0, 79 quant_max=255, 80 qscheme=torch.per_tensor_affine, 81 is_dynamic=False, 82 observer_or_fake_quant_ctr=observer.default_observer, 83 ) 84 weight_qspec = QuantizationSpec( 85 dtype=torch.int8, 86 quant_min=-128, 87 quant_max=127, 88 qscheme=torch.per_tensor_affine, 89 is_dynamic=False, 90 observer_or_fake_quant_ctr=observer.default_weight_observer, 91 ) 92 bias_qspec = QuantizationSpec( 93 dtype=torch.float32, 94 is_dynamic=False, 95 observer_or_fake_quant_ctr=observer.PlaceholderObserver, 96 ) 97 node.meta["quantization_annotation"] = QuantizationAnnotation( 98 input_qspec_map={ 99 input_act: act_qspec, 100 weight: weight_qspec, 101 bias: bias_qspec, 102 }, 103 output_qspec=act_qspec, 104 _annotated=True, 105 ) 106 107 def validate(self, model: torch.fx.GraphModule) -> None: 108 pass 109 110 example_inputs = (torch.randn(1, 3, 5, 5),) 111 node_occurrence = { 112 # two for input of the first conv, one for output for the first conv 113 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 114 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 115 } 116 node_list = [ 117 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 118 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 119 torch.ops.aten.conv2d.default, 120 torch.ops.quantized_decomposed.quantize_per_tensor.default, 121 ] 122 self._test_quantizer( 123 TestHelperModules.ConvWithBNRelu(relu=False, bn=False), 124 example_inputs, 125 BackendAQuantizer(), 126 node_occurrence, 127 node_list, 128 ) 129 130 def test_wo_annotate_conv_output_quantizer(self): 131 # TODO: use OP_TO_ANNOTATOR 132 class BackendAQuantizer(Quantizer): 133 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 134 act_qspec = QuantizationSpec( 135 dtype=torch.uint8, 136 quant_min=0, 137 quant_max=255, 138 qscheme=torch.per_tensor_affine, 139 is_dynamic=False, 140 observer_or_fake_quant_ctr=observer.default_observer, 141 ) 142 weight_qspec = QuantizationSpec( 143 dtype=torch.int8, 144 quant_min=-128, 145 quant_max=127, 146 qscheme=torch.per_tensor_affine, 147 is_dynamic=False, 148 observer_or_fake_quant_ctr=observer.default_weight_observer, 149 ) 150 bias_qspec = QuantizationSpec( 151 dtype=torch.float32, 152 is_dynamic=False, 153 observer_or_fake_quant_ctr=observer.PlaceholderObserver, 154 ) 155 for node in model.graph.nodes: 156 if ( 157 node.op == "call_function" 158 and node.target == torch.ops.aten.conv2d.default 159 ): 160 input_act = node.args[0] 161 assert isinstance(input_act, Node) 162 weight = node.args[1] 163 assert isinstance(weight, Node) 164 bias = node.args[2] 165 assert isinstance(bias, Node) 166 node.meta["quantization_annotation"] = QuantizationAnnotation( 167 input_qspec_map={ 168 input_act: act_qspec, 169 weight: weight_qspec, 170 bias: bias_qspec, 171 }, 172 _annotated=True, 173 ) 174 175 def validate(self, model: torch.fx.GraphModule) -> None: 176 pass 177 178 m = torch.nn.Conv2d(2, 2, 1) 179 x = torch.rand(1, 2, 14, 14) 180 example_inputs = (x,) 181 m = self._quantize(m, BackendAQuantizer(), example_inputs) 182 # Ensure the conv has no observer inserted at output 183 node_occurrence = { 184 # two for input of conv 185 ns.call_function( 186 torch.ops.quantized_decomposed.quantize_per_tensor.default 187 ): 1, 188 ns.call_function( 189 torch.ops.quantized_decomposed.dequantize_per_tensor.default 190 ): 2, 191 } 192 node_list = [ 193 ns.call_function( 194 torch.ops.quantized_decomposed.dequantize_per_tensor.default 195 ), 196 ns.call_function( 197 torch.ops.quantized_decomposed.dequantize_per_tensor.default 198 ), 199 ns.call_function(torch.ops.aten.conv2d.default), 200 ] 201 self.checkGraphModuleNodes( 202 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 203 ) 204 205 def test_max_pool2d_quantizer(self): 206 # TODO: use OP_TO_ANNOTATOR 207 class BackendAQuantizer(Quantizer): 208 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 209 act_qspec = QuantizationSpec( 210 dtype=torch.uint8, 211 quant_min=0, 212 quant_max=255, 213 qscheme=torch.per_tensor_affine, 214 is_dynamic=False, 215 observer_or_fake_quant_ctr=observer.default_observer, 216 ) 217 weight_qspec = QuantizationSpec( 218 dtype=torch.int8, 219 quant_min=-128, 220 quant_max=127, 221 qscheme=torch.per_tensor_affine, 222 is_dynamic=False, 223 observer_or_fake_quant_ctr=observer.default_weight_observer, 224 ) 225 bias_qspec = QuantizationSpec( 226 dtype=torch.float32, 227 is_dynamic=False, 228 observer_or_fake_quant_ctr=observer.PlaceholderObserver, 229 ) 230 for node in model.graph.nodes: 231 if ( 232 node.op == "call_function" 233 and node.target == torch.ops.aten.conv2d.default 234 ): 235 input_act = node.args[0] 236 assert isinstance(input_act, Node) 237 weight = node.args[1] 238 assert isinstance(weight, Node) 239 bias = node.args[2] 240 assert isinstance(bias, Node) 241 node.meta["quantization_annotation"] = QuantizationAnnotation( 242 input_qspec_map={ 243 input_act: act_qspec, 244 weight: weight_qspec, 245 bias: bias_qspec, 246 }, 247 _annotated=True, 248 ) 249 if ( 250 node.op == "call_function" 251 and node.target == torch.ops.aten.max_pool2d.default 252 ): 253 maxpool_node = node 254 input_act = maxpool_node.args[0] 255 assert isinstance(input_act, Node) 256 maxpool_node.meta[ 257 "quantization_annotation" 258 ] = QuantizationAnnotation( 259 input_qspec_map={ 260 input_act: act_qspec, 261 }, 262 output_qspec=SharedQuantizationSpec( 263 (input_act, maxpool_node) 264 ), 265 _annotated=True, 266 ) 267 268 def validate(self, model: torch.fx.GraphModule) -> None: 269 pass 270 271 m = TestHelperModules.ConvMaxPool2d() 272 x = torch.rand(1, 2, 14, 14) 273 example_inputs = (x,) 274 m = self._quantize(m, BackendAQuantizer(), example_inputs) 275 node_occurrence = { 276 # two for input of conv 277 # one for input of maxpool 278 # one for output of maxpool 279 ns.call_function( 280 torch.ops.quantized_decomposed.quantize_per_tensor.default 281 ): 3, 282 ns.call_function( 283 torch.ops.quantized_decomposed.dequantize_per_tensor.default 284 ): 4, 285 } 286 node_list = [ 287 ns.call_function( 288 torch.ops.quantized_decomposed.dequantize_per_tensor.default 289 ), 290 ns.call_function( 291 torch.ops.quantized_decomposed.dequantize_per_tensor.default 292 ), 293 ns.call_function(torch.ops.aten.conv2d.default), 294 ns.call_function( 295 torch.ops.quantized_decomposed.quantize_per_tensor.default 296 ), 297 ns.call_function( 298 torch.ops.quantized_decomposed.dequantize_per_tensor.default 299 ), 300 ns.call_function(torch.ops.aten.max_pool2d.default), 301 ] 302 self.checkGraphModuleNodes( 303 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 304 ) 305 306 def test_derived_qspec(self): 307 # TODO: use OP_TO_ANNOTATOR 308 class BackendAQuantizer(Quantizer): 309 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 310 for node in model.graph.nodes: 311 if ( 312 node.op == "call_function" 313 and node.target == torch.ops.aten.conv2d.default 314 ): 315 input_act = node.args[0] 316 assert isinstance(input_act, Node) 317 weight = node.args[1] 318 assert isinstance(weight, Node) 319 bias = node.args[2] 320 assert isinstance(bias, Node) 321 act_qspec = QuantizationSpec( 322 dtype=torch.uint8, 323 quant_min=0, 324 quant_max=255, 325 qscheme=torch.per_tensor_affine, 326 is_dynamic=False, 327 observer_or_fake_quant_ctr=observer.default_observer, 328 ) 329 weight_qspec = QuantizationSpec( 330 dtype=torch.int8, 331 quant_min=-128, 332 quant_max=127, 333 qscheme=torch.per_tensor_affine, 334 is_dynamic=False, 335 observer_or_fake_quant_ctr=observer.default_weight_observer, 336 ) 337 338 def derive_qparams_fn( 339 obs_or_fqs: List[ObserverOrFakeQuantize], 340 ) -> Tuple[Tensor, Tensor]: 341 assert ( 342 len(obs_or_fqs) == 2 343 ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" 344 act_obs_or_fq = obs_or_fqs[0] 345 weight_obs_or_fq = obs_or_fqs[1] 346 act_scale, act_zp = act_obs_or_fq.calculate_qparams() 347 ( 348 weight_scale, 349 weight_zp, 350 ) = weight_obs_or_fq.calculate_qparams() 351 return torch.tensor([act_scale * weight_scale]).to( 352 torch.float32 353 ), torch.tensor([0]).to(torch.int32) 354 355 bias_qspec = DerivedQuantizationSpec( 356 derived_from=[(input_act, node), (weight, node)], 357 derive_qparams_fn=derive_qparams_fn, 358 dtype=torch.int32, 359 quant_min=-(2**31), 360 quant_max=2**31 - 1, 361 qscheme=torch.per_tensor_symmetric, 362 ) 363 node.meta["quantization_annotation"] = QuantizationAnnotation( 364 input_qspec_map={ 365 input_act: act_qspec, 366 weight: weight_qspec, 367 bias: bias_qspec, 368 }, 369 output_qspec=act_qspec, 370 _annotated=True, 371 ) 372 373 def validate(self, model: torch.fx.GraphModule) -> None: 374 pass 375 376 m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() 377 example_inputs = (torch.randn(1, 3, 5, 5),) 378 379 m = self._quantize(m, BackendAQuantizer(), example_inputs) 380 node_occurrence = { 381 # input, weight, bias, output for the conv 382 # note: quantize op for weight and bias are const propagated 383 ns.call_function( 384 torch.ops.quantized_decomposed.quantize_per_tensor.default 385 ): 2, 386 ns.call_function( 387 torch.ops.quantized_decomposed.dequantize_per_tensor.default 388 ): 4, 389 } 390 node_list = [ 391 ns.call_function( 392 torch.ops.quantized_decomposed.dequantize_per_tensor.default 393 ), 394 ns.call_function( 395 torch.ops.quantized_decomposed.dequantize_per_tensor.default 396 ), 397 ns.call_function( 398 torch.ops.quantized_decomposed.dequantize_per_tensor.default 399 ), 400 ns.call_function(torch.ops.aten.conv2d.default), 401 ns.call_function( 402 torch.ops.quantized_decomposed.quantize_per_tensor.default 403 ), 404 ] 405 self.checkGraphModuleNodes( 406 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 407 ) 408 409 def test_derived_qspec_per_channel(self): 410 class BackendAQuantizer(Quantizer): 411 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 412 for node in model.graph.nodes: 413 if ( 414 node.op == "call_function" 415 and node.target == torch.ops.aten.conv2d.default 416 ): 417 input_act = node.args[0] 418 assert isinstance(input_act, Node) 419 weight = node.args[1] 420 assert isinstance(weight, Node) 421 bias = node.args[2] 422 assert isinstance(bias, Node) 423 act_qspec = QuantizationSpec( 424 dtype=torch.uint8, 425 quant_min=0, 426 quant_max=255, 427 qscheme=torch.per_tensor_affine, 428 is_dynamic=False, 429 observer_or_fake_quant_ctr=observer.default_observer, 430 ) 431 weight_qspec = QuantizationSpec( 432 dtype=torch.int8, 433 quant_min=-128, 434 quant_max=127, 435 qscheme=torch.per_channel_affine, 436 is_dynamic=False, 437 ch_axis=0, 438 observer_or_fake_quant_ctr=observer.default_per_channel_weight_observer, 439 ) 440 441 def derive_qparams_fn( 442 obs_or_fqs: List[ObserverOrFakeQuantize], 443 ) -> Tuple[Tensor, Tensor]: 444 assert ( 445 len(obs_or_fqs) == 1 446 ), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}" 447 weight_obs_or_fq = obs_or_fqs[0] 448 ( 449 weight_scale, 450 weight_zp, 451 ) = weight_obs_or_fq.calculate_qparams() 452 return weight_scale, torch.zeros_like(weight_scale) 453 454 bias_qspec = DerivedQuantizationSpec( 455 derived_from=[(weight, node)], 456 derive_qparams_fn=derive_qparams_fn, 457 dtype=torch.int32, 458 quant_min=-(2**31), 459 quant_max=2**31 - 1, 460 qscheme=torch.per_channel_symmetric, 461 ch_axis=0, 462 ) 463 node.meta["quantization_annotation"] = QuantizationAnnotation( 464 input_qspec_map={ 465 input_act: act_qspec, 466 weight: weight_qspec, 467 bias: bias_qspec, 468 }, 469 output_qspec=act_qspec, 470 _annotated=True, 471 ) 472 473 def validate(self, model: torch.fx.GraphModule) -> None: 474 pass 475 476 m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() 477 example_inputs = (torch.randn(1, 3, 5, 5),) 478 479 m = self._quantize(m, BackendAQuantizer(), example_inputs) 480 481 node_occurrence = { 482 # input, output for the conv 483 ns.call_function( 484 torch.ops.quantized_decomposed.quantize_per_tensor.default 485 ): 2, 486 ns.call_function( 487 torch.ops.quantized_decomposed.dequantize_per_tensor.default 488 ): 2, 489 # weight and bias for conv 490 # note: quantize op for weight and bias are const propagated 491 ns.call_function( 492 torch.ops.quantized_decomposed.quantize_per_channel.default 493 ): 0, 494 ns.call_function( 495 torch.ops.quantized_decomposed.dequantize_per_channel.default 496 ): 2, 497 } 498 node_list = [ 499 ns.call_function( 500 torch.ops.quantized_decomposed.dequantize_per_channel.default 501 ), 502 ns.call_function( 503 torch.ops.quantized_decomposed.dequantize_per_channel.default 504 ), 505 ns.call_function(torch.ops.aten.conv2d.default), 506 ns.call_function( 507 torch.ops.quantized_decomposed.quantize_per_tensor.default 508 ), 509 ] 510 self.checkGraphModuleNodes( 511 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 512 ) 513 514 def test_fixed_qparams_qspec_ptq(self): 515 self._test_fixed_qparams_qspec(is_qat=False) 516 517 # TODO: refactor and move this to test_quantize_pt2_qat.py 518 def test_fixed_qparams_qspec_qat(self): 519 self._test_fixed_qparams_qspec(is_qat=True) 520 521 def _test_fixed_qparams_qspec(self, is_qat: bool): 522 class M(torch.nn.Module): 523 def forward(self, x): 524 return torch.sigmoid(x) 525 526 class BackendAQuantizer(Quantizer): 527 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 528 for node in model.graph.nodes: 529 if ( 530 node.op == "call_function" 531 and node.target == torch.ops.aten.sigmoid.default 532 ): 533 input_act = node.args[0] 534 assert isinstance(input_act, Node) 535 act_qspec = FixedQParamsQuantizationSpec( 536 dtype=torch.uint8, 537 quant_min=0, 538 quant_max=255, 539 qscheme=torch.per_tensor_affine, 540 scale=1.0 / 256.0, 541 zero_point=0, 542 ) 543 node.meta["quantization_annotation"] = QuantizationAnnotation( 544 input_qspec_map={ 545 input_act: act_qspec, 546 }, 547 output_qspec=act_qspec, 548 _annotated=True, 549 ) 550 551 def validate(self, model: torch.fx.GraphModule) -> None: 552 pass 553 554 m = M().eval() 555 example_inputs = (torch.randn(1, 3, 5, 5),) 556 557 m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat) 558 fixed_scale = 1.0 / 256.0 559 fixed_zero_point = 0 560 for n in m.graph.nodes: 561 if n.op == "call_function": 562 if ( 563 n.target 564 == torch.ops.quantized_decomposed.quantize_per_tensor.default 565 ): 566 scale_0 = n.args[1] 567 zero_point_0 = n.args[2] 568 if ( 569 n.target 570 == torch.ops.quantized_decomposed.dequantize_per_tensor.default 571 ): 572 scale_1 = n.args[1] 573 zero_point_1 = n.args[2] 574 self.assertEqual(scale_0, fixed_scale) 575 self.assertEqual(zero_point_0, fixed_zero_point) 576 self.assertEqual(scale_1, fixed_scale) 577 self.assertEqual(zero_point_1, fixed_zero_point) 578 node_occurrence = { 579 # two for input of the first conv, one for output for the first conv 580 ns.call_function( 581 torch.ops.quantized_decomposed.quantize_per_tensor.default 582 ): 2, 583 ns.call_function( 584 torch.ops.quantized_decomposed.dequantize_per_tensor.default 585 ): 2, 586 } 587 node_list = [ 588 ns.call_function( 589 torch.ops.quantized_decomposed.dequantize_per_tensor.default 590 ), 591 ns.call_function(torch.ops.aten.sigmoid.default), 592 ns.call_function( 593 torch.ops.quantized_decomposed.quantize_per_tensor.default 594 ), 595 ] 596 self.checkGraphModuleNodes( 597 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 598 ) 599 600 def test_fixed_qparams_qspec_observer_dedup(self): 601 class BackendAQuantizer(Quantizer): 602 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 603 for node in model.graph.nodes: 604 if ( 605 node.op == "call_function" 606 and node.target == torch.ops.aten.sigmoid.default 607 ): 608 input_act = node.args[0] 609 assert isinstance(input_act, Node) 610 act_qspec = FixedQParamsQuantizationSpec( 611 dtype=torch.uint8, 612 quant_min=0, 613 quant_max=255, 614 qscheme=torch.per_tensor_affine, 615 scale=1.0 / 256.0, 616 zero_point=0, 617 ) 618 node.meta["quantization_annotation"] = QuantizationAnnotation( 619 input_qspec_map={ 620 input_act: act_qspec, 621 }, 622 output_qspec=act_qspec, 623 _annotated=True, 624 ) 625 elif ( 626 node.op == "call_function" 627 and node.target == torch.ops.aten.add.Tensor 628 ): 629 input_act0 = node.args[0] 630 assert isinstance(input_act, Node) 631 input_act1 = node.args[1] 632 assert isinstance(input_act, Node) 633 act_qspec = QuantizationSpec( 634 observer_or_fake_quant_ctr=observer.default_observer, 635 dtype=torch.uint8, 636 quant_min=0, 637 quant_max=255, 638 qscheme=torch.per_tensor_affine, 639 ) 640 node.meta["quantization_annotation"] = QuantizationAnnotation( 641 input_qspec_map={ 642 input_act0: act_qspec, 643 input_act1: act_qspec, 644 }, 645 output_qspec=act_qspec, 646 _annotated=True, 647 ) 648 649 def validate(self, model: torch.fx.GraphModule) -> None: 650 pass 651 652 class M(torch.nn.Module): 653 def forward(self, x, y): 654 return torch.sigmoid(x) + y 655 656 def example_inputs(self): 657 return ( 658 torch.randn(1, 3, 5, 5), 659 torch.randn(1, 3, 5, 5), 660 ) 661 662 m = M().eval() 663 example_inputs = m.example_inputs() 664 m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat=False) 665 666 node_occurrence = { 667 # two for input of the first conv, one for output for the first conv 668 ns.call_function( 669 torch.ops.quantized_decomposed.quantize_per_tensor.default 670 ): 4, 671 ns.call_function( 672 torch.ops.quantized_decomposed.dequantize_per_tensor.default 673 ): 4, 674 } 675 node_list = [ 676 ns.call_function( 677 torch.ops.quantized_decomposed.dequantize_per_tensor.default 678 ), 679 ns.call_function(torch.ops.aten.sigmoid.default), 680 ns.call_function( 681 torch.ops.quantized_decomposed.quantize_per_tensor.default 682 ), 683 ns.call_function( 684 torch.ops.quantized_decomposed.dequantize_per_tensor.default 685 ), 686 ns.call_function(torch.ops.aten.add.Tensor), 687 ns.call_function( 688 torch.ops.quantized_decomposed.quantize_per_tensor.default 689 ), 690 ] 691 self.checkGraphModuleNodes( 692 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 693 ) 694 695 def test_shared_qspec(self): 696 class BackendAQuantizer(Quantizer): 697 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 698 for node in model.graph.nodes: 699 if ( 700 node.op == "call_function" 701 and node.target == torch.ops.aten.conv2d.default 702 ): 703 input_act = node.args[0] 704 assert isinstance(input_act, Node) 705 weight = node.args[1] 706 assert isinstance(weight, Node) 707 bias = node.args[2] 708 assert isinstance(bias, Node) 709 act_qspec = QuantizationSpec( 710 dtype=torch.uint8, 711 quant_min=0, 712 quant_max=255, 713 qscheme=torch.per_tensor_affine, 714 is_dynamic=False, 715 observer_or_fake_quant_ctr=observer.default_observer, 716 ) 717 weight_qspec = QuantizationSpec( 718 dtype=torch.int8, 719 quant_min=-128, 720 quant_max=127, 721 qscheme=torch.per_tensor_affine, 722 is_dynamic=False, 723 observer_or_fake_quant_ctr=observer.default_weight_observer, 724 ) 725 bias_qspec = QuantizationSpec( 726 dtype=torch.float32, 727 is_dynamic=False, 728 observer_or_fake_quant_ctr=observer.PlaceholderObserver, 729 ) 730 node.meta["quantization_annotation"] = QuantizationAnnotation( 731 input_qspec_map={ 732 input_act: act_qspec, 733 weight: weight_qspec, 734 bias: bias_qspec, 735 }, 736 output_qspec=act_qspec, 737 _annotated=True, 738 ) 739 elif node.target is torch.ops.aten.cat.default: 740 cat_node = node 741 input_nodes = cat_node.args[0] 742 first_input_node = input_nodes[0] 743 input_qspec_map = {} 744 act_qspec = QuantizationSpec( 745 dtype=torch.uint8, 746 quant_min=0, 747 quant_max=255, 748 qscheme=torch.per_tensor_affine, 749 is_dynamic=False, 750 observer_or_fake_quant_ctr=observer.default_observer, 751 ) 752 input_qspec_map[first_input_node] = act_qspec 753 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 754 (first_input_node, cat_node) 755 ) 756 for input_node in input_nodes[1:]: 757 input_qspec_map[ 758 input_node 759 ] = share_qparams_with_input_act0_qspec 760 761 cat_node.meta[ 762 "quantization_annotation" 763 ] = QuantizationAnnotation( 764 input_qspec_map=input_qspec_map, 765 output_qspec=share_qparams_with_input_act0_qspec, 766 _annotated=True, 767 ) 768 769 def validate(self, model: torch.fx.GraphModule) -> None: 770 pass 771 772 m = TestHelperModules.Conv2dWithCat().eval() 773 example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) 774 775 # program capture 776 m = capture_pre_autograd_graph( 777 m, 778 example_inputs, 779 ) 780 m = prepare_pt2e(m, BackendAQuantizer()) 781 # make sure the two observers for input are shared 782 conv_output_obs = [] 783 for n in m.graph.nodes: 784 if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default: 785 conv_output_obs.append(getattr(m, next(iter(n.users)).target)) 786 if n.op == "call_function" and n.target == torch.ops.aten.cat.default: 787 inputs = n.args[0] 788 input0 = inputs[0] 789 input1 = inputs[1] 790 assert input0.op == "call_module" 791 assert input1.op == "call_module" 792 obs_ins0 = getattr(m, input0.target) 793 obs_ins1 = getattr(m, input1.target) 794 assert obs_ins0 == obs_ins1 795 assert ( 796 len(conv_output_obs) == 2 797 ), "expecting two observer that follows conv2d ops" 798 # checking that the output observers for the two convs are shared as well 799 assert conv_output_obs[0] == conv_output_obs[1] 800 801 m(*example_inputs) 802 m = convert_pt2e(m) 803 804 node_occurrence = { 805 # two for input of the first conv, one for output for the first conv 806 ns.call_function( 807 torch.ops.quantized_decomposed.quantize_per_tensor.default 808 ): 5, 809 ns.call_function( 810 torch.ops.quantized_decomposed.dequantize_per_tensor.default 811 ): 7, 812 } 813 node_list = [ 814 ns.call_function( 815 torch.ops.quantized_decomposed.dequantize_per_tensor.default 816 ), 817 ns.call_function( 818 torch.ops.quantized_decomposed.dequantize_per_tensor.default 819 ), 820 ns.call_function(torch.ops.aten.cat.default), 821 ns.call_function( 822 torch.ops.quantized_decomposed.quantize_per_tensor.default 823 ), 824 ] 825 self.checkGraphModuleNodes( 826 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 827 ) 828 829 def _test_transitive_sharing_with_cat_helper(self, quantizer): 830 m = TestHelperModules.Conv2dWithTwoCat().eval() 831 example_inputs = ( 832 torch.randn(1, 3, 5, 5), 833 torch.randn(1, 3, 5, 5), 834 torch.randn(1, 6, 3, 3), 835 torch.randn(1, 6, 3, 3), 836 ) 837 838 # program capture 839 m = capture_pre_autograd_graph( 840 m, 841 example_inputs, 842 ) 843 m = prepare_pt2e(m, quantizer) 844 m(*example_inputs) 845 # make sure the two input observers and output are shared 846 conv_output_obs = [] 847 for n in m.graph.nodes: 848 if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default: 849 conv_output_obs.append(getattr(m, next(iter(n.users)).target)) 850 if n.op == "call_function" and n.target == torch.ops.aten.cat.default: 851 inputs = n.args[0] 852 input0 = inputs[0] 853 input1 = inputs[1] 854 assert input0.op == "call_module" 855 assert input1.op == "call_module" 856 obs_ins0 = getattr(m, input0.target) 857 obs_ins1 = getattr(m, input1.target) 858 assert obs_ins0 == obs_ins1 859 860 output_obs = next(iter(n.users)) 861 assert output_obs.op == "call_module" 862 obs_ins2 = getattr(m, output_obs.target) 863 assert obs_ins0 == obs_ins2, "input observer does not match output" 864 865 assert ( 866 len(conv_output_obs) == 2 867 ), "expecting two observer that follows conv2d ops" 868 # checking that the output observers for the two convs are shared as well 869 assert conv_output_obs[0] == conv_output_obs[1] 870 871 m(*example_inputs) 872 m = convert_pt2e(m) 873 874 node_occurrence = { 875 # two for input of the first conv, one for output for the first conv 876 ns.call_function( 877 torch.ops.quantized_decomposed.quantize_per_tensor.default 878 ): 7, 879 ns.call_function( 880 torch.ops.quantized_decomposed.dequantize_per_tensor.default 881 ): 9, 882 } 883 node_list = [ 884 ns.call_function( 885 torch.ops.quantized_decomposed.dequantize_per_tensor.default 886 ), 887 ns.call_function( 888 torch.ops.quantized_decomposed.dequantize_per_tensor.default 889 ), 890 ns.call_function(torch.ops.aten.cat.default), 891 ns.call_function( 892 torch.ops.quantized_decomposed.quantize_per_tensor.default 893 ), 894 ns.call_function( 895 torch.ops.quantized_decomposed.dequantize_per_tensor.default 896 ), 897 ns.call_function(torch.ops.aten.cat.default), 898 ns.call_function( 899 torch.ops.quantized_decomposed.quantize_per_tensor.default 900 ), 901 ] 902 self.checkGraphModuleNodes( 903 m, expected_node_list=node_list, expected_node_occurrence=node_occurrence 904 ) 905 906 def test_shared_qspec_transitivity(self): 907 """This tests the transitivity of SharedQuantizationSpec, that is 908 if A is shared with B, B is shared with C, then C should be shared with A as well 909 910 x1 -> conv1 -> cat1 -----> cat2 911 x2 -> conv2 -/ / 912 x3 -> add / 913 x4 / 914 915 both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor 916 so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same 917 sharing group after transitive sharing 918 """ 919 920 # TODO: refactor this to a common util 921 class BackendAQuantizer(Quantizer): 922 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 923 for node in model.graph.nodes: 924 if ( 925 node.op == "call_function" 926 and node.target == torch.ops.aten.conv2d.default 927 ): 928 input_act = node.args[0] 929 assert isinstance(input_act, Node) 930 weight = node.args[1] 931 assert isinstance(weight, Node) 932 bias = node.args[2] 933 assert isinstance(bias, Node) 934 act_qspec = QuantizationSpec( 935 dtype=torch.uint8, 936 quant_min=0, 937 quant_max=255, 938 qscheme=torch.per_tensor_affine, 939 is_dynamic=False, 940 observer_or_fake_quant_ctr=observer.default_observer, 941 ) 942 weight_qspec = QuantizationSpec( 943 dtype=torch.int8, 944 quant_min=-128, 945 quant_max=127, 946 qscheme=torch.per_tensor_affine, 947 is_dynamic=False, 948 observer_or_fake_quant_ctr=observer.default_weight_observer, 949 ) 950 bias_qspec = QuantizationSpec( 951 dtype=torch.float32, 952 is_dynamic=False, 953 observer_or_fake_quant_ctr=observer.PlaceholderObserver, 954 ) 955 node.meta["quantization_annotation"] = QuantizationAnnotation( 956 input_qspec_map={ 957 input_act: act_qspec, 958 weight: weight_qspec, 959 bias: bias_qspec, 960 }, 961 output_qspec=act_qspec, 962 _annotated=True, 963 ) 964 elif node.target is torch.ops.aten.cat.default: 965 cat_node = node 966 input_nodes = cat_node.args[0] 967 first_input_node = input_nodes[0] 968 input_qspec_map = {} 969 act_qspec = QuantizationSpec( 970 dtype=torch.uint8, 971 quant_min=0, 972 quant_max=255, 973 qscheme=torch.per_tensor_affine, 974 is_dynamic=False, 975 observer_or_fake_quant_ctr=observer.default_observer, 976 ) 977 input_qspec_map[first_input_node] = act_qspec 978 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 979 (first_input_node, cat_node) 980 ) 981 for input_node in input_nodes[1:]: 982 input_qspec_map[ 983 input_node 984 ] = share_qparams_with_input_act0_qspec 985 986 cat_node.meta[ 987 "quantization_annotation" 988 ] = QuantizationAnnotation( 989 input_qspec_map=input_qspec_map, 990 output_qspec=share_qparams_with_input_act0_qspec, 991 _annotated=True, 992 ) 993 994 def validate(self, model: torch.fx.GraphModule) -> None: 995 pass 996 997 self._test_transitive_sharing_with_cat_helper(BackendAQuantizer()) 998 999 def test_shared_qspec_transitivity_case_2(self): 1000 """This tests the transitivity of SharedQuantizationSpec, that is 1001 if A is shared with B, B is shared with C, then C should be shared with A as well 1002 1003 x1 -> conv1 -> cat1 -----> cat2 1004 x2 -> conv2 -/ / 1005 x3 -> add / 1006 x4 / 1007 1008 both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor 1009 so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same 1010 sharing group after transitive sharing 1011 1012 the difference is that for this one, all edges and nodes are shared with the second input edge of cat 1013 instead of the first input edge of cat as in previous example 1014 """ 1015 1016 # TODO: refactor this to a common util 1017 class BackendAQuantizer(Quantizer): 1018 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1019 for node in model.graph.nodes: 1020 if ( 1021 node.op == "call_function" 1022 and node.target == torch.ops.aten.conv2d.default 1023 ): 1024 input_act = node.args[0] 1025 assert isinstance(input_act, Node) 1026 weight = node.args[1] 1027 assert isinstance(weight, Node) 1028 bias = node.args[2] 1029 assert isinstance(bias, Node) 1030 act_qspec = QuantizationSpec( 1031 dtype=torch.uint8, 1032 quant_min=0, 1033 quant_max=255, 1034 qscheme=torch.per_tensor_affine, 1035 is_dynamic=False, 1036 observer_or_fake_quant_ctr=observer.default_observer, 1037 ) 1038 weight_qspec = QuantizationSpec( 1039 dtype=torch.int8, 1040 quant_min=-128, 1041 quant_max=127, 1042 qscheme=torch.per_tensor_affine, 1043 is_dynamic=False, 1044 observer_or_fake_quant_ctr=observer.default_weight_observer, 1045 ) 1046 bias_qspec = QuantizationSpec( 1047 dtype=torch.float32, 1048 is_dynamic=False, 1049 observer_or_fake_quant_ctr=observer.PlaceholderObserver, 1050 ) 1051 node.meta["quantization_annotation"] = QuantizationAnnotation( 1052 input_qspec_map={ 1053 input_act: act_qspec, 1054 weight: weight_qspec, 1055 bias: bias_qspec, 1056 }, 1057 output_qspec=act_qspec, 1058 _annotated=True, 1059 ) 1060 elif node.target is torch.ops.aten.cat.default: 1061 cat_node = node 1062 input_nodes = cat_node.args[0] 1063 first_input_node = input_nodes[0] 1064 second_input_node = input_nodes[1] 1065 input_qspec_map = {} 1066 act_qspec = QuantizationSpec( 1067 dtype=torch.uint8, 1068 quant_min=0, 1069 quant_max=255, 1070 qscheme=torch.per_tensor_affine, 1071 is_dynamic=False, 1072 observer_or_fake_quant_ctr=observer.default_observer, 1073 ) 1074 input_qspec_map[second_input_node] = act_qspec 1075 share_qparams_with_input_act1_qspec = SharedQuantizationSpec( 1076 (second_input_node, cat_node) 1077 ) 1078 input_qspec_map[ 1079 first_input_node 1080 ] = share_qparams_with_input_act1_qspec 1081 1082 cat_node.meta[ 1083 "quantization_annotation" 1084 ] = QuantizationAnnotation( 1085 input_qspec_map=input_qspec_map, 1086 output_qspec=share_qparams_with_input_act1_qspec, 1087 _annotated=True, 1088 ) 1089 1090 def validate(self, model: torch.fx.GraphModule) -> None: 1091 pass 1092 1093 self._test_transitive_sharing_with_cat_helper(BackendAQuantizer()) 1094 1095 def test_allow_implicit_sharing(self): 1096 """This tests the allow_transitive_sharing flag of QuantizationAnnotation, that is 1097 if a node is configured with allow_implicit_sharing=False, we will not have implicit sharing 1098 for node and (node, consumer) even they refer to the same Tensor 1099 1100 x1 -> add1 -----> add3 1101 x2 -/ / 1102 x3 -> add2 / 1103 x4 -/ 1104 1105 all add has shared input and output, and second input is using shared quantization spec pointing 1106 to first input, but we set allow_implicit_sharing to False for all add nodes so input and output of add1, 1107 add2 and add3 will each belong to one sharing group, so we'll have: 1108 1109 x1 -> obs1 -> add1 -> obs1 -> obs3--> add3 -> obs3 1110 x2 -> obs1 -/ / 1111 x3 -> obs2 -> add2 -> obs2 -> obs3 1112 x4 -> obs2 -/ 1113 """ 1114 1115 # TODO: refactor this to a common util 1116 class BackendAQuantizer(Quantizer): 1117 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1118 for node in model.graph.nodes: 1119 if node.target is torch.ops.aten.add.Tensor: 1120 add_node = node 1121 first_input_node = add_node.args[0] 1122 second_input_node = add_node.args[1] 1123 input_qspec_map = {} 1124 act_qspec = QuantizationSpec( 1125 dtype=torch.uint8, 1126 quant_min=0, 1127 quant_max=255, 1128 qscheme=torch.per_tensor_affine, 1129 is_dynamic=False, 1130 observer_or_fake_quant_ctr=observer.default_observer, 1131 ) 1132 input_qspec_map[second_input_node] = act_qspec 1133 share_qparams_with_input_act1_qspec = SharedQuantizationSpec( 1134 (second_input_node, add_node) 1135 ) 1136 input_qspec_map[ 1137 first_input_node 1138 ] = share_qparams_with_input_act1_qspec 1139 1140 add_node.meta[ 1141 "quantization_annotation" 1142 ] = QuantizationAnnotation( 1143 input_qspec_map=input_qspec_map, 1144 output_qspec=share_qparams_with_input_act1_qspec, 1145 allow_implicit_sharing=False, 1146 _annotated=True, 1147 ) 1148 1149 def validate(self, model: torch.fx.GraphModule) -> None: 1150 pass 1151 1152 m = TestHelperModules.ThreeAdd().eval() 1153 example_inputs = ( 1154 torch.randn(1, 3, 5, 5), 1155 torch.randn(1, 3, 5, 5), 1156 torch.randn(1, 3, 5, 5), 1157 torch.randn(1, 3, 5, 5), 1158 ) 1159 1160 # program capture 1161 m = capture_pre_autograd_graph( 1162 m, 1163 example_inputs, 1164 ) 1165 quantizer = BackendAQuantizer() 1166 m = prepare_pt2e(m, quantizer) 1167 m(*example_inputs) 1168 observers = [] 1169 for n in m.graph.nodes: 1170 if n.target == torch.ops.aten.add.Tensor: 1171 input_obs1 = getattr(m, n.args[0].target) 1172 input_obs2 = getattr(m, n.args[1].target) 1173 output_obs = getattr(m, next(iter(n.users)).target) 1174 self.assertIs(input_obs1, input_obs2) 1175 self.assertIs(input_obs1, output_obs) 1176 observers.append(input_obs1) 1177 assert len(observers) == 3 1178 self.assertIsNot(observers[0], observers[1]) 1179 self.assertIsNot(observers[0], observers[2]) 1180 self.assertIsNot(observers[1], observers[2]) 1181 1182 @parametrize("dtype", (torch.float32, torch.bfloat16)) 1183 @parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn)) 1184 def test_quantization_dtype(self, dtype, quant_dtype): 1185 class DtypeActQuantizer(Quantizer): 1186 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1187 info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo 1188 activate_qspec = QuantizationSpec( 1189 dtype=quant_dtype, 1190 quant_min=int(info_fun(quant_dtype).min), 1191 quant_max=int(info_fun(quant_dtype).max), 1192 qscheme=torch.per_tensor_affine, 1193 is_dynamic=False, 1194 observer_or_fake_quant_ctr=observer.default_observer, 1195 ) 1196 int8_qspec = QuantizationSpec( 1197 dtype=torch.int8, 1198 quant_min=-128, 1199 quant_max=127, 1200 qscheme=torch.per_tensor_symmetric, 1201 is_dynamic=False, 1202 observer_or_fake_quant_ctr=observer.default_weight_observer, 1203 ) 1204 quantization_config = QuantizationConfig( 1205 input_activation=activate_qspec, 1206 weight=int8_qspec, 1207 bias=None, 1208 output_activation=activate_qspec, 1209 ) 1210 OP_TO_ANNOTATOR["conv"](model, quantization_config) 1211 1212 def validate(self, model: torch.fx.GraphModule) -> None: 1213 pass 1214 1215 class M(torch.nn.Module): 1216 def __init__(self, dtype): 1217 super().__init__() 1218 self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype) 1219 1220 def forward(self, x): 1221 return self.conv(x) 1222 1223 quantizer = DtypeActQuantizer() 1224 node_occurrence = { 1225 # one for input of the first conv, one for output for the first conv 1226 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 1227 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 1228 } 1229 node_list = [ 1230 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1231 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1232 torch.ops.aten.conv2d.default, 1233 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1234 ] 1235 example_inputs = (torch.randn(1, 3, 3, 3, dtype=dtype),) 1236 m = self._test_quantizer( 1237 M(dtype).eval(), 1238 example_inputs, 1239 quantizer, 1240 node_occurrence, 1241 node_list, 1242 ) 1243 1244 def verify_quant_dequant_iotypes(m): 1245 for node in m.graph.nodes: 1246 if ( 1247 node.op == "call_function" 1248 and node.target.__name__ == "dequantize_per_tensor.default" 1249 ): 1250 # Check dequantize node 1251 dequant_node = node 1252 dequant_in_dtype = dequant_node.args[5] 1253 dequant_out_dtype = torch.float32 1254 if "out_dtype" in dequant_node.kwargs: 1255 dequant_out_dtype = dequant_node.kwargs["out_dtype"] 1256 1257 # Check preceding quantize node 1258 # Depending on fold_quantize flag, quantize node may be absent 1259 quant_node = node.args[0] 1260 if ( 1261 quant_node.op == "call_function" 1262 and quant_node.target.__name__ == "quantize_per_tensor.default" 1263 ): 1264 quant_in_dtype = torch.float32 1265 if "val" in quant_node.args[0].meta: 1266 quant_in_dtype = quant_node.args[0].meta["val"].dtype 1267 quant_out_dtype = quant_node.args[5] 1268 assert ( 1269 quant_in_dtype == dequant_out_dtype 1270 and quant_out_dtype == dequant_in_dtype 1271 ), "quant dequant io dtype check failed!" 1272 1273 verify_quant_dequant_iotypes(m) 1274 1275 def test_input_edge_sanity_check(self): 1276 class M(torch.nn.Module): 1277 def forward(self, x): 1278 return x + 6 1279 1280 class BackendAQuantizer(Quantizer): 1281 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1282 for node in model.graph.nodes: 1283 if ( 1284 node.op == "call_function" 1285 and node.target == torch.ops.aten.add.Tensor 1286 ): 1287 input_act1 = node.args[0] 1288 # this is a constant, so not valid for annotation 1289 input_act2 = node.args[1] 1290 act_qspec = QuantizationSpec( 1291 dtype=torch.uint8, 1292 quant_min=0, 1293 quant_max=255, 1294 qscheme=torch.per_tensor_affine, 1295 is_dynamic=False, 1296 observer_or_fake_quant_ctr=observer.default_observer, 1297 ) 1298 node.meta["quantization_annotation"] = QuantizationAnnotation( 1299 input_qspec_map={ 1300 input_act1: act_qspec, 1301 # this is supposed to error out 1302 input_act2: act_qspec, 1303 }, 1304 output_qspec=act_qspec, 1305 _annotated=True, 1306 ) 1307 1308 def validate(self, model: torch.fx.GraphModule) -> None: 1309 pass 1310 1311 m = M().eval() 1312 example_inputs = torch.randn(1, 2, 3, 3) 1313 m = capture_pre_autograd_graph(m, (example_inputs,)) 1314 with self.assertRaises(Exception): 1315 m = prepare_pt2e(m, BackendAQuantizer()) 1316 1317 def test_fold_quantize(self): 1318 """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)""" 1319 m = self._get_pt2e_quantized_linear() 1320 node_occurrence = { 1321 # quantize op for weight node is folded 1322 ns.call_function( 1323 torch.ops.quantized_decomposed.quantize_per_tensor.default 1324 ): 2, 1325 ns.call_function( 1326 torch.ops.quantized_decomposed.dequantize_per_tensor.default 1327 ): 3, 1328 } 1329 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1330 1331 def test_fold_quantize_per_channel(self): 1332 """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)""" 1333 m = self._get_pt2e_quantized_linear(is_per_channel=True) 1334 node_occurrence = { 1335 # quantize op for weight node is folded 1336 ns.call_function( 1337 torch.ops.quantized_decomposed.quantize_per_tensor.default 1338 ): 2, 1339 ns.call_function( 1340 torch.ops.quantized_decomposed.dequantize_per_channel.default 1341 ): 1, 1342 ns.call_function( 1343 torch.ops.quantized_decomposed.dequantize_per_tensor.default 1344 ): 2, 1345 } 1346 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1347 1348 def test_dont_fold_other_constant(self): 1349 """Make sure the constant propagation does not apply to things unrelated to 1350 quantization 1351 """ 1352 1353 class M(torch.nn.Module): 1354 def __init__(self) -> None: 1355 super().__init__() 1356 self.linear = torch.nn.Linear(2, 2) 1357 self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2)) 1358 1359 def forward(self, x): 1360 t = self.dont_fold_me.t() 1361 return self.linear(x) + t 1362 1363 quantizer = XNNPACKQuantizer() 1364 operator_config = get_symmetric_quantization_config(is_per_channel=False) 1365 # only quantize linear, so add is not quantized and the constant Tensor 1366 # should not be folded 1367 quantizer.set_module_type(torch.nn.Linear, operator_config) 1368 example_inputs = (torch.randn(2, 2),) 1369 m = M().eval() 1370 m = self._quantize(m, quantizer, example_inputs) 1371 node_occurrence = { 1372 # quantize op for weight node is folded 1373 ns.call_function( 1374 torch.ops.quantized_decomposed.quantize_per_tensor.default 1375 ): 2, 1376 ns.call_function( 1377 torch.ops.quantized_decomposed.dequantize_per_tensor.default 1378 ): 3, 1379 # transpose op not folded 1380 ns.call_function(torch.ops.aten.t.default): 1, 1381 } 1382 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1383 1384 def test_fold_all_ops_before_quantize(self): 1385 """Test folding all ops that's before quantized operator: 1386 Before: 1387 get_attr(weight) -> transpose -> quantize -> dequantize 1388 After: 1389 get_attr(folded_weight) -> dequantize 1390 """ 1391 1392 class M(torch.nn.Module): 1393 def __init__(self) -> None: 1394 super().__init__() 1395 self.weight = torch.randn(2, 2) 1396 1397 def forward(self, x): 1398 t = self.weight.t() 1399 return torch.nn.functional.linear(x, t) 1400 1401 quantizer = XNNPACKQuantizer() 1402 operator_config = get_symmetric_quantization_config(is_per_channel=False) 1403 quantizer.set_global(operator_config) 1404 example_inputs = (torch.randn(2, 2),) 1405 m = M().eval() 1406 m = self._quantize(m, quantizer, example_inputs) 1407 node_occurrence = { 1408 # quantize op for weight node is folded 1409 ns.call_function( 1410 torch.ops.quantized_decomposed.quantize_per_tensor.default 1411 ): 2, 1412 ns.call_function( 1413 torch.ops.quantized_decomposed.dequantize_per_tensor.default 1414 ): 3, 1415 } 1416 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1417 1418 def test_constant_prop_preserve_metadata(self): 1419 """Test to make sure the get_attr node for const propagated weight Tensor gets the correct 1420 metadata (from original get_attr node from weight) 1421 """ 1422 1423 class M(torch.nn.Module): 1424 def __init__(self) -> None: 1425 super().__init__() 1426 self.linear = torch.nn.Linear(2, 2) 1427 1428 def forward(self, x): 1429 return self.linear(x) 1430 1431 quantizer = XNNPACKQuantizer() 1432 operator_config = get_symmetric_quantization_config() 1433 quantizer.set_global(operator_config) 1434 example_inputs = (torch.randn(2, 2),) 1435 m = M().eval() 1436 m = capture_pre_autograd_graph( 1437 m, 1438 example_inputs, 1439 ) 1440 weight_meta = None 1441 for n in m.graph.nodes: 1442 if ( 1443 n.op == "get_attr" 1444 and next(iter(n.users)).target == torch.ops.aten.linear.default 1445 ): 1446 weight_meta = n.meta 1447 break 1448 assert weight_meta is not None, "Expect to find metadata for weight node" 1449 1450 m = prepare_pt2e(m, quantizer) 1451 m(*example_inputs) 1452 m = convert_pt2e(m) 1453 1454 for n in m.graph.nodes: 1455 if n.op == "get_attr" and "frozen_param" in n.target: 1456 for key in n.meta: 1457 self.assertEqual(n.meta[key], weight_meta[key]) 1458 1459 def test_save_load(self): 1460 """Test save/load a quantized model""" 1461 m = self._get_pt2e_quantized_linear() 1462 example_inputs = (torch.randn(2, 2),) 1463 ref_res = m(*example_inputs) 1464 1465 with TemporaryFileName() as fname: 1466 # serialization 1467 quantized_ep = torch.export.export(m, example_inputs) 1468 torch.export.save(quantized_ep, fname) 1469 # deserialization 1470 loaded_ep = torch.export.load(fname) 1471 loaded_quantized_model = loaded_ep.module() 1472 res = loaded_quantized_model(*example_inputs) 1473 self.assertEqual(ref_res, res) 1474 1475 def test_composable_quantizer_throw(self): 1476 class BadQuantizer(Quantizer): 1477 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 1478 for n in gm.graph.nodes: 1479 n.meta["quantization_annotation"] = None 1480 1481 def validate(self, model: torch.fx.GraphModule) -> None: 1482 pass 1483 1484 quantizer = XNNPACKQuantizer() 1485 quantization_config = get_symmetric_quantization_config(is_per_channel=True) 1486 quantizer.set_global(quantization_config) 1487 bad_quantizer = BadQuantizer() 1488 composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer]) 1489 m_eager = TestHelperModules.ConvLinearWPermute().eval() 1490 example_inputs = (torch.randn(2, 3, 4, 4),) 1491 self.assertRaises( 1492 RuntimeError, 1493 lambda: self._test_quantizer( 1494 m_eager, example_inputs, composable_quantizer, {} 1495 ), 1496 ) 1497 1498 def test_transform_for_annotation(self): 1499 class TestQuantizer(Quantizer): 1500 def transform_for_annotation( 1501 self, model: torch.fx.GraphModule 1502 ) -> torch.fx.GraphModule: 1503 # Make a copy of the graph to ensure that we are using the 1504 # return value of this function. 1505 graph = torch.fx.Graph() 1506 graph.graph_copy(model.graph, {}) 1507 for n in graph.nodes: 1508 if n.target == torch.ops.aten.add.Tensor: 1509 n.target = torch.ops.aten.mul.Tensor 1510 model = torch.fx.GraphModule(model, graph) 1511 return model 1512 1513 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1514 return model 1515 1516 def validate(self, model: torch.fx.GraphModule) -> None: 1517 pass 1518 1519 class M(torch.nn.Module): 1520 def forward(self, x): 1521 return x + 3 1522 1523 m = M().eval() 1524 quantizer = TestQuantizer() 1525 example_inputs = (torch.randn(1, 2, 3, 3),) 1526 m = capture_pre_autograd_graph(m, example_inputs) 1527 m = prepare_pt2e(m, quantizer) 1528 m(*example_inputs) 1529 node_occurrence = { 1530 ns.call_function(torch.ops.aten.add.Tensor): 0, 1531 ns.call_function(torch.ops.aten.mul.Tensor): 1, 1532 } 1533 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1534 1535 def test_composable_quantizer_transform_for_annotation(self): 1536 class TestQuantizer1(Quantizer): 1537 def transform_for_annotation( 1538 self, model: torch.fx.GraphModule 1539 ) -> torch.fx.GraphModule: 1540 for n in model.graph.nodes: 1541 if n.target == torch.ops.aten.add.Tensor: 1542 n.target = torch.ops.aten.mul.Tensor 1543 return model 1544 1545 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1546 return model 1547 1548 def validate(self, model: torch.fx.GraphModule) -> None: 1549 pass 1550 1551 class TestQuantizer2(Quantizer): 1552 def transform_for_annotation( 1553 self, model: torch.fx.GraphModule 1554 ) -> torch.fx.GraphModule: 1555 for n in model.graph.nodes: 1556 if n.target == torch.ops.aten.sub.Tensor: 1557 n.target = torch.ops.aten.div.Tensor 1558 return model 1559 1560 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 1561 return model 1562 1563 def validate(self, model: torch.fx.GraphModule) -> None: 1564 pass 1565 1566 class M(torch.nn.Module): 1567 def forward(self, x, y, z): 1568 return x + y - z 1569 1570 m = M().eval() 1571 quantizer = ComposableQuantizer([TestQuantizer1(), TestQuantizer2()]) 1572 example_inputs = ( 1573 torch.randn(1, 2, 3, 3), 1574 torch.randn(1, 2, 3, 3), 1575 torch.randn(1, 2, 3, 3), 1576 ) 1577 m = capture_pre_autograd_graph(m, example_inputs) 1578 m = prepare_pt2e(m, quantizer) 1579 m(*example_inputs) 1580 node_occurrence = { 1581 ns.call_function(torch.ops.aten.add.Tensor): 0, 1582 ns.call_function(torch.ops.aten.sub.Tensor): 0, 1583 ns.call_function(torch.ops.aten.mul.Tensor): 1, 1584 ns.call_function(torch.ops.aten.div.Tensor): 1, 1585 } 1586 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1587 1588 def test_embedding_quantizer(self): 1589 m_eager = TestHelperModules.EmbeddingModule().eval() 1590 indices = torch.tensor( 1591 [ 1592 9, 1593 6, 1594 5, 1595 7, 1596 8, 1597 8, 1598 9, 1599 2, 1600 8, 1601 6, 1602 6, 1603 9, 1604 1, 1605 6, 1606 8, 1607 8, 1608 3, 1609 2, 1610 3, 1611 6, 1612 3, 1613 6, 1614 5, 1615 7, 1616 0, 1617 8, 1618 4, 1619 6, 1620 5, 1621 8, 1622 2, 1623 3, 1624 ] 1625 ) 1626 example_inputs = (indices,) 1627 1628 quantizer = EmbeddingQuantizer() 1629 node_occurrence = { 1630 # note: quantize op for weights are const propagated 1631 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1632 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1633 } 1634 node_list = [ 1635 torch.ops.quantized_decomposed.dequantize_per_channel.default, 1636 torch.ops.aten.embedding.default, 1637 ] 1638 # Compare against short term workflow 1639 # cannot compare against fx quant because of the numerical differences coming 1640 # from quantize and dequantize ops 1641 qconfig = default_per_channel_symmetric_qnnpack_qconfig 1642 qconfig_mapping = QConfigMapping().set_global(qconfig) 1643 qconfig_mapping = qconfig_mapping.set_object_type( 1644 torch.nn.Embedding, float_qparams_weight_only_qconfig 1645 ) 1646 self._test_quantizer( 1647 m_eager, 1648 example_inputs, 1649 quantizer, 1650 node_occurrence, 1651 node_list, 1652 True, 1653 qconfig_mapping, 1654 ) 1655 1656 def test_composable_quantizer_linear_conv(self): 1657 dynamic_quantizer = XNNPACKQuantizer() 1658 quantization_config_dynamic = get_symmetric_quantization_config( 1659 is_per_channel=False, is_dynamic=True 1660 ) 1661 dynamic_quantizer.set_global(quantization_config_dynamic) 1662 static_quantizer = XNNPACKQuantizer() 1663 quantization_config = get_symmetric_quantization_config(is_per_channel=True) 1664 static_quantizer.set_global(quantization_config) 1665 # Note that dynamic quantization must be applied first here. 1666 # this is because static quantizer also quantizes linear with static qspec 1667 # and if we apply static_quantizer first then dynamic_quantizer cannot be applied 1668 composable_quantizer = ComposableQuantizer( 1669 [dynamic_quantizer, static_quantizer] 1670 ) 1671 m_eager = TestHelperModules.ConvLinearWPermute().eval() 1672 1673 node_occurrence = { 1674 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 1675 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 1676 # note: quantize op for weights are const propagated 1677 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 1678 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, 1679 # note: quantize op for weights are const propagated 1680 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1681 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 1682 } 1683 act_affine_quant_obs = observer.PlaceholderObserver.with_args( 1684 dtype=torch.qint8, 1685 qscheme=torch.per_tensor_affine, 1686 quant_min=-128, 1687 quant_max=127, 1688 eps=2**-12, 1689 is_dynamic=True, 1690 ) 1691 dynamic_qconfig = QConfig( 1692 activation=act_affine_quant_obs, 1693 weight=weight_observer_range_neg_127_to_127, 1694 ) 1695 # Test with 2d inputs 1696 example_inputs = (torch.randn(2, 3, 4, 4),) 1697 qconfig = default_per_channel_symmetric_qnnpack_qconfig 1698 qconfig_mapping = QConfigMapping().set_global(qconfig) 1699 qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) 1700 # Had to turn off check against fx because fx quant workflow does not seem 1701 # to propagate observers for permute node for this model. 1702 # Suprisingly it does propagate it for EmbeddingConvLinearModule 1703 # TODO: Figure out the right behavior for propagation 1704 self._test_quantizer( 1705 m_eager, 1706 example_inputs, 1707 composable_quantizer, 1708 node_occurrence, 1709 [], 1710 False, 1711 qconfig_mapping, 1712 ) 1713 1714 def test_embedding_conv_linear_quantization(self): 1715 m_eager = TestHelperModules.EmbeddingConvLinearModule().eval() 1716 indices = torch.tensor( 1717 [ 1718 9, 1719 6, 1720 5, 1721 7, 1722 8, 1723 8, 1724 9, 1725 2, 1726 8, 1727 6, 1728 6, 1729 9, 1730 1, 1731 6, 1732 8, 1733 8, 1734 3, 1735 2, 1736 3, 1737 6, 1738 3, 1739 6, 1740 5, 1741 7, 1742 0, 1743 8, 1744 4, 1745 6, 1746 5, 1747 8, 1748 2, 1749 3, 1750 ] 1751 ) 1752 indices = torch.unsqueeze(indices, 0) 1753 example_inputs = (indices,) 1754 1755 embedding_quantizer = EmbeddingQuantizer() 1756 dynamic_quantizer = XNNPACKQuantizer() 1757 quantization_config_dynamic = get_symmetric_quantization_config( 1758 is_per_channel=True, is_dynamic=True 1759 ) 1760 dynamic_quantizer.set_global(quantization_config_dynamic) 1761 static_quantizer = XNNPACKQuantizer() 1762 quantization_config = get_symmetric_quantization_config(is_per_channel=True) 1763 static_quantizer.set_global(quantization_config) 1764 composed_quantizer = ComposableQuantizer( 1765 [embedding_quantizer, dynamic_quantizer, static_quantizer] 1766 ) 1767 1768 act_affine_quant_obs = observer.PlaceholderObserver.with_args( 1769 dtype=torch.qint8, 1770 qscheme=torch.per_tensor_affine, 1771 quant_min=-128, 1772 quant_max=127, 1773 eps=2**-12, 1774 is_dynamic=True, 1775 ) 1776 dynamic_qconfig = QConfig( 1777 activation=act_affine_quant_obs, 1778 weight=per_channel_weight_observer_range_neg_127_to_127, 1779 ) 1780 qconfig = default_per_channel_symmetric_qnnpack_qconfig 1781 qconfig_mapping = QConfigMapping().set_global(qconfig) 1782 qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig) 1783 qconfig_mapping = qconfig_mapping.set_object_type( 1784 torch.nn.Embedding, float_qparams_weight_only_qconfig 1785 ) 1786 1787 node_occurrence = { 1788 torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, 1789 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, 1790 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 1791 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 1792 # note: quantize op for weights are const propagated 1793 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 1794 torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, 1795 } 1796 self._test_quantizer( 1797 m_eager, 1798 example_inputs, 1799 composed_quantizer, 1800 node_occurrence, 1801 [], 1802 True, 1803 qconfig_mapping, 1804 ) 1805 1806 def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload): 1807 """ 1808 Return the first node matching the specified target, throwing an exception 1809 if no such batch norm node is found. 1810 """ 1811 for n in m.graph.nodes: 1812 if n.target == target: 1813 return n 1814 raise ValueError("Did not find node with target ", target) 1815 1816 def _test_move_exported_model_dropout(self, inplace: bool): 1817 """ 1818 Test switching dropout behavior between train and eval modes using 1819 `move_exported_model_to_eval` and `move_exported_model_to_train` APIs. 1820 """ 1821 1822 class M(torch.nn.Module): 1823 def __init__(self) -> None: 1824 super().__init__() 1825 self.dropout = torch.nn.Dropout(0.5, inplace=inplace) 1826 1827 def forward(self, x): 1828 return self.dropout(x) 1829 1830 example_inputs = (torch.randn(1),) 1831 m = M().train() 1832 m = capture_pre_autograd_graph(m, example_inputs) 1833 if inplace: 1834 target = torch.ops.aten.dropout_.default 1835 else: 1836 target = torch.ops.aten.dropout.default 1837 1838 # Assert that dropout op exists and is in train mode 1839 dropout_node = self._get_node(m, target) 1840 self.assertTrue(dropout_node is not None) 1841 self.assertTrue(dropout_node.args[2]) 1842 1843 # Move to eval 1844 torch.ao.quantization.move_exported_model_to_eval(m) 1845 1846 # Assert that dropout op is now in eval mode 1847 dropout_node = self._get_node(m, target) 1848 self.assertTrue(dropout_node is not None) 1849 self.assertTrue(not dropout_node.args[2]) 1850 1851 # Move back to train 1852 torch.ao.quantization.move_exported_model_to_train(m) 1853 1854 # Assert that dropout op is now in train mode again 1855 dropout_node = self._get_node(m, target) 1856 self.assertTrue(dropout_node is not None) 1857 self.assertTrue(dropout_node.args[2]) 1858 1859 def test_move_exported_model_dropout(self): 1860 self._test_move_exported_model_dropout(inplace=False) 1861 1862 def test_move_exported_model_dropout_inplace(self): 1863 self._test_move_exported_model_dropout(inplace=True) 1864 1865 def _get_bn_train_eval_ops(self): 1866 if capture_pre_autograd_graph_using_training_ir(): 1867 return ( 1868 torch.ops.aten.batch_norm.default, 1869 torch.ops.aten.batch_norm.default, 1870 ) 1871 # TODO: This branch is going through a deprecated branch and should be deleted soon, 1872 # after capture_pre_autograd_graph fully migrate to training IR 1873 # T199018392 1874 if TEST_WITH_ROCM: 1875 return ( 1876 torch.ops.aten.miopen_batch_norm.default, 1877 torch.ops.aten.miopen_batch_norm.default, 1878 ) 1879 elif TEST_CUDA: 1880 return ( 1881 torch.ops.aten.cudnn_batch_norm.default, 1882 torch.ops.aten.cudnn_batch_norm.default, 1883 ) 1884 else: 1885 return ( 1886 torch.ops.aten._native_batch_norm_legit.default, 1887 torch.ops.aten._native_batch_norm_legit_no_training.default, 1888 ) 1889 1890 def test_move_exported_model_bn(self): 1891 """ 1892 Test switching batch_norm behavior between train and eval modes using 1893 `move_exported_model_to_eval` and `move_exported_model_to_train` APIs. 1894 """ 1895 1896 class M(torch.nn.Module): 1897 def __init__(self) -> None: 1898 super().__init__() 1899 self.bn = torch.nn.BatchNorm2d(3) 1900 1901 def forward(self, x): 1902 return self.bn(x) 1903 1904 if TEST_CUDA: 1905 m = M().train().cuda() 1906 example_inputs = (torch.randn(1, 3, 3, 3).cuda(),) 1907 else: 1908 m = M().train() 1909 example_inputs = (torch.randn(1, 3, 3, 3),) 1910 bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() 1911 m = capture_pre_autograd_graph(m, example_inputs) 1912 1913 # Assert that batch norm op exists and is in train mode 1914 bn_node = self._get_node(m, bn_train_op) 1915 self.assertTrue(bn_node is not None) 1916 self.assertTrue(bn_node.args[5]) 1917 1918 # Move to eval 1919 torch.ao.quantization.move_exported_model_to_eval(m) 1920 1921 # Assert that batch norm op is now in eval mode 1922 bn_node = self._get_node(m, bn_eval_op) 1923 self.assertTrue(bn_node is not None) 1924 1925 # Move to train 1926 torch.ao.quantization.move_exported_model_to_train(m) 1927 1928 # Assert that batch norm op is now in train mode again 1929 bn_node = self._get_node(m, bn_train_op) 1930 self.assertTrue(bn_node is not None) 1931 self.assertTrue(bn_node.args[5]) 1932 1933 def test_disallow_eval_train(self): 1934 m = TestHelperModules.ConvWithBNRelu(relu=True) 1935 example_inputs = (torch.rand(3, 3, 5, 5),) 1936 1937 # Before export: this is OK 1938 m.eval() 1939 m.train() 1940 1941 # After export: this is not OK 1942 m = capture_pre_autograd_graph(m, example_inputs) 1943 with self.assertRaises(NotImplementedError): 1944 m.eval() 1945 with self.assertRaises(NotImplementedError): 1946 m.train() 1947 1948 # After prepare: still not OK 1949 quantizer = XNNPACKQuantizer() 1950 m = prepare_qat_pt2e(m, quantizer) 1951 with self.assertRaises(NotImplementedError): 1952 m.eval() 1953 with self.assertRaises(NotImplementedError): 1954 m.train() 1955 1956 # After convert: still not OK 1957 m = convert_pt2e(m) 1958 with self.assertRaises(NotImplementedError): 1959 m.eval() 1960 with self.assertRaises(NotImplementedError): 1961 m.train() 1962 1963 def test_allow_exported_model_train_eval(self): 1964 class M(torch.nn.Module): 1965 def __init__(self) -> None: 1966 super().__init__() 1967 self.bn = torch.nn.BatchNorm2d(3) 1968 self.dropout = torch.nn.Dropout(0.5) 1969 1970 def forward(self, x): 1971 x = self.bn(x) 1972 x = self.dropout(x) 1973 return x 1974 1975 if TEST_CUDA: 1976 m = M().train().cuda() 1977 example_inputs = (torch.randn(1, 3, 3, 3).cuda(),) 1978 else: 1979 m = M().train() 1980 example_inputs = (torch.randn(1, 3, 3, 3),) 1981 bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() 1982 m = capture_pre_autograd_graph(m, example_inputs) 1983 1984 def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): 1985 targets = [n.target for n in m.graph.nodes] 1986 bn_op = bn_train_op if train else bn_eval_op 1987 bn_node = self._get_node(m, bn_op) 1988 self.assertTrue(bn_node is not None) 1989 if TEST_CUDA: 1990 self.assertEqual(bn_node.args[5], train) 1991 dropout_node = self._get_node(m, torch.ops.aten.dropout.default) 1992 self.assertEqual(dropout_node.args[2], train) 1993 1994 # Before wrapping: this is not OK 1995 with self.assertRaises(NotImplementedError): 1996 m.eval() 1997 with self.assertRaises(NotImplementedError): 1998 m.train() 1999 2000 # After wrapping: does not error and swaps the ops accordingly 2001 torch.ao.quantization.allow_exported_model_train_eval(m) 2002 m.eval() 2003 _assert_ops_are_correct(m, train=False) 2004 m.train() 2005 _assert_ops_are_correct(m, train=True) 2006 2007 # After prepare but before wrapping: this is not OK 2008 quantizer = XNNPACKQuantizer() 2009 m = prepare_qat_pt2e(m, quantizer) 2010 with self.assertRaises(NotImplementedError): 2011 m.eval() 2012 with self.assertRaises(NotImplementedError): 2013 m.train() 2014 2015 # After prepare and after wrapping: does not error and swaps the ops accordingly 2016 torch.ao.quantization.allow_exported_model_train_eval(m) 2017 m.eval() 2018 _assert_ops_are_correct(m, train=False) 2019 m.train() 2020 _assert_ops_are_correct(m, train=True) 2021 2022 # After convert but before wrapping: this is not OK 2023 m = convert_pt2e(m, fold_quantize=True) 2024 with self.assertRaises(NotImplementedError): 2025 m.eval() 2026 with self.assertRaises(NotImplementedError): 2027 m.train() 2028 2029 # After convert and after wrapping: does not error and swaps the ops accordingly 2030 torch.ao.quantization.allow_exported_model_train_eval(m) 2031 m.eval() 2032 _assert_ops_are_correct(m, train=False) 2033 m.train() 2034 _assert_ops_are_correct(m, train=True) 2035 2036 def test_model_is_exported(self): 2037 m = TestHelperModules.ConvWithBNRelu(relu=True) 2038 example_inputs = (torch.rand(3, 3, 5, 5),) 2039 exported_gm = capture_pre_autograd_graph(m, example_inputs) 2040 fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) 2041 self.assertTrue( 2042 torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm) 2043 ) 2044 self.assertFalse( 2045 torch.ao.quantization.pt2e.export_utils.model_is_exported(fx_traced_gm) 2046 ) 2047 self.assertFalse(torch.ao.quantization.pt2e.export_utils.model_is_exported(m)) 2048 2049 def test_reentrant(self): 2050 """Test we can safely call quantization apis multiple times""" 2051 m = TestHelperModules.ConvBnReLU2dAndLinearReLU() 2052 example_inputs = (torch.randn(3, 3, 10, 10),) 2053 2054 quantizer = XNNPACKQuantizer().set_global( 2055 get_symmetric_quantization_config(is_per_channel=True, is_qat=True) 2056 ) 2057 m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs) 2058 m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) 2059 m(*example_inputs) 2060 m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) 2061 2062 quantizer = XNNPACKQuantizer().set_module_type( 2063 torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) 2064 ) 2065 m = capture_pre_autograd_graph(m, example_inputs) 2066 m = prepare_pt2e(m, quantizer) 2067 m = convert_pt2e(m) 2068 2069 node_occurrence = { 2070 ns.call_function( 2071 torch.ops.quantized_decomposed.quantize_per_tensor.default 2072 ): 4, 2073 # one for weight 2074 ns.call_function( 2075 torch.ops.quantized_decomposed.dequantize_per_tensor.default 2076 ): 5, 2077 ns.call_function( 2078 torch.ops.quantized_decomposed.dequantize_per_channel.default 2079 ): 1, 2080 } 2081 node_list = [ 2082 ns.call_function( 2083 torch.ops.quantized_decomposed.dequantize_per_tensor.default 2084 ), 2085 ns.call_function(torch.ops.aten.conv2d.default), 2086 ns.call_function(torch.ops.aten.relu.default), 2087 ns.call_function( 2088 torch.ops.quantized_decomposed.quantize_per_tensor.default 2089 ), 2090 ns.call_function( 2091 torch.ops.quantized_decomposed.dequantize_per_tensor.default 2092 ), 2093 ns.call_function(torch.ops.aten.linear.default), 2094 ns.call_function( 2095 torch.ops.quantized_decomposed.quantize_per_tensor.default 2096 ), 2097 ] 2098 self.checkGraphModuleNodes( 2099 m, expected_node_occurrence=node_occurrence, expected_node_list=node_list 2100 ) 2101 2102 def test_groupwise_per_channel_quant(self): 2103 m = TestHelperModules.GroupwiseConv2d() 2104 quantizer = XNNPACKQuantizer() 2105 operator_config = get_symmetric_quantization_config(is_per_channel=True) 2106 quantizer.set_global(operator_config) 2107 example_inputs = m.example_inputs() 2108 m = self._quantize(m, quantizer, example_inputs) 2109 # make sure it runs 2110 m(*example_inputs) 2111 2112 def test_observer_callback(self): 2113 from torch.library import impl, Library 2114 2115 test_lib = Library("test_int4", "DEF") # noqa: TOR901 2116 test_lib.define( 2117 "quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor" 2118 ) 2119 2120 @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") 2121 def quantize_per_tensor_int4( 2122 input: torch.Tensor, 2123 scale: float, 2124 zero_point: int, 2125 ) -> torch.Tensor: 2126 inv_scale = 1.0 / scale 2127 return ( 2128 torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15) 2129 .to(torch.uint8) 2130 .view(torch.bits8) 2131 ) 2132 2133 test_lib.define( 2134 "dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor" 2135 ) 2136 2137 @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") 2138 def dequantize_per_tensor_int4( 2139 input: torch.Tensor, 2140 scale: float, 2141 zero_point: int, 2142 ) -> torch.Tensor: 2143 return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale 2144 2145 from torch.ao.quantization.observer import ObserverBase 2146 2147 class Int4Observer(ObserverBase): 2148 def __init__(self, *args, **kwargs): 2149 # just faking a dtype here 2150 super().__init__(dtype=torch.int8) 2151 2152 def forward(self, x): 2153 return x 2154 2155 def calculate_qparams(self, **kwargs): 2156 pass 2157 2158 def convert(self, model: torch.fx.GraphModule, observer_node: Node): 2159 with model.graph.inserting_before(observer_node): 2160 q_node = model.graph.call_function( 2161 torch.ops.test_int4.quantize_per_tensor_int4, 2162 (observer_node.args[0], 1.0, 0), 2163 {}, 2164 ) 2165 dq_node = model.graph.call_function( 2166 torch.ops.test_int4.dequantize_per_tensor_int4, 2167 (q_node, 1.0, 0), 2168 {}, 2169 ) 2170 observer_node.replace_all_uses_with(dq_node) 2171 model.graph.erase_node(observer_node) 2172 2173 class BackendAQuantizer(Quantizer): 2174 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 2175 for node in model.graph.nodes: 2176 if ( 2177 node.op == "call_function" 2178 and node.target == torch.ops.aten.add.Tensor 2179 ): 2180 input_act0 = node.args[0] 2181 assert isinstance(input_act0, Node) 2182 input_act1 = node.args[1] 2183 assert isinstance(input_act1, Node) 2184 2185 act_qspec = QuantizationSpec( 2186 dtype=torch.uint8, 2187 quant_min=0, 2188 quant_max=255, 2189 qscheme=torch.per_tensor_affine, 2190 is_dynamic=False, 2191 observer_or_fake_quant_ctr=Int4Observer, 2192 ) 2193 node.meta["quantization_annotation"] = QuantizationAnnotation( 2194 input_qspec_map={ 2195 input_act0: act_qspec, 2196 input_act1: act_qspec, 2197 }, 2198 output_qspec=act_qspec, 2199 _annotated=True, 2200 ) 2201 2202 def validate(self, model: torch.fx.GraphModule) -> None: 2203 pass 2204 2205 class M(torch.nn.Module): 2206 def forward(self, x1, x2): 2207 return x1 + x2 2208 2209 example_inputs = ( 2210 torch.randn(1, 3, 5, 5), 2211 torch.randn(1, 3, 5, 5), 2212 ) 2213 node_occurrence = { 2214 # two for input of the first conv, one for output for the first conv 2215 torch.ops.test_int4.quantize_per_tensor_int4: 3, 2216 torch.ops.test_int4.dequantize_per_tensor_int4: 3, 2217 } 2218 node_list = [ 2219 torch.ops.test_int4.dequantize_per_tensor_int4, 2220 torch.ops.test_int4.dequantize_per_tensor_int4, 2221 torch.ops.aten.add.Tensor, 2222 torch.ops.test_int4.quantize_per_tensor_int4, 2223 ] 2224 self._test_quantizer( 2225 M().eval(), 2226 example_inputs, 2227 BackendAQuantizer(), 2228 node_occurrence, 2229 node_list, 2230 ) 2231 2232 def test_speed(self): 2233 import time 2234 2235 def dynamic_quantize_pt2e(model, example_inputs): 2236 torch._dynamo.reset() 2237 model = capture_pre_autograd_graph(model, example_inputs) 2238 # Per channel quantization for weight 2239 # Dynamic quantization for activation 2240 # Please read a detail: https://fburl.com/code/30zds51q 2241 embedding_quantizer = EmbeddingQuantizer() 2242 dynamic_quantizer = XNNPACKQuantizer() 2243 operator_config_dynamic = get_symmetric_quantization_config( 2244 is_per_channel=True, is_dynamic=True 2245 ) 2246 dynamic_quantizer.set_global(operator_config_dynamic) 2247 composed_quantizer = ComposableQuantizer( 2248 [embedding_quantizer, dynamic_quantizer] 2249 ) 2250 prev = time.time() 2251 model = prepare_qat_pt2e(model, composed_quantizer) 2252 cur = time.time() 2253 # print("prepare time:", cur - prev) 2254 # Without Calibraiton, scale/zero value will have an initialized value of 1.0 2255 # Per channel quantization needs a proper scale/zero shape/value to work properly. 2256 # So we need to run calibration before converting to quantized model. 2257 model(*example_inputs) 2258 prev = time.time() 2259 model = convert_pt2e(model) 2260 cur = time.time() 2261 # uncomment to see the time 2262 # print("convert time:", cur - prev) 2263 return model 2264 2265 class M(torch.nn.Module): 2266 def __init__(self) -> None: 2267 super().__init__() 2268 self.linear = torch.nn.Linear(5, 5) 2269 2270 def forward(self, x): 2271 return self.linear(x) 2272 2273 m = M().eval() 2274 example_inputs = (torch.randn(5, 5),) 2275 _ = dynamic_quantize_pt2e(m, example_inputs) 2276 2277 def test_conv_transpose_bn_relu(self): 2278 class BackendAQuantizer(Quantizer): 2279 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 2280 int8_qspec = QuantizationSpec( 2281 dtype=torch.int8, 2282 quant_min=-128, 2283 quant_max=127, 2284 qscheme=torch.per_tensor_symmetric, 2285 is_dynamic=False, 2286 observer_or_fake_quant_ctr=observer.default_weight_observer, 2287 ) 2288 quantization_config = QuantizationConfig( 2289 input_activation=int8_qspec, 2290 weight=int8_qspec, 2291 bias=None, 2292 output_activation=int8_qspec, 2293 ) 2294 # conv_transpose + bn is fused automatically in PTQ (not configurable) 2295 # so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu 2296 # pattern 2297 OP_TO_ANNOTATOR["conv_transpose_relu"](model, quantization_config) 2298 2299 def validate(self, model: torch.fx.GraphModule) -> None: 2300 pass 2301 2302 example_inputs = (torch.randn(1, 3, 5, 5),) 2303 node_occurrence = { 2304 # two for input of the first conv, one for output for the first conv 2305 torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, 2306 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, 2307 } 2308 node_list = [ 2309 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2310 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 2311 torch.ops.aten.conv_transpose2d.input, 2312 torch.ops.aten.relu.default, 2313 torch.ops.quantized_decomposed.quantize_per_tensor.default, 2314 ] 2315 self._test_quantizer( 2316 TestHelperModules.ConvTWithBNRelu(relu=True, bn=True), 2317 example_inputs, 2318 BackendAQuantizer(), 2319 node_occurrence, 2320 node_list, 2321 ) 2322 2323 def test_multi_users_without_output_observer(self): 2324 """ 2325 Test the case in which a node is used by multiple users, 2326 and had its output observer removed. 2327 """ 2328 2329 class M(torch.nn.Module): 2330 def __init__(self) -> None: 2331 super().__init__() 2332 self.conv = torch.nn.Conv2d(3, 3, 3) 2333 2334 def forward(self, x): 2335 x = self.conv(x) 2336 return x, x + 1 2337 2338 example_inputs = (torch.randn(1, 3, 5, 5),) 2339 m = M() 2340 m = capture_pre_autograd_graph(m, example_inputs) 2341 quantizer = XNNPACKQuantizer().set_global( 2342 get_symmetric_quantization_config(), 2343 ) 2344 m = prepare_pt2e(m, quantizer) 2345 m(*example_inputs) 2346 2347 # Remove output observer 2348 observer_to_remove = None 2349 for n in m.graph.nodes: 2350 if n.op == "output": 2351 observer_to_remove = n.args[0][0] 2352 assert observer_to_remove.op == "call_module" 2353 assert observer_to_remove.target.startswith("activation_post_process_") 2354 break 2355 assert observer_to_remove is not None 2356 observer_to_remove.replace_all_uses_with(observer_to_remove.args[0]) 2357 m.graph.erase_node(observer_to_remove) 2358 m.recompile() 2359 2360 # Convert should succeed 2361 m = convert_pt2e(m) 2362 m(*example_inputs) 2363 2364 2365instantiate_parametrized_tests(TestQuantizePT2E) 2366