1# Owner(s): ["oncall: quantization"] 2 3import torch 4import torch.nn as nn 5import torch.ao.nn.intrinsic as nni 6import torch.ao.nn.intrinsic.quantized as nniq 7import torch.ao.nn.quantized.reference as nnqr 8import torch.ao.quantization 9import torch.ao.nn.quantized as nnq 10import torch.ao.nn.quantized.dynamic as nnqd 11 12from torch.ao.quantization import ( 13 get_default_static_quant_module_mappings, 14 default_float_qparams_observer, 15 PerChannelMinMaxObserver, 16) 17from torch.package import PackageExporter, PackageImporter 18from torch.testing._internal.common_quantization import ( 19 QuantizationTestCase, 20 prepare_dynamic, 21 _make_conv_test_input, 22 skipIfNoFBGEMM, 23 lengths_to_offsets, 24 skipIfNoONEDNN, 25 _make_conv_add_extra_input_tensor, 26) 27from torch.testing._internal.common_quantized import ( 28 _calculate_dynamic_qparams, 29 override_quantized_engine, 30 override_qengines, 31 qengine_is_qnnpack, 32 qengine_is_onednn, 33) 34import torch.fx 35from hypothesis import assume, given 36from hypothesis import strategies as st 37import torch.testing._internal.hypothesis_utils as hu 38hu.assert_deadline_disabled() 39 40import copy 41import io 42import numpy as np 43import itertools 44 45""" 46Note that tests in this file are just API test, to make sure we wrapped the 47quantized operator implementations correctly in the user facing APIs, these are 48not correctness test for the underlying quantized operators. For correctness 49test please see `test/quantization/test_quantized_op.py`. 50""" 51 52class TestStaticQuantizedModule(QuantizationTestCase): 53 def test_relu(self): 54 relu_module = nn.ReLU() 55 relu6_module = nnq.ReLU6() 56 57 x = torch.arange(-10, 10, dtype=torch.float) 58 y_ref = torch.relu(x) 59 y6_ref = torch.nn.modules.ReLU6()(x) 60 61 qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.qint32) 62 qy = relu_module(qx) 63 qy6 = relu6_module(qx) 64 65 self.assertEqual(y_ref, qy.dequantize(), 66 msg="ReLU module API failed") 67 self.assertEqual(y6_ref, qy6.dequantize(), 68 msg="ReLU6 module API failed") 69 70 @override_qengines 71 def test_linear(self): 72 """test API functionality for nn.quantized.linear""" 73 options = itertools.product( 74 [1, 5], 75 [16, 32], 76 [4, 8], 77 [True, False], 78 [True, False]) 79 for (batch_size, in_features, out_features, use_bias, per_channel) in options: 80 self._test_linear_api_impl( 81 nnq.Linear, 'QuantizedLinear', torch.ops.quantized.linear, batch_size, 82 in_features, out_features, use_bias, per_channel) 83 84 @override_qengines 85 def test_linear_relu(self): 86 """test API functionality for nn.intrinsic.quantized.linear_relu""" 87 options = itertools.product( 88 [1, 5], 89 [16, 32], 90 [4, 8], 91 [True, False], 92 [True, False]) 93 for (batch_size, in_features, out_features, use_bias, per_channel) in options: 94 self._test_linear_api_impl( 95 nniq.LinearReLU, 'QuantizedLinearReLU', torch.ops.quantized.linear_relu, 96 batch_size, in_features, out_features, use_bias, per_channel) 97 98 def _test_linear_api_impl(self, qlinear_module, module_name, qlinear_op, 99 batch_size, in_features, out_features, use_bias, 100 per_channel, **post_ops_kwargs): 101 if torch.backends.quantized.engine == 'qnnpack': 102 per_channel = False 103 104 W = torch.rand(out_features, in_features).float() 105 if per_channel: 106 scale_tensor = torch.ones(out_features, dtype=torch.double) 107 zero_point_tensor = torch.zeros(out_features, dtype=torch.long) 108 for i in range(len(scale_tensor)): 109 scale_tensor[i] = (i + 1.0) / 255.0 110 W_q = torch.quantize_per_channel(W, scales=scale_tensor, 111 zero_points=zero_point_tensor, 112 axis=0, dtype=torch.qint8) 113 else: 114 # ONEDNN only supports symmetric quantization of weight 115 W_zp = 0 if qengine_is_onednn() else 4 116 W_q = torch.quantize_per_tensor(W, 0.1, W_zp, torch.qint8) 117 118 X = torch.rand(batch_size, in_features).float() 119 X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8) 120 B = torch.rand(out_features).float() if use_bias else None 121 scale = 0.5 122 zero_point = 3 123 qlinear = qlinear_module(in_features, out_features, **post_ops_kwargs) 124 125 qlinear_copy = copy.deepcopy(qlinear) 126 # set random quantized weight and bias before test torch scriptable 127 qlinear_copy.set_weight_bias(W_q, B) 128 self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True) 129 # Run module with default-initialized parameters. 130 # This tests that the constructor is correct. 131 qlinear(X_q) 132 133 qlinear.set_weight_bias(W_q, B) 134 # Simple round-trip test to ensure weight()/set_weight() API 135 self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0) 136 137 # testing packed param implementation 138 qlinear.scale = float(scale) 139 qlinear.zero_point = int(zero_point) 140 Z_q = qlinear(X_q) 141 142 # Check if the module implementation matches calling the 143 # ops directly 144 W_pack = qlinear._packed_params._packed_params 145 Z_ref = qlinear_op(X_q, W_pack, scale, zero_point, **post_ops_kwargs) 146 147 self.assertEqual(Z_ref, Z_q) 148 self.assertTrue(module_name in str(qlinear)) 149 150 # Test serialization of quantized Linear Module using state_dict 151 model_dict = qlinear.state_dict() 152 b = io.BytesIO() 153 torch.save(model_dict, b) 154 for weights_only in [True, False]: 155 b.seek(0) 156 loaded_dict = torch.load(b, weights_only=weights_only) 157 for key in model_dict: 158 if isinstance(model_dict[key], torch._C.ScriptObject): 159 assert isinstance(loaded_dict[key], torch._C.ScriptObject) 160 w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key]) 161 w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key]) 162 self.assertEqual(w_model, w_loaded) 163 self.assertEqual(b_model, b_loaded) 164 else: 165 self.assertEqual(model_dict[key], loaded_dict[key]) 166 167 loaded_qlinear = qlinear_module( 168 in_features, out_features, **post_ops_kwargs) 169 loaded_qlinear.load_state_dict(loaded_dict) 170 linear_unpack = torch.ops.quantized.linear_unpack 171 self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), 172 linear_unpack(loaded_qlinear._packed_params._packed_params)) 173 self.assertEqual(qlinear.scale, loaded_qlinear.scale) 174 self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) 175 # scripting will add __overloads__ to __dict__, which is why we script a copy 176 # to be able to do the check in the next line 177 self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True) 178 self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) 179 self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) 180 self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) 181 Z_q2 = loaded_qlinear(X_q) 182 self.assertEqual(Z_q, Z_q2) 183 184 # Test serialization 185 b = io.BytesIO() 186 torch.save(qlinear, b) 187 b.seek(0) 188 # weights_only=False as this is legacy code that saves the model 189 loaded = torch.load(b, weights_only=False) 190 self.assertEqual(qlinear.weight(), loaded.weight()) 191 self.assertEqual(qlinear.scale, loaded.scale) 192 self.assertEqual(qlinear.zero_point, loaded.zero_point) 193 194 # Test torch.package 195 buffer = io.BytesIO() 196 with PackageExporter(buffer) as pe: 197 pe.save_pickle("module", "qlinear.pkl", qlinear) 198 buffer.seek(0) 199 200 importer = PackageImporter(buffer) 201 loaded_from_package = importer.load_pickle("module", "qlinear.pkl") 202 self.assertEqual(qlinear.weight(), loaded_from_package.weight()) 203 self.assertEqual(qlinear.scale, loaded_from_package.scale) 204 self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point) 205 206 for name, module in loaded_from_package.named_modules(): 207 # noop, just make sure attribute "_modules" is restored correctly during torch.package import 208 assert(name is not None) # noqa: E275 209 210 # Test copy and deepcopy 211 copied_linear = copy.copy(qlinear) 212 self.assertEqual(copied_linear.bias(), qlinear.bias()) 213 self.assertEqual(copied_linear.scale, qlinear.scale) 214 self.assertEqual(copied_linear.zero_point, 215 qlinear.zero_point) 216 Y_copied = copied_linear(X_q) 217 np.testing.assert_array_almost_equal( 218 Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) 219 220 deepcopied_linear = copy.deepcopy(qlinear) 221 self.assertEqual(deepcopied_linear.bias(), qlinear.bias()) 222 self.assertEqual(deepcopied_linear.scale, qlinear.scale) 223 self.assertEqual(deepcopied_linear.zero_point, 224 qlinear.zero_point) 225 Y_deepcopied = copied_linear(X_q) 226 np.testing.assert_array_almost_equal( 227 Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) 228 229 # Test JIT 230 self.checkScriptable(qlinear, [[X_q]], check_save_load=True) 231 232 # Make sure `from_float` works for all linear variants 233 modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] 234 235 for mut in modules_under_test: 236 # Test from_float. 237 float_linear = mut(in_features, out_features).float() 238 float_linear.qconfig = torch.ao.quantization.default_qconfig 239 torch.ao.quantization.prepare(float_linear, inplace=True) 240 float_linear(X.float()) 241 # Sequential allows swapping using "convert". 242 quantized_float_linear = torch.nn.Sequential(float_linear) 243 quantized_float_linear = torch.ao.quantization.convert(quantized_float_linear, inplace=True) 244 245 # Smoke test to make sure the module actually runs 246 quantized_float_linear(X_q) 247 248 # Smoke test extra_repr 249 self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) 250 251 def test_quant_dequant_api(self): 252 r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float) 253 scale, zero_point, dtype = 1.0, 2, torch.qint8 254 # testing Quantize API 255 qr = torch.quantize_per_tensor(r, scale, zero_point, dtype) 256 quant_m = nnq.Quantize(scale, zero_point, dtype) 257 qr2 = quant_m(r) 258 self.assertEqual(qr, qr2) 259 # testing Dequantize API 260 rqr = qr.dequantize() 261 dequant_m = nnq.DeQuantize() 262 rqr2 = dequant_m(qr2) 263 self.assertEqual(rqr, rqr2) 264 265 def _test_conv_api_impl( 266 self, module_name, qconv_module, conv_module, batch_size, 267 in_channels_per_group, input_feature_map_size, out_channels_per_group, 268 groups, kernel_size, stride, padding, padding_mode, dilation, 269 X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, 270 use_bias, post_op, use_channelwise, X2_scale=1.0, X2_zero_point=0): 271 for i in range(len(kernel_size)): 272 assume(input_feature_map_size[i] + 2 * padding[i] 273 >= dilation[i] * (kernel_size[i] - 1) + 1) 274 275 in_channels = in_channels_per_group * groups 276 out_channels = out_channels_per_group * groups 277 (X, X_q, W, W_q, b) = _make_conv_test_input( 278 batch_size, in_channels_per_group, input_feature_map_size, 279 out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, 280 W_scale, W_zero_point, use_bias, use_channelwise) 281 example_input = [X, ] 282 example_input_q = [X_q, ] 283 284 if post_op in ["add", "add_relu"]: 285 X2, X2_q = _make_conv_add_extra_input_tensor(X2_scale, X2_zero_point, conv_module[0](X).size()) 286 example_input = [X, X2] 287 example_input_q = [X_q, X2_q] 288 289 # Make sure the weight shape is correct 290 self.assertTrue(qconv_module.weight().shape == W_q.shape) 291 292 qconv_module.set_weight_bias(W_q, b) 293 qconv_module.scale = Y_scale 294 qconv_module.zero_point = Y_zero_point 295 296 raw_conv_module = conv_module[0] if post_op in ["relu", "add", "add_relu"] else conv_module 297 raw_conv_module.weight.data = W 298 if use_bias: 299 raw_conv_module.bias.data = b 300 301 # Test members 302 self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name()) 303 self.assertTrue(hasattr(qconv_module, '_packed_params')) 304 self.assertTrue(hasattr(qconv_module, 'scale')) 305 self.assertTrue(hasattr(qconv_module, 'zero_point')) 306 307 # Test properties 308 self.assertEqual(W_q, qconv_module.weight()) 309 if use_bias: 310 self.assertEqual(b, qconv_module.bias()) 311 self.assertEqual(Y_scale, qconv_module.scale) 312 self.assertEqual(Y_zero_point, qconv_module.zero_point) 313 314 # Test forward 315 Y_exp = conv_module(*example_input) 316 Y_exp = torch.quantize_per_tensor( 317 Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8) 318 Y_act = qconv_module(*example_input_q) 319 320 # Make sure the results match 321 # assert_array_almost_equal compares using the following formula: 322 # abs(desired-actual) < 1.5 * 10**(-decimal) 323 # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html) 324 # We use decimal = 0 to ignore off-by-1 differences between reference 325 # and test. Off-by-1 differences arise due to the order of round and 326 # zero_point addition operation, i.e., if addition followed by round is 327 # used by reference and round followed by addition is used by test, the 328 # results may differ by 1. 329 # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is 330 # 4 assuming the rounding mode is round-to-nearest, ties-to-even. 331 # skip numerics checking for reference module 332 np.testing.assert_array_almost_equal( 333 Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0) 334 335 # Test serialization of quantized Conv Module using state_dict 336 model_dict = qconv_module.state_dict() 337 self.assertEqual(model_dict['weight'], W_q) 338 if use_bias: 339 self.assertEqual(model_dict['bias'], b) 340 bytes_io = io.BytesIO() 341 torch.save(model_dict, bytes_io) 342 for weights_only in [True, False]: 343 bytes_io.seek(0) 344 loaded_dict = torch.load(bytes_io, weights_only=weights_only) 345 for key in loaded_dict: 346 self.assertEqual(model_dict[key], loaded_dict[key]) 347 loaded_qconv_module = type(qconv_module)( 348 in_channels, out_channels, kernel_size, stride, padding, dilation, 349 groups, use_bias, padding_mode=padding_mode) 350 loaded_qconv_module.load_state_dict(loaded_dict) 351 352 self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module)) 353 self.assertTrue(module_name == loaded_qconv_module._get_name()) 354 self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) 355 self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias')) 356 357 self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight()) 358 if use_bias: 359 self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias()) 360 self.assertEqual(qconv_module.scale, loaded_qconv_module.scale) 361 self.assertEqual(qconv_module.zero_point, 362 loaded_qconv_module.zero_point) 363 Y_loaded = loaded_qconv_module(*example_input_q) 364 np.testing.assert_array_almost_equal( 365 Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0) 366 367 # Test serialization 368 b = io.BytesIO() 369 torch.save(qconv_module, b) 370 b.seek(0) 371 # weights_only=False as this is legacy code that saves the model 372 loaded_conv = torch.load(b, weights_only=False) 373 374 self.assertEqual(loaded_conv.bias(), qconv_module.bias()) 375 self.assertEqual(loaded_conv.scale, qconv_module.scale) 376 self.assertEqual(loaded_conv.zero_point, 377 qconv_module.zero_point) 378 379 # Test copy and deepcopy 380 copied_conv = copy.copy(qconv_module) 381 self.assertEqual(copied_conv.bias(), qconv_module.bias()) 382 self.assertEqual(copied_conv.scale, qconv_module.scale) 383 self.assertEqual(copied_conv.zero_point, 384 qconv_module.zero_point) 385 Y_copied = copied_conv(*example_input_q) 386 np.testing.assert_array_almost_equal( 387 Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0) 388 389 deepcopied_conv = copy.deepcopy(qconv_module) 390 self.assertEqual(deepcopied_conv.bias(), qconv_module.bias()) 391 self.assertEqual(deepcopied_conv.scale, qconv_module.scale) 392 self.assertEqual(deepcopied_conv.zero_point, 393 qconv_module.zero_point) 394 Y_deepcopied = deepcopied_conv(*example_input_q) 395 np.testing.assert_array_almost_equal( 396 Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0) 397 398 # JIT testing 399 self.checkScriptable( 400 qconv_module, [example_input_q], 401 check_save_load=True) 402 403 class _FusedModule_two_input_args(torch.ao.nn.intrinsic._FusedModule): 404 # Help Module for ConvAdd2d since torch.ao.nn.intrinsic._FusedModule only support one input arg 405 def forward(self, x1, x2): 406 input = self[0](x1, x2) 407 return input 408 409 # Test from_float 410 fused_conv_module = _FusedModule_two_input_args(conv_module) \ 411 if post_op in ["add", "add_relu"] else torch.ao.nn.intrinsic._FusedModule(conv_module) 412 413 fused_conv_module.qconfig = torch.ao.quantization.default_qconfig 414 torch.ao.quantization.prepare(fused_conv_module, inplace=True) 415 example_input[0] = example_input[0].float() 416 fused_conv_module(*example_input) 417 converted_qconv_module = fused_conv_module 418 reference_mapping = get_default_static_quant_module_mappings() 419 reference_mapping[type(conv_module)] = type(qconv_module) 420 torch.ao.quantization.convert(converted_qconv_module, mapping=reference_mapping, inplace=True) 421 422 # Smoke test to make sure the module actually runs 423 if use_bias: 424 self.assertEqual(conv_module[0].bias if (post_op in ["relu", "add", "add_relu"]) else conv_module.bias, 425 converted_qconv_module[0].bias()) 426 # Smoke test extra_repr 427 self.assertTrue(module_name == converted_qconv_module[0]._get_name()) 428 429 @override_qengines 430 def test_conv1d_api(self): 431 options = itertools.product( 432 ["zeros", "reflect"], # pad_mode 433 [True, False], # use_bias 434 [True, False], # use_channelwise 435 ) 436 for pad_mode, use_bias, use_channelwise in options: 437 if torch.backends.quantized.engine == "qnnpack": 438 use_channelwise = False 439 batch_size = 2 440 in_channels_per_group = 2 441 length = 8 442 out_channels_per_group = 2 443 groups = 3 444 kernel = 3 445 stride = 2 446 pad = 1 447 dilation = 1 448 # Tests the correctness of the conv2d module. 449 in_channels = in_channels_per_group * groups 450 out_channels = out_channels_per_group * groups 451 input_feature_map_size = (length,) 452 kernel_size = (kernel, ) 453 stride = (stride, ) 454 pad = (pad, ) 455 dilation = (dilation, ) 456 X_scale = 1.3 457 X_zero_point = 2 458 W_scale = [0.5] 459 W_zero_point = [0] if qengine_is_onednn() else [3] 460 Y_scale = 5.0 461 Y_zero_point = 4 462 if torch.backends.quantized.engine == 'qnnpack': 463 use_channelwise = False 464 qconv_cls = nnq.Conv1d 465 module_name = "QuantizedConv1d" 466 qconv_module = qconv_cls( 467 in_channels, out_channels, kernel, stride, pad, 468 dilation, groups, use_bias, padding_mode=pad_mode 469 ) 470 471 conv_module = nn.Conv1d( 472 in_channels, out_channels, kernel, stride, pad, 473 dilation, groups, use_bias, padding_mode=pad_mode) 474 conv_module = conv_module.float() 475 476 self._test_conv_api_impl( 477 module_name, qconv_module, conv_module, batch_size, 478 in_channels_per_group, input_feature_map_size, 479 out_channels_per_group, groups, kernel_size, stride, pad, pad_mode, 480 dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, 481 Y_zero_point, use_bias, "none", use_channelwise) 482 483 @override_qengines 484 def test_conv1d_relu_api(self): 485 options = itertools.product( 486 ["zeros", "reflect"], # pad_mode 487 [True, False], # use_bias 488 [True, False], # use_channelwise 489 ) 490 batch_size = 2 491 in_channels_per_group = 2 492 length = 8 493 out_channels_per_group = 2 494 groups = 3 495 kernel = 3 496 stride = 2 497 pad = 1 498 dilation = 1 499 # Tests the correctness of the conv2d module. 500 in_channels = in_channels_per_group * groups 501 out_channels = out_channels_per_group * groups 502 input_feature_map_size = (length,) 503 kernel_size = (kernel, ) 504 stride = (stride, ) 505 pad = (pad, ) 506 dilation = (dilation, ) 507 X_scale = 1.3 508 X_zero_point = 2 509 W_scale = [0.5] 510 W_zero_point = [0] if qengine_is_onednn() else [3] 511 Y_scale = 5.0 512 Y_zero_point = 4 513 qconv_cls = nniq.ConvReLU1d 514 module_name = "QuantizedConvReLU1d" 515 for pad_mode, use_bias, use_channelwise in options: 516 if torch.backends.quantized.engine == 'qnnpack': 517 use_channelwise = False 518 qconv_module = qconv_cls( 519 in_channels, out_channels, kernel, stride, pad, 520 dilation, groups, use_bias, padding_mode=pad_mode 521 ) 522 523 conv_module = nn.Conv1d( 524 in_channels, out_channels, kernel, stride, pad, 525 dilation, groups, use_bias, padding_mode=pad_mode) 526 relu_module = nn.ReLU() 527 conv_module = nni.ConvReLU1d(conv_module, relu_module) 528 conv_module = conv_module.float() 529 530 self._test_conv_api_impl( 531 module_name, qconv_module, conv_module, batch_size, 532 in_channels_per_group, input_feature_map_size, 533 out_channels_per_group, groups, kernel_size, stride, pad, pad_mode, 534 dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, 535 Y_zero_point, use_bias, "relu", use_channelwise) 536 537 @override_qengines 538 def test_conv2d_api(self): 539 options = itertools.product( 540 ["zeros", "reflect"], # pad_mode 541 [True, False], # use_bias 542 [True, False], # use_channelwise 543 ) 544 for pad_mode, use_bias, use_channelwise in options: 545 if torch.backends.quantized.engine == "qnnpack": 546 use_channelwise = False 547 batch_size = 2 548 in_channels_per_group = 2 549 H = 8 550 W = 8 551 out_channels_per_group = 2 552 groups = 3 553 kernel_h = 3 554 kernel_w = 3 555 stride_h = 2 556 stride_w = 2 557 pad_h = 1 558 pad_w = 1 559 dilation = 1 560 # Tests the correctness of the conv2d module. 561 in_channels = in_channels_per_group * groups 562 out_channels = out_channels_per_group * groups 563 input_feature_map_size = (H, W) 564 kernel_size = (kernel_h, kernel_w) 565 stride = (stride_h, stride_w) 566 padding = (pad_h, pad_w) 567 dilation = (dilation, dilation) 568 X_scale = 1.3 569 X_zero_point = 2 570 W_scale = [0.5] 571 W_zero_point = [0] if qengine_is_onednn() else [3] 572 Y_scale = 5.0 573 Y_zero_point = 4 574 qconv_cls = nnq.Conv2d 575 module_name = "QuantizedConv2d" 576 qconv_module = qconv_cls( 577 in_channels, out_channels, kernel_size, stride, padding, 578 dilation, groups, use_bias, padding_mode=pad_mode 579 ) 580 581 conv_module = nn.Conv2d( 582 in_channels, out_channels, kernel_size, stride, padding, 583 dilation, groups, use_bias, padding_mode=pad_mode) 584 conv_module = conv_module.float() 585 586 self._test_conv_api_impl( 587 module_name, qconv_module, conv_module, batch_size, 588 in_channels_per_group, input_feature_map_size, 589 out_channels_per_group, groups, kernel_size, stride, padding, 590 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, 591 Y_scale, Y_zero_point, use_bias, "none", use_channelwise) 592 593 @override_qengines 594 def test_conv2d_relu_api(self): 595 options = itertools.product( 596 ["zeros", "reflect"], # pad_mode 597 [True, False], # use_bias 598 [True, False], # use_channelwise 599 ) 600 batch_size = 2 601 in_channels_per_group = 2 602 H = 8 603 W = 8 604 out_channels_per_group = 2 605 groups = 3 606 kernel_h = 3 607 kernel_w = 3 608 stride_h = 2 609 stride_w = 2 610 pad_h = 1 611 pad_w = 1 612 dilation = 1 613 # Tests the correctness of the conv2d module. 614 in_channels = in_channels_per_group * groups 615 out_channels = out_channels_per_group * groups 616 input_feature_map_size = (H, W) 617 kernel_size = (kernel_h, kernel_w) 618 stride = (stride_h, stride_w) 619 padding = (pad_h, pad_w) 620 dilation = (dilation, dilation) 621 X_scale = 1.3 622 X_zero_point = 2 623 W_scale = [0.5] 624 W_zero_point = [0] if qengine_is_onednn() else [3] 625 Y_scale = 5.0 626 Y_zero_point = 4 627 qconv_cls = nniq.ConvReLU2d 628 module_name = "QuantizedConvReLU2d" 629 for pad_mode, use_bias, use_channelwise in options: 630 if torch.backends.quantized.engine == "qnnpack": 631 use_channelwise = False 632 qconv_module = qconv_cls( 633 in_channels, out_channels, kernel_size, stride, padding, 634 dilation, groups, use_bias, padding_mode=pad_mode 635 ) 636 637 conv_module = nn.Conv2d( 638 in_channels, out_channels, kernel_size, stride, padding, 639 dilation, groups, use_bias, padding_mode=pad_mode) 640 relu_module = nn.ReLU() 641 conv_module = nni.ConvReLU2d(conv_module, relu_module) 642 conv_module = conv_module.float() 643 644 self._test_conv_api_impl( 645 module_name, qconv_module, conv_module, batch_size, 646 in_channels_per_group, input_feature_map_size, 647 out_channels_per_group, groups, kernel_size, stride, padding, 648 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, 649 Y_scale, Y_zero_point, use_bias, "relu", use_channelwise) 650 651 @skipIfNoFBGEMM 652 def test_conv3d_api(self): 653 options = itertools.product( 654 [True, False], # use_bias 655 [True, False], # use_channelwise 656 ) 657 batch_size = 2 658 in_channels_per_group = 2 659 H = 8 660 W = 8 661 D = 8 662 out_channels_per_group = 2 663 groups = 3 664 kernel_h = 3 665 kernel_w = 3 666 kernel_d = 3 667 stride_h = 2 668 stride_w = 2 669 stride_d = 2 670 pad_mode = "zeros" # 3d doesn't support reflect padding 671 pad_h = 1 672 pad_w = 1 673 pad_d = 1 674 dilation = 1 675 # Tests the correctness of the conv3d module. 676 in_channels = in_channels_per_group * groups 677 out_channels = out_channels_per_group * groups 678 input_feature_map_size = (D, H, W) 679 kernel_size = (kernel_d, kernel_h, kernel_w) 680 stride = (stride_d, stride_h, stride_w) 681 padding = (pad_d, pad_h, pad_w) 682 dilation = (dilation, dilation, dilation) 683 X_scale = 1.3 684 X_zero_point = 2 685 W_scale = [0.5] 686 W_zero_point = [0] if qengine_is_onednn() else [3] 687 Y_scale = 5.0 688 Y_zero_point = 4 689 qconv_cls = nnq.Conv3d 690 module_name = "QuantizedConv3d" 691 for use_bias, use_channelwise in options: 692 if torch.backends.quantized.engine == "qnnpack": 693 use_channelwise = False 694 with override_quantized_engine('fbgemm'): 695 qconv_module = qconv_cls( 696 in_channels, out_channels, kernel_size, stride, padding, 697 dilation, groups, use_bias, padding_mode=pad_mode 698 ) 699 700 conv_module = nn.Conv3d( 701 in_channels, out_channels, kernel_size, stride, padding, 702 dilation, groups, use_bias, padding_mode=pad_mode) 703 conv_module = conv_module.float() 704 705 self._test_conv_api_impl( 706 module_name, qconv_module, conv_module, batch_size, 707 in_channels_per_group, input_feature_map_size, 708 out_channels_per_group, groups, kernel_size, stride, padding, 709 pad_mode, dilation, X_scale, X_zero_point, W_scale, 710 W_zero_point, Y_scale, Y_zero_point, use_bias, "none", 711 use_channelwise) 712 713 @skipIfNoFBGEMM 714 def test_conv3d_relu_api(self): 715 options = itertools.product( 716 [True, False], # use_bias 717 [True, False], # use_channelwise 718 ) 719 batch_size = 2 720 in_channels_per_group = 2 721 H = 8 722 W = 8 723 D = 8 724 out_channels_per_group = 2 725 groups = 3 726 kernel_h = 3 727 kernel_w = 3 728 kernel_d = 3 729 stride_h = 2 730 stride_w = 2 731 stride_d = 2 732 pad_mode = "zeros" # 3d doesn't support reflect padding 733 pad_h = 1 734 pad_w = 1 735 pad_d = 1 736 dilation = 1 737 # Tests the correctness of the conv3d module. 738 in_channels = in_channels_per_group * groups 739 out_channels = out_channels_per_group * groups 740 input_feature_map_size = (D, H, W) 741 kernel_size = (kernel_d, kernel_h, kernel_w) 742 stride = (stride_d, stride_h, stride_w) 743 padding = (pad_d, pad_h, pad_w) 744 dilation = (dilation, dilation, dilation) 745 X_scale = 1.3 746 X_zero_point = 2 747 W_scale = [0.5] 748 W_zero_point = [0] if qengine_is_onednn() else [3] 749 Y_scale = 5.0 750 Y_zero_point = 4 751 qconv_cls = nniq.ConvReLU3d 752 module_name = "QuantizedConvReLU3d" 753 for use_bias, use_channelwise in options: 754 if torch.backends.quantized.engine == "qnnpack": 755 use_channelwise = False 756 with override_quantized_engine('fbgemm'): 757 qconv_module = qconv_cls( 758 in_channels, out_channels, kernel_size, stride, padding, 759 dilation, groups, use_bias, padding_mode=pad_mode 760 ) 761 762 conv_module = nn.Conv3d( 763 in_channels, out_channels, kernel_size, stride, padding, 764 dilation, groups, use_bias, padding_mode=pad_mode) 765 relu_module = nn.ReLU() 766 conv_module = nni.ConvReLU3d(conv_module, relu_module) 767 conv_module = conv_module.float() 768 769 self._test_conv_api_impl( 770 module_name, qconv_module, conv_module, batch_size, 771 in_channels_per_group, input_feature_map_size, 772 out_channels_per_group, groups, kernel_size, stride, padding, 773 pad_mode, dilation, X_scale, X_zero_point, W_scale, 774 W_zero_point, Y_scale, Y_zero_point, use_bias, "relu", 775 use_channelwise) 776 777 @skipIfNoONEDNN 778 def test_conv2d_add(self): 779 """test API functionality for nn.intrinsic.quantized.ConvAdd2d""" 780 with override_quantized_engine('onednn'): 781 options = itertools.product( 782 ["zeros", "reflect"], # pad_mode 783 [True, False], # use_bias 784 [True, False], # use_channelwise 785 ) 786 batch_size = 2 787 in_channels_per_group = 2 788 H = 8 789 W = 8 790 out_channels_per_group = 2 791 groups = 3 792 kernel_h = 3 793 kernel_w = 3 794 stride_h = 2 795 stride_w = 2 796 pad_h = 1 797 pad_w = 1 798 dilation = 1 799 # Tests the correctness of the conv2d module. 800 in_channels = in_channels_per_group * groups 801 out_channels = out_channels_per_group * groups 802 input_feature_map_size = (H, W) 803 kernel_size = (kernel_h, kernel_w) 804 stride = (stride_h, stride_w) 805 padding = (pad_h, pad_w) 806 dilation = (dilation, dilation) 807 X_scale = 1.3 808 X_zero_point = 2 809 X2_scale = 1.2 810 X2_zero_point = 1 811 W_scale = [0.5] 812 W_zero_point = [0] if qengine_is_onednn() else [3] 813 Y_scale = 5.0 814 Y_zero_point = 4 815 qconv_cls = nniq.ConvAdd2d 816 module_name = "QuantizedConvAdd2d" 817 for pad_mode, use_bias, use_channelwise in options: 818 qconv_module = qconv_cls( 819 in_channels, out_channels, kernel_size, stride, padding, 820 dilation, groups, use_bias, padding_mode=pad_mode 821 ) 822 823 conv_module = nn.Conv2d( 824 in_channels, out_channels, kernel_size, stride, padding, 825 dilation, groups, use_bias, padding_mode=pad_mode) 826 conv_module = torch.ao.nn.intrinsic.ConvAdd2d(conv_module, torch.add) 827 conv_module = conv_module.float() 828 829 self._test_conv_api_impl( 830 module_name, qconv_module, conv_module, batch_size, 831 in_channels_per_group, input_feature_map_size, 832 out_channels_per_group, groups, kernel_size, stride, padding, 833 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, 834 Y_scale, Y_zero_point, use_bias, "add", use_channelwise, X2_scale, X2_zero_point) 835 836 @skipIfNoONEDNN 837 def test_conv2d_add_relu(self): 838 """test API functionality for nn.intrinsic.quantized.ConvAdd2d""" 839 with override_quantized_engine('onednn'): 840 options = itertools.product( 841 ["zeros", "reflect"], # pad_mode 842 [True, False], # use_bias 843 [True, False], # use_channelwise 844 ) 845 batch_size = 2 846 in_channels_per_group = 2 847 H = 8 848 W = 8 849 out_channels_per_group = 2 850 groups = 3 851 kernel_h = 3 852 kernel_w = 3 853 stride_h = 2 854 stride_w = 2 855 pad_h = 1 856 pad_w = 1 857 dilation = 1 858 # Tests the correctness of the conv2d module. 859 in_channels = in_channels_per_group * groups 860 out_channels = out_channels_per_group * groups 861 input_feature_map_size = (H, W) 862 kernel_size = (kernel_h, kernel_w) 863 stride = (stride_h, stride_w) 864 padding = (pad_h, pad_w) 865 dilation = (dilation, dilation) 866 X_scale = 1.3 867 X_zero_point = 2 868 X2_scale = 1.2 869 X2_zero_point = 1 870 W_scale = [0.5] 871 W_zero_point = [0] if qengine_is_onednn() else [3] 872 Y_scale = 5.0 873 Y_zero_point = 4 874 qconv_cls = nniq.ConvAddReLU2d 875 module_name = "QuantizedConvAddReLU2d" 876 for pad_mode, use_bias, use_channelwise in options: 877 qconv_module = qconv_cls( 878 in_channels, out_channels, kernel_size, stride, padding, 879 dilation, groups, use_bias, padding_mode=pad_mode 880 ) 881 882 conv_module = nn.Conv2d( 883 in_channels, out_channels, kernel_size, stride, padding, 884 dilation, groups, use_bias, padding_mode=pad_mode) 885 conv_module = torch.ao.nn.intrinsic.ConvAddReLU2d(conv_module, torch.add, nn.ReLU()) 886 conv_module = conv_module.float() 887 888 self._test_conv_api_impl( 889 module_name, qconv_module, conv_module, batch_size, 890 in_channels_per_group, input_feature_map_size, 891 out_channels_per_group, groups, kernel_size, stride, padding, 892 pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point, 893 Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, X2_scale, X2_zero_point) 894 895 def test_pool_api(self): 896 """Tests the correctness of the pool module. 897 The correctness is defined against the functional implementation. 898 """ 899 N, C, H, W = 10, 10, 10, 3 900 kwargs = { 901 'kernel_size': 2, 902 'stride': None, 903 'padding': 0, 904 'dilation': 1 905 } 906 907 scale, zero_point = 1.0 / 255, 128 908 909 X = torch.randn(N, C, H, W, dtype=torch.float32) 910 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 911 dtype=torch.quint8) 912 qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs) 913 914 pool_under_test = torch.ao.nn.quantized.MaxPool2d(**kwargs) 915 qX_hat = pool_under_test(qX) 916 self.assertEqual(qX_expect, qX_hat) 917 918 # JIT Testing 919 self.checkScriptable(pool_under_test, [[X]]) 920 921 def test_dropout(self): 922 """Tests the correctness of the dropout module. 923 The correctness is defined against the functional implementation. 924 """ 925 x = torch.randn((2, 4, 6, 8), dtype=torch.float) 926 float_mod = torch.nn.Dropout(p=0.5) 927 float_mod.training = False 928 929 y_ref = float_mod(x) 930 quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8) 931 932 quant_mod = nnq.Dropout(p=0.5) 933 qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8) 934 qy = quant_mod(qx) 935 936 self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(), 937 msg="Dropout module API failed") 938 939 def _test_dropout_serialization(self, get_model, data1, data2): 940 m1 = get_model() 941 m1.qconfig = torch.ao.quantization.default_qconfig 942 mp1 = torch.ao.quantization.prepare(m1) 943 mp1(data1) 944 mq1 = torch.ao.quantization.convert(mp1) 945 ref1 = mq1(data2) 946 947 m2 = get_model() 948 m2.qconfig = torch.ao.quantization.default_qconfig 949 mp2 = torch.ao.quantization.prepare(m2) 950 mq2 = torch.ao.quantization.convert(mp2) 951 952 mq2.load_state_dict(mq1.state_dict()) 953 ref2 = mq2(data2) 954 955 self.assertTrue(torch.allclose(ref1, ref2)) 956 957 def test_dropout_serialization(self): 958 data1 = torch.randn(2, 4, 6, 8) 959 data2 = torch.randn(2, 4, 6, 8) 960 961 def _get_model(): 962 return nn.Sequential( 963 torch.ao.quantization.QuantStub(), 964 nn.Dropout(p=0.5), 965 torch.ao.quantization.DeQuantStub() 966 ).eval() 967 968 self._test_dropout_serialization(_get_model, data1, data2) 969 970 971 972 def test_batch_norm2d(self): 973 """Tests the correctness of the batchnorm2d module. 974 The correctness is defined against the functional implementation. 975 """ 976 x = torch.randn((2, 4, 6, 8), dtype=torch.float) 977 float_mod = torch.nn.BatchNorm2d(4) 978 float_mod.training = False 979 980 y_ref = float_mod(x) 981 quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8) 982 983 quant_mod = nnq.BatchNorm2d(4) 984 qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8) 985 qy = quant_mod(qx) 986 987 self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(), 988 msg="BatchNorm2d module API failed") 989 990 def test_batch_norm3d(self): 991 """Tests the correctness of the batchnorm3d module. 992 The correctness is defined against the functional implementation. 993 """ 994 x = torch.randn((2, 4, 6, 8, 10), dtype=torch.float) 995 float_mod = torch.nn.BatchNorm3d(4) 996 float_mod.training = False 997 998 y_ref = float_mod(x) 999 quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8) 1000 1001 quant_mod = nnq.BatchNorm3d(4) 1002 qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8) 1003 qy = quant_mod(qx) 1004 1005 self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(), 1006 msg="BatchNorm3d module API failed") 1007 1008 def _test_batch_norm_serialization(self, get_model, data1, data2): 1009 m1 = get_model() 1010 m1.qconfig = torch.ao.quantization.default_qconfig 1011 mp1 = torch.ao.quantization.prepare(m1) 1012 mp1(data1) 1013 mq1 = torch.ao.quantization.convert(mp1) 1014 ref1 = mq1(data2) 1015 1016 m2 = get_model() 1017 m2.qconfig = torch.ao.quantization.default_qconfig 1018 mp2 = torch.ao.quantization.prepare(m2) 1019 mq2 = torch.ao.quantization.convert(mp2) 1020 1021 mq2.load_state_dict(mq1.state_dict()) 1022 ref2 = mq2(data2) 1023 1024 self.assertTrue(torch.allclose(ref1, ref2)) 1025 1026 def test_batch_norm2d_serialization(self): 1027 data1 = torch.randn(2, 4, 6, 8) 1028 data2 = torch.randn(2, 4, 6, 8) 1029 1030 def _get_model(): 1031 return nn.Sequential( 1032 torch.ao.quantization.QuantStub(), 1033 nn.BatchNorm2d(4), 1034 torch.ao.quantization.DeQuantStub() 1035 ).eval() 1036 1037 self._test_batch_norm_serialization(_get_model, data1, data2) 1038 1039 def test_batch_norm3d_serialization(self): 1040 data1 = torch.randn(2, 4, 6, 8, 1) 1041 data2 = torch.randn(2, 4, 6, 8, 1) 1042 1043 def _get_model(): 1044 return nn.Sequential( 1045 torch.ao.quantization.QuantStub(), 1046 nn.BatchNorm3d(4), 1047 torch.ao.quantization.DeQuantStub() 1048 ).eval() 1049 1050 self._test_batch_norm_serialization(_get_model, data1, data2) 1051 1052 def test_layer_norm(self): 1053 """Tests the correctness of the layernorm module. 1054 The correctness is defined against the functional implementation. 1055 """ 1056 x_scale = 10.0 / 256 1057 x_zero_point = 0 1058 y_scale = 5.0 / 256 1059 y_zero_point = 127 1060 1061 dims = (1, 4, 8) 1062 1063 X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10 1064 qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8) 1065 dqX = qX.dequantize() 1066 1067 float_mod = torch.nn.LayerNorm(dqX.size()[1:]).float() 1068 float_mod.weight = torch.nn.Parameter(torch.rand(*dims[1:])) 1069 float_mod.bias = torch.nn.Parameter(torch.rand(*dims[1:])) 1070 1071 dqY_ref = float_mod(dqX) 1072 qY_ref = torch.quantize_per_tensor( 1073 dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) 1074 1075 quant_mod = nnq.LayerNorm( 1076 qX.size()[1:], float_mod.weight, float_mod.bias, y_scale, y_zero_point) 1077 qY = quant_mod(qX) 1078 1079 self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(), 1080 msg=f"LayerNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}") 1081 1082 def test_group_norm(self): 1083 """Tests the correctness of the groupnorm module. 1084 The correctness is defined against the functional implementation. 1085 """ 1086 x_scale = 10.0 / 256 1087 x_zero_point = 0 1088 y_scale = 5.0 / 256 1089 y_zero_point = 127 1090 1091 dims = (1, 4, 8) 1092 1093 X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10 1094 qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8) 1095 dqX = qX.dequantize() 1096 1097 float_mod = torch.nn.GroupNorm(2, 4).float() 1098 float_mod.weight = torch.nn.Parameter(torch.rand(dims[1])) 1099 float_mod.bias = torch.nn.Parameter(torch.rand(dims[1])) 1100 1101 dqY_ref = float_mod(dqX) 1102 qY_ref = torch.quantize_per_tensor( 1103 dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) 1104 1105 quant_mod = nnq.GroupNorm( 1106 2, 2, float_mod.weight, float_mod.bias, y_scale, y_zero_point) 1107 qY = quant_mod(qX) 1108 1109 self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(), 1110 msg=f"GroupNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}") 1111 1112 def test_instance_norm(self): 1113 """Tests the correctness of the instancenorm{n}d modules. 1114 The correctness is defined against the functional implementation. 1115 """ 1116 x_scale = 10.0 / 256 1117 x_zero_point = 0 1118 y_scale = 5.0 / 256 1119 y_zero_point = 127 1120 1121 dims_to_modules = [ 1122 ((1, 4, 8), torch.nn.InstanceNorm1d, nnq.InstanceNorm1d), 1123 ((1, 4, 8, 1), torch.nn.InstanceNorm2d, nnq.InstanceNorm2d), 1124 ((1, 4, 8, 1, 1), torch.nn.InstanceNorm3d, nnq.InstanceNorm3d), 1125 ] 1126 1127 for dim_to_modules in dims_to_modules: 1128 dims, float_cls, q_cls = dim_to_modules 1129 1130 X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10 1131 qX = torch.quantize_per_tensor( 1132 X, x_scale, x_zero_point, dtype=torch.quint8) 1133 dqX = qX.dequantize() 1134 1135 float_mod = float_cls(dims[1]).float() 1136 float_mod.weight = torch.nn.Parameter(torch.rand(dims[1])) 1137 float_mod.bias = torch.nn.Parameter(torch.rand(dims[1])) 1138 1139 dqY_ref = float_mod(dqX) 1140 qY_ref = torch.quantize_per_tensor( 1141 dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) 1142 1143 quant_mod = q_cls( 1144 dims[1], float_mod.weight, float_mod.bias, y_scale, 1145 y_zero_point) 1146 qY = quant_mod(qX) 1147 1148 self.assertEqual( 1149 qY_ref.int_repr().numpy(), qY.int_repr().numpy(), 1150 msg=f"InstanceNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}") 1151 1152 def _test_activation_module_impl(self, name, float_module_class, quantized_module_class, extra_kwargs): 1153 """Tests the correctness of the ELU module. 1154 The correctness is defined against the functional implementation. 1155 """ 1156 x_scale = 10.0 / 256 1157 x_zero_point = 0 1158 y_scale = 5.0 / 256 1159 y_zero_point = 127 1160 alpha = 1.5 1161 1162 dims = (1, 4, 8) 1163 1164 X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10 1165 qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8) 1166 dqX = qX.dequantize() 1167 1168 float_mod = float_module_class(**extra_kwargs).float() 1169 1170 dqY_ref = float_mod(dqX) 1171 qY_ref = torch.quantize_per_tensor( 1172 dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) 1173 1174 quant_mod = quantized_module_class(y_scale, y_zero_point, **extra_kwargs) 1175 qY = quant_mod(qX) 1176 self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(), 1177 msg=f"{name} module API failed, qY_ref\n{qY_ref} vs qY\n{qY}") 1178 1179 def _test_leaky_relu_serialization(self): 1180 scale_original = 10.0 / 256 1181 zero_point_original = 1.0 1182 1183 quant_mod_original = nnq.LeakyReLU(scale_original, zero_point_original) 1184 state_dict = quant_mod_original.state_dict() 1185 1186 scale_new = 5.0 / 256 1187 zero_point_new = 2.0 1188 quant_mod_new = nnq.LeakyReLU(scale_new, zero_point_new) 1189 quant_mod_new.load_state_dict(state_dict) 1190 1191 self.assertEqual(quant_mod_original.scale, quant_mod_new.scale) 1192 self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point) 1193 1194 def test_elu(self): 1195 """Tests the correctness of the ELU module. 1196 The correctness is defined against the functional implementation. 1197 """ 1198 self._test_activation_module_impl("ELU", nn.ELU, nnq.ELU, {"alpha": 1.5}) 1199 1200 def test_leaky_relu(self): 1201 self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2}) 1202 self._test_leaky_relu_serialization() 1203 1204 def test_sigmoid(self): 1205 self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {}) 1206 1207 def _test_hard_swish_serialization(self): 1208 scale_original = 10.0 / 256 1209 zero_point_original = 1.0 1210 1211 quant_mod_original = nnq.Hardswish(scale_original, zero_point_original) 1212 state_dict = quant_mod_original.state_dict() 1213 1214 scale_new = 5.0 / 256 1215 zero_point_new = 2.0 1216 quant_mod_new = nnq.Hardswish(scale_new, zero_point_new) 1217 quant_mod_new.load_state_dict(state_dict) 1218 1219 self.assertEqual(quant_mod_original.scale, quant_mod_new.scale) 1220 self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point) 1221 1222 def test_hard_swish(self): 1223 self._test_activation_module_impl("Hardswish", nn.Hardswish, nnq.Hardswish, {}) 1224 self._test_hard_swish_serialization() 1225 1226 @given( 1227 num_embeddings=st.integers(10, 50), 1228 embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), 1229 set_qconfig=st.booleans(), 1230 ) 1231 @skipIfNoFBGEMM 1232 def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig): 1233 num_lengths = np.random.randint(1, 6) 1234 lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) 1235 num_indices = np.sum(lengths) 1236 indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) 1237 weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) 1238 1239 obs = default_float_qparams_observer() 1240 obs(weights) 1241 qparams = obs.calculate_qparams() 1242 1243 dtypes = [torch.quint4x2, torch.quint8] 1244 embedding_funcs = [torch.ops.quantized.embedding_4bit, torch.ops.quantized.embedding_byte] 1245 1246 for dtype, embedding_func in zip(dtypes, embedding_funcs): 1247 # Quantize the weights 1248 qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype) 1249 qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) 1250 qemb.set_weight(qweight) 1251 qemb(indices) 1252 1253 # Ensure the module has the correct weights 1254 self.assertEqual(qweight, qemb.weight()) 1255 w_packed = qemb._packed_params._packed_weight 1256 module_out = qemb(indices) 1257 1258 # Call the bit qembedding operator directly 1259 ref = embedding_func(w_packed, indices, pruned_weights=False) 1260 self.assertEqual(module_out, ref) 1261 self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, 1262 is_emb_bag=False, dtype=dtype) 1263 1264 @given( 1265 num_embeddings=st.integers(10, 50), 1266 embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), 1267 num_offsets=st.integers(1, 20), 1268 set_qconfig=st.booleans(), 1269 ) 1270 @skipIfNoFBGEMM 1271 def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig): 1272 r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8 1273 """ 1274 1275 num_lengths = np.random.randint(1, 6) 1276 lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32) 1277 num_indices = np.sum(lengths) 1278 indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) 1279 1280 offsets = lengths_to_offsets(lengths) 1281 # include the last offset 1282 offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0) 1283 weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)) 1284 1285 for qdtype in [torch.quint8, torch.quint4x2]: 1286 obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) 1287 obs(weights) 1288 # Get the scale and zero point for the weight tensor 1289 qparams = obs.calculate_qparams() 1290 # Quantize the weights to 8bits 1291 qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype) 1292 qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, 1293 include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype) 1294 qemb(indices, offsets) 1295 1296 # Ensure the module has the correct weights 1297 self.assertEqual(qweight, qemb.weight()) 1298 1299 w_packed = qemb._packed_params._packed_weight 1300 module_out = qemb(indices, offsets) 1301 1302 # Call the qembedding_bag operator directly 1303 if qdtype == torch.quint8: 1304 ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0, 1305 per_sample_weights=None, 1306 include_last_offset=True) 1307 else: 1308 ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0, 1309 per_sample_weights=None, 1310 include_last_offset=True) 1311 1312 self.assertEqual(module_out, ref) 1313 self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, 1314 offsets, set_qconfig, is_emb_bag=True, dtype=qdtype) 1315 1316 def test_prelu(self): 1317 for num_parameters in range(1, 10): 1318 x = torch.randn(4, num_parameters, 4) 1319 qx = torch.quantize_per_tensor_dynamic(x, dtype=torch.quint8, reduce_range=False) 1320 1321 1322 f_prelu = torch.nn.PReLU(num_parameters=num_parameters) 1323 f_prelu.weight = torch.nn.Parameter(torch.randn(num_parameters).abs()) 1324 f_prelu.qconfig = torch.ao.quantization.QConfig( 1325 activation=torch.ao.quantization.default_observer, 1326 weight=torch.ao.quantization.default_observer,) 1327 f_prelu.activation_post_process = f_prelu.qconfig.activation() 1328 f_prelu.activation_post_process(f_prelu(x)) 1329 q_prelu = nnq.PReLU.from_float(f_prelu) 1330 w_obs = f_prelu.qconfig.weight() 1331 w_obs(f_prelu.weight) 1332 w_scale, w_zp = w_obs.calculate_qparams() 1333 q_prelu_weight = torch.quantize_per_tensor( 1334 f_prelu.weight, 1335 dtype=torch.quint8, 1336 scale=w_scale, 1337 zero_point=w_zp 1338 ).dequantize() 1339 1340 # check that the weight makes sense 1341 self.assertEqual(q_prelu.weight.dequantize(), q_prelu_weight) 1342 f_prelu.weight = torch.nn.Parameter(q_prelu.weight.dequantize()) 1343 qy = q_prelu(qx) 1344 qy_ref = torch.quantize_per_tensor( 1345 f_prelu(qx.dequantize()), q_prelu.scale, q_prelu.zero_point, dtype=torch.quint8 1346 ) 1347 # check that the output makes sense 1348 self.assertEqual(qy, qy_ref, atol=.1, rtol=.1) 1349 1350 def test_channel_shuffle(self): 1351 """Tests the correctness of the ChannelShuffle module. 1352 """ 1353 x_scale = 10.0 / 256 1354 x_zero_point = 1 1355 y_scale = x_scale 1356 y_zero_point = x_zero_point 1357 1358 dims = (1, 4, 4, 8) 1359 groups = 2 1360 1361 X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10 1362 qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8) 1363 dqX = qX.dequantize() 1364 1365 float_mod = torch.nn.ChannelShuffle(groups).float() 1366 dqY_ref = float_mod(dqX) 1367 qY_ref = torch.quantize_per_tensor( 1368 dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) 1369 1370 quant_mod = torch.nn.ChannelShuffle(groups) 1371 qY = quant_mod(qX) 1372 1373 self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(), 1374 msg=f"ChannelShuffle module API failed, qY_ref\n{qY_ref} vs qY\n{qY}") 1375 1376 @skipIfNoONEDNN 1377 def test_linear_leaky_relu(self): 1378 """test API functionality for nn.intrinsic.quantized.linear_leaky_relu""" 1379 with override_quantized_engine('onednn'): 1380 options = itertools.product( 1381 [1, 5], # batch size 1382 [16, 32], # in_features 1383 [4, 8], # out_features 1384 [True, False], # use_bias 1385 [True, False], # per_channel 1386 [0.01, 0.05]) # negative slope 1387 for (batch_size, in_features, out_features, use_bias, 1388 per_channel, neg_slope) in options: 1389 self._test_linear_api_impl( 1390 nniq.LinearLeakyReLU, 'QuantizedLinearLeakyReLU', 1391 torch.ops.quantized.linear_leaky_relu, 1392 batch_size, in_features, out_features, use_bias, 1393 per_channel, negative_slope=neg_slope) 1394 1395 @skipIfNoONEDNN 1396 def test_linear_tanh(self): 1397 """test API functionality for nn.intrinsic.quantized.linear_tanh""" 1398 with override_quantized_engine('onednn'): 1399 options = itertools.product( 1400 [1, 5], # batch size 1401 [16, 32], # in_features 1402 [4, 8], # out_features 1403 [True, False], # use_bias 1404 [True, False]) # negative slope 1405 for (batch_size, in_features, out_features, use_bias, 1406 per_channel) in options: 1407 self._test_linear_api_impl( 1408 nniq.LinearTanh, 'QuantizedLinearTanh', 1409 torch.ops.quantized.linear_tanh, 1410 batch_size, in_features, out_features, use_bias, 1411 per_channel) 1412 1413class TestDynamicQuantizedModule(QuantizationTestCase): 1414 def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias): 1415 in_channels = 3 1416 out_channels = 10 1417 kernel_size = 2 1418 stride = 1 1419 padding = 0 1420 dilation = 1 1421 groups = 1 1422 padding_mode = 'zeros' 1423 1424 if qengine_is_qnnpack(): 1425 reduce_range = False 1426 else: 1427 reduce_range = True 1428 1429 X_fp32 = torch.randn(*([in_channels] * dim)) 1430 s, z = _calculate_dynamic_qparams(X_fp32, dtype, reduce_range) 1431 X_q = torch.quantize_per_tensor(X_fp32, s, z, dtype) 1432 X_dq = torch.dequantize(X_q) 1433 1434 quantized_module = q_mod(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 1435 dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) 1436 dynamic_module = dq_mod(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 1437 dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) 1438 1439 quantized_module.scale, quantized_module.zero_point = s, z 1440 dynamic_module.set_weight_bias(*quantized_module._weight_bias()) 1441 1442 Y_q_ref = quantized_module(X_q) 1443 Y_ref = torch.dequantize(Y_q_ref) 1444 1445 Y = dynamic_module(X_dq, reduce_range) 1446 1447 self.assertEqual(Y, Y_ref) 1448 1449 # Test serialization of quantized Conv Module using state_dict 1450 W_q, b = dynamic_module._weight_bias() 1451 model_dict = dynamic_module.state_dict() 1452 self.assertEqual(model_dict['weight'], W_q) 1453 self.assertEqual(model_dict['bias'], b) 1454 bytes_io = io.BytesIO() 1455 torch.save(model_dict, bytes_io) 1456 for weights_only in [True, False]: 1457 bytes_io.seek(0) 1458 loaded_dict = torch.load(bytes_io, weights_only=weights_only) 1459 for key in loaded_dict: 1460 self.assertEqual(model_dict[key], loaded_dict[key]) 1461 loaded_qconv_module = type(dynamic_module)( 1462 in_channels, out_channels, kernel_size, stride=stride, padding=padding, 1463 dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) 1464 loaded_qconv_module.load_state_dict(loaded_dict) 1465 1466 self.assertTrue(dir(loaded_qconv_module) == dir(dynamic_module)) 1467 self.assertTrue(dynamic_module._get_name() == loaded_qconv_module._get_name()) 1468 self.assertTrue(hasattr(loaded_qconv_module, '_packed_params')) 1469 self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias')) 1470 1471 self.assertEqual(dynamic_module.weight(), loaded_qconv_module.weight()) 1472 if bias: 1473 self.assertEqual(dynamic_module.bias(), loaded_qconv_module.bias()) 1474 self.assertEqual(dynamic_module.scale, loaded_qconv_module.scale) 1475 self.assertEqual(dynamic_module.zero_point, 1476 loaded_qconv_module.zero_point) 1477 Y_loaded = loaded_qconv_module(X_fp32, reduce_range) 1478 np.testing.assert_array_almost_equal( 1479 Y.numpy(), Y_loaded.numpy(), decimal=0) 1480 1481 # Test serialization 1482 b = io.BytesIO() 1483 torch.save(dynamic_module, b) 1484 b.seek(0) 1485 # weights_only=False as this is legacy code that saves the model 1486 loaded_conv = torch.load(b, weights_only=False) 1487 1488 self.assertEqual(loaded_conv.bias(), dynamic_module.bias()) 1489 self.assertEqual(loaded_conv.scale, dynamic_module.scale) 1490 self.assertEqual(loaded_conv.zero_point, 1491 dynamic_module.zero_point) 1492 1493 # Test copy and deepcopy 1494 copied_conv = copy.copy(dynamic_module) 1495 self.assertEqual(copied_conv.bias(), dynamic_module.bias()) 1496 self.assertEqual(copied_conv.scale, dynamic_module.scale) 1497 self.assertEqual(copied_conv.zero_point, 1498 dynamic_module.zero_point) 1499 Y_copied = copied_conv(X_fp32, reduce_range) 1500 np.testing.assert_array_almost_equal( 1501 Y.numpy(), Y_copied.numpy(), decimal=0) 1502 1503 deepcopied_conv = copy.deepcopy(dynamic_module) 1504 self.assertEqual(deepcopied_conv.bias(), dynamic_module.bias()) 1505 self.assertEqual(deepcopied_conv.scale, dynamic_module.scale) 1506 self.assertEqual(deepcopied_conv.zero_point, 1507 dynamic_module.zero_point) 1508 Y_deepcopied = copied_conv(X_fp32, reduce_range) 1509 np.testing.assert_array_almost_equal( 1510 Y.numpy(), Y_deepcopied.numpy(), decimal=0) 1511 1512 # need to fix this 1513 # JIT testing 1514 self.checkScriptable( 1515 dynamic_module, [[X_dq]], 1516 check_save_load=True) 1517 1518 # Test from_float 1519 conv_module = dynamic_module._FLOAT_MODULE(in_channels, out_channels, kernel_size) 1520 conv_module.qconfig = torch.ao.quantization.default_dynamic_qconfig # type: ignore[assignment] 1521 prepare_dynamic(conv_module) 1522 conv_module(X_dq) 1523 quantized_conv_module = dq_mod.from_float(conv_module) 1524 1525 # Smoke test to make sure the module actually runs 1526 quantized_conv_module(X_dq) 1527 1528 # Smoke test extra_repr 1529 self.assertEqual(dynamic_module._get_name(), quantized_conv_module._get_name()) 1530 1531 @override_qengines 1532 def test_dynamic_conv1d(self): 1533 q_mod = torch.ao.nn.quantized.Conv1d 1534 dq_mod = torch.ao.nn.quantized.dynamic.Conv1d 1535 dim = 3 1536 dtype = torch.quint8 1537 1538 for bias in [True, False]: 1539 self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias) 1540 1541 @override_qengines 1542 def test_dynamic_conv2d(self): 1543 q_mod = torch.ao.nn.quantized.Conv2d 1544 dq_mod = torch.ao.nn.quantized.dynamic.Conv2d 1545 dim = 4 1546 dtype = torch.quint8 1547 1548 for bias in [True, False]: 1549 self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias) 1550 1551 @override_qengines 1552 def test_dynamic_conv3d(self): 1553 q_mod = torch.ao.nn.quantized.Conv3d 1554 dq_mod = torch.ao.nn.quantized.dynamic.Conv3d 1555 dim = 5 1556 dtype = torch.quint8 1557 1558 if qengine_is_qnnpack(): 1559 return # qnnpack doesn't support unpacking conv3d 1560 for bias in [True, False]: 1561 self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias) 1562 1563 @override_qengines 1564 def test_dynamic_convtranspose1d(self): 1565 q_mod = torch.ao.nn.quantized.ConvTranspose1d 1566 dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose1d 1567 dim = 3 1568 dtype = torch.quint8 1569 1570 for bias in [True, False]: 1571 self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias) 1572 1573 @override_qengines 1574 def test_dynamic_convtranspose2d(self): 1575 q_mod = torch.ao.nn.quantized.ConvTranspose2d 1576 dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose2d 1577 dim = 4 1578 dtype = torch.quint8 1579 1580 for bias in [True, False]: 1581 self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias) 1582 1583 @override_qengines 1584 def test_dynamic_convtranspose3d(self): 1585 q_mod = torch.ao.nn.quantized.ConvTranspose3d 1586 dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose3d 1587 dim = 5 1588 dtype = torch.quint8 1589 1590 if qengine_is_qnnpack(): 1591 return # qnnpack doesn't support unpacking conv3d 1592 for bias in [True, False]: 1593 self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias) 1594 1595 @given( 1596 batch_size=st.integers(1, 5), 1597 in_features=st.integers(16, 32), 1598 out_features=st.integers(4, 8), 1599 use_bias=st.booleans(), 1600 use_default_observer=st.booleans(), 1601 ) 1602 @override_qengines 1603 def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer): 1604 """test API functionality for nn.quantized.dynamic.Linear""" 1605 W = torch.rand(out_features, in_features).float() 1606 qscheme = torch.per_tensor_symmetric if qengine_is_onednn() else torch.per_tensor_affine 1607 W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8, qscheme=qscheme) 1608 W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8) 1609 X = torch.rand(batch_size, in_features).float() 1610 B = torch.rand(out_features).float() if use_bias else None 1611 qlinear = nnqd.Linear(in_features, out_features) 1612 # Run module with default-initialized parameters. 1613 # This tests that the constructor is correct. 1614 qlinear.set_weight_bias(W_q, B) 1615 qlinear(X) 1616 1617 # Simple round-trip test to ensure weight()/set_weight() API 1618 self.assertEqual(qlinear.weight(), W_q) 1619 W_pack = qlinear._packed_params._packed_params 1620 Z_dq = qlinear(X) 1621 1622 # Check if the module implementation matches calling the 1623 # ops directly 1624 Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack, reduce_range=True) 1625 self.assertEqual(Z_ref, Z_dq) 1626 1627 # Test serialization of dynamic quantized Linear Module using state_dict 1628 model_dict = qlinear.state_dict() 1629 b = io.BytesIO() 1630 torch.save(model_dict, b) 1631 for weights_only in [True, False]: 1632 b.seek(0) 1633 loaded_dict = torch.load(b, weights_only=weights_only) 1634 for key in model_dict: 1635 if isinstance(model_dict[key], torch._C.ScriptObject): 1636 assert isinstance(loaded_dict[key], torch._C.ScriptObject) 1637 w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key]) 1638 w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key]) 1639 self.assertEqual(w_model, w_loaded) 1640 self.assertEqual(b_model, b_loaded) 1641 else: 1642 self.assertEqual(model_dict[key], loaded_dict[key]) 1643 loaded_qlinear = nnqd.Linear(in_features, out_features) 1644 loaded_qlinear.load_state_dict(loaded_dict) 1645 1646 linear_unpack = torch.ops.quantized.linear_unpack 1647 self.assertEqual(linear_unpack(qlinear._packed_params._packed_params), 1648 linear_unpack(loaded_qlinear._packed_params._packed_params)) 1649 if use_bias: 1650 self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) 1651 self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) 1652 self.assertTrue(hasattr(qlinear, '_packed_params')) 1653 self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) 1654 self.assertTrue(hasattr(qlinear, '_weight_bias')) 1655 self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) 1656 1657 self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) 1658 self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params)) 1659 Z_dq2 = qlinear(X) 1660 self.assertEqual(Z_dq, Z_dq2) 1661 1662 b = io.BytesIO() 1663 torch.save(qlinear, b) 1664 b.seek(0) 1665 # weights_only=False as this is legacy code that saves the model 1666 loaded = torch.load(b, weights_only=False) 1667 self.assertEqual(qlinear.weight(), loaded.weight()) 1668 self.assertEqual(qlinear.zero_point, loaded.zero_point) 1669 1670 # Test JIT 1671 self.checkScriptable(qlinear, [[X]], check_save_load=True) 1672 1673 modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] 1674 for mut in modules_under_test: 1675 # Test from_float 1676 float_linear = mut(in_features, out_features).float() 1677 if use_default_observer: 1678 float_linear.qconfig = torch.ao.quantization.default_dynamic_qconfig 1679 prepare_dynamic(float_linear) 1680 float_linear(X.float()) 1681 quantized_float_linear = nnqd.Linear.from_float(float_linear) 1682 1683 # Smoke test to make sure the module actually runs 1684 quantized_float_linear(X) 1685 1686 # Smoke test extra_repr 1687 self.assertTrue('QuantizedLinear' in str(quantized_float_linear)) 1688 1689 @given( 1690 dtype=st.sampled_from([torch.qint8, torch.float16]), 1691 bidirectional=st.booleans(), 1692 ) 1693 @override_qengines 1694 def test_lstm_api(self, dtype, bidirectional): 1695 r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 1696 """ 1697 # Check that module matches the numerics of the op and ensure that module can be 1698 # instantiated for all engines and dtypes 1699 seq_len = 4 1700 batch = 2 1701 input_size = 3 1702 hidden_size = 7 1703 num_layers = 2 1704 bias = True 1705 weight_keys = [] 1706 bias_keys = [] 1707 num_directions = 2 if bidirectional else 1 1708 for layer in range(num_layers): 1709 for direction in range(num_directions): 1710 suffix = '_reverse' if direction == 1 else '' 1711 key_name1 = f'weight_ih_l{layer}{suffix}' 1712 key_name2 = f'weight_hh_l{layer}{suffix}' 1713 weight_keys.append(key_name1) 1714 weight_keys.append(key_name2) 1715 key_name1 = f'bias_ih_l{layer}{suffix}' 1716 key_name2 = f'bias_hh_l{layer}{suffix}' 1717 bias_keys.append(key_name1) 1718 bias_keys.append(key_name2) 1719 1720 if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): 1721 # fp16 dynamic quant is not supported for qnnpack or onednn 1722 x = torch.randn(seq_len, batch, input_size) 1723 h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) 1724 c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) 1725 cell_dq = torch.ao.nn.quantized.dynamic.LSTM(input_size=input_size, 1726 hidden_size=hidden_size, 1727 num_layers=num_layers, 1728 bias=bias, 1729 batch_first=False, 1730 dropout=0.0, 1731 bidirectional=bidirectional, 1732 dtype=dtype) 1733 ref_dq = torch.ao.nn.quantized.dynamic.LSTM(input_size=input_size, 1734 hidden_size=hidden_size, 1735 num_layers=num_layers, 1736 bias=bias, 1737 batch_first=False, 1738 dropout=0.0, 1739 bidirectional=bidirectional, 1740 dtype=dtype) 1741 1742 _all_params = ([m.param for m in cell_dq._all_weight_values]) 1743 result = torch.quantized_lstm(x, (h, c), 1744 _all_params, 1745 cell_dq.bias, 1746 cell_dq.num_layers, 1747 float(cell_dq.dropout), 1748 False, 1749 bidirectional, 1750 False, 1751 dtype=dtype, 1752 use_dynamic=True) 1753 1754 1755 y, (h, c) = cell_dq(x, (h, c)) 1756 self.assertEqual(result[0], y) 1757 self.assertEqual(result[1], h) 1758 self.assertEqual(result[2], c) 1759 x = torch.randn(10, 20, 3) 1760 self.check_eager_serialization(cell_dq, ref_dq, [x]) 1761 self.check_weight_bias_api(cell_dq, weight_keys, bias_keys) 1762 1763 @override_qengines 1764 def test_gru_api(self): 1765 r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 1766 """ 1767 # Check that module matches the numerics of the op and ensure that module can be 1768 # instantiated for all engines and dtypes 1769 1770 for dtype in [torch.qint8, torch.float16]: 1771 if dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn"): 1772 # fp16 dynamic quant is not supported for qnnpack or onednn 1773 continue 1774 # Test default instantiation 1775 seq_len = 4 1776 batch = 2 1777 input_size = 3 1778 hidden_size = 7 1779 num_layers = 2 1780 bias = True 1781 bidirectional = False 1782 1783 x = torch.rand(seq_len, batch, input_size) 1784 h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size) 1785 1786 1787 cell_dq = torch.ao.nn.quantized.dynamic.GRU(input_size=input_size, 1788 hidden_size=hidden_size, 1789 num_layers=num_layers, 1790 bias=bias, 1791 batch_first=False, 1792 dropout=0.0, 1793 bidirectional=bidirectional, 1794 dtype=dtype) 1795 1796 _all_params = ([m.param for m in cell_dq._all_weight_values]) 1797 result = torch.quantized_gru(x, 1798 h, 1799 _all_params, 1800 cell_dq.bias, 1801 cell_dq.num_layers, 1802 float(cell_dq.dropout), 1803 False, 1804 bidirectional, 1805 False) 1806 1807 1808 y, h = cell_dq(x, h) 1809 self.assertEqual(result[0], y, msg="GRU module API failed") 1810 self.assertEqual(result[1], h, msg="GRU module API failed") 1811 1812 @given( 1813 dtype=st.sampled_from([torch.qint8, torch.float16]), 1814 ) 1815 @override_qengines 1816 def test_cell_api(self, dtype): 1817 r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16 1818 """ 1819 # Check that module matches the numerics of the op and ensure that module can be 1820 # instantiated for all engines and dtypes 1821 batch = 7 1822 input_size = 3 1823 hidden_size = 7 1824 bias = True 1825 1826 x = torch.rand(batch, input_size) 1827 h = torch.rand(batch, hidden_size) 1828 cell_dict = {'LSTMCell': torch.ao.nn.quantized.dynamic.LSTMCell, 1829 'GRUCell': torch.ao.nn.quantized.dynamic.GRUCell, 1830 'RNNTanh': torch.ao.nn.quantized.dynamic.RNNCell, 1831 'RNNReLU': torch.ao.nn.quantized.dynamic.RNNCell 1832 } 1833 state = {'LSTMCell': (h, h), 1834 'GRUCell': h, 1835 'RNNTanh': h, 1836 'RNNReLU': h} 1837 1838 qfn_dict = {'LSTMCell': torch.ops.quantized.quantized_lstm_cell_dynamic, 1839 'GRUCell': torch.ops.quantized.quantized_gru_cell_dynamic, 1840 'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic, 1841 'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic} 1842 1843 for rnn_type in cell_dict.keys(): 1844 if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): 1845 # fp16 dynamic quant is not supported for qnnpack or onednn 1846 kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype} 1847 if rnn_type == 'RNNReLU': 1848 kwargs['nonlinearity'] = "relu" 1849 elif rnn_type == 'RNNTanh': 1850 kwargs['nonlinearity'] = "tanh" 1851 1852 cell_dq = cell_dict[rnn_type](**kwargs) 1853 result = qfn_dict[rnn_type](x, state[rnn_type], 1854 cell_dq._packed_weight_ih, cell_dq._packed_weight_hh, 1855 cell_dq.bias_ih, cell_dq.bias_hh) 1856 result_module = cell_dq(x, state[rnn_type]) 1857 self.assertEqual(result[0], result_module[0], msg="RNNCell module API failed") 1858 self.assertEqual(result[1], result_module[1], msg="RNNCell module API failed") 1859 weight_keys = ['weight_ih', 'weight_hh'] 1860 bias_keys = ['bias_ih', 'bias_hh'] 1861 self.check_eager_serialization(cell_dq, cell_dict[rnn_type](**kwargs), [x]) 1862 self.check_weight_bias_api(cell_dq, weight_keys, bias_keys) 1863 1864class TestReferenceQuantizedModule(QuantizationTestCase): 1865 def _quant_dequant_weight(self, weight, weight_qparams): 1866 qscheme = weight_qparams["qscheme"] 1867 scale = weight_qparams["scale"] 1868 zero_point = weight_qparams["zero_point"] 1869 dtype = weight_qparams["dtype"] 1870 if qscheme == torch.per_tensor_affine: 1871 weight = torch.quantize_per_tensor(weight, scale, zero_point, dtype) 1872 else: 1873 # per channel affine 1874 axis = weight_qparams["axis"] 1875 weight = torch.quantize_per_channel(weight, scale, zero_point, axis, dtype) 1876 weight = weight.dequantize() 1877 return weight 1878 1879 # TODO: add tests for conv and linear 1880 def test_rnn_cell(self): 1881 """ Checks the rnn cell reference quantized modules has correct numerics 1882 This includes LSTMCell, GRUCell, RNNCell 1883 """ 1884 batch = 7 1885 input_size = 3 1886 hidden_size = 7 1887 bias = True 1888 1889 x = torch.rand(batch, input_size) 1890 h = torch.rand(batch, hidden_size) 1891 cell_dict = {'LSTMCell': torch.nn.LSTMCell, 1892 'GRUCell': torch.nn.GRUCell, 1893 'RNNTanh': torch.nn.RNNCell, 1894 'RNNReLU': torch.nn.RNNCell 1895 } 1896 state = {'LSTMCell': (h, h), 1897 'GRUCell': h, 1898 'RNNTanh': h, 1899 'RNNReLU': h} 1900 1901 qfn_dict = {'LSTMCell': nnqr.LSTMCell, 1902 'GRUCell': nnqr.GRUCell, 1903 'RNNTanh': nnqr.RNNCell, 1904 'RNNReLU': nnqr.RNNCell} 1905 1906 for rnn_type in cell_dict.keys(): 1907 kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias} 1908 if rnn_type == 'RNNReLU': 1909 kwargs['nonlinearity'] = "relu" 1910 elif rnn_type == 'RNNTanh': 1911 kwargs['nonlinearity'] = "tanh" 1912 1913 fp_cell = cell_dict[rnn_type](**kwargs) 1914 # initialize ref rnn cell module 1915 weight_qparams = { 1916 'qscheme': torch.per_tensor_affine, 1917 'dtype': torch.quint8, 1918 'scale': 2.0, 1919 'zero_point': 5 1920 } 1921 weight_qparams_dict = { 1922 "weight_ih": weight_qparams, 1923 "weight_hh": weight_qparams, 1924 "is_decomposed": False, 1925 } 1926 ref_kwargs = kwargs.copy() 1927 ref_kwargs["weight_qparams_dict"] = weight_qparams_dict 1928 ref_cell = qfn_dict[rnn_type](**ref_kwargs) 1929 # reassign the weights from fp32 rnn cell modulea 1930 ref_cell.weight_ih = fp_cell.weight_ih 1931 ref_cell.weight_hh = fp_cell.weight_hh 1932 ref_cell.bias_ih = fp_cell.bias_ih 1933 ref_cell.bias_hh = fp_cell.bias_hh 1934 1935 ref_res = ref_cell(x, state[rnn_type]) 1936 1937 # change the weight of fp_res, we first want to run a quantie and 1938 # dequantize on the weight 1939 fp_cell.weight_ih = torch.nn.Parameter(self._quant_dequant_weight(fp_cell.weight_ih, weight_qparams_dict["weight_ih"])) 1940 fp_cell.weight_hh = torch.nn.Parameter(self._quant_dequant_weight(fp_cell.weight_hh, weight_qparams_dict["weight_hh"])) 1941 fp_res = fp_cell(x, state[rnn_type]) 1942 self.assertEqual(ref_res[0], fp_res[0], msg="RNNCell module API failed") 1943 self.assertEqual(ref_res[1], fp_res[1], msg="RNNCell module API failed") 1944 1945 def test_rnn(self): 1946 """ Checks the rnn reference quantized modules has correct numerics 1947 This includes LSTM 1948 """ 1949 seq_len = 4 1950 batch = 2 1951 input_size = 3 1952 hidden_size = 7 1953 num_layers = 2 1954 bias = True 1955 for bidirectional in [True, False]: 1956 x = torch.randn(seq_len, batch, input_size) 1957 h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) 1958 c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size) 1959 fp32_rnn = torch.nn.LSTM( 1960 input_size=input_size, 1961 hidden_size=hidden_size, 1962 num_layers=num_layers, 1963 bias=bias, 1964 batch_first=False, 1965 dropout=0.0, 1966 bidirectional=bidirectional) 1967 # initialize ref rnn module 1968 weight_qparams = { 1969 "qscheme": torch.per_tensor_affine, 1970 "dtype": torch.qint8, 1971 "scale": 2.0, 1972 "zero_point": 5 1973 } 1974 weight_qparams_dict = {key: weight_qparams for key in fp32_rnn._flat_weights_names if key.startswith("weight")} 1975 weight_qparams_dict["is_decomposed"] = False 1976 ref_rnn = nnqr.LSTM( 1977 input_size=input_size, 1978 hidden_size=hidden_size, 1979 num_layers=num_layers, 1980 bias=bias, 1981 batch_first=False, 1982 dropout=0.0, 1983 bidirectional=bidirectional, 1984 weight_qparams_dict=weight_qparams_dict) 1985 for wn in fp32_rnn._flat_weights_names: 1986 setattr(ref_rnn, wn, copy.deepcopy(getattr(fp32_rnn, wn))) 1987 1988 ref_rnn._flat_weights = copy.deepcopy(fp32_rnn._flat_weights) 1989 1990 # quantize and dequantize the weights for fp32_rnn module 1991 flat_weights = [] 1992 for wn in fp32_rnn._flat_weights_names: 1993 if wn.startswith("weight"): 1994 weight = self._quant_dequant_weight(getattr(fp32_rnn, wn), weight_qparams) 1995 else: 1996 weight = getattr(fp32_rnn, wn) 1997 flat_weights.append(weight) 1998 fp32_rnn._flat_weights = flat_weights 1999 2000 fp32_res = fp32_rnn(x, (h, c)) 2001 ref_res = ref_rnn(x, (h, c)) 2002 self.assertEqual(fp32_res, ref_res) 2003 2004 def test_sparse(self): 2005 """ Embedding and EmbeddingBag 2006 """ 2007 num_embeddings = 10 2008 embedding_dim = 3 2009 # embedding input 2010 ex = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 2011 2012 # embedding bag input 2013 ebx = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) 2014 offsets = torch.tensor([0, 4], dtype=torch.long) 2015 2016 fp_to_ref = { 2017 nn.Embedding: (nnqr.Embedding, (ex,)), 2018 nn.EmbeddingBag: (nnqr.EmbeddingBag, (ebx, offsets)), 2019 } 2020 2021 per_tensor_weight_qparams = { 2022 "qscheme": torch.per_tensor_affine, 2023 "dtype": torch.quint8, 2024 "scale": 2.0, 2025 "zero_point": 5, 2026 "is_decomposed": False, 2027 } 2028 2029 per_channel_weight_qparams = { 2030 "qscheme": torch.per_channel_affine, 2031 "dtype": torch.quint8, 2032 "scale": torch.randn(10), 2033 "zero_point": torch.randint(0, 255, (10,)), 2034 "axis": 0, 2035 "is_decomposed": False, 2036 } 2037 2038 per_channel_weight_qparams_quint4x2 = { 2039 "qscheme": torch.per_channel_affine_float_qparams, 2040 "dtype": torch.quint4x2, 2041 "scale": torch.randn(10), 2042 "zero_point": torch.randint(0, 255, (10,)), 2043 "axis": 0, 2044 "is_decomposed": False, 2045 } 2046 2047 weight_qparams_options = [ 2048 per_tensor_weight_qparams, 2049 per_channel_weight_qparams, 2050 per_channel_weight_qparams_quint4x2, 2051 ] 2052 for fp_cls, weight_qparams in itertools.product([nn.Embedding, nn.EmbeddingBag], weight_qparams_options): 2053 # TODO: torch.quint4x2 not supported in quantize_per_channel, need to add support 2054 if weight_qparams["dtype"] == torch.quint4x2: 2055 continue 2056 ref_cls, args = fp_to_ref[fp_cls] 2057 2058 fp32_embedding = fp_cls(num_embeddings, embedding_dim) 2059 2060 ref_embedding = ref_cls(num_embeddings, embedding_dim, weight_qparams=weight_qparams) 2061 ref_embedding.weight = fp32_embedding.weight 2062 2063 # quantize and dequantize the weight for fp32 module 2064 fp32_embedding.weight = torch.nn.Parameter(self._quant_dequant_weight(fp32_embedding.weight, weight_qparams)) 2065 2066 fp32_res = fp32_embedding(*args) 2067 ref_res = ref_embedding(*args) 2068 self.assertEqual(fp32_res, ref_res) 2069 2070 def test_linear_decomposed_weight_custom_qmin_qmax(self): 2071 """Verify that reference Linear respects custom qmin/qmax for weight 2072 """ 2073 linear_fp32 = torch.nn.Linear(2, 2) 2074 qconfig = torch.ao.quantization.default_symmetric_qnnpack_qconfig 2075 w_obs = qconfig.weight() 2076 self.assertTrue(w_obs.quant_min == -127) 2077 self.assertTrue(w_obs.quant_max == 127) 2078 w_obs(linear_fp32.weight) 2079 weight_qparams = torch.ao.quantization.utils.get_qparam_dict(w_obs) 2080 weight_qparams["is_decomposed"] = True 2081 linear_ref = nnqr.Linear.from_float(linear_fp32, weight_qparams) 2082 linear_ref_traced = torch.fx.symbolic_trace(linear_ref) 2083 2084 # verify that the qmin/qmax arguments for weight q/dq are correctly 2085 # taken from the observer 2086 found = 0 2087 for n in linear_ref_traced.graph.nodes: 2088 if n.op != 'call_function': 2089 continue 2090 if n.target in ( 2091 torch.ops.quantized_decomposed.quantize_per_tensor, 2092 torch.ops.quantized_decomposed.dequantize_per_tensor, 2093 ): 2094 _0, _1, _2, qmin, qmax, _5 = n.args 2095 self.assertTrue(qmin == -127) 2096 self.assertTrue(qmax == 127) 2097 found += 1 2098 self.assertTrue(found == 2) 2099