1# Owner(s): ["oncall: quantization"] 2from typing import Set 3 4import torch 5import torch.nn as nn 6import torch.ao.quantization.quantize_fx as quantize_fx 7import torch.nn.functional as F 8from torch.ao.quantization import QConfig, QConfigMapping 9from torch.ao.quantization.fx._model_report.detector import ( 10 DynamicStaticDetector, 11 InputWeightEqualizationDetector, 12 PerChannelDetector, 13 OutlierDetector, 14) 15from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver 16from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer 17from torch.ao.quantization.fx._model_report.model_report import ModelReport 18from torch.ao.quantization.observer import ( 19 HistogramObserver, 20 default_per_channel_weight_observer, 21 default_observer 22) 23from torch.ao.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU 24from torch.testing._internal.common_quantization import ( 25 ConvModel, 26 QuantizationTestCase, 27 SingleLayerLinearModel, 28 TwoLayerLinearModel, 29 skipIfNoFBGEMM, 30 skipIfNoQNNPACK, 31 override_quantized_engine, 32) 33 34 35""" 36Partition of input domain: 37 38Model contains: conv or linear, both conv and linear 39 Model contains: ConvTransposeNd (not supported for per_channel) 40 41Model is: post training quantization model, quantization aware training model 42Model is: composed with nn.Sequential, composed in class structure 43 44QConfig utilizes per_channel weight observer, backend uses non per_channel weight observer 45QConfig_dict uses only one default qconfig, Qconfig dict uses > 1 unique qconfigs 46 47Partition on output domain: 48 49There are possible changes / suggestions, there are no changes / suggestions 50""" 51 52# Default output for string if no optimizations are possible 53DEFAULT_NO_OPTIMS_ANSWER_STRING = ( 54 "Further Optimizations for backend {}: \nNo further per_channel optimizations possible." 55) 56 57# Example Sequential Model with multiple Conv and Linear with nesting involved 58NESTED_CONV_LINEAR_EXAMPLE = torch.nn.Sequential( 59 torch.nn.Conv2d(3, 3, 2, 1), 60 torch.nn.Sequential(torch.nn.Linear(9, 27), torch.nn.ReLU()), 61 torch.nn.Linear(27, 27), 62 torch.nn.ReLU(), 63 torch.nn.Conv2d(3, 3, 2, 1), 64) 65 66# Example Sequential Model with Conv sub-class example 67LAZY_CONV_LINEAR_EXAMPLE = torch.nn.Sequential( 68 torch.nn.LazyConv2d(3, 3, 2, 1), 69 torch.nn.Sequential(torch.nn.Linear(5, 27), torch.nn.ReLU()), 70 torch.nn.ReLU(), 71 torch.nn.Linear(27, 27), 72 torch.nn.ReLU(), 73 torch.nn.LazyConv2d(3, 3, 2, 1), 74) 75 76# Example Sequential Model with Fusion directly built into model 77FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential( 78 ConvReLU2d(torch.nn.Conv2d(3, 3, 2, 1), torch.nn.ReLU()), 79 torch.nn.Sequential(LinearReLU(torch.nn.Linear(9, 27), torch.nn.ReLU())), 80 LinearReLU(torch.nn.Linear(27, 27), torch.nn.ReLU()), 81 torch.nn.Conv2d(3, 3, 2, 1), 82) 83 84# Test class 85# example model to use for tests 86class ThreeOps(nn.Module): 87 def __init__(self) -> None: 88 super().__init__() 89 self.linear = nn.Linear(3, 3) 90 self.bn = nn.BatchNorm2d(3) 91 self.relu = nn.ReLU() 92 93 def forward(self, x): 94 x = self.linear(x) 95 x = self.bn(x) 96 x = self.relu(x) 97 return x 98 99 def get_example_inputs(self): 100 return (torch.randn(1, 3, 3, 3),) 101 102class TwoThreeOps(nn.Module): 103 def __init__(self) -> None: 104 super().__init__() 105 self.block1 = ThreeOps() 106 self.block2 = ThreeOps() 107 108 def forward(self, x): 109 x = self.block1(x) 110 y = self.block2(x) 111 z = x + y 112 z = F.relu(z) 113 return z 114 115 def get_example_inputs(self): 116 return (torch.randn(1, 3, 3, 3),) 117 118class TestFxModelReportDetector(QuantizationTestCase): 119 120 """Prepares and calibrate the model""" 121 122 def _prepare_model_and_run_input(self, model, q_config_mapping, input): 123 model_prep = torch.ao.quantization.quantize_fx.prepare_fx(model, q_config_mapping, input) # prep model 124 model_prep(input).sum() # calibrate the model 125 return model_prep 126 127 """Case includes: 128 one conv or linear 129 post training quantization 130 composed as module 131 qconfig uses per_channel weight observer 132 Only 1 qconfig in qconfig dict 133 Output has no changes / suggestions 134 """ 135 136 @skipIfNoFBGEMM 137 def test_simple_conv(self): 138 139 with override_quantized_engine('fbgemm'): 140 torch.backends.quantized.engine = "fbgemm" 141 142 q_config_mapping = QConfigMapping() 143 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 144 145 input = torch.randn(1, 3, 10, 10) 146 prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input) 147 148 # run the detector 149 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 150 optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) 151 152 # no optims possible and there should be nothing in per_channel_status 153 self.assertEqual( 154 optims_str, 155 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 156 ) 157 158 # there should only be one conv there in this model 159 self.assertEqual(per_channel_info["conv"]["backend"], torch.backends.quantized.engine) 160 self.assertEqual(len(per_channel_info), 1) 161 self.assertEqual(next(iter(per_channel_info)), "conv") 162 self.assertEqual( 163 per_channel_info["conv"]["per_channel_quantization_supported"], 164 True, 165 ) 166 self.assertEqual(per_channel_info["conv"]["per_channel_quantization_used"], True) 167 168 """Case includes: 169 Multiple conv or linear 170 post training quantization 171 composed as module 172 qconfig doesn't use per_channel weight observer 173 Only 1 qconfig in qconfig dict 174 Output has possible changes / suggestions 175 """ 176 177 @skipIfNoQNNPACK 178 def test_multi_linear_model_without_per_channel(self): 179 180 with override_quantized_engine('qnnpack'): 181 torch.backends.quantized.engine = "qnnpack" 182 183 q_config_mapping = QConfigMapping() 184 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 185 186 prepared_model = self._prepare_model_and_run_input( 187 TwoLayerLinearModel(), 188 q_config_mapping, 189 TwoLayerLinearModel().get_example_inputs()[0], 190 ) 191 192 # run the detector 193 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 194 optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) 195 196 # there should be optims possible 197 self.assertNotEqual( 198 optims_str, 199 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 200 ) 201 # pick a random key to look at 202 rand_key: str = next(iter(per_channel_info.keys())) 203 self.assertEqual(per_channel_info[rand_key]["backend"], torch.backends.quantized.engine) 204 self.assertEqual(len(per_channel_info), 2) 205 206 # for each linear layer, should be supported but not used 207 for linear_key in per_channel_info.keys(): 208 module_entry = per_channel_info[linear_key] 209 210 self.assertEqual(module_entry["per_channel_quantization_supported"], True) 211 self.assertEqual(module_entry["per_channel_quantization_used"], False) 212 213 """Case includes: 214 Multiple conv or linear 215 post training quantization 216 composed as Module 217 qconfig doesn't use per_channel weight observer 218 More than 1 qconfig in qconfig dict 219 Output has possible changes / suggestions 220 """ 221 222 @skipIfNoQNNPACK 223 def test_multiple_q_config_options(self): 224 225 with override_quantized_engine('qnnpack'): 226 torch.backends.quantized.engine = "qnnpack" 227 228 # qconfig with support for per_channel quantization 229 per_channel_qconfig = QConfig( 230 activation=HistogramObserver.with_args(reduce_range=True), 231 weight=default_per_channel_weight_observer, 232 ) 233 234 # we need to design the model 235 class ConvLinearModel(torch.nn.Module): 236 def __init__(self) -> None: 237 super().__init__() 238 self.conv1 = torch.nn.Conv2d(3, 3, 2, 1) 239 self.fc1 = torch.nn.Linear(9, 27) 240 self.relu = torch.nn.ReLU() 241 self.fc2 = torch.nn.Linear(27, 27) 242 self.conv2 = torch.nn.Conv2d(3, 3, 2, 1) 243 244 def forward(self, x): 245 x = self.conv1(x) 246 x = self.fc1(x) 247 x = self.relu(x) 248 x = self.fc2(x) 249 x = self.conv2(x) 250 return x 251 252 q_config_mapping = QConfigMapping() 253 q_config_mapping.set_global( 254 torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine) 255 ).set_object_type(torch.nn.Conv2d, per_channel_qconfig) 256 257 prepared_model = self._prepare_model_and_run_input( 258 ConvLinearModel(), 259 q_config_mapping, 260 torch.randn(1, 3, 10, 10), 261 ) 262 263 # run the detector 264 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 265 optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) 266 267 # the only suggestions should be to linear layers 268 269 # there should be optims possible 270 self.assertNotEqual( 271 optims_str, 272 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 273 ) 274 275 # to ensure it got into the nested layer 276 self.assertEqual(len(per_channel_info), 4) 277 278 # for each layer, should be supported but not used 279 for key in per_channel_info.keys(): 280 module_entry = per_channel_info[key] 281 self.assertEqual(module_entry["per_channel_quantization_supported"], True) 282 283 # if linear False, if conv2d true cuz it uses different config 284 if "fc" in key: 285 self.assertEqual(module_entry["per_channel_quantization_used"], False) 286 elif "conv" in key: 287 self.assertEqual(module_entry["per_channel_quantization_used"], True) 288 else: 289 raise ValueError("Should only contain conv and linear layers as key values") 290 291 """Case includes: 292 Multiple conv or linear 293 post training quantization 294 composed as sequential 295 qconfig doesn't use per_channel weight observer 296 Only 1 qconfig in qconfig dict 297 Output has possible changes / suggestions 298 """ 299 300 @skipIfNoQNNPACK 301 def test_sequential_model_format(self): 302 303 with override_quantized_engine('qnnpack'): 304 torch.backends.quantized.engine = "qnnpack" 305 306 q_config_mapping = QConfigMapping() 307 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 308 309 prepared_model = self._prepare_model_and_run_input( 310 NESTED_CONV_LINEAR_EXAMPLE, 311 q_config_mapping, 312 torch.randn(1, 3, 10, 10), 313 ) 314 315 # run the detector 316 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 317 optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) 318 319 # there should be optims possible 320 self.assertNotEqual( 321 optims_str, 322 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 323 ) 324 325 # to ensure it got into the nested layer 326 self.assertEqual(len(per_channel_info), 4) 327 328 # for each layer, should be supported but not used 329 for key in per_channel_info.keys(): 330 module_entry = per_channel_info[key] 331 332 self.assertEqual(module_entry["per_channel_quantization_supported"], True) 333 self.assertEqual(module_entry["per_channel_quantization_used"], False) 334 335 """Case includes: 336 Multiple conv or linear 337 post training quantization 338 composed as sequential 339 qconfig doesn't use per_channel weight observer 340 Only 1 qconfig in qconfig dict 341 Output has possible changes / suggestions 342 """ 343 344 @skipIfNoQNNPACK 345 def test_conv_sub_class_considered(self): 346 347 with override_quantized_engine('qnnpack'): 348 torch.backends.quantized.engine = "qnnpack" 349 350 q_config_mapping = QConfigMapping() 351 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 352 353 prepared_model = self._prepare_model_and_run_input( 354 LAZY_CONV_LINEAR_EXAMPLE, 355 q_config_mapping, 356 torch.randn(1, 3, 10, 10), 357 ) 358 359 # run the detector 360 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 361 optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) 362 363 # there should be optims possible 364 self.assertNotEqual( 365 optims_str, 366 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 367 ) 368 369 # to ensure it got into the nested layer and it considered the lazyConv2d 370 self.assertEqual(len(per_channel_info), 4) 371 372 # for each layer, should be supported but not used 373 for key in per_channel_info.keys(): 374 module_entry = per_channel_info[key] 375 376 self.assertEqual(module_entry["per_channel_quantization_supported"], True) 377 self.assertEqual(module_entry["per_channel_quantization_used"], False) 378 379 """Case includes: 380 Multiple conv or linear 381 post training quantization 382 composed as sequential 383 qconfig uses per_channel weight observer 384 Only 1 qconfig in qconfig dict 385 Output has no possible changes / suggestions 386 """ 387 388 @skipIfNoFBGEMM 389 def test_fusion_layer_in_sequential(self): 390 391 with override_quantized_engine('fbgemm'): 392 torch.backends.quantized.engine = "fbgemm" 393 394 q_config_mapping = QConfigMapping() 395 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 396 397 prepared_model = self._prepare_model_and_run_input( 398 FUSION_CONV_LINEAR_EXAMPLE, 399 q_config_mapping, 400 torch.randn(1, 3, 10, 10), 401 ) 402 403 # run the detector 404 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 405 optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model) 406 407 # no optims possible and there should be nothing in per_channel_status 408 self.assertEqual( 409 optims_str, 410 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 411 ) 412 413 # to ensure it got into the nested layer and it considered all the nested fusion components 414 self.assertEqual(len(per_channel_info), 4) 415 416 # for each layer, should be supported but not used 417 for key in per_channel_info.keys(): 418 module_entry = per_channel_info[key] 419 self.assertEqual(module_entry["per_channel_quantization_supported"], True) 420 self.assertEqual(module_entry["per_channel_quantization_used"], True) 421 422 """Case includes: 423 Multiple conv or linear 424 quantitative aware training 425 composed as model 426 qconfig does not use per_channel weight observer 427 Only 1 qconfig in qconfig dict 428 Output has possible changes / suggestions 429 """ 430 431 @skipIfNoQNNPACK 432 def test_qat_aware_model_example(self): 433 434 # first we want a QAT model 435 class QATConvLinearReluModel(torch.nn.Module): 436 def __init__(self) -> None: 437 super().__init__() 438 # QuantStub converts tensors from floating point to quantized 439 self.quant = torch.ao.quantization.QuantStub() 440 self.conv = torch.nn.Conv2d(1, 1, 1) 441 self.bn = torch.nn.BatchNorm2d(1) 442 self.relu = torch.nn.ReLU() 443 # DeQuantStub converts tensors from quantized to floating point 444 self.dequant = torch.ao.quantization.DeQuantStub() 445 446 def forward(self, x): 447 x = self.quant(x) 448 x = self.conv(x) 449 x = self.bn(x) 450 x = self.relu(x) 451 x = self.dequant(x) 452 return x 453 454 with override_quantized_engine('qnnpack'): 455 # create a model instance 456 model_fp32 = QATConvLinearReluModel() 457 458 model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig("qnnpack") 459 460 # model must be in eval mode for fusion 461 model_fp32.eval() 462 model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]]) 463 464 # model must be set to train mode for QAT logic to work 465 model_fp32_fused.train() 466 467 # prepare the model for QAT, different than for post training quantization 468 model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused) 469 470 # run the detector 471 per_channel_detector = PerChannelDetector(torch.backends.quantized.engine) 472 optims_str, per_channel_info = per_channel_detector.generate_detector_report(model_fp32_prepared) 473 474 # there should be optims possible 475 self.assertNotEqual( 476 optims_str, 477 DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine), 478 ) 479 480 # make sure it was able to find the single conv in the fused model 481 self.assertEqual(len(per_channel_info), 1) 482 483 # for the one conv, it should still give advice to use different qconfig 484 for key in per_channel_info.keys(): 485 module_entry = per_channel_info[key] 486 self.assertEqual(module_entry["per_channel_quantization_supported"], True) 487 self.assertEqual(module_entry["per_channel_quantization_used"], False) 488 489 490""" 491Partition on Domain / Things to Test 492 493- All zero tensor 494- Multiple tensor dimensions 495- All of the outward facing functions 496- Epoch min max are correctly updating 497- Batch range is correctly averaging as expected 498- Reset for each epoch is correctly resetting the values 499 500Partition on Output 501- the calcuation of the ratio is occurring correctly 502 503""" 504 505 506class TestFxModelReportObserver(QuantizationTestCase): 507 class NestedModifiedSingleLayerLinear(torch.nn.Module): 508 def __init__(self) -> None: 509 super().__init__() 510 self.obs1 = ModelReportObserver() 511 self.mod1 = SingleLayerLinearModel() 512 self.obs2 = ModelReportObserver() 513 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 514 self.relu = torch.nn.ReLU() 515 516 def forward(self, x): 517 x = self.obs1(x) 518 x = self.mod1(x) 519 x = self.obs2(x) 520 x = self.fc1(x) 521 x = self.relu(x) 522 return x 523 524 def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size): 525 # split up data into batches 526 split_up_data = torch.split(ex_input, batch_size) 527 for epoch in range(num_epochs): 528 # reset all model report obs 529 model.apply( 530 lambda module: module.reset_batch_and_epoch_values() 531 if isinstance(module, ModelReportObserver) 532 else None 533 ) 534 535 # quick check that a reset occurred 536 self.assertEqual( 537 model.obs1.average_batch_activation_range, 538 torch.tensor(float(0)), 539 ) 540 self.assertEqual(model.obs1.epoch_activation_min, torch.tensor(float("inf"))) 541 self.assertEqual(model.obs1.epoch_activation_max, torch.tensor(float("-inf"))) 542 543 # loop through the batches and run through 544 for index, batch in enumerate(split_up_data): 545 546 num_tracked_so_far = model.obs1.num_batches_tracked 547 self.assertEqual(num_tracked_so_far, index) 548 549 # get general info about the batch and the model to use later 550 batch_min, batch_max = torch.aminmax(batch) 551 current_average_range = model.obs1.average_batch_activation_range 552 current_epoch_min = model.obs1.epoch_activation_min 553 current_epoch_max = model.obs1.epoch_activation_max 554 555 # run input through 556 model(ex_input) 557 558 # check that average batch activation range updated correctly 559 correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / ( 560 num_tracked_so_far + 1 561 ) 562 self.assertEqual( 563 model.obs1.average_batch_activation_range, 564 correct_updated_value, 565 ) 566 567 if current_epoch_max - current_epoch_min > 0: 568 self.assertEqual( 569 model.obs1.get_batch_to_epoch_ratio(), 570 correct_updated_value / (current_epoch_max - current_epoch_min), 571 ) 572 573 """Case includes: 574 all zero tensor 575 dim size = 2 576 run for 1 epoch 577 run for 10 batch 578 tests input data observer 579 """ 580 581 def test_zero_tensor_errors(self): 582 # initialize the model 583 model = self.NestedModifiedSingleLayerLinear() 584 585 # generate the desired input 586 ex_input = torch.zeros((10, 1, 5)) 587 588 # run it through the model and do general tests 589 self.run_model_and_common_checks(model, ex_input, 1, 1) 590 591 # make sure final values are all 0 592 self.assertEqual(model.obs1.epoch_activation_min, 0) 593 self.assertEqual(model.obs1.epoch_activation_max, 0) 594 self.assertEqual(model.obs1.average_batch_activation_range, 0) 595 596 # we should get an error if we try to calculate the ratio 597 with self.assertRaises(ValueError): 598 ratio_val = model.obs1.get_batch_to_epoch_ratio() 599 600 """Case includes: 601 non-zero tensor 602 dim size = 2 603 run for 1 epoch 604 run for 1 batch 605 tests input data observer 606 """ 607 608 def test_single_batch_of_ones(self): 609 # initialize the model 610 model = self.NestedModifiedSingleLayerLinear() 611 612 # generate the desired input 613 ex_input = torch.ones((1, 1, 5)) 614 615 # run it through the model and do general tests 616 self.run_model_and_common_checks(model, ex_input, 1, 1) 617 618 # make sure final values are all 0 except for range 619 self.assertEqual(model.obs1.epoch_activation_min, 1) 620 self.assertEqual(model.obs1.epoch_activation_max, 1) 621 self.assertEqual(model.obs1.average_batch_activation_range, 0) 622 623 # we should get an error if we try to calculate the ratio 624 with self.assertRaises(ValueError): 625 ratio_val = model.obs1.get_batch_to_epoch_ratio() 626 627 """Case includes: 628 non-zero tensor 629 dim size = 2 630 run for 10 epoch 631 run for 15 batch 632 tests non input data observer 633 """ 634 635 def test_observer_after_relu(self): 636 637 # model specific to this test 638 class NestedModifiedObserverAfterRelu(torch.nn.Module): 639 def __init__(self) -> None: 640 super().__init__() 641 self.obs1 = ModelReportObserver() 642 self.mod1 = SingleLayerLinearModel() 643 self.obs2 = ModelReportObserver() 644 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 645 self.relu = torch.nn.ReLU() 646 647 def forward(self, x): 648 x = self.obs1(x) 649 x = self.mod1(x) 650 x = self.fc1(x) 651 x = self.relu(x) 652 x = self.obs2(x) 653 return x 654 655 # initialize the model 656 model = NestedModifiedObserverAfterRelu() 657 658 # generate the desired input 659 ex_input = torch.randn((15, 1, 5)) 660 661 # run it through the model and do general tests 662 self.run_model_and_common_checks(model, ex_input, 10, 15) 663 664 """Case includes: 665 non-zero tensor 666 dim size = 2 667 run for multiple epoch 668 run for multiple batch 669 tests input data observer 670 """ 671 672 def test_random_epochs_and_batches(self): 673 674 # set up a basic model 675 class TinyNestModule(torch.nn.Module): 676 def __init__(self) -> None: 677 super().__init__() 678 self.obs1 = ModelReportObserver() 679 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 680 self.relu = torch.nn.ReLU() 681 self.obs2 = ModelReportObserver() 682 683 def forward(self, x): 684 x = self.obs1(x) 685 x = self.fc1(x) 686 x = self.relu(x) 687 x = self.obs2(x) 688 return x 689 690 class LargerIncludeNestModel(torch.nn.Module): 691 def __init__(self) -> None: 692 super().__init__() 693 self.obs1 = ModelReportObserver() 694 self.nested = TinyNestModule() 695 self.fc1 = SingleLayerLinearModel() 696 self.relu = torch.nn.ReLU() 697 698 def forward(self, x): 699 x = self.obs1(x) 700 x = self.nested(x) 701 x = self.fc1(x) 702 x = self.relu(x) 703 return x 704 705 class ModifiedThreeOps(torch.nn.Module): 706 def __init__(self, batch_norm_dim): 707 super().__init__() 708 self.obs1 = ModelReportObserver() 709 self.linear = torch.nn.Linear(7, 3, 2) 710 self.obs2 = ModelReportObserver() 711 712 if batch_norm_dim == 2: 713 self.bn = torch.nn.BatchNorm2d(2) 714 elif batch_norm_dim == 3: 715 self.bn = torch.nn.BatchNorm3d(4) 716 else: 717 raise ValueError("Dim should only be 2 or 3") 718 719 self.relu = torch.nn.ReLU() 720 721 def forward(self, x): 722 x = self.obs1(x) 723 x = self.linear(x) 724 x = self.obs2(x) 725 x = self.bn(x) 726 x = self.relu(x) 727 return x 728 729 class HighDimensionNet(torch.nn.Module): 730 def __init__(self) -> None: 731 super().__init__() 732 self.obs1 = ModelReportObserver() 733 self.fc1 = torch.nn.Linear(3, 7) 734 self.block1 = ModifiedThreeOps(3) 735 self.fc2 = torch.nn.Linear(3, 7) 736 self.block2 = ModifiedThreeOps(3) 737 self.fc3 = torch.nn.Linear(3, 7) 738 739 def forward(self, x): 740 x = self.obs1(x) 741 x = self.fc1(x) 742 x = self.block1(x) 743 x = self.fc2(x) 744 y = self.block2(x) 745 y = self.fc3(y) 746 z = x + y 747 z = F.relu(z) 748 return z 749 750 # the purpose of this test is to give the observers a variety of data examples 751 # initialize the model 752 models = [ 753 self.NestedModifiedSingleLayerLinear(), 754 LargerIncludeNestModel(), 755 ModifiedThreeOps(2), 756 HighDimensionNet(), 757 ] 758 759 # get some number of epochs and batches 760 num_epochs = 10 761 num_batches = 15 762 763 input_shapes = [(1, 5), (1, 5), (2, 3, 7), (4, 1, 8, 3)] 764 765 # generate the desired inputs 766 inputs = [] 767 for shape in input_shapes: 768 ex_input = torch.randn((num_batches, *shape)) 769 inputs.append(ex_input) 770 771 # run it through the model and do general tests 772 for index, model in enumerate(models): 773 self.run_model_and_common_checks(model, inputs[index], num_epochs, num_batches) 774 775 776""" 777Partition on domain / things to test 778 779There is only a single test case for now. 780 781This will be more thoroughly tested with the implementation of the full end to end tool coming soon. 782""" 783 784 785class TestFxModelReportDetectDynamicStatic(QuantizationTestCase): 786 @skipIfNoFBGEMM 787 def test_nested_detection_case(self): 788 class SingleLinear(torch.nn.Module): 789 def __init__(self) -> None: 790 super().__init__() 791 self.linear = torch.nn.Linear(3, 3) 792 793 def forward(self, x): 794 x = self.linear(x) 795 return x 796 797 class TwoBlockNet(torch.nn.Module): 798 def __init__(self) -> None: 799 super().__init__() 800 self.block1 = SingleLinear() 801 self.block2 = SingleLinear() 802 803 def forward(self, x): 804 x = self.block1(x) 805 y = self.block2(x) 806 z = x + y 807 z = F.relu(z) 808 return z 809 810 811 with override_quantized_engine('fbgemm'): 812 # create model, example input, and qconfig mapping 813 torch.backends.quantized.engine = "fbgemm" 814 model = TwoBlockNet() 815 example_input = torch.randint(-10, 0, (1, 3, 3, 3)) 816 example_input = example_input.to(torch.float) 817 q_config_mapping = QConfigMapping() 818 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm")) 819 820 # prep model and select observer 821 model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) 822 obs_ctr = ModelReportObserver 823 824 # find layer to attach to and store 825 linear_fqn = "block2.linear" # fqn of target linear 826 827 target_linear = None 828 for node in model_prep.graph.nodes: 829 if node.target == linear_fqn: 830 target_linear = node 831 break 832 833 # insert into both module and graph pre and post 834 835 # set up to insert before target_linear (pre_observer) 836 with model_prep.graph.inserting_before(target_linear): 837 obs_to_insert = obs_ctr() 838 pre_obs_fqn = linear_fqn + ".model_report_pre_observer" 839 model_prep.add_submodule(pre_obs_fqn, obs_to_insert) 840 model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args) 841 842 # set up and insert after the target_linear (post_observer) 843 with model_prep.graph.inserting_after(target_linear): 844 obs_to_insert = obs_ctr() 845 post_obs_fqn = linear_fqn + ".model_report_post_observer" 846 model_prep.add_submodule(post_obs_fqn, obs_to_insert) 847 model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,)) 848 849 # need to recompile module after submodule added and pass input through 850 model_prep.recompile() 851 852 num_iterations = 10 853 for i in range(num_iterations): 854 if i % 2 == 0: 855 example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float) 856 else: 857 example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float) 858 model_prep(example_input) 859 860 # run it through the dynamic vs static detector 861 dynamic_vs_static_detector = DynamicStaticDetector() 862 dynam_vs_stat_str, dynam_vs_stat_dict = dynamic_vs_static_detector.generate_detector_report(model_prep) 863 864 # one of the stats should be stationary, and the other non-stationary 865 # as a result, dynamic should be recommended 866 data_dist_info = [ 867 dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.PRE_OBS_DATA_DIST_KEY], 868 dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.POST_OBS_DATA_DIST_KEY], 869 ] 870 871 self.assertTrue("stationary" in data_dist_info) 872 self.assertTrue("non-stationary" in data_dist_info) 873 self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"]) 874 875class TestFxModelReportClass(QuantizationTestCase): 876 877 @skipIfNoFBGEMM 878 def test_constructor(self): 879 """ 880 Tests the constructor of the ModelReport class. 881 Specifically looks at: 882 - The desired reports 883 - Ensures that the observers of interest are properly initialized 884 """ 885 886 with override_quantized_engine('fbgemm'): 887 # set the backend for this test 888 torch.backends.quantized.engine = "fbgemm" 889 backend = torch.backends.quantized.engine 890 891 # create a model 892 model = ThreeOps() 893 q_config_mapping = QConfigMapping() 894 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 895 model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0]) 896 897 # make an example set of detectors 898 test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)} 899 # initialize with an empty detector 900 model_report = ModelReport(model_prep, test_detector_set) 901 902 # make sure internal valid reports matches 903 detector_name_set = {detector.get_detector_name() for detector in test_detector_set} 904 self.assertEqual(model_report.get_desired_reports_names(), detector_name_set) 905 906 # now attempt with no valid reports, should raise error 907 with self.assertRaises(ValueError): 908 model_report = ModelReport(model, set()) 909 910 # number of expected obs of interest entries 911 num_expected_entries = len(test_detector_set) 912 self.assertEqual(len(model_report.get_observers_of_interest()), num_expected_entries) 913 914 for value in model_report.get_observers_of_interest().values(): 915 self.assertEqual(len(value), 0) 916 917 @skipIfNoFBGEMM 918 def test_prepare_model_callibration(self): 919 """ 920 Tests model_report.prepare_detailed_calibration that prepares the model for callibration 921 Specifically looks at: 922 - Whether observers are properly inserted into regular nn.Module 923 - Whether the target and the arguments of the observers are proper 924 - Whether the internal representation of observers of interest is updated 925 """ 926 927 with override_quantized_engine('fbgemm'): 928 # create model report object 929 930 # create model 931 model = TwoThreeOps() 932 # make an example set of detectors 933 torch.backends.quantized.engine = "fbgemm" 934 backend = torch.backends.quantized.engine 935 test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)} 936 # initialize with an empty detector 937 938 # prepare the model 939 example_input = model.get_example_inputs()[0] 940 current_backend = torch.backends.quantized.engine 941 q_config_mapping = QConfigMapping() 942 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 943 944 model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) 945 946 model_report = ModelReport(model_prep, test_detector_set) 947 948 # prepare the model for callibration 949 prepared_for_callibrate_model = model_report.prepare_detailed_calibration() 950 951 # see whether observers properly in regular nn.Module 952 # there should be 4 observers present in this case 953 modules_observer_cnt = 0 954 for fqn, module in prepared_for_callibrate_model.named_modules(): 955 if isinstance(module, ModelReportObserver): 956 modules_observer_cnt += 1 957 958 self.assertEqual(modules_observer_cnt, 4) 959 960 model_report_str_check = "model_report" 961 # also make sure arguments for observers in the graph are proper 962 for node in prepared_for_callibrate_model.graph.nodes: 963 # not all node targets are strings, so check 964 if isinstance(node.target, str) and model_report_str_check in node.target: 965 # if pre-observer has same args as the linear (next node) 966 if "pre_observer" in node.target: 967 self.assertEqual(node.args, node.next.args) 968 # if post-observer, args are the target linear (previous node) 969 if "post_observer" in node.target: 970 self.assertEqual(node.args, (node.prev,)) 971 972 # ensure model_report observers of interest updated 973 # there should be two entries 974 self.assertEqual(len(model_report.get_observers_of_interest()), 2) 975 for detector in test_detector_set: 976 self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest().keys()) 977 978 # get number of entries for this detector 979 detector_obs_of_interest_fqns = model_report.get_observers_of_interest()[detector.get_detector_name()] 980 981 # assert that the per channel detector has 0 and the dynamic static has 4 982 if isinstance(detector, PerChannelDetector): 983 self.assertEqual(len(detector_obs_of_interest_fqns), 0) 984 elif isinstance(detector, DynamicStaticDetector): 985 self.assertEqual(len(detector_obs_of_interest_fqns), 4) 986 987 # ensure that we can prepare for callibration only once 988 with self.assertRaises(ValueError): 989 prepared_for_callibrate_model = model_report.prepare_detailed_calibration() 990 991 992 def get_module_and_graph_cnts(self, callibrated_fx_module): 993 r""" 994 Calculates number of ModelReportObserver modules in the model as well as the graph structure. 995 Returns a tuple of two elements: 996 int: The number of ModelReportObservers found in the model 997 int: The number of model_report nodes found in the graph 998 """ 999 # get the number of observers stored as modules 1000 modules_observer_cnt = 0 1001 for fqn, module in callibrated_fx_module.named_modules(): 1002 if isinstance(module, ModelReportObserver): 1003 modules_observer_cnt += 1 1004 1005 # get number of observers in the graph 1006 model_report_str_check = "model_report" 1007 graph_observer_cnt = 0 1008 # also make sure arguments for observers in the graph are proper 1009 for node in callibrated_fx_module.graph.nodes: 1010 # not all node targets are strings, so check 1011 if isinstance(node.target, str) and model_report_str_check in node.target: 1012 # increment if we found a graph observer 1013 graph_observer_cnt += 1 1014 1015 return (modules_observer_cnt, graph_observer_cnt) 1016 1017 @skipIfNoFBGEMM 1018 def test_generate_report(self): 1019 """ 1020 Tests model_report.generate_model_report to ensure report generation 1021 Specifically looks at: 1022 - Whether correct number of reports are being generated 1023 - Whether observers are being properly removed if specified 1024 - Whether correct blocking from generating report twice if obs removed 1025 """ 1026 1027 with override_quantized_engine('fbgemm'): 1028 # set the backend for this test 1029 torch.backends.quantized.engine = "fbgemm" 1030 1031 # check whether the correct number of reports are being generated 1032 filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)} 1033 single_detector_set = {DynamicStaticDetector()} 1034 1035 # create our models 1036 model_full = TwoThreeOps() 1037 model_single = TwoThreeOps() 1038 1039 # prepare and callibrate two different instances of same model 1040 # prepare the model 1041 example_input = model_full.get_example_inputs()[0] 1042 current_backend = torch.backends.quantized.engine 1043 q_config_mapping = QConfigMapping() 1044 q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)) 1045 1046 model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input) 1047 model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input) 1048 1049 # initialize one with filled detector 1050 model_report_full = ModelReport(model_prep_full, filled_detector_set) 1051 # initialize another with a single detector set 1052 model_report_single = ModelReport(model_prep_single, single_detector_set) 1053 1054 # prepare the models for callibration 1055 prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration() 1056 prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration() 1057 1058 # now callibrate the two models 1059 num_iterations = 10 1060 for i in range(num_iterations): 1061 example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float) 1062 prepared_for_callibrate_model_full(example_input) 1063 prepared_for_callibrate_model_single(example_input) 1064 1065 # now generate the reports 1066 model_full_report = model_report_full.generate_model_report(True) 1067 model_single_report = model_report_single.generate_model_report(False) 1068 1069 # check that sizes are appropriate 1070 self.assertEqual(len(model_full_report), len(filled_detector_set)) 1071 self.assertEqual(len(model_single_report), len(single_detector_set)) 1072 1073 # make sure observers are being properly removed for full report since we put flag in 1074 modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full) 1075 self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules 1076 self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph 1077 1078 # make sure observers aren't being removed for single report since not specified 1079 modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single) 1080 self.assertNotEqual(modules_observer_cnt, 0) 1081 self.assertNotEqual(graph_observer_cnt, 0) 1082 1083 # make sure error when try to rerun report generation for full report but not single report 1084 with self.assertRaises(Exception): 1085 model_full_report = model_report_full.generate_model_report( 1086 prepared_for_callibrate_model_full, False 1087 ) 1088 1089 # make sure we don't run into error for single report 1090 model_single_report = model_report_single.generate_model_report(False) 1091 1092 @skipIfNoFBGEMM 1093 def test_generate_visualizer(self): 1094 """ 1095 Tests that the ModelReport class can properly create the ModelReportVisualizer instance 1096 Checks that: 1097 - Correct number of modules are represented 1098 - Modules are sorted 1099 - Correct number of features for each module 1100 """ 1101 with override_quantized_engine('fbgemm'): 1102 # set the backend for this test 1103 torch.backends.quantized.engine = "fbgemm" 1104 # test with multiple detectors 1105 detector_set = set() 1106 detector_set.add(OutlierDetector(reference_percentile=0.95)) 1107 detector_set.add(InputWeightEqualizationDetector(0.5)) 1108 1109 model = TwoThreeOps() 1110 1111 # get tst model and callibrate 1112 prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( 1113 model, detector_set, model.get_example_inputs()[0] 1114 ) 1115 1116 # now we actually callibrate the model 1117 example_input = model.get_example_inputs()[0] 1118 example_input = example_input.to(torch.float) 1119 1120 prepared_for_callibrate_model(example_input) 1121 1122 # try to visualize without generating report, should throw error 1123 with self.assertRaises(Exception): 1124 mod_rep_visualizaiton = mod_report.generate_visualizer() 1125 1126 # now get the report by running it through ModelReport instance 1127 generated_report = mod_report.generate_model_report(remove_inserted_observers=False) 1128 1129 # now we get the visualizer should not error 1130 mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer() 1131 1132 # since we tested with outlier detector, which looks at every base level module 1133 # should be six entries in the ordered dict 1134 mod_fqns_to_features = mod_rep_visualizer.generated_reports 1135 1136 self.assertEqual(len(mod_fqns_to_features), 6) 1137 1138 # outlier detector has 9 feature per module 1139 # input-weight has 12 features per module 1140 # there are 1 common data point, so should be 12 + 9 - 1 = 20 unique features per common modules 1141 # all linears will be common 1142 for module_fqn in mod_fqns_to_features: 1143 if ".linear" in module_fqn: 1144 linear_info = mod_fqns_to_features[module_fqn] 1145 self.assertEqual(len(linear_info), 20) 1146 1147 @skipIfNoFBGEMM 1148 def test_qconfig_mapping_generation(self): 1149 """ 1150 Tests for generation of qconfigs by ModelReport API 1151 - Tests that qconfigmapping is generated 1152 - Tests that mappings include information for for relavent modules 1153 """ 1154 with override_quantized_engine('fbgemm'): 1155 # set the backend for this test 1156 torch.backends.quantized.engine = "fbgemm" 1157 # test with multiple detectors 1158 detector_set = set() 1159 detector_set.add(PerChannelDetector()) 1160 detector_set.add(DynamicStaticDetector()) 1161 1162 model = TwoThreeOps() 1163 1164 # get tst model and callibrate 1165 prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( 1166 model, detector_set, model.get_example_inputs()[0] 1167 ) 1168 1169 # now we actually callibrate the models 1170 example_input = model.get_example_inputs()[0] 1171 example_input = example_input.to(torch.float) 1172 1173 prepared_for_callibrate_model(example_input) 1174 1175 1176 # get the mapping without error 1177 qconfig_mapping = mod_report.generate_qconfig_mapping() 1178 1179 # now get the report by running it through ModelReport instance 1180 generated_report = mod_report.generate_model_report(remove_inserted_observers=False) 1181 1182 # get the visualizer so we can get access to reformatted reports by module fqn 1183 mod_reports_by_fqn = mod_report.generate_visualizer().generated_reports 1184 1185 # compare the entries of the mapping to those of the report 1186 # we should have the same number of entries 1187 self.assertEqual(len(qconfig_mapping.module_name_qconfigs), len(mod_reports_by_fqn)) 1188 1189 # for the non_empty one, we should have 2 because we have only applicable linears 1190 # so should have suggestions for each module named 1191 self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2) 1192 1193 # only two linears, make sure per channel min max for weight since fbgemm 1194 # also static distribution since a simple single callibration 1195 for key in qconfig_mapping.module_name_qconfigs: 1196 config = qconfig_mapping.module_name_qconfigs[key] 1197 self.assertEqual(config.weight, default_per_channel_weight_observer) 1198 self.assertEqual(config.activation, default_observer) 1199 1200 # make sure these can actually be used to prepare the model 1201 prepared = quantize_fx.prepare_fx(TwoThreeOps(), qconfig_mapping, example_input) 1202 1203 # now convert the model to ensure no errors in conversion 1204 converted = quantize_fx.convert_fx(prepared) 1205 1206 @skipIfNoFBGEMM 1207 def test_equalization_mapping_generation(self): 1208 """ 1209 Tests for generation of qconfigs by ModelReport API 1210 - Tests that equalization config generated when input-weight equalization detector used 1211 - Tests that mappings include information for for relavent modules 1212 """ 1213 with override_quantized_engine('fbgemm'): 1214 # set the backend for this test 1215 torch.backends.quantized.engine = "fbgemm" 1216 # test with multiple detectors 1217 detector_set = set() 1218 detector_set.add(InputWeightEqualizationDetector(0.6)) 1219 1220 model = TwoThreeOps() 1221 1222 # get tst model and callibrate 1223 prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( 1224 model, detector_set, model.get_example_inputs()[0] 1225 ) 1226 1227 # now we actually callibrate the models 1228 example_input = model.get_example_inputs()[0] 1229 example_input = example_input.to(torch.float) 1230 1231 prepared_for_callibrate_model(example_input) 1232 1233 1234 # get the mapping without error 1235 qconfig_mapping = mod_report.generate_qconfig_mapping() 1236 equalization_mapping = mod_report.generate_equalization_mapping() 1237 1238 # tests a lot more simple for the equalization mapping 1239 1240 # shouldn't have any equalization suggestions for this case 1241 self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2) 1242 1243 1244 # make sure these can actually be used to prepare the model 1245 prepared = quantize_fx.prepare_fx( 1246 TwoThreeOps(), 1247 qconfig_mapping, 1248 example_input, 1249 _equalization_config=equalization_mapping 1250 ) 1251 1252 # now convert the model to ensure no errors in conversion 1253 converted = quantize_fx.convert_fx(prepared) 1254 1255class TestFxDetectInputWeightEqualization(QuantizationTestCase): 1256 1257 class SimpleConv(torch.nn.Module): 1258 def __init__(self, con_dims): 1259 super().__init__() 1260 self.relu = torch.nn.ReLU() 1261 self.conv = torch.nn.Conv2d(con_dims[0], con_dims[1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) 1262 1263 def forward(self, x): 1264 x = self.conv(x) 1265 x = self.relu(x) 1266 return x 1267 1268 class TwoBlockComplexNet(torch.nn.Module): 1269 def __init__(self) -> None: 1270 super().__init__() 1271 self.block1 = TestFxDetectInputWeightEqualization.SimpleConv((3, 32)) 1272 self.block2 = TestFxDetectInputWeightEqualization.SimpleConv((3, 3)) 1273 self.conv = torch.nn.Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False) 1274 self.linear = torch.nn.Linear(768, 10) 1275 self.relu = torch.nn.ReLU() 1276 1277 def forward(self, x): 1278 x = self.block1(x) 1279 x = self.conv(x) 1280 y = self.block2(x) 1281 y = y.repeat(1, 1, 2, 2) 1282 z = x + y 1283 z = z.flatten(start_dim=1) 1284 z = self.linear(z) 1285 z = self.relu(z) 1286 return z 1287 1288 def get_fusion_modules(self): 1289 return [['conv', 'relu']] 1290 1291 def get_example_inputs(self): 1292 return (torch.randn((1, 3, 28, 28)),) 1293 1294 class ReluOnly(torch.nn.Module): 1295 def __init__(self) -> None: 1296 super().__init__() 1297 self.relu = torch.nn.ReLU() 1298 1299 def forward(self, x): 1300 x = self.relu(x) 1301 return x 1302 1303 def get_example_inputs(self): 1304 return (torch.arange(27).reshape((1, 3, 3, 3)),) 1305 1306 def _get_prepped_for_calibration_model(self, model, detector_set, fused=False): 1307 r"""Returns a model that has been prepared for callibration and corresponding model_report""" 1308 1309 # pass in necessary inputs to helper 1310 example_input = model.get_example_inputs()[0] 1311 return _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused) 1312 1313 @skipIfNoFBGEMM 1314 def test_input_weight_equalization_determine_points(self): 1315 # use fbgemm and create our model instance 1316 # then create model report instance with detector 1317 with override_quantized_engine('fbgemm'): 1318 1319 detector_set = {InputWeightEqualizationDetector(0.5)} 1320 1321 # get tst model and callibrate 1322 non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set) 1323 fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set, fused=True) 1324 1325 # reporter should still give same counts even for fused model 1326 for prepared_for_callibrate_model, mod_report in [non_fused, fused]: 1327 1328 # supported modules to check 1329 mods_to_check = {nn.Linear, nn.Conv2d} 1330 1331 # get the set of all nodes in the graph their fqns 1332 node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes} 1333 1334 # there should be 4 node fqns that have the observer inserted 1335 correct_number_of_obs_inserted = 4 1336 number_of_obs_found = 0 1337 obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME 1338 1339 for node in prepared_for_callibrate_model.graph.nodes: 1340 # if the obs name is inside the target, we found an observer 1341 if obs_name_to_find in str(node.target): 1342 number_of_obs_found += 1 1343 1344 self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted) 1345 1346 # assert that each of the desired modules have the observers inserted 1347 for fqn, module in prepared_for_callibrate_model.named_modules(): 1348 # check if module is a supported module 1349 is_in_include_list = sum(isinstance(module, x) for x in mods_to_check) > 0 1350 1351 if is_in_include_list: 1352 # make sure it has the observer attribute 1353 self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) 1354 else: 1355 # if it's not a supported type, it shouldn't have observer attached 1356 self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) 1357 1358 @skipIfNoFBGEMM 1359 def test_input_weight_equalization_report_gen(self): 1360 # use fbgemm and create our model instance 1361 # then create model report instance with detector 1362 with override_quantized_engine('fbgemm'): 1363 1364 test_input_weight_detector = InputWeightEqualizationDetector(0.4) 1365 detector_set = {test_input_weight_detector} 1366 model = self.TwoBlockComplexNet() 1367 # prepare the model for callibration 1368 prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model( 1369 model, detector_set 1370 ) 1371 1372 # now we actually callibrate the model 1373 example_input = model.get_example_inputs()[0] 1374 example_input = example_input.to(torch.float) 1375 1376 prepared_for_callibrate_model(example_input) 1377 1378 # now get the report by running it through ModelReport instance 1379 generated_report = model_report.generate_model_report(True) 1380 1381 # check that sizes are appropriate only 1 detector 1382 self.assertEqual(len(generated_report), 1) 1383 1384 # get the specific report for input weight equalization 1385 input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()] 1386 1387 # we should have 5 layers looked at since 4 conv / linear layers 1388 self.assertEqual(len(input_weight_dict), 4) 1389 1390 # we can validate that the max and min values of the detector were recorded properly for the first one 1391 # this is because no data has been processed yet, so it should be values from original input 1392 1393 example_input = example_input.reshape((3, 28, 28)) # reshape input 1394 for module_fqn in input_weight_dict: 1395 # look for the first linear 1396 if "block1.linear" in module_fqn: 1397 block_1_lin_recs = input_weight_dict[module_fqn] 1398 # get input range info and the channel axis 1399 ch_axis = block_1_lin_recs[InputWeightEqualizationDetector.CHANNEL_KEY] 1400 1401 # ensure that the min and max values extracted match properly 1402 example_min, example_max = torch.aminmax(example_input, dim=ch_axis) 1403 dimension_min = torch.amin(example_min, dim=ch_axis) 1404 dimension_max = torch.amax(example_max, dim=ch_axis) 1405 1406 # make sure per channel min and max are as expected 1407 min_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX 1408 min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY 1409 1410 max_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX 1411 max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY 1412 1413 per_channel_min = block_1_lin_recs[min_per_key] 1414 per_channel_max = block_1_lin_recs[max_per_key] 1415 self.assertEqual(per_channel_min, dimension_min) 1416 self.assertEqual(per_channel_max, dimension_max) 1417 1418 # make sure per channel min and max are as expected 1419 min_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX 1420 min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY 1421 1422 max_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX 1423 max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY 1424 1425 # make sure the global min and max were correctly recorded and presented 1426 global_min = block_1_lin_recs[min_key] 1427 global_max = block_1_lin_recs[max_key] 1428 self.assertEqual(global_min, min(dimension_min)) 1429 self.assertEqual(global_max, max(dimension_max)) 1430 1431 input_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min)) 1432 # ensure comparision stat passed back is sqrt of range ratios 1433 # need to get the weight ratios first 1434 1435 # make sure per channel min and max are as expected 1436 min_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX 1437 min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY 1438 1439 max_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX 1440 max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY 1441 1442 # get weight per channel and global info 1443 per_channel_min = block_1_lin_recs[min_per_key] 1444 per_channel_max = block_1_lin_recs[max_per_key] 1445 1446 # make sure per channel min and max are as expected 1447 min_key = InputWeightEqualizationDetector.WEIGHT_PREFIX 1448 min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY 1449 1450 max_key = InputWeightEqualizationDetector.WEIGHT_PREFIX 1451 max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY 1452 1453 global_min = block_1_lin_recs[min_key] 1454 global_max = block_1_lin_recs[max_key] 1455 1456 weight_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min)) 1457 1458 # also get comp stat for this specific layer 1459 comp_stat = block_1_lin_recs[InputWeightEqualizationDetector.COMP_METRIC_KEY] 1460 1461 weight_to_input_ratio = weight_ratio / input_ratio 1462 1463 self.assertEqual(comp_stat, weight_to_input_ratio) 1464 # only looking at the first example so can break 1465 break 1466 1467 @skipIfNoFBGEMM 1468 def test_input_weight_equalization_report_gen_empty(self): 1469 # tests report gen on a model that doesn't have any layers 1470 # use fbgemm and create our model instance 1471 # then create model report instance with detector 1472 with override_quantized_engine('fbgemm'): 1473 test_input_weight_detector = InputWeightEqualizationDetector(0.4) 1474 detector_set = {test_input_weight_detector} 1475 model = self.ReluOnly() 1476 # prepare the model for callibration 1477 prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set) 1478 1479 # now we actually callibrate the model 1480 example_input = model.get_example_inputs()[0] 1481 example_input = example_input.to(torch.float) 1482 1483 prepared_for_callibrate_model(example_input) 1484 1485 # now get the report by running it through ModelReport instance 1486 generated_report = model_report.generate_model_report(True) 1487 1488 # check that sizes are appropriate only 1 detector 1489 self.assertEqual(len(generated_report), 1) 1490 1491 # get the specific report for input weight equalization 1492 input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()] 1493 1494 # we should have 0 layers since there is only a Relu 1495 self.assertEqual(len(input_weight_dict), 0) 1496 1497 # make sure that the string only has two lines, as should be if no suggestions 1498 self.assertEqual(input_weight_str.count("\n"), 2) 1499 1500 1501class TestFxDetectOutliers(QuantizationTestCase): 1502 1503 class LargeBatchModel(torch.nn.Module): 1504 def __init__(self, param_size): 1505 super().__init__() 1506 self.param_size = param_size 1507 self.linear = torch.nn.Linear(param_size, param_size) 1508 self.relu_1 = torch.nn.ReLU() 1509 self.conv = torch.nn.Conv2d(param_size, param_size, 1) 1510 self.relu_2 = torch.nn.ReLU() 1511 1512 def forward(self, x): 1513 x = self.linear(x) 1514 x = self.relu_1(x) 1515 x = self.conv(x) 1516 x = self.relu_2(x) 1517 return x 1518 1519 def get_example_inputs(self): 1520 param_size = self.param_size 1521 return (torch.randn((1, param_size, param_size, param_size)),) 1522 1523 def get_outlier_inputs(self): 1524 param_size = self.param_size 1525 random_vals = torch.randn((1, param_size, param_size, param_size)) 1526 # change one in some of them to be a massive value 1527 random_vals[:, 0:param_size:2, 0, 3] = torch.tensor([3.28e8]) 1528 return (random_vals,) 1529 1530 1531 def _get_prepped_for_calibration_model(self, model, detector_set, use_outlier_data=False): 1532 r"""Returns a model that has been prepared for callibration and corresponding model_report""" 1533 # call the general helper function to callibrate 1534 example_input = model.get_example_inputs()[0] 1535 1536 # if we specifically want to test data with outliers replace input 1537 if use_outlier_data: 1538 example_input = model.get_outlier_inputs()[0] 1539 1540 return _get_prepped_for_calibration_model_helper(model, detector_set, example_input) 1541 1542 @skipIfNoFBGEMM 1543 def test_outlier_detection_determine_points(self): 1544 # use fbgemm and create our model instance 1545 # then create model report instance with detector 1546 # similar to test for InputWeightEqualization but key differences that made refactoring not viable 1547 # not explicitly testing fusion because fx workflow automatically 1548 with override_quantized_engine('fbgemm'): 1549 1550 detector_set = {OutlierDetector(reference_percentile=0.95)} 1551 1552 # get tst model and callibrate 1553 prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( 1554 self.LargeBatchModel(param_size=128), detector_set 1555 ) 1556 1557 # supported modules to check 1558 mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU} 1559 1560 # there should be 4 node fqns that have the observer inserted 1561 correct_number_of_obs_inserted = 4 1562 number_of_obs_found = 0 1563 obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME 1564 1565 number_of_obs_found = sum( 1566 1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes 1567 ) 1568 self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted) 1569 1570 # assert that each of the desired modules have the observers inserted 1571 for fqn, module in prepared_for_callibrate_model.named_modules(): 1572 # check if module is a supported module 1573 is_in_include_list = isinstance(module, tuple(mods_to_check)) 1574 1575 if is_in_include_list: 1576 # make sure it has the observer attribute 1577 self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) 1578 else: 1579 # if it's not a supported type, it shouldn't have observer attached 1580 self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME)) 1581 1582 @skipIfNoFBGEMM 1583 def test_no_outlier_report_gen(self): 1584 # use fbgemm and create our model instance 1585 # then create model report instance with detector 1586 with override_quantized_engine('fbgemm'): 1587 1588 # test with multiple detectors 1589 outlier_detector = OutlierDetector(reference_percentile=0.95) 1590 dynamic_static_detector = DynamicStaticDetector(tolerance=0.5) 1591 1592 param_size: int = 4 1593 detector_set = {outlier_detector, dynamic_static_detector} 1594 model = self.LargeBatchModel(param_size=param_size) 1595 1596 # get tst model and callibrate 1597 prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( 1598 model, detector_set 1599 ) 1600 1601 # now we actually callibrate the model 1602 example_input = model.get_example_inputs()[0] 1603 example_input = example_input.to(torch.float) 1604 1605 prepared_for_callibrate_model(example_input) 1606 1607 # now get the report by running it through ModelReport instance 1608 generated_report = mod_report.generate_model_report(True) 1609 1610 # check that sizes are appropriate only 2 detectors 1611 self.assertEqual(len(generated_report), 2) 1612 1613 # get the specific report for input weight equalization 1614 outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()] 1615 1616 # we should have 5 layers looked at since 4 conv + linear + relu 1617 self.assertEqual(len(outlier_dict), 4) 1618 1619 # assert the following are true for all the modules 1620 for module_fqn in outlier_dict: 1621 # get the info for the specific module 1622 module_dict = outlier_dict[module_fqn] 1623 1624 # there really should not be any outliers since we used a normal distribution to perform this calculation 1625 outlier_info = module_dict[OutlierDetector.OUTLIER_KEY] 1626 self.assertEqual(sum(outlier_info), 0) 1627 1628 # ensure that the number of ratios and batches counted is the same as the number of params 1629 self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size) 1630 self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size) 1631 1632 1633 @skipIfNoFBGEMM 1634 def test_all_outlier_report_gen(self): 1635 # make the percentile 0 and the ratio 1, and then see that everything is outlier according to it 1636 # use fbgemm and create our model instance 1637 # then create model report instance with detector 1638 with override_quantized_engine('fbgemm'): 1639 # create detector of interest 1640 outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0) 1641 1642 param_size: int = 16 1643 detector_set = {outlier_detector} 1644 model = self.LargeBatchModel(param_size=param_size) 1645 1646 # get tst model and callibrate 1647 prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( 1648 model, detector_set 1649 ) 1650 1651 # now we actually callibrate the model 1652 example_input = model.get_example_inputs()[0] 1653 example_input = example_input.to(torch.float) 1654 1655 prepared_for_callibrate_model(example_input) 1656 1657 # now get the report by running it through ModelReport instance 1658 generated_report = mod_report.generate_model_report(True) 1659 1660 # check that sizes are appropriate only 1 detector 1661 self.assertEqual(len(generated_report), 1) 1662 1663 # get the specific report for input weight equalization 1664 outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()] 1665 1666 # we should have 5 layers looked at since 4 conv + linear + relu 1667 self.assertEqual(len(outlier_dict), 4) 1668 1669 # assert the following are true for all the modules 1670 for module_fqn in outlier_dict: 1671 # get the info for the specific module 1672 module_dict = outlier_dict[module_fqn] 1673 1674 # everything should be an outlier because we said that the max should be equal to the min for all of them 1675 # however we will just test and say most should be in case we have several 0 channel values 1676 outlier_info = module_dict[OutlierDetector.OUTLIER_KEY] 1677 assert sum(outlier_info) >= len(outlier_info) / 2 1678 1679 # ensure that the number of ratios and batches counted is the same as the number of params 1680 self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size) 1681 self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size) 1682 1683 @skipIfNoFBGEMM 1684 def test_multiple_run_consistent_spike_outlier_report_gen(self): 1685 # specifically make a row really high consistently in the number of batches that you are testing and try that 1686 # generate report after just 1 run, and after many runs (30) and make sure above minimum threshold is there 1687 with override_quantized_engine('fbgemm'): 1688 1689 # detector of interest 1690 outlier_detector = OutlierDetector(reference_percentile=0.95) 1691 1692 param_size: int = 8 1693 detector_set = {outlier_detector} 1694 model = self.LargeBatchModel(param_size=param_size) 1695 1696 # get tst model and callibrate 1697 prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model( 1698 model, detector_set, use_outlier_data=True 1699 ) 1700 1701 # now we actually callibrate the model 1702 example_input = model.get_outlier_inputs()[0] 1703 example_input = example_input.to(torch.float) 1704 1705 # now callibrate minimum 30 times to make it above minimum threshold 1706 for i in range(30): 1707 example_input = model.get_outlier_inputs()[0] 1708 example_input = example_input.to(torch.float) 1709 1710 # make 2 of the batches to have zero channel 1711 if i % 14 == 0: 1712 # make one channel constant 1713 example_input[0][1] = torch.zeros_like(example_input[0][1]) 1714 1715 prepared_for_callibrate_model(example_input) 1716 1717 # now get the report by running it through ModelReport instance 1718 generated_report = mod_report.generate_model_report(True) 1719 1720 # check that sizes are appropriate only 1 detector 1721 self.assertEqual(len(generated_report), 1) 1722 1723 # get the specific report for input weight equalization 1724 outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()] 1725 1726 # we should have 5 layers looked at since 4 conv + linear + relu 1727 self.assertEqual(len(outlier_dict), 4) 1728 1729 # assert the following are true for all the modules 1730 for module_fqn in outlier_dict: 1731 # get the info for the specific module 1732 module_dict = outlier_dict[module_fqn] 1733 1734 # because we ran 30 times, we should have at least a couple be significant 1735 # could be less because some channels could possibly be all 0 1736 sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY] 1737 assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2 1738 1739 # half of them should be outliers, because we set a really high value every 2 channels 1740 outlier_info = module_dict[OutlierDetector.OUTLIER_KEY] 1741 self.assertEqual(sum(outlier_info), len(outlier_info) / 2) 1742 1743 # ensure that the number of ratios and batches counted is the same as the number of params 1744 self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size) 1745 self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size) 1746 1747 # for the first one ensure the per channel max values are what we set 1748 if module_fqn == "linear.0": 1749 1750 # check that the non-zero channel count, at least 2 should be there 1751 # for the first module 1752 counts_info = module_dict[OutlierDetector.CONSTANT_COUNTS_KEY] 1753 assert sum(counts_info) >= 2 1754 1755 # half of the recorded max values should be what we set 1756 matched_max = sum(val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY]) 1757 self.assertEqual(matched_max, param_size / 2) 1758 1759 1760class TestFxModelReportVisualizer(QuantizationTestCase): 1761 1762 def _callibrate_and_generate_visualizer(self, model, prepared_for_callibrate_model, mod_report): 1763 r""" 1764 Callibrates the passed in model, generates report, and returns the visualizer 1765 """ 1766 # now we actually callibrate the model 1767 example_input = model.get_example_inputs()[0] 1768 example_input = example_input.to(torch.float) 1769 1770 prepared_for_callibrate_model(example_input) 1771 1772 # now get the report by running it through ModelReport instance 1773 generated_report = mod_report.generate_model_report(remove_inserted_observers=False) 1774 1775 # now we get the visualizer should not error 1776 mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer() 1777 1778 return mod_rep_visualizer 1779 1780 @skipIfNoFBGEMM 1781 def test_get_modules_and_features(self): 1782 """ 1783 Tests the get_all_unique_module_fqns and get_all_unique_feature_names methods of 1784 ModelReportVisualizer 1785 1786 Checks whether returned sets are of proper size and filtered properly 1787 """ 1788 with override_quantized_engine('fbgemm'): 1789 # set the backend for this test 1790 torch.backends.quantized.engine = "fbgemm" 1791 # test with multiple detectors 1792 detector_set = set() 1793 detector_set.add(OutlierDetector(reference_percentile=0.95)) 1794 detector_set.add(InputWeightEqualizationDetector(0.5)) 1795 1796 model = TwoThreeOps() 1797 1798 # get tst model and callibrate 1799 prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( 1800 model, detector_set, model.get_example_inputs()[0] 1801 ) 1802 1803 mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer( 1804 model, prepared_for_callibrate_model, mod_report 1805 ) 1806 1807 # ensure the module fqns match the ones given by the get_all_unique_feature_names method 1808 actual_model_fqns = set(mod_rep_visualizer.generated_reports.keys()) 1809 returned_model_fqns = mod_rep_visualizer.get_all_unique_module_fqns() 1810 self.assertEqual(returned_model_fqns, actual_model_fqns) 1811 1812 # now ensure that features are all properly returned 1813 # all the linears have all the features for two detectors 1814 # can use those as check that method is working reliably 1815 b_1_linear_features = mod_rep_visualizer.generated_reports["block1.linear"] 1816 1817 # first test all features 1818 returned_all_feats = mod_rep_visualizer.get_all_unique_feature_names(False) 1819 self.assertEqual(returned_all_feats, set(b_1_linear_features.keys())) 1820 1821 # now test plottable features 1822 plottable_set = set() 1823 1824 for feature_name in b_1_linear_features: 1825 if type(b_1_linear_features[feature_name]) == torch.Tensor: 1826 plottable_set.add(feature_name) 1827 1828 returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names() 1829 self.assertEqual(returned_plottable_feats, plottable_set) 1830 1831 def _prep_visualizer_helper(self): 1832 r""" 1833 Returns a mod rep visualizer that we test in various ways 1834 """ 1835 # set backend for test 1836 torch.backends.quantized.engine = "fbgemm" 1837 1838 # test with multiple detectors 1839 detector_set = set() 1840 detector_set.add(OutlierDetector(reference_percentile=0.95)) 1841 detector_set.add(InputWeightEqualizationDetector(0.5)) 1842 1843 model = TwoThreeOps() 1844 1845 # get tst model and callibrate 1846 prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper( 1847 model, detector_set, model.get_example_inputs()[0] 1848 ) 1849 1850 mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer( 1851 model, prepared_for_callibrate_model, mod_report 1852 ) 1853 1854 return mod_rep_visualizer 1855 1856 @skipIfNoFBGEMM 1857 def test_generate_tables_match_with_report(self): 1858 """ 1859 Tests the generate_table_view() 1860 ModelReportVisualizer 1861 1862 Checks whether the generated dict has proper information 1863 Visual check that the tables look correct performed during testing 1864 """ 1865 with override_quantized_engine('fbgemm'): 1866 1867 # get the visualizer 1868 mod_rep_visualizer = self._prep_visualizer_helper() 1869 1870 table_dict = mod_rep_visualizer.generate_filtered_tables() 1871 1872 # test primarily the dict since it has same info as str 1873 tensor_headers, tensor_table = table_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] 1874 channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] 1875 1876 # these two together should be the same as the generated report info in terms of keys 1877 tensor_info_modules = {row[1] for row in tensor_table} 1878 channel_info_modules = {row[1] for row in channel_table} 1879 combined_modules: Set = tensor_info_modules.union(channel_info_modules) 1880 1881 generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys()) 1882 self.assertEqual(combined_modules, generated_report_keys) 1883 1884 @skipIfNoFBGEMM 1885 def test_generate_tables_no_match(self): 1886 """ 1887 Tests the generate_table_view() 1888 ModelReportVisualizer 1889 1890 Checks whether the generated dict has proper information 1891 Visual check that the tables look correct performed during testing 1892 """ 1893 with override_quantized_engine('fbgemm'): 1894 # get the visualizer 1895 mod_rep_visualizer = self._prep_visualizer_helper() 1896 1897 # try a random filter and make sure that there are no rows for either table 1898 empty_tables_dict = mod_rep_visualizer.generate_filtered_tables(module_fqn_filter="random not there module") 1899 1900 # test primarily the dict since it has same info as str 1901 tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] 1902 channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] 1903 1904 tensor_info_modules = {row[1] for row in tensor_table} 1905 channel_info_modules = {row[1] for row in channel_table} 1906 combined_modules: Set = tensor_info_modules.union(channel_info_modules) 1907 self.assertEqual(len(combined_modules), 0) # should be no matching modules 1908 1909 @skipIfNoFBGEMM 1910 def test_generate_tables_single_feat_match(self): 1911 """ 1912 Tests the generate_table_view() 1913 ModelReportVisualizer 1914 1915 Checks whether the generated dict has proper information 1916 Visual check that the tables look correct performed during testing 1917 """ 1918 with override_quantized_engine('fbgemm'): 1919 # get the visualizer 1920 mod_rep_visualizer = self._prep_visualizer_helper() 1921 1922 # try a matching filter for feature and make sure only those features show up 1923 # if we filter to a very specific feature name, should only have 1 additional column in each table row 1924 single_feat_dict = mod_rep_visualizer.generate_filtered_tables(feature_filter=OutlierDetector.MAX_VALS_KEY) 1925 1926 # test primarily the dict since it has same info as str 1927 tensor_headers, tensor_table = single_feat_dict[ModelReportVisualizer.TABLE_TENSOR_KEY] 1928 channel_headers, channel_table = single_feat_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY] 1929 1930 # get the number of features in each of these 1931 tensor_info_features = len(tensor_headers) 1932 channel_info_features = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS 1933 1934 # make sure that there are no tensor features, and that there is one channel level feature 1935 self.assertEqual(tensor_info_features, 0) 1936 self.assertEqual(channel_info_features, 1) 1937 1938def _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused: bool = False): 1939 r"""Returns a model that has been prepared for callibration and corresponding model_report""" 1940 # set the backend for this test 1941 torch.backends.quantized.engine = "fbgemm" 1942 1943 # create model instance and prepare it 1944 example_input = example_input.to(torch.float) 1945 q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping() 1946 1947 # if they passed in fusion paramter, make sure to test that 1948 if fused: 1949 model = torch.ao.quantization.fuse_modules(model, model.get_fusion_modules()) 1950 1951 model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input) 1952 1953 model_report = ModelReport(model_prep, detector_set) 1954 1955 # prepare the model for callibration 1956 prepared_for_callibrate_model = model_report.prepare_detailed_calibration() 1957 1958 return (prepared_for_callibrate_model, model_report) 1959