xref: /aosp_15_r20/external/pytorch/test/quantization/fx/test_model_report_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2from typing import Set
3
4import torch
5import torch.nn as nn
6import torch.ao.quantization.quantize_fx as quantize_fx
7import torch.nn.functional as F
8from torch.ao.quantization import QConfig, QConfigMapping
9from torch.ao.quantization.fx._model_report.detector import (
10    DynamicStaticDetector,
11    InputWeightEqualizationDetector,
12    PerChannelDetector,
13    OutlierDetector,
14)
15from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
16from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
17from torch.ao.quantization.fx._model_report.model_report import ModelReport
18from torch.ao.quantization.observer import (
19    HistogramObserver,
20    default_per_channel_weight_observer,
21    default_observer
22)
23from torch.ao.nn.intrinsic.modules.fused import ConvReLU2d, LinearReLU
24from torch.testing._internal.common_quantization import (
25    ConvModel,
26    QuantizationTestCase,
27    SingleLayerLinearModel,
28    TwoLayerLinearModel,
29    skipIfNoFBGEMM,
30    skipIfNoQNNPACK,
31    override_quantized_engine,
32)
33
34
35"""
36Partition of input domain:
37
38Model contains: conv or linear, both conv and linear
39    Model contains: ConvTransposeNd (not supported for per_channel)
40
41Model is: post training quantization model, quantization aware training model
42Model is: composed with nn.Sequential, composed in class structure
43
44QConfig utilizes per_channel weight observer, backend uses non per_channel weight observer
45QConfig_dict uses only one default qconfig, Qconfig dict uses > 1 unique qconfigs
46
47Partition on output domain:
48
49There are possible changes / suggestions, there are no changes / suggestions
50"""
51
52# Default output for string if no optimizations are possible
53DEFAULT_NO_OPTIMS_ANSWER_STRING = (
54    "Further Optimizations for backend {}: \nNo further per_channel optimizations possible."
55)
56
57# Example Sequential Model with multiple Conv and Linear with nesting involved
58NESTED_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
59    torch.nn.Conv2d(3, 3, 2, 1),
60    torch.nn.Sequential(torch.nn.Linear(9, 27), torch.nn.ReLU()),
61    torch.nn.Linear(27, 27),
62    torch.nn.ReLU(),
63    torch.nn.Conv2d(3, 3, 2, 1),
64)
65
66# Example Sequential Model with Conv sub-class example
67LAZY_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
68    torch.nn.LazyConv2d(3, 3, 2, 1),
69    torch.nn.Sequential(torch.nn.Linear(5, 27), torch.nn.ReLU()),
70    torch.nn.ReLU(),
71    torch.nn.Linear(27, 27),
72    torch.nn.ReLU(),
73    torch.nn.LazyConv2d(3, 3, 2, 1),
74)
75
76# Example Sequential Model with Fusion directly built into model
77FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
78    ConvReLU2d(torch.nn.Conv2d(3, 3, 2, 1), torch.nn.ReLU()),
79    torch.nn.Sequential(LinearReLU(torch.nn.Linear(9, 27), torch.nn.ReLU())),
80    LinearReLU(torch.nn.Linear(27, 27), torch.nn.ReLU()),
81    torch.nn.Conv2d(3, 3, 2, 1),
82)
83
84# Test class
85# example model to use for tests
86class ThreeOps(nn.Module):
87    def __init__(self) -> None:
88        super().__init__()
89        self.linear = nn.Linear(3, 3)
90        self.bn = nn.BatchNorm2d(3)
91        self.relu = nn.ReLU()
92
93    def forward(self, x):
94        x = self.linear(x)
95        x = self.bn(x)
96        x = self.relu(x)
97        return x
98
99    def get_example_inputs(self):
100        return (torch.randn(1, 3, 3, 3),)
101
102class TwoThreeOps(nn.Module):
103    def __init__(self) -> None:
104        super().__init__()
105        self.block1 = ThreeOps()
106        self.block2 = ThreeOps()
107
108    def forward(self, x):
109        x = self.block1(x)
110        y = self.block2(x)
111        z = x + y
112        z = F.relu(z)
113        return z
114
115    def get_example_inputs(self):
116        return (torch.randn(1, 3, 3, 3),)
117
118class TestFxModelReportDetector(QuantizationTestCase):
119
120    """Prepares and calibrate the model"""
121
122    def _prepare_model_and_run_input(self, model, q_config_mapping, input):
123        model_prep = torch.ao.quantization.quantize_fx.prepare_fx(model, q_config_mapping, input)  # prep model
124        model_prep(input).sum()  # calibrate the model
125        return model_prep
126
127    """Case includes:
128        one conv or linear
129        post training quantization
130        composed as module
131        qconfig uses per_channel weight observer
132        Only 1 qconfig in qconfig dict
133        Output has no changes / suggestions
134    """
135
136    @skipIfNoFBGEMM
137    def test_simple_conv(self):
138
139        with override_quantized_engine('fbgemm'):
140            torch.backends.quantized.engine = "fbgemm"
141
142            q_config_mapping = QConfigMapping()
143            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
144
145            input = torch.randn(1, 3, 10, 10)
146            prepared_model = self._prepare_model_and_run_input(ConvModel(), q_config_mapping, input)
147
148            # run the detector
149            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
150            optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
151
152            # no optims possible and there should be nothing in per_channel_status
153            self.assertEqual(
154                optims_str,
155                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
156            )
157
158            # there should only be one conv there in this model
159            self.assertEqual(per_channel_info["conv"]["backend"], torch.backends.quantized.engine)
160            self.assertEqual(len(per_channel_info), 1)
161            self.assertEqual(next(iter(per_channel_info)), "conv")
162            self.assertEqual(
163                per_channel_info["conv"]["per_channel_quantization_supported"],
164                True,
165            )
166            self.assertEqual(per_channel_info["conv"]["per_channel_quantization_used"], True)
167
168    """Case includes:
169        Multiple conv or linear
170        post training quantization
171        composed as module
172        qconfig doesn't use per_channel weight observer
173        Only 1 qconfig in qconfig dict
174        Output has possible changes / suggestions
175    """
176
177    @skipIfNoQNNPACK
178    def test_multi_linear_model_without_per_channel(self):
179
180        with override_quantized_engine('qnnpack'):
181            torch.backends.quantized.engine = "qnnpack"
182
183            q_config_mapping = QConfigMapping()
184            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
185
186            prepared_model = self._prepare_model_and_run_input(
187                TwoLayerLinearModel(),
188                q_config_mapping,
189                TwoLayerLinearModel().get_example_inputs()[0],
190            )
191
192            # run the detector
193            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
194            optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
195
196            # there should be optims possible
197            self.assertNotEqual(
198                optims_str,
199                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
200            )
201            # pick a random key to look at
202            rand_key: str = next(iter(per_channel_info.keys()))
203            self.assertEqual(per_channel_info[rand_key]["backend"], torch.backends.quantized.engine)
204            self.assertEqual(len(per_channel_info), 2)
205
206            # for each linear layer, should be supported but not used
207            for linear_key in per_channel_info.keys():
208                module_entry = per_channel_info[linear_key]
209
210                self.assertEqual(module_entry["per_channel_quantization_supported"], True)
211                self.assertEqual(module_entry["per_channel_quantization_used"], False)
212
213    """Case includes:
214        Multiple conv or linear
215        post training quantization
216        composed as Module
217        qconfig doesn't use per_channel weight observer
218        More than 1 qconfig in qconfig dict
219        Output has possible changes / suggestions
220    """
221
222    @skipIfNoQNNPACK
223    def test_multiple_q_config_options(self):
224
225        with override_quantized_engine('qnnpack'):
226            torch.backends.quantized.engine = "qnnpack"
227
228            # qconfig with support for per_channel quantization
229            per_channel_qconfig = QConfig(
230                activation=HistogramObserver.with_args(reduce_range=True),
231                weight=default_per_channel_weight_observer,
232            )
233
234            # we need to design the model
235            class ConvLinearModel(torch.nn.Module):
236                def __init__(self) -> None:
237                    super().__init__()
238                    self.conv1 = torch.nn.Conv2d(3, 3, 2, 1)
239                    self.fc1 = torch.nn.Linear(9, 27)
240                    self.relu = torch.nn.ReLU()
241                    self.fc2 = torch.nn.Linear(27, 27)
242                    self.conv2 = torch.nn.Conv2d(3, 3, 2, 1)
243
244                def forward(self, x):
245                    x = self.conv1(x)
246                    x = self.fc1(x)
247                    x = self.relu(x)
248                    x = self.fc2(x)
249                    x = self.conv2(x)
250                    return x
251
252            q_config_mapping = QConfigMapping()
253            q_config_mapping.set_global(
254                torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine)
255            ).set_object_type(torch.nn.Conv2d, per_channel_qconfig)
256
257            prepared_model = self._prepare_model_and_run_input(
258                ConvLinearModel(),
259                q_config_mapping,
260                torch.randn(1, 3, 10, 10),
261            )
262
263            # run the detector
264            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
265            optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
266
267            # the only suggestions should be to linear layers
268
269            # there should be optims possible
270            self.assertNotEqual(
271                optims_str,
272                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
273            )
274
275            # to ensure it got into the nested layer
276            self.assertEqual(len(per_channel_info), 4)
277
278            # for each layer, should be supported but not used
279            for key in per_channel_info.keys():
280                module_entry = per_channel_info[key]
281                self.assertEqual(module_entry["per_channel_quantization_supported"], True)
282
283                # if linear False, if conv2d true cuz it uses different config
284                if "fc" in key:
285                    self.assertEqual(module_entry["per_channel_quantization_used"], False)
286                elif "conv" in key:
287                    self.assertEqual(module_entry["per_channel_quantization_used"], True)
288                else:
289                    raise ValueError("Should only contain conv and linear layers as key values")
290
291    """Case includes:
292        Multiple conv or linear
293        post training quantization
294        composed as sequential
295        qconfig doesn't use per_channel weight observer
296        Only 1 qconfig in qconfig dict
297        Output has possible changes / suggestions
298    """
299
300    @skipIfNoQNNPACK
301    def test_sequential_model_format(self):
302
303        with override_quantized_engine('qnnpack'):
304            torch.backends.quantized.engine = "qnnpack"
305
306            q_config_mapping = QConfigMapping()
307            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
308
309            prepared_model = self._prepare_model_and_run_input(
310                NESTED_CONV_LINEAR_EXAMPLE,
311                q_config_mapping,
312                torch.randn(1, 3, 10, 10),
313            )
314
315            # run the detector
316            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
317            optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
318
319            # there should be optims possible
320            self.assertNotEqual(
321                optims_str,
322                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
323            )
324
325            # to ensure it got into the nested layer
326            self.assertEqual(len(per_channel_info), 4)
327
328            # for each layer, should be supported but not used
329            for key in per_channel_info.keys():
330                module_entry = per_channel_info[key]
331
332                self.assertEqual(module_entry["per_channel_quantization_supported"], True)
333                self.assertEqual(module_entry["per_channel_quantization_used"], False)
334
335    """Case includes:
336        Multiple conv or linear
337        post training quantization
338        composed as sequential
339        qconfig doesn't use per_channel weight observer
340        Only 1 qconfig in qconfig dict
341        Output has possible changes / suggestions
342    """
343
344    @skipIfNoQNNPACK
345    def test_conv_sub_class_considered(self):
346
347        with override_quantized_engine('qnnpack'):
348            torch.backends.quantized.engine = "qnnpack"
349
350            q_config_mapping = QConfigMapping()
351            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
352
353            prepared_model = self._prepare_model_and_run_input(
354                LAZY_CONV_LINEAR_EXAMPLE,
355                q_config_mapping,
356                torch.randn(1, 3, 10, 10),
357            )
358
359            # run the detector
360            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
361            optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
362
363            # there should be optims possible
364            self.assertNotEqual(
365                optims_str,
366                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
367            )
368
369            # to ensure it got into the nested layer and it considered the lazyConv2d
370            self.assertEqual(len(per_channel_info), 4)
371
372            # for each layer, should be supported but not used
373            for key in per_channel_info.keys():
374                module_entry = per_channel_info[key]
375
376                self.assertEqual(module_entry["per_channel_quantization_supported"], True)
377                self.assertEqual(module_entry["per_channel_quantization_used"], False)
378
379    """Case includes:
380        Multiple conv or linear
381        post training quantization
382        composed as sequential
383        qconfig uses per_channel weight observer
384        Only 1 qconfig in qconfig dict
385        Output has no possible changes / suggestions
386    """
387
388    @skipIfNoFBGEMM
389    def test_fusion_layer_in_sequential(self):
390
391        with override_quantized_engine('fbgemm'):
392            torch.backends.quantized.engine = "fbgemm"
393
394            q_config_mapping = QConfigMapping()
395            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
396
397            prepared_model = self._prepare_model_and_run_input(
398                FUSION_CONV_LINEAR_EXAMPLE,
399                q_config_mapping,
400                torch.randn(1, 3, 10, 10),
401            )
402
403            # run the detector
404            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
405            optims_str, per_channel_info = per_channel_detector.generate_detector_report(prepared_model)
406
407            # no optims possible and there should be nothing in per_channel_status
408            self.assertEqual(
409                optims_str,
410                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
411            )
412
413            # to ensure it got into the nested layer and it considered all the nested fusion components
414            self.assertEqual(len(per_channel_info), 4)
415
416            # for each layer, should be supported but not used
417            for key in per_channel_info.keys():
418                module_entry = per_channel_info[key]
419                self.assertEqual(module_entry["per_channel_quantization_supported"], True)
420                self.assertEqual(module_entry["per_channel_quantization_used"], True)
421
422    """Case includes:
423        Multiple conv or linear
424        quantitative aware training
425        composed as model
426        qconfig does not use per_channel weight observer
427        Only 1 qconfig in qconfig dict
428        Output has possible changes / suggestions
429    """
430
431    @skipIfNoQNNPACK
432    def test_qat_aware_model_example(self):
433
434        # first we want a QAT model
435        class QATConvLinearReluModel(torch.nn.Module):
436            def __init__(self) -> None:
437                super().__init__()
438                # QuantStub converts tensors from floating point to quantized
439                self.quant = torch.ao.quantization.QuantStub()
440                self.conv = torch.nn.Conv2d(1, 1, 1)
441                self.bn = torch.nn.BatchNorm2d(1)
442                self.relu = torch.nn.ReLU()
443                # DeQuantStub converts tensors from quantized to floating point
444                self.dequant = torch.ao.quantization.DeQuantStub()
445
446            def forward(self, x):
447                x = self.quant(x)
448                x = self.conv(x)
449                x = self.bn(x)
450                x = self.relu(x)
451                x = self.dequant(x)
452                return x
453
454        with override_quantized_engine('qnnpack'):
455            # create a model instance
456            model_fp32 = QATConvLinearReluModel()
457
458            model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig("qnnpack")
459
460            # model must be in eval mode for fusion
461            model_fp32.eval()
462            model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [["conv", "bn", "relu"]])
463
464            # model must be set to train mode for QAT logic to work
465            model_fp32_fused.train()
466
467            # prepare the model for QAT, different than for post training quantization
468            model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused)
469
470            # run the detector
471            per_channel_detector = PerChannelDetector(torch.backends.quantized.engine)
472            optims_str, per_channel_info = per_channel_detector.generate_detector_report(model_fp32_prepared)
473
474            # there should be optims possible
475            self.assertNotEqual(
476                optims_str,
477                DEFAULT_NO_OPTIMS_ANSWER_STRING.format(torch.backends.quantized.engine),
478            )
479
480            # make sure it was able to find the single conv in the fused model
481            self.assertEqual(len(per_channel_info), 1)
482
483            # for the one conv, it should still give advice to use different qconfig
484            for key in per_channel_info.keys():
485                module_entry = per_channel_info[key]
486                self.assertEqual(module_entry["per_channel_quantization_supported"], True)
487                self.assertEqual(module_entry["per_channel_quantization_used"], False)
488
489
490"""
491Partition on Domain / Things to Test
492
493- All zero tensor
494- Multiple tensor dimensions
495- All of the outward facing functions
496- Epoch min max are correctly updating
497- Batch range is correctly averaging as expected
498- Reset for each epoch is correctly resetting the values
499
500Partition on Output
501- the calcuation of the ratio is occurring correctly
502
503"""
504
505
506class TestFxModelReportObserver(QuantizationTestCase):
507    class NestedModifiedSingleLayerLinear(torch.nn.Module):
508        def __init__(self) -> None:
509            super().__init__()
510            self.obs1 = ModelReportObserver()
511            self.mod1 = SingleLayerLinearModel()
512            self.obs2 = ModelReportObserver()
513            self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
514            self.relu = torch.nn.ReLU()
515
516        def forward(self, x):
517            x = self.obs1(x)
518            x = self.mod1(x)
519            x = self.obs2(x)
520            x = self.fc1(x)
521            x = self.relu(x)
522            return x
523
524    def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size):
525        # split up data into batches
526        split_up_data = torch.split(ex_input, batch_size)
527        for epoch in range(num_epochs):
528            # reset all model report obs
529            model.apply(
530                lambda module: module.reset_batch_and_epoch_values()
531                if isinstance(module, ModelReportObserver)
532                else None
533            )
534
535            # quick check that a reset occurred
536            self.assertEqual(
537                model.obs1.average_batch_activation_range,
538                torch.tensor(float(0)),
539            )
540            self.assertEqual(model.obs1.epoch_activation_min, torch.tensor(float("inf")))
541            self.assertEqual(model.obs1.epoch_activation_max, torch.tensor(float("-inf")))
542
543            # loop through the batches and run through
544            for index, batch in enumerate(split_up_data):
545
546                num_tracked_so_far = model.obs1.num_batches_tracked
547                self.assertEqual(num_tracked_so_far, index)
548
549                # get general info about the batch and the model to use later
550                batch_min, batch_max = torch.aminmax(batch)
551                current_average_range = model.obs1.average_batch_activation_range
552                current_epoch_min = model.obs1.epoch_activation_min
553                current_epoch_max = model.obs1.epoch_activation_max
554
555                # run input through
556                model(ex_input)
557
558                # check that average batch activation range updated correctly
559                correct_updated_value = (current_average_range * num_tracked_so_far + (batch_max - batch_min)) / (
560                    num_tracked_so_far + 1
561                )
562                self.assertEqual(
563                    model.obs1.average_batch_activation_range,
564                    correct_updated_value,
565                )
566
567                if current_epoch_max - current_epoch_min > 0:
568                    self.assertEqual(
569                        model.obs1.get_batch_to_epoch_ratio(),
570                        correct_updated_value / (current_epoch_max - current_epoch_min),
571                    )
572
573    """Case includes:
574        all zero tensor
575        dim size = 2
576        run for 1 epoch
577        run for 10 batch
578        tests input data observer
579    """
580
581    def test_zero_tensor_errors(self):
582        # initialize the model
583        model = self.NestedModifiedSingleLayerLinear()
584
585        # generate the desired input
586        ex_input = torch.zeros((10, 1, 5))
587
588        # run it through the model and do general tests
589        self.run_model_and_common_checks(model, ex_input, 1, 1)
590
591        # make sure final values are all 0
592        self.assertEqual(model.obs1.epoch_activation_min, 0)
593        self.assertEqual(model.obs1.epoch_activation_max, 0)
594        self.assertEqual(model.obs1.average_batch_activation_range, 0)
595
596        # we should get an error if we try to calculate the ratio
597        with self.assertRaises(ValueError):
598            ratio_val = model.obs1.get_batch_to_epoch_ratio()
599
600    """Case includes:
601    non-zero tensor
602    dim size = 2
603    run for 1 epoch
604    run for 1 batch
605    tests input data observer
606    """
607
608    def test_single_batch_of_ones(self):
609        # initialize the model
610        model = self.NestedModifiedSingleLayerLinear()
611
612        # generate the desired input
613        ex_input = torch.ones((1, 1, 5))
614
615        # run it through the model and do general tests
616        self.run_model_and_common_checks(model, ex_input, 1, 1)
617
618        # make sure final values are all 0 except for range
619        self.assertEqual(model.obs1.epoch_activation_min, 1)
620        self.assertEqual(model.obs1.epoch_activation_max, 1)
621        self.assertEqual(model.obs1.average_batch_activation_range, 0)
622
623        # we should get an error if we try to calculate the ratio
624        with self.assertRaises(ValueError):
625            ratio_val = model.obs1.get_batch_to_epoch_ratio()
626
627    """Case includes:
628    non-zero tensor
629    dim size = 2
630    run for 10 epoch
631    run for 15 batch
632    tests non input data observer
633    """
634
635    def test_observer_after_relu(self):
636
637        # model specific to this test
638        class NestedModifiedObserverAfterRelu(torch.nn.Module):
639            def __init__(self) -> None:
640                super().__init__()
641                self.obs1 = ModelReportObserver()
642                self.mod1 = SingleLayerLinearModel()
643                self.obs2 = ModelReportObserver()
644                self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
645                self.relu = torch.nn.ReLU()
646
647            def forward(self, x):
648                x = self.obs1(x)
649                x = self.mod1(x)
650                x = self.fc1(x)
651                x = self.relu(x)
652                x = self.obs2(x)
653                return x
654
655        # initialize the model
656        model = NestedModifiedObserverAfterRelu()
657
658        # generate the desired input
659        ex_input = torch.randn((15, 1, 5))
660
661        # run it through the model and do general tests
662        self.run_model_and_common_checks(model, ex_input, 10, 15)
663
664    """Case includes:
665        non-zero tensor
666        dim size = 2
667        run for multiple epoch
668        run for multiple batch
669        tests input data observer
670    """
671
672    def test_random_epochs_and_batches(self):
673
674        # set up a basic model
675        class TinyNestModule(torch.nn.Module):
676            def __init__(self) -> None:
677                super().__init__()
678                self.obs1 = ModelReportObserver()
679                self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
680                self.relu = torch.nn.ReLU()
681                self.obs2 = ModelReportObserver()
682
683            def forward(self, x):
684                x = self.obs1(x)
685                x = self.fc1(x)
686                x = self.relu(x)
687                x = self.obs2(x)
688                return x
689
690        class LargerIncludeNestModel(torch.nn.Module):
691            def __init__(self) -> None:
692                super().__init__()
693                self.obs1 = ModelReportObserver()
694                self.nested = TinyNestModule()
695                self.fc1 = SingleLayerLinearModel()
696                self.relu = torch.nn.ReLU()
697
698            def forward(self, x):
699                x = self.obs1(x)
700                x = self.nested(x)
701                x = self.fc1(x)
702                x = self.relu(x)
703                return x
704
705        class ModifiedThreeOps(torch.nn.Module):
706            def __init__(self, batch_norm_dim):
707                super().__init__()
708                self.obs1 = ModelReportObserver()
709                self.linear = torch.nn.Linear(7, 3, 2)
710                self.obs2 = ModelReportObserver()
711
712                if batch_norm_dim == 2:
713                    self.bn = torch.nn.BatchNorm2d(2)
714                elif batch_norm_dim == 3:
715                    self.bn = torch.nn.BatchNorm3d(4)
716                else:
717                    raise ValueError("Dim should only be 2 or 3")
718
719                self.relu = torch.nn.ReLU()
720
721            def forward(self, x):
722                x = self.obs1(x)
723                x = self.linear(x)
724                x = self.obs2(x)
725                x = self.bn(x)
726                x = self.relu(x)
727                return x
728
729        class HighDimensionNet(torch.nn.Module):
730            def __init__(self) -> None:
731                super().__init__()
732                self.obs1 = ModelReportObserver()
733                self.fc1 = torch.nn.Linear(3, 7)
734                self.block1 = ModifiedThreeOps(3)
735                self.fc2 = torch.nn.Linear(3, 7)
736                self.block2 = ModifiedThreeOps(3)
737                self.fc3 = torch.nn.Linear(3, 7)
738
739            def forward(self, x):
740                x = self.obs1(x)
741                x = self.fc1(x)
742                x = self.block1(x)
743                x = self.fc2(x)
744                y = self.block2(x)
745                y = self.fc3(y)
746                z = x + y
747                z = F.relu(z)
748                return z
749
750        # the purpose of this test is to give the observers a variety of data examples
751        # initialize the model
752        models = [
753            self.NestedModifiedSingleLayerLinear(),
754            LargerIncludeNestModel(),
755            ModifiedThreeOps(2),
756            HighDimensionNet(),
757        ]
758
759        # get some number of epochs and batches
760        num_epochs = 10
761        num_batches = 15
762
763        input_shapes = [(1, 5), (1, 5), (2, 3, 7), (4, 1, 8, 3)]
764
765        # generate the desired inputs
766        inputs = []
767        for shape in input_shapes:
768            ex_input = torch.randn((num_batches, *shape))
769            inputs.append(ex_input)
770
771        # run it through the model and do general tests
772        for index, model in enumerate(models):
773            self.run_model_and_common_checks(model, inputs[index], num_epochs, num_batches)
774
775
776"""
777Partition on domain / things to test
778
779There is only a single test case for now.
780
781This will be more thoroughly tested with the implementation of the full end to end tool coming soon.
782"""
783
784
785class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
786    @skipIfNoFBGEMM
787    def test_nested_detection_case(self):
788        class SingleLinear(torch.nn.Module):
789            def __init__(self) -> None:
790                super().__init__()
791                self.linear = torch.nn.Linear(3, 3)
792
793            def forward(self, x):
794                x = self.linear(x)
795                return x
796
797        class TwoBlockNet(torch.nn.Module):
798            def __init__(self) -> None:
799                super().__init__()
800                self.block1 = SingleLinear()
801                self.block2 = SingleLinear()
802
803            def forward(self, x):
804                x = self.block1(x)
805                y = self.block2(x)
806                z = x + y
807                z = F.relu(z)
808                return z
809
810
811        with override_quantized_engine('fbgemm'):
812            # create model, example input, and qconfig mapping
813            torch.backends.quantized.engine = "fbgemm"
814            model = TwoBlockNet()
815            example_input = torch.randint(-10, 0, (1, 3, 3, 3))
816            example_input = example_input.to(torch.float)
817            q_config_mapping = QConfigMapping()
818            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig("fbgemm"))
819
820            # prep model and select observer
821            model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
822            obs_ctr = ModelReportObserver
823
824            # find layer to attach to and store
825            linear_fqn = "block2.linear"  # fqn of target linear
826
827            target_linear = None
828            for node in model_prep.graph.nodes:
829                if node.target == linear_fqn:
830                    target_linear = node
831                    break
832
833            # insert into both module and graph pre and post
834
835            # set up to insert before target_linear (pre_observer)
836            with model_prep.graph.inserting_before(target_linear):
837                obs_to_insert = obs_ctr()
838                pre_obs_fqn = linear_fqn + ".model_report_pre_observer"
839                model_prep.add_submodule(pre_obs_fqn, obs_to_insert)
840                model_prep.graph.create_node(op="call_module", target=pre_obs_fqn, args=target_linear.args)
841
842            # set up and insert after the target_linear (post_observer)
843            with model_prep.graph.inserting_after(target_linear):
844                obs_to_insert = obs_ctr()
845                post_obs_fqn = linear_fqn + ".model_report_post_observer"
846                model_prep.add_submodule(post_obs_fqn, obs_to_insert)
847                model_prep.graph.create_node(op="call_module", target=post_obs_fqn, args=(target_linear,))
848
849            # need to recompile module after submodule added and pass input through
850            model_prep.recompile()
851
852            num_iterations = 10
853            for i in range(num_iterations):
854                if i % 2 == 0:
855                    example_input = torch.randint(-10, 0, (1, 3, 3, 3)).to(torch.float)
856                else:
857                    example_input = torch.randint(0, 10, (1, 3, 3, 3)).to(torch.float)
858                model_prep(example_input)
859
860            # run it through the dynamic vs static detector
861            dynamic_vs_static_detector = DynamicStaticDetector()
862            dynam_vs_stat_str, dynam_vs_stat_dict = dynamic_vs_static_detector.generate_detector_report(model_prep)
863
864            # one of the stats should be stationary, and the other non-stationary
865            # as a result, dynamic should be recommended
866            data_dist_info = [
867                dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.PRE_OBS_DATA_DIST_KEY],
868                dynam_vs_stat_dict[linear_fqn][DynamicStaticDetector.POST_OBS_DATA_DIST_KEY],
869            ]
870
871            self.assertTrue("stationary" in data_dist_info)
872            self.assertTrue("non-stationary" in data_dist_info)
873            self.assertTrue(dynam_vs_stat_dict[linear_fqn]["dynamic_recommended"])
874
875class TestFxModelReportClass(QuantizationTestCase):
876
877    @skipIfNoFBGEMM
878    def test_constructor(self):
879        """
880        Tests the constructor of the ModelReport class.
881        Specifically looks at:
882        - The desired reports
883        - Ensures that the observers of interest are properly initialized
884        """
885
886        with override_quantized_engine('fbgemm'):
887            # set the backend for this test
888            torch.backends.quantized.engine = "fbgemm"
889            backend = torch.backends.quantized.engine
890
891            # create a model
892            model = ThreeOps()
893            q_config_mapping = QConfigMapping()
894            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
895            model_prep = quantize_fx.prepare_fx(model, q_config_mapping, model.get_example_inputs()[0])
896
897            # make an example set of detectors
898            test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
899            # initialize with an empty detector
900            model_report = ModelReport(model_prep, test_detector_set)
901
902            # make sure internal valid reports matches
903            detector_name_set = {detector.get_detector_name() for detector in test_detector_set}
904            self.assertEqual(model_report.get_desired_reports_names(), detector_name_set)
905
906            # now attempt with no valid reports, should raise error
907            with self.assertRaises(ValueError):
908                model_report = ModelReport(model, set())
909
910            # number of expected obs of interest entries
911            num_expected_entries = len(test_detector_set)
912            self.assertEqual(len(model_report.get_observers_of_interest()), num_expected_entries)
913
914            for value in model_report.get_observers_of_interest().values():
915                self.assertEqual(len(value), 0)
916
917    @skipIfNoFBGEMM
918    def test_prepare_model_callibration(self):
919        """
920        Tests model_report.prepare_detailed_calibration that prepares the model for callibration
921        Specifically looks at:
922        - Whether observers are properly inserted into regular nn.Module
923        - Whether the target and the arguments of the observers are proper
924        - Whether the internal representation of observers of interest is updated
925        """
926
927        with override_quantized_engine('fbgemm'):
928            # create model report object
929
930            # create model
931            model = TwoThreeOps()
932            # make an example set of detectors
933            torch.backends.quantized.engine = "fbgemm"
934            backend = torch.backends.quantized.engine
935            test_detector_set = {DynamicStaticDetector(), PerChannelDetector(backend)}
936            # initialize with an empty detector
937
938            # prepare the model
939            example_input = model.get_example_inputs()[0]
940            current_backend = torch.backends.quantized.engine
941            q_config_mapping = QConfigMapping()
942            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
943
944            model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
945
946            model_report = ModelReport(model_prep, test_detector_set)
947
948            # prepare the model for callibration
949            prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
950
951            # see whether observers properly in regular nn.Module
952            # there should be 4 observers present in this case
953            modules_observer_cnt = 0
954            for fqn, module in prepared_for_callibrate_model.named_modules():
955                if isinstance(module, ModelReportObserver):
956                    modules_observer_cnt += 1
957
958            self.assertEqual(modules_observer_cnt, 4)
959
960            model_report_str_check = "model_report"
961            # also make sure arguments for observers in the graph are proper
962            for node in prepared_for_callibrate_model.graph.nodes:
963                # not all node targets are strings, so check
964                if isinstance(node.target, str) and model_report_str_check in node.target:
965                    # if pre-observer has same args as the linear (next node)
966                    if "pre_observer" in node.target:
967                        self.assertEqual(node.args, node.next.args)
968                    # if post-observer, args are the target linear (previous node)
969                    if "post_observer" in node.target:
970                        self.assertEqual(node.args, (node.prev,))
971
972            # ensure model_report observers of interest updated
973            # there should be two entries
974            self.assertEqual(len(model_report.get_observers_of_interest()), 2)
975            for detector in test_detector_set:
976                self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest().keys())
977
978                # get number of entries for this detector
979                detector_obs_of_interest_fqns = model_report.get_observers_of_interest()[detector.get_detector_name()]
980
981                # assert that the per channel detector has 0 and the dynamic static has 4
982                if isinstance(detector, PerChannelDetector):
983                    self.assertEqual(len(detector_obs_of_interest_fqns), 0)
984                elif isinstance(detector, DynamicStaticDetector):
985                    self.assertEqual(len(detector_obs_of_interest_fqns), 4)
986
987            # ensure that we can prepare for callibration only once
988            with self.assertRaises(ValueError):
989                prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
990
991
992    def get_module_and_graph_cnts(self, callibrated_fx_module):
993        r"""
994        Calculates number of ModelReportObserver modules in the model as well as the graph structure.
995        Returns a tuple of two elements:
996        int: The number of ModelReportObservers found in the model
997        int: The number of model_report nodes found in the graph
998        """
999        # get the number of observers stored as modules
1000        modules_observer_cnt = 0
1001        for fqn, module in callibrated_fx_module.named_modules():
1002            if isinstance(module, ModelReportObserver):
1003                modules_observer_cnt += 1
1004
1005        # get number of observers in the graph
1006        model_report_str_check = "model_report"
1007        graph_observer_cnt = 0
1008        # also make sure arguments for observers in the graph are proper
1009        for node in callibrated_fx_module.graph.nodes:
1010            # not all node targets are strings, so check
1011            if isinstance(node.target, str) and model_report_str_check in node.target:
1012                # increment if we found a graph observer
1013                graph_observer_cnt += 1
1014
1015        return (modules_observer_cnt, graph_observer_cnt)
1016
1017    @skipIfNoFBGEMM
1018    def test_generate_report(self):
1019        """
1020            Tests model_report.generate_model_report to ensure report generation
1021            Specifically looks at:
1022            - Whether correct number of reports are being generated
1023            - Whether observers are being properly removed if specified
1024            - Whether correct blocking from generating report twice if obs removed
1025        """
1026
1027        with override_quantized_engine('fbgemm'):
1028            # set the backend for this test
1029            torch.backends.quantized.engine = "fbgemm"
1030
1031            # check whether the correct number of reports are being generated
1032            filled_detector_set = {DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)}
1033            single_detector_set = {DynamicStaticDetector()}
1034
1035            # create our models
1036            model_full = TwoThreeOps()
1037            model_single = TwoThreeOps()
1038
1039            # prepare and callibrate two different instances of same model
1040            # prepare the model
1041            example_input = model_full.get_example_inputs()[0]
1042            current_backend = torch.backends.quantized.engine
1043            q_config_mapping = QConfigMapping()
1044            q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
1045
1046            model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input)
1047            model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input)
1048
1049            # initialize one with filled detector
1050            model_report_full = ModelReport(model_prep_full, filled_detector_set)
1051            # initialize another with a single detector set
1052            model_report_single = ModelReport(model_prep_single, single_detector_set)
1053
1054            # prepare the models for callibration
1055            prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration()
1056            prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration()
1057
1058            # now callibrate the two models
1059            num_iterations = 10
1060            for i in range(num_iterations):
1061                example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float)
1062                prepared_for_callibrate_model_full(example_input)
1063                prepared_for_callibrate_model_single(example_input)
1064
1065            # now generate the reports
1066            model_full_report = model_report_full.generate_model_report(True)
1067            model_single_report = model_report_single.generate_model_report(False)
1068
1069            # check that sizes are appropriate
1070            self.assertEqual(len(model_full_report), len(filled_detector_set))
1071            self.assertEqual(len(model_single_report), len(single_detector_set))
1072
1073            # make sure observers are being properly removed for full report since we put flag in
1074            modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full)
1075            self.assertEqual(modules_observer_cnt, 0)  # assert no more observer modules
1076            self.assertEqual(graph_observer_cnt, 0)  # assert no more observer nodes in graph
1077
1078            # make sure observers aren't being removed for single report since not specified
1079            modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single)
1080            self.assertNotEqual(modules_observer_cnt, 0)
1081            self.assertNotEqual(graph_observer_cnt, 0)
1082
1083            # make sure error when try to rerun report generation for full report but not single report
1084            with self.assertRaises(Exception):
1085                model_full_report = model_report_full.generate_model_report(
1086                    prepared_for_callibrate_model_full, False
1087                )
1088
1089            # make sure we don't run into error for single report
1090            model_single_report = model_report_single.generate_model_report(False)
1091
1092    @skipIfNoFBGEMM
1093    def test_generate_visualizer(self):
1094        """
1095        Tests that the ModelReport class can properly create the ModelReportVisualizer instance
1096        Checks that:
1097            - Correct number of modules are represented
1098            - Modules are sorted
1099            - Correct number of features for each module
1100        """
1101        with override_quantized_engine('fbgemm'):
1102            # set the backend for this test
1103            torch.backends.quantized.engine = "fbgemm"
1104            # test with multiple detectors
1105            detector_set = set()
1106            detector_set.add(OutlierDetector(reference_percentile=0.95))
1107            detector_set.add(InputWeightEqualizationDetector(0.5))
1108
1109            model = TwoThreeOps()
1110
1111            # get tst model and callibrate
1112            prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
1113                model, detector_set, model.get_example_inputs()[0]
1114            )
1115
1116            # now we actually callibrate the model
1117            example_input = model.get_example_inputs()[0]
1118            example_input = example_input.to(torch.float)
1119
1120            prepared_for_callibrate_model(example_input)
1121
1122            # try to visualize without generating report, should throw error
1123            with self.assertRaises(Exception):
1124                mod_rep_visualizaiton = mod_report.generate_visualizer()
1125
1126            # now get the report by running it through ModelReport instance
1127            generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
1128
1129            # now we get the visualizer should not error
1130            mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
1131
1132            # since we tested with outlier detector, which looks at every base level module
1133            # should be six entries in the ordered dict
1134            mod_fqns_to_features = mod_rep_visualizer.generated_reports
1135
1136            self.assertEqual(len(mod_fqns_to_features), 6)
1137
1138            # outlier detector has 9 feature per module
1139            # input-weight has 12 features per module
1140            # there are 1 common data point, so should be 12 + 9 - 1 = 20 unique features per common modules
1141            # all linears will be common
1142            for module_fqn in mod_fqns_to_features:
1143                if ".linear" in module_fqn:
1144                    linear_info = mod_fqns_to_features[module_fqn]
1145                    self.assertEqual(len(linear_info), 20)
1146
1147    @skipIfNoFBGEMM
1148    def test_qconfig_mapping_generation(self):
1149        """
1150        Tests for generation of qconfigs by ModelReport API
1151        - Tests that qconfigmapping is generated
1152        - Tests that mappings include information for for relavent modules
1153        """
1154        with override_quantized_engine('fbgemm'):
1155            # set the backend for this test
1156            torch.backends.quantized.engine = "fbgemm"
1157            # test with multiple detectors
1158            detector_set = set()
1159            detector_set.add(PerChannelDetector())
1160            detector_set.add(DynamicStaticDetector())
1161
1162            model = TwoThreeOps()
1163
1164            # get tst model and callibrate
1165            prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
1166                model, detector_set, model.get_example_inputs()[0]
1167            )
1168
1169            # now we actually callibrate the models
1170            example_input = model.get_example_inputs()[0]
1171            example_input = example_input.to(torch.float)
1172
1173            prepared_for_callibrate_model(example_input)
1174
1175
1176            # get the mapping without error
1177            qconfig_mapping = mod_report.generate_qconfig_mapping()
1178
1179            # now get the report by running it through ModelReport instance
1180            generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
1181
1182            # get the visualizer so we can get access to reformatted reports by module fqn
1183            mod_reports_by_fqn = mod_report.generate_visualizer().generated_reports
1184
1185            # compare the entries of the mapping to those of the report
1186            # we should have the same number of entries
1187            self.assertEqual(len(qconfig_mapping.module_name_qconfigs), len(mod_reports_by_fqn))
1188
1189            # for the non_empty one, we should have 2 because we have only applicable linears
1190            # so should have suggestions for each module named
1191            self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
1192
1193            # only two linears, make sure per channel min max for weight since fbgemm
1194            # also static distribution since a simple single callibration
1195            for key in qconfig_mapping.module_name_qconfigs:
1196                config = qconfig_mapping.module_name_qconfigs[key]
1197                self.assertEqual(config.weight, default_per_channel_weight_observer)
1198                self.assertEqual(config.activation, default_observer)
1199
1200            # make sure these can actually be used to prepare the model
1201            prepared = quantize_fx.prepare_fx(TwoThreeOps(), qconfig_mapping, example_input)
1202
1203            # now convert the model to ensure no errors in conversion
1204            converted = quantize_fx.convert_fx(prepared)
1205
1206    @skipIfNoFBGEMM
1207    def test_equalization_mapping_generation(self):
1208        """
1209        Tests for generation of qconfigs by ModelReport API
1210        - Tests that equalization config generated when input-weight equalization detector used
1211        - Tests that mappings include information for for relavent modules
1212        """
1213        with override_quantized_engine('fbgemm'):
1214            # set the backend for this test
1215            torch.backends.quantized.engine = "fbgemm"
1216            # test with multiple detectors
1217            detector_set = set()
1218            detector_set.add(InputWeightEqualizationDetector(0.6))
1219
1220            model = TwoThreeOps()
1221
1222            # get tst model and callibrate
1223            prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
1224                model, detector_set, model.get_example_inputs()[0]
1225            )
1226
1227            # now we actually callibrate the models
1228            example_input = model.get_example_inputs()[0]
1229            example_input = example_input.to(torch.float)
1230
1231            prepared_for_callibrate_model(example_input)
1232
1233
1234            # get the mapping without error
1235            qconfig_mapping = mod_report.generate_qconfig_mapping()
1236            equalization_mapping = mod_report.generate_equalization_mapping()
1237
1238            # tests a lot more simple for the equalization mapping
1239
1240            # shouldn't have any equalization suggestions for this case
1241            self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
1242
1243
1244            # make sure these can actually be used to prepare the model
1245            prepared = quantize_fx.prepare_fx(
1246                TwoThreeOps(),
1247                qconfig_mapping,
1248                example_input,
1249                _equalization_config=equalization_mapping
1250            )
1251
1252            # now convert the model to ensure no errors in conversion
1253            converted = quantize_fx.convert_fx(prepared)
1254
1255class TestFxDetectInputWeightEqualization(QuantizationTestCase):
1256
1257    class SimpleConv(torch.nn.Module):
1258        def __init__(self, con_dims):
1259            super().__init__()
1260            self.relu = torch.nn.ReLU()
1261            self.conv = torch.nn.Conv2d(con_dims[0], con_dims[1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
1262
1263        def forward(self, x):
1264            x = self.conv(x)
1265            x = self.relu(x)
1266            return x
1267
1268    class TwoBlockComplexNet(torch.nn.Module):
1269        def __init__(self) -> None:
1270            super().__init__()
1271            self.block1 = TestFxDetectInputWeightEqualization.SimpleConv((3, 32))
1272            self.block2 = TestFxDetectInputWeightEqualization.SimpleConv((3, 3))
1273            self.conv = torch.nn.Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
1274            self.linear = torch.nn.Linear(768, 10)
1275            self.relu = torch.nn.ReLU()
1276
1277        def forward(self, x):
1278            x = self.block1(x)
1279            x = self.conv(x)
1280            y = self.block2(x)
1281            y = y.repeat(1, 1, 2, 2)
1282            z = x + y
1283            z = z.flatten(start_dim=1)
1284            z = self.linear(z)
1285            z = self.relu(z)
1286            return z
1287
1288        def get_fusion_modules(self):
1289            return [['conv', 'relu']]
1290
1291        def get_example_inputs(self):
1292            return (torch.randn((1, 3, 28, 28)),)
1293
1294    class ReluOnly(torch.nn.Module):
1295        def __init__(self) -> None:
1296            super().__init__()
1297            self.relu = torch.nn.ReLU()
1298
1299        def forward(self, x):
1300            x = self.relu(x)
1301            return x
1302
1303        def get_example_inputs(self):
1304            return (torch.arange(27).reshape((1, 3, 3, 3)),)
1305
1306    def _get_prepped_for_calibration_model(self, model, detector_set, fused=False):
1307        r"""Returns a model that has been prepared for callibration and corresponding model_report"""
1308
1309        # pass in necessary inputs to helper
1310        example_input = model.get_example_inputs()[0]
1311        return _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused)
1312
1313    @skipIfNoFBGEMM
1314    def test_input_weight_equalization_determine_points(self):
1315        # use fbgemm and create our model instance
1316        # then create model report instance with detector
1317        with override_quantized_engine('fbgemm'):
1318
1319            detector_set = {InputWeightEqualizationDetector(0.5)}
1320
1321            # get tst model and callibrate
1322            non_fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set)
1323            fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set, fused=True)
1324
1325            # reporter should still give same counts even for fused model
1326            for prepared_for_callibrate_model, mod_report in [non_fused, fused]:
1327
1328                # supported modules to check
1329                mods_to_check = {nn.Linear, nn.Conv2d}
1330
1331                # get the set of all nodes in the graph their fqns
1332                node_fqns = {node.target for node in prepared_for_callibrate_model.graph.nodes}
1333
1334                # there should be 4 node fqns that have the observer inserted
1335                correct_number_of_obs_inserted = 4
1336                number_of_obs_found = 0
1337                obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME
1338
1339                for node in prepared_for_callibrate_model.graph.nodes:
1340                    # if the obs name is inside the target, we found an observer
1341                    if obs_name_to_find in str(node.target):
1342                        number_of_obs_found += 1
1343
1344                self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)
1345
1346                # assert that each of the desired modules have the observers inserted
1347                for fqn, module in prepared_for_callibrate_model.named_modules():
1348                    # check if module is a supported module
1349                    is_in_include_list = sum(isinstance(module, x) for x in mods_to_check) > 0
1350
1351                    if is_in_include_list:
1352                        # make sure it has the observer attribute
1353                        self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
1354                    else:
1355                        # if it's not a supported type, it shouldn't have observer attached
1356                        self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
1357
1358    @skipIfNoFBGEMM
1359    def test_input_weight_equalization_report_gen(self):
1360        # use fbgemm and create our model instance
1361        # then create model report instance with detector
1362        with override_quantized_engine('fbgemm'):
1363
1364            test_input_weight_detector = InputWeightEqualizationDetector(0.4)
1365            detector_set = {test_input_weight_detector}
1366            model = self.TwoBlockComplexNet()
1367            # prepare the model for callibration
1368            prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(
1369                model, detector_set
1370            )
1371
1372            # now we actually callibrate the model
1373            example_input = model.get_example_inputs()[0]
1374            example_input = example_input.to(torch.float)
1375
1376            prepared_for_callibrate_model(example_input)
1377
1378            # now get the report by running it through ModelReport instance
1379            generated_report = model_report.generate_model_report(True)
1380
1381            # check that sizes are appropriate only 1 detector
1382            self.assertEqual(len(generated_report), 1)
1383
1384            # get the specific report for input weight equalization
1385            input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()]
1386
1387            # we should have 5 layers looked at since 4 conv / linear layers
1388            self.assertEqual(len(input_weight_dict), 4)
1389
1390            # we can validate that the max and min values of the detector were recorded properly for the first one
1391            # this is because no data has been processed yet, so it should be values from original input
1392
1393            example_input = example_input.reshape((3, 28, 28))  # reshape input
1394            for module_fqn in input_weight_dict:
1395                # look for the first linear
1396                if "block1.linear" in module_fqn:
1397                    block_1_lin_recs = input_weight_dict[module_fqn]
1398                    # get input range info and the channel axis
1399                    ch_axis = block_1_lin_recs[InputWeightEqualizationDetector.CHANNEL_KEY]
1400
1401                    # ensure that the min and max values extracted match properly
1402                    example_min, example_max = torch.aminmax(example_input, dim=ch_axis)
1403                    dimension_min = torch.amin(example_min, dim=ch_axis)
1404                    dimension_max = torch.amax(example_max, dim=ch_axis)
1405
1406                    # make sure per channel min and max are as expected
1407                    min_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
1408                    min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY
1409
1410                    max_per_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
1411                    max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY
1412
1413                    per_channel_min = block_1_lin_recs[min_per_key]
1414                    per_channel_max = block_1_lin_recs[max_per_key]
1415                    self.assertEqual(per_channel_min, dimension_min)
1416                    self.assertEqual(per_channel_max, dimension_max)
1417
1418                    # make sure per channel min and max are as expected
1419                    min_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
1420                    min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY
1421
1422                    max_key = InputWeightEqualizationDetector.ACTIVATION_PREFIX
1423                    max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY
1424
1425                    # make sure the global min and max were correctly recorded and presented
1426                    global_min = block_1_lin_recs[min_key]
1427                    global_max = block_1_lin_recs[max_key]
1428                    self.assertEqual(global_min, min(dimension_min))
1429                    self.assertEqual(global_max, max(dimension_max))
1430
1431                    input_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min))
1432                    # ensure comparision stat passed back is sqrt of range ratios
1433                    # need to get the weight ratios first
1434
1435                    # make sure per channel min and max are as expected
1436                    min_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
1437                    min_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MIN_KEY
1438
1439                    max_per_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
1440                    max_per_key += InputWeightEqualizationDetector.PER_CHANNEL_MAX_KEY
1441
1442                    # get weight per channel and global info
1443                    per_channel_min = block_1_lin_recs[min_per_key]
1444                    per_channel_max = block_1_lin_recs[max_per_key]
1445
1446                    # make sure per channel min and max are as expected
1447                    min_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
1448                    min_key += InputWeightEqualizationDetector.GLOBAL_MIN_KEY
1449
1450                    max_key = InputWeightEqualizationDetector.WEIGHT_PREFIX
1451                    max_key += InputWeightEqualizationDetector.GLOBAL_MAX_KEY
1452
1453                    global_min = block_1_lin_recs[min_key]
1454                    global_max = block_1_lin_recs[max_key]
1455
1456                    weight_ratio = torch.sqrt((per_channel_max - per_channel_min) / (global_max - global_min))
1457
1458                    # also get comp stat for this specific layer
1459                    comp_stat = block_1_lin_recs[InputWeightEqualizationDetector.COMP_METRIC_KEY]
1460
1461                    weight_to_input_ratio = weight_ratio / input_ratio
1462
1463                    self.assertEqual(comp_stat, weight_to_input_ratio)
1464                    # only looking at the first example so can break
1465                    break
1466
1467    @skipIfNoFBGEMM
1468    def test_input_weight_equalization_report_gen_empty(self):
1469        # tests report gen on a model that doesn't have any layers
1470        # use fbgemm and create our model instance
1471        # then create model report instance with detector
1472        with override_quantized_engine('fbgemm'):
1473            test_input_weight_detector = InputWeightEqualizationDetector(0.4)
1474            detector_set = {test_input_weight_detector}
1475            model = self.ReluOnly()
1476            # prepare the model for callibration
1477            prepared_for_callibrate_model, model_report = self._get_prepped_for_calibration_model(model, detector_set)
1478
1479            # now we actually callibrate the model
1480            example_input = model.get_example_inputs()[0]
1481            example_input = example_input.to(torch.float)
1482
1483            prepared_for_callibrate_model(example_input)
1484
1485            # now get the report by running it through ModelReport instance
1486            generated_report = model_report.generate_model_report(True)
1487
1488            # check that sizes are appropriate only 1 detector
1489            self.assertEqual(len(generated_report), 1)
1490
1491            # get the specific report for input weight equalization
1492            input_weight_str, input_weight_dict = generated_report[test_input_weight_detector.get_detector_name()]
1493
1494            # we should have 0 layers since there is only a Relu
1495            self.assertEqual(len(input_weight_dict), 0)
1496
1497            # make sure that the string only has two lines, as should be if no suggestions
1498            self.assertEqual(input_weight_str.count("\n"), 2)
1499
1500
1501class TestFxDetectOutliers(QuantizationTestCase):
1502
1503    class LargeBatchModel(torch.nn.Module):
1504        def __init__(self, param_size):
1505            super().__init__()
1506            self.param_size = param_size
1507            self.linear = torch.nn.Linear(param_size, param_size)
1508            self.relu_1 = torch.nn.ReLU()
1509            self.conv = torch.nn.Conv2d(param_size, param_size, 1)
1510            self.relu_2 = torch.nn.ReLU()
1511
1512        def forward(self, x):
1513            x = self.linear(x)
1514            x = self.relu_1(x)
1515            x = self.conv(x)
1516            x = self.relu_2(x)
1517            return x
1518
1519        def get_example_inputs(self):
1520            param_size = self.param_size
1521            return (torch.randn((1, param_size, param_size, param_size)),)
1522
1523        def get_outlier_inputs(self):
1524            param_size = self.param_size
1525            random_vals = torch.randn((1, param_size, param_size, param_size))
1526            # change one in some of them to be a massive value
1527            random_vals[:, 0:param_size:2, 0, 3] = torch.tensor([3.28e8])
1528            return (random_vals,)
1529
1530
1531    def _get_prepped_for_calibration_model(self, model, detector_set, use_outlier_data=False):
1532        r"""Returns a model that has been prepared for callibration and corresponding model_report"""
1533        # call the general helper function to callibrate
1534        example_input = model.get_example_inputs()[0]
1535
1536        # if we specifically want to test data with outliers replace input
1537        if use_outlier_data:
1538            example_input = model.get_outlier_inputs()[0]
1539
1540        return _get_prepped_for_calibration_model_helper(model, detector_set, example_input)
1541
1542    @skipIfNoFBGEMM
1543    def test_outlier_detection_determine_points(self):
1544        # use fbgemm and create our model instance
1545        # then create model report instance with detector
1546        # similar to test for InputWeightEqualization but key differences that made refactoring not viable
1547        # not explicitly testing fusion because fx workflow automatically
1548        with override_quantized_engine('fbgemm'):
1549
1550            detector_set = {OutlierDetector(reference_percentile=0.95)}
1551
1552            # get tst model and callibrate
1553            prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
1554                self.LargeBatchModel(param_size=128), detector_set
1555            )
1556
1557            # supported modules to check
1558            mods_to_check = {nn.Linear, nn.Conv2d, nn.ReLU}
1559
1560            # there should be 4 node fqns that have the observer inserted
1561            correct_number_of_obs_inserted = 4
1562            number_of_obs_found = 0
1563            obs_name_to_find = InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME
1564
1565            number_of_obs_found = sum(
1566                1 if obs_name_to_find in str(node.target) else 0 for node in prepared_for_callibrate_model.graph.nodes
1567            )
1568            self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted)
1569
1570            # assert that each of the desired modules have the observers inserted
1571            for fqn, module in prepared_for_callibrate_model.named_modules():
1572                # check if module is a supported module
1573                is_in_include_list = isinstance(module, tuple(mods_to_check))
1574
1575                if is_in_include_list:
1576                    # make sure it has the observer attribute
1577                    self.assertTrue(hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
1578                else:
1579                    # if it's not a supported type, it shouldn't have observer attached
1580                    self.assertTrue(not hasattr(module, InputWeightEqualizationDetector.DEFAULT_PRE_OBSERVER_NAME))
1581
1582    @skipIfNoFBGEMM
1583    def test_no_outlier_report_gen(self):
1584        # use fbgemm and create our model instance
1585        # then create model report instance with detector
1586        with override_quantized_engine('fbgemm'):
1587
1588            # test with multiple detectors
1589            outlier_detector = OutlierDetector(reference_percentile=0.95)
1590            dynamic_static_detector = DynamicStaticDetector(tolerance=0.5)
1591
1592            param_size: int = 4
1593            detector_set = {outlier_detector, dynamic_static_detector}
1594            model = self.LargeBatchModel(param_size=param_size)
1595
1596            # get tst model and callibrate
1597            prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
1598                model, detector_set
1599            )
1600
1601            # now we actually callibrate the model
1602            example_input = model.get_example_inputs()[0]
1603            example_input = example_input.to(torch.float)
1604
1605            prepared_for_callibrate_model(example_input)
1606
1607            # now get the report by running it through ModelReport instance
1608            generated_report = mod_report.generate_model_report(True)
1609
1610            # check that sizes are appropriate only 2 detectors
1611            self.assertEqual(len(generated_report), 2)
1612
1613            # get the specific report for input weight equalization
1614            outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
1615
1616            # we should have 5 layers looked at since 4 conv + linear + relu
1617            self.assertEqual(len(outlier_dict), 4)
1618
1619            # assert the following are true for all the modules
1620            for module_fqn in outlier_dict:
1621                # get the info for the specific module
1622                module_dict = outlier_dict[module_fqn]
1623
1624                # there really should not be any outliers since we used a normal distribution to perform this calculation
1625                outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
1626                self.assertEqual(sum(outlier_info), 0)
1627
1628                # ensure that the number of ratios and batches counted is the same as the number of params
1629                self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
1630                self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
1631
1632
1633    @skipIfNoFBGEMM
1634    def test_all_outlier_report_gen(self):
1635        # make the percentile 0 and the ratio 1, and then see that everything is outlier according to it
1636        # use fbgemm and create our model instance
1637        # then create model report instance with detector
1638        with override_quantized_engine('fbgemm'):
1639            # create detector of interest
1640            outlier_detector = OutlierDetector(ratio_threshold=1, reference_percentile=0)
1641
1642            param_size: int = 16
1643            detector_set = {outlier_detector}
1644            model = self.LargeBatchModel(param_size=param_size)
1645
1646            # get tst model and callibrate
1647            prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
1648                model, detector_set
1649            )
1650
1651            # now we actually callibrate the model
1652            example_input = model.get_example_inputs()[0]
1653            example_input = example_input.to(torch.float)
1654
1655            prepared_for_callibrate_model(example_input)
1656
1657            # now get the report by running it through ModelReport instance
1658            generated_report = mod_report.generate_model_report(True)
1659
1660            # check that sizes are appropriate only 1 detector
1661            self.assertEqual(len(generated_report), 1)
1662
1663            # get the specific report for input weight equalization
1664            outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
1665
1666            # we should have 5 layers looked at since 4 conv + linear + relu
1667            self.assertEqual(len(outlier_dict), 4)
1668
1669            # assert the following are true for all the modules
1670            for module_fqn in outlier_dict:
1671                # get the info for the specific module
1672                module_dict = outlier_dict[module_fqn]
1673
1674                # everything should be an outlier because we said that the max should be equal to the min for all of them
1675                # however we will just test and say most should be in case we have several 0 channel values
1676                outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
1677                assert sum(outlier_info) >= len(outlier_info) / 2
1678
1679                # ensure that the number of ratios and batches counted is the same as the number of params
1680                self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
1681                self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
1682
1683    @skipIfNoFBGEMM
1684    def test_multiple_run_consistent_spike_outlier_report_gen(self):
1685        # specifically make a row really high consistently in the number of batches that you are testing and try that
1686        # generate report after just 1 run, and after many runs (30) and make sure above minimum threshold is there
1687        with override_quantized_engine('fbgemm'):
1688
1689            # detector of interest
1690            outlier_detector = OutlierDetector(reference_percentile=0.95)
1691
1692            param_size: int = 8
1693            detector_set = {outlier_detector}
1694            model = self.LargeBatchModel(param_size=param_size)
1695
1696            # get tst model and callibrate
1697            prepared_for_callibrate_model, mod_report = self._get_prepped_for_calibration_model(
1698                model, detector_set, use_outlier_data=True
1699            )
1700
1701            # now we actually callibrate the model
1702            example_input = model.get_outlier_inputs()[0]
1703            example_input = example_input.to(torch.float)
1704
1705            # now callibrate minimum 30 times to make it above minimum threshold
1706            for i in range(30):
1707                example_input = model.get_outlier_inputs()[0]
1708                example_input = example_input.to(torch.float)
1709
1710                # make 2 of the batches to have zero channel
1711                if i % 14 == 0:
1712                    # make one channel constant
1713                    example_input[0][1] = torch.zeros_like(example_input[0][1])
1714
1715                prepared_for_callibrate_model(example_input)
1716
1717            # now get the report by running it through ModelReport instance
1718            generated_report = mod_report.generate_model_report(True)
1719
1720            # check that sizes are appropriate only 1 detector
1721            self.assertEqual(len(generated_report), 1)
1722
1723            # get the specific report for input weight equalization
1724            outlier_str, outlier_dict = generated_report[outlier_detector.get_detector_name()]
1725
1726            # we should have 5 layers looked at since 4 conv + linear + relu
1727            self.assertEqual(len(outlier_dict), 4)
1728
1729            # assert the following are true for all the modules
1730            for module_fqn in outlier_dict:
1731                # get the info for the specific module
1732                module_dict = outlier_dict[module_fqn]
1733
1734                # because we ran 30 times, we should have at least a couple be significant
1735                # could be less because some channels could possibly be all 0
1736                sufficient_batches_info = module_dict[OutlierDetector.IS_SUFFICIENT_BATCHES_KEY]
1737                assert sum(sufficient_batches_info) >= len(sufficient_batches_info) / 2
1738
1739                # half of them should be outliers, because we set a really high value every 2 channels
1740                outlier_info = module_dict[OutlierDetector.OUTLIER_KEY]
1741                self.assertEqual(sum(outlier_info), len(outlier_info) / 2)
1742
1743                # ensure that the number of ratios and batches counted is the same as the number of params
1744                self.assertEqual(len(module_dict[OutlierDetector.COMP_METRIC_KEY]), param_size)
1745                self.assertEqual(len(module_dict[OutlierDetector.NUM_BATCHES_KEY]), param_size)
1746
1747                # for the first one ensure the per channel max values are what we set
1748                if module_fqn == "linear.0":
1749
1750                    # check that the non-zero channel count, at least 2 should be there
1751                    # for the first module
1752                    counts_info = module_dict[OutlierDetector.CONSTANT_COUNTS_KEY]
1753                    assert sum(counts_info) >= 2
1754
1755                    # half of the recorded max values should be what we set
1756                    matched_max = sum(val == 3.28e8 for val in module_dict[OutlierDetector.MAX_VALS_KEY])
1757                    self.assertEqual(matched_max, param_size / 2)
1758
1759
1760class TestFxModelReportVisualizer(QuantizationTestCase):
1761
1762    def _callibrate_and_generate_visualizer(self, model, prepared_for_callibrate_model, mod_report):
1763        r"""
1764        Callibrates the passed in model, generates report, and returns the visualizer
1765        """
1766        # now we actually callibrate the model
1767        example_input = model.get_example_inputs()[0]
1768        example_input = example_input.to(torch.float)
1769
1770        prepared_for_callibrate_model(example_input)
1771
1772        # now get the report by running it through ModelReport instance
1773        generated_report = mod_report.generate_model_report(remove_inserted_observers=False)
1774
1775        # now we get the visualizer should not error
1776        mod_rep_visualizer: ModelReportVisualizer = mod_report.generate_visualizer()
1777
1778        return mod_rep_visualizer
1779
1780    @skipIfNoFBGEMM
1781    def test_get_modules_and_features(self):
1782        """
1783        Tests the get_all_unique_module_fqns and get_all_unique_feature_names methods of
1784        ModelReportVisualizer
1785
1786        Checks whether returned sets are of proper size and filtered properly
1787        """
1788        with override_quantized_engine('fbgemm'):
1789            # set the backend for this test
1790            torch.backends.quantized.engine = "fbgemm"
1791            # test with multiple detectors
1792            detector_set = set()
1793            detector_set.add(OutlierDetector(reference_percentile=0.95))
1794            detector_set.add(InputWeightEqualizationDetector(0.5))
1795
1796            model = TwoThreeOps()
1797
1798            # get tst model and callibrate
1799            prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
1800                model, detector_set, model.get_example_inputs()[0]
1801            )
1802
1803            mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer(
1804                model, prepared_for_callibrate_model, mod_report
1805            )
1806
1807            # ensure the module fqns match the ones given by the get_all_unique_feature_names method
1808            actual_model_fqns = set(mod_rep_visualizer.generated_reports.keys())
1809            returned_model_fqns = mod_rep_visualizer.get_all_unique_module_fqns()
1810            self.assertEqual(returned_model_fqns, actual_model_fqns)
1811
1812            # now ensure that features are all properly returned
1813            # all the linears have all the features for two detectors
1814            # can use those as check that method is working reliably
1815            b_1_linear_features = mod_rep_visualizer.generated_reports["block1.linear"]
1816
1817            # first test all features
1818            returned_all_feats = mod_rep_visualizer.get_all_unique_feature_names(False)
1819            self.assertEqual(returned_all_feats, set(b_1_linear_features.keys()))
1820
1821            # now test plottable features
1822            plottable_set = set()
1823
1824            for feature_name in b_1_linear_features:
1825                if type(b_1_linear_features[feature_name]) == torch.Tensor:
1826                    plottable_set.add(feature_name)
1827
1828            returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names()
1829            self.assertEqual(returned_plottable_feats, plottable_set)
1830
1831    def _prep_visualizer_helper(self):
1832        r"""
1833        Returns a mod rep visualizer that we test in various ways
1834        """
1835        # set backend for test
1836        torch.backends.quantized.engine = "fbgemm"
1837
1838        # test with multiple detectors
1839        detector_set = set()
1840        detector_set.add(OutlierDetector(reference_percentile=0.95))
1841        detector_set.add(InputWeightEqualizationDetector(0.5))
1842
1843        model = TwoThreeOps()
1844
1845        # get tst model and callibrate
1846        prepared_for_callibrate_model, mod_report = _get_prepped_for_calibration_model_helper(
1847            model, detector_set, model.get_example_inputs()[0]
1848        )
1849
1850        mod_rep_visualizer: ModelReportVisualizer = self._callibrate_and_generate_visualizer(
1851            model, prepared_for_callibrate_model, mod_report
1852        )
1853
1854        return mod_rep_visualizer
1855
1856    @skipIfNoFBGEMM
1857    def test_generate_tables_match_with_report(self):
1858        """
1859        Tests the generate_table_view()
1860        ModelReportVisualizer
1861
1862        Checks whether the generated dict has proper information
1863            Visual check that the tables look correct performed during testing
1864        """
1865        with override_quantized_engine('fbgemm'):
1866
1867            # get the visualizer
1868            mod_rep_visualizer = self._prep_visualizer_helper()
1869
1870            table_dict = mod_rep_visualizer.generate_filtered_tables()
1871
1872            # test primarily the dict since it has same info as str
1873            tensor_headers, tensor_table = table_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
1874            channel_headers, channel_table = table_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
1875
1876            # these two together should be the same as the generated report info in terms of keys
1877            tensor_info_modules = {row[1] for row in tensor_table}
1878            channel_info_modules = {row[1] for row in channel_table}
1879            combined_modules: Set = tensor_info_modules.union(channel_info_modules)
1880
1881            generated_report_keys: Set = set(mod_rep_visualizer.generated_reports.keys())
1882            self.assertEqual(combined_modules, generated_report_keys)
1883
1884    @skipIfNoFBGEMM
1885    def test_generate_tables_no_match(self):
1886        """
1887        Tests the generate_table_view()
1888        ModelReportVisualizer
1889
1890        Checks whether the generated dict has proper information
1891            Visual check that the tables look correct performed during testing
1892        """
1893        with override_quantized_engine('fbgemm'):
1894            # get the visualizer
1895            mod_rep_visualizer = self._prep_visualizer_helper()
1896
1897            # try a random filter and make sure that there are no rows for either table
1898            empty_tables_dict = mod_rep_visualizer.generate_filtered_tables(module_fqn_filter="random not there module")
1899
1900            # test primarily the dict since it has same info as str
1901            tensor_headers, tensor_table = empty_tables_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
1902            channel_headers, channel_table = empty_tables_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
1903
1904            tensor_info_modules = {row[1] for row in tensor_table}
1905            channel_info_modules = {row[1] for row in channel_table}
1906            combined_modules: Set = tensor_info_modules.union(channel_info_modules)
1907            self.assertEqual(len(combined_modules), 0)  # should be no matching modules
1908
1909    @skipIfNoFBGEMM
1910    def test_generate_tables_single_feat_match(self):
1911        """
1912        Tests the generate_table_view()
1913        ModelReportVisualizer
1914
1915        Checks whether the generated dict has proper information
1916            Visual check that the tables look correct performed during testing
1917        """
1918        with override_quantized_engine('fbgemm'):
1919            # get the visualizer
1920            mod_rep_visualizer = self._prep_visualizer_helper()
1921
1922            # try a matching filter for feature and make sure only those features show up
1923            # if we filter to a very specific feature name, should only have 1 additional column in each table row
1924            single_feat_dict = mod_rep_visualizer.generate_filtered_tables(feature_filter=OutlierDetector.MAX_VALS_KEY)
1925
1926            # test primarily the dict since it has same info as str
1927            tensor_headers, tensor_table = single_feat_dict[ModelReportVisualizer.TABLE_TENSOR_KEY]
1928            channel_headers, channel_table = single_feat_dict[ModelReportVisualizer.TABLE_CHANNEL_KEY]
1929
1930            # get the number of features in each of these
1931            tensor_info_features = len(tensor_headers)
1932            channel_info_features = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
1933
1934            # make sure that there are no tensor features, and that there is one channel level feature
1935            self.assertEqual(tensor_info_features, 0)
1936            self.assertEqual(channel_info_features, 1)
1937
1938def _get_prepped_for_calibration_model_helper(model, detector_set, example_input, fused: bool = False):
1939    r"""Returns a model that has been prepared for callibration and corresponding model_report"""
1940    # set the backend for this test
1941    torch.backends.quantized.engine = "fbgemm"
1942
1943    # create model instance and prepare it
1944    example_input = example_input.to(torch.float)
1945    q_config_mapping = torch.ao.quantization.get_default_qconfig_mapping()
1946
1947    # if they passed in fusion paramter, make sure to test that
1948    if fused:
1949        model = torch.ao.quantization.fuse_modules(model, model.get_fusion_modules())
1950
1951    model_prep = quantize_fx.prepare_fx(model, q_config_mapping, example_input)
1952
1953    model_report = ModelReport(model_prep, detector_set)
1954
1955    # prepare the model for callibration
1956    prepared_for_callibrate_model = model_report.prepare_detailed_calibration()
1957
1958    return (prepared_for_callibrate_model, model_report)
1959