1# Owner(s): ["oncall: quantization"] 2 3import unittest 4import torch 5import torch.nn as nn 6import torch.ao.nn.quantized as nnq 7from torch.ao.quantization import ( 8 DeQuantStub, 9 QuantStub, 10 convert, 11 default_qconfig, 12 prepare, 13 quantize, 14 quantize_dynamic, 15) 16from torch.ao.ns._numeric_suite import ( 17 OutputLogger, 18 Shadow, 19 ShadowLogger, 20 compare_model_outputs, 21 compare_model_stub, 22 compare_weights, 23 prepare_model_outputs, 24 get_matching_activations, 25) 26from torch.testing._internal.common_quantization import ( 27 AnnotatedConvBnReLUModel, 28 AnnotatedConvModel, 29 AnnotatedConvTransposeModel, 30 AnnotatedSingleLayerLinearModel, 31 LSTMwithHiddenDynamicModel, 32 AnnotatedTwoLayerLinearModel, 33 QuantizationTestCase, 34 SingleLayerLinearDynamicModel, 35 test_only_eval_fn, 36 skip_if_no_torchvision, 37) 38from torch.testing._internal.common_quantized import override_qengines 39from torch.testing._internal.common_utils import IS_ARM64 40 41class SubModule(torch.nn.Module): 42 def __init__(self) -> None: 43 super().__init__() 44 self.qconfig = default_qconfig 45 self.mod1 = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float) 46 self.mod2 = nn.ReLU() 47 self.quant = QuantStub() 48 self.dequant = DeQuantStub() 49 50 def forward(self, x): 51 x = self.quant(x) 52 x = self.mod1(x) 53 x = self.mod2(x) 54 x = self.dequant(x) 55 return x 56 57 58class ModelWithSubModules(torch.nn.Module): 59 def __init__(self) -> None: 60 super().__init__() 61 self.mod1 = SubModule() 62 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 63 64 def forward(self, x): 65 x = self.mod1(x) 66 x = self.conv(x) 67 return x 68 69 70class ModelWithFunctionals(torch.nn.Module): 71 def __init__(self) -> None: 72 super().__init__() 73 self.mycat = nnq.FloatFunctional() 74 self.myadd = nnq.FloatFunctional() 75 self.mymul = nnq.FloatFunctional() 76 self.myadd_relu = nnq.FloatFunctional() 77 self.my_scalar_add = nnq.FloatFunctional() 78 self.my_scalar_mul = nnq.FloatFunctional() 79 self.quant = QuantStub() 80 self.dequant = DeQuantStub() 81 82 def forward(self, x): 83 x = self.quant(x) 84 x = self.mycat.cat([x, x, x]) 85 x = self.myadd.add(x, x) 86 x = self.mymul.mul(x, x) 87 x = self.myadd_relu.add_relu(x, x) 88 w = self.my_scalar_add.add_scalar(x, -0.5) 89 w = self.my_scalar_mul.mul_scalar(w, 0.5) 90 91 w = self.dequant(w) 92 return w 93 94 95class TestNumericSuiteEager(QuantizationTestCase): 96 @override_qengines 97 def test_compare_weights_conv_static(self): 98 r"""Compare the weights of float and static quantized conv layer""" 99 100 qengine = torch.backends.quantized.engine 101 102 def compare_and_validate_results(float_model, q_model): 103 weight_dict = compare_weights( 104 float_model.state_dict(), q_model.state_dict() 105 ) 106 self.assertEqual(len(weight_dict), 1) 107 for v in weight_dict.values(): 108 self.assertTrue(v["float"].shape == v["quantized"].shape) 109 110 model_list = [AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine)] 111 for model in model_list: 112 model.eval() 113 if hasattr(model, "fuse_model"): 114 model.fuse_model() 115 q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) 116 compare_and_validate_results(model, q_model) 117 118 @override_qengines 119 def test_compare_weights_linear_static(self): 120 r"""Compare the weights of float and static quantized linear layer""" 121 122 qengine = torch.backends.quantized.engine 123 124 def compare_and_validate_results(float_model, q_model): 125 weight_dict = compare_weights( 126 float_model.state_dict(), q_model.state_dict() 127 ) 128 self.assertEqual(len(weight_dict), 1) 129 for v in weight_dict.values(): 130 self.assertTrue(v["float"].shape == v["quantized"].shape) 131 132 model_list = [AnnotatedSingleLayerLinearModel(qengine)] 133 for model in model_list: 134 model.eval() 135 if hasattr(model, "fuse_model"): 136 model.fuse_model() 137 q_model = quantize(model, test_only_eval_fn, [self.calib_data]) 138 compare_and_validate_results(model, q_model) 139 140 @override_qengines 141 def test_compare_weights_linear_dynamic(self): 142 r"""Compare the weights of float and dynamic quantized linear layer""" 143 144 qengine = torch.backends.quantized.engine 145 146 def compare_and_validate_results(float_model, q_model): 147 weight_dict = compare_weights( 148 float_model.state_dict(), q_model.state_dict() 149 ) 150 self.assertEqual(len(weight_dict), 1) 151 for v in weight_dict.values(): 152 self.assertTrue(len(v["float"]) == len(v["quantized"])) 153 for i, val in enumerate(v["quantized"]): 154 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 155 156 model_list = [SingleLayerLinearDynamicModel(qengine)] 157 for model in model_list: 158 model.eval() 159 if hasattr(model, "fuse_model"): 160 model.fuse_model() 161 q_model = quantize_dynamic(model) 162 compare_and_validate_results(model, q_model) 163 164 @override_qengines 165 def test_compare_weights_lstm_dynamic(self): 166 r"""Compare the weights of float and dynamic quantized LSTM layer""" 167 168 qengine = torch.backends.quantized.engine 169 170 def compare_and_validate_results(float_model, q_model): 171 weight_dict = compare_weights( 172 float_model.state_dict(), q_model.state_dict() 173 ) 174 self.assertEqual(len(weight_dict), 1) 175 for v in weight_dict.values(): 176 self.assertTrue(len(v["float"]) == len(v["quantized"])) 177 for i, val in enumerate(v["quantized"]): 178 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 179 180 model_list = [LSTMwithHiddenDynamicModel(qengine)] 181 for model in model_list: 182 model.eval() 183 if hasattr(model, "fuse_model"): 184 model.fuse_model() 185 q_model = quantize_dynamic(model) 186 compare_and_validate_results(model, q_model) 187 188 @override_qengines 189 def test_compare_model_stub_conv_static(self): 190 r"""Compare the output of static quantized conv layer and its float shadow module""" 191 192 qengine = torch.backends.quantized.engine 193 194 def compare_and_validate_results(float_model, q_model, module_swap_list, data): 195 ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) 196 self.assertEqual(len(ob_dict), 1) 197 for v in ob_dict.values(): 198 self.assertTrue(len(v["float"]) == len(v["quantized"])) 199 for i, val in enumerate(v["quantized"]): 200 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 201 202 model_list = [AnnotatedConvModel(qengine), 203 AnnotatedConvTransposeModel("qnnpack"), # ConvT cannot use per channel weights 204 AnnotatedConvBnReLUModel(qengine)] 205 module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d, nn.ConvTranspose2d] 206 for model in model_list: 207 model.eval() 208 if hasattr(model, "fuse_model"): 209 model.fuse_model() 210 q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) 211 compare_and_validate_results( 212 model, q_model, module_swap_list, self.img_data_2d[0][0] 213 ) 214 215 @override_qengines 216 def test_compare_model_stub_linear_static(self): 217 r"""Compare the output of static quantized linear layer and its float shadow module""" 218 219 qengine = torch.backends.quantized.engine 220 221 def compare_and_validate_results(float_model, q_model, module_swap_list, data): 222 ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) 223 self.assertEqual(len(ob_dict), 1) 224 for v in ob_dict.values(): 225 self.assertTrue(len(v["float"]) == len(v["quantized"])) 226 for i, val in enumerate(v["quantized"]): 227 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 228 229 linear_data = self.calib_data[0][0] 230 module_swap_list = [nn.Linear] 231 model_list = [AnnotatedSingleLayerLinearModel(qengine)] 232 for model in model_list: 233 model.eval() 234 if hasattr(model, "fuse_model"): 235 model.fuse_model() 236 q_model = quantize(model, test_only_eval_fn, [self.calib_data]) 237 compare_and_validate_results(model, q_model, module_swap_list, linear_data) 238 239 @override_qengines 240 def test_compare_model_stub_partial(self): 241 r"""Compare the output of static quantized linear layer and its float shadow module""" 242 243 qengine = torch.backends.quantized.engine 244 # TODO: Rebase on top of PR to remove compare and validate results here 245 246 def compare_and_validate_results(float_model, q_model, module_swap_list, data): 247 ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) 248 self.assertEqual(len(ob_dict), 1) 249 for v in ob_dict.values(): 250 self.assertTrue(len(v["float"]) == len(v["quantized"])) 251 for i, val in enumerate(v["quantized"]): 252 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 253 254 linear_data = self.calib_data[0][0] 255 module_swap_list = [nn.Linear] 256 model_list = [AnnotatedTwoLayerLinearModel()] 257 for model in model_list: 258 model.eval() 259 if hasattr(model, "fuse_model"): 260 model.fuse_model() 261 q_model = quantize(model, test_only_eval_fn, [self.calib_data]) 262 compare_and_validate_results(model, q_model, module_swap_list, linear_data) 263 264 @override_qengines 265 def test_compare_model_stub_submodule_static(self): 266 r"""Compare the output of static quantized submodule and its float shadow module""" 267 268 qengine = torch.backends.quantized.engine 269 270 model = ModelWithSubModules().eval() 271 q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) 272 module_swap_list = [SubModule, nn.Conv2d] 273 ob_dict = compare_model_stub( 274 model, q_model, module_swap_list, self.img_data_2d[0][0] 275 ) 276 # Since conv is not quantized, we do not insert a shadow module 277 # mod1 contains a linear that is quantized, so we insert a shadow module 278 self.assertTrue(isinstance(q_model.mod1, Shadow)) 279 self.assertFalse(isinstance(q_model.conv, Shadow)) 280 281 282 @override_qengines 283 def test_compare_model_stub_functional_static(self): 284 r"""Compare the output of static quantized functional layer and its float shadow module""" 285 286 qengine = torch.backends.quantized.engine 287 288 model = ModelWithFunctionals().eval() 289 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 290 q_model = prepare(model, inplace=False) 291 q_model(self.img_data_2d[0][0]) 292 q_model = convert(q_model) 293 module_swap_list = [nnq.FloatFunctional] 294 ob_dict = compare_model_stub( 295 model, q_model, module_swap_list, self.img_data_2d[0][0] 296 ) 297 self.assertEqual(len(ob_dict), 6) 298 self.assertTrue(isinstance(q_model.mycat, Shadow)) 299 self.assertTrue(isinstance(q_model.myadd, Shadow)) 300 self.assertTrue(isinstance(q_model.mymul, Shadow)) 301 self.assertTrue(isinstance(q_model.myadd_relu, Shadow)) 302 self.assertTrue(isinstance(q_model.my_scalar_add, Shadow)) 303 self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow)) 304 for v in ob_dict.values(): 305 self.assertTrue(len(v["float"]) == len(v["quantized"])) 306 for i, val in enumerate(v["quantized"]): 307 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 308 309 @override_qengines 310 def test_compare_model_stub_linear_dynamic(self): 311 r"""Compare the output of dynamic quantized linear layer and its float shadow module""" 312 313 qengine = torch.backends.quantized.engine 314 315 def compare_and_validate_results(float_model, q_model, module_swap_list, data): 316 ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data) 317 self.assertEqual(len(ob_dict), 1) 318 for v in ob_dict.values(): 319 self.assertTrue(len(v["float"]) == len(v["quantized"])) 320 for i, val in enumerate(v["quantized"]): 321 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 322 323 linear_data = self.calib_data[0][0] 324 325 model_list = [SingleLayerLinearDynamicModel(qengine)] 326 module_swap_list = [nn.Linear, nn.LSTM] 327 for model in model_list: 328 model.eval() 329 if hasattr(model, "fuse_model"): 330 model.fuse_model() 331 q_model = quantize_dynamic(model) 332 compare_and_validate_results(model, q_model, module_swap_list, linear_data) 333 334 @override_qengines 335 def test_compare_model_stub_lstm_dynamic(self): 336 r"""Compare the output of dynamic quantized LSTM layer and its float shadow module""" 337 338 qengine = torch.backends.quantized.engine 339 340 def compare_and_validate_results( 341 float_model, q_model, module_swap_list, input, hidden 342 ): 343 ob_dict = compare_model_stub( 344 float_model, q_model, module_swap_list, input, hidden 345 ) 346 self.assertEqual(len(ob_dict), 1) 347 for v in ob_dict.values(): 348 self.assertTrue(len(v["float"]) == len(v["quantized"])) 349 for i, val in enumerate(v["quantized"]): 350 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 351 352 lstm_input = torch.rand((1, 1, 2)) 353 lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) 354 355 model_list = [LSTMwithHiddenDynamicModel(qengine)] 356 module_swap_list = [nn.Linear, nn.LSTM] 357 for model in model_list: 358 model.eval() 359 if hasattr(model, "fuse_model"): 360 model.fuse_model() 361 q_model = quantize_dynamic(model) 362 compare_and_validate_results( 363 model, q_model, module_swap_list, lstm_input, lstm_hidden 364 ) 365 366 @override_qengines 367 def test_compare_model_outputs_conv_static(self): 368 r"""Compare the output of conv layer in stataic quantized model and corresponding 369 output of conv layer in float model 370 """ 371 qengine = torch.backends.quantized.engine 372 373 def compare_and_validate_results(float_model, q_model, data): 374 act_compare_dict = compare_model_outputs(float_model, q_model, data) 375 expected_act_compare_dict_keys = {"conv.stats", "quant.stats"} 376 377 self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) 378 for v in act_compare_dict.values(): 379 self.assertTrue(v["float"][0].shape == v["quantized"][0].shape) 380 381 model_list = [AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine)] 382 for model in model_list: 383 model.eval() 384 if hasattr(model, "fuse_model"): 385 model.fuse_model() 386 q_model = quantize(model, test_only_eval_fn, [self.img_data_2d]) 387 compare_and_validate_results(model, q_model, self.img_data_2d[0][0]) 388 389 @override_qengines 390 def test_compare_model_outputs_linear_static(self): 391 r"""Compare the output of linear layer in static quantized model and corresponding 392 output of conv layer in float model 393 """ 394 qengine = torch.backends.quantized.engine 395 396 def compare_and_validate_results(float_model, q_model, data): 397 act_compare_dict = compare_model_outputs(float_model, q_model, data) 398 expected_act_compare_dict_keys = {"fc1.quant.stats", "fc1.module.stats"} 399 400 self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) 401 for v in act_compare_dict.values(): 402 self.assertTrue(len(v["float"]) == len(v["quantized"])) 403 for i, val in enumerate(v["quantized"]): 404 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 405 406 linear_data = self.calib_data[0][0] 407 model_list = [AnnotatedSingleLayerLinearModel(qengine)] 408 for model in model_list: 409 model.eval() 410 if hasattr(model, "fuse_model"): 411 model.fuse_model() 412 q_model = quantize(model, test_only_eval_fn, [self.calib_data]) 413 compare_and_validate_results(model, q_model, linear_data) 414 415 @override_qengines 416 def test_compare_model_outputs_functional_static(self): 417 r"""Compare the output of functional layer in static quantized model and corresponding 418 output of conv layer in float model 419 """ 420 qengine = torch.backends.quantized.engine 421 422 model = ModelWithFunctionals().eval() 423 model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 424 q_model = prepare(model, inplace=False) 425 q_model(self.img_data_2d[0][0]) 426 q_model = convert(q_model) 427 act_compare_dict = compare_model_outputs(model, q_model, self.img_data_2d[0][0]) 428 self.assertEqual(len(act_compare_dict), 5) 429 expected_act_compare_dict_keys = { 430 "mycat.stats", 431 "myadd.stats", 432 "mymul.stats", 433 "myadd_relu.stats", 434 "quant.stats", 435 } 436 self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) 437 for v in act_compare_dict.values(): 438 self.assertTrue(len(v["float"]) == len(v["quantized"])) 439 for i, val in enumerate(v["quantized"]): 440 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 441 442 @override_qengines 443 def test_compare_model_outputs_linear_dynamic(self): 444 r"""Compare the output of linear layer in dynamic quantized model and corresponding 445 output of conv layer in float model 446 """ 447 qengine = torch.backends.quantized.engine 448 449 def compare_and_validate_results(float_model, q_model, data): 450 act_compare_dict = compare_model_outputs(float_model, q_model, data) 451 expected_act_compare_dict_keys = {"fc1.stats"} 452 453 self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) 454 for v in act_compare_dict.values(): 455 self.assertTrue(len(v["float"]) == len(v["quantized"])) 456 for i, val in enumerate(v["quantized"]): 457 self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) 458 459 linear_data = self.calib_data[0][0] 460 461 model_list = [SingleLayerLinearDynamicModel(qengine)] 462 for model in model_list: 463 model.eval() 464 if hasattr(model, "fuse_model"): 465 model.fuse_model() 466 q_model = quantize_dynamic(model) 467 compare_and_validate_results(model, q_model, linear_data) 468 469 @override_qengines 470 def test_compare_model_outputs_lstm_dynamic(self): 471 r"""Compare the output of LSTM layer in dynamic quantized model and corresponding 472 output of conv layer in float model 473 """ 474 qengine = torch.backends.quantized.engine 475 476 def compare_and_validate_results(float_model, q_model, input, hidden): 477 act_compare_dict = compare_model_outputs( 478 float_model, q_model, input, hidden 479 ) 480 expected_act_compare_dict_keys = {"lstm.stats"} 481 482 self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys) 483 for v in act_compare_dict.values(): 484 self.assertTrue(len(v["float"]) == len(v["quantized"])) 485 for i, val in enumerate(v["quantized"]): 486 self.assertTrue(len(v["float"][i]) == len(v["quantized"][i])) 487 if i == 0: 488 self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape) 489 else: 490 self.assertTrue( 491 v["float"][i][0].shape == v["quantized"][i][0].shape 492 ) 493 self.assertTrue( 494 v["float"][i][1].shape == v["quantized"][i][1].shape 495 ) 496 497 lstm_input = torch.rand((1, 1, 2)) 498 lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) 499 500 model_list = [LSTMwithHiddenDynamicModel(qengine)] 501 for model in model_list: 502 model.eval() 503 if hasattr(model, "fuse_model"): 504 model.fuse_model() 505 q_model = quantize_dynamic(model) 506 compare_and_validate_results(model, q_model, lstm_input, lstm_hidden) 507 508 @override_qengines 509 def test_output_logger(self): 510 r"""Compare output from OutputLogger with the expected results""" 511 x = torch.rand(2, 2) 512 y = torch.rand(2, 1) 513 514 l = [] 515 l.append(x) 516 l.append(y) 517 518 logger = OutputLogger() 519 logger.forward(x) 520 logger.forward(y) 521 522 self.assertEqual(l, logger.stats["tensor_val"]) 523 524 @override_qengines 525 def test_shadow_logger(self): 526 r"""Compare output from ShawdowLogger with the expected results""" 527 a_float = torch.rand(2, 2) 528 a_quantized = torch.rand(2, 2) 529 530 b_float = torch.rand(3, 2, 2) 531 b_quantized = torch.rand(3, 2, 2) 532 533 logger = ShadowLogger() 534 logger.forward(a_float, a_quantized) 535 logger.forward(b_float, b_quantized) 536 537 self.assertEqual(len(logger.stats["float"]), 2) 538 self.assertEqual(len(logger.stats["quantized"]), 2) 539 540 @skip_if_no_torchvision 541 def _test_vision_model(self, float_model): 542 float_model.to('cpu') 543 float_model.eval() 544 float_model.fuse_model() 545 float_model.qconfig = torch.ao.quantization.default_qconfig 546 img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)] 547 qmodel = quantize(float_model, torch.ao.quantization.default_eval_fn, [img_data], inplace=False) 548 549 wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) 550 551 def compute_error(x, y): 552 Ps = torch.norm(x) 553 Pn = torch.norm(x - y) 554 return 20 * torch.log10(Ps / Pn) 555 556 data = img_data[0][0] 557 # Take in floating point and quantized model as well as input data, and returns a dict, with keys 558 # corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and 559 # 'quantized', containing the activations of floating point and quantized model at matching locations. 560 act_compare_dict = compare_model_outputs(float_model, qmodel, data) 561 562 563 for key in act_compare_dict: 564 compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize()) 565 566 prepare_model_outputs(float_model, qmodel) 567 568 for data in img_data: 569 float_model(data[0]) 570 qmodel(data[0]) 571 572 # Find the matching activation between floating point and quantized modules, and return a dict with key 573 # corresponding to quantized module names and each entry being a dictionary with two keys 'float' 574 # and 'quantized', containing the matching floating point and quantized activations logged by the logger 575 act_compare_dict = get_matching_activations(float_model, qmodel) 576 577 @skip_if_no_torchvision 578 @unittest.skipIf(IS_ARM64, "Not working on arm right now") 579 def test_mobilenet_v2(self): 580 from torchvision.models.quantization import mobilenet_v2 581 self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False)) 582 583 @skip_if_no_torchvision 584 @unittest.skipIf(IS_ARM64, "Not working on arm right now") 585 def test_mobilenet_v3(self): 586 from torchvision.models.quantization import mobilenet_v3_large 587 self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False)) 588