xref: /aosp_15_r20/external/pytorch/test/quantization/eager/test_quantize_eager_ptq.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4import torch.nn as nn
5import torch.ao.nn.quantized as nnq
6from torch.nn.utils.rnn import PackedSequence
7from torch.ao.quantization import (
8    quantize,
9    prepare,
10    convert,
11    prepare_qat,
12    quantize_dynamic,
13    QuantWrapper,
14    QuantStub,
15    DeQuantStub,
16    default_qconfig,
17    default_dynamic_qconfig,
18    per_channel_dynamic_qconfig,
19    float16_dynamic_qconfig,
20    float_qparams_weight_only_qconfig,
21    float_qparams_weight_only_qconfig_4bit,
22    FixedQParamsObserver,
23    PerChannelMinMaxObserver,
24    default_dynamic_quant_observer,
25    default_weight_observer,
26    QConfig,
27)
28
29from torch.testing._internal.common_quantization import (
30    QuantizationTestCase,
31    AnnotatedSingleLayerLinearModel,
32    QuantStubModel,
33    ModelWithFunctionals,
34    SingleLayerLinearDynamicModel,
35    TwoLayerLinearModel,
36    NestedModel,
37    ResNetBase,
38    RNNDynamicModel,
39    RNNCellDynamicModel,
40    ActivationsTestModel,
41    NormalizationTestModel,
42    test_only_eval_fn,
43    prepare_dynamic,
44    convert_dynamic,
45    skipIfNoFBGEMM,
46    EmbeddingBagModule,
47    EmbeddingModule,
48    EmbeddingWithStaticLinear,
49    LinearReluLinearModel,
50)
51
52# annotated models
53from torch.testing._internal.common_quantization import (
54    AnnotatedTwoLayerLinearModel,
55    AnnotatedNestedModel,
56    AnnotatedSubNestedModel,
57    AnnotatedCustomConfigNestedModel,
58    AnnotatedSkipQuantModel,
59)
60
61from torch.testing._internal.common_quantized import (
62    override_quantized_engine,
63    supported_qengines,
64    override_qengines,
65)
66
67from hypothesis import given
68from hypothesis import strategies as st
69import torch.testing._internal.hypothesis_utils as hu
70hu.assert_deadline_disabled()
71
72# Standard library
73from typing import Tuple
74import numpy as np
75
76class TestQuantizeEagerOps(QuantizationTestCase):
77    @override_qengines
78    def _test_reference_module_impl(self,
79                                    float_module_class,
80                                    quantized_module_class,
81                                    extra_module_kwargs,
82                                    input_size):
83        class M(torch.nn.Module):
84            def __init__(self) -> None:
85                super().__init__()
86                self.conv = float_module_class(**extra_module_kwargs)
87                self.quant = QuantStub()
88                self.dequant = DeQuantStub()
89
90            def forward(self, x):
91                x = self.quant(x)
92                x = self.conv(x)
93                x = self.dequant(x)
94                return x
95
96        class RefM(torch.nn.Module):
97            def __init__(self) -> None:
98                super().__init__()
99                self.conv = float_module_class(**extra_module_kwargs)
100                self.quant1 = QuantStub()
101                self.dequant1 = DeQuantStub()
102                self.quant2 = QuantStub()
103                self.dequant2 = DeQuantStub()
104
105            def forward(self, x):
106                x = self.quant1(x)
107                x = self.dequant1(x)
108                x = self.conv(x)
109                x = self.quant2(x)
110                x = self.dequant2(x)
111                return x
112
113        qengine = torch.backends.quantized.engine
114        if qengine not in supported_qengines or qengine == 'qnnpack':
115            return   # qnnpack does not support nnq.ConvTranspose3d
116
117        data = torch.randn(*input_size, dtype=torch.float)
118        original_m = M()
119        original_ref_m = RefM()
120
121        original_ref_m.conv.weight = torch.nn.Parameter(original_m.conv.weight.detach())
122        original_ref_m.conv.bias = torch.nn.Parameter(original_m.conv.bias.detach())
123
124        original_m.qconfig = torch.ao.quantization.default_qconfig
125
126        m = prepare(original_m)
127        # calibration
128        m(data)
129        m = convert(m)
130        # check if the module is properly quantized
131        self.assertEqual(type(m.quant), nnq.Quantize)
132        self.assertEqual(type(m.conv), quantized_module_class)
133        self.assertEqual(type(m.dequant), nnq.DeQuantize)
134        res = m(data)
135
136        # quantize the reference model
137        original_ref_m.eval()
138        original_ref_m.qconfig = torch.ao.quantization.default_qconfig
139
140        ref_m = prepare(original_ref_m)
141        ref_m(data)
142        ref_m = convert(ref_m, is_reference=True)
143        ref_res = ref_m(data)
144        self.assertEqual(res, ref_res)
145
146    def test_conv_1d(self):
147        self._test_reference_module_impl(
148            nn.Conv1d,
149            nnq.Conv1d,
150            {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
151            (16, 1, 1)
152        )
153
154    def test_conv_2d(self):
155        self._test_reference_module_impl(
156            nn.Conv2d,
157            nnq.Conv2d,
158            {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
159            (16, 1, 10, 10)
160        )
161
162    def test_conv_3d(self):
163        self._test_reference_module_impl(
164            nn.Conv3d,
165            nnq.Conv3d,
166            {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
167            (16, 1, 10, 10, 10)
168        )
169
170    def test_conv_transpose_1d(self):
171        self._test_reference_module_impl(
172            nn.ConvTranspose1d,
173            nnq.ConvTranspose1d,
174            {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
175            (16, 1, 1)
176        )
177
178    def test_conv_transpose_2d(self):
179        self._test_reference_module_impl(
180            nn.ConvTranspose2d,
181            nnq.ConvTranspose2d,
182            {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
183            (16, 1, 10, 10)
184        )
185
186    def test_conv_transpose_3d(self):
187        self._test_reference_module_impl(
188            nn.ConvTranspose3d,
189            nnq.ConvTranspose3d,
190            {'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
191            (16, 1, 10, 10, 10)
192        )
193
194    def test_linear(self):
195        self._test_reference_module_impl(
196            nn.Linear,
197            nnq.Linear,
198            {'in_features': 5, 'out_features': 10},
199            (16, 5)
200        )
201
202    @override_qengines
203    def test_int16_reference_module(self):
204
205        class RefM(torch.nn.Module):
206            def __init__(self) -> None:
207                super().__init__()
208                self.conv = nn.ConvTranspose2d(1, 1, 1)
209                self.quant1 = QuantStub()
210                self.dequant1 = DeQuantStub()
211                self.quant2 = QuantStub()
212                self.dequant2 = DeQuantStub()
213
214            def forward(self, x):
215                x = self.quant1(x)
216                x = self.dequant1(x)
217                x = self.conv(x)
218                x = self.quant2(x)
219                x = self.dequant2(x)
220                return x
221
222
223        input_size = (16, 1, 10, 10)
224        data = torch.randn(*input_size, dtype=torch.float)
225
226        original_ref_m = RefM()
227        rand_w = torch.randn_like(original_ref_m.conv.weight)
228        rand_b = torch.randn_like(original_ref_m.conv.bias)
229        original_ref_m.conv.weight = torch.nn.Parameter(rand_w, requires_grad=False)
230        original_ref_m.conv.bias = torch.nn.Parameter(rand_b, requires_grad=False)
231
232        qengine = torch.backends.quantized.engine
233        if qengine not in supported_qengines:
234            return
235        from torch.ao.quantization.observer import MovingAverageMinMaxObserver
236
237        weight_obs = MovingAverageMinMaxObserver.with_args(
238            dtype=torch.qint32,
239            # set qmin and qmax to represent qint16
240            quant_min=-1 * (2 ** 15),
241            quant_max=(2 ** 15) - 1,
242            qscheme=torch.per_tensor_symmetric,
243        )
244        act_obs = MovingAverageMinMaxObserver.with_args(
245            dtype=torch.qint32,
246            quant_min=-1 * (2 ** 15),
247            quant_max=(2 ** 15) - 1,
248        )
249        custom_qconfig = QConfig(activation=act_obs, weight=weight_obs)
250
251        # quantize the reference model
252        original_ref_m.eval()
253        original_ref_m.qconfig = custom_qconfig
254
255        ref_m = prepare(original_ref_m)
256        # calibration
257        ref_m(torch.randn(*input_size, dtype=torch.float))
258
259        ref_m = convert(ref_m, is_reference=True)
260
261        myobs = MovingAverageMinMaxObserver(averaging_constant=0.5,
262                                            dtype=torch.qint32,
263                                            # set qmin and qmax to represent qint16
264                                            quant_min=-1 * (2 ** 15),
265                                            quant_max=(2 ** 15) - 1,
266                                            qscheme=torch.per_tensor_symmetric,
267                                            )
268        result = myobs(rand_w)
269        qparams = myobs.calculate_qparams()
270        self.assertEqual(ref_m.conv.weight_scale, qparams[0])
271
272
273    def _test_activation_op_impl(
274            self, float_module_class, quantized_module_class, extra_module_kwargs):
275        """ Implementation for testing common activation ops like leaky relu
276        Args:
277            extra_module_kwargs: keyword args to instantiate the float module
278        """
279        class M(torch.nn.Module):
280            def __init__(self) -> None:
281                super().__init__()
282                self.activation_op = float_module_class(**extra_module_kwargs)
283                self.quant = QuantStub()
284                self.dequant = DeQuantStub()
285
286            def forward(self, x):
287                x = self.quant(x)
288                x = self.activation_op(x)
289                x = self.dequant(x)
290                return x
291
292        m = M().eval()
293        m.qconfig = default_qconfig
294        m = prepare(m)
295        self.checkObservers(m)
296        m = convert(m)
297        self.assertEqual(type(m.activation_op), quantized_module_class)
298
299    def test_leaky_relu(self):
300        self._test_activation_op_impl(nn.LeakyReLU, nnq.LeakyReLU, {'negative_slope': 0.1, 'inplace': False})
301
302    def test_relu(self):
303        self._test_activation_op_impl(nn.ReLU, nn.ReLU, {'inplace': False})
304
305    # Histogram Observers are slow, so have no-deadline to ensure test doesn't time out
306    @given(train_mode=st.booleans())
307    def test_functional_module(self, train_mode):
308        model = ModelWithFunctionals()
309        x = torch.rand(10, 1, dtype=torch.float)
310        xq = torch.quantize_per_tensor(x, 0.01, 30, torch.quint8)
311        self.checkScriptable(model, [[x]], check_save_load=True)
312        if train_mode:
313            model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
314            model = prepare_qat(model)
315        else:
316            model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
317            model = prepare(model)
318        # Check if observers and quant/dequant nodes are inserted
319        self.checkNoPrepModules(model)
320        self.checkObservers(model)
321        # Calibrate
322        model(xq.dequantize())
323        model = convert(model)
324
325        def checkQuantized(model):
326            self.checkNoPrepModules(model)
327            self.assertEqual(type(model.myadd), torch.ao.nn.quantized.QFunctional)
328            self.assertEqual(type(model.mycat), torch.ao.nn.quantized.QFunctional)
329            self.assertEqual(type(model.myadd_relu), torch.ao.nn.quantized.QFunctional)
330            self.assertEqual(type(model.mymatmul), torch.ao.nn.quantized.QFunctional)
331            self.checkNoQconfig(model)
332
333        checkQuantized(model)
334        self.checkScriptable(model, [[xq]], check_save_load=True)
335
336class TestQuantizeEagerPTQStatic(QuantizationTestCase):
337
338    def test_single_layer(self):
339        r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
340        to nnq.Linear which is the quantized version of the module
341        """
342        for qengine in supported_qengines:
343            with override_quantized_engine(qengine):
344                qconfig = torch.ao.quantization.get_default_qconfig(qengine)
345                model = AnnotatedSingleLayerLinearModel(qengine)
346                model.qconfig = qconfig
347                model = prepare(model)
348                # Check if observers and quant/dequant nodes are inserted
349                self.checkNoPrepModules(model)
350                self.checkHasPrepModules(model.fc1)
351                self.checkObservers(model)
352
353                test_only_eval_fn(model, self.calib_data)
354                model = convert(model)
355
356                def checkQuantized(model):
357                    self.checkNoPrepModules(model)
358                    self.checkHasPrepModules(model.fc1)
359                    self.checkWrappedQuantizedLinear(model.fc1)
360                    test_only_eval_fn(model, self.calib_data)
361                    self.checkScriptable(model, self.calib_data)
362                    self.checkNoQconfig(model)
363
364                checkQuantized(model)
365
366                # test one line API - out of place version
367                base = AnnotatedSingleLayerLinearModel(qengine)
368                base.qconfig = qconfig
369                keys_before = set(base.state_dict().keys())
370                model = quantize(base, test_only_eval_fn, [self.calib_data])
371                checkQuantized(model)
372                keys_after = set(base.state_dict().keys())
373                self.assertEqual(keys_before, keys_after)  # simple check that nothing changed
374
375                # in-place version
376                model = AnnotatedSingleLayerLinearModel(qengine)
377                model.qconfig = qconfig
378                quantize(model, test_only_eval_fn, [self.calib_data], inplace=True)
379                checkQuantized(model)
380
381    @skipIfNoFBGEMM
382    def test_two_layers(self):
383        r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
384        `fc2`, and `fc1`is not quantized
385        """
386        with override_quantized_engine('fbgemm'):
387            model = AnnotatedTwoLayerLinearModel()
388            model = prepare(model)
389
390            self.checkNoPrepModules(model)
391            self.checkObservers(model)
392            self.checkNoPrepModules(model.fc1)
393            self.checkHasPrepModules(model.fc2)
394
395            test_only_eval_fn(model, self.calib_data)
396            model = convert(model)
397
398            def checkQuantized(model):
399                self.checkNoPrepModules(model)
400                self.checkNoPrepModules(model.fc1)
401                self.checkHasPrepModules(model.fc2)
402                self.assertEqual(type(model.fc1), torch.nn.Linear)
403                self.checkWrappedQuantizedLinear(model.fc2)
404                test_only_eval_fn(model, self.calib_data)
405                self.checkScriptable(model, self.calib_data)
406                self.checkNoQconfig(model)
407
408            checkQuantized(model)
409
410            # test one line API
411            model = quantize(AnnotatedTwoLayerLinearModel(), test_only_eval_fn,
412                             [self.calib_data])
413            checkQuantized(model)
414
415    def test_nested1(self):
416        r"""Test quantization for nested model, top level 'fc3' and
417        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
418        """
419        for qengine in supported_qengines:
420            with override_quantized_engine(qengine):
421                model = AnnotatedNestedModel(qengine)
422
423                def checkPrepModules(model, before_calib=False):
424                    if before_calib:
425                        self.checkObservers(model)
426                    self.checkNoPrepModules(model)
427                    self.checkNoPrepModules(model.sub1)
428                    self.checkNoPrepModules(model.sub1.fc)
429                    self.checkNoPrepModules(model.sub1.relu)
430                    self.checkNoPrepModules(model.sub2)
431                    self.checkHasPrepModules(model.sub2.fc1)
432                    self.checkNoPrepModules(model.sub2.fc2)
433                    self.checkHasPrepModules(model.fc3)
434
435                model = prepare(model)
436                checkPrepModules(model, True)
437                test_only_eval_fn(model, self.calib_data)
438                model = convert(model)
439
440                def checkQuantized(model):
441                    checkPrepModules(model)
442                    self.checkLinear(model.sub1.fc)
443                    self.checkWrappedQuantizedLinear(model.fc3)
444                    self.checkWrappedQuantizedLinear(model.sub2.fc1)
445                    self.checkLinear(model.sub2.fc2)
446                    test_only_eval_fn(model, self.calib_data)
447                    self.checkScriptable(model, self.calib_data)
448                    self.checkNoQconfig(model)
449
450                checkQuantized(model)
451
452                # test one line API
453                model = quantize(AnnotatedNestedModel(qengine), test_only_eval_fn,
454                                 [self.calib_data])
455                checkQuantized(model)
456
457
458    @skipIfNoFBGEMM
459    def test_nested2(self):
460        model = AnnotatedSubNestedModel()
461        model = prepare(model)
462
463        def checkPrepModules(model, before_calib=False):
464            if before_calib:
465                self.checkObservers(model)
466            self.checkNoPrepModules(model)
467            self.checkNoPrepModules(model.sub1)
468            self.checkNoPrepModules(model.sub1.fc)
469            self.checkNoPrepModules(model.sub1.relu)
470            self.checkHasPrepModules(model.sub2)
471            self.checkNoPrepModules(model.sub2.module.fc1)
472            self.checkNoPrepModules(model.sub2.module.fc2)
473            self.checkHasPrepModules(model.fc3)
474
475        checkPrepModules(model, True)
476
477        test_only_eval_fn(model, self.calib_data)
478        model = convert(model)
479
480        def checkQuantized(model):
481            checkPrepModules(model)
482            self.checkLinear(model.sub1.fc)
483            self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
484            self.checkQuantizedLinear(model.sub2.module.fc1)
485            self.checkQuantizedLinear(model.sub2.module.fc2)
486            self.checkWrappedQuantizedLinear(model.fc3)
487            test_only_eval_fn(model, self.calib_data)
488            self.checkScriptable(model, self.calib_data)
489            self.checkNoQconfig(model)
490
491        checkQuantized(model)
492
493        # test one line API
494        model = quantize(AnnotatedSubNestedModel(), test_only_eval_fn,
495                         [self.calib_data])
496        checkQuantized(model)
497
498    def test_nested3(self):
499        r"""More complicated nested test case with child qconfig overrides
500        parent qconfig
501        """
502        for qengine in supported_qengines:
503            with override_quantized_engine(qengine):
504                model = AnnotatedCustomConfigNestedModel()
505                model = prepare(model)
506
507                def checkPrepModules(model, before_calib=False):
508                    if before_calib:
509                        self.checkObservers(model)
510                    self.checkNoPrepModules(model)
511                    self.checkNoPrepModules(model.sub1)
512                    self.checkNoPrepModules(model.sub1.fc)
513                    self.checkNoPrepModules(model.sub1.relu)
514                    self.checkNoPrepModules(model.sub2)
515                    self.checkHasPrepModules(model.sub2.fc1)
516                    self.checkHasPrepModules(model.sub2.fc2)
517                    self.checkHasPrepModules(model.fc3)
518
519                checkPrepModules(model, True)
520
521                test_only_eval_fn(model, self.calib_data)
522                model = convert(model)
523
524                def checkQuantized(model):
525                    checkPrepModules(model)
526                    self.checkWrappedQuantizedLinear(model.sub2.fc1)
527                    self.checkWrappedQuantizedLinear(model.sub2.fc2)
528                    self.checkWrappedQuantizedLinear(model.fc3)
529                    test_only_eval_fn(model, self.calib_data)
530                    self.checkScriptable(model, self.calib_data)
531                    self.checkNoQconfig(model)
532
533                checkQuantized(model)
534
535                # test one line API
536                model = quantize(AnnotatedCustomConfigNestedModel(), test_only_eval_fn,
537                                 [self.calib_data])
538                checkQuantized(model)
539
540    def test_skip_quant(self):
541        r"""The case when we want to skip quantizing some layers
542        """
543        for qengine in supported_qengines:
544            with override_quantized_engine(qengine):
545                model = AnnotatedSkipQuantModel(qengine)
546                model = prepare(model)
547                self.checkObservers(model)
548
549                test_only_eval_fn(model, self.calib_data)
550                model = convert(model)
551
552                def checkQuantized(model):
553                    self.checkLinear(model.fc)
554                    self.checkQuantDequant(model.sub)
555                    self.checkQuantizedLinear(model.sub.module.fc1)
556                    self.checkQuantizedLinear(model.sub.module.fc2)
557                    self.assertEqual(type(model.sub.module.relu1), nn.ReLU)
558                    self.assertEqual(type(model.sub.module.relu2), nn.ReLU)
559                    self.checkScriptable(model, self.calib_data)
560                    self.checkNoQconfig(model)
561
562                checkQuantized(model)
563
564                # test one line API
565                model = quantize(AnnotatedSkipQuantModel(qengine), test_only_eval_fn, [self.calib_data])
566                checkQuantized(model)
567
568    @skipIfNoFBGEMM
569    def test_manual(self):
570        r"""User inserts QuantStub and DeQuantStub in model code
571        and call the quantization utility functions.
572        """
573        model = QuantStubModel()
574        # propagate the qconfig of parents to children, model is changed
575        # inplace
576        model = prepare(model)
577        self.checkObservers(model)
578
579        test_only_eval_fn(model, self.calib_data)
580        model = convert(model)
581
582        def checkQuantized(model):
583            self.assertEqual(type(model.fc), nnq.Linear)
584            test_only_eval_fn(model, self.calib_data)
585            self.checkScriptable(model, self.calib_data)
586            self.checkNoQconfig(model)
587
588        checkQuantized(model)
589
590        # test one line API
591        model = quantize(QuantStubModel(), test_only_eval_fn, [self.calib_data])
592        checkQuantized(model)
593
594    def test_resnet_base(self):
595        r"""Test quantization for bottleneck topology used in resnet/resnext
596        and add coverage for conversion of average pool and float functional
597        """
598        for qengine in supported_qengines:
599            with override_quantized_engine(qengine):
600                qconfig = torch.ao.quantization.get_default_qconfig(qengine)
601                model = ResNetBase().float().eval()
602                model.fuse_model()
603                model = QuantWrapper(model)
604                model.qconfig = qconfig
605                model = prepare(model)
606                self.checkObservers(model)
607                test_only_eval_fn(model, self.img_data_2d)
608                model = convert(model)
609
610                def checkQuantized(model):
611                    self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d)
612                    self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
613                    self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
614                    self.assertEqual(type(model.module.fc), nnq.Linear)
615
616                    test_only_eval_fn(model, self.img_data_2d)
617                    self.checkNoQconfig(model)
618
619                checkQuantized(model)
620
621    @skipIfNoFBGEMM
622    def test_normalization(self):
623        r"""
624        Test quantization of normalization layers
625        """
626        model = NormalizationTestModel()
627        model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
628        prepare(model, inplace=True)
629        self.checkObservers(model)
630        test_only_eval_fn(model, self.calib_data)
631        model = convert(model)
632
633        def checkQuantized(model):
634            self.checkNoPrepModules(model.layer_norm)
635            self.checkNoPrepModules(model.group_norm)
636            self.checkNoPrepModules(model.instance_norm1d)
637            self.checkNoPrepModules(model.instance_norm2d)
638            self.checkNoPrepModules(model.instance_norm3d)
639            self.assertEqual(type(model.layer_norm), nnq.LayerNorm)
640            self.assertEqual(type(model.group_norm), nnq.GroupNorm)
641            self.assertEqual(type(model.instance_norm1d), nnq.InstanceNorm1d)
642            self.assertEqual(type(model.instance_norm2d), nnq.InstanceNorm2d)
643            self.assertEqual(type(model.instance_norm3d), nnq.InstanceNorm3d)
644            test_only_eval_fn(model, self.calib_data)
645            self.checkScriptable(model, self.calib_data)
646            self.checkNoQconfig(model)
647
648        checkQuantized(model)
649
650        model_oneline = quantize(
651            NormalizationTestModel(), test_only_eval_fn, [self.calib_data])
652        checkQuantized(model)
653
654    def test_save_load_state_dict(self):
655        r"""Test PTQ flow of creating a model and quantizing it and saving the quantized state_dict
656        Load the quantized state_dict for eval and compare results against original model
657        """
658
659        for qengine in supported_qengines:
660            with override_quantized_engine(qengine):
661                model = TwoLayerLinearModel()
662                model = torch.ao.quantization.QuantWrapper(model)
663                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
664
665                model = prepare(model)
666                # calibrate
667                test_only_eval_fn(model, self.calib_data)
668                model = convert(model)
669                x = torch.rand(2, 5, dtype=torch.float)
670                ref = model(x)
671
672                quant_state_dict = model.state_dict()
673
674                # Create model again for eval
675                model = TwoLayerLinearModel()
676                model = torch.ao.quantization.QuantWrapper(model)
677                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
678                model = prepare(model)
679                model = convert(model)
680                new_state_dict = model.state_dict()
681
682                # Check to make sure the state dict keys match original model after convert.
683                self.assertEqual(set(new_state_dict.keys()), set(quant_state_dict.keys()))
684
685                model.load_state_dict(quant_state_dict)
686
687                out = model(x)
688                self.assertEqual(ref, out)
689
690    @skipIfNoFBGEMM
691    def test_activations(self):
692        r"""
693        Test quantization of activations
694        """
695        model = ActivationsTestModel()
696        model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
697        prepare(model, inplace=True)
698        self.checkObservers(model)
699        test_only_eval_fn(model, self.calib_data)
700        model = convert(model)
701
702        def checkQuantized(model):
703            self.checkNoPrepModules(model.hardswish)
704            self.assertEqual(type(model.hardswish), nnq.Hardswish)
705            self.assertEqual(type(model.elu), nnq.ELU)
706            test_only_eval_fn(model, self.calib_data)
707            self.checkScriptable(model, self.calib_data)
708            self.checkNoQconfig(model)
709
710        checkQuantized(model)
711
712        # test one line API
713        model_oneline = quantize(ActivationsTestModel(), test_only_eval_fn,
714                                 [self.calib_data])
715        checkQuantized(model_oneline)
716
717    @override_qengines
718    def test_forward_hooks_preserved(self):
719        r"""Test post-training static quantization on preserving
720        pre forward and post forward hooks of original model
721        """
722        qengine = torch.backends.quantized.engine
723        model = QuantStubModel()
724        counter = {
725            'pre_forwards': 0,
726            'forwards': 0,
727        }
728
729        def fw_pre_hook(h_module, input):
730            counter['pre_forwards'] += 1
731
732        def fw_hook(h_module, input, output):
733            counter['forwards'] += 1
734
735        model.fc.register_forward_pre_hook(fw_pre_hook)
736        model.fc.register_forward_hook(fw_hook)
737
738        model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
739        model = prepare(model)
740
741        def checkHooksIsPresent(model, before_convert=True):
742            num_fwd_hooks = 1
743            if before_convert:
744                self.assertEqual(len(model.quant._forward_hooks.values()), 1,
745                                 "Quantization observer hook has disappeared")
746                num_fwd_hooks = 2
747
748            self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
749            self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
750            self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
751                             "Extra pre forward hooks have appeared on a layer")
752            # During static quantization non stub layers are provided with quantization observer hook too
753            self.assertEqual(len(model.fc._forward_hooks.values()), num_fwd_hooks,
754                             "Extra post forward hooks have appeared on a layer")
755            # Implicitly check that fw_hook goes after _observer_forward_hook
756            self.assertEqual(list(model.fc._forward_hooks.values())[-1], fw_hook,
757                             "_observer_forward_hook is not a first entry of the hooks list")
758
759        checkHooksIsPresent(model, True)
760        test_only_eval_fn(model, self.calib_data)
761        torch.ao.quantization.convert(model, inplace=True)
762        checkHooksIsPresent(model, False)
763
764    @skipIfNoFBGEMM
765    def test_quantized_embedding(self):
766        r""" Test the post-training quantization flow, serialization and scripting
767        of embedding modules
768        """
769
770        for qconfig in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]:
771            model = EmbeddingModule().eval()
772            indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
773            weights = torch.randn(10, 12, dtype=torch.float32)
774            model.qconfig = qconfig
775            prepare(model, inplace=True)
776            convert(model, inplace=True)
777            self.assertTrue('QuantizedEmbedding' in str(model))
778            self.assertEqual(type(model.emb), torch.ao.nn.quantized.Embedding)
779            self.checkScriptable(model, [[indices]], check_save_load=True)
780
781            idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
782            offsets = torch.LongTensor([0, 4])
783            x = torch.randn(2, 4)
784            model = EmbeddingWithStaticLinear().eval()
785            prepare(model, inplace=True)
786            convert(model, inplace=True)
787            self.assertTrue('QuantizedEmbedding' in str(model))
788            self.assertTrue('QuantizedLinear' in str(model))
789            self.checkQuantizedLinear(model.fc)
790            model(idx, offsets, x)
791
792    @skipIfNoFBGEMM
793    def test_dequant_stub(self):
794        m = QuantStubModel().eval()
795        prepare(m, inplace=True)
796        self.checkObservers(m)
797        convert(m, inplace=True)
798        self.assertEqual(type(m.quant), nnq.Quantize)
799        self.assertEqual(type(m.fc), nnq.Linear)
800        self.assertEqual(type(m.dequant), nnq.DeQuantize)
801
802        # check DeQuantStub is not swapped when it doesn't have a qconfig
803        m2 = QuantStubModel().eval()
804        m2.dequant.qconfig = None
805        prepare(m2, inplace=True)
806        self.checkObservers(m2)
807        convert(m2, inplace=True)
808        self.assertEqual(type(m2.quant), nnq.Quantize)
809        self.assertEqual(type(m2.fc), nnq.Linear)
810        self.assertEqual(type(m2.dequant), DeQuantStub)
811
812
813    def test_quantized_embedding_bag(self):
814        r""" Test the post-training quantization flow, serialization and scripting
815        of embedding_bag modules
816        """
817        indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
818        offsets = torch.tensor([0, 19, 20, 28, 28, 32])
819        weights = torch.randn(10, 12, dtype=torch.float32)
820
821        for dtype in [torch.quint8, torch.quint4x2]:
822            model = EmbeddingBagModule().eval()
823            float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
824                                                                        qscheme=torch.per_channel_affine_float_qparams,
825                                                                        ch_axis=0)
826            float_qparams_qconfig = QConfig(activation=default_dynamic_quant_observer,
827                                            weight=float_qparams_observer)
828            model.qconfig = float_qparams_qconfig
829
830            prepare(model, inplace=True)
831            quantized_model = convert(model)
832
833            per_sample_weights = torch.from_numpy(np.random.uniform(
834                low=0.01, high=0.5, size=[len(indices)]).astype(np.float32))
835
836            # Test to make sure module is quantized correctly.
837            self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model))
838            self.checkDynamicQuantizedModule(quantized_model.emb, torch.ao.nn.quantized.EmbeddingBag, torch.quint8)
839            self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True)
840
841            class EmbeddingBagWithLinear(torch.nn.Module):
842                def __init__(self) -> None:
843                    super().__init__()
844                    self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
845                                                     include_last_offset=True, scale_grad_by_freq=False, mode='sum')
846                    self.fc = torch.nn.Linear(5, 5)
847
848                def forward(self, indices, offsets, per_sample_weights, linear_in):
849                    return self.emb(indices, offsets, per_sample_weights), self.fc(linear_in)
850
851            # Test quantization of embedding_bag layer only
852            model2 = EmbeddingBagWithLinear().eval()
853            model2.emb.qconfig = float_qparams_qconfig
854            prepare(model2, inplace=True)
855            quantized_model = convert(model2)
856
857            self.assertTrue('QuantizedEmbeddingBag' in str(quantized_model))
858            self.checkLinear(model2.fc)
859            self.checkDynamicQuantizedModule(quantized_model.emb, torch.ao.nn.quantized.EmbeddingBag, torch.quint8)
860
861    @skipIfNoFBGEMM
862    def test_custom_module_class(self):
863        class CustomModule(torch.nn.Module):
864            def __init__(self) -> None:
865                super().__init__()
866                self.conv = torch.nn.Conv2d(1, 1, 1)
867
868            def forward(self, x):
869                return self.conv(x)
870
871        class ObservedCustomModule(torch.nn.Module):
872            def __init__(self, conv):
873                super().__init__()
874                self.conv = conv
875
876            def forward(self, x):
877                return self.conv(x)
878
879            @classmethod
880            def from_float(cls, float_module):
881                assert hasattr(float_module, 'qconfig')
882                observed = cls(float_module.conv)
883                observed.qconfig = float_module.qconfig
884                return observed
885
886        class QuantizedCustomModule(torch.nn.Module):
887            def __init__(self, conv):
888                super().__init__()
889                self.conv = conv
890
891            def forward(self, x):
892                return self.conv(x)
893
894            @classmethod
895            def from_observed(cls, observed_module):
896                assert hasattr(observed_module, 'qconfig')
897                assert hasattr(observed_module, 'activation_post_process')
898                observed_module.conv.activation_post_process = \
899                    observed_module.activation_post_process
900                quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
901                return quantized
902
903        class Sub(torch.nn.Module):
904            def __init__(self) -> None:
905                super().__init__()
906                self.custom = CustomModule()
907
908            def forward(self, x):
909                return self.custom(x)
910
911        class M(torch.nn.Module):
912            def __init__(self) -> None:
913                super().__init__()
914                self.quant = QuantStub()
915                self.conv = torch.nn.Conv2d(1, 1, 1)
916                self.sub = Sub()
917                self.dequant = DeQuantStub()
918
919            def forward(self, x):
920                x = self.quant(x)
921                x = self.conv(x)
922                x = self.sub(x)
923                x = self.dequant(x)
924                return x
925
926        class RefM(torch.nn.Module):
927            def __init__(self) -> None:
928                super().__init__()
929                self.quant = QuantStub()
930                self.conv1 = torch.nn.Conv2d(1, 1, 1)
931                self.conv2 = torch.nn.Conv2d(1, 1, 1)
932                self.dequant = DeQuantStub()
933
934            def forward(self, x):
935                x = self.quant(x)
936                x = self.conv1(x)
937                x = self.conv2(x)
938                x = self.dequant(x)
939                return x
940
941        data = torch.randn(1, 1, 1, 1)
942        # instantiate M and RefM and align the parameters
943        original_m = M()
944        original_ref_m = RefM()
945        original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
946        original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
947        original_ref_m.conv2.weight = torch.nn.Parameter(original_m.sub.custom.conv.weight.detach())
948        original_ref_m.conv2.bias = torch.nn.Parameter(original_m.sub.custom.conv.bias.detach())
949
950        original_m.qconfig = default_qconfig
951        prepare_custom_config_dict = {
952            "float_to_observed_custom_module_class": {
953                CustomModule: ObservedCustomModule
954            }
955        }
956        convert_custom_config_dict = {
957            "observed_to_quantized_custom_module_class": {
958                ObservedCustomModule: QuantizedCustomModule
959            }
960        }
961        m = prepare(
962            original_m,
963            prepare_custom_config_dict=prepare_custom_config_dict)
964        self.checkObservers(m, None, prepare_custom_config_dict)
965        # calibration
966        m(data)
967        # all activation observers are inserted in the top level module
968
969        # check converted/quantized model
970        m = convert(
971            m,
972            convert_custom_config_dict=convert_custom_config_dict)
973        # check if the module is properly quantized
974        self.assertEqual(type(m.quant), nnq.Quantize)
975        self.assertEqual(type(m.conv), nnq.Conv2d)
976        self.assertEqual(type(m.sub), Sub)
977        self.assertEqual(type(m.sub.custom), QuantizedCustomModule)
978        self.assertEqual(type(m.sub.custom.conv), nnq.Conv2d)
979        self.assertEqual(type(m.dequant), nnq.DeQuantize)
980        res = m(data)
981
982        # quantize the reference model
983        original_ref_m.eval()
984        original_ref_m.qconfig = default_qconfig
985        ref_m = prepare(original_ref_m)
986        ref_m(data)
987        ref_m = convert(ref_m)
988        ref_res = ref_m(data)
989        self.assertEqual(res, ref_res)
990
991    @skipIfNoFBGEMM
992    def test_convtranspose_per_channel_fails_early(self):
993        r"""
994        Verifies that attempting to quantize a ConvTranspose module with per-Channel
995        weight observers fails in the prepare step, as opposed to the convert step.
996        """
997        m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
998        m.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
999        with self.assertRaises(AssertionError) as context:
1000            mp = torch.ao.quantization.prepare(m)
1001        self.assertTrue(
1002            str(context.exception) ==
1003            'Per channel weight observer is not supported yet for ConvTranspose{n}d.')
1004
1005    @skipIfNoFBGEMM
1006    def test_convtranspose_per_channel_qconfig_none(self):
1007        r"""
1008        Verifies that having qconfig==None for conv transpose does not crash
1009        """
1010        m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
1011        m.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
1012        m[0].qconfig = None
1013        mp = torch.ao.quantization.prepare(m)
1014
1015    @skipIfNoFBGEMM
1016    def test_quantwrapper_attaches_qconfig_to_dequant(self):
1017        qconfig = torch.ao.quantization.default_qconfig
1018
1019        m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
1020        for i in range(len(m)):
1021            m[i].qconfig = qconfig
1022            m[i] = torch.ao.quantization.QuantWrapper(m[i])
1023
1024        mp = torch.ao.quantization.prepare(m)
1025        mq = torch.ao.quantization.convert(mp)
1026        self.assertTrue(isinstance(mq[0].dequant, nnq.DeQuantize))
1027
1028    def test_activations_in_non_leaf_module_list(self):
1029        """
1030        Ensure activations like `nn.Sigmoid` and `nn.Tanh` are properly handled in
1031        `non_leaf_module_list`.
1032        """
1033        class MyModel(torch.nn.Module):
1034            def __init__(self) -> None:
1035                super().__init__()
1036                self.quant = QuantStub()
1037                self.sigmoid = torch.nn.Sigmoid()
1038                self.hardsigmoid = torch.nn.Hardsigmoid()
1039                self.softmax = torch.nn.Softmax()
1040                self.tanh = torch.nn.Tanh()
1041                self.dequant = DeQuantStub()
1042
1043            def forward(self, x):
1044                x = self.quant(x)
1045                x = self.sigmoid(x)
1046                x = self.hardsigmoid(x)
1047                x = self.softmax(x)
1048                x = self.tanh(x)
1049                x = self.dequant(x)
1050                return x
1051
1052        qconfig = QConfig(
1053            activation=FixedQParamsObserver.with_args(scale=123.0, zero_point=0),
1054            weight=default_weight_observer
1055        )
1056        m = MyModel()
1057        m.qconfig = qconfig
1058        m = prepare(m, observer_non_leaf_module_list=[
1059            torch.nn.Sigmoid,
1060            torch.nn.Hardsigmoid,
1061            torch.nn.Softmax,
1062            torch.nn.Tanh,
1063        ])
1064
1065        # Should use the observer specified in the QConfig instead of the default (FixedQParamsFakeQuantize)
1066        self.assertTrue(isinstance(m.sigmoid.activation_post_process, FixedQParamsObserver))
1067        self.assertTrue(isinstance(m.hardsigmoid.activation_post_process, FixedQParamsObserver))
1068        self.assertTrue(isinstance(m.softmax.activation_post_process, FixedQParamsObserver))
1069        self.assertTrue(isinstance(m.tanh.activation_post_process, FixedQParamsObserver))
1070
1071    @skipIfNoFBGEMM
1072    def test_mha_batch_first_attr_is_copied_in_prepare(self):
1073        class TransformerDecoderLayer(nn.Module):
1074            def __init__(self, d_model, nhead, batch_first):
1075                super().__init__()
1076                self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.1, batch_first=batch_first)
1077
1078        qengine = torch.backends.quantized.engine
1079        for batch_first in [True, False]:
1080            model = TransformerDecoderLayer(512, 8, batch_first)
1081            quantization_config = torch.ao.quantization.get_default_qconfig(qengine)
1082            model.qconfig = quantization_config
1083            prepared_model = torch.ao.quantization.prepare(model, inplace=False)
1084            self.assertTrue(prepared_model.self_attn.batch_first == model.self_attn.batch_first)
1085
1086@skipIfNoFBGEMM
1087class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
1088    def test_single_layer(self):
1089        r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
1090        make sure it is swapped to nnqd.Linear which is the quantized version of
1091        the module
1092        """
1093        for dtype in [torch.qint8, torch.float16]:
1094            model = SingleLayerLinearDynamicModel().eval()
1095            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1096            qconfig_dict = {
1097                'fc1': qconfig
1098            }
1099            prepare_dynamic(model, qconfig_dict)
1100            convert_dynamic(model)
1101
1102            def checkQuantized(model):
1103                self.checkDynamicQuantizedLinear(model.fc1, dtype)
1104                self.checkScriptable(model, self.calib_data, check_save_load=True)
1105                self.checkNoQconfig(model)
1106
1107            checkQuantized(model)
1108
1109            # test one line API - out of place version
1110            base = SingleLayerLinearDynamicModel()
1111            keys_before = set(base.state_dict().keys())
1112            model = quantize_dynamic(base, qconfig_dict)
1113            checkQuantized(model)
1114            keys_after = set(base.state_dict().keys())
1115            self.assertEqual(keys_before, keys_after)  # simple check that nothing changed
1116
1117            # in-place version
1118            model = SingleLayerLinearDynamicModel()
1119            quantize_dynamic(model, qconfig_dict, inplace=True)
1120            checkQuantized(model)
1121
1122            # Test set qconfig
1123            model = SingleLayerLinearDynamicModel()
1124            quantize_dynamic(model, {nn.Linear}, inplace=True, dtype=dtype)
1125            checkQuantized(model)
1126
1127    def test_two_layers(self):
1128        r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one
1129        `fc2`, and `fc1`is not quantized
1130        """
1131        for dtype in [torch.qint8, torch.float16]:
1132            model = TwoLayerLinearModel().eval()
1133            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1134            qconfig_dict = {
1135                'fc2': qconfig
1136            }
1137            prepare_dynamic(model, qconfig_dict)
1138
1139            convert_dynamic(model)
1140
1141            def checkQuantized(model):
1142                self.assertEqual(type(model.fc1), torch.nn.Linear)
1143                self.checkDynamicQuantizedLinear(model.fc2, dtype=dtype)
1144                self.checkScriptable(model, self.calib_data, check_save_load=True)
1145                self.checkNoQconfig(model)
1146
1147            checkQuantized(model)
1148
1149            # test one line API
1150            model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict)
1151            checkQuantized(model)
1152
1153            # Test set API
1154            model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}, dtype=dtype)
1155            checkQuantized(model)
1156
1157    def test_nested1(self):
1158        r"""Test quantization for nested model, top level 'fc3' and
1159        'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized
1160        """
1161        for dtype in [torch.qint8, torch.float16]:
1162            model = NestedModel().eval()
1163            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1164            qconfig_dict = {
1165                'fc3': qconfig,
1166                'sub2.fc1': qconfig
1167            }
1168
1169            prepare_dynamic(model, qconfig_dict)
1170            convert_dynamic(model)
1171
1172            def checkQuantized(model):
1173                self.checkLinear(model.sub1.fc)
1174                self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
1175                self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
1176                self.checkLinear(model.sub2.fc2)
1177                self.checkScriptable(model, self.calib_data, check_save_load=True)
1178                self.checkNoQconfig(model)
1179
1180            checkQuantized(model)
1181
1182            # test one line API
1183            model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
1184            checkQuantized(model)
1185
1186            model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}, dtype=dtype)
1187            checkQuantized(model)
1188
1189    def test_nested2(self):
1190        r"""Another test case for quantized, we will quantize all submodules
1191        of submodule sub2
1192        """
1193        for dtype in [torch.qint8, torch.float16]:
1194            model = NestedModel().eval()
1195            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1196            qconfig_dict = {
1197                'fc3': qconfig,
1198                'sub2': qconfig
1199            }
1200            prepare_dynamic(model, qconfig_dict)
1201
1202            convert_dynamic(model)
1203
1204            def checkQuantized(model):
1205                self.checkLinear(model.sub1.fc)
1206                self.assertEqual(type(model.sub1.relu), torch.nn.ReLU)
1207                self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
1208                self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
1209                self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
1210                self.checkScriptable(model, self.calib_data, check_save_load=True)
1211                self.checkNoQconfig(model)
1212
1213            checkQuantized(model)
1214
1215            # test one line API
1216            model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
1217            checkQuantized(model)
1218
1219            # Test set API
1220            model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}, dtype=dtype)
1221            checkQuantized(model)
1222
1223    def test_nested3(self):
1224        r"""More complicated nested test case with child qconfig overrides
1225        parent qconfig
1226        """
1227        for dtype in [torch.qint8, torch.float16]:
1228            model = NestedModel().eval()
1229            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1230            qconfig_dynamic_dict = {
1231                'fc3': qconfig,
1232                'sub2': qconfig,
1233                'sub2.fc1': qconfig
1234            }
1235            prepare_dynamic(model, qconfig_dynamic_dict)
1236
1237            convert_dynamic(model)
1238
1239            def checkQuantized(model):
1240                self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=dtype)
1241                self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
1242                self.checkDynamicQuantizedLinear(model.fc3, dtype=dtype)
1243                self.checkScriptable(model, self.calib_data, check_save_load=True)
1244                self.checkNoQconfig(model)
1245
1246            checkQuantized(model)
1247
1248            # test one line API
1249            model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict)
1250            checkQuantized(model)
1251
1252            # Test set API
1253            model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}, dtype=dtype)
1254            checkQuantized(model)
1255
1256    def test_type_match_rule(self):
1257        r"""Test quantization for nested model, top level 'fc3' and
1258        'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized
1259        """
1260        for dtype in [torch.qint8, torch.float16]:
1261            model = NestedModel().eval()
1262            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1263            qconfig_dict = {
1264                'fc3': None,
1265                'sub2.fc1': None,
1266                torch.nn.Linear: qconfig
1267            }
1268
1269            prepare_dynamic(model, qconfig_dict)
1270            test_only_eval_fn(model, self.calib_data)
1271            convert_dynamic(model)
1272
1273            def checkQuantized(model):
1274                self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=dtype)
1275                self.checkLinear(model.fc3)
1276                self.checkLinear(model.sub2.fc1)
1277                self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=dtype)
1278                test_only_eval_fn(model, self.calib_data)
1279                self.checkScriptable(model, self.calib_data, check_save_load=True)
1280                self.checkNoQconfig(model)
1281
1282            checkQuantized(model)
1283
1284            # test one line API
1285            model = quantize_dynamic(NestedModel().eval(), qconfig_dict, dtype=dtype)
1286            checkQuantized(model)
1287
1288    def test_per_channel_linear_quantize(self):
1289        r"""Test quantization for per_channel dynamic quantization
1290        """
1291        model = NestedModel().eval()
1292        qconfig_dict = {
1293            torch.nn.Linear: per_channel_dynamic_qconfig
1294        }
1295
1296        prepare_dynamic(model, qconfig_dict)
1297        test_only_eval_fn(model, self.calib_data)
1298        convert_dynamic(model)
1299
1300        def checkQuantized(model):
1301            self.checkDynamicQuantizedLinear(model.sub1.fc, dtype=torch.qint8)
1302            self.checkDynamicQuantizedLinear(model.fc3, dtype=torch.qint8)
1303            self.checkDynamicQuantizedLinear(model.sub2.fc1, dtype=torch.qint8)
1304            self.checkDynamicQuantizedLinear(model.sub2.fc2, dtype=torch.qint8)
1305            test_only_eval_fn(model, self.calib_data)
1306            self.checkScriptable(model, self.calib_data, check_save_load=True)
1307            self.checkNoQconfig(model)
1308
1309        checkQuantized(model)
1310        # test one line API
1311        model = quantize_dynamic(NestedModel().eval(), qconfig_dict)
1312        checkQuantized(model)
1313
1314    def test_linear_relu_fusion(self):
1315        dtype = torch.qint8
1316        model = LinearReluLinearModel().eval()
1317        qconfig = default_dynamic_qconfig
1318        qconfig_dict = {'' : qconfig}
1319        torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
1320        prepare_dynamic(model, qconfig_dict)
1321        convert_dynamic(model)
1322
1323        def checkQuantized(model):
1324            self.checkDynamicQuantizedLinearRelu(model.fc1, dtype)
1325            self.checkDynamicQuantizedLinear(model.fc2, dtype)
1326            self.checkScriptable(model, self.calib_data, check_save_load=True)
1327            self.checkNoQconfig(model)
1328
1329        checkQuantized(model)
1330
1331    @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]),
1332           dtype=st.sampled_from([torch.qint8, torch.float16]))
1333    def test_quantized_rnn(self, qconfig, dtype):
1334        r"""Test dynamic quantization, scriptability and serialization for dynamic quantized lstm modules on int8 and fp16
1335        """
1336        niter = 10
1337        x = torch.tensor([[100, -155],
1338                          [-155, 100],
1339                          [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
1340        qconfig_dict = {
1341            torch.nn.LSTM : qconfig,
1342            torch.nn.GRU: qconfig
1343        }
1344
1345        def checkQuantized(model, module_type):
1346            mod_type_map = {'LSTM': torch.ao.nn.quantized.dynamic.LSTM,
1347                            'GRU': torch.ao.nn.quantized.dynamic.GRU}
1348            mod_repr_map = {'LSTM': 'DynamicQuantizedLSTM',
1349                            'GRU': 'DynamicQuantizedGRU'}
1350            self.assertTrue(mod_repr_map[module_type] in str(model_quantized))
1351            self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype)
1352
1353        for module_type in ['LSTM', 'GRU']:
1354            model = RNNDynamicModel(module_type).eval()
1355
1356            if dtype == torch.float16:
1357                model_quantized = quantize_dynamic(model=model, dtype=dtype)
1358            else:
1359                model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype)
1360
1361            checkQuantized(model_quantized, module_type)
1362            self.checkScriptable(model_quantized, [[x]], check_save_load=True)
1363
1364            class ScriptWrapperPackedLSTM(torch.nn.Module):
1365                def __init__(self, cell):
1366                    super().__init__()
1367                    self.cell = cell
1368
1369                def forward(self, x: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
1370                    return self.cell(x)
1371
1372            class ScriptWrapperPackedGRU(torch.nn.Module):
1373                def __init__(self, cell):
1374                    super().__init__()
1375                    self.cell = cell
1376
1377                def forward(self, x: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]:
1378                    return self.cell(x)
1379
1380            script_wrapper_map = {'LSTM': ScriptWrapperPackedLSTM,
1381                                  'GRU': ScriptWrapperPackedGRU}
1382            packed_input = torch.nn.utils.rnn.pack_padded_sequence(x, torch.tensor([10, 5, 2]))
1383            model_with_packed_input = script_wrapper_map[module_type](model_quantized.mod)
1384            model_with_packed_input(packed_input)
1385            scripted = torch.jit.script(model_with_packed_input)
1386            scripted(packed_input)
1387            # We cannot trace with input dtype being a packed sequence
1388            self._checkScriptable(model_with_packed_input, scripted, [[packed_input]], True)
1389
1390
1391    @given(qconfig=st.sampled_from([per_channel_dynamic_qconfig, default_dynamic_qconfig]),
1392           dtype=st.sampled_from([torch.qint8, torch.float16]))
1393    def test_quantized_rnn_cell(self, qconfig, dtype):
1394        r"""Test dynamic quantization, scriptability and serialization for dynamic quantized rnn cell modules on int8 and fp16
1395        """
1396        qconfig_dict = {
1397            torch.nn.LSTMCell : qconfig,
1398            torch.nn.GRUCell : qconfig,
1399            torch.nn.RNNCell : qconfig
1400        }
1401
1402        for module_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
1403            model = RNNCellDynamicModel(module_type).eval()
1404            x = torch.tensor([[100, -155],
1405                             [-155, 100],
1406                             [100, -155]], dtype=torch.float)
1407
1408            if torch.backends.quantized.engine == 'qnnpack' and dtype == torch.float16:
1409                continue
1410                # fp16 dynamic quant is not supported for qnnpack
1411
1412            if dtype == torch.float16:
1413                model_quantized = quantize_dynamic(model=model, dtype=dtype)
1414            else:
1415                model_quantized = quantize_dynamic(model=model, qconfig_spec=qconfig_dict, dtype=dtype)
1416
1417            def checkQuantized(model, module_type):
1418                mod_type_map = {'LSTMCell': torch.ao.nn.quantized.dynamic.LSTMCell,
1419                                'GRUCell': torch.ao.nn.quantized.dynamic.GRUCell,
1420                                'RNNTanh': torch.ao.nn.quantized.dynamic.RNNCell,
1421                                'RNNReLU': torch.ao.nn.quantized.dynamic.RNNCell}
1422
1423                mod_repr_map = {'LSTMCell': 'DynamicQuantizedLSTMCell',
1424                                'GRUCell': 'DynamicQuantizedGRUCell',
1425                                'RNNTanh': 'DynamicQuantizedRNNCell',
1426                                'RNNReLU': 'DynamicQuantizedRNNCell'}
1427
1428                self.assertTrue(mod_repr_map[module_type] in str(model_quantized))
1429                self.checkDynamicQuantizedModule(model_quantized.mod, mod_type_map[module_type], dtype)
1430                self.checkNoQconfig(model)
1431
1432            # Smoke test extra reprs
1433            checkQuantized(model_quantized, module_type)
1434            self.checkScriptable(model_quantized, [[x]], check_save_load=True)
1435
1436
1437    def test_forward_hooks_preserved(self):
1438        r"""Test post-training dynamic quantization on preserving
1439        pre forward and post forward hooks of original model
1440        """
1441        for dtype in [torch.qint8, torch.float16]:
1442            model = SingleLayerLinearDynamicModel().eval()
1443            qconfig = float16_dynamic_qconfig if dtype == torch.float16 else default_dynamic_qconfig
1444            qconfig_dict = {
1445                'fc1': qconfig
1446            }
1447            convert_dynamic(model)
1448
1449            counter = {
1450                'pre_forwards': 0,
1451                'forwards': 0,
1452            }
1453
1454            def fw_pre_hook(h_module, input):
1455                counter['pre_forwards'] += 1
1456
1457            def fw_hook(h_module, input, output):
1458                counter['forwards'] += 1
1459
1460            model.fc1.register_forward_pre_hook(fw_pre_hook)
1461            model.fc1.register_forward_hook(fw_hook)
1462            prepare_dynamic(model, qconfig_dict)
1463
1464            def checkHooksIsPresent(model):
1465                self.assertObjectIn(fw_pre_hook, model.fc1._forward_pre_hooks.values())
1466                self.assertObjectIn(fw_hook, model.fc1._forward_hooks.values())
1467                self.assertEqual(len(model.fc1._forward_pre_hooks.values()), 1,
1468                                 "Extra pre forward hooks have appeared on a layer")
1469                self.assertEqual(len(model.fc1._forward_hooks.values()), 1,
1470                                 "Extra post forward hooks have appeared on a layer")
1471
1472            checkHooksIsPresent(model)
1473            test_only_eval_fn(model, self.calib_data)
1474            convert_dynamic(model)
1475            checkHooksIsPresent(model)
1476
1477    @skipIfNoFBGEMM
1478    def test_embedding_bag_dynamic(self):
1479        class EmbeddingBagWithLinear(torch.nn.Module):
1480            def __init__(self) -> None:
1481                super().__init__()
1482                self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
1483                                                 include_last_offset=True, scale_grad_by_freq=False, mode='sum')
1484                self.fc = torch.nn.Linear(5, 5)
1485
1486            def forward(self, indices, offsets, linear_in):
1487                return self.emb(indices, offsets), self.fc(linear_in)
1488        model = EmbeddingBagWithLinear().eval()
1489
1490        qconfig_dict = {
1491            torch.nn.EmbeddingBag : float_qparams_weight_only_qconfig,
1492            torch.nn.Linear: default_dynamic_qconfig
1493        }
1494        indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
1495        offsets = torch.tensor([0, 19, 20, 28, 28, 32])
1496        q_model = quantize_dynamic(model, qconfig_dict)
1497
1498        q_model(indices, offsets, torch.randn(5, 5))
1499        self.assertTrue('QuantizedEmbeddingBag' in str(q_model.emb))
1500        self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc))
1501
1502    @skipIfNoFBGEMM
1503    def test_embedding_ops_dynamic(self):
1504        class EmbeddingWithLinear(torch.nn.Module):
1505            def __init__(self) -> None:
1506                super().__init__()
1507                self.emb = torch.nn.Embedding(
1508                    num_embeddings=10, embedding_dim=12, scale_grad_by_freq=False)
1509                self.fc = torch.nn.Linear(5, 5)
1510
1511            def forward(self, indices, linear_in):
1512                return self.emb(indices), self.fc(linear_in)
1513        model = EmbeddingWithLinear().eval()
1514        qconfig_dict = {
1515            torch.nn.Embedding : float_qparams_weight_only_qconfig,
1516            torch.nn.Linear: default_dynamic_qconfig
1517        }
1518        indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
1519        q_model = quantize_dynamic(model, qconfig_dict)
1520        self.assertTrue('QuantizedEmbedding' in str(q_model.emb))
1521        self.assertTrue('DynamicQuantizedLinear' in str(q_model.fc))
1522        q_model(indices, torch.randn(5, 5))
1523
1524if __name__ == '__main__':
1525    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
1526                       "\tpython test/test_quantization.py TESTNAME\n\n"
1527                       "instead.")
1528