1# Owner(s): ["module: unknown"] 2 3 4import logging 5 6import torch 7import torch.ao.quantization as tq 8from torch import nn 9from torch.ao import pruning 10from torch.ao.pruning import fqn_to_module 11from torch.ao.quantization.quantize_fx import ( 12 convert_fx, 13 convert_to_reference_fx, 14 prepare_fx, 15 prepare_qat_fx, 16) 17from torch.testing._internal.common_utils import TestCase 18 19 20logging.basicConfig( 21 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 22) 23 24sparse_defaults = { 25 "sparsity_level": 0.8, 26 "sparse_block_shape": (1, 4), 27 "zeros_per_block": 4, 28} 29 30 31def _get_model_and_sparsifier_and_sparse_config(qconfig=None): 32 model = nn.Sequential( 33 nn.Linear(4, 4), # 0 34 nn.ReLU(), 35 nn.Linear(4, 4), # 2 36 nn.ReLU(), 37 tq.QuantStub(), 38 nn.Linear(4, 4), # 5 39 nn.ReLU(), 40 tq.DeQuantStub(), 41 ) 42 if qconfig: 43 model[4].qconfig = qconfig 44 model[5].qconfig = qconfig 45 46 sparsifier = pruning.WeightNormSparsifier(**sparse_defaults) 47 48 sparse_config = [ 49 { 50 "tensor_fqn": "5.weight", 51 "sparsity_level": 0.7, 52 "sparse_block_shape": (1, 4), 53 "zeros_per_block": 4, 54 }, 55 {"tensor_fqn": "0.weight"}, 56 ] 57 return model, sparsifier, sparse_config 58 59 60def _squash_mask_calibrate_and_convert(model, sparsifier, input): 61 sparsifier.step() 62 sparsifier.squash_mask() 63 model(input) 64 tq.convert(model, inplace=True) 65 66 67def _calculate_sparsity(tensor): 68 return ((tensor == 0).sum() / tensor.numel()).item() 69 70 71# This series of tests are to check the composability goals for sparsity and quantization. Namely 72# that performing quantization and sparsity model manipulations in various orderings 73# does not cause problems 74class TestComposability(TestCase): 75 # This test checks whether performing quantization prepare before sparse prepare 76 # causes any issues and verifies that the correct observers are inserted and that 77 # the quantized model works as expected 78 def test_q_prep_before_s_prep(self): 79 ( 80 mod, 81 sparsifier, 82 sparse_config, 83 ) = _get_model_and_sparsifier_and_sparse_config( 84 tq.get_default_qconfig("fbgemm") 85 ) 86 87 tq.prepare(mod, inplace=True) 88 sparsifier.prepare(mod, config=sparse_config) 89 90 # check that correct modules had parametrizations added 91 self.assertTrue(hasattr(mod[0], "parametrizations")) 92 self.assertTrue(hasattr(mod[5], "parametrizations")) 93 # check that correct observers were inserted 94 self.assertTrue(hasattr(mod[5], "activation_post_process")) 95 96 _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) 97 98 # check that final module is the expected quantized module and that the model runs 99 self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear)) 100 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 101 102 # This test checks whether performing sparsity prepare before quantization prepare 103 # causes any issues. In particular, previous quantization flow was unable to match 104 # the post sparse prepare module names (adding parametrizations changes the module class names) 105 # which would result in those parametrized modules not being quantized. This test verifies that 106 # the fix for this was successful. 107 def test_s_prep_before_q_prep(self): 108 ( 109 mod, 110 sparsifier, 111 sparse_config, 112 ) = _get_model_and_sparsifier_and_sparse_config( 113 tq.get_default_qconfig("fbgemm") 114 ) 115 116 sparsifier.prepare(mod, config=sparse_config) 117 tq.prepare(mod, inplace=True) 118 119 # check that correct modules had parametrizations added and 120 # that none were lost during prepare 121 self.assertTrue(hasattr(mod[0], "parametrizations")) 122 self.assertTrue(hasattr(mod[5], "parametrizations")) 123 124 # check that correct observers were inserted and that matching 125 # occurred successfully 126 self.assertTrue(hasattr(mod[5], "activation_post_process")) 127 128 _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) 129 130 # check that final module is the expected quantized module and that the model runs 131 self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear)) 132 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 133 134 # if the sparsified modules have not undergone the final squash mask operation, its possible 135 # that the problem outlined in test_s_prep_before_q_prep would occur. This test verifies 136 # both that the fix to the convert flow avoids this issue and that the resulting quantized 137 # module uses the sparse version of the weight value. 138 def test_convert_without_squash_mask(self): 139 ( 140 mod, 141 sparsifier, 142 sparse_config, 143 ) = _get_model_and_sparsifier_and_sparse_config( 144 tq.get_default_qconfig("fbgemm") 145 ) 146 147 sparsifier.prepare(mod, config=sparse_config) 148 tq.prepare(mod, inplace=True) 149 150 # check that correct modules had parametrizations added and 151 # that none were lost during prepare 152 self.assertTrue(hasattr(mod[0], "parametrizations")) 153 self.assertTrue(hasattr(mod[5], "parametrizations")) 154 155 # check that correct observers were inserted and that matching 156 # occurred successfully 157 self.assertTrue(hasattr(mod[5], "activation_post_process")) 158 sparsifier.step() 159 sparsity_level = _calculate_sparsity(mod[5].weight) 160 mod(torch.randn(1, 4, 4, 4)) 161 tq.convert(mod, inplace=True) 162 163 # check that final module is the expected quantized module and that the model runs 164 self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear)) 165 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 166 167 # check that module was actually sparsified 168 cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) 169 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 170 self.assertGreaterAlmostEqual( 171 sparsity_level, sparse_config[0]["sparsity_level"] 172 ) 173 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 174 175 # This tests whether performing sparse prepare before fusion causes any issues. The 176 # worry was that the link created between the sparsifier and the modules that need to 177 # be sparsified would be broken. 178 def test_s_prep_before_fusion(self): 179 ( 180 mod, 181 sparsifier, 182 sparse_config, 183 ) = _get_model_and_sparsifier_and_sparse_config( 184 tq.get_default_qconfig("fbgemm") 185 ) 186 sparsifier.prepare(mod, config=sparse_config) 187 tq.fuse_modules(mod, [["5", "6"]], inplace=True) 188 mod[5].qconfig = tq.get_default_qconfig("fbgemm") 189 tq.prepare(mod, inplace=True) 190 191 # check that correct modules had parametrizations added and 192 # that none were lost during prepare or fusion 193 self.assertTrue(hasattr(mod[0], "parametrizations")) 194 self.assertTrue(hasattr(mod[5][0], "parametrizations")) 195 196 # check that correct observers were inserted and that matching 197 # occurred successfully 198 self.assertTrue(hasattr(mod[5], "activation_post_process")) 199 _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) 200 201 # check that final module is the expected quantized module and that the model runs 202 self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU)) 203 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 204 205 # This tests whether performing fusion before sparse prepare causes and issues. The 206 # main worry was that the links to the modules in the sparse config would be broken by fusion. 207 def test_fusion_before_s_prep(self): 208 ( 209 mod, 210 sparsifier, 211 _, 212 ) = _get_model_and_sparsifier_and_sparse_config( 213 tq.get_default_qconfig("fbgemm") 214 ) 215 tq.fuse_modules(mod, [["5", "6"]], inplace=True) 216 217 # its absolutely broken by fusion but will still work if you put the correct fqn in 218 sparse_config = [ 219 { 220 "tensor_fqn": "5.0.weight", 221 "sparsity_level": 0.7, 222 "sparse_block_shape": (1, 4), 223 "zeros_per_block": 4, 224 }, 225 {"tensor_fqn": "0.weight"}, 226 ] 227 228 sparsifier.prepare(mod, config=sparse_config) 229 mod[5].qconfig = tq.get_default_qconfig("fbgemm") 230 tq.prepare(mod, inplace=True) 231 232 # check that correct modules had parametrizations added and 233 # that none were lost during prepare 234 self.assertTrue(hasattr(mod[0], "parametrizations")) 235 self.assertTrue(hasattr(mod[5][0], "parametrizations")) 236 237 # check that correct observers were inserted and that matching 238 # occurred successfully 239 self.assertTrue(hasattr(mod[5], "activation_post_process")) 240 sparsifier.step() 241 sparsity_level = _calculate_sparsity(mod[5][0].weight) 242 mod(torch.randn(1, 4, 4, 4)) 243 tq.convert(mod, inplace=True) 244 245 # check that final module is the expected quantized module and that the model runs 246 self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU)) 247 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 248 249 # check that module was actually sparsified 250 cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) 251 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 252 self.assertGreaterAlmostEqual( 253 sparsity_level, sparse_config[0]["sparsity_level"] 254 ) 255 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 256 257 # This tests whether performing sparse prepare before qat prepare causes issues. 258 # The primary worries were that qat_prep wouldn't recognize the parametrized 259 # modules and that the convert step for qat would remove the parametrizations 260 # from the modules. 261 def test_s_prep_before_qat_prep(self): 262 ( 263 mod, 264 sparsifier, 265 sparse_config, 266 ) = _get_model_and_sparsifier_and_sparse_config( 267 tq.get_default_qat_qconfig("fbgemm") 268 ) 269 sparsifier.prepare(mod, config=sparse_config) 270 tq.prepare_qat(mod, inplace=True) 271 self.assertTrue(hasattr(mod[0], "parametrizations")) 272 self.assertTrue(hasattr(mod[5], "parametrizations")) 273 274 # check that correct observers were inserted and that matching 275 # occurred successfully 276 self.assertTrue(hasattr(mod[5], "activation_post_process")) 277 self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear)) 278 _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) 279 # check that final module is the expected quantized module and that the model runs 280 self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear)) 281 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 282 283 # check that module was actually sparsified 284 cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) 285 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 286 287 # This tests whether performing qat prepare before sparse prepare causes issues. 288 def test_qat_prep_before_s_prep(self): 289 mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config( 290 tq.get_default_qat_qconfig("fbgemm") 291 ) 292 tq.prepare_qat(mod, inplace=True) 293 294 # need to setup sparse_config on new modules 295 sparse_config = [ 296 { 297 "tensor_fqn": "5.weight", 298 "sparsity_level": 0.7, 299 "sparse_block_shape": (1, 4), 300 "zeros_per_block": 4, 301 }, 302 {"tensor_fqn": "0.weight"}, 303 ] 304 sparsifier.prepare(mod, config=sparse_config) 305 306 # check that correct modules had parametrizations added and 307 # that none were lost during qat prepare 308 self.assertTrue(hasattr(mod[0], "parametrizations")) 309 self.assertTrue(hasattr(mod[5], "parametrizations")) 310 311 # check that correct observers were inserted and that matching 312 # occurred successfully 313 self.assertTrue(hasattr(mod[5], "activation_post_process")) 314 self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear)) 315 316 _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4)) 317 318 # check that final module is the expected quantized module and that the model runs 319 self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear)) 320 self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4])) 321 322 # check that module was actually sparsified 323 cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0]) 324 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 325 326 327def _module_has_activation_post_process(model, fqn_of_module): 328 for node in model.graph.nodes: 329 # look for an observer whose arg is the target module 330 if "activation_post_process" in node.name: 331 if node.args[0].target == fqn_of_module: 332 return True 333 return False 334 335 336class TestFxComposability(TestCase): 337 r"""This series of tests checks that various steps of the quantization and sparsity flow 338 compose cleanly despite variation in sequencing. 339 """ 340 341 def test_q_prep_fx_before_s_prep(self): 342 r""" 343 This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx 344 compose cleanly without issue and that the final result is sparsified without 345 having to call squash mask between sparse prepare and convert_fx. This also tests the 346 automatic fusion that occurs during prepare_fx. 347 """ 348 ( 349 mod, 350 sparsifier, 351 _, 352 ) = _get_model_and_sparsifier_and_sparse_config() 353 354 example = torch.randn(1, 4, 4, 4) 355 qconfig = tq.get_default_qconfig("fbgemm") 356 qconfig_mapping = ( 357 tq.QConfigMapping() 358 .set_module_name("4", qconfig) 359 .set_module_name("5", qconfig) 360 ) 361 362 mod = prepare_fx(mod, qconfig_mapping, (example,)) 363 364 # its absolutely broken by auto fusion in fx 365 # but will still work if you put the correct fqn in 366 sparse_config = [ 367 { 368 "tensor_fqn": "5.0.weight", 369 "sparsity_level": 0.7, 370 "sparse_block_shape": (1, 4), 371 "zeros_per_block": 4, 372 }, 373 {"tensor_fqn": "0.0.weight"}, 374 ] 375 sparsifier.prepare(mod, config=sparse_config) 376 377 # check that correct modules had parametrizations added and 378 # that none were lost during prepare 379 self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) 380 self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) 381 382 # check that correct observers were inserted and that matching 383 # occurred successfully 384 self.assertTrue(_module_has_activation_post_process(mod, "5")) 385 sparsifier.step() 386 sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) 387 mod(example) 388 mod = convert_fx(mod) 389 390 # check that final module is the expected quantized module and that the model runs 391 self.assertTrue( 392 isinstance( 393 fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU 394 ) 395 ) 396 self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) 397 398 # check that module was actually sparsified 399 cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) 400 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 401 self.assertGreaterAlmostEqual( 402 sparsity_level, sparse_config[0]["sparsity_level"] 403 ) 404 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 405 406 def test_q_prep_fx_s_prep_ref_conv(self): 407 r""" 408 This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx 409 compose cleanly without issue and that the final result is sparsified without 410 having to call squash mask before convert_to_reference_fx. 411 """ 412 ( 413 mod, 414 sparsifier, 415 _, 416 ) = _get_model_and_sparsifier_and_sparse_config() 417 418 example = torch.randn(1, 4, 4, 4) 419 qconfig = tq.get_default_qconfig("fbgemm") 420 qconfig_mapping = ( 421 tq.QConfigMapping() 422 .set_module_name("4", qconfig) 423 .set_module_name("5", qconfig) 424 ) 425 426 mod = prepare_fx(mod, qconfig_mapping, (example,)) 427 428 # its absolutely broken by auto fusion in fx 429 # but will still work if you put the correct fqn in 430 sparse_config = [ 431 { 432 "tensor_fqn": "5.0.weight", 433 "sparsity_level": 0.7, 434 "sparse_block_shape": (1, 4), 435 "zeros_per_block": 4, 436 }, 437 {"tensor_fqn": "0.0.weight"}, 438 ] 439 sparsifier.prepare(mod, config=sparse_config) 440 441 # check that correct modules had parametrizations added and 442 # that none were lost during prepare 443 self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) 444 self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) 445 446 # check that correct observers were inserted and that matching 447 # occurred successfully 448 self.assertTrue(_module_has_activation_post_process(mod, "5")) 449 sparsifier.step() 450 sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) 451 mod(example) 452 mod = convert_to_reference_fx(mod) 453 454 # check that final module is the expected quantized module and that the model runs 455 self.assertTrue( 456 isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU) 457 ) 458 self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) 459 self.assertTrue( 460 isinstance( 461 fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear 462 ) 463 ) 464 465 # check that module was actually sparsified 466 cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) 467 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 468 self.assertGreaterAlmostEqual( 469 sparsity_level, sparse_config[0]["sparsity_level"] 470 ) 471 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 472 473 def test_s_prep_before_q_prep_fx(self): 474 r""" 475 This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx 476 compose cleanly without issue and that the final result is sparsified without 477 having to call squash mask before convert_fx. 478 """ 479 ( 480 mod, 481 sparsifier, 482 sparse_config, 483 ) = _get_model_and_sparsifier_and_sparse_config() 484 sparsifier.prepare(mod, config=sparse_config) 485 486 example = torch.randn(1, 4, 4, 4) 487 qconfig = tq.get_default_qconfig("fbgemm") 488 qconfig_mapping = ( 489 tq.QConfigMapping() 490 .set_module_name("4", qconfig) 491 .set_module_name("5", qconfig) 492 ) 493 mod = prepare_fx(mod, qconfig_mapping, (example,)) 494 495 # check that correct modules had parametrizations added and 496 # that none were lost during prepare 497 self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) 498 self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) 499 500 # check that correct observers were inserted and that matching 501 # occurred successfully 502 self.assertTrue(_module_has_activation_post_process(mod, "5")) 503 sparsifier.step() 504 sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) 505 mod(example) 506 mod = convert_fx(mod) 507 508 # check that final module is the expected quantized module and that the model runs 509 self.assertTrue( 510 isinstance( 511 fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU 512 ) 513 ) 514 self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) 515 516 # check that module was actually sparsified 517 cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) 518 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 519 self.assertGreaterAlmostEqual( 520 sparsity_level, sparse_config[0]["sparsity_level"] 521 ) 522 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 523 524 def test_s_prep_before_qat_prep_fx(self): 525 r""" 526 This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx 527 compose cleanly without issue and that the final result is sparsified without 528 having to call squash mask before convert_fx. 529 """ 530 ( 531 mod, 532 sparsifier, 533 sparse_config, 534 ) = _get_model_and_sparsifier_and_sparse_config() 535 sparsifier.prepare(mod, config=sparse_config) 536 537 example = torch.randn(1, 4, 4, 4) 538 qconfig = tq.get_default_qat_qconfig("fbgemm") 539 qconfig_mapping = ( 540 tq.QConfigMapping() 541 .set_module_name("4", qconfig) 542 .set_module_name("5", qconfig) 543 ) 544 mod = prepare_qat_fx(mod, qconfig_mapping, (example,)) 545 546 # check that correct modules had parametrizations added and 547 # that none were lost during prepare 548 self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) 549 self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations")) 550 self.assertTrue( 551 isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU) 552 ) 553 554 # check that correct observers were inserted and that matching 555 # occurred successfully 556 self.assertTrue(_module_has_activation_post_process(mod, "5")) 557 sparsifier.step() 558 sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight")) 559 mod(example) 560 mod = convert_fx(mod) 561 562 # check that final module is the expected quantized module and that the model runs 563 self.assertTrue( 564 isinstance( 565 fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU 566 ) 567 ) 568 self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) 569 570 # check that module was actually sparsified 571 cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0]) 572 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 573 self.assertGreaterAlmostEqual( 574 sparsity_level, sparse_config[0]["sparsity_level"] 575 ) 576 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 577 578 def test_s_prep_q_prep_fx_ref(self): 579 r""" 580 This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx 581 compose cleanly without issue and that the final result is sparsified without 582 having to call squash mask before convert_to_reference_fx. 583 """ 584 ( 585 mod, 586 sparsifier, 587 sparse_config, 588 ) = _get_model_and_sparsifier_and_sparse_config() 589 sparsifier.prepare(mod, config=sparse_config) 590 591 example = torch.randn(1, 4, 4, 4) 592 qconfig = tq.get_default_qconfig("fbgemm") 593 qconfig_mapping = ( 594 tq.QConfigMapping() 595 .set_module_name("4", qconfig) 596 .set_module_name("5", qconfig) 597 ) 598 mod = prepare_fx(mod, qconfig_mapping, (example,)) 599 600 # check that correct modules had parametrizations added and 601 # that none were lost during prepare 602 self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations")) 603 self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations")) 604 605 # check that correct observers were inserted and that matching 606 # occurred successfully 607 self.assertTrue(_module_has_activation_post_process(mod, "5")) 608 sparsifier.step() 609 sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) 610 mod(example) 611 mod = convert_to_reference_fx(mod) 612 613 # check that final module is the expected quantized module and that the model runs 614 self.assertTrue( 615 isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU) 616 ) 617 self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4])) 618 self.assertTrue( 619 isinstance( 620 fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear 621 ) 622 ) 623 624 # check that module was actually sparsified 625 cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight")) 626 self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level) 627 self.assertGreaterAlmostEqual( 628 sparsity_level, sparse_config[0]["sparsity_level"] 629 ) 630 self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"]) 631