xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_quantized_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4import torch.nn as nn
5import torch.ao.nn.intrinsic as nni
6import torch.ao.nn.intrinsic.quantized as nniq
7import torch.ao.nn.quantized.reference as nnqr
8import torch.ao.quantization
9import torch.ao.nn.quantized as nnq
10import torch.ao.nn.quantized.dynamic as nnqd
11
12from torch.ao.quantization import (
13    get_default_static_quant_module_mappings,
14    default_float_qparams_observer,
15    PerChannelMinMaxObserver,
16)
17from torch.package import PackageExporter, PackageImporter
18from torch.testing._internal.common_quantization import (
19    QuantizationTestCase,
20    prepare_dynamic,
21    _make_conv_test_input,
22    skipIfNoFBGEMM,
23    lengths_to_offsets,
24    skipIfNoONEDNN,
25    _make_conv_add_extra_input_tensor,
26)
27from torch.testing._internal.common_quantized import (
28    _calculate_dynamic_qparams,
29    override_quantized_engine,
30    override_qengines,
31    qengine_is_qnnpack,
32    qengine_is_onednn,
33)
34import torch.fx
35from hypothesis import assume, given
36from hypothesis import strategies as st
37import torch.testing._internal.hypothesis_utils as hu
38hu.assert_deadline_disabled()
39
40import copy
41import io
42import numpy as np
43import itertools
44
45"""
46Note that tests in this file are just API test, to make sure we wrapped the
47quantized operator implementations correctly in the user facing APIs, these are
48not correctness test for the underlying quantized operators. For correctness
49test please see `test/quantization/test_quantized_op.py`.
50"""
51
52class TestStaticQuantizedModule(QuantizationTestCase):
53    def test_relu(self):
54        relu_module = nn.ReLU()
55        relu6_module = nnq.ReLU6()
56
57        x = torch.arange(-10, 10, dtype=torch.float)
58        y_ref = torch.relu(x)
59        y6_ref = torch.nn.modules.ReLU6()(x)
60
61        qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.qint32)
62        qy = relu_module(qx)
63        qy6 = relu6_module(qx)
64
65        self.assertEqual(y_ref, qy.dequantize(),
66                         msg="ReLU module API failed")
67        self.assertEqual(y6_ref, qy6.dequantize(),
68                         msg="ReLU6 module API failed")
69
70    @override_qengines
71    def test_linear(self):
72        """test API functionality for nn.quantized.linear"""
73        options = itertools.product(
74            [1, 5],
75            [16, 32],
76            [4, 8],
77            [True, False],
78            [True, False])
79        for (batch_size, in_features, out_features, use_bias, per_channel) in options:
80            self._test_linear_api_impl(
81                nnq.Linear, 'QuantizedLinear', torch.ops.quantized.linear, batch_size,
82                in_features, out_features, use_bias, per_channel)
83
84    @override_qengines
85    def test_linear_relu(self):
86        """test API functionality for nn.intrinsic.quantized.linear_relu"""
87        options = itertools.product(
88            [1, 5],
89            [16, 32],
90            [4, 8],
91            [True, False],
92            [True, False])
93        for (batch_size, in_features, out_features, use_bias, per_channel) in options:
94            self._test_linear_api_impl(
95                nniq.LinearReLU, 'QuantizedLinearReLU', torch.ops.quantized.linear_relu,
96                batch_size, in_features, out_features, use_bias, per_channel)
97
98    def _test_linear_api_impl(self, qlinear_module, module_name, qlinear_op,
99                              batch_size, in_features, out_features, use_bias,
100                              per_channel, **post_ops_kwargs):
101        if torch.backends.quantized.engine == 'qnnpack':
102            per_channel = False
103
104        W = torch.rand(out_features, in_features).float()
105        if per_channel:
106            scale_tensor = torch.ones(out_features, dtype=torch.double)
107            zero_point_tensor = torch.zeros(out_features, dtype=torch.long)
108            for i in range(len(scale_tensor)):
109                scale_tensor[i] = (i + 1.0) / 255.0
110            W_q = torch.quantize_per_channel(W, scales=scale_tensor,
111                                             zero_points=zero_point_tensor,
112                                             axis=0, dtype=torch.qint8)
113        else:
114            # ONEDNN only supports symmetric quantization of weight
115            W_zp = 0 if qengine_is_onednn() else 4
116            W_q = torch.quantize_per_tensor(W, 0.1, W_zp, torch.qint8)
117
118        X = torch.rand(batch_size, in_features).float()
119        X_q = torch.quantize_per_tensor(X, 0.2, 10, torch.quint8)
120        B = torch.rand(out_features).float() if use_bias else None
121        scale = 0.5
122        zero_point = 3
123        qlinear = qlinear_module(in_features, out_features, **post_ops_kwargs)
124
125        qlinear_copy = copy.deepcopy(qlinear)
126        # set random quantized weight and bias before test torch scriptable
127        qlinear_copy.set_weight_bias(W_q, B)
128        self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
129        # Run module with default-initialized parameters.
130        # This tests that the constructor is correct.
131        qlinear(X_q)
132
133        qlinear.set_weight_bias(W_q, B)
134        # Simple round-trip test to ensure weight()/set_weight() API
135        self.assertEqual(qlinear.weight(), W_q, atol=1e-5, rtol=0)
136
137        # testing packed param implementation
138        qlinear.scale = float(scale)
139        qlinear.zero_point = int(zero_point)
140        Z_q = qlinear(X_q)
141
142        # Check if the module implementation matches calling the
143        # ops directly
144        W_pack = qlinear._packed_params._packed_params
145        Z_ref = qlinear_op(X_q, W_pack, scale, zero_point, **post_ops_kwargs)
146
147        self.assertEqual(Z_ref, Z_q)
148        self.assertTrue(module_name in str(qlinear))
149
150        # Test serialization of quantized Linear Module using state_dict
151        model_dict = qlinear.state_dict()
152        b = io.BytesIO()
153        torch.save(model_dict, b)
154        for weights_only in [True, False]:
155            b.seek(0)
156            loaded_dict = torch.load(b, weights_only=weights_only)
157            for key in model_dict:
158                if isinstance(model_dict[key], torch._C.ScriptObject):
159                    assert isinstance(loaded_dict[key], torch._C.ScriptObject)
160                    w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
161                    w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
162                    self.assertEqual(w_model, w_loaded)
163                    self.assertEqual(b_model, b_loaded)
164                else:
165                    self.assertEqual(model_dict[key], loaded_dict[key])
166
167            loaded_qlinear = qlinear_module(
168                in_features, out_features, **post_ops_kwargs)
169            loaded_qlinear.load_state_dict(loaded_dict)
170            linear_unpack = torch.ops.quantized.linear_unpack
171            self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
172                             linear_unpack(loaded_qlinear._packed_params._packed_params))
173            self.assertEqual(qlinear.scale, loaded_qlinear.scale)
174            self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point)
175            # scripting will add __overloads__ to __dict__, which is why we script a copy
176            # to be able to do the check in the next line
177            self.checkScriptable(copy.deepcopy(loaded_qlinear), [[X_q]], check_save_load=True)
178            self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
179            self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
180            self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
181            Z_q2 = loaded_qlinear(X_q)
182            self.assertEqual(Z_q, Z_q2)
183
184        # Test serialization
185        b = io.BytesIO()
186        torch.save(qlinear, b)
187        b.seek(0)
188        # weights_only=False as this is legacy code that saves the model
189        loaded = torch.load(b, weights_only=False)
190        self.assertEqual(qlinear.weight(), loaded.weight())
191        self.assertEqual(qlinear.scale, loaded.scale)
192        self.assertEqual(qlinear.zero_point, loaded.zero_point)
193
194        # Test torch.package
195        buffer = io.BytesIO()
196        with PackageExporter(buffer) as pe:
197            pe.save_pickle("module", "qlinear.pkl", qlinear)
198        buffer.seek(0)
199
200        importer = PackageImporter(buffer)
201        loaded_from_package = importer.load_pickle("module", "qlinear.pkl")
202        self.assertEqual(qlinear.weight(), loaded_from_package.weight())
203        self.assertEqual(qlinear.scale, loaded_from_package.scale)
204        self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point)
205
206        for name, module in loaded_from_package.named_modules():
207            # noop, just make sure attribute "_modules" is restored correctly during torch.package import
208            assert(name is not None)  # noqa: E275
209
210        # Test copy and deepcopy
211        copied_linear = copy.copy(qlinear)
212        self.assertEqual(copied_linear.bias(), qlinear.bias())
213        self.assertEqual(copied_linear.scale, qlinear.scale)
214        self.assertEqual(copied_linear.zero_point,
215                         qlinear.zero_point)
216        Y_copied = copied_linear(X_q)
217        np.testing.assert_array_almost_equal(
218            Z_q.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
219
220        deepcopied_linear = copy.deepcopy(qlinear)
221        self.assertEqual(deepcopied_linear.bias(), qlinear.bias())
222        self.assertEqual(deepcopied_linear.scale, qlinear.scale)
223        self.assertEqual(deepcopied_linear.zero_point,
224                         qlinear.zero_point)
225        Y_deepcopied = copied_linear(X_q)
226        np.testing.assert_array_almost_equal(
227            Z_q.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
228
229        # Test JIT
230        self.checkScriptable(qlinear, [[X_q]], check_save_load=True)
231
232        # Make sure `from_float` works for all linear variants
233        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
234
235        for mut in modules_under_test:
236            # Test from_float.
237            float_linear = mut(in_features, out_features).float()
238            float_linear.qconfig = torch.ao.quantization.default_qconfig
239            torch.ao.quantization.prepare(float_linear, inplace=True)
240            float_linear(X.float())
241            # Sequential allows swapping using "convert".
242            quantized_float_linear = torch.nn.Sequential(float_linear)
243            quantized_float_linear = torch.ao.quantization.convert(quantized_float_linear, inplace=True)
244
245            # Smoke test to make sure the module actually runs
246            quantized_float_linear(X_q)
247
248            # Smoke test extra_repr
249            self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
250
251    def test_quant_dequant_api(self):
252        r = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.float)
253        scale, zero_point, dtype = 1.0, 2, torch.qint8
254        # testing Quantize API
255        qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
256        quant_m = nnq.Quantize(scale, zero_point, dtype)
257        qr2 = quant_m(r)
258        self.assertEqual(qr, qr2)
259        # testing Dequantize API
260        rqr = qr.dequantize()
261        dequant_m = nnq.DeQuantize()
262        rqr2 = dequant_m(qr2)
263        self.assertEqual(rqr, rqr2)
264
265    def _test_conv_api_impl(
266            self, module_name, qconv_module, conv_module, batch_size,
267            in_channels_per_group, input_feature_map_size, out_channels_per_group,
268            groups, kernel_size, stride, padding, padding_mode, dilation,
269            X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
270            use_bias, post_op, use_channelwise, X2_scale=1.0, X2_zero_point=0):
271        for i in range(len(kernel_size)):
272            assume(input_feature_map_size[i] + 2 * padding[i]
273                   >= dilation[i] * (kernel_size[i] - 1) + 1)
274
275        in_channels = in_channels_per_group * groups
276        out_channels = out_channels_per_group * groups
277        (X, X_q, W, W_q, b) = _make_conv_test_input(
278            batch_size, in_channels_per_group, input_feature_map_size,
279            out_channels_per_group, groups, kernel_size, X_scale, X_zero_point,
280            W_scale, W_zero_point, use_bias, use_channelwise)
281        example_input = [X, ]
282        example_input_q = [X_q, ]
283
284        if post_op in ["add", "add_relu"]:
285            X2, X2_q = _make_conv_add_extra_input_tensor(X2_scale, X2_zero_point, conv_module[0](X).size())
286            example_input = [X, X2]
287            example_input_q = [X_q, X2_q]
288
289        # Make sure the weight shape is correct
290        self.assertTrue(qconv_module.weight().shape == W_q.shape)
291
292        qconv_module.set_weight_bias(W_q, b)
293        qconv_module.scale = Y_scale
294        qconv_module.zero_point = Y_zero_point
295
296        raw_conv_module = conv_module[0] if post_op in ["relu", "add", "add_relu"] else conv_module
297        raw_conv_module.weight.data = W
298        if use_bias:
299            raw_conv_module.bias.data = b
300
301        # Test members
302        self.assertTrue(module_name == qconv_module._get_name(), module_name + " " + qconv_module._get_name())
303        self.assertTrue(hasattr(qconv_module, '_packed_params'))
304        self.assertTrue(hasattr(qconv_module, 'scale'))
305        self.assertTrue(hasattr(qconv_module, 'zero_point'))
306
307        # Test properties
308        self.assertEqual(W_q, qconv_module.weight())
309        if use_bias:
310            self.assertEqual(b, qconv_module.bias())
311        self.assertEqual(Y_scale, qconv_module.scale)
312        self.assertEqual(Y_zero_point, qconv_module.zero_point)
313
314        # Test forward
315        Y_exp = conv_module(*example_input)
316        Y_exp = torch.quantize_per_tensor(
317            Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8)
318        Y_act = qconv_module(*example_input_q)
319
320        # Make sure the results match
321        # assert_array_almost_equal compares using the following formula:
322        #     abs(desired-actual) < 1.5 * 10**(-decimal)
323        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
324        # We use decimal = 0 to ignore off-by-1 differences between reference
325        # and test. Off-by-1 differences arise due to the order of round and
326        # zero_point addition operation, i.e., if addition followed by round is
327        # used by reference and round followed by addition is used by test, the
328        # results may differ by 1.
329        # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
330        # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
331        # skip numerics checking for reference module
332        np.testing.assert_array_almost_equal(
333            Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
334
335        # Test serialization of quantized Conv Module using state_dict
336        model_dict = qconv_module.state_dict()
337        self.assertEqual(model_dict['weight'], W_q)
338        if use_bias:
339            self.assertEqual(model_dict['bias'], b)
340        bytes_io = io.BytesIO()
341        torch.save(model_dict, bytes_io)
342        for weights_only in [True, False]:
343            bytes_io.seek(0)
344            loaded_dict = torch.load(bytes_io, weights_only=weights_only)
345            for key in loaded_dict:
346                self.assertEqual(model_dict[key], loaded_dict[key])
347            loaded_qconv_module = type(qconv_module)(
348                in_channels, out_channels, kernel_size, stride, padding, dilation,
349                groups, use_bias, padding_mode=padding_mode)
350            loaded_qconv_module.load_state_dict(loaded_dict)
351
352            self.assertTrue(dir(loaded_qconv_module) == dir(qconv_module))
353            self.assertTrue(module_name == loaded_qconv_module._get_name())
354            self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
355            self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))
356
357            self.assertEqual(qconv_module.weight(), loaded_qconv_module.weight())
358            if use_bias:
359                self.assertEqual(qconv_module.bias(), loaded_qconv_module.bias())
360            self.assertEqual(qconv_module.scale, loaded_qconv_module.scale)
361            self.assertEqual(qconv_module.zero_point,
362                             loaded_qconv_module.zero_point)
363            Y_loaded = loaded_qconv_module(*example_input_q)
364            np.testing.assert_array_almost_equal(
365                Y_exp.int_repr().numpy(), Y_loaded.int_repr().numpy(), decimal=0)
366
367        # Test serialization
368        b = io.BytesIO()
369        torch.save(qconv_module, b)
370        b.seek(0)
371        # weights_only=False as this is legacy code that saves the model
372        loaded_conv = torch.load(b, weights_only=False)
373
374        self.assertEqual(loaded_conv.bias(), qconv_module.bias())
375        self.assertEqual(loaded_conv.scale, qconv_module.scale)
376        self.assertEqual(loaded_conv.zero_point,
377                         qconv_module.zero_point)
378
379        # Test copy and deepcopy
380        copied_conv = copy.copy(qconv_module)
381        self.assertEqual(copied_conv.bias(), qconv_module.bias())
382        self.assertEqual(copied_conv.scale, qconv_module.scale)
383        self.assertEqual(copied_conv.zero_point,
384                         qconv_module.zero_point)
385        Y_copied = copied_conv(*example_input_q)
386        np.testing.assert_array_almost_equal(
387            Y_exp.int_repr().numpy(), Y_copied.int_repr().numpy(), decimal=0)
388
389        deepcopied_conv = copy.deepcopy(qconv_module)
390        self.assertEqual(deepcopied_conv.bias(), qconv_module.bias())
391        self.assertEqual(deepcopied_conv.scale, qconv_module.scale)
392        self.assertEqual(deepcopied_conv.zero_point,
393                         qconv_module.zero_point)
394        Y_deepcopied = deepcopied_conv(*example_input_q)
395        np.testing.assert_array_almost_equal(
396            Y_exp.int_repr().numpy(), Y_deepcopied.int_repr().numpy(), decimal=0)
397
398        # JIT testing
399        self.checkScriptable(
400            qconv_module, [example_input_q],
401            check_save_load=True)
402
403        class _FusedModule_two_input_args(torch.ao.nn.intrinsic._FusedModule):
404            # Help Module for ConvAdd2d since torch.ao.nn.intrinsic._FusedModule only support one input arg
405            def forward(self, x1, x2):
406                input = self[0](x1, x2)
407                return input
408
409        # Test from_float
410        fused_conv_module = _FusedModule_two_input_args(conv_module) \
411            if post_op in ["add", "add_relu"] else torch.ao.nn.intrinsic._FusedModule(conv_module)
412
413        fused_conv_module.qconfig = torch.ao.quantization.default_qconfig
414        torch.ao.quantization.prepare(fused_conv_module, inplace=True)
415        example_input[0] = example_input[0].float()
416        fused_conv_module(*example_input)
417        converted_qconv_module = fused_conv_module
418        reference_mapping = get_default_static_quant_module_mappings()
419        reference_mapping[type(conv_module)] = type(qconv_module)
420        torch.ao.quantization.convert(converted_qconv_module, mapping=reference_mapping, inplace=True)
421
422        # Smoke test to make sure the module actually runs
423        if use_bias:
424            self.assertEqual(conv_module[0].bias if (post_op in ["relu", "add", "add_relu"]) else conv_module.bias,
425                             converted_qconv_module[0].bias())
426        # Smoke test extra_repr
427        self.assertTrue(module_name == converted_qconv_module[0]._get_name())
428
429    @override_qengines
430    def test_conv1d_api(self):
431        options = itertools.product(
432            ["zeros", "reflect"],  # pad_mode
433            [True, False],  # use_bias
434            [True, False],  # use_channelwise
435        )
436        for pad_mode, use_bias, use_channelwise in options:
437            if torch.backends.quantized.engine == "qnnpack":
438                use_channelwise = False
439            batch_size = 2
440            in_channels_per_group = 2
441            length = 8
442            out_channels_per_group = 2
443            groups = 3
444            kernel = 3
445            stride = 2
446            pad = 1
447            dilation = 1
448            # Tests the correctness of the conv2d module.
449            in_channels = in_channels_per_group * groups
450            out_channels = out_channels_per_group * groups
451            input_feature_map_size = (length,)
452            kernel_size = (kernel, )
453            stride = (stride, )
454            pad = (pad, )
455            dilation = (dilation, )
456            X_scale = 1.3
457            X_zero_point = 2
458            W_scale = [0.5]
459            W_zero_point = [0] if qengine_is_onednn() else [3]
460            Y_scale = 5.0
461            Y_zero_point = 4
462            if torch.backends.quantized.engine == 'qnnpack':
463                use_channelwise = False
464            qconv_cls = nnq.Conv1d
465            module_name = "QuantizedConv1d"
466            qconv_module = qconv_cls(
467                in_channels, out_channels, kernel, stride, pad,
468                dilation, groups, use_bias, padding_mode=pad_mode
469            )
470
471            conv_module = nn.Conv1d(
472                in_channels, out_channels, kernel, stride, pad,
473                dilation, groups, use_bias, padding_mode=pad_mode)
474            conv_module = conv_module.float()
475
476            self._test_conv_api_impl(
477                module_name, qconv_module, conv_module, batch_size,
478                in_channels_per_group, input_feature_map_size,
479                out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
480                dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
481                Y_zero_point, use_bias, "none", use_channelwise)
482
483    @override_qengines
484    def test_conv1d_relu_api(self):
485        options = itertools.product(
486            ["zeros", "reflect"],  # pad_mode
487            [True, False],  # use_bias
488            [True, False],  # use_channelwise
489        )
490        batch_size = 2
491        in_channels_per_group = 2
492        length = 8
493        out_channels_per_group = 2
494        groups = 3
495        kernel = 3
496        stride = 2
497        pad = 1
498        dilation = 1
499        # Tests the correctness of the conv2d module.
500        in_channels = in_channels_per_group * groups
501        out_channels = out_channels_per_group * groups
502        input_feature_map_size = (length,)
503        kernel_size = (kernel, )
504        stride = (stride, )
505        pad = (pad, )
506        dilation = (dilation, )
507        X_scale = 1.3
508        X_zero_point = 2
509        W_scale = [0.5]
510        W_zero_point = [0] if qengine_is_onednn() else [3]
511        Y_scale = 5.0
512        Y_zero_point = 4
513        qconv_cls = nniq.ConvReLU1d
514        module_name = "QuantizedConvReLU1d"
515        for pad_mode, use_bias, use_channelwise in options:
516            if torch.backends.quantized.engine == 'qnnpack':
517                use_channelwise = False
518            qconv_module = qconv_cls(
519                in_channels, out_channels, kernel, stride, pad,
520                dilation, groups, use_bias, padding_mode=pad_mode
521            )
522
523            conv_module = nn.Conv1d(
524                in_channels, out_channels, kernel, stride, pad,
525                dilation, groups, use_bias, padding_mode=pad_mode)
526            relu_module = nn.ReLU()
527            conv_module = nni.ConvReLU1d(conv_module, relu_module)
528            conv_module = conv_module.float()
529
530            self._test_conv_api_impl(
531                module_name, qconv_module, conv_module, batch_size,
532                in_channels_per_group, input_feature_map_size,
533                out_channels_per_group, groups, kernel_size, stride, pad, pad_mode,
534                dilation, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
535                Y_zero_point, use_bias, "relu", use_channelwise)
536
537    @override_qengines
538    def test_conv2d_api(self):
539        options = itertools.product(
540            ["zeros", "reflect"],  # pad_mode
541            [True, False],  # use_bias
542            [True, False],  # use_channelwise
543        )
544        for pad_mode, use_bias, use_channelwise in options:
545            if torch.backends.quantized.engine == "qnnpack":
546                use_channelwise = False
547            batch_size = 2
548            in_channels_per_group = 2
549            H = 8
550            W = 8
551            out_channels_per_group = 2
552            groups = 3
553            kernel_h = 3
554            kernel_w = 3
555            stride_h = 2
556            stride_w = 2
557            pad_h = 1
558            pad_w = 1
559            dilation = 1
560            # Tests the correctness of the conv2d module.
561            in_channels = in_channels_per_group * groups
562            out_channels = out_channels_per_group * groups
563            input_feature_map_size = (H, W)
564            kernel_size = (kernel_h, kernel_w)
565            stride = (stride_h, stride_w)
566            padding = (pad_h, pad_w)
567            dilation = (dilation, dilation)
568            X_scale = 1.3
569            X_zero_point = 2
570            W_scale = [0.5]
571            W_zero_point = [0] if qengine_is_onednn() else [3]
572            Y_scale = 5.0
573            Y_zero_point = 4
574            qconv_cls = nnq.Conv2d
575            module_name = "QuantizedConv2d"
576            qconv_module = qconv_cls(
577                in_channels, out_channels, kernel_size, stride, padding,
578                dilation, groups, use_bias, padding_mode=pad_mode
579            )
580
581            conv_module = nn.Conv2d(
582                in_channels, out_channels, kernel_size, stride, padding,
583                dilation, groups, use_bias, padding_mode=pad_mode)
584            conv_module = conv_module.float()
585
586            self._test_conv_api_impl(
587                module_name, qconv_module, conv_module, batch_size,
588                in_channels_per_group, input_feature_map_size,
589                out_channels_per_group, groups, kernel_size, stride, padding,
590                pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
591                Y_scale, Y_zero_point, use_bias, "none", use_channelwise)
592
593    @override_qengines
594    def test_conv2d_relu_api(self):
595        options = itertools.product(
596            ["zeros", "reflect"],  # pad_mode
597            [True, False],  # use_bias
598            [True, False],  # use_channelwise
599        )
600        batch_size = 2
601        in_channels_per_group = 2
602        H = 8
603        W = 8
604        out_channels_per_group = 2
605        groups = 3
606        kernel_h = 3
607        kernel_w = 3
608        stride_h = 2
609        stride_w = 2
610        pad_h = 1
611        pad_w = 1
612        dilation = 1
613        # Tests the correctness of the conv2d module.
614        in_channels = in_channels_per_group * groups
615        out_channels = out_channels_per_group * groups
616        input_feature_map_size = (H, W)
617        kernel_size = (kernel_h, kernel_w)
618        stride = (stride_h, stride_w)
619        padding = (pad_h, pad_w)
620        dilation = (dilation, dilation)
621        X_scale = 1.3
622        X_zero_point = 2
623        W_scale = [0.5]
624        W_zero_point = [0] if qengine_is_onednn() else [3]
625        Y_scale = 5.0
626        Y_zero_point = 4
627        qconv_cls = nniq.ConvReLU2d
628        module_name = "QuantizedConvReLU2d"
629        for pad_mode, use_bias, use_channelwise in options:
630            if torch.backends.quantized.engine == "qnnpack":
631                use_channelwise = False
632            qconv_module = qconv_cls(
633                in_channels, out_channels, kernel_size, stride, padding,
634                dilation, groups, use_bias, padding_mode=pad_mode
635            )
636
637            conv_module = nn.Conv2d(
638                in_channels, out_channels, kernel_size, stride, padding,
639                dilation, groups, use_bias, padding_mode=pad_mode)
640            relu_module = nn.ReLU()
641            conv_module = nni.ConvReLU2d(conv_module, relu_module)
642            conv_module = conv_module.float()
643
644            self._test_conv_api_impl(
645                module_name, qconv_module, conv_module, batch_size,
646                in_channels_per_group, input_feature_map_size,
647                out_channels_per_group, groups, kernel_size, stride, padding,
648                pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
649                Y_scale, Y_zero_point, use_bias, "relu", use_channelwise)
650
651    @skipIfNoFBGEMM
652    def test_conv3d_api(self):
653        options = itertools.product(
654            [True, False],  # use_bias
655            [True, False],  # use_channelwise
656        )
657        batch_size = 2
658        in_channels_per_group = 2
659        H = 8
660        W = 8
661        D = 8
662        out_channels_per_group = 2
663        groups = 3
664        kernel_h = 3
665        kernel_w = 3
666        kernel_d = 3
667        stride_h = 2
668        stride_w = 2
669        stride_d = 2
670        pad_mode = "zeros"  # 3d doesn't support reflect padding
671        pad_h = 1
672        pad_w = 1
673        pad_d = 1
674        dilation = 1
675        # Tests the correctness of the conv3d module.
676        in_channels = in_channels_per_group * groups
677        out_channels = out_channels_per_group * groups
678        input_feature_map_size = (D, H, W)
679        kernel_size = (kernel_d, kernel_h, kernel_w)
680        stride = (stride_d, stride_h, stride_w)
681        padding = (pad_d, pad_h, pad_w)
682        dilation = (dilation, dilation, dilation)
683        X_scale = 1.3
684        X_zero_point = 2
685        W_scale = [0.5]
686        W_zero_point = [0] if qengine_is_onednn() else [3]
687        Y_scale = 5.0
688        Y_zero_point = 4
689        qconv_cls = nnq.Conv3d
690        module_name = "QuantizedConv3d"
691        for use_bias, use_channelwise in options:
692            if torch.backends.quantized.engine == "qnnpack":
693                use_channelwise = False
694            with override_quantized_engine('fbgemm'):
695                qconv_module = qconv_cls(
696                    in_channels, out_channels, kernel_size, stride, padding,
697                    dilation, groups, use_bias, padding_mode=pad_mode
698                )
699
700                conv_module = nn.Conv3d(
701                    in_channels, out_channels, kernel_size, stride, padding,
702                    dilation, groups, use_bias, padding_mode=pad_mode)
703                conv_module = conv_module.float()
704
705                self._test_conv_api_impl(
706                    module_name, qconv_module, conv_module, batch_size,
707                    in_channels_per_group, input_feature_map_size,
708                    out_channels_per_group, groups, kernel_size, stride, padding,
709                    pad_mode, dilation, X_scale, X_zero_point, W_scale,
710                    W_zero_point, Y_scale, Y_zero_point, use_bias, "none",
711                    use_channelwise)
712
713    @skipIfNoFBGEMM
714    def test_conv3d_relu_api(self):
715        options = itertools.product(
716            [True, False],  # use_bias
717            [True, False],  # use_channelwise
718        )
719        batch_size = 2
720        in_channels_per_group = 2
721        H = 8
722        W = 8
723        D = 8
724        out_channels_per_group = 2
725        groups = 3
726        kernel_h = 3
727        kernel_w = 3
728        kernel_d = 3
729        stride_h = 2
730        stride_w = 2
731        stride_d = 2
732        pad_mode = "zeros"  # 3d doesn't support reflect padding
733        pad_h = 1
734        pad_w = 1
735        pad_d = 1
736        dilation = 1
737        # Tests the correctness of the conv3d module.
738        in_channels = in_channels_per_group * groups
739        out_channels = out_channels_per_group * groups
740        input_feature_map_size = (D, H, W)
741        kernel_size = (kernel_d, kernel_h, kernel_w)
742        stride = (stride_d, stride_h, stride_w)
743        padding = (pad_d, pad_h, pad_w)
744        dilation = (dilation, dilation, dilation)
745        X_scale = 1.3
746        X_zero_point = 2
747        W_scale = [0.5]
748        W_zero_point = [0] if qengine_is_onednn() else [3]
749        Y_scale = 5.0
750        Y_zero_point = 4
751        qconv_cls = nniq.ConvReLU3d
752        module_name = "QuantizedConvReLU3d"
753        for use_bias, use_channelwise in options:
754            if torch.backends.quantized.engine == "qnnpack":
755                use_channelwise = False
756            with override_quantized_engine('fbgemm'):
757                qconv_module = qconv_cls(
758                    in_channels, out_channels, kernel_size, stride, padding,
759                    dilation, groups, use_bias, padding_mode=pad_mode
760                )
761
762                conv_module = nn.Conv3d(
763                    in_channels, out_channels, kernel_size, stride, padding,
764                    dilation, groups, use_bias, padding_mode=pad_mode)
765                relu_module = nn.ReLU()
766                conv_module = nni.ConvReLU3d(conv_module, relu_module)
767                conv_module = conv_module.float()
768
769                self._test_conv_api_impl(
770                    module_name, qconv_module, conv_module, batch_size,
771                    in_channels_per_group, input_feature_map_size,
772                    out_channels_per_group, groups, kernel_size, stride, padding,
773                    pad_mode, dilation, X_scale, X_zero_point, W_scale,
774                    W_zero_point, Y_scale, Y_zero_point, use_bias, "relu",
775                    use_channelwise)
776
777    @skipIfNoONEDNN
778    def test_conv2d_add(self):
779        """test API functionality for nn.intrinsic.quantized.ConvAdd2d"""
780        with override_quantized_engine('onednn'):
781            options = itertools.product(
782                ["zeros", "reflect"],  # pad_mode
783                [True, False],  # use_bias
784                [True, False],  # use_channelwise
785            )
786            batch_size = 2
787            in_channels_per_group = 2
788            H = 8
789            W = 8
790            out_channels_per_group = 2
791            groups = 3
792            kernel_h = 3
793            kernel_w = 3
794            stride_h = 2
795            stride_w = 2
796            pad_h = 1
797            pad_w = 1
798            dilation = 1
799            # Tests the correctness of the conv2d module.
800            in_channels = in_channels_per_group * groups
801            out_channels = out_channels_per_group * groups
802            input_feature_map_size = (H, W)
803            kernel_size = (kernel_h, kernel_w)
804            stride = (stride_h, stride_w)
805            padding = (pad_h, pad_w)
806            dilation = (dilation, dilation)
807            X_scale = 1.3
808            X_zero_point = 2
809            X2_scale = 1.2
810            X2_zero_point = 1
811            W_scale = [0.5]
812            W_zero_point = [0] if qengine_is_onednn() else [3]
813            Y_scale = 5.0
814            Y_zero_point = 4
815            qconv_cls = nniq.ConvAdd2d
816            module_name = "QuantizedConvAdd2d"
817            for pad_mode, use_bias, use_channelwise in options:
818                qconv_module = qconv_cls(
819                    in_channels, out_channels, kernel_size, stride, padding,
820                    dilation, groups, use_bias, padding_mode=pad_mode
821                )
822
823                conv_module = nn.Conv2d(
824                    in_channels, out_channels, kernel_size, stride, padding,
825                    dilation, groups, use_bias, padding_mode=pad_mode)
826                conv_module = torch.ao.nn.intrinsic.ConvAdd2d(conv_module, torch.add)
827                conv_module = conv_module.float()
828
829                self._test_conv_api_impl(
830                    module_name, qconv_module, conv_module, batch_size,
831                    in_channels_per_group, input_feature_map_size,
832                    out_channels_per_group, groups, kernel_size, stride, padding,
833                    pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
834                    Y_scale, Y_zero_point, use_bias, "add", use_channelwise, X2_scale, X2_zero_point)
835
836    @skipIfNoONEDNN
837    def test_conv2d_add_relu(self):
838        """test API functionality for nn.intrinsic.quantized.ConvAdd2d"""
839        with override_quantized_engine('onednn'):
840            options = itertools.product(
841                ["zeros", "reflect"],  # pad_mode
842                [True, False],  # use_bias
843                [True, False],  # use_channelwise
844            )
845            batch_size = 2
846            in_channels_per_group = 2
847            H = 8
848            W = 8
849            out_channels_per_group = 2
850            groups = 3
851            kernel_h = 3
852            kernel_w = 3
853            stride_h = 2
854            stride_w = 2
855            pad_h = 1
856            pad_w = 1
857            dilation = 1
858            # Tests the correctness of the conv2d module.
859            in_channels = in_channels_per_group * groups
860            out_channels = out_channels_per_group * groups
861            input_feature_map_size = (H, W)
862            kernel_size = (kernel_h, kernel_w)
863            stride = (stride_h, stride_w)
864            padding = (pad_h, pad_w)
865            dilation = (dilation, dilation)
866            X_scale = 1.3
867            X_zero_point = 2
868            X2_scale = 1.2
869            X2_zero_point = 1
870            W_scale = [0.5]
871            W_zero_point = [0] if qengine_is_onednn() else [3]
872            Y_scale = 5.0
873            Y_zero_point = 4
874            qconv_cls = nniq.ConvAddReLU2d
875            module_name = "QuantizedConvAddReLU2d"
876            for pad_mode, use_bias, use_channelwise in options:
877                qconv_module = qconv_cls(
878                    in_channels, out_channels, kernel_size, stride, padding,
879                    dilation, groups, use_bias, padding_mode=pad_mode
880                )
881
882                conv_module = nn.Conv2d(
883                    in_channels, out_channels, kernel_size, stride, padding,
884                    dilation, groups, use_bias, padding_mode=pad_mode)
885                conv_module = torch.ao.nn.intrinsic.ConvAddReLU2d(conv_module, torch.add, nn.ReLU())
886                conv_module = conv_module.float()
887
888                self._test_conv_api_impl(
889                    module_name, qconv_module, conv_module, batch_size,
890                    in_channels_per_group, input_feature_map_size,
891                    out_channels_per_group, groups, kernel_size, stride, padding,
892                    pad_mode, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
893                    Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, X2_scale, X2_zero_point)
894
895    def test_pool_api(self):
896        """Tests the correctness of the pool module.
897        The correctness is defined against the functional implementation.
898        """
899        N, C, H, W = 10, 10, 10, 3
900        kwargs = {
901            'kernel_size': 2,
902            'stride': None,
903            'padding': 0,
904            'dilation': 1
905        }
906
907        scale, zero_point = 1.0 / 255, 128
908
909        X = torch.randn(N, C, H, W, dtype=torch.float32)
910        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
911                                       dtype=torch.quint8)
912        qX_expect = torch.nn.functional.max_pool2d(qX, **kwargs)
913
914        pool_under_test = torch.ao.nn.quantized.MaxPool2d(**kwargs)
915        qX_hat = pool_under_test(qX)
916        self.assertEqual(qX_expect, qX_hat)
917
918        # JIT Testing
919        self.checkScriptable(pool_under_test, [[X]])
920
921    def test_dropout(self):
922        """Tests the correctness of the dropout module.
923        The correctness is defined against the functional implementation.
924        """
925        x = torch.randn((2, 4, 6, 8), dtype=torch.float)
926        float_mod = torch.nn.Dropout(p=0.5)
927        float_mod.training = False
928
929        y_ref = float_mod(x)
930        quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
931
932        quant_mod = nnq.Dropout(p=0.5)
933        qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
934        qy = quant_mod(qx)
935
936        self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
937                         msg="Dropout module API failed")
938
939    def _test_dropout_serialization(self, get_model, data1, data2):
940        m1 = get_model()
941        m1.qconfig = torch.ao.quantization.default_qconfig
942        mp1 = torch.ao.quantization.prepare(m1)
943        mp1(data1)
944        mq1 = torch.ao.quantization.convert(mp1)
945        ref1 = mq1(data2)
946
947        m2 = get_model()
948        m2.qconfig = torch.ao.quantization.default_qconfig
949        mp2 = torch.ao.quantization.prepare(m2)
950        mq2 = torch.ao.quantization.convert(mp2)
951
952        mq2.load_state_dict(mq1.state_dict())
953        ref2 = mq2(data2)
954
955        self.assertTrue(torch.allclose(ref1, ref2))
956
957    def test_dropout_serialization(self):
958        data1 = torch.randn(2, 4, 6, 8)
959        data2 = torch.randn(2, 4, 6, 8)
960
961        def _get_model():
962            return nn.Sequential(
963                torch.ao.quantization.QuantStub(),
964                nn.Dropout(p=0.5),
965                torch.ao.quantization.DeQuantStub()
966            ).eval()
967
968        self._test_dropout_serialization(_get_model, data1, data2)
969
970
971
972    def test_batch_norm2d(self):
973        """Tests the correctness of the batchnorm2d module.
974        The correctness is defined against the functional implementation.
975        """
976        x = torch.randn((2, 4, 6, 8), dtype=torch.float)
977        float_mod = torch.nn.BatchNorm2d(4)
978        float_mod.training = False
979
980        y_ref = float_mod(x)
981        quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
982
983        quant_mod = nnq.BatchNorm2d(4)
984        qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
985        qy = quant_mod(qx)
986
987        self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
988                         msg="BatchNorm2d module API failed")
989
990    def test_batch_norm3d(self):
991        """Tests the correctness of the batchnorm3d module.
992        The correctness is defined against the functional implementation.
993        """
994        x = torch.randn((2, 4, 6, 8, 10), dtype=torch.float)
995        float_mod = torch.nn.BatchNorm3d(4)
996        float_mod.training = False
997
998        y_ref = float_mod(x)
999        quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
1000
1001        quant_mod = nnq.BatchNorm3d(4)
1002        qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
1003        qy = quant_mod(qx)
1004
1005        self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
1006                         msg="BatchNorm3d module API failed")
1007
1008    def _test_batch_norm_serialization(self, get_model, data1, data2):
1009        m1 = get_model()
1010        m1.qconfig = torch.ao.quantization.default_qconfig
1011        mp1 = torch.ao.quantization.prepare(m1)
1012        mp1(data1)
1013        mq1 = torch.ao.quantization.convert(mp1)
1014        ref1 = mq1(data2)
1015
1016        m2 = get_model()
1017        m2.qconfig = torch.ao.quantization.default_qconfig
1018        mp2 = torch.ao.quantization.prepare(m2)
1019        mq2 = torch.ao.quantization.convert(mp2)
1020
1021        mq2.load_state_dict(mq1.state_dict())
1022        ref2 = mq2(data2)
1023
1024        self.assertTrue(torch.allclose(ref1, ref2))
1025
1026    def test_batch_norm2d_serialization(self):
1027        data1 = torch.randn(2, 4, 6, 8)
1028        data2 = torch.randn(2, 4, 6, 8)
1029
1030        def _get_model():
1031            return nn.Sequential(
1032                torch.ao.quantization.QuantStub(),
1033                nn.BatchNorm2d(4),
1034                torch.ao.quantization.DeQuantStub()
1035            ).eval()
1036
1037        self._test_batch_norm_serialization(_get_model, data1, data2)
1038
1039    def test_batch_norm3d_serialization(self):
1040        data1 = torch.randn(2, 4, 6, 8, 1)
1041        data2 = torch.randn(2, 4, 6, 8, 1)
1042
1043        def _get_model():
1044            return nn.Sequential(
1045                torch.ao.quantization.QuantStub(),
1046                nn.BatchNorm3d(4),
1047                torch.ao.quantization.DeQuantStub()
1048            ).eval()
1049
1050        self._test_batch_norm_serialization(_get_model, data1, data2)
1051
1052    def test_layer_norm(self):
1053        """Tests the correctness of the layernorm module.
1054        The correctness is defined against the functional implementation.
1055        """
1056        x_scale = 10.0 / 256
1057        x_zero_point = 0
1058        y_scale = 5.0 / 256
1059        y_zero_point = 127
1060
1061        dims = (1, 4, 8)
1062
1063        X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1064        qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1065        dqX = qX.dequantize()
1066
1067        float_mod = torch.nn.LayerNorm(dqX.size()[1:]).float()
1068        float_mod.weight = torch.nn.Parameter(torch.rand(*dims[1:]))
1069        float_mod.bias = torch.nn.Parameter(torch.rand(*dims[1:]))
1070
1071        dqY_ref = float_mod(dqX)
1072        qY_ref = torch.quantize_per_tensor(
1073            dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1074
1075        quant_mod = nnq.LayerNorm(
1076            qX.size()[1:], float_mod.weight, float_mod.bias, y_scale, y_zero_point)
1077        qY = quant_mod(qX)
1078
1079        self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1080                         msg=f"LayerNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1081
1082    def test_group_norm(self):
1083        """Tests the correctness of the groupnorm module.
1084        The correctness is defined against the functional implementation.
1085        """
1086        x_scale = 10.0 / 256
1087        x_zero_point = 0
1088        y_scale = 5.0 / 256
1089        y_zero_point = 127
1090
1091        dims = (1, 4, 8)
1092
1093        X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1094        qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1095        dqX = qX.dequantize()
1096
1097        float_mod = torch.nn.GroupNorm(2, 4).float()
1098        float_mod.weight = torch.nn.Parameter(torch.rand(dims[1]))
1099        float_mod.bias = torch.nn.Parameter(torch.rand(dims[1]))
1100
1101        dqY_ref = float_mod(dqX)
1102        qY_ref = torch.quantize_per_tensor(
1103            dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1104
1105        quant_mod = nnq.GroupNorm(
1106            2, 2, float_mod.weight, float_mod.bias, y_scale, y_zero_point)
1107        qY = quant_mod(qX)
1108
1109        self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1110                         msg=f"GroupNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1111
1112    def test_instance_norm(self):
1113        """Tests the correctness of the instancenorm{n}d modules.
1114        The correctness is defined against the functional implementation.
1115        """
1116        x_scale = 10.0 / 256
1117        x_zero_point = 0
1118        y_scale = 5.0 / 256
1119        y_zero_point = 127
1120
1121        dims_to_modules = [
1122            ((1, 4, 8), torch.nn.InstanceNorm1d, nnq.InstanceNorm1d),
1123            ((1, 4, 8, 1), torch.nn.InstanceNorm2d, nnq.InstanceNorm2d),
1124            ((1, 4, 8, 1, 1), torch.nn.InstanceNorm3d, nnq.InstanceNorm3d),
1125        ]
1126
1127        for dim_to_modules in dims_to_modules:
1128            dims, float_cls, q_cls = dim_to_modules
1129
1130            X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1131            qX = torch.quantize_per_tensor(
1132                X, x_scale, x_zero_point, dtype=torch.quint8)
1133            dqX = qX.dequantize()
1134
1135            float_mod = float_cls(dims[1]).float()
1136            float_mod.weight = torch.nn.Parameter(torch.rand(dims[1]))
1137            float_mod.bias = torch.nn.Parameter(torch.rand(dims[1]))
1138
1139            dqY_ref = float_mod(dqX)
1140            qY_ref = torch.quantize_per_tensor(
1141                dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1142
1143            quant_mod = q_cls(
1144                dims[1], float_mod.weight, float_mod.bias, y_scale,
1145                y_zero_point)
1146            qY = quant_mod(qX)
1147
1148            self.assertEqual(
1149                qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1150                msg=f"InstanceNorm module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1151
1152    def _test_activation_module_impl(self, name, float_module_class, quantized_module_class, extra_kwargs):
1153        """Tests the correctness of the ELU module.
1154        The correctness is defined against the functional implementation.
1155        """
1156        x_scale = 10.0 / 256
1157        x_zero_point = 0
1158        y_scale = 5.0 / 256
1159        y_zero_point = 127
1160        alpha = 1.5
1161
1162        dims = (1, 4, 8)
1163
1164        X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1165        qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1166        dqX = qX.dequantize()
1167
1168        float_mod = float_module_class(**extra_kwargs).float()
1169
1170        dqY_ref = float_mod(dqX)
1171        qY_ref = torch.quantize_per_tensor(
1172            dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1173
1174        quant_mod = quantized_module_class(y_scale, y_zero_point, **extra_kwargs)
1175        qY = quant_mod(qX)
1176        self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1177                         msg=f"{name} module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1178
1179    def _test_leaky_relu_serialization(self):
1180        scale_original = 10.0 / 256
1181        zero_point_original = 1.0
1182
1183        quant_mod_original = nnq.LeakyReLU(scale_original, zero_point_original)
1184        state_dict = quant_mod_original.state_dict()
1185
1186        scale_new = 5.0 / 256
1187        zero_point_new = 2.0
1188        quant_mod_new = nnq.LeakyReLU(scale_new, zero_point_new)
1189        quant_mod_new.load_state_dict(state_dict)
1190
1191        self.assertEqual(quant_mod_original.scale, quant_mod_new.scale)
1192        self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point)
1193
1194    def test_elu(self):
1195        """Tests the correctness of the ELU module.
1196        The correctness is defined against the functional implementation.
1197        """
1198        self._test_activation_module_impl("ELU", nn.ELU, nnq.ELU, {"alpha": 1.5})
1199
1200    def test_leaky_relu(self):
1201        self._test_activation_module_impl("LeakyReLU", nn.LeakyReLU, nnq.LeakyReLU, {"negative_slope": 0.2})
1202        self._test_leaky_relu_serialization()
1203
1204    def test_sigmoid(self):
1205        self._test_activation_module_impl("Sigmoid", nn.Sigmoid, nnq.Sigmoid, {})
1206
1207    def _test_hard_swish_serialization(self):
1208        scale_original = 10.0 / 256
1209        zero_point_original = 1.0
1210
1211        quant_mod_original = nnq.Hardswish(scale_original, zero_point_original)
1212        state_dict = quant_mod_original.state_dict()
1213
1214        scale_new = 5.0 / 256
1215        zero_point_new = 2.0
1216        quant_mod_new = nnq.Hardswish(scale_new, zero_point_new)
1217        quant_mod_new.load_state_dict(state_dict)
1218
1219        self.assertEqual(quant_mod_original.scale, quant_mod_new.scale)
1220        self.assertEqual(quant_mod_original.zero_point, quant_mod_new.zero_point)
1221
1222    def test_hard_swish(self):
1223        self._test_activation_module_impl("Hardswish", nn.Hardswish, nnq.Hardswish, {})
1224        self._test_hard_swish_serialization()
1225
1226    @given(
1227        num_embeddings=st.integers(10, 50),
1228        embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
1229        set_qconfig=st.booleans(),
1230    )
1231    @skipIfNoFBGEMM
1232    def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig):
1233        num_lengths = np.random.randint(1, 6)
1234        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
1235        num_indices = np.sum(lengths)
1236        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
1237        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))
1238
1239        obs = default_float_qparams_observer()
1240        obs(weights)
1241        qparams = obs.calculate_qparams()
1242
1243        dtypes = [torch.quint4x2, torch.quint8]
1244        embedding_funcs = [torch.ops.quantized.embedding_4bit, torch.ops.quantized.embedding_byte]
1245
1246        for dtype, embedding_func in zip(dtypes, embedding_funcs):
1247            # Quantize the weights
1248            qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype)
1249            qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
1250            qemb.set_weight(qweight)
1251            qemb(indices)
1252
1253            # Ensure the module has the correct weights
1254            self.assertEqual(qweight, qemb.weight())
1255            w_packed = qemb._packed_params._packed_weight
1256            module_out = qemb(indices)
1257
1258            # Call the bit qembedding operator directly
1259            ref = embedding_func(w_packed, indices, pruned_weights=False)
1260            self.assertEqual(module_out, ref)
1261            self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False,
1262                                             is_emb_bag=False, dtype=dtype)
1263
1264    @given(
1265        num_embeddings=st.integers(10, 50),
1266        embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
1267        num_offsets=st.integers(1, 20),
1268        set_qconfig=st.booleans(),
1269    )
1270    @skipIfNoFBGEMM
1271    def test_embedding_bag_api(self, num_embeddings, embedding_dim, num_offsets, set_qconfig):
1272        r"""Test execution and serialization for dynamic quantized embedding_bag modules on int8
1273        """
1274
1275        num_lengths = np.random.randint(1, 6)
1276        lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
1277        num_indices = np.sum(lengths)
1278        indices = torch.from_numpy(np.random.randint(low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
1279
1280        offsets = lengths_to_offsets(lengths)
1281        # include the last offset
1282        offsets = torch.cat((offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0)
1283        weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32))
1284
1285        for qdtype in [torch.quint8, torch.quint4x2]:
1286            obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
1287            obs(weights)
1288            # Get the scale and zero point for the weight tensor
1289            qparams = obs.calculate_qparams()
1290            # Quantize the weights to 8bits
1291            qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
1292            qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
1293                                    include_last_offset=True, mode='sum', _weight=qweight, dtype=qdtype)
1294            qemb(indices, offsets)
1295
1296            # Ensure the module has the correct weights
1297            self.assertEqual(qweight, qemb.weight())
1298
1299            w_packed = qemb._packed_params._packed_weight
1300            module_out = qemb(indices, offsets)
1301
1302            # Call the qembedding_bag operator directly
1303            if qdtype == torch.quint8:
1304                ref = torch.ops.quantized.embedding_bag_byte(w_packed, indices, offsets, mode=0,
1305                                                             per_sample_weights=None,
1306                                                             include_last_offset=True)
1307            else:
1308                ref = torch.ops.quantized.embedding_bag_4bit(w_packed, indices, offsets, mode=0,
1309                                                             per_sample_weights=None,
1310                                                             include_last_offset=True)
1311
1312            self.assertEqual(module_out, ref)
1313            self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices,
1314                                             offsets, set_qconfig, is_emb_bag=True, dtype=qdtype)
1315
1316    def test_prelu(self):
1317        for num_parameters in range(1, 10):
1318            x = torch.randn(4, num_parameters, 4)
1319            qx = torch.quantize_per_tensor_dynamic(x, dtype=torch.quint8, reduce_range=False)
1320
1321
1322            f_prelu = torch.nn.PReLU(num_parameters=num_parameters)
1323            f_prelu.weight = torch.nn.Parameter(torch.randn(num_parameters).abs())
1324            f_prelu.qconfig = torch.ao.quantization.QConfig(
1325                activation=torch.ao.quantization.default_observer,
1326                weight=torch.ao.quantization.default_observer,)
1327            f_prelu.activation_post_process = f_prelu.qconfig.activation()
1328            f_prelu.activation_post_process(f_prelu(x))
1329            q_prelu = nnq.PReLU.from_float(f_prelu)
1330            w_obs = f_prelu.qconfig.weight()
1331            w_obs(f_prelu.weight)
1332            w_scale, w_zp = w_obs.calculate_qparams()
1333            q_prelu_weight = torch.quantize_per_tensor(
1334                f_prelu.weight,
1335                dtype=torch.quint8,
1336                scale=w_scale,
1337                zero_point=w_zp
1338            ).dequantize()
1339
1340            # check that the weight makes sense
1341            self.assertEqual(q_prelu.weight.dequantize(), q_prelu_weight)
1342            f_prelu.weight = torch.nn.Parameter(q_prelu.weight.dequantize())
1343            qy = q_prelu(qx)
1344            qy_ref = torch.quantize_per_tensor(
1345                f_prelu(qx.dequantize()), q_prelu.scale, q_prelu.zero_point, dtype=torch.quint8
1346            )
1347            # check that the output makes sense
1348            self.assertEqual(qy, qy_ref, atol=.1, rtol=.1)
1349
1350    def test_channel_shuffle(self):
1351        """Tests the correctness of the ChannelShuffle module.
1352        """
1353        x_scale = 10.0 / 256
1354        x_zero_point = 1
1355        y_scale = x_scale
1356        y_zero_point = x_zero_point
1357
1358        dims = (1, 4, 4, 8)
1359        groups = 2
1360
1361        X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10
1362        qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8)
1363        dqX = qX.dequantize()
1364
1365        float_mod = torch.nn.ChannelShuffle(groups).float()
1366        dqY_ref = float_mod(dqX)
1367        qY_ref = torch.quantize_per_tensor(
1368            dqY_ref, y_scale, y_zero_point, dtype=torch.quint8)
1369
1370        quant_mod = torch.nn.ChannelShuffle(groups)
1371        qY = quant_mod(qX)
1372
1373        self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(),
1374                         msg=f"ChannelShuffle module API failed, qY_ref\n{qY_ref} vs qY\n{qY}")
1375
1376    @skipIfNoONEDNN
1377    def test_linear_leaky_relu(self):
1378        """test API functionality for nn.intrinsic.quantized.linear_leaky_relu"""
1379        with override_quantized_engine('onednn'):
1380            options = itertools.product(
1381                [1, 5],  # batch size
1382                [16, 32],  # in_features
1383                [4, 8],  # out_features
1384                [True, False],  # use_bias
1385                [True, False],  # per_channel
1386                [0.01, 0.05])  # negative slope
1387            for (batch_size, in_features, out_features, use_bias,
1388                 per_channel, neg_slope) in options:
1389                self._test_linear_api_impl(
1390                    nniq.LinearLeakyReLU, 'QuantizedLinearLeakyReLU',
1391                    torch.ops.quantized.linear_leaky_relu,
1392                    batch_size, in_features, out_features, use_bias,
1393                    per_channel, negative_slope=neg_slope)
1394
1395    @skipIfNoONEDNN
1396    def test_linear_tanh(self):
1397        """test API functionality for nn.intrinsic.quantized.linear_tanh"""
1398        with override_quantized_engine('onednn'):
1399            options = itertools.product(
1400                [1, 5],  # batch size
1401                [16, 32],  # in_features
1402                [4, 8],  # out_features
1403                [True, False],  # use_bias
1404                [True, False])  # negative slope
1405            for (batch_size, in_features, out_features, use_bias,
1406                 per_channel) in options:
1407                self._test_linear_api_impl(
1408                    nniq.LinearTanh, 'QuantizedLinearTanh',
1409                    torch.ops.quantized.linear_tanh,
1410                    batch_size, in_features, out_features, use_bias,
1411                    per_channel)
1412
1413class TestDynamicQuantizedModule(QuantizationTestCase):
1414    def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias):
1415        in_channels = 3
1416        out_channels = 10
1417        kernel_size = 2
1418        stride = 1
1419        padding = 0
1420        dilation = 1
1421        groups = 1
1422        padding_mode = 'zeros'
1423
1424        if qengine_is_qnnpack():
1425            reduce_range = False
1426        else:
1427            reduce_range = True
1428
1429        X_fp32 = torch.randn(*([in_channels] * dim))
1430        s, z = _calculate_dynamic_qparams(X_fp32, dtype, reduce_range)
1431        X_q = torch.quantize_per_tensor(X_fp32, s, z, dtype)
1432        X_dq = torch.dequantize(X_q)
1433
1434        quantized_module = q_mod(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
1435                                 dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
1436        dynamic_module = dq_mod(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
1437                                dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
1438
1439        quantized_module.scale, quantized_module.zero_point = s, z
1440        dynamic_module.set_weight_bias(*quantized_module._weight_bias())
1441
1442        Y_q_ref = quantized_module(X_q)
1443        Y_ref = torch.dequantize(Y_q_ref)
1444
1445        Y = dynamic_module(X_dq, reduce_range)
1446
1447        self.assertEqual(Y, Y_ref)
1448
1449        # Test serialization of quantized Conv Module using state_dict
1450        W_q, b = dynamic_module._weight_bias()
1451        model_dict = dynamic_module.state_dict()
1452        self.assertEqual(model_dict['weight'], W_q)
1453        self.assertEqual(model_dict['bias'], b)
1454        bytes_io = io.BytesIO()
1455        torch.save(model_dict, bytes_io)
1456        for weights_only in [True, False]:
1457            bytes_io.seek(0)
1458            loaded_dict = torch.load(bytes_io, weights_only=weights_only)
1459            for key in loaded_dict:
1460                self.assertEqual(model_dict[key], loaded_dict[key])
1461            loaded_qconv_module = type(dynamic_module)(
1462                in_channels, out_channels, kernel_size, stride=stride, padding=padding,
1463                dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
1464            loaded_qconv_module.load_state_dict(loaded_dict)
1465
1466            self.assertTrue(dir(loaded_qconv_module) == dir(dynamic_module))
1467            self.assertTrue(dynamic_module._get_name() == loaded_qconv_module._get_name())
1468            self.assertTrue(hasattr(loaded_qconv_module, '_packed_params'))
1469            self.assertTrue(hasattr(loaded_qconv_module, '_weight_bias'))
1470
1471            self.assertEqual(dynamic_module.weight(), loaded_qconv_module.weight())
1472            if bias:
1473                self.assertEqual(dynamic_module.bias(), loaded_qconv_module.bias())
1474            self.assertEqual(dynamic_module.scale, loaded_qconv_module.scale)
1475            self.assertEqual(dynamic_module.zero_point,
1476                             loaded_qconv_module.zero_point)
1477            Y_loaded = loaded_qconv_module(X_fp32, reduce_range)
1478            np.testing.assert_array_almost_equal(
1479                Y.numpy(), Y_loaded.numpy(), decimal=0)
1480
1481        # Test serialization
1482        b = io.BytesIO()
1483        torch.save(dynamic_module, b)
1484        b.seek(0)
1485        # weights_only=False as this is legacy code that saves the model
1486        loaded_conv = torch.load(b, weights_only=False)
1487
1488        self.assertEqual(loaded_conv.bias(), dynamic_module.bias())
1489        self.assertEqual(loaded_conv.scale, dynamic_module.scale)
1490        self.assertEqual(loaded_conv.zero_point,
1491                         dynamic_module.zero_point)
1492
1493        # Test copy and deepcopy
1494        copied_conv = copy.copy(dynamic_module)
1495        self.assertEqual(copied_conv.bias(), dynamic_module.bias())
1496        self.assertEqual(copied_conv.scale, dynamic_module.scale)
1497        self.assertEqual(copied_conv.zero_point,
1498                         dynamic_module.zero_point)
1499        Y_copied = copied_conv(X_fp32, reduce_range)
1500        np.testing.assert_array_almost_equal(
1501            Y.numpy(), Y_copied.numpy(), decimal=0)
1502
1503        deepcopied_conv = copy.deepcopy(dynamic_module)
1504        self.assertEqual(deepcopied_conv.bias(), dynamic_module.bias())
1505        self.assertEqual(deepcopied_conv.scale, dynamic_module.scale)
1506        self.assertEqual(deepcopied_conv.zero_point,
1507                         dynamic_module.zero_point)
1508        Y_deepcopied = copied_conv(X_fp32, reduce_range)
1509        np.testing.assert_array_almost_equal(
1510            Y.numpy(), Y_deepcopied.numpy(), decimal=0)
1511
1512        # need to fix this
1513        # JIT testing
1514        self.checkScriptable(
1515            dynamic_module, [[X_dq]],
1516            check_save_load=True)
1517
1518        # Test from_float
1519        conv_module = dynamic_module._FLOAT_MODULE(in_channels, out_channels, kernel_size)
1520        conv_module.qconfig = torch.ao.quantization.default_dynamic_qconfig  # type: ignore[assignment]
1521        prepare_dynamic(conv_module)
1522        conv_module(X_dq)
1523        quantized_conv_module = dq_mod.from_float(conv_module)
1524
1525        # Smoke test to make sure the module actually runs
1526        quantized_conv_module(X_dq)
1527
1528        # Smoke test extra_repr
1529        self.assertEqual(dynamic_module._get_name(), quantized_conv_module._get_name())
1530
1531    @override_qengines
1532    def test_dynamic_conv1d(self):
1533        q_mod = torch.ao.nn.quantized.Conv1d
1534        dq_mod = torch.ao.nn.quantized.dynamic.Conv1d
1535        dim = 3
1536        dtype = torch.quint8
1537
1538        for bias in [True, False]:
1539            self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1540
1541    @override_qengines
1542    def test_dynamic_conv2d(self):
1543        q_mod = torch.ao.nn.quantized.Conv2d
1544        dq_mod = torch.ao.nn.quantized.dynamic.Conv2d
1545        dim = 4
1546        dtype = torch.quint8
1547
1548        for bias in [True, False]:
1549            self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1550
1551    @override_qengines
1552    def test_dynamic_conv3d(self):
1553        q_mod = torch.ao.nn.quantized.Conv3d
1554        dq_mod = torch.ao.nn.quantized.dynamic.Conv3d
1555        dim = 5
1556        dtype = torch.quint8
1557
1558        if qengine_is_qnnpack():
1559            return  # qnnpack doesn't support unpacking conv3d
1560        for bias in [True, False]:
1561            self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1562
1563    @override_qengines
1564    def test_dynamic_convtranspose1d(self):
1565        q_mod = torch.ao.nn.quantized.ConvTranspose1d
1566        dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose1d
1567        dim = 3
1568        dtype = torch.quint8
1569
1570        for bias in [True, False]:
1571            self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1572
1573    @override_qengines
1574    def test_dynamic_convtranspose2d(self):
1575        q_mod = torch.ao.nn.quantized.ConvTranspose2d
1576        dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose2d
1577        dim = 4
1578        dtype = torch.quint8
1579
1580        for bias in [True, False]:
1581            self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1582
1583    @override_qengines
1584    def test_dynamic_convtranspose3d(self):
1585        q_mod = torch.ao.nn.quantized.ConvTranspose3d
1586        dq_mod = torch.ao.nn.quantized.dynamic.ConvTranspose3d
1587        dim = 5
1588        dtype = torch.quint8
1589
1590        if qengine_is_qnnpack():
1591            return  # qnnpack doesn't support unpacking conv3d
1592        for bias in [True, False]:
1593            self._test_qconv_impl(q_mod, dq_mod, dim, dtype, bias)
1594
1595    @given(
1596        batch_size=st.integers(1, 5),
1597        in_features=st.integers(16, 32),
1598        out_features=st.integers(4, 8),
1599        use_bias=st.booleans(),
1600        use_default_observer=st.booleans(),
1601    )
1602    @override_qengines
1603    def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_default_observer):
1604        """test API functionality for nn.quantized.dynamic.Linear"""
1605        W = torch.rand(out_features, in_features).float()
1606        qscheme = torch.per_tensor_symmetric if qengine_is_onednn() else torch.per_tensor_affine
1607        W_scale, W_zp = _calculate_dynamic_qparams(W, torch.qint8, qscheme=qscheme)
1608        W_q = torch.quantize_per_tensor(W, W_scale, W_zp, torch.qint8)
1609        X = torch.rand(batch_size, in_features).float()
1610        B = torch.rand(out_features).float() if use_bias else None
1611        qlinear = nnqd.Linear(in_features, out_features)
1612        # Run module with default-initialized parameters.
1613        # This tests that the constructor is correct.
1614        qlinear.set_weight_bias(W_q, B)
1615        qlinear(X)
1616
1617        # Simple round-trip test to ensure weight()/set_weight() API
1618        self.assertEqual(qlinear.weight(), W_q)
1619        W_pack = qlinear._packed_params._packed_params
1620        Z_dq = qlinear(X)
1621
1622        # Check if the module implementation matches calling the
1623        # ops directly
1624        Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack, reduce_range=True)
1625        self.assertEqual(Z_ref, Z_dq)
1626
1627        # Test serialization of dynamic quantized Linear Module using state_dict
1628        model_dict = qlinear.state_dict()
1629        b = io.BytesIO()
1630        torch.save(model_dict, b)
1631        for weights_only in [True, False]:
1632            b.seek(0)
1633            loaded_dict = torch.load(b, weights_only=weights_only)
1634            for key in model_dict:
1635                if isinstance(model_dict[key], torch._C.ScriptObject):
1636                    assert isinstance(loaded_dict[key], torch._C.ScriptObject)
1637                    w_model, b_model = torch.ops.quantized.linear_unpack(model_dict[key])
1638                    w_loaded, b_loaded = torch.ops.quantized.linear_unpack(loaded_dict[key])
1639                    self.assertEqual(w_model, w_loaded)
1640                    self.assertEqual(b_model, b_loaded)
1641                else:
1642                    self.assertEqual(model_dict[key], loaded_dict[key])
1643            loaded_qlinear = nnqd.Linear(in_features, out_features)
1644            loaded_qlinear.load_state_dict(loaded_dict)
1645
1646            linear_unpack = torch.ops.quantized.linear_unpack
1647            self.assertEqual(linear_unpack(qlinear._packed_params._packed_params),
1648                             linear_unpack(loaded_qlinear._packed_params._packed_params))
1649            if use_bias:
1650                self.assertEqual(qlinear.bias(), loaded_qlinear.bias())
1651            self.assertTrue(dir(qlinear) == dir(loaded_qlinear))
1652            self.assertTrue(hasattr(qlinear, '_packed_params'))
1653            self.assertTrue(hasattr(loaded_qlinear, '_packed_params'))
1654            self.assertTrue(hasattr(qlinear, '_weight_bias'))
1655            self.assertTrue(hasattr(loaded_qlinear, '_weight_bias'))
1656
1657            self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias())
1658            self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params._packed_params))
1659            Z_dq2 = qlinear(X)
1660            self.assertEqual(Z_dq, Z_dq2)
1661
1662        b = io.BytesIO()
1663        torch.save(qlinear, b)
1664        b.seek(0)
1665        # weights_only=False as this is legacy code that saves the model
1666        loaded = torch.load(b, weights_only=False)
1667        self.assertEqual(qlinear.weight(), loaded.weight())
1668        self.assertEqual(qlinear.zero_point, loaded.zero_point)
1669
1670        # Test JIT
1671        self.checkScriptable(qlinear, [[X]], check_save_load=True)
1672
1673        modules_under_test = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
1674        for mut in modules_under_test:
1675            # Test from_float
1676            float_linear = mut(in_features, out_features).float()
1677            if use_default_observer:
1678                float_linear.qconfig = torch.ao.quantization.default_dynamic_qconfig
1679            prepare_dynamic(float_linear)
1680            float_linear(X.float())
1681            quantized_float_linear = nnqd.Linear.from_float(float_linear)
1682
1683            # Smoke test to make sure the module actually runs
1684            quantized_float_linear(X)
1685
1686        # Smoke test extra_repr
1687        self.assertTrue('QuantizedLinear' in str(quantized_float_linear))
1688
1689    @given(
1690        dtype=st.sampled_from([torch.qint8, torch.float16]),
1691        bidirectional=st.booleans(),
1692    )
1693    @override_qengines
1694    def test_lstm_api(self, dtype, bidirectional):
1695        r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
1696        """
1697        # Check that module matches the numerics of the op and ensure that module can be
1698        # instantiated for all engines and dtypes
1699        seq_len = 4
1700        batch = 2
1701        input_size = 3
1702        hidden_size = 7
1703        num_layers = 2
1704        bias = True
1705        weight_keys = []
1706        bias_keys = []
1707        num_directions = 2 if bidirectional else 1
1708        for layer in range(num_layers):
1709            for direction in range(num_directions):
1710                suffix = '_reverse' if direction == 1 else ''
1711                key_name1 = f'weight_ih_l{layer}{suffix}'
1712                key_name2 = f'weight_hh_l{layer}{suffix}'
1713                weight_keys.append(key_name1)
1714                weight_keys.append(key_name2)
1715                key_name1 = f'bias_ih_l{layer}{suffix}'
1716                key_name2 = f'bias_hh_l{layer}{suffix}'
1717                bias_keys.append(key_name1)
1718                bias_keys.append(key_name2)
1719
1720        if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
1721            # fp16 dynamic quant is not supported for qnnpack or onednn
1722            x = torch.randn(seq_len, batch, input_size)
1723            h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1724            c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1725            cell_dq = torch.ao.nn.quantized.dynamic.LSTM(input_size=input_size,
1726                                                         hidden_size=hidden_size,
1727                                                         num_layers=num_layers,
1728                                                         bias=bias,
1729                                                         batch_first=False,
1730                                                         dropout=0.0,
1731                                                         bidirectional=bidirectional,
1732                                                         dtype=dtype)
1733            ref_dq = torch.ao.nn.quantized.dynamic.LSTM(input_size=input_size,
1734                                                        hidden_size=hidden_size,
1735                                                        num_layers=num_layers,
1736                                                        bias=bias,
1737                                                        batch_first=False,
1738                                                        dropout=0.0,
1739                                                        bidirectional=bidirectional,
1740                                                        dtype=dtype)
1741
1742            _all_params = ([m.param for m in cell_dq._all_weight_values])
1743            result = torch.quantized_lstm(x, (h, c),
1744                                          _all_params,
1745                                          cell_dq.bias,
1746                                          cell_dq.num_layers,
1747                                          float(cell_dq.dropout),
1748                                          False,
1749                                          bidirectional,
1750                                          False,
1751                                          dtype=dtype,
1752                                          use_dynamic=True)
1753
1754
1755            y, (h, c) = cell_dq(x, (h, c))
1756            self.assertEqual(result[0], y)
1757            self.assertEqual(result[1], h)
1758            self.assertEqual(result[2], c)
1759            x = torch.randn(10, 20, 3)
1760            self.check_eager_serialization(cell_dq, ref_dq, [x])
1761            self.check_weight_bias_api(cell_dq, weight_keys, bias_keys)
1762
1763    @override_qengines
1764    def test_gru_api(self):
1765        r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
1766        """
1767        # Check that module matches the numerics of the op and ensure that module can be
1768        # instantiated for all engines and dtypes
1769
1770        for dtype in [torch.qint8, torch.float16]:
1771            if dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn"):
1772                # fp16 dynamic quant is not supported for qnnpack or onednn
1773                continue
1774                # Test default instantiation
1775            seq_len = 4
1776            batch = 2
1777            input_size = 3
1778            hidden_size = 7
1779            num_layers = 2
1780            bias = True
1781            bidirectional = False
1782
1783            x = torch.rand(seq_len, batch, input_size)
1784            h = torch.rand(num_layers * (bidirectional + 1), batch, hidden_size)
1785
1786
1787            cell_dq = torch.ao.nn.quantized.dynamic.GRU(input_size=input_size,
1788                                                        hidden_size=hidden_size,
1789                                                        num_layers=num_layers,
1790                                                        bias=bias,
1791                                                        batch_first=False,
1792                                                        dropout=0.0,
1793                                                        bidirectional=bidirectional,
1794                                                        dtype=dtype)
1795
1796            _all_params = ([m.param for m in cell_dq._all_weight_values])
1797            result = torch.quantized_gru(x,
1798                                         h,
1799                                         _all_params,
1800                                         cell_dq.bias,
1801                                         cell_dq.num_layers,
1802                                         float(cell_dq.dropout),
1803                                         False,
1804                                         bidirectional,
1805                                         False)
1806
1807
1808            y, h = cell_dq(x, h)
1809            self.assertEqual(result[0], y, msg="GRU module API failed")
1810            self.assertEqual(result[1], h, msg="GRU module API failed")
1811
1812    @given(
1813        dtype=st.sampled_from([torch.qint8, torch.float16]),
1814    )
1815    @override_qengines
1816    def test_cell_api(self, dtype):
1817        r"""Test execution and serialization for dynamic quantized lstm modules on int8 and fp16
1818        """
1819        # Check that module matches the numerics of the op and ensure that module can be
1820        # instantiated for all engines and dtypes
1821        batch = 7
1822        input_size = 3
1823        hidden_size = 7
1824        bias = True
1825
1826        x = torch.rand(batch, input_size)
1827        h = torch.rand(batch, hidden_size)
1828        cell_dict = {'LSTMCell': torch.ao.nn.quantized.dynamic.LSTMCell,
1829                     'GRUCell': torch.ao.nn.quantized.dynamic.GRUCell,
1830                     'RNNTanh': torch.ao.nn.quantized.dynamic.RNNCell,
1831                     'RNNReLU': torch.ao.nn.quantized.dynamic.RNNCell
1832                     }
1833        state = {'LSTMCell': (h, h),
1834                 'GRUCell': h,
1835                 'RNNTanh': h,
1836                 'RNNReLU': h}
1837
1838        qfn_dict = {'LSTMCell': torch.ops.quantized.quantized_lstm_cell_dynamic,
1839                    'GRUCell': torch.ops.quantized.quantized_gru_cell_dynamic,
1840                    'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic,
1841                    'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic}
1842
1843        for rnn_type in cell_dict.keys():
1844            if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")):
1845                # fp16 dynamic quant is not supported for qnnpack or onednn
1846                kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype}
1847                if rnn_type == 'RNNReLU':
1848                    kwargs['nonlinearity'] = "relu"
1849                elif rnn_type == 'RNNTanh':
1850                    kwargs['nonlinearity'] = "tanh"
1851
1852                cell_dq = cell_dict[rnn_type](**kwargs)
1853                result = qfn_dict[rnn_type](x, state[rnn_type],
1854                                            cell_dq._packed_weight_ih, cell_dq._packed_weight_hh,
1855                                            cell_dq.bias_ih, cell_dq.bias_hh)
1856                result_module = cell_dq(x, state[rnn_type])
1857                self.assertEqual(result[0], result_module[0], msg="RNNCell module API failed")
1858                self.assertEqual(result[1], result_module[1], msg="RNNCell module API failed")
1859                weight_keys = ['weight_ih', 'weight_hh']
1860                bias_keys = ['bias_ih', 'bias_hh']
1861                self.check_eager_serialization(cell_dq, cell_dict[rnn_type](**kwargs), [x])
1862                self.check_weight_bias_api(cell_dq, weight_keys, bias_keys)
1863
1864class TestReferenceQuantizedModule(QuantizationTestCase):
1865    def _quant_dequant_weight(self, weight, weight_qparams):
1866        qscheme = weight_qparams["qscheme"]
1867        scale = weight_qparams["scale"]
1868        zero_point = weight_qparams["zero_point"]
1869        dtype = weight_qparams["dtype"]
1870        if qscheme == torch.per_tensor_affine:
1871            weight = torch.quantize_per_tensor(weight, scale, zero_point, dtype)
1872        else:
1873            # per channel affine
1874            axis = weight_qparams["axis"]
1875            weight = torch.quantize_per_channel(weight, scale, zero_point, axis, dtype)
1876        weight = weight.dequantize()
1877        return weight
1878
1879    # TODO: add tests for conv and linear
1880    def test_rnn_cell(self):
1881        """ Checks the rnn cell reference quantized modules has correct numerics
1882        This includes LSTMCell, GRUCell, RNNCell
1883        """
1884        batch = 7
1885        input_size = 3
1886        hidden_size = 7
1887        bias = True
1888
1889        x = torch.rand(batch, input_size)
1890        h = torch.rand(batch, hidden_size)
1891        cell_dict = {'LSTMCell': torch.nn.LSTMCell,
1892                     'GRUCell': torch.nn.GRUCell,
1893                     'RNNTanh': torch.nn.RNNCell,
1894                     'RNNReLU': torch.nn.RNNCell
1895                     }
1896        state = {'LSTMCell': (h, h),
1897                 'GRUCell': h,
1898                 'RNNTanh': h,
1899                 'RNNReLU': h}
1900
1901        qfn_dict = {'LSTMCell': nnqr.LSTMCell,
1902                    'GRUCell': nnqr.GRUCell,
1903                    'RNNTanh': nnqr.RNNCell,
1904                    'RNNReLU': nnqr.RNNCell}
1905
1906        for rnn_type in cell_dict.keys():
1907            kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias}
1908            if rnn_type == 'RNNReLU':
1909                kwargs['nonlinearity'] = "relu"
1910            elif rnn_type == 'RNNTanh':
1911                kwargs['nonlinearity'] = "tanh"
1912
1913            fp_cell = cell_dict[rnn_type](**kwargs)
1914            # initialize ref rnn cell module
1915            weight_qparams = {
1916                'qscheme': torch.per_tensor_affine,
1917                'dtype': torch.quint8,
1918                'scale': 2.0,
1919                'zero_point': 5
1920            }
1921            weight_qparams_dict = {
1922                "weight_ih": weight_qparams,
1923                "weight_hh": weight_qparams,
1924                "is_decomposed": False,
1925            }
1926            ref_kwargs = kwargs.copy()
1927            ref_kwargs["weight_qparams_dict"] = weight_qparams_dict
1928            ref_cell = qfn_dict[rnn_type](**ref_kwargs)
1929            # reassign the weights from fp32 rnn cell modulea
1930            ref_cell.weight_ih = fp_cell.weight_ih
1931            ref_cell.weight_hh = fp_cell.weight_hh
1932            ref_cell.bias_ih = fp_cell.bias_ih
1933            ref_cell.bias_hh = fp_cell.bias_hh
1934
1935            ref_res = ref_cell(x, state[rnn_type])
1936
1937            # change the weight of fp_res, we first want to run a quantie and
1938            # dequantize on the weight
1939            fp_cell.weight_ih = torch.nn.Parameter(self._quant_dequant_weight(fp_cell.weight_ih, weight_qparams_dict["weight_ih"]))
1940            fp_cell.weight_hh = torch.nn.Parameter(self._quant_dequant_weight(fp_cell.weight_hh, weight_qparams_dict["weight_hh"]))
1941            fp_res = fp_cell(x, state[rnn_type])
1942            self.assertEqual(ref_res[0], fp_res[0], msg="RNNCell module API failed")
1943            self.assertEqual(ref_res[1], fp_res[1], msg="RNNCell module API failed")
1944
1945    def test_rnn(self):
1946        """ Checks the rnn reference quantized modules has correct numerics
1947        This includes LSTM
1948        """
1949        seq_len = 4
1950        batch = 2
1951        input_size = 3
1952        hidden_size = 7
1953        num_layers = 2
1954        bias = True
1955        for bidirectional in [True, False]:
1956            x = torch.randn(seq_len, batch, input_size)
1957            h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1958            c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
1959            fp32_rnn = torch.nn.LSTM(
1960                input_size=input_size,
1961                hidden_size=hidden_size,
1962                num_layers=num_layers,
1963                bias=bias,
1964                batch_first=False,
1965                dropout=0.0,
1966                bidirectional=bidirectional)
1967            # initialize ref rnn module
1968            weight_qparams = {
1969                "qscheme": torch.per_tensor_affine,
1970                "dtype": torch.qint8,
1971                "scale": 2.0,
1972                "zero_point": 5
1973            }
1974            weight_qparams_dict = {key: weight_qparams for key in fp32_rnn._flat_weights_names if key.startswith("weight")}
1975            weight_qparams_dict["is_decomposed"] = False
1976            ref_rnn = nnqr.LSTM(
1977                input_size=input_size,
1978                hidden_size=hidden_size,
1979                num_layers=num_layers,
1980                bias=bias,
1981                batch_first=False,
1982                dropout=0.0,
1983                bidirectional=bidirectional,
1984                weight_qparams_dict=weight_qparams_dict)
1985            for wn in fp32_rnn._flat_weights_names:
1986                setattr(ref_rnn, wn, copy.deepcopy(getattr(fp32_rnn, wn)))
1987
1988            ref_rnn._flat_weights = copy.deepcopy(fp32_rnn._flat_weights)
1989
1990            # quantize and dequantize the weights for fp32_rnn module
1991            flat_weights = []
1992            for wn in fp32_rnn._flat_weights_names:
1993                if wn.startswith("weight"):
1994                    weight = self._quant_dequant_weight(getattr(fp32_rnn, wn), weight_qparams)
1995                else:
1996                    weight = getattr(fp32_rnn, wn)
1997                flat_weights.append(weight)
1998            fp32_rnn._flat_weights = flat_weights
1999
2000            fp32_res = fp32_rnn(x, (h, c))
2001            ref_res = ref_rnn(x, (h, c))
2002            self.assertEqual(fp32_res, ref_res)
2003
2004    def test_sparse(self):
2005        """ Embedding and EmbeddingBag
2006        """
2007        num_embeddings = 10
2008        embedding_dim = 3
2009        # embedding input
2010        ex = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
2011
2012        # embedding bag input
2013        ebx = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
2014        offsets = torch.tensor([0, 4], dtype=torch.long)
2015
2016        fp_to_ref = {
2017            nn.Embedding: (nnqr.Embedding, (ex,)),
2018            nn.EmbeddingBag: (nnqr.EmbeddingBag, (ebx, offsets)),
2019        }
2020
2021        per_tensor_weight_qparams = {
2022            "qscheme": torch.per_tensor_affine,
2023            "dtype": torch.quint8,
2024            "scale": 2.0,
2025            "zero_point": 5,
2026            "is_decomposed": False,
2027        }
2028
2029        per_channel_weight_qparams = {
2030            "qscheme": torch.per_channel_affine,
2031            "dtype": torch.quint8,
2032            "scale": torch.randn(10),
2033            "zero_point": torch.randint(0, 255, (10,)),
2034            "axis": 0,
2035            "is_decomposed": False,
2036        }
2037
2038        per_channel_weight_qparams_quint4x2 = {
2039            "qscheme": torch.per_channel_affine_float_qparams,
2040            "dtype": torch.quint4x2,
2041            "scale": torch.randn(10),
2042            "zero_point": torch.randint(0, 255, (10,)),
2043            "axis": 0,
2044            "is_decomposed": False,
2045        }
2046
2047        weight_qparams_options = [
2048            per_tensor_weight_qparams,
2049            per_channel_weight_qparams,
2050            per_channel_weight_qparams_quint4x2,
2051        ]
2052        for fp_cls, weight_qparams in itertools.product([nn.Embedding, nn.EmbeddingBag], weight_qparams_options):
2053            # TODO: torch.quint4x2 not supported in quantize_per_channel, need to add support
2054            if weight_qparams["dtype"] == torch.quint4x2:
2055                continue
2056            ref_cls, args = fp_to_ref[fp_cls]
2057
2058            fp32_embedding = fp_cls(num_embeddings, embedding_dim)
2059
2060            ref_embedding = ref_cls(num_embeddings, embedding_dim, weight_qparams=weight_qparams)
2061            ref_embedding.weight = fp32_embedding.weight
2062
2063            # quantize and dequantize the weight for fp32 module
2064            fp32_embedding.weight = torch.nn.Parameter(self._quant_dequant_weight(fp32_embedding.weight, weight_qparams))
2065
2066            fp32_res = fp32_embedding(*args)
2067            ref_res = ref_embedding(*args)
2068            self.assertEqual(fp32_res, ref_res)
2069
2070    def test_linear_decomposed_weight_custom_qmin_qmax(self):
2071        """Verify that reference Linear respects custom qmin/qmax for weight
2072        """
2073        linear_fp32 = torch.nn.Linear(2, 2)
2074        qconfig = torch.ao.quantization.default_symmetric_qnnpack_qconfig
2075        w_obs = qconfig.weight()
2076        self.assertTrue(w_obs.quant_min == -127)
2077        self.assertTrue(w_obs.quant_max == 127)
2078        w_obs(linear_fp32.weight)
2079        weight_qparams = torch.ao.quantization.utils.get_qparam_dict(w_obs)
2080        weight_qparams["is_decomposed"] = True
2081        linear_ref = nnqr.Linear.from_float(linear_fp32, weight_qparams)
2082        linear_ref_traced = torch.fx.symbolic_trace(linear_ref)
2083
2084        # verify that the qmin/qmax arguments for weight q/dq are correctly
2085        # taken from the observer
2086        found = 0
2087        for n in linear_ref_traced.graph.nodes:
2088            if n.op != 'call_function':
2089                continue
2090            if n.target in (
2091                torch.ops.quantized_decomposed.quantize_per_tensor,
2092                torch.ops.quantized_decomposed.dequantize_per_tensor,
2093            ):
2094                _0, _1, _2, qmin, qmax, _5 = n.args
2095                self.assertTrue(qmin == -127)
2096                self.assertTrue(qmax == 127)
2097                found += 1
2098        self.assertTrue(found == 2)
2099