xref: /aosp_15_r20/external/pytorch/test/quantization/eager/test_quantize_eager_qat.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import copy
4import math
5
6import torch
7import torch.ao.nn.intrinsic.qat as nniqat
8import torch.ao.nn.qat as nnqat
9import torch.ao.nn.qat.dynamic as nnqatd
10import torch.ao.nn.quantized as nnq
11import torch.ao.nn.quantized.dynamic as nnqd
12import torch.backends.mkldnn
13import torch.nn as nn
14import torch.testing._internal.hypothesis_utils as hu
15
16from hypothesis import given, strategies as st
17from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
18from torch.ao.quantization import (
19    convert,
20    default_embedding_qat_qconfig,
21    default_qat_qconfig,
22    default_qconfig,
23    default_symmetric_qnnpack_qat_qconfig,
24    DeQuantStub,
25    FixedQParamsFakeQuantize,
26    FusedMovingAvgObsFakeQuantize,
27    get_default_qat_qconfig,
28    get_embedding_qat_module_mappings,
29    get_embedding_static_quant_module_mappings,
30    NoopObserver,
31    prepare,
32    prepare_qat,
33    quantize_qat,
34    QuantStub,
35)
36from torch.ao.quantization.qconfig import qconfig_equals
37from torch.nn import BatchNorm2d, Conv2d, init, ReLU
38from torch.nn.modules.utils import _pair
39from torch.testing._internal.common_quantization import (
40    DeFusedEmbeddingBagLinear,
41    ManualConvLinearQATModel,
42    ManualConvLinearSymmQATModel,
43    ManualDropoutQATModel,
44    ManualEmbeddingBagLinear,
45    ManualLinearDynamicQATModel,
46    ManualLinearQATModel,
47    QuantizationTestCase,
48    QuantStubModel,
49    test_only_eval_fn,
50    test_only_train_fn,
51    TwoLayerLinearModel,
52)
53
54from torch.testing._internal.common_quantized import (
55    override_qengines,
56    override_quantized_engine,
57    supported_qengines,
58)
59
60from torch.testing._internal.common_utils import skipIfNoXNNPACK
61
62hu.assert_deadline_disabled()
63from functools import reduce
64
65class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
66    """
67    Conv-BN fusion implemented with explicit folding. Useful
68    to verify numerical equivalency with non-folded version.
69    """
70    def __init__(self,
71                 # ConvNd args
72                 in_channels, out_channels, kernel_size, stride,
73                 padding, dilation, transposed, output_padding,
74                 groups,
75                 bias,
76                 padding_mode,
77                 # BatchNormNd args
78                 # num_features: out_channels
79                 eps=1e-05, momentum=0.1,
80                 # affine: True
81                 # track_running_stats: True
82                 # Args for this module
83                 freeze_bn=False,
84                 qconfig=None):
85        nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
86                                         stride, padding, dilation, transposed,
87                                         output_padding, groups, False, padding_mode)
88        assert qconfig, 'qconfig must be provided for QAT module'
89        self.qconfig = qconfig
90        self.eps = eps
91        self.momentum = momentum
92        self.freeze_bn = freeze_bn if self.training else True
93        self.num_features = out_channels
94        self.gamma = nn.Parameter(torch.empty(out_channels))
95        self.beta = nn.Parameter(torch.empty(out_channels))
96        self.affine = True
97        self.track_running_stats = True
98        self.running_mean = nn.Buffer(torch.zeros(out_channels))
99        self.running_var = nn.Buffer(torch.ones(out_channels))
100        self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long))
101        self.activation_post_process = self.qconfig.activation()
102        self.weight_fake_quant = self.qconfig.weight()
103        if bias:
104            self.bias = nn.Parameter(torch.empty(out_channels))
105        else:
106            self.register_parameter('bias', None)
107        self.reset_bn_parameters()
108
109    def reset_running_stats(self):
110        self.running_mean.zero_()
111        self.running_var.fill_(1)
112        self.num_batches_tracked.zero_()
113
114    def reset_bn_parameters(self):
115        self.reset_running_stats()
116        init.uniform_(self.gamma)
117        init.zeros_(self.beta)
118        if self.bias is not None:
119            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
120            bound = 1 / math.sqrt(fan_in)
121            init.uniform_(self.bias, -bound, bound)
122
123    def reset_parameters(self):
124        super().reset_parameters()
125        # A hack to avoid resetting on undefined parameters
126        if hasattr(self, 'gamma'):
127            self.reset_bn_parameters()
128
129    def update_bn_stats(self):
130        self.freeze_bn = False
131        return self
132
133    def freeze_bn_stats(self):
134        self.freeze_bn = True
135        return self
136
137    def _forward(self, input):
138        # exponential_average_factor is self.momentum set to
139        # (when it is available) only so that if gets updated
140        # in ONNX graph when this node is exported to ONNX.
141        if self.momentum is None:
142            exponential_average_factor = 0.0
143        else:
144            exponential_average_factor = self.momentum
145
146        if self.training and not self.freeze_bn and self.track_running_stats:
147            # TODO: if statement only here to tell the jit to skip emitting this when it is None
148            if self.num_batches_tracked is not None:
149                self.num_batches_tracked += 1
150                if self.momentum is None:  # use cumulative moving average
151                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
152                else:  # use exponential moving average
153                    exponential_average_factor = self.momentum
154
155        # we use running statistics from the previous batch, so this is an
156        # approximation of the approach mentioned in the whitepaper, but we only
157        # need to do one convolution in this case instead of two
158        running_std = torch.sqrt(self.running_var + self.eps)
159        scale_factor = self.gamma / running_std
160        scaled_weight = self.weight * scale_factor.reshape([-1, 1, 1, 1])
161        if self.bias is not None:
162            zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
163        else:
164            zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
165        conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)
166
167        if self.training and not self.freeze_bn:
168            # recovering original conv to get original batch_mean and batch_var
169            if self.bias is not None:
170                conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
171            else:
172                conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
173            batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
174            batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
175            n = float(conv_orig.numel() / conv_orig.size()[1])
176            unbiased_batch_var = batch_var * (n / (n - 1))
177            batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
178
179            conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
180                (self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
181            self.running_mean = exponential_average_factor * batch_mean.detach() + \
182                (1 - exponential_average_factor) * self.running_mean
183            self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
184                (1 - exponential_average_factor) * self.running_var
185        else:
186            if self.bias is None:
187                conv = conv + (self.beta - self.gamma * self.running_mean /
188                               running_std).reshape([1, -1, 1, 1])
189            else:
190                conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
191        return conv
192
193    def extra_repr(self):
194        # TODO(jerryzh): extend
195        return super().extra_repr()
196
197    def forward(self, input):
198        return self.activation_post_process(self._forward(input))
199
200    @classmethod
201    def from_float(cls, mod, qconfig=None):
202        r"""Create a qat module from a float module or qparams_dict
203            Args: `mod` a float module, either produced by torch.ao.quantization utilities
204            or directly from user
205        """
206        assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
207            cls._FLOAT_MODULE.__name__
208        if not qconfig:
209            assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
210            assert mod.qconfig, 'Input float module must have a valid qconfig'
211            qconfig = mod.qconfig
212        conv, bn = mod[0], mod[1]
213        qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
214                         conv.stride, conv.padding, conv.dilation,
215                         conv.groups, conv.bias is not None,
216                         conv.padding_mode,
217                         bn.eps, bn.momentum,
218                         False,
219                         qconfig)
220        qat_convbn.weight = conv.weight
221        qat_convbn.bias = conv.bias
222        qat_convbn.gamma = bn.weight
223        qat_convbn.beta = bn.bias
224        qat_convbn.running_mean = bn.running_mean
225        qat_convbn.running_var = bn.running_var
226        qat_convbn.num_batches_tracked = bn.num_batches_tracked
227        return qat_convbn
228
229class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
230    _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
231
232    def __init__(self,
233                 # ConvNd args
234                 in_channels, out_channels, kernel_size, stride=1,
235                 padding=0, dilation=1, groups=1,
236                 bias=None,
237                 padding_mode='zeros',
238                 # BatchNorm2d args
239                 # num_features: out_channels
240                 eps=1e-05, momentum=0.1,
241                 # affine: True
242                 # track_running_stats: True
243                 # Args for this module
244                 freeze_bn=False,
245                 qconfig=None):
246        kernel_size = _pair(kernel_size)
247        stride = _pair(stride)
248        padding = _pair(padding)
249        dilation = _pair(dilation)
250        _ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
251                                    padding, dilation, False, _pair(0), groups, bias, padding_mode,
252                                    eps, momentum, freeze_bn, qconfig)
253
254class TestQuantizeEagerQAT(QuantizationTestCase):
255    def setUp(self):
256        super().setUp()
257
258        self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long),
259                                         torch.randn((12, 1), dtype=torch.float)]
260                                        for _ in range(2)]
261        self.embed_data = [[torch.randint(0, 10, (12, 1))]]
262
263
264    def test_manual(self):
265        for qengine in supported_qengines:
266            with override_quantized_engine(qengine):
267                model = ManualLinearQATModel(qengine)
268                model = prepare_qat(model)
269                self.checkObservers(model)
270                test_only_train_fn(model, self.train_data)
271                model = convert(model)
272
273                def checkQuantized(model):
274                    self.assertEqual(type(model.fc1), nnq.Linear)
275                    self.assertEqual(type(model.fc2), nnq.Linear)
276                    test_only_eval_fn(model, self.calib_data)
277                    self.checkScriptable(model, self.calib_data)
278                    self.checkNoQconfig(model)
279
280                checkQuantized(model)
281
282                model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
283                                     [self.train_data])
284                checkQuantized(model)
285
286    def test_dropout(self):
287        for qengine in supported_qengines:
288            with override_quantized_engine(qengine):
289                model = ManualDropoutQATModel(qengine)
290                model = prepare_qat(model)
291                self.checkObservers(model)
292                test_only_train_fn(model, self.train_data)
293                model = convert(model)
294
295                def checkQuantized(model):
296                    self.assertEqual(type(model.fc1), nnq.Linear)
297                    self.assertEqual(type(model.dropout), nnq.Dropout)
298                    test_only_eval_fn(model, self.calib_data)
299                    self.checkScriptable(model, self.calib_data)
300                    self.checkNoQconfig(model)
301
302                checkQuantized(model)
303
304                model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn,
305                                     [self.train_data])
306                checkQuantized(model)
307
308    def test_eval_only_fake_quant(self):
309        r"""Using FakeQuant in evaluation only mode,
310        this is useful for estimating accuracy loss when we quantize the
311        network
312        """
313        for qengine in supported_qengines:
314            with override_quantized_engine(qengine):
315                model = ManualLinearQATModel(qengine)
316
317                model = prepare_qat(model)
318                self.checkObservers(model)
319
320                model.eval()
321                test_only_eval_fn(model, self.calib_data)
322
323    def test_conv_linear(self):
324        for qengine in supported_qengines:
325            with override_quantized_engine(qengine):
326                model = ManualConvLinearQATModel()
327
328                model = prepare_qat(model)
329                self.checkObservers(model)
330
331                test_only_train_fn(model, self.img_data_2d_train)
332                model = convert(model)
333
334                def checkQuantized(model):
335                    self.assertEqual(type(model.conv), nnq.Conv2d)
336                    self.assertEqual(type(model.fc1), nnq.Linear)
337                    self.assertEqual(type(model.fc2), nnq.Linear)
338                    test_only_eval_fn(model, self.img_data_2d)
339                    self.checkScriptable(model, self.img_data_2d)
340                    self.checkNoQconfig(model)
341
342                checkQuantized(model)
343
344                model = ManualConvLinearQATModel()
345                model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
346                checkQuantized(model)
347
348    @skipIfNoXNNPACK
349    def test_conv_linear_symm(self):
350        r"""Same as test_conv_linear but with Symmetric quantization.
351        Supported only with qengine=qnnpack, which uses symmetric
352        kernels from xnnpack library."""
353        for qengine in supported_qengines:
354            if qengine != 'qnnpack':
355                continue
356            with override_quantized_engine(qengine):
357                model = ManualConvLinearSymmQATModel()
358
359                model = prepare_qat(model)
360                self.checkObservers(model)
361
362                test_only_train_fn(model, self.img_data_2d_train)
363                model = convert(model)
364
365                def checkQuantized(model):
366                    self.assertEqual(type(model.conv), nnq.Conv2d)
367                    self.assertEqual(type(model.fc1), nnq.Linear)
368                    self.assertEqual(type(model.fc2), nnq.Linear)
369                    test_only_eval_fn(model, self.img_data_2d)
370                    self.checkScriptable(model, self.img_data_2d)
371                    self.checkNoQconfig(model)
372
373                checkQuantized(model)
374
375                model = ManualConvLinearSymmQATModel()
376                model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
377                checkQuantized(model)
378
379    def test_dynamic_qat_linear(self):
380        for qengine in supported_qengines:
381            with override_quantized_engine(qengine):
382                # Dynamic QAT without memoryless observers should fail
383                with self.assertRaisesRegex(ValueError,
384                                            "Dynamic QAT requires a memoryless observer." +
385                                            "This means a MovingAverage observer with averaging constant equal to 1"
386                                            ):
387                    model = ManualLinearDynamicQATModel(default_qat_qconfig)
388                    model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
389
390                model = ManualLinearDynamicQATModel()
391                model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
392                self.assertEqual(type(model.fc1), nnqatd.Linear)
393                self.assertEqual(type(model.fc2), nnqatd.Linear)
394                self.checkObservers(model)
395                test_only_train_fn(model, self.train_data)
396                model = convert(model, mapping={nnqatd.Linear: nnqd.Linear})
397                self.assertEqual(type(model.fc1), nnqd.Linear)
398                self.assertEqual(type(model.fc2), nnqd.Linear)
399                test_only_eval_fn(model, self.calib_data)
400                self.checkScriptable(model, self.calib_data)
401                self.checkNoQconfig(model)
402
403    def test_defused_embedding_bag_linear(self):
404        for qengine in supported_qengines:
405            with override_quantized_engine(qengine):
406                model = DeFusedEmbeddingBagLinear().train()
407                model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
408                self.checkObservers(model)
409
410                test_only_train_fn(model, self.embed_linear_data_train)
411                # make sure activation_post_process is inserted after Linear.
412                self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
413                # make sure that Embedding has a noop for activation.
414                self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
415                # make sure that FakeQuant zero_points are correct dtype
416                self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
417                self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
418
419                model = convert(model, mapping=get_embedding_static_quant_module_mappings())
420
421                def checkQuantized(model):
422                    # make sure Embedding is now a QuantizedEmbedding
423                    self.assertEqual(type(model.emb), nn.quantized.Embedding)
424                    # make sure Linear is now a QuantizedLinear
425                    self.assertEqual(type(model.linear), nn.quantized.Linear)
426
427                    test_only_eval_fn(model, self.embed_data)
428                    self.checkScriptable(model, self.embed_data)
429                    self.checkNoQconfig(model)
430
431                checkQuantized(model)
432
433
434    def test_embedding_bag_linear(self):
435        for qengine in supported_qengines:
436            with override_quantized_engine(qengine):
437                model = ManualEmbeddingBagLinear().train()
438                model = prepare_qat(model, mapping=get_embedding_qat_module_mappings())
439                self.checkObservers(model)
440
441                test_only_train_fn(model, self.embed_linear_data_train)
442                # make sure not activation_post_process is inserted for EmbeddingBag
443                self.assertFalse(hasattr(model, "activation_post_process"))
444                # make sure that FakeQuant zero_points are correct dtype
445                self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
446                self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
447                model = convert(model, mapping=get_embedding_static_quant_module_mappings())
448
449                def checkQuantized(model):
450                    # Make sure EmbeddingBag is now a quantized EmbeddingBag.
451                    self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
452                    # Also test that Linear has been quantized.
453                    self.assertTrue(type(model.linear), nnq.Linear)
454
455                    test_only_eval_fn(model, self.embed_data)
456                    self.checkScriptable(model, self.embed_data)
457                    self.checkNoQconfig(model)
458
459                checkQuantized(model)
460
461                model = ManualEmbeddingBagLinear()
462
463    def test_train_save_load_eval(self):
464        r"""Test QAT flow of creating a model, doing QAT and saving the quantized state_dict
465        During eval, we first call prepare_qat and conver on the model and then load the state_dict
466        and compare results against original model
467        """
468        for qengine in supported_qengines:
469            with override_quantized_engine(qengine):
470                model = TwoLayerLinearModel()
471                model = torch.ao.quantization.QuantWrapper(model)
472                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
473                model = prepare_qat(model)
474
475                fq_state_dict = model.state_dict()
476
477                test_only_train_fn(model, self.train_data)
478                model = convert(model)
479
480                quant_state_dict = model.state_dict()
481
482                x = torch.rand(2, 5, dtype=torch.float)
483                ref = model(x)
484
485                # Create model again for eval. Check result using quantized state_dict
486                model = TwoLayerLinearModel()
487                model = torch.ao.quantization.QuantWrapper(model)
488                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
489                torch.ao.quantization.prepare_qat(model, inplace=True)
490                new_state_dict = model.state_dict()
491
492                # Check to make sure the model after prepare_qat has the same state_dict as original.
493                self.assertEqual(set(fq_state_dict.keys()), set(new_state_dict.keys()))
494
495                torch.ao.quantization.convert(model, inplace=True)
496                model.eval()
497                model.load_state_dict(quant_state_dict)
498                out = model(x)
499                self.assertEqual(ref, out)
500
501                # Check model created using prepare has same state dict as quantized state_dict
502                model = TwoLayerLinearModel()
503                model.eval()
504                model = torch.ao.quantization.QuantWrapper(model)
505                model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
506                torch.ao.quantization.prepare(model, inplace=True)
507                torch.ao.quantization.convert(model, inplace=True)
508                self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
509                model.eval()
510                model.load_state_dict(quant_state_dict)
511                out = model(x)
512                self.assertEqual(ref, out)
513
514    @override_qengines
515    def test_forward_hooks_preserved(self):
516        r"""Test QAT on preserving pre forward and post forward hooks of original model
517        """
518        qengine = torch.backends.quantized.engine
519        model = QuantStubModel()
520        counter = {
521            'pre_forwards': 0,
522            'forwards': 0,
523        }
524
525        def fw_pre_hook(h_module, input):
526            counter['pre_forwards'] += 1
527
528        def fw_hook(h_module, input, output):
529            counter['forwards'] += 1
530
531        model.fc.register_forward_pre_hook(fw_pre_hook)
532        model.fc.register_forward_hook(fw_hook)
533
534        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
535        model = prepare_qat(model)
536
537        def checkHooksIsPresent(model, before_convert=True):
538            forward_hooks = 1
539            if before_convert:
540                self.assertEqual(len(model.quant._forward_hooks.values()), 1,
541                                 "Quantization observer hook has disappeared")
542                forward_hooks = 2
543            self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
544            self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
545            self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
546                             "Extra pre forward hooks have appeared on a layer")
547            self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks,
548                             "Extra post forward hooks have appeared on a layer")
549
550        checkHooksIsPresent(model, True)
551        x = torch.rand(2, 5, dtype=torch.float)
552        model(x)
553        torch.ao.quantization.convert(model, inplace=True)
554        checkHooksIsPresent(model, False)
555
556    def test_add_scalar_uses_input_qparams(self):
557        class M(torch.nn.Module):
558            def __init__(self) -> None:
559                super().__init__()
560                self.quant = torch.ao.quantization.QuantStub()
561                self.ff = torch.ao.nn.quantized.FloatFunctional()
562
563            def forward(self, x):
564                x = self.quant(x)
565                x = self.ff.add_scalar(x, 1.0)
566                return x
567
568        m = M()
569        m.qconfig = torch.ao.quantization.default_qconfig
570        mp = torch.ao.quantization.prepare_qat(m)
571        mp(torch.randn(4, 4))
572        mq = torch.ao.quantization.convert(mp)
573        res = mq(torch.randn(4, 4))
574        eps = 1e-5
575        self.assertTrue(torch.abs(mq.quant.scale - res.q_scale()) < eps)
576
577    def test_mul_scalar_uses_input_qparams(self):
578        class M(torch.nn.Module):
579            def __init__(self) -> None:
580                super().__init__()
581                self.quant = torch.ao.quantization.QuantStub()
582                self.ff = torch.ao.nn.quantized.FloatFunctional()
583
584            def forward(self, x):
585                x = self.quant(x)
586                x = self.ff.mul_scalar(x, 2.0)
587                return x
588
589        m = M()
590        m.qconfig = torch.ao.quantization.default_qconfig
591        mp = torch.ao.quantization.prepare_qat(m)
592        mp(torch.randn(4, 4))
593        mq = torch.ao.quantization.convert(mp)
594        res = mq(torch.randn(4, 4))
595        eps = 1e-5
596        self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps)
597
598    @override_qengines
599    def test_qat_embedding_bag_errors(self):
600        default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
601
602        # Test constructor parameters checks here.
603        with self.assertRaisesRegex(AssertionError,
604                                    "qconfig must be provided for QAT module"):
605            nnqat.EmbeddingBag(10, 5, qconfig=None)
606
607        with self.assertRaisesRegex(AssertionError,
608                                    "Embedding Bag weights requires a qscheme of " +
609                                    "torch.per_channel_affine_float_qparams"):
610            nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
611
612        # Test from_float checks here.
613        embed = nn.Embedding(10, 5)
614        with self.assertRaisesRegex(AssertionError,
615                                    "qat.EmbeddingBag.from_float only works for EmbeddingBag"):
616            nnqat.EmbeddingBag.from_float(embed)
617        embed_bag = nn.EmbeddingBag(10, 5)
618        with self.assertRaisesRegex(AssertionError,
619                                    "Input float module must have qconfig defined"):
620            nnqat.EmbeddingBag.from_float(embed_bag)
621        embed_bag.qconfig = None
622        with self.assertRaisesRegex(AssertionError,
623                                    "Input float module must have a valid qconfig"):
624            nnqat.EmbeddingBag.from_float(embed_bag)
625        embed_bag.qconfig = default_qat_qconfig
626        with self.assertRaisesRegex(AssertionError,
627                                    "Embedding Bag weights requires a qscheme of " +
628                                    "torch.per_channel_affine_float_qparams"):
629            nnqat.EmbeddingBag.from_float(embed_bag)
630
631    def test_embedding_qat_qconfig_equal(self):
632        # Embedding QAT uses a NoopObserver class for activation,
633        # and a FakeQuant for weight, make sure that qconfig comparison
634        # functions properly for a mix of partial function and class in
635        # qconfig.
636        model = ManualEmbeddingBagLinear().train()
637        model = prepare_qat(model)
638
639        self.assertTrue(qconfig_equals(model.emb.qconfig,
640                                       default_embedding_qat_qconfig))
641
642class TestQuantizeEagerQATNumerics(QuantizationTestCase):
643    def _test_activation_convert_numerics_impl(self, Act, data):
644        class M(torch.nn.Module):
645            def __init__(self) -> None:
646                super().__init__()
647                self.act = Act()
648                self.quant = QuantStub()
649                self.dequant = DeQuantStub()
650
651            def forward(self, x):
652                x = self.quant(x)
653                x = self.act(x)
654                x = self.dequant(x)
655                return x
656
657        m = M().train()
658        m.qconfig = default_qat_qconfig
659        m = prepare_qat(m)
660        before_convert = m(data)
661        m = convert(m)
662        after_convert = m(data)
663        self.assertEqual(before_convert, after_convert)
664
665    def test_fixed_qparam_ops(self):
666        class M(torch.nn.Module):
667            def __init__(self) -> None:
668                super().__init__()
669                self.sigmoid = torch.nn.Sigmoid()
670                self.hardsigmoid = torch.nn.Hardsigmoid()
671                self.tanh = torch.nn.Tanh()
672                self.quant = QuantStub()
673                self.dequant = DeQuantStub()
674
675            def forward(self, x):
676                x = self.quant(x)
677                x = self.sigmoid(x)
678                x = self.hardsigmoid(x)
679                x = self.tanh(x)
680                x = self.dequant(x)
681                return x
682
683        m = M().train()
684        m.qconfig = default_qat_qconfig
685        m = prepare_qat(m)
686        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
687            self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
688        data = torch.randn(1, 3, 2, 4)
689        before_convert = m(data)
690        m = convert(m)
691        after_convert = m(data)
692        self.assertEqual(before_convert, after_convert)
693        # make sure activation post process is removed
694        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
695            # verify fake quant module is removd
696            self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))
697            # verify that hooks are removed
698            self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
699
700        # make sure no fake quantize module is inserted for eval mode
701
702        def checkNoFQModule(m):
703            for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
704                self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
705                self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
706
707        m = M().eval()
708        m.qconfig = default_qconfig
709        m = prepare(m)
710        checkNoFQModule(m)
711        m = convert(m)
712        checkNoFQModule(m)
713
714    def test_leaky_relu(self):
715        data = torch.randn(1, 3, 2, 4)
716        self._test_activation_convert_numerics_impl(nn.LeakyReLU, data)
717
718    def test_relu(self):
719        class M(torch.nn.Module):
720            def __init__(self) -> None:
721                super().__init__()
722                self.relu = nn.ReLU()
723
724            def forward(self, x):
725                x = self.relu(x)
726                return x
727
728        m = M().train()
729        m.qconfig = default_qconfig
730        m = prepare_qat(m)
731        # make sure no activation_post_process is inserted for relu
732        self.assertFalse(hasattr(m, "activation_post_process"))
733        m = convert(m)
734        # make sure ReLU module is not changed
735        self.assertTrue(type(m.relu), nn.ReLU)
736
737    @given(batch_size=st.integers(2, 4),
738           input_channels_per_group=st.sampled_from([2, 3, 4]),
739           height=st.integers(5, 10),
740           width=st.integers(5, 10),
741           output_channels_per_group=st.sampled_from([2, 3]),
742           groups=st.integers(1, 3),
743           kernel_h=st.integers(1, 3),
744           kernel_w=st.integers(1, 3),
745           stride_h=st.integers(1, 2),
746           stride_w=st.integers(1, 2),
747           pad_h=st.integers(0, 2),
748           pad_w=st.integers(0, 2),
749           dilation=st.integers(1, 1),
750           padding_mode=st.sampled_from(['zeros', 'circular']),
751           use_relu=st.booleans(),
752           eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
753           momentum=st.sampled_from([0.1, 0.2, 0.3]),
754           freeze_bn=st.booleans(),
755           zero_gamma=st.booleans(),
756           has_bias=st.booleans(),
757           use_slow_fusion=st.booleans())
758    def test_conv_bn_relu(
759            self,
760            batch_size,
761            input_channels_per_group,
762            height,
763            width,
764            output_channels_per_group,
765            groups,
766            kernel_h,
767            kernel_w,
768            stride_h,
769            stride_w,
770            pad_h,
771            pad_w,
772            dilation,
773            padding_mode,
774            use_relu,
775            eps,
776            momentum,
777            freeze_bn,
778            zero_gamma,
779            has_bias,
780            use_slow_fusion,
781    ):
782        input_channels = input_channels_per_group * groups
783        output_channels = output_channels_per_group * groups
784        dilation_h = dilation_w = dilation
785
786        conv_op = Conv2d(
787            input_channels,
788            output_channels,
789            (kernel_h, kernel_w),
790            (stride_h, stride_w),
791            (pad_h, pad_w),
792            (dilation_h, dilation_w),
793            groups,
794            has_bias,
795            padding_mode
796        ).to(dtype=torch.double)
797        bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
798        relu_op = ReLU()
799
800        cls = ConvBnReLU2d if use_relu else ConvBn2d
801        qat_op = cls(
802            input_channels,
803            output_channels,
804            (kernel_h, kernel_w),
805            (stride_h, stride_w),
806            (pad_h, pad_w),
807            (dilation_h, dilation_w),
808            groups,
809            has_bias,
810            padding_mode,
811            eps,
812            momentum,
813            freeze_bn=True,
814            qconfig=default_qat_qconfig
815        ).to(dtype=torch.double)
816        qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
817
818        # the approximate fusion will not work if bn.weight has 0
819        if zero_gamma and use_slow_fusion:
820            torch.nn.init.zeros_(qat_op.bn.weight)
821
822        qat_op.apply(torch.ao.quantization.disable_fake_quant)
823        if freeze_bn:
824            qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
825        else:
826            qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
827
828        # align inputs and internal parameters
829        input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
830        conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
831        if has_bias:
832            conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
833        bn_op.running_mean = qat_op.bn.running_mean.clone()
834        bn_op.running_var = qat_op.bn.running_var.clone()
835        bn_op.weight = torch.nn.Parameter(qat_op.bn.weight.detach())
836        bn_op.bias = torch.nn.Parameter(qat_op.bn.bias.detach())
837
838        def compose(functions):
839            # functions are reversed for natural reading order
840            return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
841
842        if not use_relu:
843            def relu_op(x):  # noqa: F811
844                return x
845
846        if freeze_bn:
847            def ref_op(x):
848                x = conv_op(x)
849                x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
850                    (bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
851                    .reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
852                x = relu_op(x)
853                return x
854        else:
855            ref_op = compose([conv_op, bn_op, relu_op])
856
857        input_clone = input.clone().detach().requires_grad_()
858        for i in range(2):
859            result_ref = ref_op(input)
860            result_actual = qat_op(input_clone)
861            self.assertEqual(result_ref, result_actual)
862
863            # backward
864            dout = torch.randn(result_ref.size(), dtype=torch.double)
865            loss = (result_ref - dout).sum()
866            loss.backward()
867            input_grad_ref = input.grad.cpu()
868            weight_grad_ref = conv_op.weight.grad.cpu()
869            gamma_grad_ref = bn_op.weight.grad.cpu()
870            beta_grad_ref = bn_op.bias.grad.cpu()
871            running_mean_ref = bn_op.running_mean
872            running_var_ref = bn_op.running_var
873            num_batches_tracked_ref = bn_op.num_batches_tracked
874            loss = (result_actual - dout).sum()
875            loss.backward()
876            input_grad_actual = input_clone.grad.cpu()
877            weight_grad_actual = qat_op.weight.grad.cpu()
878            gamma_grad_actual = qat_op.bn.weight.grad.cpu()
879            beta_grad_actual = qat_op.bn.bias.grad.cpu()
880            running_mean_actual = qat_op.bn.running_mean
881            running_var_actual = qat_op.bn.running_var
882            num_batches_tracked_actual = qat_op.bn.num_batches_tracked
883            precision = 1e-10
884            self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
885            self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
886            self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
887            self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
888            self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
889            self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
890            self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
891
892    @given(batch_size=st.integers(2, 4),
893           input_channels_per_group=st.sampled_from([2, 3, 4]),
894           height=st.integers(5, 10),
895           width=st.integers(5, 10),
896           output_channels_per_group=st.sampled_from([2, 3]),
897           groups=st.integers(1, 3),
898           kernel_h=st.integers(1, 3),
899           kernel_w=st.integers(1, 3),
900           stride_h=st.integers(1, 2),
901           stride_w=st.integers(1, 2),
902           pad_h=st.integers(0, 2),
903           pad_w=st.integers(0, 2),
904           dilation=st.integers(1, 1),
905           padding_mode=st.sampled_from(['zeros', 'circular']),
906           eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
907           momentum=st.sampled_from([0.1, 0.2, 0.3]),
908           freeze_bn=st.booleans(),
909           bias=st.booleans())
910    def test_conv_bn_folded_vs_unfolded(
911            self,
912            batch_size,
913            input_channels_per_group,
914            height,
915            width,
916            output_channels_per_group,
917            groups,
918            kernel_h,
919            kernel_w,
920            stride_h,
921            stride_w,
922            pad_h,
923            pad_w,
924            dilation,
925            padding_mode,
926            eps,
927            momentum,
928            freeze_bn,
929            bias,
930    ):
931        input_channels = input_channels_per_group * groups
932        output_channels = output_channels_per_group * groups
933        dilation_h = dilation_w = dilation
934
935        qat_op = ConvBn2d(
936            input_channels,
937            output_channels,
938            (kernel_h, kernel_w),
939            (stride_h, stride_w),
940            (pad_h, pad_w),
941            (dilation_h, dilation_w),
942            groups,
943            bias,  # bias
944            padding_mode,
945            eps,
946            momentum,
947            freeze_bn=freeze_bn,
948            qconfig=default_qat_qconfig
949        ).to(dtype=torch.double)
950
951        qat_ref_op = _ReferenceConvBn2d(
952            input_channels,
953            output_channels,
954            (kernel_h, kernel_w),
955            (stride_h, stride_w),
956            (pad_h, pad_w),
957            (dilation_h, dilation_w),
958            groups,
959            bias,  # bias
960            padding_mode,
961            eps,
962            momentum,
963            freeze_bn=freeze_bn,
964            qconfig=default_qat_qconfig
965        ).to(dtype=torch.double)
966
967        qat_op.apply(torch.ao.quantization.disable_fake_quant)
968        qat_ref_op.apply(torch.ao.quantization.disable_fake_quant)
969
970        # align inputs and internal parameters
971        qat_ref_op.weight = torch.nn.Parameter(qat_op.weight.detach().clone())
972        qat_ref_op.running_mean = qat_op.bn.running_mean.clone()
973        qat_ref_op.running_var = qat_op.bn.running_var.clone()
974        qat_ref_op.gamma = torch.nn.Parameter(qat_op.bn.weight.detach().clone())
975        qat_ref_op.beta = torch.nn.Parameter(qat_op.bn.bias.detach().clone())
976        if qat_op.bias is not None:
977            qat_ref_op.bias = torch.nn.Parameter(qat_op.bias.detach().clone())
978
979        lr = 0.01
980        qat_op_optim = torch.optim.SGD(qat_op.parameters(), lr=lr)
981        qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
982
983        for i in range(5):
984
985            # make sure that calling model.train() does not override the
986            # bn freeze setting
987            qat_op.train()
988            qat_ref_op.train()
989
990            qat_op_optim.zero_grad()
991            qat_ref_op_optim.zero_grad()
992
993            input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
994            input_clone = input.clone().detach().requires_grad_()
995
996            if i > 2:
997                qat_op.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
998                qat_ref_op.freeze_bn_stats()
999
1000            if i > 3:
1001                qat_op.apply(torch.ao.quantization.disable_observer)
1002                qat_ref_op.apply(torch.ao.quantization.disable_observer)
1003
1004            result_ref = qat_ref_op(input)
1005            result_actual = qat_op(input_clone)
1006            self.assertEqual(result_ref, result_actual)
1007
1008            # backward
1009            dout = torch.randn(result_ref.size(), dtype=torch.double) + 10.0
1010
1011            loss = (result_ref - dout).sum()
1012            loss.backward()
1013            input_grad_ref = input.grad.cpu()
1014            weight_grad_ref = qat_ref_op.weight.grad.cpu()
1015            gamma_grad_ref = qat_ref_op.gamma.grad.cpu()
1016            beta_grad_ref = qat_ref_op.beta.grad.cpu()
1017            running_mean_ref = qat_ref_op.running_mean
1018            running_var_ref = qat_ref_op.running_var
1019            num_batches_tracked_ref = qat_ref_op.num_batches_tracked
1020
1021            loss = (result_actual - dout).sum()
1022            loss.backward()
1023            input_grad_actual = input_clone.grad.cpu()
1024            weight_grad_actual = qat_op.weight.grad.cpu()
1025            gamma_grad_actual = qat_op.bn.weight.grad.cpu()
1026            beta_grad_actual = qat_op.bn.bias.grad.cpu()
1027            running_mean_actual = qat_op.bn.running_mean
1028            running_var_actual = qat_op.bn.running_var
1029            num_batches_tracked_actual = qat_op.bn.num_batches_tracked
1030
1031            precision = 1e-5
1032            self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
1033            self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
1034            self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
1035            self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
1036            self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
1037            self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
1038            self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
1039
1040            qat_op_optim.step()
1041            qat_ref_op_optim.step()
1042
1043    @override_qengines
1044    def test_linear_bn_numerics(self):
1045        qengine = torch.backends.quantized.engine
1046        m_ref = nn.Sequential(
1047            nn.Linear(4, 4),
1048            nn.BatchNorm1d(4),
1049        )
1050        m_ref_copy = copy.deepcopy(m_ref)
1051        m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
1052        qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
1053        m_ref_copy[0].qconfig = qconfig
1054        m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
1055
1056        # without fake_quants, fused QAT module should match fp32 module
1057        m.apply(torch.ao.quantization.disable_fake_quant)
1058        data = torch.randn(4, 4)
1059        r1 = m_ref(data)
1060        r2 = m(data)
1061        self.assertTrue(torch.allclose(r1, r2))
1062
1063    @skipIfNoXNNPACK
1064    @override_qengines
1065    def test_linear_bn_symm_numerics(self):
1066        qengine = torch.backends.quantized.engine
1067        if qengine != "qnnpack":
1068            return  # Only qnnpack support symmetric quantization
1069        m_ref = nn.Sequential(
1070            nn.Linear(4, 4),
1071            nn.BatchNorm1d(4),
1072        )
1073        m_ref_copy = copy.deepcopy(m_ref)
1074        m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
1075        qconfig = default_symmetric_qnnpack_qat_qconfig
1076        m_ref_copy[0].qconfig = qconfig
1077        m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
1078
1079        # without fake_quants, fused QAT module should match fp32 module
1080        m.apply(torch.ao.quantization.disable_fake_quant)
1081        data = torch.randn(4, 4)
1082        r1 = m_ref(data)
1083        r2 = m(data)
1084        self.assertTrue(torch.allclose(r1, r2))
1085
1086    @override_qengines
1087    def test_linear_bn_workflow(self):
1088        qengine = torch.backends.quantized.engine
1089        m = nn.Sequential(
1090            QuantStub(),
1091            nn.Linear(4, 4),
1092            nn.BatchNorm1d(4),
1093        )
1094        data = torch.randn(4, 4)
1095        m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
1096        m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
1097        mp = prepare_qat(m)
1098        mp(data)
1099        mq = convert(mp)
1100        self.assertTrue(type(mq[1]) == nnq.Linear)
1101        self.assertTrue(type(mq[2]) == nn.Identity)
1102
1103
1104    @skipIfNoXNNPACK
1105    @override_qengines
1106    def test_linear_precomputed_fake_quant(self):
1107        qengine = torch.backends.quantized.engine
1108        if qengine != "qnnpack":
1109            return  # Only qnnpack support symmetric quantization
1110        m_ref = nn.Linear(4, 4)
1111
1112        m_ref_copy = copy.deepcopy(m_ref)
1113        qconfig = default_qconfig
1114        m_ref_copy.qconfig = qconfig
1115        weight_post_process = copy.deepcopy(qconfig.weight())
1116        activation = copy.deepcopy(qconfig.activation())
1117        activation(torch.randn(4, 4))
1118        m_ref_copy.activation_post_process = activation
1119        m_ref_copy = nnq.Linear.from_float(m_ref_copy)
1120        weight_post_process = qconfig.weight()
1121        weight_post_process.min_val = torch.tensor(-1)
1122        weight_post_process.max_val = torch.tensor(1)
1123        m_ref.weight_post_process = weight_post_process
1124        m_ref.activation_post_process = activation
1125        m_ref.qconfig = qconfig
1126        m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
1127        self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
1128
1129
1130if __name__ == '__main__':
1131    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
1132                       "\tpython test/test_quantization.py TESTNAME\n\n"
1133                       "instead.")
1134