1# Owner(s): ["oncall: quantization"] 2 3import copy 4import math 5 6import torch 7import torch.ao.nn.intrinsic.qat as nniqat 8import torch.ao.nn.qat as nnqat 9import torch.ao.nn.qat.dynamic as nnqatd 10import torch.ao.nn.quantized as nnq 11import torch.ao.nn.quantized.dynamic as nnqd 12import torch.backends.mkldnn 13import torch.nn as nn 14import torch.testing._internal.hypothesis_utils as hu 15 16from hypothesis import given, strategies as st 17from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d 18from torch.ao.quantization import ( 19 convert, 20 default_embedding_qat_qconfig, 21 default_qat_qconfig, 22 default_qconfig, 23 default_symmetric_qnnpack_qat_qconfig, 24 DeQuantStub, 25 FixedQParamsFakeQuantize, 26 FusedMovingAvgObsFakeQuantize, 27 get_default_qat_qconfig, 28 get_embedding_qat_module_mappings, 29 get_embedding_static_quant_module_mappings, 30 NoopObserver, 31 prepare, 32 prepare_qat, 33 quantize_qat, 34 QuantStub, 35) 36from torch.ao.quantization.qconfig import qconfig_equals 37from torch.nn import BatchNorm2d, Conv2d, init, ReLU 38from torch.nn.modules.utils import _pair 39from torch.testing._internal.common_quantization import ( 40 DeFusedEmbeddingBagLinear, 41 ManualConvLinearQATModel, 42 ManualConvLinearSymmQATModel, 43 ManualDropoutQATModel, 44 ManualEmbeddingBagLinear, 45 ManualLinearDynamicQATModel, 46 ManualLinearQATModel, 47 QuantizationTestCase, 48 QuantStubModel, 49 test_only_eval_fn, 50 test_only_train_fn, 51 TwoLayerLinearModel, 52) 53 54from torch.testing._internal.common_quantized import ( 55 override_qengines, 56 override_quantized_engine, 57 supported_qengines, 58) 59 60from torch.testing._internal.common_utils import skipIfNoXNNPACK 61 62hu.assert_deadline_disabled() 63from functools import reduce 64 65class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd): 66 """ 67 Conv-BN fusion implemented with explicit folding. Useful 68 to verify numerical equivalency with non-folded version. 69 """ 70 def __init__(self, 71 # ConvNd args 72 in_channels, out_channels, kernel_size, stride, 73 padding, dilation, transposed, output_padding, 74 groups, 75 bias, 76 padding_mode, 77 # BatchNormNd args 78 # num_features: out_channels 79 eps=1e-05, momentum=0.1, 80 # affine: True 81 # track_running_stats: True 82 # Args for this module 83 freeze_bn=False, 84 qconfig=None): 85 nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, 86 stride, padding, dilation, transposed, 87 output_padding, groups, False, padding_mode) 88 assert qconfig, 'qconfig must be provided for QAT module' 89 self.qconfig = qconfig 90 self.eps = eps 91 self.momentum = momentum 92 self.freeze_bn = freeze_bn if self.training else True 93 self.num_features = out_channels 94 self.gamma = nn.Parameter(torch.empty(out_channels)) 95 self.beta = nn.Parameter(torch.empty(out_channels)) 96 self.affine = True 97 self.track_running_stats = True 98 self.running_mean = nn.Buffer(torch.zeros(out_channels)) 99 self.running_var = nn.Buffer(torch.ones(out_channels)) 100 self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long)) 101 self.activation_post_process = self.qconfig.activation() 102 self.weight_fake_quant = self.qconfig.weight() 103 if bias: 104 self.bias = nn.Parameter(torch.empty(out_channels)) 105 else: 106 self.register_parameter('bias', None) 107 self.reset_bn_parameters() 108 109 def reset_running_stats(self): 110 self.running_mean.zero_() 111 self.running_var.fill_(1) 112 self.num_batches_tracked.zero_() 113 114 def reset_bn_parameters(self): 115 self.reset_running_stats() 116 init.uniform_(self.gamma) 117 init.zeros_(self.beta) 118 if self.bias is not None: 119 fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 120 bound = 1 / math.sqrt(fan_in) 121 init.uniform_(self.bias, -bound, bound) 122 123 def reset_parameters(self): 124 super().reset_parameters() 125 # A hack to avoid resetting on undefined parameters 126 if hasattr(self, 'gamma'): 127 self.reset_bn_parameters() 128 129 def update_bn_stats(self): 130 self.freeze_bn = False 131 return self 132 133 def freeze_bn_stats(self): 134 self.freeze_bn = True 135 return self 136 137 def _forward(self, input): 138 # exponential_average_factor is self.momentum set to 139 # (when it is available) only so that if gets updated 140 # in ONNX graph when this node is exported to ONNX. 141 if self.momentum is None: 142 exponential_average_factor = 0.0 143 else: 144 exponential_average_factor = self.momentum 145 146 if self.training and not self.freeze_bn and self.track_running_stats: 147 # TODO: if statement only here to tell the jit to skip emitting this when it is None 148 if self.num_batches_tracked is not None: 149 self.num_batches_tracked += 1 150 if self.momentum is None: # use cumulative moving average 151 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 152 else: # use exponential moving average 153 exponential_average_factor = self.momentum 154 155 # we use running statistics from the previous batch, so this is an 156 # approximation of the approach mentioned in the whitepaper, but we only 157 # need to do one convolution in this case instead of two 158 running_std = torch.sqrt(self.running_var + self.eps) 159 scale_factor = self.gamma / running_std 160 scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1]) 161 if self.bias is not None: 162 zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) 163 else: 164 zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype) 165 conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias) 166 167 if self.training and not self.freeze_bn: 168 # recovering original conv to get original batch_mean and batch_var 169 if self.bias is not None: 170 conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1]) 171 else: 172 conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) 173 batch_mean = torch.mean(conv_orig, dim=[0, 2, 3]) 174 batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False) 175 n = float(conv_orig.numel() / conv_orig.size()[1]) 176 unbiased_batch_var = batch_var * (n / (n - 1)) 177 batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps) 178 179 conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \ 180 (self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1]) 181 self.running_mean = exponential_average_factor * batch_mean.detach() + \ 182 (1 - exponential_average_factor) * self.running_mean 183 self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \ 184 (1 - exponential_average_factor) * self.running_var 185 else: 186 if self.bias is None: 187 conv = conv + (self.beta - self.gamma * self.running_mean / 188 running_std).reshape([1, -1, 1, 1]) 189 else: 190 conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1]) 191 return conv 192 193 def extra_repr(self): 194 # TODO(jerryzh): extend 195 return super().extra_repr() 196 197 def forward(self, input): 198 return self.activation_post_process(self._forward(input)) 199 200 @classmethod 201 def from_float(cls, mod, qconfig=None): 202 r"""Create a qat module from a float module or qparams_dict 203 Args: `mod` a float module, either produced by torch.ao.quantization utilities 204 or directly from user 205 """ 206 assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ 207 cls._FLOAT_MODULE.__name__ 208 if not qconfig: 209 assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' 210 assert mod.qconfig, 'Input float module must have a valid qconfig' 211 qconfig = mod.qconfig 212 conv, bn = mod[0], mod[1] 213 qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, 214 conv.stride, conv.padding, conv.dilation, 215 conv.groups, conv.bias is not None, 216 conv.padding_mode, 217 bn.eps, bn.momentum, 218 False, 219 qconfig) 220 qat_convbn.weight = conv.weight 221 qat_convbn.bias = conv.bias 222 qat_convbn.gamma = bn.weight 223 qat_convbn.beta = bn.bias 224 qat_convbn.running_mean = bn.running_mean 225 qat_convbn.running_var = bn.running_var 226 qat_convbn.num_batches_tracked = bn.num_batches_tracked 227 return qat_convbn 228 229class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d): 230 _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d 231 232 def __init__(self, 233 # ConvNd args 234 in_channels, out_channels, kernel_size, stride=1, 235 padding=0, dilation=1, groups=1, 236 bias=None, 237 padding_mode='zeros', 238 # BatchNorm2d args 239 # num_features: out_channels 240 eps=1e-05, momentum=0.1, 241 # affine: True 242 # track_running_stats: True 243 # Args for this module 244 freeze_bn=False, 245 qconfig=None): 246 kernel_size = _pair(kernel_size) 247 stride = _pair(stride) 248 padding = _pair(padding) 249 dilation = _pair(dilation) 250 _ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, 251 padding, dilation, False, _pair(0), groups, bias, padding_mode, 252 eps, momentum, freeze_bn, qconfig) 253 254class TestQuantizeEagerQAT(QuantizationTestCase): 255 def setUp(self): 256 super().setUp() 257 258 self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long), 259 torch.randn((12, 1), dtype=torch.float)] 260 for _ in range(2)] 261 self.embed_data = [[torch.randint(0, 10, (12, 1))]] 262 263 264 def test_manual(self): 265 for qengine in supported_qengines: 266 with override_quantized_engine(qengine): 267 model = ManualLinearQATModel(qengine) 268 model = prepare_qat(model) 269 self.checkObservers(model) 270 test_only_train_fn(model, self.train_data) 271 model = convert(model) 272 273 def checkQuantized(model): 274 self.assertEqual(type(model.fc1), nnq.Linear) 275 self.assertEqual(type(model.fc2), nnq.Linear) 276 test_only_eval_fn(model, self.calib_data) 277 self.checkScriptable(model, self.calib_data) 278 self.checkNoQconfig(model) 279 280 checkQuantized(model) 281 282 model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, 283 [self.train_data]) 284 checkQuantized(model) 285 286 def test_dropout(self): 287 for qengine in supported_qengines: 288 with override_quantized_engine(qengine): 289 model = ManualDropoutQATModel(qengine) 290 model = prepare_qat(model) 291 self.checkObservers(model) 292 test_only_train_fn(model, self.train_data) 293 model = convert(model) 294 295 def checkQuantized(model): 296 self.assertEqual(type(model.fc1), nnq.Linear) 297 self.assertEqual(type(model.dropout), nnq.Dropout) 298 test_only_eval_fn(model, self.calib_data) 299 self.checkScriptable(model, self.calib_data) 300 self.checkNoQconfig(model) 301 302 checkQuantized(model) 303 304 model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn, 305 [self.train_data]) 306 checkQuantized(model) 307 308 def test_eval_only_fake_quant(self): 309 r"""Using FakeQuant in evaluation only mode, 310 this is useful for estimating accuracy loss when we quantize the 311 network 312 """ 313 for qengine in supported_qengines: 314 with override_quantized_engine(qengine): 315 model = ManualLinearQATModel(qengine) 316 317 model = prepare_qat(model) 318 self.checkObservers(model) 319 320 model.eval() 321 test_only_eval_fn(model, self.calib_data) 322 323 def test_conv_linear(self): 324 for qengine in supported_qengines: 325 with override_quantized_engine(qengine): 326 model = ManualConvLinearQATModel() 327 328 model = prepare_qat(model) 329 self.checkObservers(model) 330 331 test_only_train_fn(model, self.img_data_2d_train) 332 model = convert(model) 333 334 def checkQuantized(model): 335 self.assertEqual(type(model.conv), nnq.Conv2d) 336 self.assertEqual(type(model.fc1), nnq.Linear) 337 self.assertEqual(type(model.fc2), nnq.Linear) 338 test_only_eval_fn(model, self.img_data_2d) 339 self.checkScriptable(model, self.img_data_2d) 340 self.checkNoQconfig(model) 341 342 checkQuantized(model) 343 344 model = ManualConvLinearQATModel() 345 model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) 346 checkQuantized(model) 347 348 @skipIfNoXNNPACK 349 def test_conv_linear_symm(self): 350 r"""Same as test_conv_linear but with Symmetric quantization. 351 Supported only with qengine=qnnpack, which uses symmetric 352 kernels from xnnpack library.""" 353 for qengine in supported_qengines: 354 if qengine != 'qnnpack': 355 continue 356 with override_quantized_engine(qengine): 357 model = ManualConvLinearSymmQATModel() 358 359 model = prepare_qat(model) 360 self.checkObservers(model) 361 362 test_only_train_fn(model, self.img_data_2d_train) 363 model = convert(model) 364 365 def checkQuantized(model): 366 self.assertEqual(type(model.conv), nnq.Conv2d) 367 self.assertEqual(type(model.fc1), nnq.Linear) 368 self.assertEqual(type(model.fc2), nnq.Linear) 369 test_only_eval_fn(model, self.img_data_2d) 370 self.checkScriptable(model, self.img_data_2d) 371 self.checkNoQconfig(model) 372 373 checkQuantized(model) 374 375 model = ManualConvLinearSymmQATModel() 376 model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) 377 checkQuantized(model) 378 379 def test_dynamic_qat_linear(self): 380 for qengine in supported_qengines: 381 with override_quantized_engine(qengine): 382 # Dynamic QAT without memoryless observers should fail 383 with self.assertRaisesRegex(ValueError, 384 "Dynamic QAT requires a memoryless observer." + 385 "This means a MovingAverage observer with averaging constant equal to 1" 386 ): 387 model = ManualLinearDynamicQATModel(default_qat_qconfig) 388 model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) 389 390 model = ManualLinearDynamicQATModel() 391 model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) 392 self.assertEqual(type(model.fc1), nnqatd.Linear) 393 self.assertEqual(type(model.fc2), nnqatd.Linear) 394 self.checkObservers(model) 395 test_only_train_fn(model, self.train_data) 396 model = convert(model, mapping={nnqatd.Linear: nnqd.Linear}) 397 self.assertEqual(type(model.fc1), nnqd.Linear) 398 self.assertEqual(type(model.fc2), nnqd.Linear) 399 test_only_eval_fn(model, self.calib_data) 400 self.checkScriptable(model, self.calib_data) 401 self.checkNoQconfig(model) 402 403 def test_defused_embedding_bag_linear(self): 404 for qengine in supported_qengines: 405 with override_quantized_engine(qengine): 406 model = DeFusedEmbeddingBagLinear().train() 407 model = prepare_qat(model, mapping=get_embedding_qat_module_mappings()) 408 self.checkObservers(model) 409 410 test_only_train_fn(model, self.embed_linear_data_train) 411 # make sure activation_post_process is inserted after Linear. 412 self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) 413 # make sure that Embedding has a noop for activation. 414 self.assertEqual(type(model.emb.activation_post_process), NoopObserver) 415 # make sure that FakeQuant zero_points are correct dtype 416 self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) 417 self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32) 418 419 model = convert(model, mapping=get_embedding_static_quant_module_mappings()) 420 421 def checkQuantized(model): 422 # make sure Embedding is now a QuantizedEmbedding 423 self.assertEqual(type(model.emb), nn.quantized.Embedding) 424 # make sure Linear is now a QuantizedLinear 425 self.assertEqual(type(model.linear), nn.quantized.Linear) 426 427 test_only_eval_fn(model, self.embed_data) 428 self.checkScriptable(model, self.embed_data) 429 self.checkNoQconfig(model) 430 431 checkQuantized(model) 432 433 434 def test_embedding_bag_linear(self): 435 for qengine in supported_qengines: 436 with override_quantized_engine(qengine): 437 model = ManualEmbeddingBagLinear().train() 438 model = prepare_qat(model, mapping=get_embedding_qat_module_mappings()) 439 self.checkObservers(model) 440 441 test_only_train_fn(model, self.embed_linear_data_train) 442 # make sure not activation_post_process is inserted for EmbeddingBag 443 self.assertFalse(hasattr(model, "activation_post_process")) 444 # make sure that FakeQuant zero_points are correct dtype 445 self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) 446 self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32) 447 model = convert(model, mapping=get_embedding_static_quant_module_mappings()) 448 449 def checkQuantized(model): 450 # Make sure EmbeddingBag is now a quantized EmbeddingBag. 451 self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) 452 # Also test that Linear has been quantized. 453 self.assertTrue(type(model.linear), nnq.Linear) 454 455 test_only_eval_fn(model, self.embed_data) 456 self.checkScriptable(model, self.embed_data) 457 self.checkNoQconfig(model) 458 459 checkQuantized(model) 460 461 model = ManualEmbeddingBagLinear() 462 463 def test_train_save_load_eval(self): 464 r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict 465 During eval, we first call prepare_qat and conver on the model and then load the state_dict 466 and compare results against original model 467 """ 468 for qengine in supported_qengines: 469 with override_quantized_engine(qengine): 470 model = TwoLayerLinearModel() 471 model = torch.ao.quantization.QuantWrapper(model) 472 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 473 model = prepare_qat(model) 474 475 fq_state_dict = model.state_dict() 476 477 test_only_train_fn(model, self.train_data) 478 model = convert(model) 479 480 quant_state_dict = model.state_dict() 481 482 x = torch.rand(2, 5, dtype=torch.float) 483 ref = model(x) 484 485 # Create model again for eval. Check result using quantized state_dict 486 model = TwoLayerLinearModel() 487 model = torch.ao.quantization.QuantWrapper(model) 488 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 489 torch.ao.quantization.prepare_qat(model, inplace=True) 490 new_state_dict = model.state_dict() 491 492 # Check to make sure the model after prepare_qat has the same state_dict as original. 493 self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys())) 494 495 torch.ao.quantization.convert(model, inplace=True) 496 model.eval() 497 model.load_state_dict(quant_state_dict) 498 out = model(x) 499 self.assertEqual(ref, out) 500 501 # Check model created using prepare has same state dict as quantized state_dict 502 model = TwoLayerLinearModel() 503 model.eval() 504 model = torch.ao.quantization.QuantWrapper(model) 505 model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 506 torch.ao.quantization.prepare(model, inplace=True) 507 torch.ao.quantization.convert(model, inplace=True) 508 self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) 509 model.eval() 510 model.load_state_dict(quant_state_dict) 511 out = model(x) 512 self.assertEqual(ref, out) 513 514 @override_qengines 515 def test_forward_hooks_preserved(self): 516 r"""Test QAT on preserving pre forward and post forward hooks of original model 517 """ 518 qengine = torch.backends.quantized.engine 519 model = QuantStubModel() 520 counter = { 521 'pre_forwards': 0, 522 'forwards': 0, 523 } 524 525 def fw_pre_hook(h_module, input): 526 counter['pre_forwards'] += 1 527 528 def fw_hook(h_module, input, output): 529 counter['forwards'] += 1 530 531 model.fc.register_forward_pre_hook(fw_pre_hook) 532 model.fc.register_forward_hook(fw_hook) 533 534 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 535 model = prepare_qat(model) 536 537 def checkHooksIsPresent(model, before_convert=True): 538 forward_hooks = 1 539 if before_convert: 540 self.assertEqual(len(model.quant._forward_hooks.values()), 1, 541 "Quantization observer hook has disappeared") 542 forward_hooks = 2 543 self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) 544 self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) 545 self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, 546 "Extra pre forward hooks have appeared on a layer") 547 self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks, 548 "Extra post forward hooks have appeared on a layer") 549 550 checkHooksIsPresent(model, True) 551 x = torch.rand(2, 5, dtype=torch.float) 552 model(x) 553 torch.ao.quantization.convert(model, inplace=True) 554 checkHooksIsPresent(model, False) 555 556 def test_add_scalar_uses_input_qparams(self): 557 class M(torch.nn.Module): 558 def __init__(self) -> None: 559 super().__init__() 560 self.quant = torch.ao.quantization.QuantStub() 561 self.ff = torch.ao.nn.quantized.FloatFunctional() 562 563 def forward(self, x): 564 x = self.quant(x) 565 x = self.ff.add_scalar(x, 1.0) 566 return x 567 568 m = M() 569 m.qconfig = torch.ao.quantization.default_qconfig 570 mp = torch.ao.quantization.prepare_qat(m) 571 mp(torch.randn(4, 4)) 572 mq = torch.ao.quantization.convert(mp) 573 res = mq(torch.randn(4, 4)) 574 eps = 1e-5 575 self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps) 576 577 def test_mul_scalar_uses_input_qparams(self): 578 class M(torch.nn.Module): 579 def __init__(self) -> None: 580 super().__init__() 581 self.quant = torch.ao.quantization.QuantStub() 582 self.ff = torch.ao.nn.quantized.FloatFunctional() 583 584 def forward(self, x): 585 x = self.quant(x) 586 x = self.ff.mul_scalar(x, 2.0) 587 return x 588 589 m = M() 590 m.qconfig = torch.ao.quantization.default_qconfig 591 mp = torch.ao.quantization.prepare_qat(m) 592 mp(torch.randn(4, 4)) 593 mq = torch.ao.quantization.convert(mp) 594 res = mq(torch.randn(4, 4)) 595 eps = 1e-5 596 self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) 597 598 @override_qengines 599 def test_qat_embedding_bag_errors(self): 600 default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) 601 602 # Test constructor parameters checks here. 603 with self.assertRaisesRegex(AssertionError, 604 "qconfig must be provided for QAT module"): 605 nnqat.EmbeddingBag(10, 5, qconfig=None) 606 607 with self.assertRaisesRegex(AssertionError, 608 "Embedding Bag weights requires a qscheme of " + 609 "torch.per_channel_affine_float_qparams"): 610 nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig) 611 612 # Test from_float checks here. 613 embed = nn.Embedding(10, 5) 614 with self.assertRaisesRegex(AssertionError, 615 "qat.EmbeddingBag.from_float only works for EmbeddingBag"): 616 nnqat.EmbeddingBag.from_float(embed) 617 embed_bag = nn.EmbeddingBag(10, 5) 618 with self.assertRaisesRegex(AssertionError, 619 "Input float module must have qconfig defined"): 620 nnqat.EmbeddingBag.from_float(embed_bag) 621 embed_bag.qconfig = None 622 with self.assertRaisesRegex(AssertionError, 623 "Input float module must have a valid qconfig"): 624 nnqat.EmbeddingBag.from_float(embed_bag) 625 embed_bag.qconfig = default_qat_qconfig 626 with self.assertRaisesRegex(AssertionError, 627 "Embedding Bag weights requires a qscheme of " + 628 "torch.per_channel_affine_float_qparams"): 629 nnqat.EmbeddingBag.from_float(embed_bag) 630 631 def test_embedding_qat_qconfig_equal(self): 632 # Embedding QAT uses a NoopObserver class for activation, 633 # and a FakeQuant for weight, make sure that qconfig comparison 634 # functions properly for a mix of partial function and class in 635 # qconfig. 636 model = ManualEmbeddingBagLinear().train() 637 model = prepare_qat(model) 638 639 self.assertTrue(qconfig_equals(model.emb.qconfig, 640 default_embedding_qat_qconfig)) 641 642class TestQuantizeEagerQATNumerics(QuantizationTestCase): 643 def _test_activation_convert_numerics_impl(self, Act, data): 644 class M(torch.nn.Module): 645 def __init__(self) -> None: 646 super().__init__() 647 self.act = Act() 648 self.quant = QuantStub() 649 self.dequant = DeQuantStub() 650 651 def forward(self, x): 652 x = self.quant(x) 653 x = self.act(x) 654 x = self.dequant(x) 655 return x 656 657 m = M().train() 658 m.qconfig = default_qat_qconfig 659 m = prepare_qat(m) 660 before_convert = m(data) 661 m = convert(m) 662 after_convert = m(data) 663 self.assertEqual(before_convert, after_convert) 664 665 def test_fixed_qparam_ops(self): 666 class M(torch.nn.Module): 667 def __init__(self) -> None: 668 super().__init__() 669 self.sigmoid = torch.nn.Sigmoid() 670 self.hardsigmoid = torch.nn.Hardsigmoid() 671 self.tanh = torch.nn.Tanh() 672 self.quant = QuantStub() 673 self.dequant = DeQuantStub() 674 675 def forward(self, x): 676 x = self.quant(x) 677 x = self.sigmoid(x) 678 x = self.hardsigmoid(x) 679 x = self.tanh(x) 680 x = self.dequant(x) 681 return x 682 683 m = M().train() 684 m.qconfig = default_qat_qconfig 685 m = prepare_qat(m) 686 for attr in ['sigmoid', 'hardsigmoid', 'tanh']: 687 self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) 688 data = torch.randn(1, 3, 2, 4) 689 before_convert = m(data) 690 m = convert(m) 691 after_convert = m(data) 692 self.assertEqual(before_convert, after_convert) 693 # make sure activation post process is removed 694 for attr in ['sigmoid', 'hardsigmoid', 'tanh']: 695 # verify fake quant module is removd 696 self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) 697 # verify that hooks are removed 698 self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) 699 700 # make sure no fake quantize module is inserted for eval mode 701 702 def checkNoFQModule(m): 703 for attr in ['sigmoid', 'hardsigmoid', 'tanh']: 704 self.assertFalse(hasattr(getattr(m, attr), "activation_post_process")) 705 self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) 706 707 m = M().eval() 708 m.qconfig = default_qconfig 709 m = prepare(m) 710 checkNoFQModule(m) 711 m = convert(m) 712 checkNoFQModule(m) 713 714 def test_leaky_relu(self): 715 data = torch.randn(1, 3, 2, 4) 716 self._test_activation_convert_numerics_impl(nn.LeakyReLU, data) 717 718 def test_relu(self): 719 class M(torch.nn.Module): 720 def __init__(self) -> None: 721 super().__init__() 722 self.relu = nn.ReLU() 723 724 def forward(self, x): 725 x = self.relu(x) 726 return x 727 728 m = M().train() 729 m.qconfig = default_qconfig 730 m = prepare_qat(m) 731 # make sure no activation_post_process is inserted for relu 732 self.assertFalse(hasattr(m, "activation_post_process")) 733 m = convert(m) 734 # make sure ReLU module is not changed 735 self.assertTrue(type(m.relu), nn.ReLU) 736 737 @given(batch_size=st.integers(2, 4), 738 input_channels_per_group=st.sampled_from([2, 3, 4]), 739 height=st.integers(5, 10), 740 width=st.integers(5, 10), 741 output_channels_per_group=st.sampled_from([2, 3]), 742 groups=st.integers(1, 3), 743 kernel_h=st.integers(1, 3), 744 kernel_w=st.integers(1, 3), 745 stride_h=st.integers(1, 2), 746 stride_w=st.integers(1, 2), 747 pad_h=st.integers(0, 2), 748 pad_w=st.integers(0, 2), 749 dilation=st.integers(1, 1), 750 padding_mode=st.sampled_from(['zeros', 'circular']), 751 use_relu=st.booleans(), 752 eps=st.sampled_from([1e-5, 1e-4, 1e-3]), 753 momentum=st.sampled_from([0.1, 0.2, 0.3]), 754 freeze_bn=st.booleans(), 755 zero_gamma=st.booleans(), 756 has_bias=st.booleans(), 757 use_slow_fusion=st.booleans()) 758 def test_conv_bn_relu( 759 self, 760 batch_size, 761 input_channels_per_group, 762 height, 763 width, 764 output_channels_per_group, 765 groups, 766 kernel_h, 767 kernel_w, 768 stride_h, 769 stride_w, 770 pad_h, 771 pad_w, 772 dilation, 773 padding_mode, 774 use_relu, 775 eps, 776 momentum, 777 freeze_bn, 778 zero_gamma, 779 has_bias, 780 use_slow_fusion, 781 ): 782 input_channels = input_channels_per_group * groups 783 output_channels = output_channels_per_group * groups 784 dilation_h = dilation_w = dilation 785 786 conv_op = Conv2d( 787 input_channels, 788 output_channels, 789 (kernel_h, kernel_w), 790 (stride_h, stride_w), 791 (pad_h, pad_w), 792 (dilation_h, dilation_w), 793 groups, 794 has_bias, 795 padding_mode 796 ).to(dtype=torch.double) 797 bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double) 798 relu_op = ReLU() 799 800 cls = ConvBnReLU2d if use_relu else ConvBn2d 801 qat_op = cls( 802 input_channels, 803 output_channels, 804 (kernel_h, kernel_w), 805 (stride_h, stride_w), 806 (pad_h, pad_w), 807 (dilation_h, dilation_w), 808 groups, 809 has_bias, 810 padding_mode, 811 eps, 812 momentum, 813 freeze_bn=True, 814 qconfig=default_qat_qconfig 815 ).to(dtype=torch.double) 816 qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion 817 818 # the approximate fusion will not work if bn.weight has 0 819 if zero_gamma and use_slow_fusion: 820 torch.nn.init.zeros_(qat_op.bn.weight) 821 822 qat_op.apply(torch.ao.quantization.disable_fake_quant) 823 if freeze_bn: 824 qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) 825 else: 826 qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats) 827 828 # align inputs and internal parameters 829 input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) 830 conv_op.weight = torch.nn.Parameter(qat_op.weight.detach()) 831 if has_bias: 832 conv_op.bias = torch.nn.Parameter(qat_op.bias.detach()) 833 bn_op.running_mean = qat_op.bn.running_mean.clone() 834 bn_op.running_var = qat_op.bn.running_var.clone() 835 bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach()) 836 bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach()) 837 838 def compose(functions): 839 # functions are reversed for natural reading order 840 return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x) 841 842 if not use_relu: 843 def relu_op(x): # noqa: F811 844 return x 845 846 if freeze_bn: 847 def ref_op(x): 848 x = conv_op(x) 849 x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \ 850 (bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \ 851 .reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1]) 852 x = relu_op(x) 853 return x 854 else: 855 ref_op = compose([conv_op, bn_op, relu_op]) 856 857 input_clone = input.clone().detach().requires_grad_() 858 for i in range(2): 859 result_ref = ref_op(input) 860 result_actual = qat_op(input_clone) 861 self.assertEqual(result_ref, result_actual) 862 863 # backward 864 dout = torch.randn(result_ref.size(), dtype=torch.double) 865 loss = (result_ref - dout).sum() 866 loss.backward() 867 input_grad_ref = input.grad.cpu() 868 weight_grad_ref = conv_op.weight.grad.cpu() 869 gamma_grad_ref = bn_op.weight.grad.cpu() 870 beta_grad_ref = bn_op.bias.grad.cpu() 871 running_mean_ref = bn_op.running_mean 872 running_var_ref = bn_op.running_var 873 num_batches_tracked_ref = bn_op.num_batches_tracked 874 loss = (result_actual - dout).sum() 875 loss.backward() 876 input_grad_actual = input_clone.grad.cpu() 877 weight_grad_actual = qat_op.weight.grad.cpu() 878 gamma_grad_actual = qat_op.bn.weight.grad.cpu() 879 beta_grad_actual = qat_op.bn.bias.grad.cpu() 880 running_mean_actual = qat_op.bn.running_mean 881 running_var_actual = qat_op.bn.running_var 882 num_batches_tracked_actual = qat_op.bn.num_batches_tracked 883 precision = 1e-10 884 self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0) 885 self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0) 886 self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0) 887 self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0) 888 self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0) 889 self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0) 890 self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0) 891 892 @given(batch_size=st.integers(2, 4), 893 input_channels_per_group=st.sampled_from([2, 3, 4]), 894 height=st.integers(5, 10), 895 width=st.integers(5, 10), 896 output_channels_per_group=st.sampled_from([2, 3]), 897 groups=st.integers(1, 3), 898 kernel_h=st.integers(1, 3), 899 kernel_w=st.integers(1, 3), 900 stride_h=st.integers(1, 2), 901 stride_w=st.integers(1, 2), 902 pad_h=st.integers(0, 2), 903 pad_w=st.integers(0, 2), 904 dilation=st.integers(1, 1), 905 padding_mode=st.sampled_from(['zeros', 'circular']), 906 eps=st.sampled_from([1e-5, 1e-4, 1e-3]), 907 momentum=st.sampled_from([0.1, 0.2, 0.3]), 908 freeze_bn=st.booleans(), 909 bias=st.booleans()) 910 def test_conv_bn_folded_vs_unfolded( 911 self, 912 batch_size, 913 input_channels_per_group, 914 height, 915 width, 916 output_channels_per_group, 917 groups, 918 kernel_h, 919 kernel_w, 920 stride_h, 921 stride_w, 922 pad_h, 923 pad_w, 924 dilation, 925 padding_mode, 926 eps, 927 momentum, 928 freeze_bn, 929 bias, 930 ): 931 input_channels = input_channels_per_group * groups 932 output_channels = output_channels_per_group * groups 933 dilation_h = dilation_w = dilation 934 935 qat_op = ConvBn2d( 936 input_channels, 937 output_channels, 938 (kernel_h, kernel_w), 939 (stride_h, stride_w), 940 (pad_h, pad_w), 941 (dilation_h, dilation_w), 942 groups, 943 bias, # bias 944 padding_mode, 945 eps, 946 momentum, 947 freeze_bn=freeze_bn, 948 qconfig=default_qat_qconfig 949 ).to(dtype=torch.double) 950 951 qat_ref_op = _ReferenceConvBn2d( 952 input_channels, 953 output_channels, 954 (kernel_h, kernel_w), 955 (stride_h, stride_w), 956 (pad_h, pad_w), 957 (dilation_h, dilation_w), 958 groups, 959 bias, # bias 960 padding_mode, 961 eps, 962 momentum, 963 freeze_bn=freeze_bn, 964 qconfig=default_qat_qconfig 965 ).to(dtype=torch.double) 966 967 qat_op.apply(torch.ao.quantization.disable_fake_quant) 968 qat_ref_op.apply(torch.ao.quantization.disable_fake_quant) 969 970 # align inputs and internal parameters 971 qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone()) 972 qat_ref_op.running_mean = qat_op.bn.running_mean.clone() 973 qat_ref_op.running_var = qat_op.bn.running_var.clone() 974 qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone()) 975 qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone()) 976 if qat_op.bias is not None: 977 qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone()) 978 979 lr = 0.01 980 qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr) 981 qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr) 982 983 for i in range(5): 984 985 # make sure that calling model.train() does not override the 986 # bn freeze setting 987 qat_op.train() 988 qat_ref_op.train() 989 990 qat_op_optim.zero_grad() 991 qat_ref_op_optim.zero_grad() 992 993 input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) 994 input_clone = input.clone().detach().requires_grad_() 995 996 if i > 2: 997 qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) 998 qat_ref_op.freeze_bn_stats() 999 1000 if i > 3: 1001 qat_op.apply(torch.ao.quantization.disable_observer) 1002 qat_ref_op.apply(torch.ao.quantization.disable_observer) 1003 1004 result_ref = qat_ref_op(input) 1005 result_actual = qat_op(input_clone) 1006 self.assertEqual(result_ref, result_actual) 1007 1008 # backward 1009 dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0 1010 1011 loss = (result_ref - dout).sum() 1012 loss.backward() 1013 input_grad_ref = input.grad.cpu() 1014 weight_grad_ref = qat_ref_op.weight.grad.cpu() 1015 gamma_grad_ref = qat_ref_op.gamma.grad.cpu() 1016 beta_grad_ref = qat_ref_op.beta.grad.cpu() 1017 running_mean_ref = qat_ref_op.running_mean 1018 running_var_ref = qat_ref_op.running_var 1019 num_batches_tracked_ref = qat_ref_op.num_batches_tracked 1020 1021 loss = (result_actual - dout).sum() 1022 loss.backward() 1023 input_grad_actual = input_clone.grad.cpu() 1024 weight_grad_actual = qat_op.weight.grad.cpu() 1025 gamma_grad_actual = qat_op.bn.weight.grad.cpu() 1026 beta_grad_actual = qat_op.bn.bias.grad.cpu() 1027 running_mean_actual = qat_op.bn.running_mean 1028 running_var_actual = qat_op.bn.running_var 1029 num_batches_tracked_actual = qat_op.bn.num_batches_tracked 1030 1031 precision = 1e-5 1032 self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0) 1033 self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0) 1034 self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0) 1035 self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0) 1036 self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0) 1037 self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0) 1038 self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0) 1039 1040 qat_op_optim.step() 1041 qat_ref_op_optim.step() 1042 1043 @override_qengines 1044 def test_linear_bn_numerics(self): 1045 qengine = torch.backends.quantized.engine 1046 m_ref = nn.Sequential( 1047 nn.Linear(4, 4), 1048 nn.BatchNorm1d(4), 1049 ) 1050 m_ref_copy = copy.deepcopy(m_ref) 1051 m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']]) 1052 qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 1053 m_ref_copy[0].qconfig = qconfig 1054 m = nniqat.LinearBn1d.from_float(m_ref_copy[0]) 1055 1056 # without fake_quants, fused QAT module should match fp32 module 1057 m.apply(torch.ao.quantization.disable_fake_quant) 1058 data = torch.randn(4, 4) 1059 r1 = m_ref(data) 1060 r2 = m(data) 1061 self.assertTrue(torch.allclose(r1, r2)) 1062 1063 @skipIfNoXNNPACK 1064 @override_qengines 1065 def test_linear_bn_symm_numerics(self): 1066 qengine = torch.backends.quantized.engine 1067 if qengine != "qnnpack": 1068 return # Only qnnpack support symmetric quantization 1069 m_ref = nn.Sequential( 1070 nn.Linear(4, 4), 1071 nn.BatchNorm1d(4), 1072 ) 1073 m_ref_copy = copy.deepcopy(m_ref) 1074 m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']]) 1075 qconfig = default_symmetric_qnnpack_qat_qconfig 1076 m_ref_copy[0].qconfig = qconfig 1077 m = nniqat.LinearBn1d.from_float(m_ref_copy[0]) 1078 1079 # without fake_quants, fused QAT module should match fp32 module 1080 m.apply(torch.ao.quantization.disable_fake_quant) 1081 data = torch.randn(4, 4) 1082 r1 = m_ref(data) 1083 r2 = m(data) 1084 self.assertTrue(torch.allclose(r1, r2)) 1085 1086 @override_qengines 1087 def test_linear_bn_workflow(self): 1088 qengine = torch.backends.quantized.engine 1089 m = nn.Sequential( 1090 QuantStub(), 1091 nn.Linear(4, 4), 1092 nn.BatchNorm1d(4), 1093 ) 1094 data = torch.randn(4, 4) 1095 m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 1096 m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']]) 1097 mp = prepare_qat(m) 1098 mp(data) 1099 mq = convert(mp) 1100 self.assertTrue(type(mq[1]) == nnq.Linear) 1101 self.assertTrue(type(mq[2]) == nn.Identity) 1102 1103 1104 @skipIfNoXNNPACK 1105 @override_qengines 1106 def test_linear_precomputed_fake_quant(self): 1107 qengine = torch.backends.quantized.engine 1108 if qengine != "qnnpack": 1109 return # Only qnnpack support symmetric quantization 1110 m_ref = nn.Linear(4, 4) 1111 1112 m_ref_copy = copy.deepcopy(m_ref) 1113 qconfig = default_qconfig 1114 m_ref_copy.qconfig = qconfig 1115 weight_post_process = copy.deepcopy(qconfig.weight()) 1116 activation = copy.deepcopy(qconfig.activation()) 1117 activation(torch.randn(4, 4)) 1118 m_ref_copy.activation_post_process = activation 1119 m_ref_copy = nnq.Linear.from_float(m_ref_copy) 1120 weight_post_process = qconfig.weight() 1121 weight_post_process.min_val = torch.tensor(-1) 1122 weight_post_process.max_val = torch.tensor(1) 1123 m_ref.weight_post_process = weight_post_process 1124 m_ref.activation_post_process = activation 1125 m_ref.qconfig = qconfig 1126 m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True) 1127 self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale) 1128 1129 1130if __name__ == '__main__': 1131 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 1132 "\tpython test/test_quantization.py TESTNAME\n\n" 1133 "instead.") 1134