1# Owner(s): ["oncall: quantization"] 2 3import torch 4import torch.nn as nn 5import torch.ao.nn.quantized as nnq 6from torch.nn.utils.rnn import PackedSequence 7from torch.ao.quantization import ( 8 quantize, 9 prepare, 10 convert, 11 prepare_qat, 12 quantize_dynamic, 13 QuantWrapper, 14 QuantStub, 15 DeQuantStub, 16 default_qconfig, 17 default_dynamic_qconfig, 18 per_channel_dynamic_qconfig, 19 float16_dynamic_qconfig, 20 float_qparams_weight_only_qconfig, 21 float_qparams_weight_only_qconfig_4bit, 22 FixedQParamsObserver, 23 PerChannelMinMaxObserver, 24 default_dynamic_quant_observer, 25 default_weight_observer, 26 QConfig, 27) 28 29from torch.testing._internal.common_quantization import ( 30 QuantizationTestCase, 31 AnnotatedSingleLayerLinearModel, 32 QuantStubModel, 33 ModelWithFunctionals, 34 SingleLayerLinearDynamicModel, 35 TwoLayerLinearModel, 36 NestedModel, 37 ResNetBase, 38 RNNDynamicModel, 39 RNNCellDynamicModel, 40 ActivationsTestModel, 41 NormalizationTestModel, 42 test_only_eval_fn, 43 prepare_dynamic, 44 convert_dynamic, 45 skipIfNoFBGEMM, 46 EmbeddingBagModule, 47 EmbeddingModule, 48 EmbeddingWithStaticLinear, 49 LinearReluLinearModel, 50) 51 52# annotated models 53from torch.testing._internal.common_quantization import ( 54 AnnotatedTwoLayerLinearModel, 55 AnnotatedNestedModel, 56 AnnotatedSubNestedModel, 57 AnnotatedCustomConfigNestedModel, 58 AnnotatedSkipQuantModel, 59) 60 61from torch.testing._internal.common_quantized import ( 62 override_quantized_engine, 63 supported_qengines, 64 override_qengines, 65) 66 67from hypothesis import given 68from hypothesis import strategies as st 69import torch.testing._internal.hypothesis_utils as hu 70hu.assert_deadline_disabled() 71 72# Standard library 73from typing import Tuple 74import numpy as np 75 76class TestQuantizeEagerOps(QuantizationTestCase): 77 @override_qengines 78 def _test_reference_module_impl(self, 79 float_module_class, 80 quantized_module_class, 81 extra_module_kwargs, 82 input_size): 83 class M(torch.nn.Module): 84 def __init__(self) -> None: 85 super().__init__() 86 self.conv = float_module_class(**extra_module_kwargs) 87 self.quant = QuantStub() 88 self.dequant = DeQuantStub() 89 90 def forward(self, x): 91 x = self.quant(x) 92 x = self.conv(x) 93 x = self.dequant(x) 94 return x 95 96 class RefM(torch.nn.Module): 97 def __init__(self) -> None: 98 super().__init__() 99 self.conv = float_module_class(**extra_module_kwargs) 100 self.quant1 = QuantStub() 101 self.dequant1 = DeQuantStub() 102 self.quant2 = QuantStub() 103 self.dequant2 = DeQuantStub() 104 105 def forward(self, x): 106 x = self.quant1(x) 107 x = self.dequant1(x) 108 x = self.conv(x) 109 x = self.quant2(x) 110 x = self.dequant2(x) 111 return x 112 113 qengine = torch.backends.quantized.engine 114 if qengine not in supported_qengines or qengine == 'qnnpack': 115 return # qnnpack does not support nnq.ConvTranspose3d 116 117 data = torch.randn(*input_size, dtype=torch.float) 118 original_m = M() 119 original_ref_m = RefM() 120 121 original_ref_m.conv.weight = torch.nn.Parameter(original_m.conv.weight.detach()) 122 original_ref_m.conv.bias = torch.nn.Parameter(original_m.conv.bias.detach()) 123 124 original_m.qconfig = torch.ao.quantization.default_qconfig 125 126 m = prepare(original_m) 127 # calibration 128 m(data) 129 m = convert(m) 130 # check if the module is properly quantized 131 self.assertEqual(type(m.quant), nnq.Quantize) 132 self.assertEqual(type(m.conv), quantized_module_class) 133 self.assertEqual(type(m.dequant), nnq.DeQuantize) 134 res = m(data) 135 136 # quantize the reference model 137 original_ref_m.eval() 138 original_ref_m.qconfig = torch.ao.quantization.default_qconfig 139 140 ref_m = prepare(original_ref_m) 141 ref_m(data) 142 ref_m = convert(ref_m, is_reference=True) 143 ref_res = ref_m(data) 144 self.assertEqual(res, ref_res) 145 146 def test_conv_1d(self): 147 self._test_reference_module_impl( 148 nn.Conv1d, 149 nnq.Conv1d, 150 {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1}, 151 (16, 1, 1) 152 ) 153 154 def test_conv_2d(self): 155 self._test_reference_module_impl( 156 nn.Conv2d, 157 nnq.Conv2d, 158 {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1}, 159 (16, 1, 10, 10) 160 ) 161 162 def test_conv_3d(self): 163 self._test_reference_module_impl( 164 nn.Conv3d, 165 nnq.Conv3d, 166 {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1}, 167 (16, 1, 10, 10, 10) 168 ) 169 170 def test_conv_transpose_1d(self): 171 self._test_reference_module_impl( 172 nn.ConvTranspose1d, 173 nnq.ConvTranspose1d, 174 {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1}, 175 (16, 1, 1) 176 ) 177 178 def test_conv_transpose_2d(self): 179 self._test_reference_module_impl( 180 nn.ConvTranspose2d, 181 nnq.ConvTranspose2d, 182 {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1}, 183 (16, 1, 10, 10) 184 ) 185 186 def test_conv_transpose_3d(self): 187 self._test_reference_module_impl( 188 nn.ConvTranspose3d, 189 nnq.ConvTranspose3d, 190 {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1}, 191 (16, 1, 10, 10, 10) 192 ) 193 194 def test_linear(self): 195 self._test_reference_module_impl( 196 nn.Linear, 197 nnq.Linear, 198 {'in_features': 5, 'out_features': 10}, 199 (16, 5) 200 ) 201 202 @override_qengines 203 def test_int16_reference_module(self): 204 205 class RefM(torch.nn.Module): 206 def __init__(self) -> None: 207 super().__init__() 208 self.conv = nn.ConvTranspose2d(1, 1, 1) 209 self.quant1 = QuantStub() 210 self.dequant1 = DeQuantStub() 211 self.quant2 = QuantStub() 212 self.dequant2 = DeQuantStub() 213 214 def forward(self, x): 215 x = self.quant1(x) 216 x = self.dequant1(x) 217 x = self.conv(x) 218 x = self.quant2(x) 219 x = self.dequant2(x) 220 return x 221 222 223 input_size = (16, 1, 10, 10) 224 data = torch.randn(*input_size, dtype=torch.float) 225 226 original_ref_m = RefM() 227 rand_w = torch.randn_like(original_ref_m.conv.weight) 228 rand_b = torch.randn_like(original_ref_m.conv.bias) 229 original_ref_m.conv.weight = torch.nn.Parameter(rand_w, requires_grad=False) 230 original_ref_m.conv.bias = torch.nn.Parameter(rand_b, requires_grad=False) 231 232 qengine = torch.backends.quantized.engine 233 if qengine not in supported_qengines: 234 return 235 from torch.ao.quantization.observer import MovingAverageMinMaxObserver 236 237 weight_obs = MovingAverageMinMaxObserver.with_args( 238 dtype=torch.qint32, 239 # set qmin and qmax to represent qint16 240 quant_min=-1 * (2 ** 15), 241 quant_max=(2 ** 15) - 1, 242 qscheme=torch.per_tensor_symmetric, 243 ) 244 act_obs = MovingAverageMinMaxObserver.with_args( 245 dtype=torch.qint32, 246 quant_min=-1 * (2 ** 15), 247 quant_max=(2 ** 15) - 1, 248 ) 249 custom_qconfig = QConfig(activation=act_obs, weight=weight_obs) 250 251 # quantize the reference model 252 original_ref_m.eval() 253 original_ref_m.qconfig = custom_qconfig 254 255 ref_m = prepare(original_ref_m) 256 # calibration 257 ref_m(torch.randn(*input_size, dtype=torch.float)) 258 259 ref_m = convert(ref_m, is_reference=True) 260 261 myobs = MovingAverageMinMaxObserver(averaging_constant=0.5, 262 dtype=torch.qint32, 263 # set qmin and qmax to represent qint16 264 quant_min=-1 * (2 ** 15), 265 quant_max=(2 ** 15) - 1, 266 qscheme=torch.per_tensor_symmetric, 267 ) 268 result = myobs(rand_w) 269 qparams = myobs.calculate_qparams() 270 self.assertEqual(ref_m.conv.weight_scale, qparams[0]) 271 272 273 def _test_activation_op_impl( 274 self, float_module_class, quantized_module_class, extra_module_kwargs): 275 """ Implementation for testing common activation ops like leaky relu 276 Args: 277 extra_module_kwargs: keyword args to instantiate the float module 278 """ 279 class M(torch.nn.Module): 280 def __init__(self) -> None: 281 super().__init__() 282 self.activation_op = float_module_class(**extra_module_kwargs) 283 self.quant = QuantStub() 284 self.dequant = DeQuantStub() 285 286 def forward(self, x): 287 x = self.quant(x) 288 x = self.activation_op(x) 289 x = self.dequant(x) 290 return x 291 292 m = M().eval() 293 m.qconfig = default_qconfig 294 m = prepare(m) 295 self.checkObservers(m) 296 m = convert(m) 297 self.assertEqual(type(m.activation_op), quantized_module_class) 298 299 def test_leaky_relu(self): 300 self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False}) 301 302 def test_relu(self): 303 self._test_activation_op_impl(nn.ReLU, nn.ReLU, {'inplace': False}) 304 305 # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out 306 @given(train_mode=st.booleans()) 307 def test_functional_module(self, train_mode): 308 model = ModelWithFunctionals() 309 x = torch.rand(10, 1, dtype=torch.float) 310 xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8) 311 self.checkScriptable(model, [[x]], check_save_load=True) 312 if train_mode: 313 model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm') 314 model = prepare_qat(model) 315 else: 316 model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') 317 model = prepare(model) 318 # Check if observers and quant/dequant nodes are inserted 319 self.checkNoPrepModules(model) 320 self.checkObservers(model) 321 # Calibrate 322 model(xq.dequantize()) 323 model = convert(model) 324 325 def checkQuantized(model): 326 self.checkNoPrepModules(model) 327 self.assertEqual(type(model.myadd), torch.ao.nn.quantized.QFunctional) 328 self.assertEqual(type(model.mycat), torch.ao.nn.quantized.QFunctional) 329 self.assertEqual(type(model.myadd_relu), torch.ao.nn.quantized.QFunctional) 330 self.assertEqual(type(model.mymatmul), torch.ao.nn.quantized.QFunctional) 331 self.checkNoQconfig(model) 332 333 checkQuantized(model) 334 self.checkScriptable(model, [[xq]], check_save_load=True) 335 336class TestQuantizeEagerPTQStatic(QuantizationTestCase): 337 338 def test_single_layer(self): 339 r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped 340 to nnq.Linear which is the quantized version of the module 341 """ 342 for qengine in supported_qengines: 343 with override_quantized_engine(qengine): 344 qconfig = torch.ao.quantization.get_default_qconfig(qengine) 345 model = AnnotatedSingleLayerLinearModel(qengine) 346 model.qconfig = qconfig 347 model = prepare(model) 348 # Check if observers and quant/dequant nodes are inserted 349 self.checkNoPrepModules(model) 350 self.checkHasPrepModules(model.fc1) 351 self.checkObservers(model) 352 353 test_only_eval_fn(model, self.calib_data) 354 model = convert(model) 355 356 def checkQuantized(model): 357 self.checkNoPrepModules(model) 358 self.checkHasPrepModules(model.fc1) 359 self.checkWrappedQuantizedLinear(model.fc1) 360 test_only_eval_fn(model, self.calib_data) 361 self.checkScriptable(model, self.calib_data) 362 self.checkNoQconfig(model) 363 364 checkQuantized(model) 365 366 # test one line API - out of place version 367 base = AnnotatedSingleLayerLinearModel(qengine) 368 base.qconfig = qconfig 369 keys_before = set(base.state_dict().keys()) 370 model = quantize(base, test_only_eval_fn, [self.calib_data]) 371 checkQuantized(model) 372 keys_after = set(base.state_dict().keys()) 373 self.assertEqual(keys_before, keys_after) # simple check that nothing changed 374 375 # in-place version 376 model = AnnotatedSingleLayerLinearModel(qengine) 377 model.qconfig = qconfig 378 quantize(model, test_only_eval_fn, [self.calib_data], inplace=True) 379 checkQuantized(model) 380 381 @skipIfNoFBGEMM 382 def test_two_layers(self): 383 r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one 384 `fc2`, and `fc1`is not quantized 385 """ 386 with override_quantized_engine('fbgemm'): 387 model = AnnotatedTwoLayerLinearModel() 388 model = prepare(model) 389 390 self.checkNoPrepModules(model) 391 self.checkObservers(model) 392 self.checkNoPrepModules(model.fc1) 393 self.checkHasPrepModules(model.fc2) 394 395 test_only_eval_fn(model, self.calib_data) 396 model = convert(model) 397 398 def checkQuantized(model): 399 self.checkNoPrepModules(model) 400 self.checkNoPrepModules(model.fc1) 401 self.checkHasPrepModules(model.fc2) 402 self.assertEqual(type(model.fc1), torch.nn.Linear) 403 self.checkWrappedQuantizedLinear(model.fc2) 404 test_only_eval_fn(model, self.calib_data) 405 self.checkScriptable(model, self.calib_data) 406 self.checkNoQconfig(model) 407 408 checkQuantized(model) 409 410 # test one line API 411 model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn, 412 [self.calib_data]) 413 checkQuantized(model) 414 415 def test_nested1(self): 416 r"""Test quantization for nested model, top level 'fc3' and 417 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized 418 """ 419 for qengine in supported_qengines: 420 with override_quantized_engine(qengine): 421 model = AnnotatedNestedModel(qengine) 422 423 def checkPrepModules(model, before_calib=False): 424 if before_calib: 425 self.checkObservers(model) 426 self.checkNoPrepModules(model) 427 self.checkNoPrepModules(model.sub1) 428 self.checkNoPrepModules(model.sub1.fc) 429 self.checkNoPrepModules(model.sub1.relu) 430 self.checkNoPrepModules(model.sub2) 431 self.checkHasPrepModules(model.sub2.fc1) 432 self.checkNoPrepModules(model.sub2.fc2) 433 self.checkHasPrepModules(model.fc3) 434 435 model = prepare(model) 436 checkPrepModules(model, True) 437 test_only_eval_fn(model, self.calib_data) 438 model = convert(model) 439 440 def checkQuantized(model): 441 checkPrepModules(model) 442 self.checkLinear(model.sub1.fc) 443 self.checkWrappedQuantizedLinear(model.fc3) 444 self.checkWrappedQuantizedLinear(model.sub2.fc1) 445 self.checkLinear(model.sub2.fc2) 446 test_only_eval_fn(model, self.calib_data) 447 self.checkScriptable(model, self.calib_data) 448 self.checkNoQconfig(model) 449 450 checkQuantized(model) 451 452 # test one line API 453 model = quantize(AnnotatedNestedModel(qengine), test_only_eval_fn, 454 [self.calib_data]) 455 checkQuantized(model) 456 457 458 @skipIfNoFBGEMM 459 def test_nested2(self): 460 model = AnnotatedSubNestedModel() 461 model = prepare(model) 462 463 def checkPrepModules(model, before_calib=False): 464 if before_calib: 465 self.checkObservers(model) 466 self.checkNoPrepModules(model) 467 self.checkNoPrepModules(model.sub1) 468 self.checkNoPrepModules(model.sub1.fc) 469 self.checkNoPrepModules(model.sub1.relu) 470 self.checkHasPrepModules(model.sub2) 471 self.checkNoPrepModules(model.sub2.module.fc1) 472 self.checkNoPrepModules(model.sub2.module.fc2) 473 self.checkHasPrepModules(model.fc3) 474 475 checkPrepModules(model, True) 476 477 test_only_eval_fn(model, self.calib_data) 478 model = convert(model) 479 480 def checkQuantized(model): 481 checkPrepModules(model) 482 self.checkLinear(model.sub1.fc) 483 self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) 484 self.checkQuantizedLinear(model.sub2.module.fc1) 485 self.checkQuantizedLinear(model.sub2.module.fc2) 486 self.checkWrappedQuantizedLinear(model.fc3) 487 test_only_eval_fn(model, self.calib_data) 488 self.checkScriptable(model, self.calib_data) 489 self.checkNoQconfig(model) 490 491 checkQuantized(model) 492 493 # test one line API 494 model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn, 495 [self.calib_data]) 496 checkQuantized(model) 497 498 def test_nested3(self): 499 r"""More complicated nested test case with child qconfig overrides 500 parent qconfig 501 """ 502 for qengine in supported_qengines: 503 with override_quantized_engine(qengine): 504 model = AnnotatedCustomConfigNestedModel() 505 model = prepare(model) 506 507 def checkPrepModules(model, before_calib=False): 508 if before_calib: 509 self.checkObservers(model) 510 self.checkNoPrepModules(model) 511 self.checkNoPrepModules(model.sub1) 512 self.checkNoPrepModules(model.sub1.fc) 513 self.checkNoPrepModules(model.sub1.relu) 514 self.checkNoPrepModules(model.sub2) 515 self.checkHasPrepModules(model.sub2.fc1) 516 self.checkHasPrepModules(model.sub2.fc2) 517 self.checkHasPrepModules(model.fc3) 518 519 checkPrepModules(model, True) 520 521 test_only_eval_fn(model, self.calib_data) 522 model = convert(model) 523 524 def checkQuantized(model): 525 checkPrepModules(model) 526 self.checkWrappedQuantizedLinear(model.sub2.fc1) 527 self.checkWrappedQuantizedLinear(model.sub2.fc2) 528 self.checkWrappedQuantizedLinear(model.fc3) 529 test_only_eval_fn(model, self.calib_data) 530 self.checkScriptable(model, self.calib_data) 531 self.checkNoQconfig(model) 532 533 checkQuantized(model) 534 535 # test one line API 536 model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn, 537 [self.calib_data]) 538 checkQuantized(model) 539 540 def test_skip_quant(self): 541 r"""The case when we want to skip quantizing some layers 542 """ 543 for qengine in supported_qengines: 544 with override_quantized_engine(qengine): 545 model = AnnotatedSkipQuantModel(qengine) 546 model = prepare(model) 547 self.checkObservers(model) 548 549 test_only_eval_fn(model, self.calib_data) 550 model = convert(model) 551 552 def checkQuantized(model): 553 self.checkLinear(model.fc) 554 self.checkQuantDequant(model.sub) 555 self.checkQuantizedLinear(model.sub.module.fc1) 556 self.checkQuantizedLinear(model.sub.module.fc2) 557 self.assertEqual(type(model.sub.module.relu1), nn.ReLU) 558 self.assertEqual(type(model.sub.module.relu2), nn.ReLU) 559 self.checkScriptable(model, self.calib_data) 560 self.checkNoQconfig(model) 561 562 checkQuantized(model) 563 564 # test one line API 565 model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, [self.calib_data]) 566 checkQuantized(model) 567 568 @skipIfNoFBGEMM 569 def test_manual(self): 570 r"""User inserts QuantStub and DeQuantStub in model code 571 and call the quantization utility functions. 572 """ 573 model = QuantStubModel() 574 # propagate the qconfig of parents to children, model is changed 575 # inplace 576 model = prepare(model) 577 self.checkObservers(model) 578 579 test_only_eval_fn(model, self.calib_data) 580 model = convert(model) 581 582 def checkQuantized(model): 583 self.assertEqual(type(model.fc), nnq.Linear) 584 test_only_eval_fn(model, self.calib_data) 585 self.checkScriptable(model, self.calib_data) 586 self.checkNoQconfig(model) 587 588 checkQuantized(model) 589 590 # test one line API 591 model = quantize(QuantStubModel(), test_only_eval_fn, [self.calib_data]) 592 checkQuantized(model) 593 594 def test_resnet_base(self): 595 r"""Test quantization for bottleneck topology used in resnet/resnext 596 and add coverage for conversion of average pool and float functional 597 """ 598 for qengine in supported_qengines: 599 with override_quantized_engine(qengine): 600 qconfig = torch.ao.quantization.get_default_qconfig(qengine) 601 model = ResNetBase().float().eval() 602 model.fuse_model() 603 model = QuantWrapper(model) 604 model.qconfig = qconfig 605 model = prepare(model) 606 self.checkObservers(model) 607 test_only_eval_fn(model, self.img_data_2d) 608 model = convert(model) 609 610 def checkQuantized(model): 611 self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d) 612 self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) 613 self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) 614 self.assertEqual(type(model.module.fc), nnq.Linear) 615 616 test_only_eval_fn(model, self.img_data_2d) 617 self.checkNoQconfig(model) 618 619 checkQuantized(model) 620 621 @skipIfNoFBGEMM 622 def test_normalization(self): 623 r""" 624 Test quantization of normalization layers 625 """ 626 model = NormalizationTestModel() 627 model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 628 prepare(model, inplace=True) 629 self.checkObservers(model) 630 test_only_eval_fn(model, self.calib_data) 631 model = convert(model) 632 633 def checkQuantized(model): 634 self.checkNoPrepModules(model.layer_norm) 635 self.checkNoPrepModules(model.group_norm) 636 self.checkNoPrepModules(model.instance_norm1d) 637 self.checkNoPrepModules(model.instance_norm2d) 638 self.checkNoPrepModules(model.instance_norm3d) 639 self.assertEqual(type(model.layer_norm), nnq.LayerNorm) 640 self.assertEqual(type(model.group_norm), nnq.GroupNorm) 641 self.assertEqual(type(model.instance_norm1d), nnq.InstanceNorm1d) 642 self.assertEqual(type(model.instance_norm2d), nnq.InstanceNorm2d) 643 self.assertEqual(type(model.instance_norm3d), nnq.InstanceNorm3d) 644 test_only_eval_fn(model, self.calib_data) 645 self.checkScriptable(model, self.calib_data) 646 self.checkNoQconfig(model) 647 648 checkQuantized(model) 649 650 model_oneline = quantize( 651 NormalizationTestModel(), test_only_eval_fn, [self.calib_data]) 652 checkQuantized(model) 653 654 def test_save_load_state_dict(self): 655 r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict 656 Load the quantized state_dict for eval and compare results against original model 657 """ 658 659 for qengine in supported_qengines: 660 with override_quantized_engine(qengine): 661 model = TwoLayerLinearModel() 662 model = torch.ao.quantization.QuantWrapper(model) 663 model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 664 665 model = prepare(model) 666 # calibrate 667 test_only_eval_fn(model, self.calib_data) 668 model = convert(model) 669 x = torch.rand(2, 5, dtype=torch.float) 670 ref = model(x) 671 672 quant_state_dict = model.state_dict() 673 674 # Create model again for eval 675 model = TwoLayerLinearModel() 676 model = torch.ao.quantization.QuantWrapper(model) 677 model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 678 model = prepare(model) 679 model = convert(model) 680 new_state_dict = model.state_dict() 681 682 # Check to make sure the state dict keys match original model after convert. 683 self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys())) 684 685 model.load_state_dict(quant_state_dict) 686 687 out = model(x) 688 self.assertEqual(ref, out) 689 690 @skipIfNoFBGEMM 691 def test_activations(self): 692 r""" 693 Test quantization of activations 694 """ 695 model = ActivationsTestModel() 696 model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 697 prepare(model, inplace=True) 698 self.checkObservers(model) 699 test_only_eval_fn(model, self.calib_data) 700 model = convert(model) 701 702 def checkQuantized(model): 703 self.checkNoPrepModules(model.hardswish) 704 self.assertEqual(type(model.hardswish), nnq.Hardswish) 705 self.assertEqual(type(model.elu), nnq.ELU) 706 test_only_eval_fn(model, self.calib_data) 707 self.checkScriptable(model, self.calib_data) 708 self.checkNoQconfig(model) 709 710 checkQuantized(model) 711 712 # test one line API 713 model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn, 714 [self.calib_data]) 715 checkQuantized(model_oneline) 716 717 @override_qengines 718 def test_forward_hooks_preserved(self): 719 r"""Test post-training static quantization on preserving 720 pre forward and post forward hooks of original model 721 """ 722 qengine = torch.backends.quantized.engine 723 model = QuantStubModel() 724 counter = { 725 'pre_forwards': 0, 726 'forwards': 0, 727 } 728 729 def fw_pre_hook(h_module, input): 730 counter['pre_forwards'] += 1 731 732 def fw_hook(h_module, input, output): 733 counter['forwards'] += 1 734 735 model.fc.register_forward_pre_hook(fw_pre_hook) 736 model.fc.register_forward_hook(fw_hook) 737 738 model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 739 model = prepare(model) 740 741 def checkHooksIsPresent(model, before_convert=True): 742 num_fwd_hooks = 1 743 if before_convert: 744 self.assertEqual(len(model.quant._forward_hooks.values()), 1, 745 "Quantization observer hook has disappeared") 746 num_fwd_hooks = 2 747 748 self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) 749 self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) 750 self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, 751 "Extra pre forward hooks have appeared on a layer") 752 # During static quantization non stub layers are provided with quantization observer hook too 753 self.assertEqual(len(model.fc._forward_hooks.values()), num_fwd_hooks, 754 "Extra post forward hooks have appeared on a layer") 755 # Implicitly check that fw_hook goes after _observer_forward_hook 756 self.assertEqual(list(model.fc._forward_hooks.values())[-1], fw_hook, 757 "_observer_forward_hook is not a first entry of the hooks list") 758 759 checkHooksIsPresent(model, True) 760 test_only_eval_fn(model, self.calib_data) 761 torch.ao.quantization.convert(model, inplace=True) 762 checkHooksIsPresent(model, False) 763 764 @skipIfNoFBGEMM 765 def test_quantized_embedding(self): 766 r""" Test the post-training quantization flow, serialization and scripting 767 of embedding modules 768 """ 769 770 for qconfig in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]: 771 model = EmbeddingModule().eval() 772 indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 773 weights = torch.randn(10, 12, dtype=torch.float32) 774 model.qconfig = qconfig 775 prepare(model, inplace=True) 776 convert(model, inplace=True) 777 self.assertTrue('QuantizedEmbedding' in str(model)) 778 self.assertEqual(type(model.emb), torch.ao.nn.quantized.Embedding) 779 self.checkScriptable(model, [[indices]], check_save_load=True) 780 781 idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) 782 offsets = torch.LongTensor([0, 4]) 783 x = torch.randn(2, 4) 784 model = EmbeddingWithStaticLinear().eval() 785 prepare(model, inplace=True) 786 convert(model, inplace=True) 787 self.assertTrue('QuantizedEmbedding' in str(model)) 788 self.assertTrue('QuantizedLinear' in str(model)) 789 self.checkQuantizedLinear(model.fc) 790 model(idx, offsets, x) 791 792 @skipIfNoFBGEMM 793 def test_dequant_stub(self): 794 m = QuantStubModel().eval() 795 prepare(m, inplace=True) 796 self.checkObservers(m) 797 convert(m, inplace=True) 798 self.assertEqual(type(m.quant), nnq.Quantize) 799 self.assertEqual(type(m.fc), nnq.Linear) 800 self.assertEqual(type(m.dequant), nnq.DeQuantize) 801 802 # check DeQuantStub is not swapped when it doesn't have a qconfig 803 m2 = QuantStubModel().eval() 804 m2.dequant.qconfig = None 805 prepare(m2, inplace=True) 806 self.checkObservers(m2) 807 convert(m2, inplace=True) 808 self.assertEqual(type(m2.quant), nnq.Quantize) 809 self.assertEqual(type(m2.fc), nnq.Linear) 810 self.assertEqual(type(m2.dequant), DeQuantStub) 811 812 813 def test_quantized_embedding_bag(self): 814 r""" Test the post-training quantization flow, serialization and scripting 815 of embedding_bag modules 816 """ 817 indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 818 offsets = torch.tensor([0, 19, 20, 28, 28, 32]) 819 weights = torch.randn(10, 12, dtype=torch.float32) 820 821 for dtype in [torch.quint8, torch.quint4x2]: 822 model = EmbeddingBagModule().eval() 823 float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, 824 qscheme=torch.per_channel_affine_float_qparams, 825 ch_axis=0) 826 float_qparams_qconfig = QConfig(activation=default_dynamic_quant_observer, 827 weight=float_qparams_observer) 828 model.qconfig = float_qparams_qconfig 829 830 prepare(model, inplace=True) 831 quantized_model = convert(model) 832 833 per_sample_weights = torch.from_numpy(np.random.uniform( 834 low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) 835 836 # Test to make sure module is quantized correctly. 837 self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) 838 self.checkDynamicQuantizedModule(quantized_model.emb, torch.ao.nn.quantized.EmbeddingBag, torch.quint8) 839 self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True) 840 841 class EmbeddingBagWithLinear(torch.nn.Module): 842 def __init__(self) -> None: 843 super().__init__() 844 self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, 845 include_last_offset=True, scale_grad_by_freq=False, mode='sum') 846 self.fc = torch.nn.Linear(5, 5) 847 848 def forward(self, indices, offsets, per_sample_weights, linear_in): 849 return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in) 850 851 # Test quantization of embedding_bag layer only 852 model2 = EmbeddingBagWithLinear().eval() 853 model2.emb.qconfig = float_qparams_qconfig 854 prepare(model2, inplace=True) 855 quantized_model = convert(model2) 856 857 self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model)) 858 self.checkLinear(model2.fc) 859 self.checkDynamicQuantizedModule(quantized_model.emb, torch.ao.nn.quantized.EmbeddingBag, torch.quint8) 860 861 @skipIfNoFBGEMM 862 def test_custom_module_class(self): 863 class CustomModule(torch.nn.Module): 864 def __init__(self) -> None: 865 super().__init__() 866 self.conv = torch.nn.Conv2d(1, 1, 1) 867 868 def forward(self, x): 869 return self.conv(x) 870 871 class ObservedCustomModule(torch.nn.Module): 872 def __init__(self, conv): 873 super().__init__() 874 self.conv = conv 875 876 def forward(self, x): 877 return self.conv(x) 878 879 @classmethod 880 def from_float(cls, float_module): 881 assert hasattr(float_module, 'qconfig') 882 observed = cls(float_module.conv) 883 observed.qconfig = float_module.qconfig 884 return observed 885 886 class QuantizedCustomModule(torch.nn.Module): 887 def __init__(self, conv): 888 super().__init__() 889 self.conv = conv 890 891 def forward(self, x): 892 return self.conv(x) 893 894 @classmethod 895 def from_observed(cls, observed_module): 896 assert hasattr(observed_module, 'qconfig') 897 assert hasattr(observed_module, 'activation_post_process') 898 observed_module.conv.activation_post_process = \ 899 observed_module.activation_post_process 900 quantized = cls(nnq.Conv2d.from_float(observed_module.conv)) 901 return quantized 902 903 class Sub(torch.nn.Module): 904 def __init__(self) -> None: 905 super().__init__() 906 self.custom = CustomModule() 907 908 def forward(self, x): 909 return self.custom(x) 910 911 class M(torch.nn.Module): 912 def __init__(self) -> None: 913 super().__init__() 914 self.quant = QuantStub() 915 self.conv = torch.nn.Conv2d(1, 1, 1) 916 self.sub = Sub() 917 self.dequant = DeQuantStub() 918 919 def forward(self, x): 920 x = self.quant(x) 921 x = self.conv(x) 922 x = self.sub(x) 923 x = self.dequant(x) 924 return x 925 926 class RefM(torch.nn.Module): 927 def __init__(self) -> None: 928 super().__init__() 929 self.quant = QuantStub() 930 self.conv1 = torch.nn.Conv2d(1, 1, 1) 931 self.conv2 = torch.nn.Conv2d(1, 1, 1) 932 self.dequant = DeQuantStub() 933 934 def forward(self, x): 935 x = self.quant(x) 936 x = self.conv1(x) 937 x = self.conv2(x) 938 x = self.dequant(x) 939 return x 940 941 data = torch.randn(1, 1, 1, 1) 942 # instantiate M and RefM and align the parameters 943 original_m = M() 944 original_ref_m = RefM() 945 original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) 946 original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) 947 original_ref_m.conv2.weight = torch.nn.Parameter(original_m.sub.custom.conv.weight.detach()) 948 original_ref_m.conv2.bias = torch.nn.Parameter(original_m.sub.custom.conv.bias.detach()) 949 950 original_m.qconfig = default_qconfig 951 prepare_custom_config_dict = { 952 "float_to_observed_custom_module_class": { 953 CustomModule: ObservedCustomModule 954 } 955 } 956 convert_custom_config_dict = { 957 "observed_to_quantized_custom_module_class": { 958 ObservedCustomModule: QuantizedCustomModule 959 } 960 } 961 m = prepare( 962 original_m, 963 prepare_custom_config_dict=prepare_custom_config_dict) 964 self.checkObservers(m, None, prepare_custom_config_dict) 965 # calibration 966 m(data) 967 # all activation observers are inserted in the top level module 968 969 # check converted/quantized model 970 m = convert( 971 m, 972 convert_custom_config_dict=convert_custom_config_dict) 973 # check if the module is properly quantized 974 self.assertEqual(type(m.quant), nnq.Quantize) 975 self.assertEqual(type(m.conv), nnq.Conv2d) 976 self.assertEqual(type(m.sub), Sub) 977 self.assertEqual(type(m.sub.custom), QuantizedCustomModule) 978 self.assertEqual(type(m.sub.custom.conv), nnq.Conv2d) 979 self.assertEqual(type(m.dequant), nnq.DeQuantize) 980 res = m(data) 981 982 # quantize the reference model 983 original_ref_m.eval() 984 original_ref_m.qconfig = default_qconfig 985 ref_m = prepare(original_ref_m) 986 ref_m(data) 987 ref_m = convert(ref_m) 988 ref_res = ref_m(data) 989 self.assertEqual(res, ref_res) 990 991 @skipIfNoFBGEMM 992 def test_convtranspose_per_channel_fails_early(self): 993 r""" 994 Verifies that attempting to quantize a ConvTranspose module with per-Channel 995 weight observers fails in the prepare step, as opposed to the convert step. 996 """ 997 m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) 998 m.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 999 with self.assertRaises(AssertionError) as context: 1000 mp = torch.ao.quantization.prepare(m) 1001 self.assertTrue( 1002 str(context.exception) == 1003 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') 1004 1005 @skipIfNoFBGEMM 1006 def test_convtranspose_per_channel_qconfig_none(self): 1007 r""" 1008 Verifies that having qconfig==None for conv transpose does not crash 1009 """ 1010 m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) 1011 m.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 1012 m[0].qconfig = None 1013 mp = torch.ao.quantization.prepare(m) 1014 1015 @skipIfNoFBGEMM 1016 def test_quantwrapper_attaches_qconfig_to_dequant(self): 1017 qconfig = torch.ao.quantization.default_qconfig 1018 1019 m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() 1020 for i in range(len(m)): 1021 m[i].qconfig = qconfig 1022 m[i] = torch.ao.quantization.QuantWrapper(m[i]) 1023 1024 mp = torch.ao.quantization.prepare(m) 1025 mq = torch.ao.quantization.convert(mp) 1026 self.assertTrue(isinstance(mq[0].dequant, nnq.DeQuantize)) 1027 1028 def test_activations_in_non_leaf_module_list(self): 1029 """ 1030 Ensure activations like `nn.Sigmoid` and `nn.Tanh` are properly handled in 1031 `non_leaf_module_list`. 1032 """ 1033 class MyModel(torch.nn.Module): 1034 def __init__(self) -> None: 1035 super().__init__() 1036 self.quant = QuantStub() 1037 self.sigmoid = torch.nn.Sigmoid() 1038 self.hardsigmoid = torch.nn.Hardsigmoid() 1039 self.softmax = torch.nn.Softmax() 1040 self.tanh = torch.nn.Tanh() 1041 self.dequant = DeQuantStub() 1042 1043 def forward(self, x): 1044 x = self.quant(x) 1045 x = self.sigmoid(x) 1046 x = self.hardsigmoid(x) 1047 x = self.softmax(x) 1048 x = self.tanh(x) 1049 x = self.dequant(x) 1050 return x 1051 1052 qconfig = QConfig( 1053 activation=FixedQParamsObserver.with_args(scale=123.0, zero_point=0), 1054 weight=default_weight_observer 1055 ) 1056 m = MyModel() 1057 m.qconfig = qconfig 1058 m = prepare(m, observer_non_leaf_module_list=[ 1059 torch.nn.Sigmoid, 1060 torch.nn.Hardsigmoid, 1061 torch.nn.Softmax, 1062 torch.nn.Tanh, 1063 ]) 1064 1065 # Should use the observer specified in the QConfig instead of the default (FixedQParamsFakeQuantize) 1066 self.assertTrue(isinstance(m.sigmoid.activation_post_process, FixedQParamsObserver)) 1067 self.assertTrue(isinstance(m.hardsigmoid.activation_post_process, FixedQParamsObserver)) 1068 self.assertTrue(isinstance(m.softmax.activation_post_process, FixedQParamsObserver)) 1069 self.assertTrue(isinstance(m.tanh.activation_post_process, FixedQParamsObserver)) 1070 1071 @skipIfNoFBGEMM 1072 def test_mha_batch_first_attr_is_copied_in_prepare(self): 1073 class TransformerDecoderLayer(nn.Module): 1074 def __init__(self, d_model, nhead, batch_first): 1075 super().__init__() 1076 self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.1, batch_first=batch_first) 1077 1078 qengine = torch.backends.quantized.engine 1079 for batch_first in [True, False]: 1080 model = TransformerDecoderLayer(512, 8, batch_first) 1081 quantization_config = torch.ao.quantization.get_default_qconfig(qengine) 1082 model.qconfig = quantization_config 1083 prepared_model = torch.ao.quantization.prepare(model, inplace=False) 1084 self.assertTrue(prepared_model.self_attn.batch_first == model.self_attn.batch_first) 1085 1086@skipIfNoFBGEMM 1087class TestQuantizeEagerPTQDynamic(QuantizationTestCase): 1088 def test_single_layer(self): 1089 r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module, 1090 make sure it is swapped to nnqd.Linear which is the quantized version of 1091 the module 1092 """ 1093 for dtype in [torch.qint8, torch.float16]: 1094 model = SingleLayerLinearDynamicModel().eval() 1095 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1096 qconfig_dict = { 1097 'fc1': qconfig 1098 } 1099 prepare_dynamic(model, qconfig_dict) 1100 convert_dynamic(model) 1101 1102 def checkQuantized(model): 1103 self.checkDynamicQuantizedLinear(model.fc1, dtype) 1104 self.checkScriptable(model, self.calib_data, check_save_load=True) 1105 self.checkNoQconfig(model) 1106 1107 checkQuantized(model) 1108 1109 # test one line API - out of place version 1110 base = SingleLayerLinearDynamicModel() 1111 keys_before = set(base.state_dict().keys()) 1112 model = quantize_dynamic(base, qconfig_dict) 1113 checkQuantized(model) 1114 keys_after = set(base.state_dict().keys()) 1115 self.assertEqual(keys_before, keys_after) # simple check that nothing changed 1116 1117 # in-place version 1118 model = SingleLayerLinearDynamicModel() 1119 quantize_dynamic(model, qconfig_dict, inplace=True) 1120 checkQuantized(model) 1121 1122 # Test set qconfig 1123 model = SingleLayerLinearDynamicModel() 1124 quantize_dynamic(model, {nn.Linear}, inplace=True, dtype=dtype) 1125 checkQuantized(model) 1126 1127 def test_two_layers(self): 1128 r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one 1129 `fc2`, and `fc1`is not quantized 1130 """ 1131 for dtype in [torch.qint8, torch.float16]: 1132 model = TwoLayerLinearModel().eval() 1133 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1134 qconfig_dict = { 1135 'fc2': qconfig 1136 } 1137 prepare_dynamic(model, qconfig_dict) 1138 1139 convert_dynamic(model) 1140 1141 def checkQuantized(model): 1142 self.assertEqual(type(model.fc1), torch.nn.Linear) 1143 self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype) 1144 self.checkScriptable(model, self.calib_data, check_save_load=True) 1145 self.checkNoQconfig(model) 1146 1147 checkQuantized(model) 1148 1149 # test one line API 1150 model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict) 1151 checkQuantized(model) 1152 1153 # Test set API 1154 model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype) 1155 checkQuantized(model) 1156 1157 def test_nested1(self): 1158 r"""Test quantization for nested model, top level 'fc3' and 1159 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized 1160 """ 1161 for dtype in [torch.qint8, torch.float16]: 1162 model = NestedModel().eval() 1163 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1164 qconfig_dict = { 1165 'fc3': qconfig, 1166 'sub2.fc1': qconfig 1167 } 1168 1169 prepare_dynamic(model, qconfig_dict) 1170 convert_dynamic(model) 1171 1172 def checkQuantized(model): 1173 self.checkLinear(model.sub1.fc) 1174 self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) 1175 self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) 1176 self.checkLinear(model.sub2.fc2) 1177 self.checkScriptable(model, self.calib_data, check_save_load=True) 1178 self.checkNoQconfig(model) 1179 1180 checkQuantized(model) 1181 1182 # test one line API 1183 model = quantize_dynamic(NestedModel().eval(), qconfig_dict) 1184 checkQuantized(model) 1185 1186 model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype) 1187 checkQuantized(model) 1188 1189 def test_nested2(self): 1190 r"""Another test case for quantized, we will quantize all submodules 1191 of submodule sub2 1192 """ 1193 for dtype in [torch.qint8, torch.float16]: 1194 model = NestedModel().eval() 1195 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1196 qconfig_dict = { 1197 'fc3': qconfig, 1198 'sub2': qconfig 1199 } 1200 prepare_dynamic(model, qconfig_dict) 1201 1202 convert_dynamic(model) 1203 1204 def checkQuantized(model): 1205 self.checkLinear(model.sub1.fc) 1206 self.assertEqual(type(model.sub1.relu), torch.nn.ReLU) 1207 self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) 1208 self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) 1209 self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) 1210 self.checkScriptable(model, self.calib_data, check_save_load=True) 1211 self.checkNoQconfig(model) 1212 1213 checkQuantized(model) 1214 1215 # test one line API 1216 model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype) 1217 checkQuantized(model) 1218 1219 # Test set API 1220 model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype) 1221 checkQuantized(model) 1222 1223 def test_nested3(self): 1224 r"""More complicated nested test case with child qconfig overrides 1225 parent qconfig 1226 """ 1227 for dtype in [torch.qint8, torch.float16]: 1228 model = NestedModel().eval() 1229 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1230 qconfig_dynamic_dict = { 1231 'fc3': qconfig, 1232 'sub2': qconfig, 1233 'sub2.fc1': qconfig 1234 } 1235 prepare_dynamic(model, qconfig_dynamic_dict) 1236 1237 convert_dynamic(model) 1238 1239 def checkQuantized(model): 1240 self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype) 1241 self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) 1242 self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype) 1243 self.checkScriptable(model, self.calib_data, check_save_load=True) 1244 self.checkNoQconfig(model) 1245 1246 checkQuantized(model) 1247 1248 # test one line API 1249 model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict) 1250 checkQuantized(model) 1251 1252 # Test set API 1253 model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype) 1254 checkQuantized(model) 1255 1256 def test_type_match_rule(self): 1257 r"""Test quantization for nested model, top level 'fc3' and 1258 'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized 1259 """ 1260 for dtype in [torch.qint8, torch.float16]: 1261 model = NestedModel().eval() 1262 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1263 qconfig_dict = { 1264 'fc3': None, 1265 'sub2.fc1': None, 1266 torch.nn.Linear: qconfig 1267 } 1268 1269 prepare_dynamic(model, qconfig_dict) 1270 test_only_eval_fn(model, self.calib_data) 1271 convert_dynamic(model) 1272 1273 def checkQuantized(model): 1274 self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype) 1275 self.checkLinear(model.fc3) 1276 self.checkLinear(model.sub2.fc1) 1277 self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype) 1278 test_only_eval_fn(model, self.calib_data) 1279 self.checkScriptable(model, self.calib_data, check_save_load=True) 1280 self.checkNoQconfig(model) 1281 1282 checkQuantized(model) 1283 1284 # test one line API 1285 model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype) 1286 checkQuantized(model) 1287 1288 def test_per_channel_linear_quantize(self): 1289 r"""Test quantization for per_channel dynamic quantization 1290 """ 1291 model = NestedModel().eval() 1292 qconfig_dict = { 1293 torch.nn.Linear: per_channel_dynamic_qconfig 1294 } 1295 1296 prepare_dynamic(model, qconfig_dict) 1297 test_only_eval_fn(model, self.calib_data) 1298 convert_dynamic(model) 1299 1300 def checkQuantized(model): 1301 self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8) 1302 self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8) 1303 self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8) 1304 self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8) 1305 test_only_eval_fn(model, self.calib_data) 1306 self.checkScriptable(model, self.calib_data, check_save_load=True) 1307 self.checkNoQconfig(model) 1308 1309 checkQuantized(model) 1310 # test one line API 1311 model = quantize_dynamic(NestedModel().eval(), qconfig_dict) 1312 checkQuantized(model) 1313 1314 def test_linear_relu_fusion(self): 1315 dtype = torch.qint8 1316 model = LinearReluLinearModel().eval() 1317 qconfig = default_dynamic_qconfig 1318 qconfig_dict = {'' : qconfig} 1319 torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True) 1320 prepare_dynamic(model, qconfig_dict) 1321 convert_dynamic(model) 1322 1323 def checkQuantized(model): 1324 self.checkDynamicQuantizedLinearRelu(model.fc1, dtype) 1325 self.checkDynamicQuantizedLinear(model.fc2, dtype) 1326 self.checkScriptable(model, self.calib_data, check_save_load=True) 1327 self.checkNoQconfig(model) 1328 1329 checkQuantized(model) 1330 1331 @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), 1332 dtype=st.sampled_from([torch.qint8, torch.float16])) 1333 def test_quantized_rnn(self, qconfig, dtype): 1334 r"""Test dynamic quantization, scriptability and serialization for dynamic quantized lstm modules on int8 and fp16 1335 """ 1336 niter = 10 1337 x = torch.tensor([[100, -155], 1338 [-155, 100], 1339 [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) 1340 qconfig_dict = { 1341 torch.nn.LSTM : qconfig, 1342 torch.nn.GRU: qconfig 1343 } 1344 1345 def checkQuantized(model, module_type): 1346 mod_type_map = {'LSTM': torch.ao.nn.quantized.dynamic.LSTM, 1347 'GRU': torch.ao.nn.quantized.dynamic.GRU} 1348 mod_repr_map = {'LSTM': 'DynamicQuantizedLSTM', 1349 'GRU': 'DynamicQuantizedGRU'} 1350 self.assertTrue(mod_repr_map[module_type] in str(model_quantized)) 1351 self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype) 1352 1353 for module_type in ['LSTM', 'GRU']: 1354 model = RNNDynamicModel(module_type).eval() 1355 1356 if dtype == torch.float16: 1357 model_quantized = quantize_dynamic(model=model, dtype=dtype) 1358 else: 1359 model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype) 1360 1361 checkQuantized(model_quantized, module_type) 1362 self.checkScriptable(model_quantized, [[x]], check_save_load=True) 1363 1364 class ScriptWrapperPackedLSTM(torch.nn.Module): 1365 def __init__(self, cell): 1366 super().__init__() 1367 self.cell = cell 1368 1369 def forward(self, x: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: 1370 return self.cell(x) 1371 1372 class ScriptWrapperPackedGRU(torch.nn.Module): 1373 def __init__(self, cell): 1374 super().__init__() 1375 self.cell = cell 1376 1377 def forward(self, x: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: 1378 return self.cell(x) 1379 1380 script_wrapper_map = {'LSTM': ScriptWrapperPackedLSTM, 1381 'GRU': ScriptWrapperPackedGRU} 1382 packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2])) 1383 model_with_packed_input = script_wrapper_map[module_type](model_quantized.mod) 1384 model_with_packed_input(packed_input) 1385 scripted = torch.jit.script(model_with_packed_input) 1386 scripted(packed_input) 1387 # We cannot trace with input dtype being a packed sequence 1388 self._checkScriptable(model_with_packed_input, scripted, [[packed_input]], True) 1389 1390 1391 @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]), 1392 dtype=st.sampled_from([torch.qint8, torch.float16])) 1393 def test_quantized_rnn_cell(self, qconfig, dtype): 1394 r"""Test dynamic quantization, scriptability and serialization for dynamic quantized rnn cell modules on int8 and fp16 1395 """ 1396 qconfig_dict = { 1397 torch.nn.LSTMCell : qconfig, 1398 torch.nn.GRUCell : qconfig, 1399 torch.nn.RNNCell : qconfig 1400 } 1401 1402 for module_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']: 1403 model = RNNCellDynamicModel(module_type).eval() 1404 x = torch.tensor([[100, -155], 1405 [-155, 100], 1406 [100, -155]], dtype=torch.float) 1407 1408 if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16: 1409 continue 1410 # fp16 dynamic quant is not supported for qnnpack 1411 1412 if dtype == torch.float16: 1413 model_quantized = quantize_dynamic(model=model, dtype=dtype) 1414 else: 1415 model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype) 1416 1417 def checkQuantized(model, module_type): 1418 mod_type_map = {'LSTMCell': torch.ao.nn.quantized.dynamic.LSTMCell, 1419 'GRUCell': torch.ao.nn.quantized.dynamic.GRUCell, 1420 'RNNTanh': torch.ao.nn.quantized.dynamic.RNNCell, 1421 'RNNReLU': torch.ao.nn.quantized.dynamic.RNNCell} 1422 1423 mod_repr_map = {'LSTMCell': 'DynamicQuantizedLSTMCell', 1424 'GRUCell': 'DynamicQuantizedGRUCell', 1425 'RNNTanh': 'DynamicQuantizedRNNCell', 1426 'RNNReLU': 'DynamicQuantizedRNNCell'} 1427 1428 self.assertTrue(mod_repr_map[module_type] in str(model_quantized)) 1429 self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype) 1430 self.checkNoQconfig(model) 1431 1432 # Smoke test extra reprs 1433 checkQuantized(model_quantized, module_type) 1434 self.checkScriptable(model_quantized, [[x]], check_save_load=True) 1435 1436 1437 def test_forward_hooks_preserved(self): 1438 r"""Test post-training dynamic quantization on preserving 1439 pre forward and post forward hooks of original model 1440 """ 1441 for dtype in [torch.qint8, torch.float16]: 1442 model = SingleLayerLinearDynamicModel().eval() 1443 qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig 1444 qconfig_dict = { 1445 'fc1': qconfig 1446 } 1447 convert_dynamic(model) 1448 1449 counter = { 1450 'pre_forwards': 0, 1451 'forwards': 0, 1452 } 1453 1454 def fw_pre_hook(h_module, input): 1455 counter['pre_forwards'] += 1 1456 1457 def fw_hook(h_module, input, output): 1458 counter['forwards'] += 1 1459 1460 model.fc1.register_forward_pre_hook(fw_pre_hook) 1461 model.fc1.register_forward_hook(fw_hook) 1462 prepare_dynamic(model, qconfig_dict) 1463 1464 def checkHooksIsPresent(model): 1465 self.assertObjectIn(fw_pre_hook, model.fc1._forward_pre_hooks.values()) 1466 self.assertObjectIn(fw_hook, model.fc1._forward_hooks.values()) 1467 self.assertEqual(len(model.fc1._forward_pre_hooks.values()), 1, 1468 "Extra pre forward hooks have appeared on a layer") 1469 self.assertEqual(len(model.fc1._forward_hooks.values()), 1, 1470 "Extra post forward hooks have appeared on a layer") 1471 1472 checkHooksIsPresent(model) 1473 test_only_eval_fn(model, self.calib_data) 1474 convert_dynamic(model) 1475 checkHooksIsPresent(model) 1476 1477 @skipIfNoFBGEMM 1478 def test_embedding_bag_dynamic(self): 1479 class EmbeddingBagWithLinear(torch.nn.Module): 1480 def __init__(self) -> None: 1481 super().__init__() 1482 self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, 1483 include_last_offset=True, scale_grad_by_freq=False, mode='sum') 1484 self.fc = torch.nn.Linear(5, 5) 1485 1486 def forward(self, indices, offsets, linear_in): 1487 return self.emb(indices, offsets), self.fc(linear_in) 1488 model = EmbeddingBagWithLinear().eval() 1489 1490 qconfig_dict = { 1491 torch.nn.EmbeddingBag : float_qparams_weight_only_qconfig, 1492 torch.nn.Linear: default_dynamic_qconfig 1493 } 1494 indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 1495 offsets = torch.tensor([0, 19, 20, 28, 28, 32]) 1496 q_model = quantize_dynamic(model, qconfig_dict) 1497 1498 q_model(indices, offsets, torch.randn(5, 5)) 1499 self.assertTrue('QuantizedEmbeddingBag' in str(q_model.emb)) 1500 self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc)) 1501 1502 @skipIfNoFBGEMM 1503 def test_embedding_ops_dynamic(self): 1504 class EmbeddingWithLinear(torch.nn.Module): 1505 def __init__(self) -> None: 1506 super().__init__() 1507 self.emb = torch.nn.Embedding( 1508 num_embeddings=10, embedding_dim=12, scale_grad_by_freq=False) 1509 self.fc = torch.nn.Linear(5, 5) 1510 1511 def forward(self, indices, linear_in): 1512 return self.emb(indices), self.fc(linear_in) 1513 model = EmbeddingWithLinear().eval() 1514 qconfig_dict = { 1515 torch.nn.Embedding : float_qparams_weight_only_qconfig, 1516 torch.nn.Linear: default_dynamic_qconfig 1517 } 1518 indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 1519 q_model = quantize_dynamic(model, qconfig_dict) 1520 self.assertTrue('QuantizedEmbedding' in str(q_model.emb)) 1521 self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc)) 1522 q_model(indices, torch.randn(5, 5)) 1523 1524if __name__ == '__main__': 1525 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 1526 "\tpython test/test_quantization.py TESTNAME\n\n" 1527 "instead.") 1528