xref: /aosp_15_r20/external/pytorch/test/quantization/eager/test_numeric_suite_eager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import unittest
4import torch
5import torch.nn as nn
6import torch.ao.nn.quantized as nnq
7from torch.ao.quantization import (
8    DeQuantStub,
9    QuantStub,
10    convert,
11    default_qconfig,
12    prepare,
13    quantize,
14    quantize_dynamic,
15)
16from torch.ao.ns._numeric_suite import (
17    OutputLogger,
18    Shadow,
19    ShadowLogger,
20    compare_model_outputs,
21    compare_model_stub,
22    compare_weights,
23    prepare_model_outputs,
24    get_matching_activations,
25)
26from torch.testing._internal.common_quantization import (
27    AnnotatedConvBnReLUModel,
28    AnnotatedConvModel,
29    AnnotatedConvTransposeModel,
30    AnnotatedSingleLayerLinearModel,
31    LSTMwithHiddenDynamicModel,
32    AnnotatedTwoLayerLinearModel,
33    QuantizationTestCase,
34    SingleLayerLinearDynamicModel,
35    test_only_eval_fn,
36    skip_if_no_torchvision,
37)
38from torch.testing._internal.common_quantized import override_qengines
39from torch.testing._internal.common_utils import IS_ARM64
40
41class SubModule(torch.nn.Module):
42    def __init__(self) -> None:
43        super().__init__()
44        self.qconfig = default_qconfig
45        self.mod1 = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float)
46        self.mod2 = nn.ReLU()
47        self.quant = QuantStub()
48        self.dequant = DeQuantStub()
49
50    def forward(self, x):
51        x = self.quant(x)
52        x = self.mod1(x)
53        x = self.mod2(x)
54        x = self.dequant(x)
55        return x
56
57
58class ModelWithSubModules(torch.nn.Module):
59    def __init__(self) -> None:
60        super().__init__()
61        self.mod1 = SubModule()
62        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
63
64    def forward(self, x):
65        x = self.mod1(x)
66        x = self.conv(x)
67        return x
68
69
70class ModelWithFunctionals(torch.nn.Module):
71    def __init__(self) -> None:
72        super().__init__()
73        self.mycat = nnq.FloatFunctional()
74        self.myadd = nnq.FloatFunctional()
75        self.mymul = nnq.FloatFunctional()
76        self.myadd_relu = nnq.FloatFunctional()
77        self.my_scalar_add = nnq.FloatFunctional()
78        self.my_scalar_mul = nnq.FloatFunctional()
79        self.quant = QuantStub()
80        self.dequant = DeQuantStub()
81
82    def forward(self, x):
83        x = self.quant(x)
84        x = self.mycat.cat([x, x, x])
85        x = self.myadd.add(x, x)
86        x = self.mymul.mul(x, x)
87        x = self.myadd_relu.add_relu(x, x)
88        w = self.my_scalar_add.add_scalar(x, -0.5)
89        w = self.my_scalar_mul.mul_scalar(w, 0.5)
90
91        w = self.dequant(w)
92        return w
93
94
95class TestNumericSuiteEager(QuantizationTestCase):
96    @override_qengines
97    def test_compare_weights_conv_static(self):
98        r"""Compare the weights of float and static quantized conv layer"""
99
100        qengine = torch.backends.quantized.engine
101
102        def compare_and_validate_results(float_model, q_model):
103            weight_dict = compare_weights(
104                float_model.state_dict(), q_model.state_dict()
105            )
106            self.assertEqual(len(weight_dict), 1)
107            for v in weight_dict.values():
108                self.assertTrue(v["float"].shape == v["quantized"].shape)
109
110        model_list = [AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine)]
111        for model in model_list:
112            model.eval()
113            if hasattr(model, "fuse_model"):
114                model.fuse_model()
115            q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
116            compare_and_validate_results(model, q_model)
117
118    @override_qengines
119    def test_compare_weights_linear_static(self):
120        r"""Compare the weights of float and static quantized linear layer"""
121
122        qengine = torch.backends.quantized.engine
123
124        def compare_and_validate_results(float_model, q_model):
125            weight_dict = compare_weights(
126                float_model.state_dict(), q_model.state_dict()
127            )
128            self.assertEqual(len(weight_dict), 1)
129            for v in weight_dict.values():
130                self.assertTrue(v["float"].shape == v["quantized"].shape)
131
132        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
133        for model in model_list:
134            model.eval()
135            if hasattr(model, "fuse_model"):
136                model.fuse_model()
137            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
138            compare_and_validate_results(model, q_model)
139
140    @override_qengines
141    def test_compare_weights_linear_dynamic(self):
142        r"""Compare the weights of float and dynamic quantized linear layer"""
143
144        qengine = torch.backends.quantized.engine
145
146        def compare_and_validate_results(float_model, q_model):
147            weight_dict = compare_weights(
148                float_model.state_dict(), q_model.state_dict()
149            )
150            self.assertEqual(len(weight_dict), 1)
151            for v in weight_dict.values():
152                self.assertTrue(len(v["float"]) == len(v["quantized"]))
153                for i, val in enumerate(v["quantized"]):
154                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
155
156        model_list = [SingleLayerLinearDynamicModel(qengine)]
157        for model in model_list:
158            model.eval()
159            if hasattr(model, "fuse_model"):
160                model.fuse_model()
161            q_model = quantize_dynamic(model)
162            compare_and_validate_results(model, q_model)
163
164    @override_qengines
165    def test_compare_weights_lstm_dynamic(self):
166        r"""Compare the weights of float and dynamic quantized LSTM layer"""
167
168        qengine = torch.backends.quantized.engine
169
170        def compare_and_validate_results(float_model, q_model):
171            weight_dict = compare_weights(
172                float_model.state_dict(), q_model.state_dict()
173            )
174            self.assertEqual(len(weight_dict), 1)
175            for v in weight_dict.values():
176                self.assertTrue(len(v["float"]) == len(v["quantized"]))
177                for i, val in enumerate(v["quantized"]):
178                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
179
180        model_list = [LSTMwithHiddenDynamicModel(qengine)]
181        for model in model_list:
182            model.eval()
183            if hasattr(model, "fuse_model"):
184                model.fuse_model()
185            q_model = quantize_dynamic(model)
186            compare_and_validate_results(model, q_model)
187
188    @override_qengines
189    def test_compare_model_stub_conv_static(self):
190        r"""Compare the output of static quantized conv layer and its float shadow module"""
191
192        qengine = torch.backends.quantized.engine
193
194        def compare_and_validate_results(float_model, q_model, module_swap_list, data):
195            ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data)
196            self.assertEqual(len(ob_dict), 1)
197            for v in ob_dict.values():
198                self.assertTrue(len(v["float"]) == len(v["quantized"]))
199                for i, val in enumerate(v["quantized"]):
200                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
201
202        model_list = [AnnotatedConvModel(qengine),
203                      AnnotatedConvTransposeModel("qnnpack"),  # ConvT cannot use per channel weights
204                      AnnotatedConvBnReLUModel(qengine)]
205        module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d, nn.ConvTranspose2d]
206        for model in model_list:
207            model.eval()
208            if hasattr(model, "fuse_model"):
209                model.fuse_model()
210            q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
211            compare_and_validate_results(
212                model, q_model, module_swap_list, self.img_data_2d[0][0]
213            )
214
215    @override_qengines
216    def test_compare_model_stub_linear_static(self):
217        r"""Compare the output of static quantized linear layer and its float shadow module"""
218
219        qengine = torch.backends.quantized.engine
220
221        def compare_and_validate_results(float_model, q_model, module_swap_list, data):
222            ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data)
223            self.assertEqual(len(ob_dict), 1)
224            for v in ob_dict.values():
225                self.assertTrue(len(v["float"]) == len(v["quantized"]))
226                for i, val in enumerate(v["quantized"]):
227                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
228
229        linear_data = self.calib_data[0][0]
230        module_swap_list = [nn.Linear]
231        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
232        for model in model_list:
233            model.eval()
234            if hasattr(model, "fuse_model"):
235                model.fuse_model()
236            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
237            compare_and_validate_results(model, q_model, module_swap_list, linear_data)
238
239    @override_qengines
240    def test_compare_model_stub_partial(self):
241        r"""Compare the output of static quantized linear layer and its float shadow module"""
242
243        qengine = torch.backends.quantized.engine
244        # TODO: Rebase on top of PR to remove compare and validate results here
245
246        def compare_and_validate_results(float_model, q_model, module_swap_list, data):
247            ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data)
248            self.assertEqual(len(ob_dict), 1)
249            for v in ob_dict.values():
250                self.assertTrue(len(v["float"]) == len(v["quantized"]))
251                for i, val in enumerate(v["quantized"]):
252                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
253
254        linear_data = self.calib_data[0][0]
255        module_swap_list = [nn.Linear]
256        model_list = [AnnotatedTwoLayerLinearModel()]
257        for model in model_list:
258            model.eval()
259            if hasattr(model, "fuse_model"):
260                model.fuse_model()
261            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
262            compare_and_validate_results(model, q_model, module_swap_list, linear_data)
263
264    @override_qengines
265    def test_compare_model_stub_submodule_static(self):
266        r"""Compare the output of static quantized submodule and its float shadow module"""
267
268        qengine = torch.backends.quantized.engine
269
270        model = ModelWithSubModules().eval()
271        q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
272        module_swap_list = [SubModule, nn.Conv2d]
273        ob_dict = compare_model_stub(
274            model, q_model, module_swap_list, self.img_data_2d[0][0]
275        )
276        # Since conv is not quantized, we do not insert a shadow module
277        # mod1 contains a linear that is quantized, so we insert a shadow module
278        self.assertTrue(isinstance(q_model.mod1, Shadow))
279        self.assertFalse(isinstance(q_model.conv, Shadow))
280
281
282    @override_qengines
283    def test_compare_model_stub_functional_static(self):
284        r"""Compare the output of static quantized functional layer and its float shadow module"""
285
286        qengine = torch.backends.quantized.engine
287
288        model = ModelWithFunctionals().eval()
289        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
290        q_model = prepare(model, inplace=False)
291        q_model(self.img_data_2d[0][0])
292        q_model = convert(q_model)
293        module_swap_list = [nnq.FloatFunctional]
294        ob_dict = compare_model_stub(
295            model, q_model, module_swap_list, self.img_data_2d[0][0]
296        )
297        self.assertEqual(len(ob_dict), 6)
298        self.assertTrue(isinstance(q_model.mycat, Shadow))
299        self.assertTrue(isinstance(q_model.myadd, Shadow))
300        self.assertTrue(isinstance(q_model.mymul, Shadow))
301        self.assertTrue(isinstance(q_model.myadd_relu, Shadow))
302        self.assertTrue(isinstance(q_model.my_scalar_add, Shadow))
303        self.assertTrue(isinstance(q_model.my_scalar_mul, Shadow))
304        for v in ob_dict.values():
305            self.assertTrue(len(v["float"]) == len(v["quantized"]))
306            for i, val in enumerate(v["quantized"]):
307                self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
308
309    @override_qengines
310    def test_compare_model_stub_linear_dynamic(self):
311        r"""Compare the output of dynamic quantized linear layer and its float shadow module"""
312
313        qengine = torch.backends.quantized.engine
314
315        def compare_and_validate_results(float_model, q_model, module_swap_list, data):
316            ob_dict = compare_model_stub(float_model, q_model, module_swap_list, data)
317            self.assertEqual(len(ob_dict), 1)
318            for v in ob_dict.values():
319                self.assertTrue(len(v["float"]) == len(v["quantized"]))
320                for i, val in enumerate(v["quantized"]):
321                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
322
323        linear_data = self.calib_data[0][0]
324
325        model_list = [SingleLayerLinearDynamicModel(qengine)]
326        module_swap_list = [nn.Linear, nn.LSTM]
327        for model in model_list:
328            model.eval()
329            if hasattr(model, "fuse_model"):
330                model.fuse_model()
331            q_model = quantize_dynamic(model)
332            compare_and_validate_results(model, q_model, module_swap_list, linear_data)
333
334    @override_qengines
335    def test_compare_model_stub_lstm_dynamic(self):
336        r"""Compare the output of dynamic quantized LSTM layer and its float shadow module"""
337
338        qengine = torch.backends.quantized.engine
339
340        def compare_and_validate_results(
341            float_model, q_model, module_swap_list, input, hidden
342        ):
343            ob_dict = compare_model_stub(
344                float_model, q_model, module_swap_list, input, hidden
345            )
346            self.assertEqual(len(ob_dict), 1)
347            for v in ob_dict.values():
348                self.assertTrue(len(v["float"]) == len(v["quantized"]))
349                for i, val in enumerate(v["quantized"]):
350                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
351
352        lstm_input = torch.rand((1, 1, 2))
353        lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
354
355        model_list = [LSTMwithHiddenDynamicModel(qengine)]
356        module_swap_list = [nn.Linear, nn.LSTM]
357        for model in model_list:
358            model.eval()
359            if hasattr(model, "fuse_model"):
360                model.fuse_model()
361            q_model = quantize_dynamic(model)
362            compare_and_validate_results(
363                model, q_model, module_swap_list, lstm_input, lstm_hidden
364            )
365
366    @override_qengines
367    def test_compare_model_outputs_conv_static(self):
368        r"""Compare the output of conv layer in stataic quantized model and corresponding
369        output of conv layer in float model
370        """
371        qengine = torch.backends.quantized.engine
372
373        def compare_and_validate_results(float_model, q_model, data):
374            act_compare_dict = compare_model_outputs(float_model, q_model, data)
375            expected_act_compare_dict_keys = {"conv.stats", "quant.stats"}
376
377            self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
378            for v in act_compare_dict.values():
379                self.assertTrue(v["float"][0].shape == v["quantized"][0].shape)
380
381        model_list = [AnnotatedConvModel(qengine), AnnotatedConvBnReLUModel(qengine)]
382        for model in model_list:
383            model.eval()
384            if hasattr(model, "fuse_model"):
385                model.fuse_model()
386            q_model = quantize(model, test_only_eval_fn, [self.img_data_2d])
387            compare_and_validate_results(model, q_model, self.img_data_2d[0][0])
388
389    @override_qengines
390    def test_compare_model_outputs_linear_static(self):
391        r"""Compare the output of linear layer in static quantized model and corresponding
392        output of conv layer in float model
393        """
394        qengine = torch.backends.quantized.engine
395
396        def compare_and_validate_results(float_model, q_model, data):
397            act_compare_dict = compare_model_outputs(float_model, q_model, data)
398            expected_act_compare_dict_keys = {"fc1.quant.stats", "fc1.module.stats"}
399
400            self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
401            for v in act_compare_dict.values():
402                self.assertTrue(len(v["float"]) == len(v["quantized"]))
403                for i, val in enumerate(v["quantized"]):
404                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
405
406        linear_data = self.calib_data[0][0]
407        model_list = [AnnotatedSingleLayerLinearModel(qengine)]
408        for model in model_list:
409            model.eval()
410            if hasattr(model, "fuse_model"):
411                model.fuse_model()
412            q_model = quantize(model, test_only_eval_fn, [self.calib_data])
413            compare_and_validate_results(model, q_model, linear_data)
414
415    @override_qengines
416    def test_compare_model_outputs_functional_static(self):
417        r"""Compare the output of functional layer in static quantized model and corresponding
418        output of conv layer in float model
419        """
420        qengine = torch.backends.quantized.engine
421
422        model = ModelWithFunctionals().eval()
423        model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
424        q_model = prepare(model, inplace=False)
425        q_model(self.img_data_2d[0][0])
426        q_model = convert(q_model)
427        act_compare_dict = compare_model_outputs(model, q_model, self.img_data_2d[0][0])
428        self.assertEqual(len(act_compare_dict), 5)
429        expected_act_compare_dict_keys = {
430            "mycat.stats",
431            "myadd.stats",
432            "mymul.stats",
433            "myadd_relu.stats",
434            "quant.stats",
435        }
436        self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
437        for v in act_compare_dict.values():
438            self.assertTrue(len(v["float"]) == len(v["quantized"]))
439            for i, val in enumerate(v["quantized"]):
440                self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
441
442    @override_qengines
443    def test_compare_model_outputs_linear_dynamic(self):
444        r"""Compare the output of linear layer in dynamic quantized model and corresponding
445        output of conv layer in float model
446        """
447        qengine = torch.backends.quantized.engine
448
449        def compare_and_validate_results(float_model, q_model, data):
450            act_compare_dict = compare_model_outputs(float_model, q_model, data)
451            expected_act_compare_dict_keys = {"fc1.stats"}
452
453            self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
454            for v in act_compare_dict.values():
455                self.assertTrue(len(v["float"]) == len(v["quantized"]))
456                for i, val in enumerate(v["quantized"]):
457                    self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
458
459        linear_data = self.calib_data[0][0]
460
461        model_list = [SingleLayerLinearDynamicModel(qengine)]
462        for model in model_list:
463            model.eval()
464            if hasattr(model, "fuse_model"):
465                model.fuse_model()
466            q_model = quantize_dynamic(model)
467            compare_and_validate_results(model, q_model, linear_data)
468
469    @override_qengines
470    def test_compare_model_outputs_lstm_dynamic(self):
471        r"""Compare the output of LSTM layer in dynamic quantized model and corresponding
472        output of conv layer in float model
473        """
474        qengine = torch.backends.quantized.engine
475
476        def compare_and_validate_results(float_model, q_model, input, hidden):
477            act_compare_dict = compare_model_outputs(
478                float_model, q_model, input, hidden
479            )
480            expected_act_compare_dict_keys = {"lstm.stats"}
481
482            self.assertTrue(act_compare_dict.keys() == expected_act_compare_dict_keys)
483            for v in act_compare_dict.values():
484                self.assertTrue(len(v["float"]) == len(v["quantized"]))
485                for i, val in enumerate(v["quantized"]):
486                    self.assertTrue(len(v["float"][i]) == len(v["quantized"][i]))
487                    if i == 0:
488                        self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape)
489                    else:
490                        self.assertTrue(
491                            v["float"][i][0].shape == v["quantized"][i][0].shape
492                        )
493                        self.assertTrue(
494                            v["float"][i][1].shape == v["quantized"][i][1].shape
495                        )
496
497        lstm_input = torch.rand((1, 1, 2))
498        lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2))
499
500        model_list = [LSTMwithHiddenDynamicModel(qengine)]
501        for model in model_list:
502            model.eval()
503            if hasattr(model, "fuse_model"):
504                model.fuse_model()
505            q_model = quantize_dynamic(model)
506            compare_and_validate_results(model, q_model, lstm_input, lstm_hidden)
507
508    @override_qengines
509    def test_output_logger(self):
510        r"""Compare output from OutputLogger with the expected results"""
511        x = torch.rand(2, 2)
512        y = torch.rand(2, 1)
513
514        l = []
515        l.append(x)
516        l.append(y)
517
518        logger = OutputLogger()
519        logger.forward(x)
520        logger.forward(y)
521
522        self.assertEqual(l, logger.stats["tensor_val"])
523
524    @override_qengines
525    def test_shadow_logger(self):
526        r"""Compare output from ShawdowLogger with the expected results"""
527        a_float = torch.rand(2, 2)
528        a_quantized = torch.rand(2, 2)
529
530        b_float = torch.rand(3, 2, 2)
531        b_quantized = torch.rand(3, 2, 2)
532
533        logger = ShadowLogger()
534        logger.forward(a_float, a_quantized)
535        logger.forward(b_float, b_quantized)
536
537        self.assertEqual(len(logger.stats["float"]), 2)
538        self.assertEqual(len(logger.stats["quantized"]), 2)
539
540    @skip_if_no_torchvision
541    def _test_vision_model(self, float_model):
542        float_model.to('cpu')
543        float_model.eval()
544        float_model.fuse_model()
545        float_model.qconfig = torch.ao.quantization.default_qconfig
546        img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
547        qmodel = quantize(float_model, torch.ao.quantization.default_eval_fn, [img_data], inplace=False)
548
549        wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
550
551        def compute_error(x, y):
552            Ps = torch.norm(x)
553            Pn = torch.norm(x - y)
554            return 20 * torch.log10(Ps / Pn)
555
556        data = img_data[0][0]
557        # Take in floating point and quantized model as well as input data, and returns a dict, with keys
558        # corresponding to the quantized module names and each entry being a dictionary with two keys 'float' and
559        # 'quantized', containing the activations of floating point and quantized model at matching locations.
560        act_compare_dict = compare_model_outputs(float_model, qmodel, data)
561
562
563        for key in act_compare_dict:
564            compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize())
565
566        prepare_model_outputs(float_model, qmodel)
567
568        for data in img_data:
569            float_model(data[0])
570            qmodel(data[0])
571
572        # Find the matching activation between floating point and quantized modules, and return a dict with key
573        # corresponding to quantized module names and each entry being a dictionary with two keys 'float'
574        # and 'quantized', containing the matching floating point and quantized activations logged by the logger
575        act_compare_dict = get_matching_activations(float_model, qmodel)
576
577    @skip_if_no_torchvision
578    @unittest.skipIf(IS_ARM64, "Not working on arm right now")
579    def test_mobilenet_v2(self):
580        from torchvision.models.quantization import mobilenet_v2
581        self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
582
583    @skip_if_no_torchvision
584    @unittest.skipIf(IS_ARM64, "Not working on arm right now")
585    def test_mobilenet_v3(self):
586        from torchvision.models.quantization import mobilenet_v3_large
587        self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))
588