xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_quantized_op.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3
4import copy
5import itertools
6import numpy as np
7import operator
8import random
9import unittest
10from typing import NamedTuple, List
11
12import torch
13from torch import _VF
14import torch.jit
15import torch.nn.functional as F
16from torch.nn.modules.utils import _single, _pair
17
18from hypothesis import settings, HealthCheck
19from hypothesis import assume, given, note
20from hypothesis import strategies as st
21import torch.testing._internal.hypothesis_utils as hu
22hu.assert_deadline_disabled()
23
24from torch.testing._internal.common_cuda import SM80OrLater
25from torch.testing._internal.common_utils import TestCase
26from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE
27from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
28from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
29    override_quantized_engine, supported_qengines, override_qengines, _snr
30from torch.testing._internal.common_quantized import (
31    qengine_is_qnnpack,
32    qengine_is_onednn,
33)
34from torch.ao.quantization import PerChannelMinMaxObserver
35from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA
36from torch.testing._internal.optests import opcheck
37import torch.backends.xnnpack
38
39from torch.utils.cpp_extension import ROCM_HOME
40
41from typing import Optional
42
43np_dtype = {
44    torch.quint8 : np.uint8,
45    torch.qint8 : np.int8,
46    torch.qint32 : np.int32
47}
48
49TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
50
51class PointwisePostOp(NamedTuple):
52    binary_attr : str = "none"
53    alpha : float = 1.0
54    unary_attr : str = "none"
55    scalars : List = []
56    algorithm : str = ""
57
58# Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM.
59# On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction
60# for the 8-bit int multiplication. This instruction vertically multiplies each
61# unsigned 8-bit integer from a with the corresponding signed 8-bit integer from
62# b, producing intermediate signed 16-bit integers. This function modifies the
63# weights to eliminate the overflow on the signed 16-bit integers.
64def avoid_vpmaddubsw_overflow_linear(
65    batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
66):
67    for i, j in np.ndindex((batch_size, output_channels)):
68        for k in range(0, input_channels // 2 * 2, 2):
69            x0 = X[i, k] - X_min
70            x1 = X[i, k + 1] - X_min
71            w0 = W[j, k] - 128 - W_min
72            w1 = W[j, k + 1] - 128 - W_min
73            if x0 * w0 + x1 * w1 < -(1 << 15):
74                w1_adjusted = (-(1 << 15) - float(x0) * w0) / x1
75                W[j, k + 1] = int(w1_adjusted) + 128 + W_min
76            elif x0 * w0 + x1 * w1 > (1 << 15) - 1:
77                w1_adjusted = ((1 << 15) - 1 - float(x0) * w0) / x1
78                W[j, k + 1] = int(w1_adjusted) + 128 + W_min
79
80    # Go through the same loop again to double check we don't have any overflow
81    for i, j in np.ndindex((batch_size, output_channels)):
82        for k in range(0, input_channels // 2 * 2, 2):
83            x0 = X[i, k] - X_min
84            x1 = X[i, k + 1] - X_min
85            w0 = W[j, k] - 128 - W_min
86            w1 = W[j, k + 1] - 128 - W_min
87            assert -(1 << 15) <= x0 * w0 + x1 * w1 < (1 << 15)
88
89
90# Reference quantized Linear operator
91def qlinear_ref(X_q, X_scale, X_zp, W_q, W_scale, W_zp, b_q, Y_scale, Y_zp, dtype=np.uint8):
92    X_q = np.reshape(X_q, (-1, X_q.shape[X_q.ndim - 1]))
93    row_offsets_ref = X_q.sum(axis=1).astype(np.int32).reshape((-1, 1))
94    col_offsets_ref = W_q.sum(axis=1).astype(np.int32).reshape((1, -1))
95    assert X_q.ndim == 2
96    batch_size, input_channels = X_q.shape
97    Prod_XqWq_ref = (
98        np.matmul(X_q.astype(np.int32), W_q.astype(np.int32).T)
99        - W_zp * row_offsets_ref
100        - X_zp * col_offsets_ref
101        + input_channels * X_zp * W_zp
102    )
103    if b_q is not None:
104        Prod_XqWq_ref += b_q
105    Y_q_ref = _quantize(Prod_XqWq_ref, Y_scale / (X_scale * W_scale), Y_zp, dtype=dtype)
106    return Y_q_ref
107
108"""Computes the output shape given pooling parameters."""
109def pool_output_shape(input_size, kernel_size, padding, stride,
110                      dilation, ceiling_mode=False):
111    if stride is None:
112        stride = kernel_size
113    output_size = (
114        (input_size + 2 * padding - dilation * (kernel_size - 1) - 1
115         + (stride - 1 if ceiling_mode else 0)) // stride + 1)
116    if (ceiling_mode and
117            ((output_size - 1) * stride >= input_size + padding)):
118        output_size -= 1
119    return output_size
120
121"""
122Util for creating a random tensor and quantization params when Hypothesis
123is undesirable.
124"""
125def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type):
126    X = (torch.rand(*shapes, dtype=torch.float) - 0.5) * rand_scale
127    # Calculate reasonable quantization params
128    min_val = torch.min(X)
129    max_val = torch.max(X)
130    if torch_type == torch.qint32:
131        X_zero_point = int(torch.randint(-1 * (2 ** 31), 2 ** 31 - 1, (1,)))
132        num_bins = 2 ** 32
133        X_scale = float(max_val - min_val) / num_bins
134    elif torch_type == torch.qint8:
135        X_zero_point = int(torch.randint(-128, 127, (1,)))
136        num_bins = 2 ** 8
137        X_scale = float(max_val - min_val) / num_bins
138    else:  # torch.quint8
139        X_zero_point = 127
140        num_bins = 2 ** 8
141        X_scale = float(max_val - min_val) / num_bins
142    if X_scale == 0:
143        X_scale = 1e-10
144    return X, X_scale, X_zero_point
145
146class TestQuantizedOps(TestCase):
147
148    """Helper function to test quantized activation functions."""
149    def _test_activation_function(self, X, fn_name, test_configs):
150        r"""
151            When writing a unit test for the activation function,
152            instead of specifying the test routines only applicable to the activation function itself,
153            you utilize the _test_activation_function that provides general testing.
154            To utilize the helper function, a test config must be provided.
155            A test config is a list that contains metadata about the quantized activation
156            functions that will be tested and how the tests need to be set up; it allows simpler and
157            more concise unit tests to be written by specifying the configurations needed
158            and calling the provided helper function _test_activation_function.
159            Inside the list, each config (as a dictionary) represents a suite of tests that assert the
160            correctness of various quantization functions.
161            You can check out the test_qrelu, test_qrelu6, test_qsigmoid, and test_qhardsigmoid for
162            how their test configs are specified.
163            Here's a list of the fields that can be included in a test config:
164            quantized_fn: a list of the quantized functions to be tested
165            reference_fn: the original reference function to be called on the
166            the dequantized X
167            extra_kwargs: the additional keyword arguments
168            for each test entry in ops_under_test, it must have at least the fields
169            for quantized_fn and reference_fn.
170            output_range: the output range the operator will map to. By default, if it is
171            no specified, the range will not be controlled and depend on Xmin and Xmax.
172            change_zero_point: a boolean flag indicating if the zero point parameter should
173            be determined based on torch_type during quantization (see sigmoid/hardsigmoid for
174            examples). By default, if it is not specified, change_zero_point is assumed to be
175            False and zero point will just take on the default value from X.
176            `output_is_observed`: if specified and is True, we'll append extra
177             output_scale/output_zero_point keyword argument when calling quantized op
178        """
179        # Retrives the default parameters from X.
180        X, (scale, zero_point, torch_type) = X
181        if not isinstance(X, torch.Tensor):
182            X = torch.from_numpy(X)
183        if (X.device.type == 'cuda') and (torch.backends.quantized.engine == 'qnnpack'):
184            return
185        # Quantizes the reference to account for max error.
186        # q_min and q_max only depend on the initial torch_type.
187        q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max
188
189        for op_group in test_configs:
190            ref_op = op_group['reference_fn']
191            for q_op in op_group['quantized_fn']:
192
193                for memory_format in (torch.channels_last, torch.contiguous_format):
194                    if memory_format == torch.channels_last and len(X.shape) != 4:
195                        continue
196                    X = X.to(memory_format=memory_format)
197
198                    # Retrieves the inplace keyword arguments
199                    # some functions require inplace=True to test in-place.
200                    # copy.copy is needed because these are modified in place
201                    extra_kwargs = \
202                        copy.copy(op_group.get('extra_kwargs', {}))
203                    output_is_observed = \
204                        copy.copy(op_group.get('output_is_observed', False))
205
206                    # Quantizes and dequantizes to account for max error.
207                    qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
208                                                   dtype=torch_type)
209                    dqX = qX.dequantize()
210                    dqY_hat = ref_op(dqX.clone(), **extra_kwargs)
211
212                    # Adjusts output_scale if needed.
213                    # The output_scale determines the quantization scale for functions that
214                    # have a constrained output range. e.x. sigmoid ranges from 0 to 1.
215                    output_scale = scale
216                    if 'output_range' in op_group:
217                        (f_min, f_max) = op_group['output_range']
218                        output_scale = (f_max - f_min) / (q_max - q_min + 1.0)
219
220                    # Adjusts output_zero_point if needed (see explanation for the
221                    # change_zero_point parameter above).
222                    # output_zero_point determines the additional offset that will be
223                    # added to a scaled value during quantization.
224                    if op_group.get('change_zero_point', False):
225                        output_zero_point = 0 if torch_type == torch.qint32 else q_min
226                    else:
227                        output_zero_point = zero_point
228
229                    # Quantizes the dequantized version of Y_hat.
230                    qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale,
231                                                       zero_point=output_zero_point,
232                                                       dtype=torch_type)
233
234                    if output_is_observed:
235                        extra_kwargs.update({'output_scale': output_scale, 'output_zero_point': output_zero_point})
236
237                    # Finds qY using in-place or non-in-place quantized operators.
238                    qY = q_op(qX, **extra_kwargs)
239
240                    self.assertEqual(qY, qY_hat, msg=f'{fn_name} - {q_op} failed: ({qY} vs. {qY_hat})')
241
242    """Tests the correctness of the quantized::relu op."""
243    @override_qengines
244    def test_qrelu(self):
245        relu_test_configs = [
246            {
247                'quantized_fn': [
248                    torch.relu,
249                    torch.relu_,
250                    torch.nn.functional.relu,
251                    torch.nn.functional.relu,
252                ],
253                'reference_fn': torch.nn.functional.relu
254            },
255            {
256                'quantized_fn': [
257                    torch.nn.functional.relu,
258                    torch.nn.functional.relu,
259                ],
260                'reference_fn': torch.nn.functional.relu,
261                'extra_kwargs': {
262                    'inplace': True
263                }
264            }
265        ]
266        devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"]
267        for device in devices:
268            shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
269            dtypes = (torch.quint8, torch.qint8)
270            scales = (0.05, 0.1)
271            zero_points = (0, 5)
272            test_cases = itertools.product(shapes, dtypes, scales, zero_points)
273            for shape, dtype, scale, zero_point in test_cases:
274                X = torch.randn(*shape, device=device)
275                X = (X, (scale, zero_point, dtype))
276                self._test_activation_function(X, 'relu', relu_test_configs)
277
278    """Tests the correctness of the quantized::relu6 op."""
279    def test_qrelu6(self):
280        relu6_test_configs = [
281            {
282                'quantized_fn': [
283                    torch.ops.quantized.relu6,
284                    torch.ao.nn.quantized.ReLU6(inplace=False),
285                    torch.ao.nn.quantized.ReLU6(inplace=True)
286                ],
287                'reference_fn': torch.nn.functional.relu6
288            }
289        ]
290        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
291        dtypes = (torch.quint8, torch.qint8)
292        scales = (0.05, 0.1)
293        zero_points = (0, 5)
294        test_cases = itertools.product(shapes, dtypes, scales, zero_points)
295        for shape, dtype, scale, zero_point in test_cases:
296            X = torch.randn(*shape) * 10
297            X = (X, (scale, zero_point, dtype))
298            self._test_activation_function(X, 'relu6', relu6_test_configs)
299
300    """Tests the correctness of the quantized::sigmoid op."""
301    @override_qengines
302    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
303                       qparams=hu.qparams()))
304    def test_sigmoid_non_observed(self, X):
305        sigmoid_test_configs = [
306            {
307                'quantized_fn': [
308                    torch.sigmoid
309                ],
310                'reference_fn': torch.sigmoid,
311                'output_range': (0.0, 1.0),
312                'change_zero_point': True
313            }
314        ]
315        self._test_activation_function(X, 'sigmoid', sigmoid_test_configs)
316
317    """Tests the correctness of the quantized::sigmoid op."""
318    # TODO: enable after observed output is supported in qnnpack
319    # @override_qengines
320    @skipIfNoFBGEMM
321    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
322                       qparams=hu.qparams()))
323    def test_sigmoid(self, X):
324        sigmoid_test_configs = [
325            {
326                'quantized_fn': [
327                    torch.ops.quantized.sigmoid
328                ],
329                'reference_fn': torch.sigmoid,
330                'output_range': (0.0, 1.0),
331                'change_zero_point': True,
332                'output_is_observed': True,
333            }
334        ]
335        self._test_activation_function(X, 'sigmoid', sigmoid_test_configs)
336
337    @skipIfNoFBGEMM
338    def test_sigmoid_dequantize_rounding_error(self):
339        # issue #107030
340        sigmoid_test_configs = [
341            {
342                'quantized_fn': [
343                    torch.ops.quantized.sigmoid
344                ],
345                'reference_fn': torch.sigmoid,
346                'output_range': (0.0, 1.0),
347                'change_zero_point': True,
348                'output_is_observed': True,
349            }
350        ]
351        X = (np.full(64, 514., dtype=np.float32), (1028.02, 255, torch.quint8))
352        self._test_activation_function(X, 'sigmoid', sigmoid_test_configs)
353
354    """Tests the correctness of the quantized::hardsigmoid op."""
355    @override_qengines
356    def test_qhardsigmoid(self):
357        hardsigmoid_test_configs = [
358            {
359                'quantized_fn': [
360                    torch.ao.nn.quantized.functional.hardsigmoid,
361                ],
362                'reference_fn': torch.nn.functional.hardsigmoid,
363                'output_range': (0.0, 1.0),
364                'change_zero_point': True,
365            },
366            {
367                'quantized_fn': [
368                    torch.ao.nn.quantized.functional.hardsigmoid,
369                ],
370                'reference_fn': torch.nn.functional.hardsigmoid,
371                'output_range': (0.0, 1.0),
372                'change_zero_point': True,
373                'extra_kwargs': {
374                    'inplace': True,
375                },
376            },
377        ]
378        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
379        dtypes = (torch.quint8, torch.qint8)
380        test_cases = itertools.product(shapes, dtypes)
381        for shape, dtype in test_cases:
382            X = (np.random.rand(*shape).astype(np.float32), (1.0, 0, dtype))
383            self._test_activation_function(X, 'hardsigmoid', hardsigmoid_test_configs)
384
385    @override_qengines
386    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
387                       qparams=hu.qparams()))
388    def test_leaky_relu_observed_output(self, X):
389        leaky_relu_test_configs = [
390            {
391                'quantized_fn': [
392                    torch.ops.quantized.leaky_relu
393                ],
394                'reference_fn': torch.nn.functional.leaky_relu,
395                'extra_kwargs': {
396                    'negative_slope': 0.1,
397                    'inplace': False,
398                },
399                'output_is_observed': True,
400            }
401        ]
402        self._test_activation_function(X, 'leaky_relu', leaky_relu_test_configs)
403
404    """Tests the correctness of the quantized::relu op."""
405    def test_leaky_relu(self):
406        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
407        dtypes = (torch.quint8, torch.qint8)
408        memory_formats = (torch.channels_last, torch.contiguous_format)
409        test_cases = itertools.product(shapes, dtypes, memory_formats)
410        for shape, dtype, memory_format in test_cases:
411            if memory_format == torch.channels_last and len(shape) != 4:
412                continue
413            X, scale, zero_point, torch_type, alpha = \
414                torch.randn(*shape), 0.1, 0, dtype, 0.01
415            X = X.to(memory_format=memory_format)
416
417            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
418                                           dtype=torch_type)
419            dqX = qX.dequantize()
420
421            # torch.nn.functional
422            op = torch.nn.functional.leaky_relu
423            dqY = op(dqX, negative_slope=alpha)
424            qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
425                                           dtype=torch_type)
426            qY_hat = op(qX, negative_slope=alpha)
427            self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
428                             msg=f"F.leaky_relu failed ({qY} vs {qY_hat})")
429
430    """Tests the correctness of the quantized::elu op."""
431    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
432                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
433                       qparams=hu.qparams()),
434           alpha=st.floats(0.01, 10.0, allow_nan=False, allow_infinity=False))
435    def test_qelu(self, X, alpha):
436        X, (scale, zero_point, torch_type) = X
437        output_scale = 0.5
438        output_zero_point = 1
439
440        X = torch.from_numpy(X)
441        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
442                                       dtype=torch_type)
443
444        # calculate ELU(dqX) and quantize
445        dqX = qX.dequantize()
446        dqY_hat = dqX.clone()
447        dqY_hat = torch.nn.functional.elu(dqX, alpha)
448        qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, zero_point=output_zero_point,
449                                           dtype=torch_type)
450
451        qY = torch.ao.nn.quantized.functional.elu(qX, output_scale, output_zero_point, alpha=alpha)
452        self.assertEqual(qY, qY_hat,
453                         msg=f"F.elu failed ({qY} vs {qY_hat})")
454
455
456    """Tests the correctness of the quantized::celu op."""
457    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
458                       elements=hu.floats(-1e2, 1e2, allow_nan=False, allow_infinity=False),
459                       qparams=hu.qparams(scale_max=9.999999747378752e-06)),
460           alpha=st.floats(0.01, 100.0, allow_nan=False, allow_infinity=False))
461    def test_qcelu(self, X, alpha):
462        X, (scale, zero_point, torch_type) = X
463        output_scale = 0.5
464        output_zero_point = 1
465
466        X = torch.from_numpy(X)
467        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
468                                       dtype=torch_type)
469
470        # calculate CELU(dqX) and quantize
471        dqX = qX.dequantize()
472        dqY_hat = torch.nn.functional.celu(dqX, alpha)
473        qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, zero_point=output_zero_point,
474                                           dtype=torch_type)
475
476        # test regular
477        qY = torch.ops.quantized.celu(qX, output_scale, output_zero_point, alpha=alpha)
478        self.assertEqual(qY, qY_hat,
479                         msg=f"F.celu failed ({qY} vs {qY_hat})")
480
481    """Tests the correctness of the quantized::gelu op."""
482    def test_qgelu(self):
483        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
484        dtypes = (torch.quint8, torch.qint8)
485        memory_formats = (torch.channels_last, torch.contiguous_format)
486        approximation = ['none', 'tanh']
487        test_cases = itertools.product(shapes, dtypes, memory_formats, approximation)
488        devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"]
489        for shape, dtype, memory_format, approximate in test_cases:
490            if memory_format == torch.channels_last and len(shape) != 4:
491                continue
492
493            X, scale, zero_point, torch_type = \
494                torch.randn(*shape), 0.1, 0, dtype
495            X = X.to(memory_format=memory_format)
496            for device in devices:
497                X = X.to(device=device)
498                qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
499                                               dtype=torch_type)
500                dqX = qX.dequantize()
501
502                op = torch.nn.functional.gelu
503                dqY = op(dqX, approximate=approximate)
504                qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
505                                               dtype=torch_type)
506                qY_hat = op(qX)
507                self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
508                                 msg=f"F.gelu failed ({qY} vs {qY_hat})")
509
510    """Tests the correctness of the quantized::prelu op."""
511    def test_qprelu(self):
512        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
513        num_params = (0, 1)  # 0: num_parameter = num_channels
514        dtypes = (torch.quint8, torch.qint8)
515        memory_formats = (torch.channels_last, torch.contiguous_format)
516        test_cases = itertools.product(shapes, num_params, dtypes, memory_formats)
517        for shape, num_param, dtype, memory_format in test_cases:
518            if memory_format == torch.channels_last and len(shape) != 4:
519                continue
520            X, scale, zero_point, torch_type = \
521                torch.randn(*shape), 0.1, 0, dtype
522            X = X.to(memory_format=memory_format)
523            num_parameter = 1 if num_param == 1 or len(shape) == 1 else shape[1]
524            W = torch.randn(num_parameter)
525            W, w_scale, w_zero_point = \
526                torch.randn(num_parameter), 0.2, 0
527
528            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
529                                           dtype=torch_type)
530            dqX = qX.dequantize()
531            qW = torch.quantize_per_tensor(W, scale=w_scale, zero_point=w_zero_point,
532                                           dtype=torch_type)
533            dqW = qW.dequantize()
534
535            op = torch.nn.functional.prelu
536            qop = torch.ops.quantized.prelu
537            dqY = op(dqX, dqW)
538            qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
539                                           dtype=torch_type)
540            qY_hat = qop(qX, qW, scale, zero_point)
541            self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
542                             msg=f"F.prelu failed ({qY} vs {qY_hat})")
543
544    """Tests the correctness of the quantized::qlayer_norm op."""
545    @skipIfNoFBGEMM
546    def test_qlayer_norm(self):
547        # hypothesis is flaky for this test, create test cases manually
548        side_lens = (1, 8, 11)
549        torch_types = (torch.qint8, torch.quint8)
550        y_scales = (0.1, 4.23)
551        y_zero_points = (0, 1)
552        channels_last_list = (True, False)
553        affine_list = (True, False)
554        combined = [side_lens, torch_types, y_scales, y_zero_points,
555                    channels_last_list, affine_list]
556        test_cases = itertools.product(*combined)
557
558        with override_quantized_engine("fbgemm"):
559            for test_case in test_cases:
560
561                side_len, torch_type, Y_scale, Y_zero_point, channels_last, \
562                    affine = test_case
563                shapes = [side_len] * 4
564
565                # In the FP kernel, mean and variance are calculated in floating point.
566                # In the quantized kernel, they are calculated in integer arithmetic.
567                # Because of this, the numerics do not always match exactly which is
568                # expected and acceptable. We do two things to allow this failure
569                # in this test:
570                # 1. do not use Hypothesis to generate the input tensor.  Hypothesis
571                #    favors homogeneous inputs in its search strategies which isn't
572                #    representative of the inputs we care about, and tends to maximize
573                #    this particular numerics difference.
574                # 2. allow a small % of off by Y_scale errors.  Even when the
575                #    variance of the input is high, there can be off by one errors
576                #    in the result if the input value happens to fall exactly on
577                #    the bin boundary of the output scale.
578                #
579                # If we want the numerics to match we could switch to calculating
580                # mean+var in floating point in the future, at the cost of speed.
581                X, X_scale, X_zero_point = \
582                    _get_random_tensor_and_q_params(shapes, 1.0, torch_type)
583
584                qX = torch.quantize_per_tensor(X, scale=X_scale,
585                                               zero_point=X_zero_point,
586                                               dtype=torch_type)
587                if channels_last:
588                    qX = qX.contiguous(memory_format=torch.channels_last)
589                dqX = qX.dequantize()
590
591                # Enforce non-homogeneous inputs
592                enough_unique_vals_in_each_layer = sum(
593                    1 if (
594                        dqX[i].shape[0] < 5 or
595                        float(torch.unique(dqX[i]).shape[0]) / dqX[i].shape[0] > 0.01
596                    ) else 0
597                    for i in range(dqX.shape[0])
598                ) == dqX.shape[0]
599                assume(enough_unique_vals_in_each_layer)
600
601                # Initialize the weights non-randomly for reproducibility, to avoid
602                # flaky tests
603                if affine:
604                    weight = torch.ones(*qX.size()[1:], dtype=torch.float) * 0.5
605                    bias = torch.ones(*qX.size()[1:], dtype=torch.float) * 1
606                else:
607                    weight = None
608                    bias = None
609                epsilon = 1e-5
610
611                qY = torch.ops.quantized.layer_norm(
612                    qX, qX.size()[1:], weight=weight, bias=bias, eps=epsilon,
613                    output_scale=Y_scale, output_zero_point=Y_zero_point)
614
615                Y_hat = F.layer_norm(
616                    dqX, dqX.size()[1:], weight=weight, bias=bias, eps=epsilon)
617                qY_hat = torch.quantize_per_tensor(
618                    Y_hat, scale=Y_scale, zero_point=Y_zero_point, dtype=torch_type)
619
620                # Due to the numerics difference mentioned above between calculating
621                # the variance in float vs int, the results can still be slightly
622                # different.
623                dqY = qY.dequantize()
624                dqY_hat = qY_hat.dequantize()
625                diff = dqY - dqY_hat
626
627                # off-by-one errors are magnitude of Y_scale
628                num_diff = torch.sum(diff > Y_scale * 1.0001)
629                pct_diff = float(num_diff) / (diff.numel() + 1e-5)
630                num_diff_off_by_one = torch.sum((diff > 0) * (diff <= Y_scale))
631                pct_diff_off_by_one = float(num_diff_off_by_one) / (diff.numel() + 1e-5)
632
633                self.assertTrue(pct_diff < 1e-6)
634                self.assertTrue(pct_diff_off_by_one < 0.01)
635
636
637    """Tests the correctness of the quantized::qnnpack_tanh op."""
638    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
639                       qparams=hu.qparams()))
640    @unittest.skip(
641        "this is broken without changes to any relevant code, "
642        "we need to remove hypothesis testing in CI")
643    def test_qtanh(self, X):
644        # Note: QNNPACK is tested separately in TestQNNPackOps
645        X, (scale, zero_point, torch_type) = X
646
647        X = torch.from_numpy(X)
648        Y = torch.tanh(X)
649
650        qX = torch.quantize_per_tensor(X, scale=scale,
651                                       zero_point=zero_point,
652                                       dtype=torch_type)
653
654        # Quantize the reference to account for max error.
655        # Note that the output scale has +1, because we use scale of 2.0/2^BITS
656        # in the implementations.
657        f_min, f_max = -1.0, 1.0
658        q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max
659        output_scale = (f_max - f_min) / (q_max - q_min + 1.0)
660        output_zero_point = int(round((q_max + q_min) / 2.0))
661        qY = torch.quantize_per_tensor(Y, scale=output_scale,
662                                       zero_point=output_zero_point,
663                                       dtype=torch_type)
664        qY_hat = torch.tanh(qX)
665        self.assertEqual(qY, qY_hat,
666                         msg=f"TanH failed: {qY} vs. {qY_hat}")
667
668    """Tests the correctness of the quantized::threshold op."""
669    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
670                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
671                       qparams=hu.qparams()),
672           threshold=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
673           value=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False))
674    def test_qthreshold(self, X, threshold, value):
675        X, (scale, zero_point, torch_type) = X
676        X = torch.from_numpy(X)
677        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
678                                       dtype=torch_type)
679
680        # calculate threshold(dqX) and quantize
681        dqX = qX.dequantize()
682        dqY_hat = dqX.clone()
683        dqY_hat = torch.nn.functional.threshold(dqY_hat, threshold, value)
684        qY_hat = torch.quantize_per_tensor(dqY_hat, scale=scale, zero_point=zero_point,
685                                           dtype=torch_type)
686
687        ops_under_test = {
688            'native': torch.threshold,
689            'nn.functional': torch.nn.functional.threshold,
690            'ao.nn.quantized.functional': torch.ao.nn.quantized.functional.threshold,
691        }
692
693        for name, op in ops_under_test.items():
694            qY = op(qX, threshold, value)
695            self.assertEqual(qY, qY_hat, msg=f"{name} qthreshold failed")
696
697    """Tests the correctness of the quantized::clamp op."""
698    @given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8, max_numel=10**5),
699                       elements=hu.floats(-1e6, 1e6, allow_nan=False),
700                       qparams=hu.qparams()),
701           min_val=hu.floats(-1e6, 1e6, allow_nan=False),
702           max_val=hu.floats(-1e6, 1e6, allow_nan=False))
703    def test_qclamp(self, X, min_val, max_val):
704        X, (scale, zero_point, torch_type) = X
705
706        assume(min_val <= max_val)
707        Y_clamp = torch.clamp(torch.from_numpy(X), min=min_val, max=max_val)
708        qY_clamp = torch.quantize_per_tensor(Y_clamp, scale=scale,
709                                             zero_point=zero_point, dtype=torch_type)
710
711        X = torch.from_numpy(X)
712        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
713                                       dtype=torch_type)
714        ops_under_test = {
715            'ops.quantized': torch.ops.quantized.clamp,
716        }
717
718        for name, op in ops_under_test.items():
719            qY_clamp_hat = op(qX, min=min_val, max=max_val)
720            self.assertEqual(qY_clamp, qY_clamp_hat, msg=f"{name} qclamp failed")
721
722        if torch.backends.quantized.engine == 'fbgemm':
723            with override_quantized_engine('fbgemm'):
724                Y_min_clamp = torch.clamp(X, min=min_val)
725                Y_max_clamp = torch.clamp(X, max=max_val)
726
727                qY_min_clamp = torch.quantize_per_tensor(Y_min_clamp, scale=scale,
728                                                         zero_point=zero_point, dtype=torch_type)
729                qY_max_clamp = torch.quantize_per_tensor(Y_max_clamp, scale=scale,
730                                                         zero_point=zero_point, dtype=torch_type)
731
732
733                for name, op in ops_under_test.items():
734                    qY_min_clamp_hat = op(qX, min=min_val)
735                    self.assertEqual(qY_min_clamp, qY_min_clamp_hat, msg=f"{name} qclamp failed")
736                    qY_max_clamp_hat = op(qX, max=max_val)
737                    self.assertEqual(qY_max_clamp, qY_max_clamp_hat, msg=f"{name} qclamp failed")
738
739    """Tests the correctness of the quantized::hardtanh op."""
740    @skipIfNoFBGEMM
741    @given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8, max_numel=10**5),
742                       elements=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False),
743                       qparams=hu.qparams()),
744           min_val=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False),
745           max_val=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False))
746    def test_hardtanh(self, X, min_val, max_val):
747        with override_quantized_engine('fbgemm'):
748            X, (scale, zero_point, torch_type) = X
749
750            assume(min_val <= max_val)
751            Y = X.copy()
752            Y[Y < min_val] = min_val
753            Y[Y > max_val] = max_val
754            qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
755                                           zero_point=zero_point, dtype=torch_type)
756            X = torch.from_numpy(X)
757            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
758                                           dtype=torch_type)
759
760            ops_under_test = {
761                'ao.nn.quantized.functional.hardtanh':
762                    torch.ao.nn.quantized.functional.hardtanh,
763            }
764
765            for name, op in ops_under_test.items():
766                qY_hat = op(qX, min_val, max_val)
767                self.assertEqual(qY, qY_hat, msg=f"{name} hardtanh failed")
768
769            ops_under_test_inplace = {
770                'inplace ao.nn.quantized.functional.hardtanh':
771                    torch.ao.nn.quantized.functional.hardtanh,
772            }
773
774            for name, op_ in ops_under_test_inplace.items():
775                qY_hat = qX.clone()
776                op_(qY_hat, min_val, max_val, inplace=True)
777                self.assertEqual(qY, qY_hat, msg=f"{name} hardtanh failed")
778
779    """Tests the correctness of the quantized::hardswish op."""
780    @override_qengines
781    def test_hardswish(self):
782        max_sides = (3, 4)
783        side_lens = (1, 7)
784        torch_types = (torch.quint8, torch.qint8)
785        y_scales = (0.1, )
786        y_zero_points = (1,)
787        combined = [max_sides, side_lens, torch_types, y_scales, y_zero_points]
788        test_cases = itertools.product(*combined)
789        for test_case in test_cases:
790            max_side, side_len, torch_type, Y_scale, Y_zero_point = test_case
791
792            if torch.backends.quantized.engine == 'qnnpack' and torch_type != torch.quint8:
793                continue
794
795            shapes = [side_len] * max_side
796            X, X_scale, X_zero_point = \
797                _get_random_tensor_and_q_params(shapes, 2.0, torch_type)
798            for memory_format in torch.channels_last, torch.contiguous_format:
799                if memory_format == torch.channels_last and len(shapes) == 4:
800                    X = X.to(memory_format=memory_format)
801                qX = torch.quantize_per_tensor(X, scale=X_scale, zero_point=X_zero_point,
802                                               dtype=torch_type)
803                dqX = qX.dequantize()
804
805                dqY_hat = F.hardswish(dqX)
806                qY_hat = torch.quantize_per_tensor(dqY_hat, scale=Y_scale,
807                                                   zero_point=Y_zero_point,
808                                                   dtype=torch_type)
809
810                qY = torch.ao.nn.quantized.functional.hardswish(
811                    qX, scale=Y_scale, zero_point=Y_zero_point)
812                self.assertEqual(
813                    qY, qY_hat,
814                    msg=f"Hardswish failed: {qY} vs {qY_hat}, {torch.backends.quantized.engine}")
815
816    """Tests the correctness of the binary op + scalar."""
817    def _test_binary_op_scalar_relu(self, A, b, binary_op_name, binary_op, quantized_op, quantized_op_relu):
818        import copy
819        op_scalar = quantized_op
820        op_scalar_relu = quantized_op_relu
821
822        A, (scale, zero_point, dtype) = A
823        A = A.astype(np.float32)
824        qA = torch.quantize_per_tensor(torch.from_numpy(A), scale, zero_point, dtype)
825
826        if binary_op_name == 'add':
827            C = binary_op(qA.dequantize(), round(b / scale) * scale)
828        else:
829            C = binary_op(qA.dequantize(), b)
830        C_relu = copy.deepcopy(C)
831        C_relu[C_relu < 0] = 0
832
833        C_hat = op_scalar(qA, b)
834        C_ref = torch.quantize_per_tensor(C, C_hat.q_scale(), C_hat.q_zero_point(), dtype)
835        C_relu_hat = op_scalar_relu(qA, b)
836        C_relu_ref = torch.quantize_per_tensor(
837            C_relu, C_relu_hat.q_scale(), C_relu_hat.q_zero_point(), dtype)
838
839        self.assertEqual(C_ref.dequantize(), C_hat.dequantize(),
840                         msg=f"{binary_op_name}_scalar results don't match: "
841                         f"{C_ref.dequantize()} vs {C_hat.dequantize()}")
842        self.assertEqual(C_relu_ref.dequantize(), C_relu_hat.dequantize(),
843                         msg=f"{binary_op_name}_scalar_relu results don't match: "
844                         f"{C_relu_ref.dequantize()} vs {C_relu_hat.dequantize()}")
845
846    @unittest.skipIf(IS_MACOS, "skipping macos test")
847    @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5),
848                       elements=hu.floats(-1e6, 1e6, allow_nan=False),
849                       qparams=hu.qparams()),
850           b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False))
851    def test_add_scalar_relu(self, A, b):
852        self._test_binary_op_scalar_relu(A, b, "add", operator.add, torch.ops.quantized.add, torch.ops.quantized.add_relu)
853
854    @unittest.skipIf(IS_MACOS, "skipping macos test")
855    @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5),
856                       elements=hu.floats(-1e6, 1e6, allow_nan=False),
857                       qparams=hu.qparams()),
858           b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False))
859    def test_mul_scalar_relu(self, A, b):
860        self._test_binary_op_scalar_relu(A, b, "mul", operator.mul, torch.ops.quantized.mul, torch.ops.quantized.mul_relu)
861
862    """Tests the correctness of the add and add_relu op."""
863    def test_qadd_relu_same_qparams(self):
864        for dtype in [torch.quint8, torch.qint8, torch.qint32]:
865            add_relu = torch.ops.quantized.add_relu
866            add = torch.ops.quantized.add
867            add_out = torch.ops.quantized.add
868            add_relu_out = torch.ops.quantized.add_relu
869
870            # NB: This is a strange size so that we exercise both the vectorized
871            # implementation (64-element chunks at at time) as well as the scalar
872            # implementation
873            A = torch.arange(-128, 130, dtype=torch.float)
874            B = torch.arange(-128, 130, dtype=torch.float)
875            scale = 2.0
876            zero_point = 127
877            qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point,
878                                           dtype=dtype)
879            qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point,
880                                           dtype=dtype)
881
882            # Add ReLU ground truth
883            C = (qA.dequantize() + qB.dequantize()).numpy()
884            qC = _quantize(C, scale, zero_point, dtype=np_dtype[dtype])
885            qC_hat = add(qA, qB, scale=scale, zero_point=zero_point)
886            np.testing.assert_equal(qC, qC_hat.int_repr(),
887                                    "Quantized addition failed.")
888            qC_out_hat = torch._empty_affine_quantized(qC.shape,
889                                                       scale=scale,
890                                                       zero_point=zero_point,
891                                                       dtype=dtype)
892            add_out(qA, qB, out=qC_out_hat)
893            self.assertEqual(qC_hat, qC_out_hat, msg="Add.out failed")
894
895            # Add + ReLU ground truth
896            Crelu = C.copy()
897            Crelu[C < 0] = 0
898            qCrelu = _quantize(Crelu, scale, zero_point, dtype=np_dtype[dtype])
899            qCrelu_hat = add_relu(qA, qB, scale=scale, zero_point=zero_point)
900            np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
901                                    "Quantized addition with ReLU failed.")
902            qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
903                                                           scale=scale,
904                                                           zero_point=zero_point,
905                                                           dtype=dtype)
906            add_relu_out(qA, qB, out=qCrelu_out_hat)
907            self.assertEqual(qCrelu_hat, qCrelu_out_hat,
908                             msg="AddReLU.out failed")
909
910    """Tests the correctness of the cudnn add and add_relu op
911    (Similar to test_qadd_relu_different_qparams, will probably merge in the future)"""
912    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
913    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
914    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
915    @unittest.skip("not currently working and feature isn't used")
916    def test_qadd_relu_cudnn(self):
917        dtype = torch.qint8
918        add_relu = torch.ops.quantized.add_relu
919        add = torch.ops.quantized.add
920
921        A = torch.arange(-128, 130, dtype=torch.float).to(torch.device("cuda"))
922        B = torch.arange(-128, 130, dtype=torch.float).to(torch.device("cuda"))
923        scale_A = 2.5
924        scale_B = 6.3
925        scale_C = 12.9
926        zero_point = 0
927        qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point,
928                                       dtype=dtype)
929        qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point,
930                                       dtype=dtype)
931        # Add ground truth
932        C = (qA.dequantize() + qB.dequantize()).to(device="cpu").numpy()
933        qC = _quantize(C, scale_C, zero_point, dtype=np_dtype[dtype])
934        qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu")
935        np.testing.assert_equal(qC, qC_hat.int_repr(),
936                                "Quantized addition failed.")
937
938        # Add + ReLU ground truth
939        Crelu = C.copy()
940        Crelu[C < 0] = 0
941        qCrelu = _quantize(Crelu, scale_C, zero_point, dtype=np_dtype[dtype])
942        qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu")
943        np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
944                                "Quantized addition with ReLU failed.")
945
946    """Tests the correctness of the cudnn add and add_relu op for nhwc format"""
947    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
948    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
949    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
950    @unittest.skip("not currently working and feature isn't used")
951    def test_qadd_relu_cudnn_nhwc(self):
952        dtype = torch.qint8
953        add_relu = torch.ops.quantized.add_relu
954        add = torch.ops.quantized.add
955
956        A = torch.rand(16, 8, 4, 12).to(device="cuda")
957        B = torch.rand(16, 8, 4, 12).to(device="cuda")
958        scale_A = 2.5
959        scale_B = 6.3
960        scale_C = 12.9
961        zero_point = 0
962        qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point,
963                                       dtype=dtype)
964        qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point,
965                                       dtype=dtype)
966        # Add ground truth
967        C = (qA.dequantize() + qB.dequantize()).to(device="cpu").numpy()
968        qC = _quantize(C, scale_C, zero_point, dtype=np_dtype[dtype])
969        qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu")
970        np.testing.assert_equal(qC, qC_hat.int_repr(),
971                                "Quantized addition failed.")
972
973        # Add + ReLU ground truth
974        Crelu = C.copy()
975        Crelu[C < 0] = 0
976        qCrelu = _quantize(Crelu, scale_C, zero_point, dtype=np_dtype[dtype])
977        qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu")
978        np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
979                                "Quantized addition with ReLU failed.")
980
981    """Tests the correctness of the add and add_relu op."""
982    def test_qadd_relu_different_qparams(self):
983        for dtype in [torch.quint8, torch.qint8, torch.qint32]:
984            add_relu = torch.ops.quantized.add_relu
985            add = torch.ops.quantized.add
986            add_out = torch.ops.quantized.add
987            add_relu_out = torch.ops.quantized.add_relu
988
989            # NB: This is a strange size so that we exercise both the vectorized
990            # implementation (64-element chunks at at time) as well as the scalar
991            # implementation
992            A = torch.arange(-128, 130, dtype=torch.float)
993            B = torch.arange(-128, 130, dtype=torch.float)
994            scale_A = 3.0
995            zero_point_A = 7
996            scale_B = 5.0
997            zero_point_B = 127
998
999            scale_C = 0.5
1000            zero_point_C = 5
1001
1002            qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
1003                                           dtype=dtype)
1004            qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
1005                                           dtype=dtype)
1006
1007            # Add ground truth
1008            C = (qA.dequantize() + qB.dequantize()).numpy()
1009            qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype[dtype])
1010            qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point_C)
1011            np.testing.assert_equal(qC, qC_hat.int_repr(),
1012                                    "Quantized addition failed.")
1013            qC_out_hat = torch._empty_affine_quantized(qC.shape,
1014                                                       scale=scale_C,
1015                                                       zero_point=zero_point_C,
1016                                                       dtype=dtype)
1017            add_out(qA, qB, out=qC_out_hat)
1018            self.assertEqual(qC_hat, qC_out_hat, msg="Add.out failed")
1019
1020            # Add + ReLU ground truth
1021            Crelu = C.copy()
1022            Crelu[C < 0] = 0
1023            qCrelu = _quantize(Crelu, scale_C, zero_point_C, dtype=np_dtype[dtype])
1024            qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
1025            np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
1026                                    "Quantized addition with ReLU failed.")
1027            qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
1028                                                           scale=scale_C,
1029                                                           zero_point=zero_point_C,
1030                                                           dtype=dtype)
1031            add_relu_out(qA, qB, out=qCrelu_out_hat)
1032            self.assertEqual(qCrelu_hat, qCrelu_out_hat,
1033                             msg="AddReLU.out failed")
1034
1035    """Tests the correctness of the mul and mul_relu op."""
1036    def test_qmul_relu_same_qparams(self):
1037        for dtype in [torch.quint8, torch.qint8, torch.qint32]:
1038            mul_relu = torch.ops.quantized.mul_relu
1039            mul = torch.ops.quantized.mul
1040            mul_out = torch.ops.quantized.mul
1041            mul_relu_out = torch.ops.quantized.mul_relu
1042
1043            A = torch.arange(-100, 100, dtype=torch.float)
1044            B = torch.arange(-100, 100, dtype=torch.float)
1045            scale = 2
1046            zero_point = 127
1047            qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point,
1048                                           dtype=dtype)
1049            qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point,
1050                                           dtype=dtype)
1051
1052            # mul ReLU ground truth
1053            C = (qA.dequantize() * qB.dequantize()).numpy()
1054            qC = _quantize(C, scale, zero_point, dtype=np_dtype[dtype])
1055            qC_hat = mul(qA, qB, scale=scale, zero_point=zero_point)
1056            np.testing.assert_equal(qC, qC_hat.int_repr(),
1057                                    "Quantized mulition failed.")
1058            qC_out_hat = torch._empty_affine_quantized(qC.shape,
1059                                                       scale=scale,
1060                                                       zero_point=zero_point,
1061                                                       dtype=dtype)
1062            mul_out(qA, qB, out=qC_out_hat)
1063            self.assertEqual(qC_hat, qC_out_hat, msg="mul.out failed")
1064
1065            # mul + ReLU ground truth
1066            Crelu = C.copy()
1067            Crelu[C < 0] = 0
1068            qCrelu = _quantize(Crelu, scale, zero_point, dtype=np_dtype[dtype])
1069            qCrelu_hat = mul_relu(qA, qB, scale=scale, zero_point=zero_point)
1070            np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
1071                                    "Quantized mulition with ReLU failed.")
1072            qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
1073                                                           scale=scale,
1074                                                           zero_point=zero_point,
1075                                                           dtype=dtype)
1076            mul_relu_out(qA, qB, out=qCrelu_out_hat)
1077            self.assertEqual(qCrelu_hat, qCrelu_out_hat,
1078                             msg="mulReLU.out failed")
1079
1080            # Scalar multiplication
1081            for b in B:
1082                C_ref = qA.dequantize().numpy() * b.item()
1083                qC_hat = torch.ops.quantized.mul(qA, b.item())
1084
1085                self.assertEqual(C_ref, qC_hat.dequantize())
1086
1087            # Scalar multiplication + relu
1088            for b in B:
1089                C_ref = qA.dequantize().numpy() * b.item()
1090                C_ref[C_ref < 0] = 0
1091                qC_hat = torch.ops.quantized.mul_relu(qA, b.item())
1092
1093                self.assertEqual(C_ref, qC_hat.dequantize())
1094
1095    """Tests the correctness of the mul and mul_relu op."""
1096    def test_qmul_relu_different_qparams(self):
1097        for dtype in [torch.quint8, torch.qint8, torch.qint32]:
1098            mul_relu = torch.ops.quantized.mul_relu
1099            mul = torch.ops.quantized.mul
1100            mul_out = torch.ops.quantized.mul
1101            mul_relu_out = torch.ops.quantized.mul_relu
1102
1103            A = torch.arange(-100, 100, dtype=torch.float)
1104            B = torch.arange(-100, 100, dtype=torch.float)
1105            scale_A = 3.0
1106            zero_point_A = 7
1107            scale_B = 5.0
1108            zero_point_B = 127
1109
1110            scale_C = 0.5
1111            zero_point_C = 5
1112
1113            qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
1114                                           dtype=dtype)
1115            qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
1116                                           dtype=dtype)
1117
1118            # mul ground truth
1119            C = (qA.dequantize() * qB.dequantize()).numpy()
1120            qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype[dtype])
1121            qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C)
1122            np.testing.assert_equal(qC, qC_hat.int_repr(),
1123                                    "Quantized multiplication failed.")
1124            qC_out_hat = torch._empty_affine_quantized(qC.shape,
1125                                                       scale=scale_C,
1126                                                       zero_point=zero_point_C,
1127                                                       dtype=dtype)
1128            mul_out(qA, qB, out=qC_out_hat)
1129            self.assertEqual(qC_hat, qC_out_hat, msg="mul.out failed")
1130
1131            # mul + ReLU ground truth
1132            Crelu = C.copy()
1133            Crelu[C < 0] = 0
1134            qCrelu = _quantize(Crelu, scale_C, zero_point_C, dtype=np_dtype[dtype])
1135            qCrelu_hat = mul_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
1136            np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
1137                                    "Quantized multiplication with ReLU failed.")
1138            qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
1139                                                           scale=scale_C,
1140                                                           zero_point=zero_point_C,
1141                                                           dtype=dtype)
1142            mul_relu_out(qA, qB, out=qCrelu_out_hat)
1143            self.assertEqual(qCrelu_hat, qCrelu_out_hat,
1144                             msg="mulReLU.out failed")
1145
1146    """Tests the correctness of the matmul op."""
1147    @given(num_dims=st.integers(2, 5),
1148           outer_dims=st.lists(st.integers(2, 6), min_size=3, max_size=3),
1149           m=st.integers(2, 6),
1150           k=st.integers(2, 6),
1151           n=st.integers(2, 6),
1152           dtypes=st.sampled_from(((torch.qint8, np.int8),
1153                                   (torch.quint8, np.uint8))))
1154    def test_qmatmul(self, num_dims, outer_dims, m, k, n, dtypes):
1155        (torch_dtype, np_dtype) = dtypes
1156
1157        size_a = outer_dims[:num_dims - 2] + [m, k]
1158        size_b = outer_dims[:num_dims - 2] + [k, n]
1159        A = torch.randn(size=size_a, dtype=torch.float32) * 3
1160        B = torch.randn(size=size_b, dtype=torch.float32) * 3
1161
1162        scale_A = 3.1
1163        zero_point_A = 7
1164        scale_B = 5.3
1165        zero_point_B = 127
1166
1167        scale_C = 1.3
1168        zero_point_C = 5
1169
1170        qA = torch.quantize_per_tensor(A,
1171                                       scale=scale_A,
1172                                       zero_point=zero_point_A,
1173                                       dtype=torch_dtype)
1174        qB = torch.quantize_per_tensor(B,
1175                                       scale=scale_B,
1176                                       zero_point=zero_point_B,
1177                                       dtype=torch_dtype)
1178
1179        # matmul ground truth
1180        C = torch.matmul(qA.dequantize(), qB.dequantize()).numpy()
1181        qC = _quantize(C, scale_C, zero_point_C, dtype=(np_dtype))
1182        qC_hat = torch.ops.quantized.matmul(qA,
1183                                            qB,
1184                                            scale=scale_C,
1185                                            zero_point=zero_point_C)
1186        np.testing.assert_equal(qC, qC_hat.int_repr(),
1187                                "Quantized multiplication failed.")
1188
1189        # Using per channel quantization fails
1190        axis = 0
1191        scales_A = torch.rand(size=(A.shape[axis],))
1192        zero_points_A = torch.randint(low=0, high=5, size=(A.shape[axis],))
1193        scales_B = torch.rand(size=(B.shape[axis],))
1194        zero_points_B = torch.randint(low=0, high=5, size=(B.shape[axis],))
1195
1196        qA = torch.quantize_per_channel(A,
1197                                        scales=scales_A,
1198                                        zero_points=zero_points_A,
1199                                        axis=axis,
1200                                        dtype=torch.qint8)
1201        qB = torch.quantize_per_channel(B,
1202                                        scales=scales_B,
1203                                        zero_points=zero_points_B,
1204                                        axis=axis,
1205                                        dtype=torch.qint8)
1206        np.testing.assert_raises_regex(RuntimeError,
1207                                       ".*per-tensor.*",
1208                                       torch.ops.quantized.matmul,
1209                                       qA,
1210                                       qB,
1211                                       scale_C,
1212                                       zero_point_C)
1213
1214
1215    """Tests the correctness of the quantized softmax op."""
1216    @given(dims=st.lists(st.integers(2, 5), min_size=5, max_size=5))
1217    def test_qsoftmax(self, dims):
1218        for (num_dims, dim, memory_format) in [
1219            (2, 1, torch.contiguous_format),  # 2d softmax over last dim
1220            (4, 3, torch.contiguous_format),  # >2 dims, softmax along last dim
1221            (5, 2, torch.contiguous_format),  # >2 dims, softmax along not last dim (requires permute)
1222            (4, 3, torch.channels_last),      # >2 dims, softmax along last dim, but not contiguous
1223            (4, 1, torch.channels_last),      # Channels Last, doesn't require permute
1224            (5, 1, torch.channels_last_3d),   # Channels Last 3D, doesn't require permute
1225        ]:
1226            size = dims[:num_dims]
1227            torch_dtype = torch.quint8
1228            np_dtype = np.uint8
1229
1230            scale_X = 1.3
1231            zero_point_X = 5
1232            X = torch.rand(size=size, dtype=torch.float32) * 8 + zero_point_X
1233            X = X.to(memory_format=memory_format)
1234
1235            scale_Y = 1 / 256
1236            zero_point_Y = 0
1237
1238            qX = torch.quantize_per_tensor(X,
1239                                           scale=scale_X,
1240                                           zero_point=zero_point_X,
1241                                           dtype=torch_dtype)
1242
1243
1244            # softmax ground truth
1245            Y = torch.softmax(qX.dequantize(), dim=dim).numpy()
1246            qY = _quantize(Y, scale_Y, zero_point_Y, dtype=np_dtype)
1247            qY_hat = torch.ops.quantized.softmax(qX,
1248                                                 dim=dim,
1249                                                 output_scale=scale_Y,
1250                                                 output_zero_point=zero_point_Y)
1251
1252            np.testing.assert_equal(qY, qY_hat.int_repr(),
1253                                    "Quantized softmax failed.")
1254
1255    """Tests the correctness of the quantized softmax op using qnnpack."""
1256    @skipIfNoQNNPACK
1257    def test_qsoftmax_qnnpack(self):
1258        with override_quantized_engine('qnnpack'):
1259            self.test_qsoftmax()
1260
1261    """Tests the correctness of the mul and mul_relu op."""
1262    def test_qmul_broadcast(self):
1263        mul_relu = torch.ops.quantized.mul_relu
1264        mul = torch.ops.quantized.mul
1265        mul_out = torch.ops.quantized.mul
1266        mul_relu_out = torch.ops.quantized.mul_relu
1267
1268        # A = torch.arange(-25, 25, dtype=torch.float)
1269        # B = torch.arange(-25, 25, dtype=torch.float)
1270        A = torch.randn(8, 1, 6, 1)
1271        B = torch.randn(7, 1, 5)
1272        scale_A = 3.0
1273        zero_point_A = 7
1274        scale_B = 5.0
1275        zero_point_B = 127
1276
1277        scale_C = 0.5
1278        zero_point_C = 5
1279
1280        qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
1281                                       dtype=torch.quint8)
1282        qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
1283                                       dtype=torch.quint8)
1284
1285        # mul ground truth
1286        C = (qA.dequantize() * qB.dequantize()).numpy()
1287        qC = _quantize(C, scale_C, zero_point_C)
1288        qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C)
1289        np.testing.assert_equal(qC, qC_hat.int_repr(),
1290                                "Quantized multiplication failed.")
1291
1292    """Tests that quantized add works with broadcasting"""
1293    def test_qadd_broadcast(self):
1294        A = torch.randn(1, 1, 4, 4)
1295        B = torch.randn(2, 1, 4, 4)
1296        qA = torch.quantize_per_tensor(A, 0.02, 0, torch.quint8)
1297        qB = torch.quantize_per_tensor(B, 0.04, 2, torch.quint8)
1298
1299        output_scale = 0.01
1300        output_zp = 1
1301
1302        # ground truth
1303        C = qA.dequantize() + qB.dequantize()
1304        qC = torch.quantize_per_tensor(C, output_scale, output_zp, torch.quint8)
1305
1306        # quantized
1307        qC_hat_1 = torch.ops.quantized.add(qA, qB, output_scale, output_zp)
1308        qC_hat_2 = torch.ops.quantized.add(qB, qA, output_scale, output_zp)
1309
1310        self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_1.dequantize()))
1311        self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_2.dequantize()))
1312
1313    """Tests channel shuffle operation on quantized tensors."""
1314    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
1315                                              min_side=2, max_side=32, max_numel=10**5),
1316                       qparams=hu.qparams(dtypes=[torch.quint8])),
1317           groups=st.integers(2, 6))
1318    def test_channel_shuffle(self, X, groups):
1319        X, (scale, zero_point, torch_type) = X
1320        channels = X.shape[-3]
1321        iH, iW = X.shape[-2:]
1322        assume(channels % groups == 0)
1323
1324        a = torch.from_numpy(X)
1325        a = torch.rand(a.shape)
1326        a_out = torch.nn.functional.channel_shuffle(a, groups)
1327
1328        a_ref = torch.quantize_per_tensor(a_out, scale=scale,
1329                                          zero_point=zero_point, dtype=torch_type)
1330        a_ref = a_ref.dequantize()
1331        qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point,
1332                                       dtype=torch_type)
1333
1334        a_hat = torch.nn.functional.channel_shuffle(qa, groups)
1335        self.assertEqual(a_ref, a_hat.dequantize(),
1336                         msg="torch.nn.functional.channel_shuffle results are off")
1337
1338    """Tests 1D max pool operation on quantized tensors."""
1339    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=3,
1340                                              min_side=1, max_side=10),
1341                       qparams=hu.qparams()),
1342           kernel=st.sampled_from((3, 5, 7)),
1343           stride=st.sampled_from((None, 1, 2)),
1344           dilation=st.integers(1, 2),
1345           padding=st.integers(0, 2),
1346           ceil_mode=st.booleans())
1347    def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode):
1348        X, (scale, zero_point, torch_type) = X
1349        # Check constraints
1350        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1351        iW = X.shape[-1]
1352        oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
1353        assume(oW > 0)
1354
1355        a = torch.from_numpy(X)
1356        a_pool = torch.nn.functional.max_pool1d(a, kernel_size=kernel,
1357                                                stride=stride,
1358                                                padding=padding,
1359                                                dilation=dilation,
1360                                                ceil_mode=ceil_mode)
1361        a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
1362                                          zero_point=zero_point, dtype=torch_type)
1363        a_ref = a_ref.dequantize()
1364        qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point,
1365                                       dtype=torch_type)
1366
1367        ops_under_test = {
1368            "torch": torch.max_pool1d,
1369            "nn.functional": torch.nn.functional.max_pool1d,
1370            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.max_pool1d,
1371        }
1372
1373        for name, op in ops_under_test.items():
1374            a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
1375                       dilation=dilation, ceil_mode=ceil_mode)
1376            self.assertEqual(a_ref, a_hat.dequantize(),
1377                             msg=f"{name} results are off")
1378        # Test the ops.quantized separately, because None is not treated.
1379        a_hat = torch.ops.quantized.max_pool1d(
1380            qa, kernel_size=_single(kernel),
1381            stride=_single(kernel if stride is None else stride),
1382            padding=_single(padding), dilation=_single(dilation),
1383            ceil_mode=ceil_mode)
1384        self.assertEqual(a_ref, a_hat.dequantize(),
1385                         msg="ops.quantized.max_pool1d results are off")
1386
1387    # TODO: merge this test with test_max_pool2d
1388    """Tests 2D cudnn max pool operation on quantized tensors."""
1389    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
1390                                              min_side=1, max_side=10),
1391                       # cudnn's support for quantized pooling is limited to
1392                       # int8 currently
1393                       qparams=hu.qparams(dtypes=[torch.qint8])),
1394           kernel=st.sampled_from((3, 5, 7)),
1395           stride=st.sampled_from((None, 1, 2)),
1396           # currently there is no support for dilation for cudnn
1397           # pooling
1398           dilation=st.integers(1, 1),
1399           padding=st.integers(0, 2),
1400           ceil_mode=st.booleans())
1401    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
1402    @unittest.skipIf(TEST_CUDNN_VERSION <= 90100, "cuDNN maxpool2d mishandles -128 before v90100")
1403    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
1404    def test_max_pool2d_cudnn(self, X, kernel, stride, dilation, padding, ceil_mode):
1405        X, (scale, zero_point, torch_type) = X
1406        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1407        iH, iW = X.shape[-2:]
1408        oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
1409        assume(oH > 0)
1410        oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
1411        assume(oW > 0)
1412
1413        a = torch.from_numpy(X).to(device="cuda")
1414        a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel,
1415                                                stride=stride,
1416                                                padding=padding, dilation=dilation,
1417                                                ceil_mode=ceil_mode)
1418        a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
1419                                          zero_point=zero_point, dtype=torch_type)
1420        a_ref = a_ref.dequantize()
1421        qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point,
1422                                       dtype=torch_type)
1423
1424        # Test the ops.quantized separately, because None is not treated.
1425        a_hat = torch.ops.quantized.max_pool2d(
1426            qa, kernel_size=_pair(kernel),
1427            stride=_pair(kernel if stride is None else stride),
1428            padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode)
1429        self.assertEqual(a_ref, a_hat.dequantize(),
1430                         msg="ops.quantized.max_pool2d results are off")
1431
1432    """Tests 2D max pool operation on quantized tensors."""
1433    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
1434                                              min_side=1, max_side=10),
1435                       qparams=hu.qparams()),
1436           kernel=st.sampled_from((3, 5, 7)),
1437           stride=st.sampled_from((None, 1, 2)),
1438           dilation=st.integers(1, 2),
1439           padding=st.integers(0, 2),
1440           ceil_mode=st.booleans())
1441    def test_max_pool2d(self, X, kernel, stride, dilation, padding, ceil_mode):
1442        X, (scale, zero_point, torch_type) = X
1443        # Check constraints
1444        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1445        iH, iW = X.shape[-2:]
1446        oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
1447        assume(oH > 0)
1448        oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
1449        assume(oW > 0)
1450
1451        a = torch.from_numpy(X)
1452        a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel,
1453                                                stride=stride,
1454                                                padding=padding, dilation=dilation,
1455                                                ceil_mode=ceil_mode)
1456        a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
1457                                          zero_point=zero_point, dtype=torch_type)
1458        a_ref = a_ref.dequantize()
1459        qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point,
1460                                       dtype=torch_type)
1461
1462        ops_under_test = {
1463            "torch": torch.max_pool2d,
1464            "nn.functional": torch.nn.functional.max_pool2d,
1465            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.max_pool2d,
1466        }
1467
1468        for name, op in ops_under_test.items():
1469            a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
1470                       dilation=dilation, ceil_mode=ceil_mode)
1471            self.assertEqual(a_ref, a_hat.dequantize(),
1472                             msg=f"{name} results are off")
1473        # Test the ops.quantized separately, because None is not treated.
1474        a_hat = torch.ops.quantized.max_pool2d(
1475            qa, kernel_size=_pair(kernel),
1476            stride=_pair(kernel if stride is None else stride),
1477            padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode)
1478        self.assertEqual(a_ref, a_hat.dequantize(),
1479                         msg="ops.quantized.max_pool2d results are off")
1480
1481
1482    def test_max_pool2d_pt2e(self):
1483        kernel_list = [2, 3]
1484        stride_list = [1, 2]
1485        padding_list = [0, 2]
1486        dilation_list = [1, 2]
1487        ceil_mode_list = [False, True]
1488        channels_last_input = [False, True]
1489        options = itertools.product(kernel_list, stride_list, padding_list, dilation_list, ceil_mode_list, channels_last_input)
1490        for kernel, stride, padding, dilation, ceil_mode, channels_last in options:
1491            if padding >= (kernel // 2):
1492                # Continue with invalid input
1493                continue
1494            input = torch.randint(0, 8, (1, 3, 8, 8), dtype=torch.uint8)
1495            if channels_last:
1496                input = input.contiguous(memory_format=torch.channels_last)
1497            a_pool = torch.nn.functional.max_pool2d(input.to(torch.float32), kernel_size=kernel,
1498                                                    stride=stride, padding=padding, dilation=dilation,
1499                                                    ceil_mode=ceil_mode).to(torch.uint8)
1500            a_hat = torch.ops.quantized.max_pool2d(input, kernel_size=_pair(kernel),
1501                                                   stride=_pair(stride), padding=_pair(padding),
1502                                                   dilation=_pair(dilation), ceil_mode=ceil_mode)
1503            self.assertEqual(input.is_contiguous(), a_hat.is_contiguous(),
1504                             msg="ops.quantized.max_pool2d input output diff memory format")
1505            self.assertEqual(a_pool, a_hat,
1506                             msg="ops.quantized.max_pool2d results are off")
1507
1508
1509    """Tests 3D max pool operation on quantized tensors."""
1510    def test_max_pool3d(self):
1511        torch_types = [torch.qint8, torch.quint8]
1512        kernels = [1, 3]
1513        strides = [1, 3]
1514        dilations = [1, 3]
1515        paddings = [1, 3]
1516        ceil_modes = [True, False]
1517        options = itertools.product(torch_types, kernels, strides, dilations, paddings, ceil_modes)
1518        for torch_type, kernel, stride, dilation, padding, ceil_mode in options:
1519            X = torch.randint(20, 40, (2, 3, 16, 10, 10)).to(torch.float)
1520            scale = 15
1521            zero_point = 20
1522            # Check constraints for invalid input
1523            if not (kernel // 2 >= padding):
1524                continue
1525            iT, iH, iW = X.shape[-3:]
1526            oT = pool_output_shape(iT, kernel, padding, stride, dilation, ceil_mode)
1527            if not (oT > 0):
1528                continue
1529            oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
1530            if not (oH > 0):
1531                continue
1532            oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
1533            if not (oW > 0):
1534                continue
1535
1536            a_pool = torch.nn.functional.max_pool3d(X, kernel_size=kernel,
1537                                                    stride=stride,
1538                                                    padding=padding, dilation=dilation,
1539                                                    ceil_mode=ceil_mode)
1540            a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
1541                                              zero_point=zero_point, dtype=torch_type)
1542            a_ref = a_ref.dequantize()
1543            qa = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
1544                                           dtype=torch_type)
1545            ops_under_test = {
1546                "torch": torch.max_pool3d,
1547                "nn.functional": torch.nn.functional.max_pool3d,
1548            }
1549            for name, op in ops_under_test.items():
1550                a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
1551                           dilation=dilation, ceil_mode=ceil_mode)
1552                self.assertEqual(a_ref, a_hat.dequantize(),
1553                                 msg=f"{name} results are off")
1554
1555    """Tests max pool operation on NHWC quantized tensors."""
1556    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
1557                                              min_side=1, max_side=10),
1558                       qparams=hu.qparams()),
1559           kernel=st.sampled_from((3, 5, 7)),
1560           stride=st.sampled_from((None, 1, 2)),
1561           dilation=st.integers(1, 2),
1562           padding=st.integers(0, 2),
1563           ceil_mode=st.booleans())
1564    def test_max_pool2d_nhwc(self, X, kernel, stride, dilation, padding, ceil_mode):
1565        X, (scale, zero_point, torch_type) = X
1566        # Ensure we hit the vectorized paths
1567        # 176 = 128 + 32 + 16
1568        # 128 hits the interleaved path
1569        # 32 hits the non-interleaved path
1570        # 16 hits the scalar path
1571        if X.shape[1] < 176:
1572            X = np.repeat(X, 176 / X.shape[1], 1)
1573        # Check constraints
1574        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1575        iH, iW = X.shape[-2:]
1576        oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
1577        assume(oH > 0)
1578        oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
1579        assume(oW > 0)
1580
1581        X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
1582        a = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
1583        a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel,
1584                                                stride=stride,
1585                                                padding=padding, dilation=dilation,
1586                                                ceil_mode=ceil_mode)
1587        a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
1588                                          zero_point=zero_point, dtype=torch_type)
1589        a_ref = a_ref.dequantize()
1590        qa = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale, zero_point=zero_point,
1591                                       dtype=torch_type).permute([0, 3, 1, 2])
1592        self.assertTrue(qa.stride() != sorted(qa.stride()))
1593
1594        ops_under_test = {
1595            "torch": torch.max_pool2d,
1596            "nn.functional": torch.nn.functional.max_pool2d,
1597            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.max_pool2d,
1598        }
1599
1600        for name, op in ops_under_test.items():
1601            a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
1602                       dilation=dilation, ceil_mode=ceil_mode)
1603            self.assertTrue(a_hat.stride() != sorted(a_hat.stride()))
1604            self.assertEqual(a_ref, a_hat.dequantize(),
1605                             msg=f"{name} results are off")
1606        # Test the ops.quantized separately, because None is not treated.
1607        a_hat = torch.ops.quantized.max_pool2d(
1608            qa, kernel_size=_pair(kernel),
1609            stride=_pair(kernel if stride is None else stride),
1610            padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode)
1611        self.assertEqual(a_ref, a_hat.dequantize(),
1612                         msg="ops.quantized.max_pool2d results are off")
1613
1614    """Tests 3D max pool operation on quantized channel_last tensors."""
1615    def test_max_pool3d_nhwc(self):
1616        torch_types = [torch.qint8, torch.quint8]
1617        kernels = [1, 3]
1618        strides = [1, 3]
1619        dilations = [1, 3]
1620        paddings = [1, 3]
1621        ceil_modes = [True, False]
1622        options = itertools.product(torch_types, kernels, strides, dilations, paddings, ceil_modes)
1623        for torch_type, kernel, stride, dilation, padding, ceil_mode in options:
1624            X = torch.randint(20, 40, (2, 67, 16, 10, 10)).to(torch.float)
1625            X_copy = copy.deepcopy(X)
1626            X = X.contiguous(memory_format=torch.channels_last_3d)
1627            scale = 15
1628            zero_point = 20
1629            # Check constraints for invalid input
1630            if not (kernel // 2 >= padding):
1631                continue
1632            iT, iH, iW = X.shape[-3:]
1633            oT = pool_output_shape(iT, kernel, padding, stride, dilation, ceil_mode)
1634            if not (oT > 0):
1635                continue
1636            oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
1637            if not (oH > 0):
1638                continue
1639            oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
1640            if not (oW > 0):
1641                continue
1642
1643            a_pool = torch.nn.functional.max_pool3d(X, kernel_size=kernel,
1644                                                    stride=stride,
1645                                                    padding=padding, dilation=dilation,
1646                                                    ceil_mode=ceil_mode)
1647            a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
1648                                              zero_point=zero_point, dtype=torch_type)
1649            a_ref = a_ref.dequantize()
1650            qa = torch.quantize_per_tensor(X_copy, scale=scale, zero_point=zero_point,
1651                                           dtype=torch_type)
1652            qa = qa.contiguous(memory_format=torch.channels_last_3d)
1653            ops_under_test = {
1654                "torch": torch.max_pool3d,
1655                "nn.functional": torch.nn.functional.max_pool3d,
1656            }
1657            for name, op in ops_under_test.items():
1658                a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
1659                           dilation=dilation, ceil_mode=ceil_mode)
1660                self.assertEqual(a_ref, a_hat.dequantize(),
1661                                 msg=f"{name} results are off")
1662
1663    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
1664                                              min_side=5, max_side=10),
1665                       qparams=hu.qparams(dtypes=torch.quint8)),
1666           kernel=st.sampled_from((3, 5)),
1667           stride=st.sampled_from((None, 1, 2)),
1668           padding=st.integers(0, 2),
1669           ceil_mode=st.sampled_from((True, False)),
1670           count_include_pad=st.sampled_from((True, False)),
1671           divisor_override=st.sampled_from((None, None)))
1672    def test_avg_pool2d(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override):
1673        """
1674        Note: we currently cannot test the divisor_override, because quantized op will clamp the result
1675        within range. However, the float op will not.
1676        """
1677        X, (scale, zero_point, torch_type) = X
1678
1679        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1680        iH, iW = X.shape[-2:]
1681        oH = pool_output_shape(iH, kernel, padding, stride, dilation=1)
1682        assume(oH > 0)
1683        oW = pool_output_shape(iW, kernel, padding, stride, dilation=1)
1684        assume(oW > 0)
1685        X = torch.from_numpy(X)
1686        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
1687                                       dtype=torch_type)
1688        X = qX.dequantize()
1689        # Run reference on float tensor and then quantize the result for comparison
1690        X_ref = torch.nn.functional.avg_pool2d(
1691            X, kernel_size=kernel, stride=stride, padding=padding,
1692            ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override)
1693        ops_under_test = {
1694            "nn.functional": torch.nn.functional.avg_pool2d,
1695            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool2d,
1696        }
1697        error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
1698        for name, op in ops_under_test.items():
1699            qX_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
1700                        count_include_pad=count_include_pad, divisor_override=divisor_override)
1701            qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
1702                                               dtype=torch_type)
1703
1704            self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
1705                             msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr()))
1706            self.assertEqual(scale, qX_hat.q_scale(),
1707                             msg=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
1708            self.assertEqual(zero_point, qX_hat.q_zero_point(),
1709                             msg=error_message.format(name + '.zero_point', scale,
1710                                                      qX_hat.q_zero_point()))
1711
1712    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
1713                                              min_side=5, max_side=10),
1714                       qparams=hu.qparams(dtypes=torch.qint8)),
1715           kernel=st.sampled_from((4, 5)),
1716           stride=st.sampled_from((None, 1, 2)),
1717           padding=st.integers(0, 2),
1718           ceil_mode=st.sampled_from((True, False)),
1719           count_include_pad=st.sampled_from((True, False)),
1720           divisor_override=st.sampled_from((None, None)))
1721    def test_avg_pool2d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override):
1722        """
1723        Note: 1) we currently cannot test the divisor_override, because quantized op will clamp the result
1724        within range. However, the float op will not.
1725        2) we cannot test the qint32, since the float point precision is much lower than int32 for big number,
1726        which will make the test be very flaky.
1727        """
1728        X, (scale, zero_point, torch_type) = X
1729        H, W = X.shape[-2:]
1730
1731
1732        if X.shape[1] < 176:
1733            X = np.repeat(X, 176 / X.shape[1], 1)
1734
1735        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1736        iH, iW = X.shape[-2:]
1737        oH = pool_output_shape(iH, kernel, padding, stride, dilation=1)
1738        assume(oH > 0)
1739        oW = pool_output_shape(iW, kernel, padding, stride, dilation=1)
1740        assume(oW > 0)
1741
1742        X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
1743
1744        qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale,
1745                                       zero_point=zero_point, dtype=torch_type).permute([0, 3, 1, 2])
1746        X = qX.dequantize()
1747
1748        # Run reference on int_repr + round to avoid double rounding error.
1749        X_ref = torch.nn.functional.avg_pool2d(
1750            X, kernel_size=kernel, stride=stride, padding=padding,
1751            ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override)
1752
1753        self.assertTrue(qX.stride() != sorted(qX.stride()))
1754        ops_under_test = {
1755            "nn.functional": torch.nn.functional.avg_pool2d,
1756            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool2d,
1757        }
1758        error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
1759        for name, op in ops_under_test.items():
1760            X_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
1761                       count_include_pad=count_include_pad, divisor_override=divisor_override)
1762            self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
1763            qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(),
1764                                               dtype=torch_type)
1765
1766            self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
1767                             msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr()))
1768            self.assertEqual(scale, X_hat.q_scale(),
1769                             msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
1770            self.assertEqual(zero_point, X_hat.q_zero_point(),
1771                             msg=error_message.format(name + '.zero_point', scale,
1772                             X_hat.q_zero_point()))
1773
1774    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5,
1775                                              min_side=5, max_side=10),
1776                       qparams=hu.qparams(dtypes=torch.quint8)),
1777           kernel=st.sampled_from((3, 5)),
1778           stride=st.sampled_from((None, 1, 2)),
1779           padding=st.integers(0, 2),
1780           ceil_mode=st.sampled_from((True, False)),
1781           count_include_pad=st.sampled_from((True, False)),
1782           divisor_override=st.sampled_from((None, None)))
1783    def test_avg_pool3d(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override):
1784        """
1785        Note: we currently cannot test the divisor_override, because quantized op will clamp the result
1786        within range. However, the float op will not.
1787        """
1788        X, (scale, zero_point, torch_type) = X
1789
1790        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1791        iD, iH, iW = X.shape[-3:]
1792        oD = pool_output_shape(iD, kernel, padding, stride, dilation=1)
1793        assume(oD > 0)
1794        oH = pool_output_shape(iH, kernel, padding, stride, dilation=1)
1795        assume(oH > 0)
1796        oW = pool_output_shape(iW, kernel, padding, stride, dilation=1)
1797        assume(oW > 0)
1798
1799        X = torch.from_numpy(X)
1800        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
1801                                       dtype=torch_type)
1802        X = qX.dequantize()
1803        # Run reference on float tensor and then quantize the result for comparison
1804        X_ref = torch.nn.functional.avg_pool3d(
1805            X, kernel_size=kernel, stride=stride, padding=padding,
1806            ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override)
1807
1808        ops_under_test = {
1809            "nn.functional": torch.nn.functional.avg_pool3d,
1810            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool3d,
1811        }
1812        error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
1813        for name, op in ops_under_test.items():
1814            qX_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
1815                        count_include_pad=count_include_pad, divisor_override=divisor_override)
1816            qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(),
1817                                               dtype=torch_type)
1818            self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
1819                             msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr()))
1820            self.assertEqual(scale, qX_hat.q_scale(),
1821                             msg=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
1822            self.assertEqual(zero_point, qX_hat.q_zero_point(),
1823                             msg=error_message.format(name + '.zero_point', scale,
1824                                                      qX_hat.q_zero_point()))
1825
1826    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5,
1827                                              min_side=5, max_side=10),
1828                       qparams=hu.qparams(dtypes=torch.qint8)),
1829           kernel=st.sampled_from((4, 5)),
1830           stride=st.sampled_from((None, 1, 2)),
1831           padding=st.integers(0, 2),
1832           ceil_mode=st.sampled_from((True, False)),
1833           count_include_pad=st.sampled_from((True, False)),
1834           divisor_override=st.sampled_from((None, None)))
1835    def test_avg_pool3d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override):
1836        """
1837        Note: 1) we currently cannot test the divisor_override, because quantized op will clamp the result
1838        within range. However, the float op will not.
1839        2) we cannot test the qint32, since the float point precision is much lower than int32 for big number,
1840        which will make the test be very flaky.
1841        """
1842        X, (scale, zero_point, torch_type) = X
1843        D, H, W = X.shape[-3:]
1844
1845
1846        if X.shape[1] < 176:
1847            X = np.repeat(X, 176 / X.shape[1], 1)
1848
1849        assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
1850        iD, iH, iW = X.shape[-3:]
1851        oD = pool_output_shape(iD, kernel, padding, stride, dilation=1)
1852        assume(oD > 0)
1853        oH = pool_output_shape(iH, kernel, padding, stride, dilation=1)
1854        assume(oH > 0)
1855        oW = pool_output_shape(iW, kernel, padding, stride, dilation=1)
1856        assume(oW > 0)
1857
1858        X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 4, 1]))
1859
1860        qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale,
1861                                       zero_point=zero_point, dtype=torch_type).permute([0, 4, 1, 2, 3])
1862        X = qX.dequantize()
1863
1864        # Run reference on int_repr + round to avoid double rounding error.
1865        X_ref = torch.nn.functional.avg_pool3d(
1866            X, kernel_size=kernel, stride=stride, padding=padding,
1867            ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override)
1868
1869        self.assertTrue(qX.stride() != sorted(qX.stride()))
1870        ops_under_test = {
1871            "nn.functional": torch.nn.functional.avg_pool3d,
1872            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool3d,
1873        }
1874        error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
1875        for name, op in ops_under_test.items():
1876            X_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
1877                       count_include_pad=count_include_pad, divisor_override=divisor_override)
1878            self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
1879            qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(),
1880                                               dtype=torch_type)
1881
1882            self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0,
1883                             msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr()))
1884            self.assertEqual(scale, X_hat.q_scale(),
1885                             msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
1886            self.assertEqual(zero_point, X_hat.q_zero_point(),
1887                             msg=error_message.format(name + '.zero_point', scale,
1888                             X_hat.q_zero_point()))
1889
1890    """Tests adaptive average pool operation on NHWC quantized tensors."""
1891    def test_adaptive_avg_pool2d_nhwc(self):
1892        side_lens = (range(1, 10))
1893        dim_lens = (range(3, 4))
1894        torch_type = torch.qint8
1895        zero_points = (0, 1)
1896        combined = [side_lens, dim_lens, zero_points]
1897        test_cases = itertools.product(*combined)
1898        for test_case in test_cases:
1899            output_size_h = random.randint(1, 10)
1900            output_size_w = random.randint(1, 10)
1901            side_len, dim_len, zero_point = test_case
1902            shapes = [side_len] * dim_len
1903            X, X_scale, X_zero_point = \
1904                _get_random_tensor_and_q_params(shapes, 1.0, zero_point)
1905            X = np.array(X)
1906            scale = 1
1907            H, W = X.shape[-2:]
1908            output_size_h = min(output_size_h, H)
1909            output_size_w = min(output_size_w, W)
1910            if output_size_h == output_size_w:
1911                output_size = output_size_h
1912            else:
1913                output_size = (output_size_h, output_size_w)
1914
1915            if X.shape[1] < 176:
1916                X = np.repeat(X, 176 / X.shape[1], 1)
1917
1918            if X.ndim == 4:
1919                X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
1920                X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
1921                qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw),
1922                                               scale=scale,
1923                                               zero_point=zero_point,
1924                                               dtype=torch_type).permute([0, 3, 1, 2])
1925            else:  # ndim == 3
1926                X_nchw = np.ascontiguousarray(X.transpose([1, 2, 0]))
1927                X = torch.from_numpy(X_nchw).permute([2, 0, 1])
1928                qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw),
1929                                               scale=scale,
1930                                               zero_point=zero_point,
1931                                               dtype=torch_type).permute([2, 0, 1])
1932
1933            # Run reference on int_repr + round to avoid double rounding error.
1934            X_ref = torch.nn.functional.adaptive_avg_pool2d(qX.int_repr().to(torch.double), output_size).round()
1935
1936            self.assertTrue(qX.stride() != sorted(qX.stride()))
1937
1938            ops_under_test = {
1939                "nn.functional": torch.nn.functional.adaptive_avg_pool2d,
1940                "ao.nn.quantized.functional":
1941                    torch.ao.nn.quantized.functional.adaptive_avg_pool2d,
1942            }
1943            error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
1944            for name, op in ops_under_test.items():
1945                X_hat = op(qX, output_size=output_size)
1946                self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
1947                self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0, rtol=0,
1948                                 msg=error_message.format(name, X_ref, X_hat.int_repr()),
1949                                 exact_dtype=False)
1950                self.assertEqual(scale, X_hat.q_scale(),
1951                                 msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
1952                self.assertEqual(zero_point, X_hat.q_zero_point(),
1953                                 msg=error_message.format(name + '.zero_point', scale,
1954                                 X_hat.q_zero_point()))
1955
1956    @unittest.skip("not currently working and feature isn't used")
1957    def test_adaptive_avg_pool(self):
1958
1959        side_lens = (range(1, 10))
1960        dim_lens = (range(3, 5))
1961        torch_type = torch.qint8
1962        zero_points = (0, 1)
1963        combined = [side_lens, dim_lens, zero_points]
1964        test_cases = itertools.product(*combined)
1965        for test_case in test_cases:
1966            output_size_d = random.randint(1, 10)
1967            output_size_h = random.randint(1, 10)
1968            output_size_w = random.randint(1, 10)
1969            side_len, dim_len, zero_point = test_case
1970            shapes = [side_len] * dim_len
1971            X, X_scale, X_zero_point = \
1972                _get_random_tensor_and_q_params(shapes, 1.0, zero_point)
1973            X = np.array(X)
1974            scale = 1
1975            ndim = X.ndim
1976            dim_to_check = []
1977            if ndim <= 4:
1978                dim_to_check.append(2)
1979            if ndim >= 4:
1980                dim_to_check.append(3)
1981
1982            D, H, W = X.shape[-3:]
1983            output_size_d = min(output_size_d, D)
1984            output_size_h = min(output_size_h, H)
1985            output_size_w = min(output_size_w, W)
1986
1987            X = torch.from_numpy(X)
1988            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
1989                                           dtype=torch_type)
1990
1991            for dim in dim_to_check:
1992                if dim == 2:
1993                    if output_size_h == output_size_w:
1994                        output_size = output_size_h
1995                    else:
1996                        output_size = (output_size_h, output_size_w)
1997                elif dim == 3:
1998                    if output_size_d == output_size_h == output_size_w:
1999                        output_size = output_size_h
2000                    else:
2001                        output_size = (output_size_d, output_size_h, output_size_w)
2002
2003                # Run reference on int_repr + round to avoid double rounding error.
2004                ref_op = getattr(torch.nn.functional, f'adaptive_avg_pool{dim}d')
2005                X_ref = ref_op(qX.int_repr().to(torch.float), output_size).round()
2006
2007                ops_under_test = {
2008                    "nn.functional":
2009                        getattr(torch.nn.functional, f'adaptive_avg_pool{dim}d'),
2010                    "nn.quantized.functional":
2011                        getattr(torch.ao.nn.quantized.functional, f'adaptive_avg_pool{dim}d'),
2012                    "ao.nn.quantized.functional":
2013                        getattr(torch.ao.nn.quantized.functional, f'adaptive_avg_pool{dim}d')
2014                }
2015
2016                error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
2017
2018                for name, op in ops_under_test.items():
2019                    # TODO: torch.cuda.is_available() should be swapped for a flag that checks if cudnn
2020                    # is enabled in the build when cudnn supports adaptive average pooling
2021                    devices = ["cpu", "cuda"] if (dim == 2 and torch.cuda.is_available()) else ["cpu"]
2022                    for device in devices:
2023                        qX_hat = op(qX.to(device=device), output_size=output_size)
2024                        self.assertEqual(
2025                            X_ref, qX_hat.int_repr(), atol=1.0,
2026                            rtol=0, msg=error_message.format(name, X_ref, qX_hat), exact_dtype=False)
2027                        self.assertEqual(
2028                            scale, qX_hat.q_scale(),
2029                            msg=error_message.format(name + '.scale', scale,
2030                                                     qX_hat.q_scale()))
2031                        self.assertEqual(
2032                            zero_point, qX_hat.q_zero_point(),
2033                            msg=error_message.format(name + '.zero_point', scale,
2034                                                     qX_hat.q_zero_point()))
2035
2036    """Tests adaptive average pool operation on NHWC quantized tensors."""
2037    def test_adaptive_avg_pool3d_ndhwc(self):
2038        side_lens = (range(1, 10))
2039        dim_lens = (range(4, 5))
2040        torch_type = torch.qint8
2041        zero_point = 0
2042        combined = [side_lens, dim_lens]
2043        test_cases = itertools.product(*combined)
2044        for test_case in test_cases:
2045            output_size_d = random.randint(1, 10)
2046            output_size_h = random.randint(1, 10)
2047            output_size_w = random.randint(1, 10)
2048            side_len, dim_len = test_case
2049            shapes = [side_len] * dim_len
2050            X, X_scale, X_zero_point = \
2051                _get_random_tensor_and_q_params(shapes, 1.0, zero_point)
2052            X = np.array(X)
2053            scale = 1
2054            D, H, W = X.shape[-3:]
2055            output_size_d = min(output_size_d, D)
2056            output_size_h = min(output_size_h, H)
2057            output_size_w = min(output_size_w, W)
2058            if output_size_d == output_size_h == output_size_w:
2059                output_size = output_size_h
2060            else:
2061                output_size = (output_size_d, output_size_h, output_size_w)
2062
2063            if X.shape[1] < 176:
2064                X = np.repeat(X, 176 / X.shape[1], 1)
2065
2066            if X.ndim == 5:
2067                X_ncdhw = np.ascontiguousarray(X.transpose([0, 2, 3, 4, 1]))
2068                X = torch.from_numpy(X_ncdhw).permute([0, 4, 1, 2, 3])
2069                qX = torch.quantize_per_tensor(torch.from_numpy(X_ncdhw),
2070                                               scale=scale,
2071                                               zero_point=zero_point,
2072                                               dtype=torch_type).permute([0, 4, 1, 2, 3])
2073            else:  # ndim == 4
2074                X_ncdhw = np.ascontiguousarray(X.transpose([1, 2, 3, 0]))
2075                X = torch.from_numpy(X_ncdhw).permute([3, 0, 1, 2])
2076                qX = torch.quantize_per_tensor(torch.from_numpy(X_ncdhw),
2077                                               scale=scale,
2078                                               zero_point=zero_point,
2079                                               dtype=torch_type).permute([3, 0, 1, 2])
2080
2081            # Run reference on int_repr + round to avoid double rounding error.
2082            X_ref = torch.nn.functional.adaptive_avg_pool3d(
2083                qX.int_repr().to(torch.double), output_size).round()
2084
2085            self.assertTrue(qX.stride() != sorted(qX.stride()))
2086
2087            ops_under_test = {
2088                "nn.functional": torch.nn.functional.adaptive_avg_pool3d,
2089                "ao.nn.quantized.functional":
2090                    torch.ao.nn.quantized.functional.adaptive_avg_pool3d,
2091            }
2092            error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
2093            for name, op in ops_under_test.items():
2094                X_hat = op(qX, output_size=output_size)
2095                self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
2096                self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0, rtol=0,
2097                                 msg=error_message.format(name, X_ref, X_hat.int_repr()),
2098                                 exact_dtype=False)
2099                self.assertEqual(scale, X_hat.q_scale(),
2100                                 msg=error_message.format(name + '.scale', scale, X_hat.q_scale()))
2101                self.assertEqual(zero_point, X_hat.q_zero_point(),
2102                                 msg=error_message.format(name + '.zero_point', scale,
2103                                 X_hat.q_zero_point()))
2104
2105    def test_qtopk(self):
2106        x_dims = [3, 4]  # Num elements in the shape
2107        sides = [3, 5]  # Side of the tensor generated
2108        dims = [0, 1, 2, 3]  # dimension over which to perform topk
2109        largest = [False, True]  # Return largest or smallest element
2110        sorted = [False, True]  # Return sorted or not
2111        dtypes = [torch.qint8, torch.quint8]
2112        is_nhwc = [False, True]  # Is input in the NHWC format?
2113
2114        test_cases = itertools.product(x_dims, sides, dims, largest, sorted, dtypes, is_nhwc)
2115        k = 2
2116        for x_dim, side, dim, larg, sort, dtype, nhwc in test_cases:
2117            if nhwc and x_dim != 4:  # NHWC requires 4 dimensions
2118                continue
2119            if dim >= x_dim:  # Dimension to find top-k for should exist
2120                continue
2121            shape = [side] * x_dim
2122            X, scale, zp = _get_random_tensor_and_q_params(shape, 1.0, dtype)
2123            qX = torch.quantize_per_tensor(X, scale, zp, dtype)
2124
2125            if nhwc:
2126                qX = qX.permute([0, 3, 1, 2])
2127                X = np.transpose(X, [0, 3, 1, 2])
2128
2129            unquantized_out = torch.topk(qX.dequantize(), k, dim=dim, largest=larg, sorted=sort)
2130
2131            values = torch.quantize_per_tensor(X, scale, zp, dtype)
2132            indices = torch.tensor(X).long()
2133
2134            quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort)
2135
2136            assert len(unquantized_out) == len(quantized_out)
2137            torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0])
2138            torch.testing.assert_close(quantized_out[1], unquantized_out[1])
2139
2140    """Tests quantize concatenation (both fused and not)."""
2141    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
2142                                              min_side=1, max_side=10),
2143                       qparams=hu.qparams()),
2144           num=st.integers(1, 4),
2145           dim=st.integers(1, 4),
2146           relu=st.booleans())
2147    def test_cat(self, X, num, dim, relu):
2148        tensors_q = []
2149        tensors_ref = []
2150        X, (scale, zero_point, torch_type) = X
2151        assume(dim < X.ndim)
2152        X = torch.from_numpy(X)
2153        new_shape = np.array(X.shape)
2154        new_shape[dim] = 0
2155        for idx in range(num):
2156            tensors_q.append(torch.quantize_per_tensor(X, scale, zero_point,
2157                                                       torch_type))
2158            tensors_ref.append(X)
2159            new_shape[dim] += tensors_ref[-1].shape[dim]
2160
2161        cat_ref = torch.cat(tensors_ref, dim=dim)
2162        cat_ref = torch.quantize_per_tensor(cat_ref, scale, zero_point, torch_type)
2163        cat_ref = cat_ref.dequantize()
2164
2165        if relu:
2166            cat_ref = F.relu(cat_ref)
2167            q_cat_op = torch.ops.quantized.cat_relu
2168            q_cat_out_op = torch.ops.quantized.cat_relu_out
2169        else:
2170            q_cat_op = torch.ops.quantized.cat
2171            q_cat_out_op = torch.ops.quantized.cat_out
2172
2173        cat_q = q_cat_op(tensors_q, dim=dim, scale=scale,
2174                         zero_point=zero_point)
2175        cat_q = cat_q.dequantize()
2176        np.testing.assert_equal(cat_ref.numpy(), cat_q.numpy())
2177
2178        cat_q_out = torch._empty_affine_quantized(
2179            list(new_shape), scale=scale,
2180            zero_point=zero_point, dtype=torch_type)
2181        q_cat_out_op(tensors_q, dim=dim, out=cat_q_out)
2182        cat_q_out = cat_q_out.dequantize()
2183        np.testing.assert_equal(cat_ref.numpy(), cat_q_out.numpy())
2184
2185        # Test the cat on per-channel quantized tensor.
2186        ch_axis = 1
2187        scales = torch.from_numpy(np.array([1.0] * X.shape[ch_axis]))
2188        scales = scales.to(torch.float64)
2189        zero_points = torch.from_numpy(np.array([0] * X.shape[ch_axis]))
2190        zero_points = zero_points.to(torch.long)
2191        tensors_q[0] = torch.quantize_per_channel(
2192            X, scales, zero_points, axis=ch_axis, dtype=torch_type)
2193        with self.assertRaisesRegex(RuntimeError, "supported.*cat"):
2194            cat_q = q_cat_op(tensors_q, dim=ch_axis, scale=scale,
2195                             zero_point=zero_point)
2196
2197    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
2198                                              min_side=5, max_side=10),
2199                       qparams=hu.qparams()),
2200           size=st.sampled_from((1, 3, 5, 10)),
2201           mode=st.sampled_from(("bilinear", "nearest", "nearest-exact")),
2202           scale_factor=st.sampled_from((None, 1.5, 2.0)),
2203           align_corners=st.sampled_from((True, False)),
2204           nhwc_layout=st.sampled_from((True, False)))
2205    def test_interpolate(self, X, size, mode, scale_factor, align_corners, nhwc_layout):
2206        """
2207        This test cover upsample_nearest2d and upsample_bilinear2d
2208        """
2209        X, (scale, zero_point, torch_type) = X
2210
2211        if scale_factor is not None:
2212            size = None
2213        if mode in ("nearest", "nearest-exact"):
2214            align_corners = None
2215
2216        if nhwc_layout:
2217            if X.shape[1] < 176:
2218                X = np.repeat(X, 176 / X.shape[1], 1)
2219
2220            X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
2221            X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
2222
2223            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2224                                           dtype=torch_type).permute([0, 3, 1, 2])
2225        else:
2226            X = torch.from_numpy(X)
2227            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2228                                           dtype=torch_type)
2229
2230        X_ref = torch.nn.functional.interpolate(
2231            qX.int_repr().to(torch.float), size=size, scale_factor=scale_factor,
2232            mode=mode, align_corners=align_corners)
2233
2234        ops_under_test = {
2235            "nn.functional": torch.nn.functional.interpolate,
2236            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.interpolate,
2237        }
2238        error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
2239        for name, op in ops_under_test.items():
2240            qX_hat = op(qX, size=size, scale_factor=scale_factor,
2241                        mode=mode, align_corners=align_corners)
2242            self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0,
2243                             msg=f"{name} results are off: qX_hat={qX_hat.int_repr()} X_ref={X_ref}",
2244                             exact_dtype=False)
2245            self.assertEqual(scale, qX_hat.q_scale(),
2246                             msg=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
2247            self.assertEqual(zero_point, qX_hat.q_zero_point(),
2248                             msg=error_message.format(name + '.zero_point', scale,
2249                                                      qX_hat.q_zero_point()))
2250
2251    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5,
2252                                              min_side=5, max_side=10),
2253                       qparams=hu.qparams()),
2254           size=st.sampled_from((1, 3, 5, 5, 10)),
2255           mode=st.sampled_from(("nearest", "nearest-exact")),
2256           scale_factor=st.sampled_from((None, 1.5, 2.0)),
2257           align_corners=st.sampled_from((True, False)),
2258           nhwc_layout=st.sampled_from((True, False)))
2259    def test_interpolate3d(self, X, size, mode, scale_factor, align_corners, nhwc_layout):
2260        """
2261        This test cover upsample_nearest3d
2262        """
2263        X, (scale, zero_point, torch_type) = X
2264        if scale_factor is not None:
2265            size = None
2266
2267        align_corners = None
2268
2269        if nhwc_layout:
2270            if X.shape[1] < 176:
2271                X = np.repeat(X, 176 / X.shape[1], 1)
2272
2273            X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 4, 1]))
2274            X = torch.from_numpy(X_nchw).permute([0, 4, 1, 2, 3])
2275
2276            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2277                                           dtype=torch_type).permute([0, 4, 1, 2, 3])
2278        else:
2279            X = torch.from_numpy(X)
2280            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2281                                           dtype=torch_type)
2282        X_ref = torch.nn.functional.interpolate(
2283            qX.int_repr().to(torch.float), size=size, scale_factor=scale_factor,
2284            mode=mode, align_corners=align_corners)
2285
2286        ops_under_test = {
2287            "nn.functional": torch.nn.functional.interpolate,
2288            "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.interpolate,
2289        }
2290
2291        error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
2292        for name, op in ops_under_test.items():
2293            qX_hat = op(qX, size=size, scale_factor=scale_factor,
2294                        mode=mode, align_corners=align_corners)
2295            self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0,
2296                             msg=f"{name} results are off: qX_hat={qX_hat.int_repr()}, X_ref={X_ref}", exact_dtype=False)
2297            self.assertEqual(scale, qX_hat.q_scale(),
2298                             msg=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
2299            self.assertEqual(zero_point, qX_hat.q_zero_point(),
2300                             msg=error_message.format(name + '.zero_point', scale,
2301                                                      qX_hat.q_zero_point()))
2302
2303    """Tests quantize concatenation (both fused and not)."""
2304    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
2305                                              min_side=1, max_side=10),
2306                       qparams=hu.qparams()),
2307           relu=st.booleans())
2308    def test_cat_nhwc(self, X, relu):
2309        # X is NHWC
2310        X, (scale, zero_point, torch_type) = X
2311
2312        # Tile out X so # channels is > 64
2313        X = np.repeat(X, 70 / X.shape[3], 3)
2314        X = torch.from_numpy(np.ascontiguousarray(X))
2315        Y = X.clone()
2316        Y = torch.from_numpy(np.ascontiguousarray(Y))
2317        # We add a fast path in qcat: when inputs share the same scale and zero_point,
2318        # it will go direct memcpy instead of dequant-cat-quant.
2319        for scaleX, scaleY in ((scale, scale), (scale, scale * 1.1)):
2320            # Here, we quantize and get quantized tensors in NHWC for both dims and strides. The
2321            # permute switches it so that the tensor looks like NCHW but it laid out in memory as
2322            # NHWC.
2323            qX = torch.quantize_per_tensor(X, scaleX, zero_point, torch_type).permute([0, 3, 1, 2])
2324            qY = torch.quantize_per_tensor(Y, scaleY, zero_point, torch_type).permute([0, 3, 1, 2])
2325
2326            ref = torch.cat([qX.dequantize(), qY.dequantize()], dim=1)
2327            if relu:
2328                ref[ref < 0] = 0.0
2329            ref = torch.quantize_per_tensor(ref, scale=scale, zero_point=zero_point, dtype=torch_type)
2330
2331            if relu:
2332                out = torch.ops.quantized.cat_relu(
2333                    [qX, qY], dim=1, scale=scale, zero_point=zero_point)
2334            else:
2335                out = torch.ops.quantized.cat([qX, qY], dim=1, scale=scale, zero_point=zero_point)
2336
2337            torch.testing.assert_close(out.dequantize(), ref.dequantize())
2338            self.assertNotEqual(out.stride(), sorted(out.stride()))
2339
2340    @override_qengines
2341    def test_mean(self):
2342        scale_list = (1, 0.25)
2343        zero_point_list = (0, 2)
2344        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4, 4))
2345        dtypes = (torch.quint8, torch.qint8)
2346        dims = ((), (-1,), (0,), (1,), (2,), (3,), (0, 1), (1, 2), (3, 4))
2347        test_cases = itertools.product(scale_list, zero_point_list, shapes, dtypes, dims)
2348        op = torch.mean
2349        for scale, zp, shape, dtype, dim in test_cases:
2350            if not all(d < len(shape) for d in dim):
2351                continue
2352            X = torch.randn(*shape) * 10
2353            qX = torch.quantize_per_tensor(X, scale, zp, dtype)
2354            Y = op(qX.dequantize(), dim)
2355            Y = torch.quantize_per_tensor(Y, scale, zp, dtype).dequantize()
2356            qY = op(qX, dim)
2357            self.assertEqual(Y, qY.dequantize())
2358
2359    @skipIfNoQNNPACK
2360    @given(keep=st.booleans())
2361    def test_quantized_mean_qnnpack(self, keep):
2362        with override_quantized_engine("qnnpack"):
2363            # using multiple of 4 sizes to satisfy pytorch_q8gavgpool_ukernel_up8xm__sse2() 4-byte alignment demand under ASAN
2364            in_dim = (4, 4, 4, 4)
2365            if keep:
2366                out_dim = (4, 4, 1, 1)
2367            else:
2368                out_dim = (4, 4)
2369            X = torch.ones(in_dim)
2370            Y = torch.ones(out_dim)
2371            XQ = torch.quantize_per_tensor(X, scale=0.2, zero_point=0, dtype=torch.quint8)
2372            YQ = torch.quantize_per_tensor(Y, scale=0.2, zero_point=0, dtype=torch.quint8)
2373            MQ = XQ.mean((2, 3), keepdim=keep)
2374            self.assertTrue(torch.equal(MQ, YQ))
2375
2376    @override_qengines
2377    def test_std(self):
2378        scale_list = (1, 0.25)
2379        zero_point_list = (0, 2)
2380        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4, 4))
2381        dtypes = (torch.quint8, torch.qint8)
2382        dims = ((), (-1,), (0,), (1,), (2,), (3,), (0, 1), (1, 2), (3, 4))
2383        unbiased_list = (True, False)
2384        keep_dim_list = (True, False)
2385        test_cases = itertools.product(scale_list, zero_point_list, shapes,
2386                                       dtypes, dims, unbiased_list, keep_dim_list)
2387        op = torch.std
2388        for scale, zp, shape, dtype, dim, unbiased, keep_dim in test_cases:
2389            if not all(d < len(shape) for d in dim):
2390                continue
2391            X = torch.randn(*shape) * 10
2392            qX = torch.quantize_per_tensor(X, scale, zp, dtype)
2393            Y = op(qX.dequantize(), dim, unbiased, keep_dim)
2394            Y = torch.quantize_per_tensor(Y, scale, zp, dtype).dequantize()
2395            qY = op(qX, dim, unbiased, keep_dim)
2396            self.assertEqual(Y, qY.dequantize())
2397
2398    """Tests the correctness of the quantized equal op."""
2399    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
2400                       qparams=hu.qparams()),
2401           X2=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
2402                        qparams=hu.qparams()),
2403           X_per_channel=st.booleans(),
2404           X2_per_channel=st.booleans())
2405    def test_equal(self, X, X2, X_per_channel, X2_per_channel):
2406        X, X_params = X
2407        (scale, zero_point, torch_type) = X_params
2408        X2, X2_params = X2
2409        (scale2, zero_point2, torch_type2) = X2_params
2410
2411        X = torch.from_numpy(X)
2412        if X_per_channel:
2413            X_scheme = 'per_channel'
2414            channels = X.shape[-1]
2415            qX = torch.quantize_per_channel(
2416                X,
2417                scales=torch.tensor([scale] * channels),
2418                zero_points=torch.tensor([zero_point] * channels),
2419                dtype=torch_type,
2420                axis=X.ndim - 1)
2421        else:
2422            X_scheme = 'per_tensor'
2423            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2424                                           dtype=torch_type)
2425        X2 = torch.from_numpy(X2)
2426        if X2_per_channel:
2427            X2_scheme = 'per_channel'
2428            channels = X2.shape[-1]
2429            qX2 = torch.quantize_per_channel(
2430                X2,
2431                scales=torch.tensor([scale2] * channels),
2432                zero_points=torch.tensor([zero_point2] * channels),
2433                dtype=torch_type2,
2434                axis=X2.ndim - 1)
2435        else:
2436            X2_scheme = 'per_tensor'
2437            qX2 = torch.quantize_per_tensor(X2, scale=scale2, zero_point=zero_point2,
2438                                            dtype=torch_type2)
2439
2440        def equal_ref(qX, qX2):
2441            if qX.qscheme() != qX2.qscheme():
2442                return False
2443            if qX.shape != qX2.shape:
2444                return False
2445            if qX.dtype != qX2.dtype:
2446                return False
2447            if qX.qscheme() == torch.per_tensor_affine:
2448                if qX.q_scale() != qX2.q_scale():
2449                    return False
2450                if qX.q_zero_point() != qX2.q_zero_point():
2451                    return False
2452            elif qX.qscheme() == torch.per_channel_affine:
2453                if (qX.q_per_channel_scales() !=
2454                   qX2.q_per_channel_scales()).any():
2455                    return False
2456                if (qX.q_per_channel_zero_points() !=
2457                   qX2.q_per_channel_zero_points()).any():
2458                    return False
2459            else:
2460                raise NotImplementedError("Don't know what to do with",
2461                                          qX.qscheme())
2462            if (qX.int_repr().to(float) != qX2.int_repr().to(float)).any():
2463                return False
2464            return True
2465
2466        self.assertEqual(qX.equal(qX), equal_ref(qX, qX))
2467        self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))
2468
2469    """Tests quantized equal op with input of non-quantized tensor."""
2470    def test_quantized_equal(self,):
2471        x = torch.rand(1)
2472        y = torch.quantize_per_tensor(x, scale=0.5, zero_point=0, dtype=torch.qint8)
2473        self.assertTrue(not torch.equal(x, y))
2474        self.assertTrue(not torch.equal(y, x))
2475
2476    @skipIfNoFBGEMM
2477    def test_group_norm(self):
2478        # hypothesis is flaky for this test, create test cases manually
2479        batches_list = (1, 7)
2480        num_groups_list = (1, 4)
2481        channels_per_groups = (1, 36, 72)
2482        elements_per_channels = (8, 128, 1024)
2483        torch_types = (torch.qint8, torch.quint8)
2484        y_scales = (0.1, 4.23)
2485        y_zero_points = (0, 1)
2486        channels_last_list = [True, False]
2487        affine_list = [True, False]
2488        combined = [batches_list, num_groups_list, channels_per_groups, elements_per_channels,
2489                    torch_types, y_scales, y_zero_points, channels_last_list, affine_list]
2490        test_cases = itertools.product(*combined)
2491
2492        with override_quantized_engine("fbgemm"):
2493            for test_case in test_cases:
2494
2495                batches, num_groups, channels_per_group, elements_per_channel, \
2496                    torch_type, Y_scale, Y_zero_point, channels_last, \
2497                    affine = test_case
2498                num_channels = num_groups * channels_per_group
2499                # minimum rank for channels_last
2500                shapes = (batches, num_channels, elements_per_channel, 1)
2501
2502                # In the FP kernel, sums and sums of squares are calculated in floating point.
2503                # In the int8 and uint8 versions of the quantized kernel, they are
2504                # calculated in integer arithmetic (which is exact).
2505                # Because of this, the numerics do not always match exactly which is
2506                # expected and acceptable. We do the following to allow this failure
2507                # in this test:
2508                # 1. do not use Hypothesis to generate the input tensor.  Hypothesis
2509                #    favors homogeneous inputs in its search strategies which isn't
2510                #    representative of the inputs we care about, and tends to maximize
2511                #    this particular numerics difference.
2512                # 2. allow a small % of off by Y_scale errors.  Even when the
2513                #    variance of the input is high, there can be off by one errors
2514                #    in the result if the input value happens to fall exactly on
2515                #    the bin boundary of the output scale.
2516                #
2517                # If we want the numerics to match we could switch to calculating
2518                # mean+var in floating point in the future, at the cost of speed.
2519                X, X_scale, X_zero_point = \
2520                    _get_random_tensor_and_q_params(shapes, 1.0, torch_type)
2521
2522                # Initialize the weights non-randomly for reproducibility
2523                if affine:
2524                    weight = torch.ones(num_channels).float() * 0.5
2525                    bias = torch.ones(num_channels).float()
2526                    for i in range(num_channels):
2527                        weight[i] *= i
2528                        bias[i] *= i
2529                else:
2530                    weight = None
2531                    bias = None
2532
2533                eps = 0.001
2534
2535                qX = torch.quantize_per_tensor(X, X_scale, X_zero_point, torch_type)
2536                if channels_last:
2537                    qX = qX.contiguous(memory_format=torch.channels_last)
2538                dqX = qX.dequantize()
2539
2540                # Enforce non-homogeneous inputs
2541                for batch_idx in range(batches):
2542                    for group_idx in range(num_groups):
2543                        ch_start = group_idx * channels_per_group
2544                        ch_end = ch_start + channels_per_group
2545                        group_vals = dqX[batch_idx][ch_start:ch_end]
2546                        assume(
2547                            float(torch.unique(group_vals).shape[0]) / group_vals.numel() > 0.001
2548                            or group_vals.numel() < 5)
2549
2550                qY = torch.ops.quantized.group_norm(qX, num_groups, weight, bias, eps, Y_scale, Y_zero_point)
2551
2552                dqY_hat = F.group_norm(dqX, num_groups=num_groups, weight=weight, bias=bias, eps=eps)
2553                qY_hat = torch.quantize_per_tensor(dqY_hat, Y_scale, Y_zero_point, torch_type)
2554
2555                # Due to the numerics difference mentioned above between calculating
2556                # the variance in float vs int, the results can still be slightly
2557                # different.
2558                dqY = qY.dequantize()
2559                dqY_hat = qY_hat.dequantize()
2560                diff = dqY - dqY_hat
2561
2562                # off-by-one errors are magnitude of Y_scale
2563                num_diff = torch.sum(diff > Y_scale * 1.0001)
2564                pct_diff = float(num_diff) / (diff.numel() + 1e-5)
2565                num_diff_off_by_one = torch.sum((diff > 0) * (diff <= Y_scale))
2566                pct_diff_off_by_one = float(num_diff_off_by_one) / (diff.numel() + 1e-5)
2567
2568                self.assertTrue(pct_diff < 1e-6)
2569                self.assertTrue(pct_diff_off_by_one < 0.01)
2570
2571    @skipIfNoFBGEMM
2572    def test_instance_norm(self):
2573        max_sides = (4, 5)
2574        shape_list = ([2, 2, 2, 2], [8, 8, 8, 8], [11, 11, 11, 11])
2575        torch_types = (torch.qint8, torch.quint8)
2576        y_scales = (0.1, 4.23)
2577        y_zero_points = (0, 1)
2578        channels_last_list = (True, False)
2579        affine_list = (True, False)
2580        combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list]
2581        test_cases_product = itertools.product(*combined)
2582        test_cases = list(test_cases_product)
2583        # NB: Add just one test case to test overflow, but this case is too slow to run
2584        # internally in @fbcode//mode/dev, the long pole is the 4x calls to torch.sort
2585        # inside torch.unique current implementation
2586        if not IS_SANDCASTLE:
2587            test_cases.append([
2588                [1, 4, 224, 224, 160],  # shape,
2589                torch.qint8,  # torch_type
2590                0.1,  # scale
2591                0,  # zero_point
2592                False,   # channels_last
2593                True,  # affine
2594            ])
2595        with override_quantized_engine("fbgemm"):
2596            for test_case in test_cases:
2597
2598                shapes, torch_type, Y_scale, Y_zero_point, channels_last, affine = test_case
2599                if channels_last and shapes.__len__() >= 5:
2600                    # required rank 4 tensor to use channels_last format
2601                    continue
2602
2603                # In the FP kernel, sums and sums of squares are calculated in floating point.
2604                # In the int8 and uint8 versions of the quantized kernel, they are
2605                # calculated in integer arithmetic (which is exact).
2606                # Because of this, the numerics do not always match exactly which is
2607                # expected and acceptable. We do the following to allow this failure
2608                # in this test:
2609                # 1. do not use Hypothesis to generate the input tensor.  Hypothesis
2610                #    favors homogeneous inputs in its search strategies which isn't
2611                #    representative of the inputs we care about, and tends to maximize
2612                #    this particular numerics difference.
2613                # 2. allow a small % of off by Y_scale errors.  Even when the
2614                #    variance of the input is high, there can be off by one errors
2615                #    in the result if the input value happens to fall exactly on
2616                #    the bin boundary of the output scale.
2617                #
2618                # If we want the numerics to match we could switch to calculating
2619                # mean+var in floating point in the future, at the cost of speed.
2620                X, X_scale, X_zero_point = \
2621                    _get_random_tensor_and_q_params(shapes, 1.0, torch_type)
2622
2623                num_channels = shapes[1]
2624                if affine:
2625                    weight = torch.rand(num_channels).float() * 0.5
2626                    bias = torch.rand(num_channels).float()
2627                    for i in range(num_channels):
2628                        weight[i] *= i
2629                        bias[i] *= i
2630                else:
2631                    weight = None
2632                    bias = None
2633                eps = 0.001
2634
2635                qX = torch.quantize_per_tensor(X, X_scale, X_zero_point, torch_type)
2636                if channels_last:
2637                    qX = qX.contiguous(memory_format=torch.channels_last)
2638                dqX = qX.dequantize()
2639
2640                # Enforce non-homogeneous inputs
2641                batches = shapes[0]
2642                for batch_idx in range(batches):
2643                    for ch_idx in range(num_channels):
2644                        ch_vals = dqX[batch_idx][ch_idx]
2645                        assume(
2646                            float(torch.unique(ch_vals).shape[0]) / ch_vals.numel() > 0.01
2647                            or ch_vals.numel() < 5 or ch_vals.numel() > 25600)
2648
2649                qY = torch.ops.quantized.instance_norm(qX, weight, bias, eps, Y_scale, Y_zero_point)
2650
2651                dqY_hat = F.instance_norm(dqX, weight=weight, bias=bias, eps=eps)
2652                qY_hat = torch.quantize_per_tensor(dqY_hat, Y_scale, Y_zero_point, torch_type)
2653
2654                # Due to the numerics difference mentioned above between calculating
2655                # the variance in float vs int, the results can still be slightly
2656                # different.
2657                dqY = qY.dequantize()
2658                dqY_hat = qY_hat.dequantize()
2659                diff = dqY - dqY_hat
2660
2661                # off-by-one errors are magnitude of Y_scale
2662                num_diff = torch.sum(diff > Y_scale * 1.0001)
2663                pct_diff = float(num_diff) / (diff.numel() + 1e-5)
2664                num_diff_off_by_one = torch.sum((diff > 0) * (diff <= Y_scale))
2665                pct_diff_off_by_one = float(num_diff_off_by_one) / (diff.numel() + 1e-5)
2666
2667                self.assertTrue(pct_diff < 1e-6)
2668                self.assertTrue(pct_diff_off_by_one < 0.01)
2669
2670    @skipIfNoFBGEMM
2671    def test_batch_norm_relu(self):
2672        # hypothesis too slow for this test, create test cases manually
2673        max_sides = (2, 3, 4, 5)
2674        side_lens = (1, 8, 11)
2675        torch_types = (torch.qint8, torch.quint8)
2676        combined = [max_sides, side_lens, torch_types]
2677        test_cases = itertools.product(*combined)
2678
2679        with override_quantized_engine("fbgemm"):
2680            for test_case in test_cases:
2681                max_side, side_len, torch_type = test_case
2682                Y_zero_point = 1
2683                Y_scale = 0.5
2684
2685                shapes = [side_len] * max_side
2686                X, scale_x, zero_point_x = \
2687                    _get_random_tensor_and_q_params(shapes, 1.0, torch_type)
2688                dtype_x = torch_type
2689
2690                c = X.shape[1]
2691                mean = torch.rand(c).float()
2692                var = torch.rand(c).float()
2693                weight = torch.rand(c).float()
2694                bias = torch.rand(c).float()
2695                eps = 0.001
2696                qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
2697                if len(X.shape) == 2 or len(X.shape) == 3:
2698                    qy = torch.ops.quantized.batch_norm1d_relu(
2699                        qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
2700                elif len(X.shape) == 4:
2701                    qy = torch.ops.quantized.batch_norm2d_relu(
2702                        qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
2703                else:
2704                    qy = torch.ops.quantized.batch_norm3d_relu(
2705                        qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
2706
2707
2708                float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
2709                                         running_mean=mean, running_var=var,
2710                                         training=False, momentum=0, eps=eps).numpy()
2711
2712                float_ref_relu = float_ref.copy()
2713                float_ref_relu[float_ref < 0] = 0
2714                quantize_ref = torch.quantize_per_tensor(
2715                    torch.from_numpy(float_ref_relu), Y_scale, Y_zero_point, dtype_x)
2716                self.assertEqual(
2717                    qy.int_repr().numpy(),
2718                    quantize_ref.int_repr().numpy(),
2719                    msg=f"{qy} vs {quantize_ref}")
2720
2721    @skipIfNoFBGEMM
2722    def test_batch_norm(self):
2723        # hypothesis too slow for this test, create test cases manually
2724        max_sides = (2, 3, 4, 5)
2725        side_lens = (1, 8, 11)
2726        torch_types = (torch.qint8, torch.quint8)
2727        combined = [max_sides, side_lens, torch_types]
2728        test_cases = itertools.product(*combined)
2729
2730        with override_quantized_engine("fbgemm"):
2731            for test_case in test_cases:
2732                max_side, side_len, torch_type = test_case
2733                Y_zero_point = 1
2734                Y_scale = 0.5
2735
2736                shapes = [side_len] * max_side
2737                X, scale_x, zero_point_x = \
2738                    _get_random_tensor_and_q_params(shapes, 1.0, torch_type)
2739                dtype_x = torch_type
2740
2741                c = X.shape[1]
2742                mean = torch.rand(c).float()
2743                var = torch.rand(c).float()
2744                weight = torch.rand(c).float()
2745                bias = torch.rand(c).float()
2746                eps = 0.001
2747                qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
2748                if len(X.shape) == 2 or len(X.shape) == 3:
2749                    qy = torch.ops.quantized.batch_norm1d(
2750                        qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
2751                elif len(X.shape) == 4:
2752                    qy = torch.ops.quantized.batch_norm2d(
2753                        qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
2754                elif len(X.shape) == 5:
2755                    qy = torch.ops.quantized.batch_norm3d(
2756                        qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
2757
2758                float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
2759                                         running_mean=mean, running_var=var, training=False,
2760                                         momentum=0, eps=eps)
2761                quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
2762                self.assertEqual(
2763                    qy.int_repr().numpy(), quantize_ref.int_repr().numpy(),
2764                    msg=f"{qy} vs {quantize_ref}")
2765
2766    @override_qengines
2767    def test_empty_batch(self):
2768        scale = 1.0
2769        zero_point = 0
2770        X = torch.ones((0, 2, 4, 4), dtype=torch.float32)
2771        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2772                                       dtype=torch.quint8)
2773
2774        # upsample_nearest2d
2775        qY = torch.nn.functional.upsample_nearest(qX, scale_factor=2)
2776        np.testing.assert_equal(qY.size(), (0, 2, 8, 8),
2777                                "Quantized upsample_nearsest2d with batch size 0 failed.")
2778
2779        # relu
2780        qY = torch.nn.functional.relu(qX)
2781        np.testing.assert_equal(qY.size(), qX.size(),
2782                                "Quantized relu with batch size 0 failed.")
2783
2784        # tanh
2785        qY = torch.tanh(qX)
2786        np.testing.assert_equal(qY.size(), qX.size(),
2787                                "Quantized tanh with batch size 0 failed.")
2788        # sigmoid
2789        qY = torch.sigmoid(qX)
2790        np.testing.assert_equal(qY.size(), qX.size(),
2791                                "Quantized sigmoid with batch size 0 failed.")
2792
2793        # interpolate
2794        op = torch.ao.nn.quantized.functional.interpolate
2795        for mode in ["nearest", "bilinear", "nearest-exact"]:
2796            qY = op(qX, scale_factor=2, mode=mode)
2797            np.testing.assert_equal(qY.size(), (0, 2, 8, 8),
2798                                    "Quantized interpolate with batch size 0 failed.")
2799
2800        # avg_pool
2801        kernel = (2, 2)
2802        stride = (1, 1)
2803        padding = (0, 0)
2804        op = torch.ao.nn.quantized.functional.avg_pool2d
2805        qY = op(qX, kernel, stride, padding)
2806        np.testing.assert_equal(qY.size(), (0, 2, 3, 3),
2807                                "Quantized avg_pool2d with batch size 0 failed.")
2808
2809        # adaptive_avg_pool
2810        op = torch.ao.nn.quantized.functional.adaptive_avg_pool2d
2811        qY = op(qX, (3, 3))
2812        np.testing.assert_equal(qY.size(), (0, 2, 3, 3),
2813                                "Quantized adaptive_avg_pool2d with batch size 0 failed.")
2814
2815        # max_pool
2816        dilation = (1, 1)
2817        qY = torch.ops.quantized.max_pool2d(qX, kernel, stride, padding, dilation, ceil_mode=False)
2818        oH = pool_output_shape(4, 2, 0, 1, 1)
2819        oW = pool_output_shape(4, 2, 0, 1, 1)
2820        np.testing.assert_equal(qY.size(), (0, 2, oH, oW),
2821                                "Quantized maxpool2d with batch size 0 failed.")
2822
2823        # hardtanh
2824        qY = torch.ao.nn.quantized.functional.hardtanh(qX, -1, 6)
2825        np.testing.assert_equal(qY.size(), qX.size(),
2826                                "Quantized hardtanh with batch size 0 failed.")
2827
2828        # mul
2829        qY = torch.ops.quantized.mul(qX, qX, 1.0, 0)
2830        np.testing.assert_equal(qY.size(), qX.size(),
2831                                "Quantized mul with batch size 0 failed.")
2832        # add
2833        qY = torch.ops.quantized.add(qX, qX, 1.0, 0)
2834        np.testing.assert_equal(qY.size(), qX.size(),
2835                                "Quantized addition with batch size 0 failed.")
2836
2837        # conv
2838        w = torch.randn((2, 2, 2, 2), dtype=torch.float)
2839        qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
2840        bias_float = torch.ones(2, dtype=torch.float)
2841        strides = [1, 1]
2842        pads = [0, 0]
2843        dilations = [1, 1]
2844
2845        w_packed = torch.ops.quantized.conv2d_prepack(qw, bias_float, strides, pads, dilations, 1)
2846        result = torch.ops.quantized.conv2d(qX, w_packed, 1.0, 0)
2847        self.assertEqual(result.shape, (0, 2, 3, 3))
2848
2849        # linear
2850        X = torch.ones((0, 2), dtype=torch.float32)
2851        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
2852                                       dtype=torch.quint8)
2853        w = torch.randn((2, 2), dtype=torch.float)
2854        qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
2855        w_packed = torch.ops.quantized.linear_prepack(qw, bias_float)
2856        result = torch.ops.quantized.linear(qX, w_packed, 1.0, 0)
2857        self.assertEqual(result.shape, (0, 2))
2858
2859        # dynamic linear
2860        result = torch.ops.quantized.linear_dynamic(X, w_packed)
2861        self.assertEqual(result.shape, (0, 2))
2862
2863    @override_qengines
2864    def test_linear_bias_unpack(self):
2865        """
2866        Verifies the correctness of bias() and unpack() API for LinearPackedParamBase.
2867        """
2868        bias_float = torch.ones(2, dtype=torch.float)
2869        w = torch.randn((2, 2), dtype=torch.float)
2870        qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
2871        w_packed = torch.ops.quantized.linear_prepack(qw, bias_float)
2872        # test bias()
2873        self.assertEqual(w_packed.bias(), bias_float)
2874        # test unpack()
2875        self.assertEqual(w_packed.unpack()[0], qw)
2876
2877    def test_advanced_indexing(self):
2878        """
2879        Verifies that the x[:, [0], :, :] syntax works for quantized tensors.
2880        """
2881        for dtype in (torch.qint8, torch.quint8, torch.qint32):
2882            scale = 0.1
2883            zp = 0
2884            x_q = torch.quantize_per_tensor(
2885                torch.randn(1, 4, 4, 4), scale, zp, dtype)
2886            # reference
2887            x_fp32 = x_q.dequantize()
2888
2889            # single dim, single index
2890            x_q_s1 = x_q[:, [0], :, :]
2891            x_fp32_s1 = x_fp32[:, [0], :, :]
2892            x_fp32_s1_ref = \
2893                torch.quantize_per_tensor(x_fp32_s1, scale, zp, dtype)
2894            self.assertEqual(x_q_s1, x_fp32_s1_ref)
2895
2896            # multiple dim, single index
2897            x_q_s2 = x_q[:, [0], [2], :]
2898            x_fp32_s2 = x_fp32[:, [0], [2], :]
2899            x_fp32_s2_ref = \
2900                torch.quantize_per_tensor(x_fp32_s2, scale, zp, dtype)
2901            self.assertEqual(x_q_s2, x_fp32_s2_ref)
2902
2903            # single dim, multiple indices
2904            x_q_s3 = x_q[:, [2, 0, 1], :, :]
2905            x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :]
2906            x_fp32_s3_ref = \
2907                torch.quantize_per_tensor(x_fp32_s3, scale, zp, dtype)
2908            self.assertEqual(x_q_s3, x_fp32_s3_ref)
2909
2910            # multiple dim, multiple indices
2911            x_q_s4 = x_q[:, [2, 0, 1], :, [1]]
2912            x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]]
2913            x_fp32_s4_ref = \
2914                torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype)
2915            self.assertEqual(x_q_s4, x_fp32_s4_ref)
2916
2917    @override_qengines
2918    def test_custom_module_lstm(self):
2919        qengine = torch.backends.quantized.engine
2920
2921        batch_size = 4
2922        seq_len = 8
2923        input_size = 12
2924
2925        hidden_size = 8
2926        num_layers = 2
2927
2928        dropout = 0  # This is not supported
2929
2930        Bias = [False, True]
2931        Batch_first = [False, True]
2932        Bidirectional = [False, True]
2933
2934        dtype = np.uint8
2935        qtype = torch.quint8
2936
2937        x = np.random.randn(seq_len, batch_size, input_size)
2938        scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype)
2939        x = torch.from_numpy(x).to(torch.float)
2940        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point,
2941                                       dtype=qtype)
2942        x = qx.dequantize()
2943
2944        with torch.no_grad():
2945            for bias, batch_first, bidirectional in itertools.product(
2946                    Bias, Batch_first, Bidirectional):
2947                # Assume 12dB is sufficient for functional equivalence
2948                # Without the bias, linear performs poorly
2949                min_power = 10 if bias else 5
2950                max_mse = 5e-6 if bias else 5e-1
2951
2952                if batch_first:
2953                    x = x.reshape(batch_size, seq_len, input_size)
2954                    qx = qx.reshape(batch_size, seq_len, input_size)
2955                else:
2956                    x = x.reshape(seq_len, batch_size, input_size)
2957                    qx = qx.reshape(seq_len, batch_size, input_size)
2958
2959                lstm = torch.nn.Sequential(
2960                    torch.nn.LSTM(input_size, hidden_size,
2961                                  num_layers=num_layers,
2962                                  bias=bias, batch_first=batch_first,
2963                                  dropout=dropout,
2964                                  bidirectional=bidirectional))
2965                lstm.eval()
2966                y_ref = lstm(x)
2967
2968                # Prepare
2969                lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
2970                lstm_prepared = torch.ao.quantization.prepare(lstm)
2971                self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
2972                self.assertEqual(num_layers, len(lstm_prepared[0].layers))
2973                assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
2974
2975                # Calibrate
2976                y = lstm_prepared(x)
2977                self.assertEqual(y_ref, y)
2978
2979                # Quantize
2980                lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
2981                assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
2982                qy = lstm_quantized(qx)
2983
2984                snr = _snr(y, qy)
2985                snr = [snr[0]] + snr[1]
2986
2987                for signal, mse, power in snr:
2988                    self.assertTrue(
2989                        power > min_power or mse < max_mse,
2990                        msg=(f"Error is too high: SNR(dB): {power}, "
2991                             f"Signal: {signal}, MSE: {mse}"))
2992
2993                # Trace
2994                jit_qmodule = torch.jit.trace(lstm_quantized, qx)
2995
2996                # Script
2997                jit_qmodule = torch.jit.script(lstm_quantized)
2998
2999    @override_qengines
3000    def test_custom_module_multi_head_attention(self):
3001        class MultiheadAttentionModel(torch.nn.Module):
3002            def __init__(self, *args, **kwargs):
3003                super().__init__()
3004                self.layer = torch.nn.MultiheadAttention(*args, **kwargs)
3005
3006            def forward(
3007                self,
3008                query,
3009                key,
3010                value,
3011                key_padding_mask: Optional[torch.Tensor] = None,
3012                need_weights: bool = True,
3013                attn_mask: Optional[torch.Tensor] = None,
3014            ):
3015                return self.layer(query, key, value, key_padding_mask, need_weights, attn_mask)
3016
3017        qengine = torch.backends.quantized.engine
3018
3019        min_power = 30
3020        max_mse = 2
3021
3022        num_heads = 16
3023        batch_size = 4
3024        target_seq_length = 128
3025        source_seq_length = 64
3026        qembed_dim = 512  # Must be divisible by the number of heads
3027        kembed_dim = 128
3028        vembed_dim = 256
3029
3030        dropout = 0.0  # This is not supported
3031
3032        Bias = [False, True]
3033        Add_bias_kv = [False, True]
3034        Add_zero_attn = [False, True]
3035
3036        dtype = np.uint8
3037        qtype = torch.quint8
3038
3039        for kdim, vdim in ((kembed_dim, vembed_dim), (None, None)):
3040            fp_data = [
3041                torch.randn(target_seq_length, batch_size, qembed_dim),  # Q
3042                torch.randn(source_seq_length, batch_size,
3043                            qembed_dim if kdim is None else kembed_dim),  # K
3044                torch.randn(source_seq_length, batch_size,
3045                            qembed_dim if vdim is None else vembed_dim)   # V
3046            ]
3047
3048            q_data = []
3049            reduce_range = (qengine in ('x86', 'fbgemm', 'onednn'))
3050            for idx, x in enumerate(fp_data):
3051                scale, zero_point = _calculate_dynamic_qparams(
3052                    x, dtype=dtype, reduce_range=reduce_range)
3053                x = x.to(torch.float)
3054                qx = torch.quantize_per_tensor(x, scale=scale,
3055                                               zero_point=zero_point, dtype=qtype)
3056                q_data.append(qx)
3057
3058                # Dequantize the data back for reference
3059                fp_data[idx] = qx.dequantize()
3060
3061            with torch.no_grad():
3062                for bias, add_bias_kv, add_zero_attn in itertools.product(
3063                        Bias, Add_bias_kv, Add_zero_attn):
3064                    mha = MultiheadAttentionModel(qembed_dim, num_heads, dropout,
3065                                                  bias, add_bias_kv, add_zero_attn,
3066                                                  kdim=kdim, vdim=vdim)
3067                    mha.eval()
3068
3069                    # Prepare
3070                    if qengine_is_onednn():
3071                        # `reduce_range` is False by default for ONEDNN backend
3072                        # but the test fails on earlier CPUs without VNNI.
3073                        # So we use a default qconfig with `reduce_range=True` here
3074                        mha.qconfig = torch.ao.quantization.get_default_qconfig()
3075                    else:
3076                        mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
3077                    mha_prepared = torch.ao.quantization.prepare(
3078                        mha)
3079
3080                    # Calibrate
3081                    y = mha_prepared(*fp_data)
3082                    y_ref = mha(*fp_data)
3083                    # Check the result of the prepare
3084                    self.assertEqual(y_ref[0], y[0])  # Attention
3085                    self.assertEqual(y_ref[1], y[1])  # Weight
3086
3087                    # Quantize
3088                    mha_quantized = torch.ao.quantization.convert(mha_prepared)
3089
3090                    for name, param in mha_quantized.named_parameters():
3091                        self.assertTrue("in_proj_weight" not in name)
3092
3093                    qy = mha_quantized(*q_data)
3094
3095                    # Reference result
3096                    mha.layer = mha_quantized.layer.dequantize()
3097                    y_ref = mha(*fp_data)
3098
3099                    snr = _snr(y, qy)
3100                    for signal, mse, power in snr:
3101                        self.assertTrue(
3102                            power > min_power or mse < max_mse,
3103                            msg=(f"Error is too high: SNR(dB): {power}, "
3104                                 f"Signal: {signal}, MSE: {mse}; "
3105                                 f"Run with bias={bias}, "
3106                                 f"add_bias_kv={add_bias_kv}, "
3107                                 f"add_zero_attn={add_zero_attn}"))
3108
3109                    # Verify the result is scriptable
3110                    mha_quantized_scripted = torch.jit.script(mha_quantized)
3111
3112
3113class TestDynamicQuantizedOps(TestCase):
3114    """Tests the correctness of the dynamic quantized linear and linear_relu op."""
3115    @override_qengines
3116    @given(
3117        batch_size=st.integers(1, 4),
3118        input_channels=st.integers(16, 32),
3119        output_channels=st.integers(4, 8),
3120        use_bias=st.booleans(),
3121        use_relu=st.booleans(),
3122        use_multi_dim_input=st.booleans(),
3123        use_channelwise=st.booleans(),
3124        reduce_range=st.booleans())
3125    def test_qlinear(self, batch_size, input_channels, output_channels,
3126                     use_bias, use_relu, use_multi_dim_input, use_channelwise, reduce_range):
3127        if torch.backends.quantized.engine == 'qnnpack':
3128            reduce_range = False
3129
3130        qlinear_prepack = torch.ops.quantized.linear_prepack
3131        if use_relu:
3132            qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic
3133        else:
3134            qlinear_dynamic = torch.ops.quantized.linear_dynamic
3135
3136        if use_multi_dim_input:
3137            batch_size *= 3  # Test the multi-dim input tensor
3138
3139        X_scale = 1.0
3140        X_zp = 0
3141        X_value_min = 0
3142        X_value_max = 255
3143        if reduce_range:
3144            X_value_max = 127
3145        X_q0 = np.round(np.random.rand(batch_size, input_channels) *
3146                        (X_value_max - X_value_min) + X_value_min).astype(np.uint8)
3147        X_q0[0, 0] = X_value_min
3148        X_q0[0, 1] = X_value_max
3149
3150        # W_scale = 1.0
3151        # W_zp = 0
3152        W_scales = np.ones(output_channels)
3153        W_zps = np.zeros(output_channels).astype(int)
3154        W_value_min = -128
3155        W_value_max = 127
3156        W_q0 = np.round(
3157            np.random.rand(output_channels, input_channels)
3158            * (W_value_max - W_value_min)
3159            + W_value_min
3160        ).astype(np.int8)
3161        W_q0[0, 0] = W_value_min
3162        W_q0[1, 0] = W_value_max
3163
3164        b_value_min = -10
3165        b_value_max = 10
3166        b_q0 = np.round(
3167            np.random.rand(output_channels) *
3168            (b_value_max - b_value_min) + b_value_min
3169        ).astype(np.int32) if use_bias else None
3170
3171        if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
3172            avoid_vpmaddubsw_overflow_linear(
3173                batch_size,
3174                input_channels,
3175                output_channels,
3176                X_q0,
3177                X_value_min,
3178                X_value_max,
3179                W_q0,
3180                W_value_min,
3181                W_value_max,
3182            )
3183
3184        X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float)
3185        if use_multi_dim_input:
3186            X_fp32 = X_fp32.view(3, int(batch_size / 3), input_channels)
3187
3188        # W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8)
3189        # We currently only check the case where W_scale = 1.0, W_zp = 0.
3190
3191        if use_channelwise:
3192            W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scales.reshape(
3193                (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float)
3194            W_q = torch.quantize_per_channel(W_fp32, scales=torch.from_numpy(W_scales),
3195                                             zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8)
3196            b_fp32 = torch.from_numpy(
3197                _dequantize(b_q0, X_scale * W_scales, 0)
3198            ).to(dtype=torch.float) if use_bias else None
3199        else:
3200            W_fp32 = torch.from_numpy(_dequantize(
3201                W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float)
3202            W_q = torch.quantize_per_tensor(W_fp32, scale=W_scales[0], zero_point=(
3203                W_zps[0].astype(int).item()), dtype=torch.qint8)
3204            b_fp32 = torch.from_numpy(
3205                _dequantize(b_q0, X_scale * int(W_scales[0].item()), 0)
3206            ).to(dtype=torch.float) if use_bias else None
3207
3208        # Observe X_fp32 and determine X_scale and X_zero_point, this should match
3209        # internals of dynamic linear.
3210        X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8, reduce_range)
3211        X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
3212
3213        # Weight prepacking operator for dynamic quantized Linear
3214        W_prepack = qlinear_prepack(W_q, b_fp32)
3215        # Dynamic quantized Linear operator with prepacked weight
3216        Y_fp32 = qlinear_dynamic(X_q.dequantize(), W_prepack, reduce_range)
3217        # Y_fp32 = qlinear_dynamic(X_fp32, W_prepack, b_fp32)
3218
3219        Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32)
3220        # Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
3221        # if use_multi_dim_input:
3222        #     Y_fp32_ref = Y_fp32_ref.view(3, int(batch_size / 3), output_channels)
3223
3224        if use_relu:
3225            Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0
3226        self.assertEqual(Y_fp32, Y_fp32_ref,
3227                         msg="torch.ops.quantized.linear_dynamic results are off")
3228
3229    @skipIfNoFBGEMM
3230    @given(
3231        batch_size=st.integers(1, 4),
3232        input_channels=st.integers(16, 32),
3233        output_channels=st.integers(4, 8),
3234    )
3235    def test_qlinear_legacy(self, batch_size, input_channels, output_channels):
3236        X_scale = 1.0
3237        X_zp = 0
3238        X_value_min = 0
3239        X_value_max = 255
3240        X_q0 = np.round(np.random.rand(batch_size, input_channels) * (
3241            X_value_max - X_value_min) + X_value_min
3242        ).astype(np.uint8)
3243        X_q0[0, 0] = X_value_min
3244        X_q0[0, 1] = X_value_max
3245
3246        W_scale = 1.0
3247        W_zp = 0
3248        W_value_min = -128
3249        W_value_max = 127
3250        W_q0 = np.round(
3251            np.random.rand(output_channels, input_channels)
3252            * (W_value_max - W_value_min)
3253            + W_value_min
3254        ).astype(np.int8)
3255        W_q0[0, 0] = W_value_min
3256        W_q0[1, 0] = W_value_max
3257
3258        b_value_min = -10
3259        b_value_max = 10
3260        b_q0 = np.round(
3261            np.random.rand(output_channels) * (b_value_max - b_value_min) +
3262            b_value_min
3263        ).astype(np.int32)
3264
3265        avoid_vpmaddubsw_overflow_linear(
3266            batch_size,
3267            input_channels,
3268            output_channels,
3269            X_q0,
3270            X_value_min,
3271            X_value_max,
3272            W_q0,
3273            W_value_min,
3274            W_value_max,
3275        )
3276
3277        X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float)
3278        W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
3279        b_fp32 = torch.from_numpy(
3280            _dequantize(b_q0, X_scale * W_scale, 0)
3281        ).to(dtype=torch.float)
3282
3283        W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8)
3284        W_q = torch.quantize_per_tensor(W_fp32, scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
3285
3286        # Observe X_fp32 and determine X_scale and X_zero_point, this should match
3287        # internals of dynamic linear.
3288        X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8)
3289        X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
3290
3291        W_int8, col_offsets, W_scale, W_zp = torch.fbgemm_linear_quantize_weight(W_q.dequantize())
3292        W_prepack = torch.fbgemm_pack_quantized_matrix(W_int8.clone(), W_int8.size(1), W_int8.size(0))
3293        # Quantized Linear operator with prepacked weight
3294        Y_fp32 = torch.fbgemm_linear_int8_weight(
3295            X_q.dequantize(), W_q.dequantize(), W_prepack, col_offsets,
3296            W_scale, W_zp, b_fp32)
3297
3298        Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32)
3299        # Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
3300
3301        self.assertEqual(Y_fp32, Y_fp32_ref,
3302                         msg="torch.ops.quantized.fbgemm_linear_dynamic results are off")
3303
3304    @skipIfNoFBGEMM
3305    @given(
3306        input_channels=st.integers(16, 32),
3307        output_channels=st.integers(4, 8),
3308        exponent=st.integers(0, 8))
3309    def test_linear_prepack_fp16_numerics(self, input_channels, output_channels, exponent):
3310        w = torch.randn(output_channels, input_channels) * 10**exponent
3311        bias = None
3312        w_packed_fp16 = torch.ops.quantized.linear_prepack_fp16(w, bias)
3313        w_unpacked_fp16 = torch.ops.quantized.linear_unpack_fp16(w_packed_fp16)
3314        w_fp16 = w.to(torch.float16).to(torch.float32)
3315        self.assertTrue(torch.equal(w_fp16, w_unpacked_fp16[0]))
3316
3317    @skipIfNoFBGEMM
3318    def test_qlinear_dynamic_fp16(self):
3319
3320        options = itertools.product(
3321            (2, 4),         # batch_size
3322            (4, 5, 12),     # input_channels
3323            (4, 7, 8),      # output_channels
3324            (True, False),  # use_bias
3325            (True, False),  # use_relu
3326        )
3327        for batch_size, input_channels, output_channels, use_bias, use_relu in options:
3328            qlinear_prepack = torch.ops.quantized.linear_prepack_fp16
3329            if use_relu:
3330                qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic_fp16
3331            else:
3332                qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16
3333
3334            x = torch.randn(batch_size, input_channels)
3335            w = torch.randn(output_channels, input_channels)
3336            bias = torch.randn(output_channels) if use_bias else None
3337
3338            w_packed = qlinear_prepack(w, bias)
3339            out = qlinear_dynamic(x, w_packed)
3340
3341            # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors
3342            # output is FP32
3343            w_fp16 = w.to(torch.float16).to(torch.float32)
3344            ref = F.linear(x, w_fp16, bias)
3345            if use_relu:
3346                ref.relu_()
3347
3348            self.assertEqual(out, ref)
3349
3350    @skipIfNoFBGEMM
3351    def test_unpacked_qlinear_dynamic_fp16(self):
3352
3353        options = itertools.product(
3354            (2, 4),         # batch_size
3355            (4, 5, 12),     # input_channels
3356            (4, 7, 8),      # output_channels
3357        )
3358        for batch_size, input_channels, output_channels in options:
3359            qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight
3360
3361            x = torch.randn(batch_size, input_channels)
3362            w = torch.randn(output_channels, input_channels)
3363            bias = torch.randn(output_channels)
3364
3365            out = qlinear_dynamic(x, w, bias)
3366
3367            # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors
3368            # output is FP32
3369            w_fp16 = w.to(torch.float16).to(torch.float32)
3370            ref = F.linear(x, w_fp16, bias)
3371
3372            self.assertEqual(out, ref)
3373
3374
3375    @skipIfNoFBGEMM
3376    def test_unpacked_qlinear_dynamic_fp16_opcheck(self):
3377        qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default
3378
3379        x = torch.randn(4, 4, device='cpu')
3380        w = torch.randn(4, 4, device='cpu')
3381        bias = torch.randn(4, device='cpu')
3382
3383        opcheck(qlinear_dynamic, (x, w, bias))
3384
3385    @skipIfNoFBGEMM
3386    def test_wrapped_fbgemm_linear_fp16(self):
3387        options = itertools.product(
3388            (2, 4),         # batch_size
3389            (4, 5),     # input_channels
3390            (4, 7),      # output_channels
3391        )
3392        for batch_size, input_channels, output_channels in options:
3393            pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
3394            linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
3395
3396            x = torch.randn(batch_size, input_channels)
3397            w = torch.randn(output_channels, input_channels)
3398            bias = torch.randn(output_channels)
3399
3400            w_packed = pack_op(w)
3401            out = linear_op(x, w_packed, bias, output_channels)
3402
3403            w_fp16 = w.to(torch.float16).to(torch.float32)
3404            ref = F.linear(x, w_fp16, bias)
3405
3406            self.assertEqual(out, ref)
3407
3408    @skipIfNoFBGEMM
3409    def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
3410        # We are not using opcheck over here because the output for the op we're testing
3411        # (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
3412        # due to the C-struct it's procuding. This would fail the check when we're trying
3413        # to match the result between compiled and eager version.
3414        #
3415        # This is only a temporary solution, long term, we should be able to support PT2
3416        # with torchbind natively.
3417        def func(X, W, B):
3418            packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
3419            return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
3420
3421        x = torch.randn(1, 4, device="cpu")
3422        w = torch.randn(4, 4, device="cpu")
3423        b = torch.zeros(4, device="cpu")
3424
3425        ref_out = func(x, w, b)
3426
3427        compiled = torch.compile(func)
3428        compiled_out = compiled(x, w, b)
3429
3430        self.assertEqual(ref_out, compiled_out)
3431
3432    """Tests the correctness of the dynamic quantized lstm/gru."""
3433
3434    def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range):
3435        # For Input (seq_len, batch, input_size)
3436        X = torch.randn(seq_len, num_batches, input_size)
3437        s, z = _calculate_dynamic_qparams(X, torch.quint8, reduce_range)
3438        Xq = torch.quantize_per_tensor(X, s, z, torch.quint8)
3439
3440        # For H and C: (num_layers(1) * num_directions, batch, hidden_size)
3441
3442        if num_directions == 1:
3443            H = torch.randn(num_directions, num_batches, hidden_size)
3444            C = torch.randn(num_directions, num_batches, hidden_size)
3445        else:
3446            H = torch.zeros(num_directions, num_batches, hidden_size)
3447            C = torch.zeros(num_directions, num_batches, hidden_size)
3448
3449        s, z = _calculate_dynamic_qparams(H, torch.quint8, reduce_range)
3450        Hq = torch.quantize_per_tensor(H, s, z, torch.quint8)
3451        s, z = _calculate_dynamic_qparams(C, torch.quint8, reduce_range)
3452        Cq = torch.quantize_per_tensor(C, s, z, torch.quint8)
3453        return Xq, Hq, Cq
3454
3455    def _get_rnn_weights_and_bias(self, input_size, hidden_size, num_directions, per_channel_quant, rnn_type):
3456        hidden_mult_map = {'LSTM': 4, 'LSTMCell': 4, 'GRU': 3, 'GRUCell': 3, 'RNNTanh': 2, 'RNNReLU': 2}
3457        hidden_mult = hidden_mult_map[rnn_type]
3458        weights1 = torch.randn(hidden_mult * hidden_size, input_size)
3459        weights2 = torch.randn(hidden_mult * hidden_size, hidden_size)
3460        scale1 = 0.1 * torch.ones([weights1.size()[0]])
3461        scale2 = 0.3 * torch.ones([weights2.size()[0]])
3462        zero_point1 = torch.zeros(scale1.size()).to(int)
3463        zero_point2 = torch.zeros(scale2.size()).to(int)
3464        b1 = torch.zeros(hidden_mult * hidden_size)
3465        if per_channel_quant:
3466            Wq1 = torch.quantize_per_channel(weights1, scale1, zero_point1, 0, torch.qint8)
3467            Wq2 = torch.quantize_per_channel(weights2, scale2, zero_point2, 0, torch.qint8)
3468
3469        else:
3470            Wq1 = torch.quantize_per_tensor(weights1, float(scale1[0]), int(zero_point1[0]), torch.qint8)
3471            Wq2 = torch.quantize_per_tensor(weights2, float(scale2[0]), int(zero_point2[0]), torch.qint8)
3472        return Wq1, Wq2, b1, b1
3473
3474    @given(
3475        num_batches=st.integers(1, 4),
3476        input_size=st.integers(16, 32),
3477        hidden_size=st.integers(4, 8),
3478        num_directions=st.integers(1, 2),
3479        per_channel_quant=st.booleans())
3480    @override_qengines
3481    def test_qlstmGRU(self, num_batches, input_size, hidden_size,
3482                      num_directions, per_channel_quant):
3483        # We test only for seq length of 1 and num layers of 1 as dynamic quantization occurs multiple times
3484        # within the LSTM op and we do not model the quantization between multiple calls of the linear op within the
3485        # lstm op
3486        seq_len = 1
3487
3488        for rnn_type in ['LSTM', 'GRU']:
3489            for dtype in [torch.qint8, torch.float16]:
3490                # Fp16 quantization is not supported for qnnpack or onednn
3491                if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16:
3492                    continue
3493
3494                if torch.backends.quantized.engine == 'qnnpack':
3495                    reduce_range = False
3496                else:
3497                    reduce_range = True
3498                Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size,
3499                                                  hidden_size, num_directions, reduce_range)
3500                Wq1, Wq2, b1, b2 = self._get_rnn_weights_and_bias(input_size,
3501                                                                  hidden_size,
3502                                                                  num_directions,
3503                                                                  per_channel_quant,
3504                                                                  rnn_type)
3505                if dtype == torch.qint8:
3506                    packed_ih = torch.ops.quantized.linear_prepack(Wq1, b1)
3507                    packed_hh = torch.ops.quantized.linear_prepack(Wq2, b2)
3508                    cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
3509                        packed_ih, packed_hh, b1, b2, reduce_range)
3510                    W_ref1 = Wq1.dequantize()
3511                    W_ref2 = Wq2.dequantize()
3512
3513                else:
3514                    packed_ih = torch.ops.quantized.linear_prepack_fp16(Wq1.dequantize(), b1)
3515                    packed_hh = torch.ops.quantized.linear_prepack_fp16(Wq2.dequantize(), b2)
3516                    cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih, packed_hh)
3517                    W_ref1 = Wq1.dequantize().to(torch.float16).to(torch.float32)
3518                    W_ref2 = Wq2.dequantize().to(torch.float16).to(torch.float32)
3519
3520                if rnn_type == 'LSTM':
3521                    if num_directions > 1:
3522                        result_ref = _VF.lstm(Xq.dequantize(),
3523                                              (Hq.dequantize(), Cq.dequantize()),
3524                                              [W_ref1, W_ref2, b1, b2, W_ref1, W_ref2, b1, b2],
3525                                              True,
3526                                              1,
3527                                              0,
3528                                              False,
3529                                              num_directions > 1,
3530                                              False)
3531
3532                        result_dynamic = torch.quantized_lstm(Xq.dequantize(),
3533                                                              (Hq.dequantize(), Cq.dequantize()),
3534                                                              ([cell_params, cell_params]),
3535                                                              True,
3536                                                              1,
3537                                                              0,
3538                                                              False,
3539                                                              True,
3540                                                              False,
3541                                                              dtype=torch.qint8,
3542                                                              use_dynamic=True)
3543                    else:
3544                        result_ref = _VF.lstm(Xq.dequantize(),
3545                                              (Hq.dequantize(), Cq.dequantize()),
3546                                              [W_ref1, W_ref2, b1, b2],
3547                                              True,
3548                                              1,
3549                                              0,
3550                                              False,
3551                                              num_directions > 1,
3552                                              False)
3553
3554                        result_dynamic = torch.quantized_lstm(Xq.dequantize(),
3555                                                              (Hq.dequantize(), Cq.dequantize()),
3556                                                              ([cell_params]),
3557                                                              True,
3558                                                              1,
3559                                                              0,
3560                                                              False,
3561                                                              num_directions > 1,
3562                                                              False,
3563                                                              dtype=torch.qint8,
3564                                                              use_dynamic=True)
3565
3566                if rnn_type == 'GRU':
3567                    if num_directions > 1:
3568                        result_ref = _VF.gru(Xq.dequantize(),
3569                                             Hq.dequantize(),
3570                                             [W_ref1, W_ref2, b1, b2, W_ref1, W_ref2, b1, b2],
3571                                             True,
3572                                             1,
3573                                             0,
3574                                             False,
3575                                             True,
3576                                             False)
3577
3578                        result_dynamic = torch.quantized_gru(Xq.dequantize(),
3579                                                             Hq.dequantize(),
3580                                                             ([cell_params, cell_params]),
3581                                                             True,
3582                                                             1,
3583                                                             0,
3584                                                             False,
3585                                                             True,
3586                                                             False)
3587                    else:
3588                        result_ref = _VF.gru(Xq.dequantize(),
3589                                             Hq.dequantize(),
3590                                             [W_ref1, W_ref2, b1, b2],
3591                                             True,
3592                                             1,
3593                                             0,
3594                                             False,
3595                                             False,
3596                                             False)
3597
3598                        result_dynamic = torch.quantized_gru(Xq.dequantize(),
3599                                                             Hq.dequantize(),
3600                                                             ([cell_params]),
3601                                                             True,
3602                                                             1,
3603                                                             0,
3604                                                             False,
3605                                                             False,
3606                                                             False)
3607
3608                self.assertEqual(result_ref[0], result_dynamic[0], msg="torch.quantized_lstm results are off")
3609
3610    @given(
3611        num_batches=st.integers(1, 4),
3612        input_size=st.integers(16, 32),
3613        hidden_size=st.integers(4, 8),
3614        per_channel_quant=st.booleans())
3615    @override_qengines
3616    def test_qrnncell(self, num_batches, input_size, hidden_size, per_channel_quant):
3617        # We test only for seq length of 1 and num layers of 1 as dynamic quantization occurs multiple times
3618        # within the LSTM op and we do not model the quantization between multiple calls of the linear op within the
3619        # lstm op
3620        seq_len = 1
3621
3622        for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']:
3623            for dtype in [torch.qint8, torch.float16]:
3624                # Fp16 quantization is not supported for qnnpack or onednn
3625                if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16:
3626                    continue
3627
3628                if torch.backends.quantized.engine == 'qnnpack':
3629                    reduce_range = False
3630                else:
3631                    reduce_range = True
3632
3633                Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, hidden_size, 1, reduce_range)
3634                Wq1, Wq2, b1, b2 = self._get_rnn_weights_and_bias(
3635                    input_size, hidden_size, 1, per_channel_quant, rnn_type)
3636                if dtype == torch.qint8:
3637                    packed_ih = torch.ops.quantized.linear_prepack(Wq1, b1)
3638                    packed_hh = torch.ops.quantized.linear_prepack(Wq2, b2)
3639                    W_ref1 = Wq1.dequantize()
3640                    W_ref2 = Wq2.dequantize()
3641                else:
3642                    packed_ih = torch.ops.quantized.linear_prepack_fp16(Wq1.dequantize(), b1)
3643                    packed_hh = torch.ops.quantized.linear_prepack_fp16(Wq2.dequantize(), b2)
3644                    W_ref1 = Wq1.dequantize().to(torch.float16).to(torch.float32)
3645                    W_ref2 = Wq2.dequantize().to(torch.float16).to(torch.float32)
3646
3647                state = {'LSTMCell': (Hq.dequantize()[0], Cq.dequantize()[0]),
3648                         'GRUCell': Hq.dequantize()[0],
3649                         'RNNTanh': Hq.dequantize()[0],
3650                         'RNNReLU': Hq.dequantize()[0]}
3651                fn_dict = {'LSTMCell': torch._VF.lstm_cell,
3652                           'GRUCell': torch._VF.gru_cell,
3653                           'RNNTanh': torch._VF.rnn_tanh_cell,
3654                           'RNNReLU': torch._VF.rnn_relu_cell}
3655                qfn_dict = {'LSTMCell': torch.ops.quantized.quantized_lstm_cell_dynamic,
3656                            'GRUCell': torch.ops.quantized.quantized_gru_cell_dynamic,
3657                            'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic,
3658                            'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic}
3659                W_ref_dict = {torch.float16: (Wq1.dequantize().to(torch.float16).to(torch.float32),
3660                                              Wq2.dequantize().to(torch.float16).to(torch.float32)),
3661                              torch.qint8: (Wq1.dequantize(), Wq2.dequantize())}
3662
3663                result_ref = fn_dict[rnn_type](Xq.dequantize()[0], state[rnn_type], W_ref1, W_ref2, b1, b2)
3664                result_dynamic = qfn_dict[rnn_type](Xq.dequantize()[0], state[rnn_type], packed_ih, packed_hh, b1, b2)
3665                self.assertEqual(result_ref[0], result_dynamic[0], msg="torch.quantized_rnncell results are off")
3666
3667    def _test_qconv_op_impl(self, q_mod, dq_op, dim, dtype):
3668        # The goal here is to show that the dynamic op is the same as
3669        # calc params->quantize_input->quantized op->dequantize output
3670
3671        if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN):
3672            return  # not supported by QNNPACK
3673
3674        if qengine_is_qnnpack():
3675            reduce_range = False
3676        else:
3677            reduce_range = True
3678
3679        X_fp32 = torch.randn(*([2] * dim))
3680        s, z = _calculate_dynamic_qparams(X_fp32, dtype, reduce_range)
3681
3682        quantized_module = q_mod(2, 3, 1)
3683        packed_params = quantized_module._packed_params
3684
3685        quantized_module.scale, quantized_module.zero_point = s, z
3686
3687        X_q = torch.quantize_per_tensor(X_fp32, s, z, dtype)
3688        Y_q_ref = quantized_module(X_q)
3689        Y_ref = torch.dequantize(Y_q_ref)
3690
3691        X_dq = torch.dequantize(X_q)
3692        Y = dq_op(X_dq, packed_params, reduce_range)
3693
3694        self.assertEqual(Y, Y_ref)
3695
3696    @override_qengines
3697    def test_dynamic_conv1d(self):
3698        q_mod = torch.ao.nn.quantized.Conv1d
3699        dq_op = torch.ops.quantized.conv1d_dynamic
3700        dim = 3
3701        dtype = torch.quint8
3702
3703        self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)
3704
3705    @override_qengines
3706    def test_dynamic_conv2d(self):
3707        q_mod = torch.ao.nn.quantized.Conv2d
3708        dq_op = torch.ops.quantized.conv2d_dynamic
3709        dim = 4
3710        dtype = torch.quint8
3711
3712        self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)
3713
3714    @override_qengines
3715    def test_dynamic_conv3d(self):
3716        q_mod = torch.ao.nn.quantized.Conv3d
3717        dq_op = torch.ops.quantized.conv3d_dynamic
3718        dim = 5
3719        dtype = torch.quint8
3720
3721        self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)
3722
3723    @override_qengines
3724    def test_dynamic_convtranspose1d(self):
3725        q_mod = torch.ao.nn.quantized.ConvTranspose1d
3726        dq_op = torch.ops.quantized.conv_transpose1d_dynamic
3727        dim = 3
3728        dtype = torch.quint8
3729
3730        self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)
3731
3732    @override_qengines
3733    def test_dynamic_convtranspose2d(self):
3734        q_mod = torch.ao.nn.quantized.ConvTranspose2d
3735        dq_op = torch.ops.quantized.conv_transpose2d_dynamic
3736        dim = 4
3737        dtype = torch.quint8
3738
3739        self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)
3740
3741    @override_qengines
3742    def test_dynamic_convtranspose3d(self):
3743        q_mod = torch.ao.nn.quantized.ConvTranspose3d
3744        dq_op = torch.ops.quantized.conv_transpose3d_dynamic
3745        dim = 5
3746        dtype = torch.quint8
3747
3748        if qengine_is_qnnpack():
3749            return  # TODO: fix MakeDeConvOutputShape overflowing for convT3d with qnnpack
3750        self._test_qconv_op_impl(q_mod, dq_op, dim, dtype)
3751
3752
3753class TestQuantizedLinear(TestCase):
3754    def _test_qlinear_impl(self, batch_size, input_channels, output_channels, use_bias,
3755                           post_op, use_multi_dim_input, use_channelwise, **post_op_kwargs):
3756        decimal_val = 4
3757        dtypes = [torch.quint8]
3758        if torch.backends.quantized.engine == 'qnnpack':
3759            # QNNPACK supports uint8 in the kernels. In the op we shift the int8
3760            # weight values to uint8 to be on par with fbgemm. However, this causes
3761            # some rounding issues in rare cases. So, we relax the check to allow
3762            # off by one results.
3763            decimal_val = 0
3764
3765            # only qnnpack qengine supports qint8 when xnnpack is available
3766            if torch.backends.xnnpack.enabled:
3767                dtypes.append(torch.qint8)
3768
3769        for dtype in dtypes:
3770            # No support for channelwise in xnnpack (int8)
3771            # ONEDNN does not support qint8
3772            if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()):
3773                return
3774
3775            nptype = np_dtype[dtype]
3776            qlinear_prepack = torch.ops.quantized.linear_prepack
3777            if post_op == 'relu':
3778                qlinear = torch.ops.quantized.linear_relu
3779            elif post_op == 'leaky_relu':
3780                qlinear = torch.ops.quantized.linear_leaky_relu
3781            else:
3782                qlinear = torch.ops.quantized.linear
3783            if use_multi_dim_input:
3784                batch_size *= 3  # Test the multi-dim input tensor
3785            X_scale = 1.5
3786            X_zp = 5
3787            X_value_min = -128 if dtype == torch.qint8 else 0
3788            X_value_max = 127 if dtype == torch.qint8 else 255
3789            X_q0 = np.round(
3790                np.random.rand(batch_size, input_channels) *
3791                (X_value_max - X_value_min)
3792                + X_value_min
3793            ).astype(nptype)
3794
3795            W_scales = np.random.rand(output_channels)
3796            # xnnpack forces W_zp to 0 when using symmetric quantization
3797            # ONEDNN only supports symmetric quantization of weight
3798            if dtype == torch.qint8 or qengine_is_onednn():
3799                W_zps = np.zeros(output_channels).astype(int)
3800            else:
3801                W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(int)
3802            # when using symmetric quantization
3803            # special restriction for xnnpack fully connected op weight
3804            # [-127, 127] instead of [-128, 127]
3805            W_value_min = -127 if dtype == torch.qint8 else -128
3806            W_value_max = 127
3807            W_q0 = np.round(
3808                np.random.rand(output_channels, input_channels)
3809                * (W_value_max - W_value_min)
3810                + W_value_min
3811            ).astype(np.int8)  # weight is always int8_t
3812            b_value_min = -10
3813            b_value_max = 10
3814            b_q0 = np.round(
3815                np.random.rand(output_channels) *
3816                (b_value_max - b_value_min) + b_value_min
3817            ).astype(np.int32) if use_bias else None
3818            if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
3819                avoid_vpmaddubsw_overflow_linear(
3820                    batch_size,
3821                    input_channels,
3822                    output_channels,
3823                    X_q0,
3824                    X_value_min,
3825                    X_value_max,
3826                    W_q0,
3827                    W_value_min,
3828                    W_value_max,
3829                )
3830            X = torch.from_numpy(_dequantize(
3831                X_q0, X_scale, X_zp)).to(dtype=torch.float)
3832            X_q = torch.quantize_per_tensor(
3833                X, scale=X_scale, zero_point=X_zp, dtype=dtype)
3834            if use_channelwise:
3835                W = torch.from_numpy(_dequantize(W_q0, W_scales.reshape(
3836                    (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float)
3837                W_q = torch.quantize_per_channel(W, scales=torch.from_numpy(W_scales),
3838                                                 zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8)
3839                b = torch.from_numpy(_dequantize(
3840                    b_q0, X_scale * W_scales, 0)).to(dtype=torch.float) if use_bias else None
3841                b_q = torch.quantize_per_channel(b, scales=torch.from_numpy(X_scale * W_scales),
3842                                                 zero_points=torch.zeros(output_channels, dtype=torch.long),
3843                                                 axis=0, dtype=torch.qint32) if use_bias else None
3844            else:
3845                W = torch.from_numpy(_dequantize(
3846                    W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float)
3847                W_q = torch.quantize_per_tensor(W, scale=W_scales[0], zero_point=(
3848                    W_zps[0].astype(int).item()), dtype=torch.qint8)
3849                b = torch.from_numpy(_dequantize(
3850                    b_q0, X_scale * (W_scales[0].item()), 0)).to(dtype=torch.float) if use_bias else None
3851                b_q = torch.quantize_per_tensor(
3852                    b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) if use_bias else None
3853            # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
3854            # Y_scale * 255 (max for uint8).
3855            Y_scale = 12.34
3856            Y_zp = 5
3857            # Weight prepacking operator for quantized Linear
3858            float_bias = b if use_bias else None
3859            W_prepack = qlinear_prepack(W_q, float_bias)
3860            if use_multi_dim_input:
3861                X_q = X_q.view(3, int(batch_size / 3), input_channels)
3862            # Quantized Linear operator with prepacked weight
3863            Y_q = qlinear(X_q, W_prepack, Y_scale, Y_zp, **post_op_kwargs)
3864            if not use_channelwise and post_op in ('none', 'relu'):
3865                # Test the per-tensor quantization only
3866                # Reference quantized Linear operator
3867                Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0,
3868                                      W_scales[0], W_zps[0], b_q0, Y_scale, Y_zp, dtype=nptype)
3869                if post_op == 'relu':
3870                    Y_q_ref[Y_q_ref < Y_zp] = Y_zp
3871                if use_multi_dim_input:
3872                    Y_q_ref = np.reshape(
3873                        Y_q_ref, (3, int(batch_size / 3), output_channels))
3874                # Assert equal
3875                np.testing.assert_array_almost_equal(Y_q_ref, Y_q.int_repr().numpy(), decimal=decimal_val)
3876            # Test both per-tensor and per-channel quantization
3877            # Reference quantized result from PyTorch Linear operator
3878            W_fp32 = W_q.dequantize().to(dtype=torch.float)
3879            X_fp32 = X_q.dequantize().to(dtype=torch.float)
3880            b_fp32 = b_q.dequantize().to(dtype=torch.float) if use_bias else None
3881            Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
3882            if post_op == 'relu':
3883                Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0
3884            elif post_op == 'leaky_relu':
3885                Y_fp32_ref = F.leaky_relu(Y_fp32_ref, **post_op_kwargs)
3886            Y_q_ref2 = torch.quantize_per_tensor(
3887                Y_fp32_ref, Y_scale, Y_zp, dtype)
3888            # Assert equal
3889            np.testing.assert_array_almost_equal(
3890                Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=decimal_val)
3891
3892    """Tests the correctness of the quantized linear op."""
3893    @override_qengines
3894    def test_qlinear(self):
3895        batch_size_list = [1, 4]
3896        input_channels_list = [16, 32]
3897        output_channels_list = [4, 8]
3898        use_bias_list = [True, False]
3899        use_multi_dim_input_list = [True, False]
3900        use_channelwise_list = [True, False]
3901        post_op = 'none'
3902        cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
3903                                  use_bias_list, use_multi_dim_input_list, use_channelwise_list)
3904        for batch_size, input_channels, output_channels, use_bias, \
3905                use_multi_dim_input, use_channelwise in cases:
3906            self._test_qlinear_impl(batch_size, input_channels, output_channels,
3907                                    use_bias, post_op, use_multi_dim_input, use_channelwise)
3908
3909    """Tests the correctness of the quantized linear_relu op."""
3910    @override_qengines
3911    def test_qlinear_relu(self):
3912        batch_size_list = [1, 4]
3913        input_channels_list = [16, 32]
3914        output_channels_list = [4, 8]
3915        use_bias_list = [True, False]
3916        use_multi_dim_input_list = [True, False]
3917        use_channelwise_list = [True, False]
3918        post_op = 'relu'
3919        cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
3920                                  use_bias_list, use_multi_dim_input_list, use_channelwise_list)
3921        for batch_size, input_channels, output_channels, use_bias, \
3922                use_multi_dim_input, use_channelwise in cases:
3923            self._test_qlinear_impl(batch_size, input_channels, output_channels,
3924                                    use_bias, post_op, use_multi_dim_input, use_channelwise)
3925
3926    @given(batch_size=st.integers(1, 4),
3927           input_channels=st.integers(16, 32),
3928           output_channels=st.integers(4, 8),
3929           use_bias=st.booleans(),
3930           use_relu=st.booleans(),
3931           use_multi_dim_input=st.booleans(),
3932           use_channelwise=st.booleans())
3933    @skipIfNoFBGEMM
3934    def test_qlinear_with_input_q_dq_qweight_dq_output_fp32(
3935            self, batch_size, input_channels, output_channels, use_bias,
3936            use_relu, use_multi_dim_input, use_channelwise):
3937        decimal_val = 4
3938        dtypes = [torch.quint8]
3939        for dtype in dtypes:
3940            # No support for channelwise in xnnpack (int8)
3941            # ONEDNN does not support qint8
3942            if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()):
3943                return
3944
3945            nptype = np_dtype[dtype]
3946            qlinear_prepack = torch.ops.quantized.linear_prepack
3947            if use_relu:
3948                qlinear = torch.ops.quantized.linear_with_input_q_dq_qweight_dq_relu_output_fp32
3949            else:
3950                qlinear = torch.ops.quantized.linear_with_input_q_dq_qweight_dq_output_fp32
3951            if use_multi_dim_input:
3952                batch_size *= 3  # Test the multi-dim input tensor
3953            X_scale = 1.5
3954            X_zp = 5
3955            X_value_min = -128 if dtype == torch.qint8 else 0
3956            X_value_max = 127 if dtype == torch.qint8 else 255
3957            X_q0 = np.round(
3958                np.random.rand(batch_size, input_channels) *
3959                (X_value_max - X_value_min)
3960                + X_value_min
3961            ).astype(nptype)
3962
3963            W_scales = np.random.rand(output_channels)
3964            # xnnpack forces W_zp to 0 when using symmetric quantization
3965            # ONEDNN only supports symmetric quantization of weight
3966            if dtype == torch.qint8 or qengine_is_onednn():
3967                W_zps = np.zeros(output_channels).astype(int)
3968            else:
3969                W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(int)
3970            # when using symmetric quantization
3971            # special restriction for xnnpack fully connected op weight
3972            # [-127, 127] instead of [-128, 127]
3973            W_value_min = -127 if dtype == torch.qint8 else -128
3974            W_value_max = 127
3975            W_q0 = np.round(
3976                np.random.rand(output_channels, input_channels)
3977                * (W_value_max - W_value_min)
3978                + W_value_min
3979            ).astype(np.int8)  # weight is always int8_t
3980            b_value_min = -10
3981            b_value_max = 10
3982            b_q0 = np.round(
3983                np.random.rand(output_channels) *
3984                (b_value_max - b_value_min) + b_value_min
3985            ).astype(np.int32) if use_bias else None
3986            if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
3987                avoid_vpmaddubsw_overflow_linear(
3988                    batch_size,
3989                    input_channels,
3990                    output_channels,
3991                    X_q0,
3992                    X_value_min,
3993                    X_value_max,
3994                    W_q0,
3995                    W_value_min,
3996                    W_value_max,
3997                )
3998            X = torch.from_numpy(_dequantize(
3999                X_q0, X_scale, X_zp)).to(dtype=torch.float)
4000            X_q = torch.quantize_per_tensor(
4001                X, scale=X_scale, zero_point=X_zp, dtype=dtype)
4002            if use_channelwise:
4003                W = torch.from_numpy(_dequantize(W_q0, W_scales.reshape(
4004                    (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float)
4005                W_q = torch.quantize_per_channel(W, scales=torch.from_numpy(W_scales),
4006                                                 zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8)
4007                b = torch.from_numpy(_dequantize(
4008                    b_q0, X_scale * W_scales, 0)).to(dtype=torch.float) if use_bias else None
4009                b_q = torch.quantize_per_channel(b, scales=torch.from_numpy(X_scale * W_scales),
4010                                                 zero_points=torch.zeros(output_channels, dtype=torch.long),
4011                                                 axis=0, dtype=torch.qint32) if use_bias else None
4012            else:
4013                W = torch.from_numpy(_dequantize(
4014                    W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float)
4015                W_q = torch.quantize_per_tensor(W, scale=W_scales[0], zero_point=(
4016                    W_zps[0].astype(int).item()), dtype=torch.qint8)
4017                b = torch.from_numpy(_dequantize(
4018                    b_q0, X_scale * (W_scales[0].item()), 0)).to(dtype=torch.float) if use_bias else None
4019                b_q = torch.quantize_per_tensor(
4020                    b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) if use_bias else None
4021            # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
4022            # Y_scale * 255 (max for uint8).
4023            Y_scale = 125.1234
4024            Y_zp = 5
4025            # Weight prepacking operator for quantized Linear
4026            float_bias = b if use_bias else None
4027            W_prepack = qlinear_prepack(W_q, float_bias)
4028            if use_multi_dim_input:
4029                X = X.view(3, int(batch_size / 3), input_channels)
4030                X_q = X_q.view(3, int(batch_size / 3), input_channels)
4031            # Quantized Linear operator with prepacked weight
4032            Y_q_dq = qlinear(X, X_scale, X_zp, W_prepack)
4033            # Test both per-tensor and per-channel quantization
4034            # Reference quantized result from PyTorch Linear operator
4035            W_fp32 = W_q.dequantize().to(dtype=torch.float)
4036            X_fp32 = X_q.dequantize().to(dtype=torch.float)
4037            b_fp32 = b_q.dequantize().to(dtype=torch.float) if use_bias else None
4038            Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
4039            if use_relu:
4040                Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0
4041            decimal_val = 1
4042            np.testing.assert_array_almost_equal(Y_fp32_ref.numpy(), Y_q_dq.numpy(), decimal=decimal_val)
4043
4044    @given(batch_size=st.integers(1, 4),
4045           # in cudnn v. 8.4.0, there is a limitation that input channels
4046           # should be a multiple of 4 for int8 tensors. in cudnn v.8.3.3
4047           # this should be a multiple of 16
4048           input_channels=st.sampled_from([4, 8, 12, 16, 32]),
4049           # constraints on output channels appear to be relax, as it seems we can use any positive integer here
4050           # except 1. It is not clear why 1 will not work. TODO: check with Yang
4051           output_channels=st.integers(2, 36),
4052           use_bias=st.booleans(),
4053           use_relu=st.booleans(),
4054           use_multi_dim_input=st.booleans(),
4055           use_channelwise=st.sampled_from([False]))  # channelwise currently not supported for qlinear cudnn
4056    @skipIfNoFBGEMM
4057    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
4058    @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0")
4059    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
4060    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
4061    # TODO: check with yang regarding CUDNN flags
4062    @unittest.skip("not currently working and feature isn't used")
4063    def test_qlinear_cudnn(self, batch_size, input_channels, output_channels, use_bias,
4064                           use_relu, use_multi_dim_input, use_channelwise):
4065        qlinear_prepack = torch.ops.quantized.linear_prepack
4066        if use_relu:
4067            qlinear_op = torch.ops.quantized.linear_relu
4068        else:
4069            qlinear_op = torch.ops.quantized.linear
4070        X_scale = 1.5
4071        X_zp = 0
4072        X_value_min = -128
4073        X_value_max = 127
4074        X_q0 = np.round(
4075            np.random.rand(batch_size, input_channels) *
4076            (X_value_max - X_value_min)
4077            + X_value_min).astype(np.int8)
4078        W_scale = 2.5
4079        W_zp = 0
4080        W_value_min = -128
4081        W_value_max = 127
4082        W_q0 = np.round(
4083            np.random.rand(output_channels, input_channels)
4084            * (W_value_max - W_value_min)
4085            + W_value_min
4086        ).astype(np.int8)
4087        b_value_min = -10
4088        b_value_max = 10
4089        b_q0 = np.round(
4090            np.random.rand(output_channels) *
4091            (b_value_max - b_value_min) + b_value_min
4092        ).astype(np.int32) if use_bias else None
4093        if use_bias:
4094            b_value_min = -10
4095            b_value_max = 10
4096            b_q0 = np.round(
4097                np.random.rand(output_channels) *
4098                (b_value_max - b_value_min) + b_value_min
4099            ).astype(np.int32)
4100        else:
4101            bias = None
4102        avoid_vpmaddubsw_overflow_linear(
4103            batch_size,
4104            input_channels,
4105            output_channels,
4106            X_q0,
4107            X_value_min,
4108            X_value_max,
4109            W_q0,
4110            W_value_min,
4111            W_value_max,
4112        )
4113        quant_dtype = torch.qint8
4114        X = torch.from_numpy(_dequantize(
4115            X_q0, X_scale, X_zp)).to(dtype=torch.float).to(device="cuda")
4116        X_q = torch.quantize_per_tensor(
4117            X, scale=X_scale, zero_point=X_zp, dtype=quant_dtype)
4118        W = torch.from_numpy(_dequantize(
4119            W_q0, W_scale, W_zp)).to(dtype=torch.float).to(device="cuda")
4120        W_q = torch.quantize_per_tensor(W, scale=W_scale, zero_point=W_zp, dtype=quant_dtype)
4121        b = torch.from_numpy(_dequantize(
4122            b_q0, X_scale * (W_zp), 0)).to(dtype=torch.float).to(device="cuda") if use_bias else None
4123        b_q = torch.quantize_per_tensor(
4124            b, scale=X_scale * W_scale, zero_point=0, dtype=quant_dtype) if use_bias else None
4125        Y_scale = 0.5
4126        Y_zp = 0
4127        # Weight prepacking operator for quantized Linear
4128        float_bias = b if use_bias else None
4129        W_prepack = qlinear_prepack(W_q, float_bias if use_bias else None)
4130        # Quantized Linear operator with prepacked weight
4131        Y_q = qlinear_op(X_q, W_prepack, Y_scale, Y_zp).to(device="cpu")
4132        Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0,
4133                              W_scale, W_zp, b_q0, Y_scale, Y_zp, dtype=np.int8)
4134        if use_relu:
4135            Y_q_ref[Y_q_ref < Y_zp] = Y_zp
4136        decimal_val = 0
4137        np.testing.assert_array_almost_equal(Y_q_ref, Y_q.int_repr().numpy(), decimal=decimal_val)
4138
4139    """Tests the correctness of the quantized::linear_unpack op."""
4140    @given(W=hu.tensor(shapes=hu.array_shapes(2, 2,),
4141                       qparams=hu.qparams(dtypes=torch.qint8)),
4142           use_channelwise=st.booleans())
4143    @override_qengines
4144    def test_qlinear_unpack(self, W, use_channelwise):
4145        W, (W_scale, W_zp, torch_type) = W
4146        if use_channelwise:
4147            output_channels = W.shape[0]
4148            W_scales = torch.rand(output_channels).to(torch.double)
4149            W_zps = torch.round(torch.rand(output_channels)
4150                                * 100 - 50).to(torch.int64)
4151        qlinear_prepack = torch.ops.quantized.linear_prepack
4152        qlinear_unpack = torch.ops.quantized.linear_unpack
4153
4154        # ONEDNN only supports symmetric quantization of weight
4155        if qengine_is_onednn():
4156            if use_channelwise:
4157                W_zps = torch.zeros(output_channels).to(torch.int64)
4158            else:
4159                W_zp = 0
4160
4161        W = torch.from_numpy(W)
4162        if use_channelwise:
4163            W_q = torch.quantize_per_channel(
4164                W, W_scales, W_zps, 0, dtype=torch_type)
4165        else:
4166            W_q = torch.quantize_per_tensor(W, scale=W_scale, zero_point=W_zp,
4167                                            dtype=torch_type)
4168        # Weight prepacking operator for quantized Linear
4169        W_prepack = qlinear_prepack(W_q)
4170        # Weight unpack operator for quantized Linear (Used for serialization)
4171        W_q_origin = qlinear_unpack(W_prepack)[0]
4172        # Assert equal
4173        np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy())
4174        if use_channelwise:
4175            np.testing.assert_array_almost_equal(np.float32(W_q.q_per_channel_scales().numpy()),
4176                                                 np.float32(
4177                                                     W_q_origin.q_per_channel_scales().numpy()),
4178                                                 decimal=4)
4179            np.testing.assert_equal(W_q.q_per_channel_zero_points(
4180            ).numpy(), W_q_origin.q_per_channel_zero_points().numpy())
4181        else:
4182            np.testing.assert_equal(np.float32(
4183                W_q.q_scale()), np.float32(W_q_origin.q_scale()))
4184            np.testing.assert_equal(
4185                W_q.q_zero_point(), W_q_origin.q_zero_point())
4186
4187    """Tests the correctness of the _quantized::wrapped_quantized_linear op."""
4188    @skipIfNoFBGEMM
4189    @given(
4190        m=st.integers(2, 6),
4191        k=st.integers(2, 6),
4192        n=st.integers(2, 6),
4193    )
4194    def test_wrapped_quantized_linear(self, m, n, k):
4195        input = torch.randn(m, k, dtype=torch.float32)
4196        input_scale = torch.tensor(0.1)
4197        input_zero_point = torch.tensor(0)
4198        weight = torch.randn(n, k, dtype=torch.float32)
4199        weight_scale = torch.tensor(0.1)
4200        weight_zero_point = torch.tensor(0)
4201        bias = torch.randn(n, dtype=torch.float32)
4202        output_scale = torch.tensor(0.1)
4203        output_zero_point = torch.tensor(0)
4204        out_channel = n
4205
4206        ret = torch.ops._quantized.wrapped_quantized_linear(
4207            input,
4208            input_scale,
4209            input_zero_point,
4210            weight,
4211            weight_scale,
4212            weight_zero_point,
4213            bias,
4214            output_scale,
4215            output_zero_point,
4216            out_channel,
4217        )
4218
4219        qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8)
4220        qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8)
4221        qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias)
4222        qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point)
4223        ret_ref = qlinear.dequantize()
4224        self.assertEqual(ret, ret_ref)
4225
4226    """Tests the correctness of the _quantized::_wrapped_linear_prepack and
4227    _quantized::_wrapped_quantized_linear_prepacked ops."""
4228    @skipIfNoFBGEMM
4229    @given(
4230        m=st.integers(2, 6),
4231        k=st.integers(2, 6),
4232        n=st.integers(2, 6),
4233    )
4234    def test_wrapped_quantized_linear_prepacked(self, m, n, k):
4235        input = torch.randn(m, k, dtype=torch.float32)
4236        input_scale = torch.tensor(0.1)
4237        input_zero_point = torch.tensor(0)
4238        weight = torch.randn(n, k, dtype=torch.float32)
4239        weight_scale = torch.tensor(0.1)
4240        weight_zero_point = torch.tensor(0)
4241        bias = torch.randn(n, dtype=torch.float32)
4242        output_scale = torch.tensor(0.1)
4243        output_zero_point = torch.tensor(0)
4244        out_channel = n
4245
4246        ret_1 = torch.ops._quantized._wrapped_linear_prepack(
4247            weight,
4248            weight_scale,
4249            weight_zero_point,
4250            bias
4251        )
4252        ret_2 = torch.ops._quantized._wrapped_quantized_linear_prepacked(
4253            input,
4254            input_scale,
4255            input_zero_point,
4256            ret_1,
4257            output_scale,
4258            output_zero_point,
4259            out_channel
4260        )
4261        qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8)
4262        qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8)
4263        qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias)
4264        qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point)
4265        ret_ref = qlinear.dequantize()
4266        self.assertEqual(ret_2, ret_ref)
4267
4268    """Tests the correctness of the quantized::linear_unpack after freeing original tensor op."""
4269    @skipIfNoQNNPACK
4270    @given(W=hu.tensor(shapes=hu.array_shapes(2, 2,),
4271                       qparams=hu.qparams(dtypes=torch.qint8)))
4272    @override_qengines
4273    def test_qlinear_qnnpack_free_memory_and_unpack(self, W):
4274        assert qengine_is_qnnpack
4275        W, (W_scale, W_zp, torch_type) = W
4276        qlinear_prepack = torch.ops.quantized.linear_prepack
4277        qlinear_unpack = torch.ops.quantized.linear_unpack
4278
4279        W = torch.from_numpy(W)
4280        # ONEDNN only supports symmetric quantization of weight
4281        if qengine_is_onednn():
4282            W_zp = 0
4283        W_q = torch.quantize_per_tensor(W, scale=W_scale, zero_point=W_zp, dtype=torch_type)
4284        # Weight prepacking operator for quantized Linear
4285        W_prepack = qlinear_prepack(W_q)
4286        dummy_input = torch.randn((1, W.shape[1]))
4287        # Make sure we free original tensor by running matrix multiplication in backend.
4288        torch.ops.quantized.linear_dynamic(dummy_input, W_prepack)
4289        torch.ops.quantized.linear_dynamic(dummy_input, W_prepack)
4290        # At this step, original tensor should be recovered from a data_ptr
4291        W_q_origin = qlinear_unpack(W_prepack)[0]
4292        # Assert equal
4293        np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy())
4294        np.testing.assert_equal(np.float32(
4295            W_q.q_scale()), np.float32(W_q_origin.q_scale()))
4296        np.testing.assert_equal(
4297            W_q.q_zero_point(), W_q_origin.q_zero_point())
4298
4299    @skipIfNoONEDNN
4300    def test_qlinear_leaky_relu(self):
4301        with override_quantized_engine('onednn'):
4302            batch_size_list = [1, 4]
4303            input_channels_list = [16, 32]
4304            output_channels_list = [4, 8]
4305            use_bias_list = [True, False]
4306            use_multi_dim_input_list = [True, False]
4307            use_channelwise_list = [True, False]
4308            negative_slopes_list = [0.01, 0.05]
4309            post_op = 'leaky_relu'
4310            cases = itertools.product(batch_size_list, input_channels_list, output_channels_list,
4311                                      use_bias_list, use_multi_dim_input_list,
4312                                      use_channelwise_list, negative_slopes_list)
4313            for batch_size, input_channels, output_channels, use_bias, \
4314                    use_multi_dim_input, use_channelwise, neg_slope in cases:
4315                self._test_qlinear_impl(batch_size, input_channels, output_channels,
4316                                        use_bias, post_op, use_multi_dim_input,
4317                                        use_channelwise, negative_slope=neg_slope)
4318
4319    @skipIfNoONEDNN
4320    def test_qlinear_tanh(self):
4321        with override_quantized_engine('onednn'):
4322            batch_size_list = [1, 4]
4323            input_channels_list = [16, 32]
4324            output_channels_list = [4, 8]
4325            use_bias_list = [True, False]
4326            use_multi_dim_input_list = [True, False]
4327            use_channelwise_list = [True, False]
4328            post_op = 'tanh'
4329            cases = itertools.product(batch_size_list, input_channels_list,
4330                                      output_channels_list, use_bias_list,
4331                                      use_multi_dim_input_list, use_channelwise_list)
4332            for batch_size, input_channels, output_channels, use_bias, \
4333                    use_multi_dim_input, use_channelwise in cases:
4334                self._test_qlinear_impl(batch_size, input_channels, output_channels,
4335                                        use_bias, post_op, use_multi_dim_input,
4336                                        use_channelwise)
4337
4338    def _test_qlinear_pt2e_helper(
4339        self,
4340        qlinear_op,
4341        post_op="none",
4342        unary_post_op_args=(),
4343        post_op_algorithms=("none"),
4344    ):
4345        qlinear_prepack = torch.ops.onednn.qlinear_prepack
4346        linear_op = F.linear
4347        in_channels_list = [4, 8]
4348        out_channels_list = [16, 32]
4349        batch_size = 1
4350        use_bias_list = [True, False]
4351        weight_quant_per_channel_list = [True, False]
4352        output_dtype_list = [None, torch.float32, torch.bfloat16]
4353        x_scale, x_zp = 1.2, 1
4354        w_scale, w_zp = 0.8, 0
4355        y_scale, y_zp = 4.7, 2
4356        input_dim_list = [2, 3]
4357        cases = itertools.product(
4358            in_channels_list, out_channels_list, use_bias_list,
4359            weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list)
4360        with override_quantized_engine('onednn'):
4361            for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases:
4362                used_y_scale = y_scale
4363                used_y_zp = y_zp
4364                fp32_out = output_dtype == torch.float32
4365                bfloat16_out = output_dtype == torch.bfloat16
4366                if fp32_out or bfloat16_out:
4367                    used_y_scale, used_y_zp = 1.0, 0
4368                    x2_scale, x2_zp = 1.0, 0
4369                else:
4370                    x2_scale, x2_zp = 2.3, 5
4371                x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
4372                w = torch.rand(oc, ic) * 10
4373                qx = torch.quantize_per_tensor(x, x_scale, x_zp, torch.quint8)
4374                if weight_quant_per_channel:
4375                    w_scales = torch.Tensor([w_scale] * oc)
4376                    w_zps = torch.zeros(oc).to(dtype=torch.int)
4377                    qw = torch.quantize_per_channel(w, w_scales, w_zps, 0, torch.qint8)
4378                else:
4379                    w_scales = torch.Tensor([w_scale])
4380                    w_zps = torch.Tensor([w_zp]).to(dtype=torch.int)
4381                    qw = torch.quantize_per_tensor(w, w_scale, w_zp, torch.qint8)
4382                if use_bias:
4383                    b = torch.rand(oc) * 10
4384                else:
4385                    b = None
4386
4387                x_ref = qx.dequantize()
4388                w_ref = qw.dequantize()
4389                y_ref = linear_op(x_ref, w_ref, b)
4390
4391                # compute with CPU tensors
4392                qx_cpu = qx.int_repr()
4393                qw_cpu = qw.int_repr()
4394                qw_packed = qlinear_prepack(qw_cpu, x.shape)
4395
4396                if post_op in ("none", "relu", "gelu"):
4397                    qy_cpu = qlinear_op(
4398                        qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4399                        b, used_y_scale, used_y_zp, output_dtype,
4400                        post_op, unary_post_op_args, post_op_algo
4401                    )
4402                    if post_op == "relu":
4403                        y_ref = F.relu(y_ref)
4404                    elif post_op == "gelu":
4405                        y_ref = F.gelu(y_ref, approximate=post_op_algo)
4406                    qy_ref = torch.quantize_per_tensor(y_ref, used_y_scale, used_y_zp, torch.quint8)
4407                elif post_op in ("sum", "sum_relu"):
4408                    x2_int8 = torch.randint(0, 4, y_ref.size())
4409                    x2 = x2_scale * ((x2_int8 - x2_zp).float())
4410                    qx2 = torch.quantize_per_tensor(
4411                        x2, scale=x2_scale, zero_point=x2_zp, dtype=torch.quint8
4412                    )
4413                    unary_post_op = "relu" if post_op == "sum_relu" else "none"
4414                    binary_alpha = 1.0  # we only support alpha=1.0 now
4415                    accum = qx2.int_repr() if output_dtype is None else qx2.dequantize()
4416                    if bfloat16_out:
4417                        accum = accum.bfloat16()
4418                    qy_cpu = qlinear_op(
4419                        qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4420                        accum, b, used_y_scale, used_y_zp, output_dtype,
4421                        x2_scale, x2_zp, "sum", binary_alpha,
4422                        unary_post_op, unary_post_op_args, post_op_algo
4423                    )
4424                    y_ref = y_ref + x2 * binary_alpha
4425                    if unary_post_op == "relu":
4426                        y_ref = F.relu(y_ref)
4427                    qy_ref = torch.quantize_per_tensor(y_ref, used_y_scale, used_y_zp, torch.quint8)
4428                elif post_op in ("add", "add_relu"):
4429                    used_y_scale, used_y_zp = 1.0, 0
4430                    if output_dtype is not None:
4431                        # Only support int8 output
4432                        continue
4433                    x2 = torch.randn(y_ref.size()) * 10
4434                    unary_post_op = "relu" if post_op == "add_relu" else "none"
4435                    binary_alpha = 1.0  # we only support alpha=1.0 now
4436                    qy_cpu = qlinear_op(
4437                        qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
4438                        x2, b, used_y_scale, used_y_zp, output_dtype,
4439                        1.0, 0, "add", binary_alpha,
4440                        unary_post_op, unary_post_op_args, post_op_algo
4441                    )
4442                    y_ref = y_ref + x2 * binary_alpha
4443                    if unary_post_op == "relu":
4444                        y_ref = F.relu(y_ref)
4445                    qy_ref = torch.quantize_per_tensor(y_ref, used_y_scale, used_y_zp, torch.quint8)
4446
4447                # Compare results
4448                if fp32_out or bfloat16_out:
4449                    qy_cpu = torch.quantize_per_tensor(
4450                        qy_cpu.to(torch.float32),
4451                        used_y_scale,
4452                        used_y_zp, dtype=torch.quint8
4453                    ).int_repr()
4454
4455                self.assertEqual(x.dim(), qy_cpu.dim())
4456
4457                np.testing.assert_array_almost_equal(
4458                    qy_ref.int_repr().cpu().numpy(),
4459                    qy_cpu.cpu().numpy(),
4460                    decimal=0,
4461                    err_msg=f"""X: {x}, W: {w}, b: {b},
4462                    x_s: {x_scale}, x_zp: {x_zp},
4463                    w_s: {w_scale}, w_zp: {w_zp},
4464                    y_s: {y_scale}, y_zp: {y_zp}""",
4465                )
4466
4467    @skipIfNoONEDNN
4468    def test_qlinear_pt2e(self):
4469        qlinear = torch.ops.onednn.qlinear_pointwise
4470        self._test_qlinear_pt2e_helper(qlinear, "none")
4471
4472    @skipIfNoONEDNN
4473    def test_qlinear_relu_pt2e(self):
4474        qlinear = torch.ops.onednn.qlinear_pointwise
4475        self._test_qlinear_pt2e_helper(qlinear, "relu")
4476
4477    @skipIfNoONEDNN
4478    def test_qlinear_gelu_pt2e(self):
4479        qlinear = torch.ops.onednn.qlinear_pointwise
4480        post_op_algorithms = ['none', 'tanh']
4481        self._test_qlinear_pt2e_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms)
4482
4483    @skipIfNoONEDNN
4484    def test_qlinear_sum_pt2e(self):
4485        qlinear = torch.ops.onednn.qlinear_pointwise.binary
4486        self._test_qlinear_pt2e_helper(qlinear, "sum")
4487
4488    @skipIfNoONEDNN
4489    def test_qlinear_sum_relu_pt2e(self):
4490        qlinear = torch.ops.onednn.qlinear_pointwise.binary
4491        self._test_qlinear_pt2e_helper(qlinear, "sum_relu")
4492
4493    @skipIfNoONEDNN
4494    def test_qlinear_add_pt2e(self):
4495        qlinear = torch.ops.onednn.qlinear_pointwise.binary
4496        self._test_qlinear_pt2e_helper(qlinear, "add")
4497
4498    @skipIfNoONEDNN
4499    def test_qlinear_add_relu_pt2e(self):
4500        qlinear = torch.ops.onednn.qlinear_pointwise.binary
4501        self._test_qlinear_pt2e_helper(qlinear, "add_relu")
4502
4503
4504@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
4505class TestQuantizedEmbeddingOps(TestCase):
4506
4507    def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimized_qparams, weights):
4508        data_type = weights.dtype
4509
4510        qtype = torch.quint8
4511        if bit_rate == 8:
4512            w_packed = pack_fn(weights)
4513        else:
4514            w_packed = pack_fn(weights, optimized_qparams=optimized_qparams)
4515        w_unpacked = unpack_fn(w_packed)
4516
4517        if (bit_rate == 8 or bit_rate == 4) and data_type != torch.float16:
4518            # torch.quantize_per_channel does not support float16 yet.
4519
4520            obs_weights = weights
4521            # Combine 3D embeddings (e.g. stacked combination of embeddings)
4522            # in a dimension orthogonal to channels.
4523            if (len(obs_weights.shape) > 2):
4524                stacked_shape = list(weights.size())
4525                stacked_shape[1] *= stacked_shape[0]
4526                obs_weights = weights.reshape(stacked_shape[1:])
4527
4528            # Check numerics of prepack function that accepts qtensor as input.
4529            # We use min-max observer to mimic the quantization performed in the original function.
4530            obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
4531            obs(obs_weights)
4532            # Get the scale and zero point for the weight tensor
4533            qparams = obs.calculate_qparams()
4534            if bit_rate == 4:
4535                qtype = torch.quint4x2
4536            # Quantize the weights to 8bits
4537            qweight = torch.quantize_per_channel(obs_weights, qparams[0], qparams[1], axis=0, dtype=qtype)
4538            real_packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
4539            self.assertEqual(isinstance(real_packed_weight, torch._C.ScriptObject), True)
4540            unpacked_weight = torch.ops.quantized.embedding_bag_unpack(real_packed_weight)
4541            self.assertEqual(unpacked_weight.int_repr().numpy(), qweight.int_repr().numpy())
4542            self.assertEqual(unpacked_weight.q_per_channel_scales(), qweight.q_per_channel_scales())
4543            self.assertEqual(unpacked_weight.q_per_channel_zero_points(), qweight.q_per_channel_zero_points())
4544
4545
4546
4547
4548    def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate,
4549                                      optimized_qparams, num_batches, data_type=np.float32):
4550
4551        # when num_batches = 1, it will create a 2D tensor
4552        unsplit_weight = torch.from_numpy((np.random.random_sample((
4553            num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32))
4554
4555        # test unsplit weight (memory format is `contiguous`)
4556        self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, unsplit_weight)
4557
4558        # test split weights (memory format is not `contiguous`)
4559        split_dim = len(unsplit_weight.shape) - 2
4560        split_weights = torch.split(unsplit_weight, 1, dim=split_dim)
4561        for weight in split_weights:
4562            self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight)
4563
4564
4565
4566    def embedding_bag_rowwise_offsets_run(
4567            self, bit_rate, num_embeddings,
4568            embedding_dim, num_offsets,
4569            use_32bit_indices, use_32bit_offsets,
4570            enable_per_sample_weights,
4571            include_last_offset, fallback_to_no_sparse, sparsity, atol, rtol):
4572        pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
4573        pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack
4574        if bit_rate == 4:
4575            pt_op = torch.ops.quantized.embedding_bag_4bit_rowwise_offsets
4576            pt_prepack_op = torch.ops.quantized.embedding_bag_4bit_prepack
4577        elif bit_rate == 2:
4578            pt_op = torch.ops.quantized.embedding_bag_2bit_rowwise_offsets
4579            pt_prepack_op = torch.ops.quantized.embedding_bag_2bit_prepack
4580
4581        weights = torch.from_numpy((np.random.random_sample((
4582            num_embeddings, embedding_dim)) + 1).astype(np.float32))
4583
4584        max_segments = 5
4585        max_segment_length = 20
4586        num_lengths = np.random.randint(1, max_segments + 1)
4587        lengths = np.random.randint(0, max_segment_length + 1,
4588                                    size=num_lengths).astype(np.int32)
4589        num_indices = np.sum(lengths)
4590
4591        def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
4592            """
4593            Convert lengths to offsets
4594            """
4595            tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
4596            tt[1:] = t
4597            tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
4598            if use_begin_offset:
4599                return tt[:-1]
4600            return tt[1:]
4601
4602        offsets = lengths_to_offsets(lengths)
4603        indices = torch.from_numpy(np.random.randint(
4604            low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
4605
4606        q_weights = pt_prepack_op(weights)
4607        per_sample_weights = torch.from_numpy(np.random.uniform(
4608            low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) if \
4609            enable_per_sample_weights else None
4610        if include_last_offset:
4611            offsets = torch.cat(
4612                (offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0
4613            )
4614
4615        # Reference result will be the floating point torch.nn.EmbeddingBag.
4616        def get_reference_result(
4617                num_embeddings, embedding_dim,
4618                include_last_offset, weights, per_sample_weights,
4619                indices, offsets):
4620            embedding_bag = torch.nn.EmbeddingBag(
4621                num_embeddings=num_embeddings,
4622                embedding_dim=embedding_dim,
4623                include_last_offset=include_last_offset, _weight=weights,
4624                scale_grad_by_freq=False, mode='sum'
4625            )
4626            return embedding_bag(indices, offsets,
4627                                 per_sample_weights=per_sample_weights)
4628
4629        mapping_table = np.zeros(num_embeddings, dtype=np.int32)
4630        pruned_weights = weights
4631        prune_weights = sparsity > 0
4632        if prune_weights:
4633            if fallback_to_no_sparse:
4634                # Testing that prune_weight with mapping_table {0} will
4635                # fallback to non sparse embedding look up kernel.
4636                mapping_table = np.zeros(1, dtype=np.int32)
4637            else:
4638                # Prune and generate mapping table
4639                num_compressed_rows = 0
4640                unpruned_ids = []
4641                for i in range(num_embeddings):
4642                    if np.random.uniform() < sparsity:
4643                        mapping_table[i] = -1
4644                        q_weights[i, :] = 0
4645                        weights[i, :] = 0
4646                    else:
4647                        mapping_table[i] = num_compressed_rows
4648                        num_compressed_rows += 1
4649                        unpruned_ids.append(i)
4650                q_weights = q_weights[unpruned_ids]
4651                pruned_weights = weights[unpruned_ids]
4652
4653        result = pt_op(q_weights,
4654                       indices.int() if use_32bit_indices else indices,
4655                       offsets.int() if use_32bit_offsets else offsets,
4656                       mode=0,
4657                       pruned_weights=prune_weights,
4658                       per_sample_weights=per_sample_weights,
4659                       compressed_indices_mapping=torch.tensor(mapping_table),
4660                       include_last_offset=include_last_offset)
4661
4662        reference_result = get_reference_result(
4663            num_embeddings, embedding_dim, include_last_offset, weights,
4664            per_sample_weights, indices, offsets)
4665
4666        torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol)
4667
4668
4669        if bit_rate == 8 or bit_rate == 4:
4670            # Test operator that accepts TorchBind packed weights.
4671            if bit_rate == 4:
4672                qdtype = torch.quint4x2
4673                op = torch.ops.quantized.embedding_bag_4bit
4674            else:
4675                qdtype = torch.quint8
4676                op = torch.ops.quantized.embedding_bag_byte
4677            obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
4678            obs(pruned_weights)
4679            # Get the scale and zero point for the weight tensor
4680            qparams = obs.calculate_qparams()
4681            # Quantize the weights to 8bits
4682            qweight = torch.quantize_per_channel(pruned_weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
4683            packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
4684            result = op(packed_weight, indices, offsets, mode=0,
4685                        pruned_weights=prune_weights,
4686                        per_sample_weights=per_sample_weights,
4687                        compressed_indices_mapping=torch.tensor(mapping_table),
4688                        include_last_offset=include_last_offset)
4689            torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol)
4690
4691    """ Tests the correctness of the embedding_bag_8bit quantized operator """
4692    @given(num_embeddings=st.integers(10, 100),
4693           embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
4694           num_offsets=st.integers(1, 20),
4695           use_32bit_indices=st.booleans(),
4696           use_32bit_offsets=st.booleans(),
4697           enable_per_sample_weights=st.booleans(),
4698           include_last_offset=st.booleans(),
4699           fallback_to_no_sparse=st.booleans(),
4700           sparsity=st.sampled_from([0.0, 0.5, 0.7]))
4701    def test_embedding_bag_byte(self, num_embeddings,
4702                                embedding_dim, num_offsets,
4703                                use_32bit_indices,
4704                                use_32bit_offsets,
4705                                enable_per_sample_weights,
4706                                include_last_offset,
4707                                fallback_to_no_sparse,
4708                                sparsity):
4709        self.embedding_bag_rowwise_offsets_run(
4710            8, num_embeddings, embedding_dim, num_offsets,
4711            use_32bit_indices, use_32bit_offsets,
4712            enable_per_sample_weights, include_last_offset,
4713            fallback_to_no_sparse,
4714            sparsity=sparsity, atol=0.005, rtol=1e-3)
4715
4716    """ Tests the correctness of the embedding_bag_4bit quantized operator """
4717    @given(num_embeddings=st.integers(10, 100),
4718           embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
4719           num_offsets=st.integers(1, 20),
4720           use_32bit_indices=st.booleans(),
4721           use_32bit_offsets=st.booleans(),
4722           enable_per_sample_weights=st.booleans(),
4723           include_last_offset=st.booleans(),
4724           fallback_to_no_sparse=st.booleans(),
4725           sparsity=st.sampled_from([0.0, 0.5, 0.7]))
4726    def test_embedding_bag_4bit(self, num_embeddings,
4727                                embedding_dim, num_offsets,
4728                                use_32bit_indices,
4729                                use_32bit_offsets,
4730                                enable_per_sample_weights,
4731                                include_last_offset,
4732                                fallback_to_no_sparse,
4733                                sparsity):
4734        self.embedding_bag_rowwise_offsets_run(4, num_embeddings,
4735                                               embedding_dim, num_offsets,
4736                                               use_32bit_indices, use_32bit_offsets,
4737                                               enable_per_sample_weights,
4738                                               include_last_offset,
4739                                               fallback_to_no_sparse,
4740                                               sparsity=sparsity,
4741                                               atol=0.1, rtol=1e-2)
4742
4743    """ Tests the correctness of the embedding_bag_2bit quantized operator """
4744    @given(num_embeddings=st.integers(10, 100),
4745           embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
4746           num_offsets=st.integers(1, 20),
4747           use_32bit_indices=st.booleans(),
4748           use_32bit_offsets=st.booleans(),
4749           enable_per_sample_weights=st.booleans(),
4750           include_last_offset=st.booleans(),
4751           fallback_to_no_sparse=st.booleans(),
4752           sparsity=st.sampled_from([0.0, 0.5, 0.7]))
4753    def test_embedding_bag_2bit(self, num_embeddings,
4754                                embedding_dim, num_offsets,
4755                                use_32bit_indices,
4756                                use_32bit_offsets,
4757                                enable_per_sample_weights,
4758                                include_last_offset,
4759                                fallback_to_no_sparse,
4760                                sparsity):
4761        self.embedding_bag_rowwise_offsets_run(2, num_embeddings,
4762                                               embedding_dim, num_offsets,
4763                                               use_32bit_indices, use_32bit_offsets,
4764                                               enable_per_sample_weights,
4765                                               include_last_offset,
4766                                               fallback_to_no_sparse,
4767                                               sparsity=sparsity,
4768                                               atol=1.0, rtol=1e-1)
4769
4770    """ Tests the correctness of the quantized 8 bit embedding lookup operator """
4771    @given(num_embeddings=st.integers(10, 100),
4772           embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))
4773    def test_embedding(self, num_embeddings, embedding_dim):
4774        dtypes = [torch.quint8, torch.quint4x2]
4775        quant_ops = [torch.ops.quantized.embedding_byte, torch.ops.quantized.embedding_4bit]
4776        atols = [0.005, 0.1]
4777        rtols = [1e-3, 1e-2]
4778        prepack_op = torch.ops.quantized.embedding_bag_prepack
4779        for quant_op, dtype, atol, rtol in zip(quant_ops, dtypes, atols, rtols):
4780            weights = torch.from_numpy((np.random.random_sample((
4781                num_embeddings, embedding_dim)) + 1).astype(np.float32))
4782
4783            obs = PerChannelMinMaxObserver(dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
4784            obs(weights)
4785            # Get the scale and zero point for the weight tensor
4786            qparams = obs.calculate_qparams()
4787
4788            # Quantize the weights to 8bits
4789            qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype)
4790            max_segments = 5
4791            max_segment_length = 20
4792            num_lengths = np.random.randint(1, max_segments + 1)
4793            lengths = np.random.randint(1, max_segment_length + 1,
4794                                        size=num_lengths).astype(np.int32)
4795            num_indices = np.sum(lengths)
4796            indices = torch.from_numpy(np.random.randint(
4797                low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
4798
4799            packed_weight = prepack_op(qweight)
4800            qresult = quant_op(packed_weight, indices, pruned_weights=False)
4801
4802            ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
4803            torch.testing.assert_close(ref, qresult, atol=atol, rtol=rtol)
4804
4805    def test_embedding_2d_indices(self):
4806        """
4807        Tests the case where 2D indices are passed into the operator
4808        In this case the operator computes the correct offsets argument.
4809        Output shape is dependent on the indices dimension.
4810        """
4811        quant_op = torch.ops.quantized.embedding_byte
4812        prepack_op = torch.ops.quantized.embedding_bag_prepack
4813
4814        indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]])
4815        weights = torch.randn(10, 12, dtype=torch.float32)
4816
4817        ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
4818        obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
4819        obs(weights)
4820        qparams = obs.calculate_qparams()
4821
4822        qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
4823        packed_weight = prepack_op(qweight)
4824        qresult = quant_op(packed_weight, indices, pruned_weights=False)
4825        torch.testing.assert_close(ref, qresult, atol=0.05, rtol=1e-3)
4826
4827    def test_embedding_bag_2d_indices(self):
4828        """
4829        Tests the case where 2D indices are passed into the operator
4830        In this case the operator computes the correct offsets argument.
4831        """
4832        indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]])
4833        weights = torch.randn(10, 12, dtype=torch.float32)
4834
4835        embedding_bag = torch.nn.EmbeddingBag(
4836            num_embeddings=10,
4837            embedding_dim=12,
4838            include_last_offset=False, _weight=weights,
4839            scale_grad_by_freq=False, mode='sum'
4840        )
4841        result = embedding_bag(indices)
4842
4843        pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
4844        pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack
4845        q_weights = pt_prepack_op(weights)
4846        qresult = pt_op(q_weights, indices, mode=0, pruned_weights=False)
4847        torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3)
4848
4849        # Test TorchBind based embedding_bag operator
4850        obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
4851        obs(weights)
4852        # Get the scale and zero point for the weight tensor
4853        qparams = obs.calculate_qparams()
4854
4855        # Quantize the weights to 8bits
4856        qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
4857
4858        packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
4859        qresult = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, mode=0)
4860
4861        torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3)
4862
4863
4864class TestQuantizedConv(TestCase):
4865    def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs,
4866                                strides, i_pads, o_pads, channelwise):
4867        (X_data, W_data, bias_data, groups, transposed) = inputs
4868        (X, (X_scale, X_zero_point, X_qtype)) = X_data
4869        (W, (W_scale, W_zero_point, W_qtype)) = W_data
4870        (bias, (bias_scale, bias_zero_point, bias_qtype)) = bias_data
4871
4872        W = torch.from_numpy(W).float()
4873        bias = torch.from_numpy(bias).float()
4874        if channelwise and transposed:
4875            # currently transposed conv and per-channel per quantization does not work
4876            return
4877        # ONEDNN only supports symmetric quantization of weight and zero output padding
4878        if qengine_is_onednn():
4879            W_zero_point = 0
4880            o_pads = len(o_pads) * [0] if o_pads is not None else None
4881        if channelwise:
4882            if transposed:
4883                output_channels = W.shape[1]  # IC OC/G
4884            else:
4885                output_channels = W.shape[0]  # OC IC/G
4886            W_scale = torch.tensor([W_scale] * output_channels)
4887            W_zero_point = torch.tensor([W_zero_point] * output_channels)
4888            W_q = torch.quantize_per_channel(
4889                W, scales=W_scale, zero_points=W_zero_point,
4890                axis=int(transposed), dtype=W_qtype)
4891        else:
4892            W_q = torch.quantize_per_tensor(
4893                W, scale=W_scale, zero_point=W_zero_point, dtype=W_qtype)
4894
4895        if isinstance(strides, int):
4896            dilations = [1]
4897        else:
4898            dilations = (1,) * len(strides)
4899
4900        if transposed:
4901            W_packed = qconv_prepack_fn(W_q, bias, strides, i_pads, o_pads,
4902                                        dilations, groups)
4903        else:
4904            W_packed = qconv_prepack_fn(W_q, bias, strides, i_pads, dilations,
4905                                        groups)
4906        (W_unpacked, bias) = qconv_unpack_fn(W_packed)
4907
4908        # Assert equal
4909        np.testing.assert_equal(W_q.int_repr().numpy(),
4910                                W_unpacked.int_repr().numpy())
4911        if channelwise:
4912            np.testing.assert_array_almost_equal(
4913                np.float32(W_q.q_per_channel_scales().numpy()),
4914                np.float32(W_unpacked.q_per_channel_scales().numpy()),
4915                decimal=4)
4916            np.testing.assert_equal(W_q.q_per_channel_zero_points(
4917            ).numpy(), W_unpacked.q_per_channel_zero_points().numpy())
4918        else:
4919            np.testing.assert_equal(np.float32(
4920                W_q.q_scale()), np.float32(W_unpacked.q_scale()))
4921            np.testing.assert_equal(
4922                W_q.q_zero_point(), W_unpacked.q_zero_point())
4923
4924    def _make_qconv_tensors(
4925        self, batch_size, input_channels_per_group, input_feature_map_shape,
4926        output_channels_per_group, groups, kernels, strides, pads, dilations,
4927        X_scale, X_zero_point, W_scale, W_zero_point,
4928        use_bias, use_channelwise, use_transpose,
4929        device=torch.device("cpu"),
4930        input_dtype=torch.quint8,
4931        weight_dtype=torch.qint8,
4932    ):
4933        assert not (use_channelwise and use_transpose), \
4934               "Cannot generate channelwise qconv_transpose_tensors "
4935        input_channels = input_channels_per_group * groups
4936        output_channels = output_channels_per_group * groups
4937        # Padded input size should be at least as big as dilated kernel
4938        kernels = _single(kernels)
4939        strides = _single(strides)
4940        pads = _single(pads)
4941        dilations = _single(dilations)
4942        for i in range(len(kernels)):
4943            assume(input_feature_map_shape[i] + 2 * pads[i]
4944                   >= dilations[i] * (kernels[i] - 1) + 1)
4945        W_scale = W_scale * output_channels
4946        W_zero_point = W_zero_point * output_channels
4947        # Resize W_scale and W_zero_points arrays equal to output_channels
4948        W_scale = W_scale[:output_channels]
4949        W_zero_point = W_zero_point[:output_channels]
4950        # For testing, we use small values for weights and for activations
4951        # so that no overflow occurs in vpmaddubsw instruction. If the
4952        # overflow occurs in qconv implementation and if there is no
4953        # overflow
4954        # In reference we can't exactly match the results with reference.
4955        # Please see the comment in qconv implementation file
4956        # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
4957        (W_value_min, W_value_max) = (-5, 5)
4958        # the operator expects them in the format
4959        # (output_channels, input_channels/groups, kernel_d, kernel_h, kernel_w)
4960        # (input_channels, output_channels/groups, kernel_d, kernel_h, kernel_w)
4961        if use_transpose:
4962            output_shape = (input_channels, output_channels_per_group,)
4963        else:
4964            output_shape = (output_channels, input_channels_per_group,)
4965        W_init = torch.randint(
4966            W_value_min,
4967            W_value_max,
4968            output_shape + kernels,
4969            device=device,
4970        )
4971        b_init = torch.randint(0, 10, (output_channels,), device=device)
4972
4973        (X_value_min, X_value_max) = (0, 4)
4974        X_init = torch.randint(
4975            X_value_min,
4976            X_value_max,
4977            (batch_size, input_channels,) + input_feature_map_shape,
4978            device=device
4979        )
4980        X = X_scale * (X_init - X_zero_point).float()
4981
4982        if use_channelwise:
4983            W_shape = (-1, 1) + (1,) * len(kernels)
4984            W_scales_tensor = torch.tensor(W_scale, dtype=torch.float, device=device)
4985            W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float, device=device)
4986            W = W_scales_tensor.reshape(*W_shape) * (
4987                W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
4988            b = X_scale * W_scales_tensor * b_init.float()
4989        else:
4990            W = W_scale[0] * (W_init - W_zero_point[0]).float()
4991            b = X_scale * W_scale[0] * b_init.float()
4992
4993        X_q = torch.quantize_per_tensor(
4994            X, scale=X_scale, zero_point=X_zero_point, dtype=input_dtype)
4995        if use_channelwise:
4996            W_q = torch.quantize_per_channel(
4997                W, W_scales_tensor, W_zero_points_tensor.long(), 0,
4998                dtype=weight_dtype)
4999        else:
5000            W_q = torch.quantize_per_tensor(
5001                W, scale=W_scale[0], zero_point=W_zero_point[0],
5002                dtype=weight_dtype)
5003
5004        bias_float = b if use_bias else None
5005
5006        return (X, W), (X_q, W_q), bias_float
5007
5008    def _test_qconv_impl(
5009        self, qconv_fn, qconv_prepack_fn, conv_op, batch_size,
5010        input_channels_per_group, input_feature_map_shape,
5011        output_channels_per_group, groups, kernels, strides, pads, o_pads,
5012        dilations, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale,
5013        Y_zero_point, use_bias, post_op, use_channelwise, use_transpose,
5014        device=torch.device("cpu"),
5015        input_dtype=torch.quint8,
5016        weight_dtype=torch.qint8,
5017        output_dtype=torch.quint8,
5018        X2_scale=1.0,
5019        X2_zero_point=128
5020    ):
5021        # ONEDNN only supports symmetric quantization of weight
5022        if qengine_is_onednn() and W_zero_point is not None:
5023            W_zero_point = len(W_zero_point) * [0]
5024        (X, W), (X_q, W_q), bias_float = self._make_qconv_tensors(
5025            batch_size, input_channels_per_group, input_feature_map_shape,
5026            output_channels_per_group, groups, kernels,
5027            strides, pads, dilations, X_scale, X_zero_point, W_scale,
5028            W_zero_point, use_bias, use_channelwise, use_transpose,
5029            device=device, input_dtype=input_dtype, weight_dtype=weight_dtype)
5030        if bias_float is not None:
5031            bias_float = bias_float.to(device)
5032        # Assign weights
5033        W = W_q.dequantize()
5034        X = X_q.dequantize()
5035        conv_op.weight = torch.nn.Parameter(W, requires_grad=False)
5036        conv_op.bias = torch.nn.Parameter(
5037            bias_float, requires_grad=False) if use_bias else None
5038        result_ref = conv_op(X)
5039        if post_op == 'relu':
5040            assert not use_transpose, "Cannot fuse ReLU with ConvTranspose"
5041            relu = torch.nn.ReLU()
5042            result_ref = relu(result_ref)
5043        elif post_op == 'add':
5044            (X_value_min, X_value_max) = (0, 4)
5045            X2_init = torch.randint(
5046                X_value_min,
5047                X_value_max,
5048                result_ref.size(),
5049                device=device
5050            )
5051            X2 = X2_scale * (X2_init - X2_zero_point).float()
5052            X2_q = torch.quantize_per_tensor(
5053                X2, scale=X2_scale, zero_point=X2_zero_point, dtype=input_dtype)
5054            result_ref = result_ref + X2
5055        elif post_op == 'add_relu':
5056            (X_value_min, X_value_max) = (0, 4)
5057            X2_init = torch.randint(
5058                X_value_min,
5059                X_value_max,
5060                result_ref.size(),
5061                device=device
5062            )
5063            X2 = X2_scale * (X2_init - X2_zero_point).float()
5064            X2_q = torch.quantize_per_tensor(
5065                X2, scale=X2_scale, zero_point=X2_zero_point, dtype=input_dtype)
5066            result_ref = result_ref + X2
5067            relu = torch.nn.ReLU()
5068            result_ref = relu(result_ref)
5069        # Quantize reference results for comparison
5070        result_ref_q = torch.quantize_per_tensor(
5071            result_ref, scale=Y_scale, zero_point=Y_zero_point,
5072            dtype=output_dtype)
5073
5074        if qconv_prepack_fn is not None:
5075            if use_transpose:
5076                W_prepack = qconv_prepack_fn(
5077                    W_q, bias_float, strides, pads, o_pads, dilations, groups)
5078            else:
5079                W_prepack = qconv_prepack_fn(
5080                    W_q, bias_float, strides, pads, dilations, groups)
5081            if post_op == 'add' or post_op == 'add_relu':
5082                Y_q = qconv_fn(
5083                    X_q,
5084                    X2_q,
5085                    W_prepack,
5086                    Y_scale,
5087                    Y_zero_point,
5088                )
5089            else:
5090                Y_q = qconv_fn(
5091                    X_q,
5092                    W_prepack,
5093                    Y_scale,
5094                    Y_zero_point,
5095                )
5096        else:
5097            # quantized conv op without prepacking
5098            Y_q = qconv_fn(X_q, W_q, bias_float, strides, pads, dilations, groups, Y_scale, Y_zero_point)
5099
5100        # Make sure the results match
5101        # assert_array_almost_equal compares using the following formula:
5102        #     abs(desired-actual) < 1.5 * 10**(-decimal)
5103        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
5104        # We use decimal = 0 to ignore off-by-1 differences between
5105        # reference and test. Off-by-1 differences arise due to the order of
5106        # round and zero_point addition operation, i.e., if addition
5107        # followed by round is used by reference and round followed by
5108        # addition is used by test, the results may differ by 1.
5109        # For example, the result of round(2.5) + 1 is 3 while
5110        # round(2.5 + 1) is 4 assuming the rounding mode is
5111        # round-to-nearest, ties-to-even.
5112        np.testing.assert_array_almost_equal(
5113            result_ref_q.int_repr().cpu().numpy(), Y_q.int_repr().cpu().numpy(), decimal=0,
5114            err_msg=f'''X: {X_q}, W: {W_q}, b: {bias_float}, strides: {strides},
5115            pads: {pads}, o_pads: {o_pads}, dilations: {dilations},
5116            groups: {groups}, y_s: {Y_scale}, y_zp: {Y_zero_point}''')
5117
5118        # Return the quantized data for later reuse
5119        return X_q, W_q, bias_float
5120
5121    """Tests the correctness of quantized convolution op."""
5122    @given(batch_size=st.integers(1, 3),
5123           input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5124           height=st.integers(10, 16),
5125           width=st.integers(7, 14),
5126           output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5127           groups=st.integers(1, 300),
5128           kernel_h=st.integers(1, 7),
5129           kernel_w=st.integers(1, 7),
5130           stride_h=st.integers(1, 2),
5131           stride_w=st.integers(1, 2),
5132           pad_h=st.integers(0, 2),
5133           pad_w=st.integers(0, 2),
5134           dilation=st.integers(1, 2),
5135           X_scale=st.floats(1.2, 1.6),
5136           X_zero_point=st.integers(0, 4),
5137           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
5138           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
5139           Y_scale=st.floats(4.2, 5.6),
5140           Y_zero_point=st.integers(0, 4),
5141           use_bias=st.booleans(),
5142           use_channelwise=st.booleans())
5143    @override_qengines
5144    def test_qconv2d(
5145            self,
5146            batch_size,
5147            input_channels_per_group,
5148            height,
5149            width,
5150            output_channels_per_group,
5151            groups,
5152            kernel_h,
5153            kernel_w,
5154            stride_h,
5155            stride_w,
5156            pad_h,
5157            pad_w,
5158            dilation,
5159            X_scale,
5160            X_zero_point,
5161            W_scale,
5162            W_zero_point,
5163            Y_scale,
5164            Y_zero_point,
5165            use_bias,
5166            use_channelwise,
5167    ):
5168        input_channels = input_channels_per_group * groups
5169        output_channels = output_channels_per_group * groups
5170        kernels = (kernel_h, kernel_w)
5171        strides = (stride_h, stride_w)
5172        pads = (pad_h, pad_w)
5173        dilations = (dilation, dilation)
5174
5175        qconv = torch.ops.quantized.conv2d
5176        qconv_prepack = torch.ops.quantized.conv2d_prepack
5177        conv_op = torch.nn.Conv2d(
5178            input_channels,
5179            output_channels,
5180            kernels,
5181            strides,
5182            pads,
5183            dilations,
5184            groups,
5185        )
5186
5187        act_qdtypes = [torch.quint8]
5188        # Only qnnpack qengine supportes qint8
5189        if qengine_is_qnnpack() and torch.backends.xnnpack.enabled:
5190            act_qdtypes.append(torch.qint8)
5191
5192        for X_qdtype in act_qdtypes:
5193            if X_qdtype == torch.qint8:
5194                W_zero_point = [0 for i in range(len(W_zero_point))]
5195
5196            self._test_qconv_impl(
5197                qconv, qconv_prepack, conv_op, batch_size,
5198                input_channels_per_group, (height, width),
5199                output_channels_per_group, groups, kernels, strides, pads, None,
5200                dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5201                Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False, input_dtype=X_qdtype, output_dtype=X_qdtype)
5202
5203    @given(batch_size=st.integers(1, 3),
5204           input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5205           height=st.integers(10, 16),
5206           width=st.integers(7, 14),
5207           output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5208           groups=st.integers(1, 300),
5209           kernel_h=st.integers(1, 7),
5210           kernel_w=st.integers(1, 7),
5211           stride_h=st.integers(1, 2),
5212           stride_w=st.integers(1, 2),
5213           pad_h=st.integers(0, 2),
5214           pad_w=st.integers(0, 2),
5215           dilation=st.integers(1, 2),
5216           X_scale=st.floats(1.2, 1.6),
5217           X_zero_point=st.integers(0, 4),
5218           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
5219           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
5220           Y_scale=st.floats(4.2, 5.6),
5221           Y_zero_point=st.integers(0, 4),
5222           use_bias=st.booleans(),
5223           use_channelwise=st.booleans())
5224    @override_qengines
5225    def test_qconv2d_relu(
5226            self,
5227            batch_size,
5228            input_channels_per_group,
5229            height,
5230            width,
5231            output_channels_per_group,
5232            groups,
5233            kernel_h,
5234            kernel_w,
5235            stride_h,
5236            stride_w,
5237            pad_h,
5238            pad_w,
5239            dilation,
5240            X_scale,
5241            X_zero_point,
5242            W_scale,
5243            W_zero_point,
5244            Y_scale,
5245            Y_zero_point,
5246            use_bias,
5247            use_channelwise,
5248    ):
5249        input_channels = input_channels_per_group * groups
5250        output_channels = output_channels_per_group * groups
5251        kernels = (kernel_h, kernel_w)
5252        strides = (stride_h, stride_w)
5253        pads = (pad_h, pad_w)
5254        dilations = (dilation, dilation)
5255
5256        qconv = torch.ops.quantized.conv2d_relu
5257        qconv_prepack = torch.ops.quantized.conv2d_prepack
5258        conv_op = torch.nn.Conv2d(
5259            input_channels,
5260            output_channels,
5261            kernels,
5262            strides,
5263            pads,
5264            dilations,
5265            groups,
5266        )
5267
5268        act_qdtypes = [torch.quint8]
5269        # Only qnnpack qengine supportes qint8
5270        if qengine_is_qnnpack() and torch.backends.xnnpack.enabled:
5271            act_qdtypes.append(torch.qint8)
5272
5273        for X_qdtype in act_qdtypes:
5274            if X_qdtype == torch.qint8:
5275                W_zero_point = [0 for i in range(len(W_zero_point))]
5276
5277            self._test_qconv_impl(
5278                qconv, qconv_prepack, conv_op, batch_size,
5279                input_channels_per_group, (height, width),
5280                output_channels_per_group, groups, kernels, strides, pads, None,
5281                dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5282                Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False, input_dtype=X_qdtype, output_dtype=X_qdtype)
5283
5284    @skipIfNoONEDNN
5285    def test_qconv2d_add(self):
5286        batch_size = 3
5287        groups_list = [1, 10]
5288        input_channels_per_group = 2
5289        output_channels_per_group = 2
5290        height = 10
5291        width = 10
5292        kernel_h = 3
5293        kernel_w = 3
5294        stride_h = 2
5295        stride_w = 2
5296        pad_h = 1
5297        pad_w = 1
5298        dilation = 1
5299        X_scale = 1.5
5300        X_zero_point = 2
5301        W_scale = [1.5]
5302        W_zero_point = [-3]
5303        Y_scale = 4.2
5304        Y_zero_point = 0
5305        use_bias_list = [False, True]
5306        use_channelwise_list = [False, True]
5307        X2_scale = 1.2
5308        X2_zero_point_list = [0, 4]
5309        options = itertools.product(groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list)
5310        for groups, use_bias, use_channelwise, X2_zero_point in options:
5311            with override_quantized_engine('onednn'):
5312                input_channels = input_channels_per_group * groups
5313                output_channels = output_channels_per_group * groups
5314                kernels = (kernel_h, kernel_w)
5315                strides = (stride_h, stride_w)
5316                pads = (pad_h, pad_w)
5317                dilations = (dilation, dilation)
5318
5319                qconv = torch.ops.quantized.conv2d_add
5320                qconv_prepack = torch.ops.quantized.conv2d_prepack
5321                conv_op = torch.nn.Conv2d(
5322                    input_channels,
5323                    output_channels,
5324                    kernels,
5325                    strides,
5326                    pads,
5327                    dilations,
5328                    groups,
5329                )
5330
5331                X_qdtype = torch.quint8
5332                self._test_qconv_impl(
5333                    qconv, qconv_prepack, conv_op, batch_size,
5334                    input_channels_per_group, (height, width),
5335                    output_channels_per_group, groups, kernels, strides, pads, None,
5336                    dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5337                    Y_scale, Y_zero_point, use_bias, "add", use_channelwise, False,
5338                    input_dtype=X_qdtype, output_dtype=X_qdtype, X2_scale=X2_scale, X2_zero_point=X2_zero_point)
5339
5340    @skipIfNoONEDNN
5341    def test_qconv2d_add_relu(self):
5342        batch_size = 3
5343        height = 10
5344        width = 10
5345        groups_list = [1, 10]
5346        input_channels_per_group = 2
5347        output_channels_per_group = 2
5348        kernel_h = 3
5349        kernel_w = 3
5350        stride_h = 2
5351        stride_w = 2
5352        pad_h = 1
5353        pad_w = 1
5354        dilation = 1
5355        X_scale = 1.5
5356        X_zero_point = 2
5357        W_scale = [1.5]
5358        W_zero_point = [-3]
5359        Y_scale = 4.2
5360        Y_zero_point = 0
5361        use_bias_list = [False, True]
5362        use_channelwise_list = [False, True]
5363        X2_scale = 1.2
5364        X2_zero_point_list = [0, 4]
5365
5366        options = itertools.product(groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list)
5367        for groups, use_bias, use_channelwise, X2_zero_point in options:
5368            with override_quantized_engine('onednn'):
5369                input_channels = input_channels_per_group * groups
5370                output_channels = output_channels_per_group * groups
5371                kernels = (kernel_h, kernel_w)
5372                strides = (stride_h, stride_w)
5373                pads = (pad_h, pad_w)
5374                dilations = (dilation, dilation)
5375
5376                qconv = torch.ops.quantized.conv2d_add_relu
5377                qconv_prepack = torch.ops.quantized.conv2d_prepack
5378                conv_op = torch.nn.Conv2d(
5379                    input_channels,
5380                    output_channels,
5381                    kernels,
5382                    strides,
5383                    pads,
5384                    dilations,
5385                    groups,
5386                )
5387
5388                X_qdtype = torch.quint8
5389                self._test_qconv_impl(
5390                    qconv, qconv_prepack, conv_op, batch_size,
5391                    input_channels_per_group, (height, width),
5392                    output_channels_per_group, groups, kernels, strides, pads, None,
5393                    dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5394                    Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, False,
5395                    input_dtype=X_qdtype, output_dtype=X_qdtype, X2_scale=X2_scale, X2_zero_point=X2_zero_point)
5396
5397    # TODO: merge this test with test_qconv2d when CUDNN runtime flags becomes available
5398    """Tests the correctness of quantized 2D convolution cudnn op."""
5399    @given(batch_size=st.integers(1, 3),
5400           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
5401           input_channels_per_group=st.integers(1, 32),
5402           height=st.integers(10, 16),
5403           width=st.integers(7, 14),
5404           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
5405           output_channels_per_group=st.integers(1, 32),
5406           groups=st.integers(1, 1),  # currently padding only supports groups=1
5407           kernel_h=st.integers(1, 7),
5408           kernel_w=st.integers(1, 7),
5409           stride_h=st.integers(1, 2),
5410           stride_w=st.integers(1, 2),
5411           pad_h=st.integers(0, 2),
5412           pad_w=st.integers(0, 2),
5413           # result for dilation == 2 is not correct
5414           # dilation=st.integers(1, 2),
5415           # currently cudnn has only been verified to work for dilation = 1
5416           # TODO: check backend works for dilation > 1
5417           dilation=st.integers(1, 1),
5418           X_scale=st.floats(1.2, 1.6),
5419           X_zero_point=st.sampled_from([0]),
5420           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
5421           W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2),
5422           Y_scale=st.floats(4.2, 5.6),
5423           Y_zero_point=st.sampled_from([0]),
5424           use_bias=st.booleans(),
5425           # TODO: enable channelwise
5426           use_channelwise=st.sampled_from([False]))
5427    @skipIfNoFBGEMM
5428    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
5429    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
5430    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
5431    @unittest.skip("not currently working and feature isn't used")
5432    def test_qconv2d_cudnn(
5433            self,
5434            batch_size,
5435            input_channels_per_group,
5436            height,
5437            width,
5438            output_channels_per_group,
5439            groups,
5440            kernel_h,
5441            kernel_w,
5442            stride_h,
5443            stride_w,
5444            pad_h,
5445            pad_w,
5446            dilation,
5447            X_scale,
5448            X_zero_point,
5449            W_scale,
5450            W_zero_point,
5451            Y_scale,
5452            Y_zero_point,
5453            use_bias,
5454            use_channelwise,
5455    ):
5456        input_channels = input_channels_per_group * groups
5457        output_channels = output_channels_per_group * groups
5458        kernels = (kernel_h, kernel_w)
5459        strides = (stride_h, stride_w)
5460        pads = (pad_h, pad_w)
5461        dilations = (dilation, dilation)
5462
5463        qconv = torch.ops.quantized.conv2d
5464        conv_op = torch.nn.Conv2d(
5465            input_channels,
5466            output_channels,
5467            kernels,
5468            strides,
5469            pads,
5470            dilations,
5471            groups,
5472        ).to(torch.device("cuda"))
5473        self._test_qconv_impl(
5474            qconv, torch.ops.quantized.conv2d_prepack, conv_op, batch_size,
5475            input_channels_per_group, (height, width),
5476            output_channels_per_group, groups, kernels, strides, pads, None,
5477            dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5478            Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False,
5479            device=torch.device("cuda"),
5480            input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8)
5481
5482    @given(batch_size=st.integers(1, 3),
5483           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
5484           input_channels_per_group=st.integers(1, 32),
5485           height=st.integers(10, 16),
5486           width=st.integers(7, 14),
5487           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
5488           output_channels_per_group=st.integers(1, 32),
5489           groups=st.integers(1, 1),  # currently padding only supports groups=1
5490           kernel_h=st.integers(1, 7),
5491           kernel_w=st.integers(1, 7),
5492           stride_h=st.integers(1, 2),
5493           stride_w=st.integers(1, 2),
5494           pad_h=st.integers(0, 2),
5495           pad_w=st.integers(0, 2),
5496           # result for dilation == 2 is not correct
5497           # dilation=st.integers(1, 2),
5498           # currently cudnn has only been verified to work for dilation = 1
5499           # TODO: check backend works for dilation > 1
5500           dilation=st.integers(1, 1),
5501           X_scale=st.floats(1.2, 1.6),
5502           X_zero_point=st.sampled_from([0]),
5503           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
5504           W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2),
5505           Y_scale=st.floats(4.2, 5.6),
5506           Y_zero_point=st.sampled_from([0]),
5507           use_bias=st.booleans(),
5508           # TODO: enable channelwise
5509           use_channelwise=st.sampled_from([False]))
5510    @skipIfNoFBGEMM
5511    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
5512    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
5513    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
5514    @unittest.skip("not currently working and feature isn't used")
5515    def test_qconv2d_relu_cudnn(
5516            self,
5517            batch_size,
5518            input_channels_per_group,
5519            height,
5520            width,
5521            output_channels_per_group,
5522            groups,
5523            kernel_h,
5524            kernel_w,
5525            stride_h,
5526            stride_w,
5527            pad_h,
5528            pad_w,
5529            dilation,
5530            X_scale,
5531            X_zero_point,
5532            W_scale,
5533            W_zero_point,
5534            Y_scale,
5535            Y_zero_point,
5536            use_bias,
5537            use_channelwise,
5538    ):
5539        input_channels = input_channels_per_group * groups
5540        output_channels = output_channels_per_group * groups
5541        kernels = (kernel_h, kernel_w)
5542        strides = (stride_h, stride_w)
5543        pads = (pad_h, pad_w)
5544        dilations = (dilation, dilation)
5545
5546        qconv = torch.ops.quantized.conv2d_relu
5547        conv_op = torch.nn.Conv2d(
5548            input_channels,
5549            output_channels,
5550            kernels,
5551            strides,
5552            pads,
5553            dilations,
5554            groups,
5555        ).to(torch.device("cuda"))
5556        self._test_qconv_impl(
5557            qconv, torch.ops.quantized.conv2d_prepack, conv_op, batch_size,
5558            input_channels_per_group, (height, width),
5559            output_channels_per_group, groups, kernels, strides, pads, None,
5560            dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5561            Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False,
5562            device=torch.device("cuda"),
5563            input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8)
5564
5565    @unittest.skip("used for local benchmarking, comment when we want to run it")
5566    def test_benchmark(self):
5567        batch_size = 16
5568        in_channel = 64
5569        out_channel = 64
5570        kernel_size = 3
5571        height = 256
5572        width = 256
5573        print(
5574            "parameters:",
5575            "batch_size:", batch_size,
5576            "in_channel:", in_channel,
5577            "out_channel:", out_channel,
5578            "kernel_size:", kernel_size,
5579            "height:", height,
5580            "widht:", width
5581        )
5582        conv = torch.nn.Conv2d(in_channel, out_channel, kernel_size).cuda()
5583        input = torch.randn((batch_size, in_channel, height, width), device='cuda')
5584        weight = conv.weight.detach()
5585        stride = (1, 1)
5586        padding = (0, 0)
5587        dilation = (1, 1)
5588        groups = 1
5589        conv_op = torch.nn.functional.conv2d
5590        # profile
5591        from torch.profiler import profile, ProfilerActivity
5592
5593        def trace_handler(p):
5594            output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)
5595            p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")
5596
5597        my_schedule = torch.profiler.schedule(
5598            wait=5,
5599            warmup=5,
5600            active=20)
5601
5602        # fp32 benchmark
5603        with profile(
5604                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
5605                schedule=my_schedule,
5606                on_trace_ready=trace_handler) as prof:
5607            for i in range(30):
5608                conv_op(input, weight, None, stride, padding, dilation, groups)
5609                prof.step()
5610
5611        print("fp32 benchmark result:")
5612        print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
5613
5614        # fp16 benchmark
5615        input_fp16 = input.to(torch.float16)
5616        weight_fp16 = input.to(torch.float16)
5617
5618        with profile(
5619                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
5620                schedule=my_schedule,
5621                on_trace_ready=trace_handler) as prof:
5622            for i in range(30):
5623                conv_op(input_fp16, weight_fp16, None, stride, padding, dilation, groups)
5624                prof.step()
5625
5626        print("fp16 benchmark result:")
5627        print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
5628
5629        input_int8 = torch.quantize_per_tensor(input, 1, 0, torch.qint8).contiguous(memory_format=torch.channels_last)
5630        weight_int8 = torch.quantize_per_tensor(weight, 1, 0, torch.qint8).contiguous(memory_format=torch.channels_last)
5631        scale = 1.0
5632        zero_point = 0
5633        conv_op = torch.ops.quantized.conv2d
5634        weight_prepacked = torch.ops.quantized.conv2d_prepack(weight_int8, None, stride, padding, dilation, groups)
5635        with profile(
5636                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
5637                schedule=my_schedule,
5638                on_trace_ready=trace_handler) as prof:
5639            for i in range(30):
5640                conv_op(input_int8, weight_prepacked, scale, zero_point)
5641                prof.step()
5642
5643        print("int8 benchmark result:")
5644        print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
5645
5646    """Tests the correctness of quantized convolution op."""
5647    @override_qengines
5648    def test_qconv_transpose1d(self):
5649        if not qengine_is_qnnpack():
5650            return  # Currently only the QNNPACK is supported
5651        if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN):
5652            return  # QNNPACK doesn't support these
5653        batch_size = 2
5654        input_channels_per_group_list = [2, 32]
5655        width = 14
5656        output_channels_per_group_list = [2, 8]
5657        groups_list = [1, 3]
5658        kernel_list = [1, 7]
5659        stride_list = [1, 2]
5660        pad = 2
5661        o_pad = 0
5662        dilation = 1
5663        X_scale = 1.2
5664        X_zero_point = 1
5665        W_scale = [1.2]
5666        W_zero_point = [1]
5667        Y_scale = 4.2
5668        Y_zero_point = 2
5669        use_bias_list = [True, False]
5670
5671        test_cases = itertools.product(
5672            input_channels_per_group_list, output_channels_per_group_list,
5673            groups_list, kernel_list, stride_list, use_bias_list)
5674        for input_channels_per_group, output_channels_per_group, \
5675                groups, kernel, stride, use_bias in test_cases:
5676
5677            input_channels = input_channels_per_group * groups
5678            output_channels = output_channels_per_group * groups
5679            kernels = (kernel,)
5680            strides = (stride,)
5681            pads = (pad,)
5682            o_pads = (o_pad,)
5683            dilations = (dilation,)
5684
5685            qconv = torch.ops.quantized.conv_transpose1d
5686            qconv_prepack = torch.ops.quantized.conv_transpose1d_prepack
5687            conv_op = torch.nn.ConvTranspose1d(
5688                in_channels=input_channels,
5689                out_channels=output_channels,
5690                kernel_size=kernels,
5691                stride=strides,
5692                padding=pads,
5693                output_padding=o_pads,
5694                groups=groups,
5695                dilation=dilations,
5696                bias=use_bias
5697            )
5698
5699            act_qdtypes = [torch.quint8]
5700            # Only qnnpack qengine supportes qint8
5701            if qengine_is_qnnpack() and torch.backends.xnnpack.enabled:
5702                act_qdtypes.append(torch.qint8)
5703
5704            for X_qdtype in act_qdtypes:
5705                if X_qdtype == torch.qint8:
5706                    W_zero_point = [0 for i in range(len(W_zero_point))]
5707
5708                X_q, W_q, bias_float = self._test_qconv_impl(
5709                    qconv, qconv_prepack, conv_op, batch_size,
5710                    input_channels_per_group, (width, ),
5711                    output_channels_per_group, groups, kernels, strides, pads, o_pads,
5712                    dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5713                    Y_scale, Y_zero_point, use_bias, post_op="none",
5714                    use_channelwise=False, use_transpose=True, input_dtype=X_qdtype, output_dtype=X_qdtype)
5715
5716                # check that this doesn't error
5717                test_conv = torch.ao.nn.quantized.ConvTranspose1d(input_channels, output_channels, 1)
5718                test_conv.scale = Y_scale
5719                test_conv(X_q)
5720
5721                # Test the module implementation
5722                qconv_op = torch.ao.nn.quantized.ConvTranspose1d(
5723                    in_channels=input_channels,
5724                    out_channels=output_channels,
5725                    kernel_size=kernels,
5726                    stride=strides,
5727                    padding=pads,
5728                    output_padding=o_pads,
5729                    groups=groups,
5730                    dilation=dilations,
5731                    bias=use_bias
5732                )
5733                qconv_op.scale = Y_scale
5734                qconv_op.zero_point = Y_zero_point
5735                qconv_op.set_weight_bias(W_q, bias_float)
5736
5737                Y_dq_ref = conv_op(X_q.dequantize())
5738                Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale,
5739                                                    zero_point=Y_zero_point,
5740                                                    dtype=X_qdtype)
5741                Y_q = qconv_op(X_q)
5742                self.assertEqual(Y_q_ref, Y_q)
5743
5744
5745    """Tests the correctness of quantized convolution op."""
5746    @given(batch_size=st.integers(1, 3),
5747           input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5748           height=st.integers(10, 16),
5749           width=st.integers(7, 14),
5750           output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5751           groups=st.integers(1, 300),
5752           kernel_h=st.integers(1, 7),
5753           kernel_w=st.integers(1, 7),
5754           stride_h=st.integers(1, 2),
5755           stride_w=st.integers(1, 2),
5756           pad_h=st.integers(0, 2),
5757           pad_w=st.integers(0, 2),
5758           o_pad_h=st.integers(0, 2),
5759           o_pad_w=st.integers(0, 2),
5760           dilation=st.integers(1, 2),
5761           X_scale=st.floats(1.2, 1.6),
5762           X_zero_point=st.integers(0, 4),
5763           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
5764           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
5765           Y_scale=st.floats(4.2, 5.6),
5766           Y_zero_point=st.integers(0, 4),
5767           use_bias=st.booleans())
5768    @override_qengines
5769    @unittest.skip(
5770        "this is broken without changes to any relevant code, "
5771        "we need to remove hypothesis testing in CI")
5772    def test_qconv_transpose2d(
5773            self,
5774            batch_size,
5775            input_channels_per_group,
5776            height,
5777            width,
5778            output_channels_per_group,
5779            groups,
5780            kernel_h,
5781            kernel_w,
5782            stride_h,
5783            stride_w,
5784            pad_h,
5785            pad_w,
5786            o_pad_h,
5787            o_pad_w,
5788            dilation,
5789            X_scale,
5790            X_zero_point,
5791            W_scale,
5792            W_zero_point,
5793            Y_scale,
5794            Y_zero_point,
5795            use_bias):
5796        if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN):
5797            return  # QNNPACK doesn't support these
5798        # ONEDNN does not support output paddings
5799        if qengine_is_onednn() and (o_pad_h, o_pad_w) != (0, 0):
5800            return
5801        assume(o_pad_h < stride_h and o_pad_h < dilation)
5802        assume(o_pad_w < stride_w and o_pad_w < dilation)
5803
5804        input_channels = input_channels_per_group * groups
5805        output_channels = output_channels_per_group * groups
5806        kernels = (kernel_h, kernel_w)
5807        strides = (stride_h, stride_w)
5808        pads = (pad_h, pad_w)
5809        o_pads = (o_pad_h, o_pad_w)
5810        dilations = (dilation, dilation)
5811
5812        qconv = torch.ops.quantized.conv_transpose2d
5813        qconv_prepack = torch.ops.quantized.conv_transpose2d_prepack
5814        conv_op = torch.nn.ConvTranspose2d(
5815            in_channels=input_channels,
5816            out_channels=output_channels,
5817            kernel_size=kernels,
5818            stride=strides,
5819            padding=pads,
5820            output_padding=o_pads,
5821            groups=groups,
5822            dilation=dilations,
5823            bias=use_bias
5824        )
5825        act_qdtypes = [torch.quint8]
5826        # Only qnnpack qengine supportes qint8
5827        if qengine_is_qnnpack() and torch.backends.xnnpack.enabled:
5828            act_qdtypes.append(torch.qint8)
5829
5830        for X_qdtype in act_qdtypes:
5831            if X_qdtype == torch.qint8:
5832                W_zero_point = [0 for i in range(len(W_zero_point))]
5833
5834            X_q, W_q, bias_float = self._test_qconv_impl(
5835                qconv, qconv_prepack, conv_op, batch_size,
5836                input_channels_per_group, (height, width),
5837                output_channels_per_group, groups, kernels, strides, pads, o_pads,
5838                dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5839                Y_scale, Y_zero_point, use_bias, post_op="none",
5840                use_channelwise=False, use_transpose=True, input_dtype=X_qdtype, output_dtype=X_qdtype)
5841
5842            # check that this doesn't error
5843            test_conv = torch.ao.nn.quantized.ConvTranspose2d(input_channels, output_channels, 1)
5844            test_conv.scale = Y_scale
5845            test_conv(X_q)
5846
5847            # Test the module implementation
5848            qconv_op = torch.ao.nn.quantized.ConvTranspose2d(
5849                in_channels=input_channels,
5850                out_channels=output_channels,
5851                kernel_size=kernels,
5852                stride=strides,
5853                padding=pads,
5854                output_padding=o_pads,
5855                groups=groups,
5856                dilation=dilations,
5857                bias=use_bias
5858            )
5859            qconv_op.scale = Y_scale
5860            qconv_op.zero_point = Y_zero_point
5861            qconv_op.set_weight_bias(W_q, bias_float)
5862
5863            Y_dq_ref = conv_op(X_q.dequantize())
5864            Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale,
5865                                                zero_point=Y_zero_point,
5866                                                dtype=X_qdtype)
5867            Y_q = qconv_op(X_q)
5868            self.assertEqual(Y_q_ref, Y_q)
5869
5870    """Tests the correctness of quantized convolution op."""
5871    @given(batch_size=st.integers(1, 3),
5872           input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5873           time=st.integers(2, 5),
5874           height=st.integers(10, 16),
5875           width=st.integers(7, 14),
5876           output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
5877           groups=st.integers(1, 300),
5878           kernel_t=st.integers(1, 7),
5879           kernel_h=st.integers(1, 7),
5880           kernel_w=st.integers(1, 7),
5881           stride_t=st.integers(1, 2),
5882           stride_h=st.integers(1, 2),
5883           stride_w=st.integers(1, 2),
5884           pad_t=st.integers(0, 2),
5885           pad_h=st.integers(0, 2),
5886           pad_w=st.integers(0, 2),
5887           o_pad_t=st.integers(0, 2),
5888           o_pad_h=st.integers(0, 2),
5889           o_pad_w=st.integers(0, 2),
5890           dilation=st.integers(1, 2),
5891           X_scale=st.floats(1.2, 1.6),
5892           X_zero_point=st.integers(0, 4),
5893           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
5894           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
5895           Y_scale=st.floats(4.2, 5.6),
5896           Y_zero_point=st.integers(0, 4),
5897           use_bias=st.booleans())
5898    @override_qengines
5899    @unittest.skip(
5900        "this is broken without changes to any relevant code, "
5901        "we need to remove hypothesis testing in CI")
5902    def test_qconv_transpose3d(
5903            self,
5904            batch_size,
5905            input_channels_per_group,
5906            time,
5907            height,
5908            width,
5909            output_channels_per_group,
5910            groups,
5911            kernel_t,
5912            kernel_h,
5913            kernel_w,
5914            stride_t,
5915            stride_h,
5916            stride_w,
5917            pad_t,
5918            pad_h,
5919            pad_w,
5920            o_pad_t,
5921            o_pad_h,
5922            o_pad_w,
5923            dilation,
5924            X_scale,
5925            X_zero_point,
5926            W_scale,
5927            W_zero_point,
5928            Y_scale,
5929            Y_zero_point,
5930            use_bias):
5931        if qengine_is_qnnpack():
5932            return  # QNNPACK doesn't support this
5933        # ONEDNN doesn't support output paddings
5934        if qengine_is_onednn() and (o_pad_t, o_pad_h, o_pad_w) != (0, 0, 0):
5935            return
5936        assume(o_pad_t < stride_t or o_pad_t < dilation)
5937        assume(o_pad_h < stride_h or o_pad_h < dilation)
5938        assume(o_pad_w < stride_w or o_pad_w < dilation)
5939
5940        input_channels = input_channels_per_group * groups
5941        output_channels = output_channels_per_group * groups
5942        kernels = (kernel_t, kernel_h, kernel_w)
5943        strides = (stride_t, stride_h, stride_w)
5944        pads = (pad_t, pad_h, pad_w)
5945        o_pads = (o_pad_t, o_pad_h, o_pad_w)
5946        dilations = (dilation, dilation, dilation)
5947
5948        qconv = torch.ops.quantized.conv_transpose3d
5949        qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack
5950        conv_op = torch.nn.ConvTranspose3d(
5951            in_channels=input_channels,
5952            out_channels=output_channels,
5953            kernel_size=kernels,
5954            stride=strides,
5955            padding=pads,
5956            output_padding=o_pads,
5957            groups=groups,
5958            dilation=dilations,
5959            bias=use_bias
5960        )
5961        X_q, W_q, bias_float = self._test_qconv_impl(
5962            qconv, qconv_prepack, conv_op, batch_size,
5963            input_channels_per_group, (time, height, width),
5964            output_channels_per_group, groups, kernels, strides, pads, o_pads,
5965            dilations, X_scale, X_zero_point, W_scale, W_zero_point,
5966            Y_scale, Y_zero_point, use_bias, post_op="none",
5967            use_channelwise=False, use_transpose=True)
5968
5969        # check that this doesn't error
5970        test_conv = torch.ao.nn.quantized.ConvTranspose3d(input_channels, output_channels, 1)
5971        test_conv.scale = Y_scale
5972        test_conv(X_q)
5973
5974        # Test the module implementation
5975        qconv_op = torch.ao.nn.quantized.ConvTranspose3d(
5976            in_channels=input_channels,
5977            out_channels=output_channels,
5978            kernel_size=kernels,
5979            stride=strides,
5980            padding=pads,
5981            output_padding=o_pads,
5982            groups=groups,
5983            dilation=dilations,
5984            bias=use_bias
5985        )
5986        qconv_op.scale = Y_scale
5987        qconv_op.zero_point = Y_zero_point
5988        qconv_op.set_weight_bias(W_q, bias_float)
5989
5990        Y_dq_ref = conv_op(X_q.dequantize())
5991        Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale,
5992                                            zero_point=Y_zero_point,
5993                                            dtype=torch.quint8)
5994        Y_q = qconv_op(X_q)
5995        self.assertEqual(Y_q_ref, Y_q)
5996
5997    @given(
5998        inputs=hu.tensor_conv(
5999            spatial_dim=1, batch_size_range=(1, 3),
6000            input_channels_per_group_range=(1, 4),
6001            output_channels_per_group_range=(1, 4), feature_map_range=(4, 8),
6002            kernel_range=(1, 4), max_groups=4,
6003            can_be_transposed=False,
6004            qparams=[hu.qparams(dtypes=torch.quint8,
6005                                zero_point_min=0,
6006                                zero_point_max=0),
6007                     hu.qparams(dtypes=torch.qint8,
6008                                zero_point_min=0,
6009                                zero_point_max=0),
6010                     hu.qparams(dtypes=torch.qint32,
6011                                zero_point_min=0,
6012                                zero_point_max=0)]),
6013        stride=st.integers(1, 3),
6014        pad=st.integers(1, 2),
6015        o_pad=st.integers(1, 2),
6016        channelwise=st.booleans())
6017    @override_qengines
6018    def test_qconv1d_unpack(self, inputs, stride, pad, o_pad, channelwise):
6019        transposed = inputs[-1]
6020        qengine = torch.backends.quantized.engine
6021        if qengine not in supported_qengines:
6022            return
6023        if qengine == 'qnnpack':
6024            assume(not channelwise)  # QNNPACK doesn't support channelwise
6025        else:
6026            assume(not transposed)  # Only QNNPACK supports transposed conv
6027        if transposed:
6028            qconv_prepack = torch.ops.quantized.conv_transpose1d_prepack
6029            qconv_unpack = torch.ops.quantized.conv_transpose1d_unpack
6030        else:
6031            qconv_prepack = torch.ops.quantized.conv1d_prepack
6032            qconv_unpack = torch.ops.quantized.conv1d_unpack
6033        self._test_qconv_unpack_impl(
6034            qconv_prepack, qconv_unpack, inputs, [stride],
6035            [pad], [o_pad], channelwise)
6036
6037    @given(
6038        inputs=hu.tensor_conv(
6039            spatial_dim=2, batch_size_range=(1, 3),
6040            input_channels_per_group_range=(1, 4),
6041            output_channels_per_group_range=(1, 4), feature_map_range=(4, 8),
6042            kernel_range=(1, 4), max_groups=4,
6043            can_be_transposed=True,
6044            qparams=[hu.qparams(dtypes=torch.quint8,
6045                                zero_point_min=0,
6046                                zero_point_max=0),
6047                     hu.qparams(dtypes=torch.qint8,
6048                                zero_point_min=0,
6049                                zero_point_max=0),
6050                     hu.qparams(dtypes=torch.qint32,
6051                                zero_point_min=0,
6052                                zero_point_max=0)]),
6053        stride=st.integers(1, 3),
6054        pad=st.integers(0, 2),
6055        o_pad=st.integers(0, 2),
6056        channelwise=st.booleans())
6057    @override_qengines
6058    def test_qconv2d_unpack(self, inputs, stride, pad, o_pad, channelwise):
6059        transposed = inputs[-1]
6060        qengine = torch.backends.quantized.engine
6061        if qengine not in supported_qengines:
6062            return
6063        if qengine == 'qnnpack':
6064            assume(not channelwise)  # QNNPACK doesn't support channelwise
6065        if transposed:
6066            qconv_prepack = torch.ops.quantized.conv_transpose2d_prepack
6067            qconv_unpack = torch.ops.quantized.conv_transpose2d_unpack
6068        else:
6069            qconv_prepack = torch.ops.quantized.conv2d_prepack
6070            qconv_unpack = torch.ops.quantized.conv2d_unpack
6071        self._test_qconv_unpack_impl(
6072            qconv_prepack, qconv_unpack, inputs, [stride, stride],
6073            [pad, pad], [o_pad, o_pad], channelwise)
6074
6075    """Tests the correctness of quantized 1D convolution op."""
6076    @given(batch_size=st.integers(1, 6),
6077           input_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)),
6078           output_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)),
6079           groups=st.integers(1, 3),
6080           length=st.integers(4, 16),
6081           kernel=st.integers(1, 7),
6082           stride=st.integers(1, 2),
6083           pad=st.integers(0, 2),
6084           dilation=st.integers(1, 2),
6085           X_scale=st.floats(1.2, 1.6),
6086           X_zero_point=st.integers(0, 4),
6087           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
6088           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
6089           Y_scale=st.floats(4.2, 5.6),
6090           Y_zero_point=st.integers(0, 4),
6091           use_bias=st.booleans(),
6092           use_channelwise=st.booleans())
6093    @override_qengines
6094    def test_qconv1d(
6095        self,
6096        batch_size,
6097        input_channels_per_group,
6098        output_channels_per_group,
6099        groups,
6100        length,
6101        kernel,
6102        stride,
6103        pad,
6104        dilation,
6105        X_scale,
6106        X_zero_point,
6107        W_scale,
6108        W_zero_point,
6109        Y_scale,
6110        Y_zero_point,
6111        use_bias,
6112        use_channelwise,
6113    ):
6114        input_channels = input_channels_per_group * groups
6115        output_channels = output_channels_per_group * groups
6116        if torch.backends.quantized.engine == 'qnnpack':
6117            use_channelwise = False
6118        conv1d = torch.nn.Conv1d(
6119            input_channels,
6120            output_channels,
6121            kernel,
6122            stride,
6123            pad,
6124            dilation,
6125            groups,
6126        )
6127        qconv_prepack = torch.ops.quantized.conv1d_prepack
6128        qconv = torch.ops.quantized.conv1d
6129
6130        act_qdtypes = [torch.quint8]
6131        # Only qnnpack qengine supportes qint8
6132        if qengine_is_qnnpack() and torch.backends.xnnpack.enabled:
6133            act_qdtypes.append(torch.qint8)
6134
6135        for X_qdtype in act_qdtypes:
6136            if X_qdtype == torch.qint8:
6137                W_zero_point = [0 for i in range(len(W_zero_point))]
6138
6139            self._test_qconv_impl(
6140                qconv, qconv_prepack, conv1d, batch_size,
6141                input_channels_per_group, (length, ),
6142                output_channels_per_group, groups, kernel, [stride], [pad], None,
6143                [dilation], X_scale, X_zero_point, W_scale, W_zero_point,
6144                Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False,
6145                input_dtype=X_qdtype, output_dtype=X_qdtype)
6146
6147    @given(batch_size=st.integers(1, 6),
6148           input_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)),
6149           output_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)),
6150           groups=st.integers(1, 3),
6151           length=st.integers(4, 16),
6152           kernel=st.integers(1, 7),
6153           stride=st.integers(1, 2),
6154           pad=st.integers(0, 2),
6155           dilation=st.integers(1, 2),
6156           X_scale=st.floats(1.2, 1.6),
6157           X_zero_point=st.integers(0, 4),
6158           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
6159           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
6160           Y_scale=st.floats(4.2, 5.6),
6161           Y_zero_point=st.integers(0, 4),
6162           use_bias=st.booleans(),
6163           use_channelwise=st.booleans())
6164    @override_qengines
6165    def test_qconv1d_relu(
6166        self,
6167        batch_size,
6168        input_channels_per_group,
6169        output_channels_per_group,
6170        groups,
6171        length,
6172        kernel,
6173        stride,
6174        pad,
6175        dilation,
6176        X_scale,
6177        X_zero_point,
6178        W_scale,
6179        W_zero_point,
6180        Y_scale,
6181        Y_zero_point,
6182        use_bias,
6183        use_channelwise,
6184    ):
6185        input_channels = input_channels_per_group * groups
6186        output_channels = output_channels_per_group * groups
6187        if torch.backends.quantized.engine == 'qnnpack':
6188            use_channelwise = False
6189        conv1d = torch.nn.Conv1d(
6190            input_channels,
6191            output_channels,
6192            kernel,
6193            stride,
6194            pad,
6195            dilation,
6196            groups,
6197        )
6198        qconv_prepack = torch.ops.quantized.conv1d_prepack
6199        qconv = torch.ops.quantized.conv1d_relu
6200
6201        act_qdtypes = [torch.quint8]
6202        # Only qnnpack qengine supportes qint8
6203        if qengine_is_qnnpack() and torch.backends.xnnpack.enabled:
6204            act_qdtypes.append(torch.qint8)
6205
6206        for X_qdtype in act_qdtypes:
6207            if X_qdtype == torch.qint8:
6208                W_zero_point = [0 for i in range(len(W_zero_point))]
6209
6210            self._test_qconv_impl(
6211                qconv, qconv_prepack, conv1d, batch_size,
6212                input_channels_per_group, (length, ),
6213                output_channels_per_group, groups, kernel, [stride], [pad], None,
6214                [dilation], X_scale, X_zero_point, W_scale, W_zero_point,
6215                Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False,
6216                input_dtype=X_qdtype, output_dtype=X_qdtype)
6217
6218    # TODO: merge this test with test_qconv1d when CUDNN runtime flags becomes available
6219    """Tests the correctness of quantized 1D convolution cudnn op."""
6220    @given(batch_size=st.integers(1, 6),
6221           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
6222           input_channels_per_group=st.integers(1, 32),
6223           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
6224           output_channels_per_group=st.integers(1, 32),
6225           groups=st.integers(1, 1),  # currently padding only supports groups=1
6226           length=st.integers(4, 16),
6227           kernel=st.integers(1, 7),
6228           stride=st.integers(1, 2),
6229           pad=st.integers(0, 2),
6230           # currently cudnn has only been verified to work for dilation = 1
6231           # TODO: check backend works for dilation > 1
6232           dilation=st.integers(1, 1),
6233           X_scale=st.floats(1.2, 1.6),
6234           # currently conv cudnn backend is only implemented for int8 symmetric
6235           X_zero_point=st.sampled_from([0]),
6236           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
6237           # currently conv cudnn backend is only implemented for int8 symmetric
6238           W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2),
6239           Y_scale=st.floats(4.2, 5.6),
6240           # currently conv cudnn backend is only implemented for int8 symmetric
6241           Y_zero_point=st.sampled_from([0]),
6242           use_bias=st.booleans(),
6243           # TODO: enable channelwise
6244           use_channelwise=st.sampled_from([False]))
6245    @skipIfNoFBGEMM
6246    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
6247    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
6248    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
6249    @unittest.skip("not currently working and feature isn't used")
6250    def test_qconv1d_cudnn(
6251        self,
6252        batch_size,
6253        input_channels_per_group,
6254        output_channels_per_group,
6255        groups,
6256        length,
6257        kernel,
6258        stride,
6259        pad,
6260        dilation,
6261        X_scale,
6262        X_zero_point,
6263        W_scale,
6264        W_zero_point,
6265        Y_scale,
6266        Y_zero_point,
6267        use_bias,
6268        use_channelwise,
6269    ):
6270        input_channels = input_channels_per_group * groups
6271        output_channels = output_channels_per_group * groups
6272
6273        conv1d = torch.nn.Conv1d(
6274            input_channels,
6275            output_channels,
6276            kernel,
6277            stride,
6278            pad,
6279            dilation,
6280            groups,
6281        ).to(torch.device("cuda"))
6282        qconv_prepack = torch.ops.quantized.conv1d_prepack
6283        qconv = torch.ops.quantized.conv1d
6284
6285        self._test_qconv_impl(
6286            qconv, qconv_prepack, conv1d, batch_size,
6287            input_channels_per_group, (length, ),
6288            output_channels_per_group, groups, kernel, [stride], [pad], None,
6289            [dilation], X_scale, X_zero_point, W_scale, W_zero_point,
6290            Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False,
6291            device=torch.device("cuda"),
6292            input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8)
6293
6294    @given(batch_size=st.integers(1, 6),
6295           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
6296           input_channels_per_group=st.integers(1, 32),
6297           # cudnn only supports multiples of 4, but we have explicitly added padding on the backend
6298           output_channels_per_group=st.integers(1, 32),
6299           groups=st.integers(1, 1),  # currently padding only supports groups=1
6300           length=st.integers(4, 16),
6301           kernel=st.integers(1, 7),
6302           stride=st.integers(1, 2),
6303           pad=st.integers(0, 2),
6304           # currently cudnn has only been verified to work for dilation = 1
6305           # TODO: check backend works for dilation > 1
6306           dilation=st.integers(1, 1),
6307           X_scale=st.floats(1.2, 1.6),
6308           # currently conv cudnn backend is only implemented for int8 symmetric
6309           X_zero_point=st.sampled_from([0]),
6310           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
6311           # currently conv cudnn backend is only implemented for int8 symmetric
6312           W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2),
6313           Y_scale=st.floats(4.2, 5.6),
6314           # currently conv cudnn backend is only implemented for int8 symmetric
6315           Y_zero_point=st.sampled_from([0]),
6316           use_bias=st.booleans(),
6317           # TODO: enable channelwise
6318           use_channelwise=st.sampled_from([False]))
6319    @skipIfNoFBGEMM
6320    @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
6321    @unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
6322    @unittest.skipIf(TEST_ROCM, "not supported on rocm.")
6323    @unittest.skip("not currently working and feature isn't used")
6324    def test_qconv1d_relu_cudnn(
6325        self,
6326        batch_size,
6327        input_channels_per_group,
6328        output_channels_per_group,
6329        groups,
6330        length,
6331        kernel,
6332        stride,
6333        pad,
6334        dilation,
6335        X_scale,
6336        X_zero_point,
6337        W_scale,
6338        W_zero_point,
6339        Y_scale,
6340        Y_zero_point,
6341        use_bias,
6342        use_channelwise,
6343    ):
6344        input_channels = input_channels_per_group * groups
6345        output_channels = output_channels_per_group * groups
6346
6347        conv1d = torch.nn.Conv1d(
6348            input_channels,
6349            output_channels,
6350            kernel,
6351            stride,
6352            pad,
6353            dilation,
6354            groups,
6355        ).to(torch.device("cuda"))
6356        qconv_prepack = torch.ops.quantized.conv1d_prepack
6357        qconv = torch.ops.quantized.conv1d_relu
6358
6359        self._test_qconv_impl(
6360            qconv, qconv_prepack, conv1d, batch_size,
6361            input_channels_per_group, (length, ),
6362            output_channels_per_group, groups, kernel, [stride], [pad], None,
6363            [dilation], X_scale, X_zero_point, W_scale, W_zero_point,
6364            Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False,
6365            device=torch.device("cuda"),
6366            input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8)
6367
6368    @given(batch_size=st.integers(1, 4),
6369           input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]),
6370           D=st.integers(4, 8),
6371           H=st.integers(4, 8),
6372           W=st.integers(4, 8),
6373           output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]),
6374           groups=st.integers(1, 3),
6375           kernel_d=st.integers(1, 4),
6376           kernel_h=st.integers(1, 4),
6377           kernel_w=st.integers(1, 4),
6378           stride_d=st.integers(1, 2),
6379           stride_h=st.integers(1, 2),
6380           stride_w=st.integers(1, 2),
6381           pad_d=st.integers(0, 2),
6382           pad_h=st.integers(0, 2),
6383           pad_w=st.integers(0, 2),
6384           dilation=st.integers(1, 2),
6385           X_scale=st.floats(1.2, 1.6),
6386           X_zero_point=st.integers(0, 4),
6387           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
6388           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
6389           Y_scale=st.floats(4.2, 5.6),
6390           Y_zero_point=st.integers(0, 4),
6391           use_bias=st.booleans(),
6392           use_channelwise=st.booleans(),
6393           qengine=st.sampled_from(("qnnpack", "fbgemm")))
6394    def test_qconv3d(
6395        self,
6396        batch_size,
6397        input_channels_per_group,
6398        D,
6399        H,
6400        W,
6401        output_channels_per_group,
6402        groups,
6403        kernel_d,
6404        kernel_h,
6405        kernel_w,
6406        stride_d,
6407        stride_h,
6408        stride_w,
6409        pad_d,
6410        pad_h,
6411        pad_w,
6412        dilation,
6413        X_scale,
6414        X_zero_point,
6415        W_scale,
6416        W_zero_point,
6417        Y_scale,
6418        Y_zero_point,
6419        use_bias,
6420        use_channelwise,
6421        qengine
6422    ):
6423        if qengine not in supported_qengines:
6424            return
6425
6426        input_channels = input_channels_per_group * groups
6427        output_channels = output_channels_per_group * groups
6428        kernels = (kernel_d, kernel_h, kernel_w)
6429        strides = (stride_d, stride_h, stride_w)
6430        pads = (pad_d, pad_h, pad_w)
6431        dilations = (dilation, dilation, dilation)
6432
6433        with override_quantized_engine(qengine):
6434            qconv = torch.ops.quantized.conv3d
6435            qconv_prepack = torch.ops.quantized.conv3d_prepack
6436            conv_op = torch.nn.Conv3d(
6437                input_channels,
6438                output_channels,
6439                kernels,
6440                strides,
6441                pads,
6442                dilations,
6443                groups,
6444            )
6445            self._test_qconv_impl(
6446                qconv, qconv_prepack, conv_op, batch_size,
6447                input_channels_per_group, (D, H, W), output_channels_per_group,
6448                groups, kernels, strides, pads, None, dilations, X_scale,
6449                X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
6450                use_bias, "none", use_channelwise, use_transpose=False)
6451
6452    @given(batch_size=st.integers(1, 4),
6453           input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]),
6454           D=st.integers(4, 8),
6455           H=st.integers(4, 8),
6456           W=st.integers(4, 8),
6457           output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]),
6458           groups=st.integers(1, 3),
6459           kernel_d=st.integers(1, 4),
6460           kernel_h=st.integers(1, 4),
6461           kernel_w=st.integers(1, 4),
6462           stride_d=st.integers(1, 2),
6463           stride_h=st.integers(1, 2),
6464           stride_w=st.integers(1, 2),
6465           pad_d=st.integers(0, 2),
6466           pad_h=st.integers(0, 2),
6467           pad_w=st.integers(0, 2),
6468           dilation=st.integers(1, 2),
6469           X_scale=st.floats(1.2, 1.6),
6470           X_zero_point=st.integers(0, 4),
6471           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
6472           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
6473           Y_scale=st.floats(4.2, 5.6),
6474           Y_zero_point=st.integers(0, 4),
6475           use_bias=st.booleans(),
6476           use_channelwise=st.booleans(),
6477           qengine=st.sampled_from(("qnnpack", "fbgemm")))
6478    def test_qconv3d_relu(
6479        self,
6480        batch_size,
6481        input_channels_per_group,
6482        D,
6483        H,
6484        W,
6485        output_channels_per_group,
6486        groups,
6487        kernel_d,
6488        kernel_h,
6489        kernel_w,
6490        stride_d,
6491        stride_h,
6492        stride_w,
6493        pad_d,
6494        pad_h,
6495        pad_w,
6496        dilation,
6497        X_scale,
6498        X_zero_point,
6499        W_scale,
6500        W_zero_point,
6501        Y_scale,
6502        Y_zero_point,
6503        use_bias,
6504        use_channelwise,
6505        qengine
6506    ):
6507        if qengine not in supported_qengines:
6508            return
6509
6510        input_channels = input_channels_per_group * groups
6511        output_channels = output_channels_per_group * groups
6512        kernels = (kernel_d, kernel_h, kernel_w)
6513        strides = (stride_d, stride_h, stride_w)
6514        pads = (pad_d, pad_h, pad_w)
6515        dilations = (dilation, dilation, dilation)
6516
6517        with override_quantized_engine(qengine):
6518            qconv = torch.ops.quantized.conv3d_relu
6519            qconv_prepack = torch.ops.quantized.conv3d_prepack
6520            conv_op = torch.nn.Conv3d(
6521                input_channels,
6522                output_channels,
6523                kernels,
6524                strides,
6525                pads,
6526                dilations,
6527                groups,
6528            )
6529            self._test_qconv_impl(
6530                qconv, qconv_prepack, conv_op, batch_size,
6531                input_channels_per_group, (D, H, W), output_channels_per_group,
6532                groups, kernels, strides, pads, None, dilations, X_scale,
6533                X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
6534                use_bias, "relu", use_channelwise, use_transpose=False)
6535
6536    """Tests the correctness of the quantized::qconv3d_unpack op."""
6537    @given(
6538        inputs=hu.tensor_conv(
6539            spatial_dim=3, batch_size_range=(1, 3),
6540            input_channels_per_group_range=(1, 3),
6541            output_channels_per_group_range=(1, 3), feature_map_range=(3, 6),
6542            kernel_range=(1, 3), max_groups=3,
6543            qparams=[hu.qparams(dtypes=torch.quint8,
6544                                zero_point_min=0,
6545                                zero_point_max=0),
6546                     hu.qparams(dtypes=torch.qint8,
6547                                zero_point_min=0,
6548                                zero_point_max=0),
6549                     hu.qparams(dtypes=torch.qint32,
6550                                zero_point_min=0,
6551                                zero_point_max=0)]),
6552        stride_d=st.integers(1, 2), stride_h=st.integers(1, 2),
6553        stride_w=st.integers(1, 2),
6554        pad_d=st.integers(1, 2), pad_h=st.integers(1, 2),
6555        pad_w=st.integers(1, 2),
6556        o_pad=st.integers(0, 2),
6557        channelwise=st.booleans())
6558    @override_qengines
6559    def test_qconv3d_unpack(
6560        self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, o_pad,
6561        channelwise
6562    ):
6563        if qengine_is_qnnpack():
6564            return  # QNNPACK doesn't support this
6565        transposed = inputs[-1]
6566        if transposed:
6567            qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack
6568            qconv_unpack = torch.ops.quantized.conv_transpose3d_unpack
6569        else:
6570            qconv_prepack = torch.ops.quantized.conv3d_prepack
6571            qconv_unpack = torch.ops.quantized.conv3d_unpack
6572        self._test_qconv_unpack_impl(
6573            qconv_prepack, qconv_unpack, inputs,
6574            (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad),
6575            channelwise)
6576
6577    def test_conv_reorder_issue_onednn(self):
6578        """ Ensure reorder failure issue in conv is fixed for onednn backend.
6579            Onednn backend used to encounter reorder failure
6580            when running conv with dynamic input shapes.
6581            Solved by https://github.com/pytorch/pytorch/pull/86876
6582        """
6583        if 'onednn' not in supported_qengines:
6584            return
6585        with override_quantized_engine('onednn'):
6586            bs = 1
6587            ic, oc = 128, 512
6588            kh, kw = 1, 1
6589            bias = None
6590            strides, paddings, dilates = (1, 1), (0, 0), (1, 1)
6591            for groups in [1, 2]:
6592                ih, iw = 28, 28
6593                w = torch.randn((oc * groups, ic, kh, kw))
6594                qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
6595                x = torch.randn((bs, ic * groups, ih, iw))
6596                qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
6597                w_packed = torch.ops.quantized.conv2d_prepack(
6598                    qw, bias, strides, paddings, dilates, groups
6599                )
6600                torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
6601                ih, iw = 5, 4
6602                x = torch.randn((bs, ic * groups, ih, iw))
6603                qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
6604                # The following should pass when input shape is changed
6605                torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
6606
6607    @skipIfNoONEDNN
6608    def test_conv_transpose_reorder_issue_onednn(self):
6609        with override_quantized_engine('onednn'):
6610            bs = 1
6611            ic, oc = 16, 33
6612            kh, kw = 3, 3
6613            ih, iw = 50, 100
6614            bias = None
6615            strides, paddings, output_paddings, dilates, groups = [2, 2], [0, 0], [0, 0], [1, 1], 1
6616            w = torch.randn((ic, oc, kh, kw))
6617            qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8)
6618            x = torch.randn((bs, ic, ih, iw))
6619            qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
6620            w_packed = torch.ops.quantized.conv_transpose2d_prepack(
6621                qw, bias, strides, paddings, output_paddings, dilates, groups
6622            )
6623            torch.ops.quantized.conv_transpose2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
6624            ih, iw = 5, 4
6625            x = torch.randn((bs, ic, ih, iw))
6626            qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
6627            # The following should pass when input shape is changed
6628            torch.ops.quantized.conv_transpose2d(qx, w_packed, output_scale=1.0, output_zero_point=0)
6629
6630    def _test_qconv_impl_cpu_tensor(
6631        self,
6632        qconv,
6633        qconv_prepack,
6634        conv_op,
6635        input_channels_per_group=2,
6636        input_feature_map_shape=(),
6637        output_channels_per_group=2,
6638        groups=1,
6639        kernels=3,
6640        strides=(),
6641        pads=(),
6642        dilations=(),
6643        X_scale=1.3,
6644        X_zero_point=2,
6645        W_scale=(1.0,),
6646        W_zero_point=(0,),
6647        Y_scale=3.2,
6648        Y_zero_point=0,
6649        use_bias=True,
6650        post_op=PointwisePostOp(),
6651        use_channelwise=True,
6652        X2_scale=1.2,
6653        X2_zero_point=0,
6654        qconv_output_dtype=None,  # None, torch.float32, torch.bfloat16
6655        weight_in_channel_last_format=False,
6656        qconv_x2_dtype=None,
6657    ):
6658        # ONEDNN only supports symmetric quantization of weight
6659        if W_zero_point is not None:
6660            W_zero_point = len(W_zero_point) * [0]
6661        fp32_output = True if qconv_output_dtype is torch.float32 else False
6662        bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
6663        if fp32_output or bfloat16_output:
6664            Y_scale = 1.0
6665            Y_zero_point = 0
6666            X2_scale = 1.0
6667            X2_zero_point = 0
6668        batch_size = 3
6669        o_pads = None
6670        device = torch.device("cpu")
6671        input_dtype = torch.quint8
6672        weight_dtype = torch.qint8
6673        output_dtype = torch.quint8
6674        use_transpose = False
6675        (X, W), (X_q, W_q), bias_float = self._make_qconv_tensors(
6676            batch_size,
6677            input_channels_per_group,
6678            input_feature_map_shape,
6679            output_channels_per_group,
6680            groups,
6681            kernels,
6682            strides,
6683            pads,
6684            dilations,
6685            X_scale,
6686            X_zero_point,
6687            W_scale,
6688            W_zero_point,
6689            use_bias,
6690            use_channelwise,
6691            use_transpose,
6692            device=device,
6693            input_dtype=input_dtype,
6694            weight_dtype=weight_dtype,
6695        )
6696        if bias_float is not None:
6697            bias_float = bias_float.to(device)
6698        # Assign weights
6699        W = W_q.dequantize()
6700        X = X_q.dequantize()
6701        conv_op.weight = torch.nn.Parameter(W, requires_grad=False)
6702        conv_op.bias = (
6703            torch.nn.Parameter(bias_float, requires_grad=False) if use_bias else None
6704        )
6705        result_ref = conv_op(X)
6706        X2_q = None
6707
6708        if post_op.binary_attr == "sum":
6709            (X_value_min, X_value_max) = (0, 4)
6710            X2_init = torch.randint(
6711                X_value_min, X_value_max, result_ref.size(), device=device
6712            )
6713            X2 = X2_scale * ((X2_init - X2_zero_point).float())
6714            X2_q = torch.quantize_per_tensor(
6715                X2, scale=X2_scale, zero_point=X2_zero_point, dtype=input_dtype
6716            )
6717            result_ref = result_ref + X2
6718            if post_op.unary_attr == "relu":
6719                relu = torch.nn.ReLU()
6720                result_ref = relu(result_ref)
6721        elif post_op.unary_attr == "relu":
6722            assert not use_transpose, "Cannot fuse ReLU with ConvTranspose"
6723            relu = torch.nn.ReLU()
6724            result_ref = relu(result_ref)
6725        elif post_op.unary_attr == "hardtanh":
6726            assert not use_transpose, "Cannot fuse hardtanh with ConvTranspose"
6727            assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in"
6728            hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1])
6729            result_ref = hardtanh(result_ref)
6730        elif post_op.unary_attr == "hardswish":
6731            assert not use_transpose, "Cannot fuse hardswish with ConvTranspose"
6732            hardswish = torch.nn.Hardswish()
6733            result_ref = hardswish(result_ref)
6734        elif post_op.unary_attr == "swish":
6735            assert not use_transpose, "Cannot fuse silu with ConvTranspose"
6736            silu = torch.nn.SiLU()
6737            result_ref = silu(result_ref)
6738
6739        # Quantize reference results for comparison
6740        result_ref_q = torch.quantize_per_tensor(
6741            result_ref, scale=Y_scale, zero_point=Y_zero_point, dtype=output_dtype
6742        )
6743
6744        # Calculate the result for 2.X path
6745        X_q_cpu_tensor = X_q.int_repr()
6746        W_q_cpu_tensor = W_q.int_repr()
6747
6748        weight_scale = (
6749            W_q.q_per_channel_scales()
6750            if use_channelwise
6751            else torch.tensor(W_q.q_scale(), dtype=torch.double, device=device)
6752        )
6753        weight_zero_point = (
6754            W_q.q_per_channel_zero_points()
6755            if use_channelwise
6756            else torch.tensor(W_q.q_zero_point(), dtype=torch.int64, device=device)
6757        )
6758
6759        if weight_in_channel_last_format:
6760            if W_q_cpu_tensor.dim() == 5:
6761                W_q_cpu_tensor = W_q_cpu_tensor.to(memory_format=torch.channels_last_3d)
6762            elif W_q_cpu_tensor.dim() == 4:
6763                W_q_cpu_tensor = W_q_cpu_tensor.to(memory_format=torch.channels_last)
6764
6765        packed_weight = qconv_prepack(
6766            W_q_cpu_tensor,
6767            weight_scale,
6768            X_scale,
6769            X_zero_point,
6770            strides,
6771            pads,
6772            dilations,
6773            groups,
6774            X_q_cpu_tensor.size(),
6775        )
6776
6777        if post_op.binary_attr == "sum":
6778            X2_cpu_tensor = (
6779                X2_q.int_repr()
6780                if qconv_output_dtype is None
6781                else X2_q.dequantize().to(qconv_x2_dtype)
6782            ).contiguous(memory_format=torch.channels_last)
6783            Y_q_cpu_tensor = qconv(
6784                X_q_cpu_tensor,
6785                X_scale,
6786                X_zero_point,
6787                X2_cpu_tensor,
6788                X2_scale,
6789                X2_zero_point,
6790                packed_weight,
6791                weight_scale,
6792                weight_zero_point,
6793                bias_float,
6794                strides,
6795                pads,
6796                dilations,
6797                groups,
6798                Y_scale,
6799                Y_zero_point,
6800                qconv_output_dtype,
6801                post_op.binary_attr,
6802                post_op.alpha,
6803                post_op.unary_attr,
6804                post_op.scalars,
6805                post_op.algorithm,
6806            )
6807        else:
6808            Y_q_cpu_tensor = qconv(
6809                X_q_cpu_tensor,
6810                X_scale,
6811                X_zero_point,
6812                packed_weight,
6813                weight_scale,
6814                weight_zero_point,
6815                bias_float,
6816                strides,
6817                pads,
6818                dilations,
6819                groups,
6820                Y_scale,
6821                Y_zero_point,
6822                qconv_output_dtype,
6823                post_op.unary_attr,
6824                post_op.scalars,
6825                post_op.algorithm,
6826            )
6827        if fp32_output or bfloat16_output:
6828            self.assertTrue(Y_q_cpu_tensor.dtype == qconv_output_dtype)
6829            Y_q_cpu_tensor = torch.quantize_per_tensor(
6830                Y_q_cpu_tensor
6831                if fp32_output
6832                else Y_q_cpu_tensor.to(torch.float32), scale=Y_scale, zero_point=Y_zero_point, dtype=output_dtype
6833            ).int_repr()
6834
6835        # Make sure the results match
6836        # assert_array_almost_equal compares using the following formula:
6837        #     abs(desired-actual) < 1.5 * 10**(-decimal)
6838        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
6839        # We use decimal = 0 to ignore off-by-1 differences between
6840        # reference and test. Off-by-1 differences arise due to the order of
6841        # round and zero_point addition operation, i.e., if addition
6842        # followed by round is used by reference and round followed by
6843        # addition is used by test, the results may differ by 1.
6844        # For example, the result of round(2.5) + 1 is 3 while
6845        # round(2.5 + 1) is 4 assuming the rounding mode is
6846        # round-to-nearest, ties-to-even.
6847
6848        np.testing.assert_array_almost_equal(
6849            result_ref_q.int_repr().cpu().numpy(),
6850            Y_q_cpu_tensor.cpu().numpy(),
6851            decimal=0,
6852            err_msg=f"""X: {X_q}, W: {W_q}, b: {bias_float}, strides: {strides},
6853            pads: {pads}, o_pads: {o_pads}, dilations: {dilations},
6854            groups: {groups}, y_s: {Y_scale}, y_zp: {Y_zero_point}, X2: {X2_q}""",
6855        )
6856
6857        # Return the quantized data for later reuse
6858        return X_q, W_q, bias_float
6859
6860    @skipIfNoONEDNN
6861    def test_qconv1d_pt2e(self):
6862        groups_list = [1, 3]
6863        input_channels_per_group = 2
6864        output_channels_per_group = 2
6865        length = 4
6866        kernel = 3
6867        stride = 1
6868        pad = 1
6869        dilation = 1
6870        W_scale = [1.5]
6871        W_zero_point = [0]
6872        use_bias_list = [False, True]
6873        use_channelwise_list = [False, True]
6874        output_dtype_list = [None, torch.float32, torch.bfloat16]
6875        options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
6876        for groups, use_bias, use_channelwise, output_dtype in options:
6877            if output_dtype is not None and not (use_bias and use_channelwise):
6878                # Remove some test combination to reduce UT test time
6879                continue
6880            conv1d = torch.nn.Conv1d(
6881                input_channels_per_group * groups,
6882                output_channels_per_group * groups,
6883                kernel,
6884                stride,
6885                pad,
6886                dilation,
6887                groups,
6888            )
6889            qconv = torch.ops.onednn.qconv1d_pointwise
6890            qconv_prepack = torch.ops.onednn.qconv_prepack
6891            pointwise_post_op = PointwisePostOp()
6892            self._test_qconv_impl_cpu_tensor(
6893                qconv,
6894                qconv_prepack,
6895                conv1d,
6896                input_channels_per_group=input_channels_per_group,
6897                input_feature_map_shape=(length,),
6898                output_channels_per_group=output_channels_per_group,
6899                groups=groups,
6900                kernels=kernel,
6901                strides=[stride],
6902                pads=[pad],
6903                dilations=[dilation],
6904                W_scale=W_scale,
6905                W_zero_point=W_zero_point,
6906                use_bias=use_bias,
6907                post_op=pointwise_post_op,
6908                use_channelwise=use_channelwise,
6909                qconv_output_dtype=output_dtype,
6910            )
6911
6912    @skipIfNoONEDNN
6913    def test_qconv2d_pt2e(self):
6914        groups_list = [1, 3]
6915        input_channels_per_group = 2
6916        output_channels_per_group = 2
6917        input_feature_map_shape = (10, 10)
6918        kernels = (3, 3)
6919        strides = (2, 2)
6920        pads = (1, 1)
6921        dilations = (1, 1)
6922        W_scale = [1.5]
6923        W_zero_point = [0]
6924        use_bias_list = [False, True]
6925        use_channelwise_list = [False, True]
6926        channel_last_weight_format_list = [False, True]
6927        output_dtype_list = [None, torch.float32, torch.bfloat16]
6928        options = itertools.product(
6929            groups_list,
6930            use_bias_list,
6931            use_channelwise_list,
6932            channel_last_weight_format_list,
6933            output_dtype_list,
6934        )
6935        for groups, use_bias, use_channelwise, channel_last_weight_format, output_dtype in options:
6936            if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise):
6937                # Remove some test combination to reduce UT test time
6938                continue
6939            qconv = torch.ops.onednn.qconv2d_pointwise
6940            qconv_prepack = torch.ops.onednn.qconv_prepack
6941            conv_op = torch.nn.Conv2d(
6942                input_channels_per_group * groups,
6943                output_channels_per_group * groups,
6944                kernels,
6945                strides,
6946                pads,
6947                dilations,
6948                groups,
6949            )
6950            pointwise_post_op = PointwisePostOp()
6951            self._test_qconv_impl_cpu_tensor(
6952                qconv,
6953                qconv_prepack,
6954                conv_op,
6955                input_channels_per_group=input_channels_per_group,
6956                input_feature_map_shape=input_feature_map_shape,
6957                output_channels_per_group=output_channels_per_group,
6958                groups=groups,
6959                kernels=kernels,
6960                strides=strides,
6961                pads=pads,
6962                dilations=dilations,
6963                W_scale=W_scale,
6964                W_zero_point=W_zero_point,
6965                use_bias=use_bias,
6966                post_op=pointwise_post_op,
6967                use_channelwise=use_channelwise,
6968                qconv_output_dtype=output_dtype,
6969                weight_in_channel_last_format=channel_last_weight_format,
6970            )
6971
6972    @skipIfNoONEDNN
6973    def test_qconv3d_pt2e(self):
6974        input_channels_per_group = 2
6975        input_feature_map_shape = (6, 6, 6)
6976        output_channels_per_group = 2
6977        groups_list = [1, 3]
6978        kernels = (3, 3, 3)
6979        strides = (2, 2, 2)
6980        pads = (1, 1, 1)
6981        dilations = (1, 1, 1)
6982        W_scale = [1.5]
6983        W_zero_point = [0]
6984        use_bias_list = [False, True]
6985        use_channelwise_list = [False, True]
6986        channel_last_weight_format_list = [False, True]
6987        output_dtype_list = [None, torch.float32, torch.bfloat16]
6988        options = itertools.product(
6989            groups_list,
6990            use_bias_list,
6991            use_channelwise_list,
6992            channel_last_weight_format_list,
6993            output_dtype_list,
6994        )
6995        for groups, use_bias, use_channelwise, channel_last_weight_format, output_dtype in options:
6996            if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise):
6997                # Remove some test combination to reduce UT test time
6998                continue
6999            qconv = torch.ops.onednn.qconv3d_pointwise
7000            qconv_prepack = torch.ops.onednn.qconv_prepack
7001            conv_op = torch.nn.Conv3d(
7002                input_channels_per_group * groups,
7003                output_channels_per_group * groups,
7004                kernels,
7005                strides,
7006                pads,
7007                dilations,
7008                groups,
7009            )
7010            pointwise_post_op = PointwisePostOp()
7011            self._test_qconv_impl_cpu_tensor(
7012                qconv,
7013                qconv_prepack,
7014                conv_op,
7015                input_channels_per_group=input_channels_per_group,
7016                input_feature_map_shape=input_feature_map_shape,
7017                output_channels_per_group=output_channels_per_group,
7018                groups=groups,
7019                kernels=kernels,
7020                strides=strides,
7021                pads=pads,
7022                dilations=dilations,
7023                W_scale=W_scale,
7024                W_zero_point=W_zero_point,
7025                use_bias=use_bias,
7026                post_op=pointwise_post_op,
7027                use_channelwise=use_channelwise,
7028                qconv_output_dtype=output_dtype,
7029                weight_in_channel_last_format=channel_last_weight_format,
7030            )
7031
7032    # Test qconv with post op relu
7033    @skipIfNoONEDNN
7034    def test_qconv2d_relu_pt2e(self):
7035        input_channels_per_group = 2
7036        output_channels_per_group = 2
7037        groups_list = [1, 10]
7038        input_feature_map_shape = (10, 10)
7039        kernels = (3, 3)
7040        strides = (2, 2)
7041        pads = (1, 1)
7042        dilations = (1, 1)
7043        W_scale = [1.5]
7044        W_zero_point = [0]
7045        use_bias_list = [False, True]
7046        use_channelwise_list = [False, True]
7047        output_dtype_list = [None, torch.float32, torch.bfloat16]
7048        options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
7049        for groups, use_bias, use_channelwise, output_dtype in options:
7050            qconv = torch.ops.onednn.qconv2d_pointwise
7051            qconv_prepack = torch.ops.onednn.qconv_prepack
7052            conv_op = torch.nn.Conv2d(
7053                input_channels_per_group * groups,
7054                output_channels_per_group * groups,
7055                kernels,
7056                strides,
7057                pads,
7058                dilations,
7059                groups,
7060            )
7061            pointwise_post_op = PointwisePostOp(unary_attr="relu")
7062            self._test_qconv_impl_cpu_tensor(
7063                qconv,
7064                qconv_prepack,
7065                conv_op,
7066                input_channels_per_group=input_channels_per_group,
7067                input_feature_map_shape=input_feature_map_shape,
7068                output_channels_per_group=output_channels_per_group,
7069                groups=groups,
7070                kernels=kernels,
7071                strides=strides,
7072                pads=pads,
7073                dilations=dilations,
7074                W_scale=W_scale,
7075                W_zero_point=W_zero_point,
7076                use_bias=use_bias,
7077                post_op=pointwise_post_op,
7078                use_channelwise=use_channelwise,
7079                qconv_output_dtype=output_dtype,
7080            )
7081
7082    # Test qconv with post op hardtanh
7083    @skipIfNoONEDNN
7084    def test_qconv2d_hardtanh_pt2e(self):
7085        input_channels_per_group = 2
7086        output_channels_per_group = 2
7087        groups_list = [1, 10]
7088        input_feature_map_shape = (10, 10)
7089        kernels = (3, 3)
7090        strides = (2, 2)
7091        pads = (1, 1)
7092        dilations = (1, 1)
7093        W_scale = [1.5]
7094        W_zero_point = [0]
7095        use_bias_list = [False, True]
7096        use_channelwise_list = [False, True]
7097        output_dtype_list = [None, torch.float32, torch.bfloat16]
7098        options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
7099        for groups, use_bias, use_channelwise, output_dtype in options:
7100            qconv = torch.ops.onednn.qconv2d_pointwise
7101            qconv_prepack = torch.ops.onednn.qconv_prepack
7102            conv_op = torch.nn.Conv2d(
7103                input_channels_per_group * groups,
7104                output_channels_per_group * groups,
7105                kernels,
7106                strides,
7107                pads,
7108                dilations,
7109                groups,
7110            )
7111            pointwise_post_op = PointwisePostOp(unary_attr="hardtanh", scalars=[0.0, 6.0])
7112            self._test_qconv_impl_cpu_tensor(
7113                qconv,
7114                qconv_prepack,
7115                conv_op,
7116                input_channels_per_group=input_channels_per_group,
7117                input_feature_map_shape=input_feature_map_shape,
7118                output_channels_per_group=output_channels_per_group,
7119                groups=groups,
7120                kernels=kernels,
7121                strides=strides,
7122                pads=pads,
7123                dilations=dilations,
7124                W_scale=W_scale,
7125                W_zero_point=W_zero_point,
7126                use_bias=use_bias,
7127                post_op=pointwise_post_op,
7128                use_channelwise=use_channelwise,
7129                qconv_output_dtype=output_dtype,
7130            )
7131
7132    # Test qconv with post op silu
7133    @skipIfNoONEDNN
7134    def test_qconv2d_silu_pt2e(self):
7135        input_channels_per_group = 2
7136        output_channels_per_group = 2
7137        groups_list = [1, 10]
7138        input_feature_map_shape = (10, 10)
7139        kernels = (3, 3)
7140        strides = (2, 2)
7141        pads = (1, 1)
7142        dilations = (1, 1)
7143        W_scale = [1.5]
7144        W_zero_point = [0]
7145        use_bias_list = [False, True]
7146        use_channelwise_list = [False, True]
7147        output_dtype_list = [None, torch.float32, torch.bfloat16]
7148        options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
7149        for groups, use_bias, use_channelwise, output_dtype in options:
7150            qconv = torch.ops.onednn.qconv2d_pointwise
7151            qconv_prepack = torch.ops.onednn.qconv_prepack
7152            conv_op = torch.nn.Conv2d(
7153                input_channels_per_group * groups,
7154                output_channels_per_group * groups,
7155                kernels,
7156                strides,
7157                pads,
7158                dilations,
7159                groups,
7160            )
7161            pointwise_post_op = PointwisePostOp(unary_attr="swish")
7162            self._test_qconv_impl_cpu_tensor(
7163                qconv,
7164                qconv_prepack,
7165                conv_op,
7166                input_channels_per_group=input_channels_per_group,
7167                input_feature_map_shape=input_feature_map_shape,
7168                output_channels_per_group=output_channels_per_group,
7169                groups=groups,
7170                kernels=kernels,
7171                strides=strides,
7172                pads=pads,
7173                dilations=dilations,
7174                W_scale=W_scale,
7175                W_zero_point=W_zero_point,
7176                use_bias=use_bias,
7177                post_op=pointwise_post_op,
7178                use_channelwise=use_channelwise,
7179                qconv_output_dtype=output_dtype,
7180            )
7181
7182        # Test qconv with post op hardswish
7183        @skipIfNoONEDNN
7184        def test_qconv2d_hardswish_pt2e(self):
7185            input_channels_per_group = 2
7186            output_channels_per_group = 2
7187            groups_list = [1, 10]
7188            input_feature_map_shape = (10, 10)
7189            kernels = (3, 3)
7190            strides = (2, 2)
7191            pads = (1, 1)
7192            dilations = (1, 1)
7193            W_scale = [1.5]
7194            W_zero_point = [0]
7195            use_bias_list = [False, True]
7196            use_channelwise_list = [False, True]
7197            output_dtype_list = [None, torch.float32, torch.bfloat16]
7198            options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
7199
7200            for groups, use_bias, use_channelwise, output_dtype in options:
7201                qconv = torch.ops.onednn.qconv2d_pointwise
7202                qconv_prepack = torch.ops.onednn.qconv_prepack
7203                conv_op = torch.nn.Conv2d(
7204                    input_channels_per_group * groups,
7205                    output_channels_per_group * groups,
7206                    kernels,
7207                    strides,
7208                    pads,
7209                    dilations,
7210                    groups,
7211                )
7212                pointwise_post_op = PointwisePostOp(unary_attr="hardswish")
7213                self._test_qconv_impl_cpu_tensor(
7214                    qconv,
7215                    qconv_prepack,
7216                    conv_op,
7217                    input_channels_per_group=input_channels_per_group,
7218                    input_feature_map_shape=input_feature_map_shape,
7219                    output_channels_per_group=output_channels_per_group,
7220                    groups=groups,
7221                    kernels=kernels,
7222                    strides=strides,
7223                    pads=pads,
7224                    dilations=dilations,
7225                    W_scale=W_scale,
7226                    W_zero_point=W_zero_point,
7227                    use_bias=use_bias,
7228                    post_op=pointwise_post_op,
7229                    use_channelwise=use_channelwise,
7230                    qconv_output_dtype=output_dtype,
7231                )
7232
7233    # Test qconv with post op sum
7234    @skipIfNoONEDNN
7235    def test_qconv2d_sum_pt2e(self):
7236        groups_list = [1, 3]
7237        input_channels_per_group = 2
7238        output_channels_per_group = 2
7239        input_feature_map_shape = (10, 10)
7240        kernels = (3, 3)
7241        strides = (2, 2)
7242        pads = (1, 1)
7243        dilations = (1, 1)
7244        W_scale = [1.5]
7245        W_zero_point = [-3]
7246        use_bias_list = [False, True]
7247        use_channelwise_list = [False, True]
7248        output_dtype_list = [None, torch.float32, torch.bfloat16]
7249        X2_zero_point_list = [0, 1]
7250        options = itertools.product(
7251            groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list, output_dtype_list
7252        )
7253        for groups, use_bias, use_channelwise, X2_zero_point, output_dtype in options:
7254            qconv = torch.ops.onednn.qconv2d_pointwise.binary
7255            qconv_prepack = torch.ops.onednn.qconv_prepack
7256            conv_op = torch.nn.Conv2d(
7257                input_channels_per_group * groups,
7258                output_channels_per_group * groups,
7259                kernels,
7260                strides,
7261                pads,
7262                dilations,
7263                groups,
7264            )
7265            pointwise_post_op = PointwisePostOp(binary_attr="sum")
7266            self._test_qconv_impl_cpu_tensor(
7267                qconv,
7268                qconv_prepack,
7269                conv_op,
7270                input_channels_per_group=input_channels_per_group,
7271                input_feature_map_shape=input_feature_map_shape,
7272                output_channels_per_group=output_channels_per_group,
7273                groups=groups,
7274                kernels=kernels,
7275                strides=strides,
7276                pads=pads,
7277                dilations=dilations,
7278                W_scale=W_scale,
7279                W_zero_point=W_zero_point,
7280                use_bias=use_bias,
7281                post_op=pointwise_post_op,
7282                use_channelwise=use_channelwise,
7283                X2_zero_point=X2_zero_point,
7284                qconv_output_dtype=output_dtype,
7285                qconv_x2_dtype=output_dtype,
7286            )
7287
7288    # Test qconv with post op sum relu
7289    @skipIfNoONEDNN
7290    def test_qconv2d_sum_relu_pt2e(self):
7291        groups_list = [1, 3]
7292        input_channels_per_group = 2
7293        output_channels_per_group = 2
7294        input_feature_map_shape = (10, 10)
7295        kernels = (3, 3)
7296        strides = (2, 2)
7297        pads = (1, 1)
7298        dilations = (1, 1)
7299        W_scale = [1.5]
7300        W_zero_point = [-3]
7301        use_bias_list = [False, True]
7302        use_channelwise_list = [False, True]
7303        X2_zero_point_list = [0, 1]
7304        options = itertools.product(
7305            groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list
7306        )
7307        for groups, use_bias, use_channelwise, X2_zero_point in options:
7308            qconv = torch.ops.onednn.qconv2d_pointwise.binary
7309            qconv_prepack = torch.ops.onednn.qconv_prepack
7310            conv_op = torch.nn.Conv2d(
7311                input_channels_per_group * groups,
7312                output_channels_per_group * groups,
7313                kernels,
7314                strides,
7315                pads,
7316                dilations,
7317                groups,
7318            )
7319            pointwise_post_op = PointwisePostOp(binary_attr="sum", unary_attr="relu")
7320            self._test_qconv_impl_cpu_tensor(
7321                qconv,
7322                qconv_prepack,
7323                conv_op,
7324                input_channels_per_group=input_channels_per_group,
7325                input_feature_map_shape=input_feature_map_shape,
7326                output_channels_per_group=output_channels_per_group,
7327                groups=groups,
7328                kernels=kernels,
7329                strides=strides,
7330                pads=pads,
7331                dilations=dilations,
7332                W_scale=W_scale,
7333                W_zero_point=W_zero_point,
7334                use_bias=use_bias,
7335                post_op=pointwise_post_op,
7336                use_channelwise=use_channelwise,
7337                X2_zero_point=X2_zero_point,
7338            )
7339
7340    # Test qconv with post op sum
7341    @skipIfNoONEDNN
7342    def test_qconv2d_sum_relu_float_output_pt2e(self):
7343        groups = 1
7344        input_channels_per_group = 2
7345        output_channels_per_group = 2
7346        input_feature_map_shape = (10, 10)
7347        kernels = (3, 3)
7348        strides = (2, 2)
7349        pads = (1, 1)
7350        dilations = (1, 1)
7351        W_scale = [1.5]
7352        W_zero_point = [-3]
7353        use_bias_list = [False, True]
7354        use_channelwise = True
7355        output_dtype_list = [torch.float32, torch.bfloat16]
7356        X2_zero_point = 0
7357        use_relu_list = [True, False]
7358        options = itertools.product(
7359            use_bias_list, output_dtype_list, use_relu_list
7360        )
7361        for use_bias, output_dtype, use_relu in options:
7362            qconv_x2_dtype = output_dtype
7363            qconv = torch.ops.onednn.qconv2d_pointwise.binary
7364            qconv_prepack = torch.ops.onednn.qconv_prepack
7365            conv_op = torch.nn.Conv2d(
7366                input_channels_per_group * groups,
7367                output_channels_per_group * groups,
7368                kernels,
7369                strides,
7370                pads,
7371                dilations,
7372                groups,
7373            )
7374            pointwise_post_op = (
7375                PointwisePostOp(binary_attr="sum", unary_attr="relu")
7376                if use_relu
7377                else PointwisePostOp(binary_attr="sum")
7378            )
7379            self._test_qconv_impl_cpu_tensor(
7380                qconv,
7381                qconv_prepack,
7382                conv_op,
7383                input_channels_per_group=input_channels_per_group,
7384                input_feature_map_shape=input_feature_map_shape,
7385                output_channels_per_group=output_channels_per_group,
7386                groups=groups,
7387                kernels=kernels,
7388                strides=strides,
7389                pads=pads,
7390                dilations=dilations,
7391                W_scale=W_scale,
7392                W_zero_point=W_zero_point,
7393                use_bias=use_bias,
7394                post_op=pointwise_post_op,
7395                use_channelwise=use_channelwise,
7396                X2_zero_point=X2_zero_point,
7397                qconv_output_dtype=output_dtype,
7398                qconv_x2_dtype=qconv_x2_dtype,
7399            )
7400
7401class TestPadding(TestCase):
7402    @given(batch_size=st.integers(1, 64),
7403           channels=st.integers(1, 64),
7404           width=st.integers(16, 128),
7405           qtype=st.sampled_from(hu._ALL_QINT_TYPES))
7406    def test_reflection_pad1d(self, batch_size, channels, width, qtype):
7407        padding = width // 4
7408
7409        x = torch.arange(batch_size * channels * width).to(torch.float)
7410        x = x.resize(batch_size, channels, width)
7411        # Per-Tensor test
7412        scale, zp = _calculate_dynamic_qparams(x, qtype)
7413        qx = torch.quantize_per_tensor(x, scale, zp, qtype)
7414
7415        padding_op = torch.nn.ReflectionPad1d(padding)
7416
7417        y_ref = padding_op(x)
7418        qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype)
7419        qy_hat = padding_op(qx)
7420        self.assertEqual(qy_ref, qy_hat)
7421
7422        # Out variant
7423        qy_hat = torch._C._nn.reflection_pad1d(qx, padding, out=qy_hat)
7424        self.assertEqual(qy_ref, qy_hat)
7425
7426    @given(batch_size=st.integers(1, 64),
7427           channels=st.integers(1, 64),
7428           height=st.integers(16, 128),
7429           width=st.integers(16, 128),
7430           qtype=st.sampled_from(hu._ALL_QINT_TYPES))
7431    def test_reflection_pad2d(self, batch_size, channels, height, width, qtype):
7432        padding = (width // 4, width // 4, height // 4, height // 4)
7433
7434        x = torch.arange(batch_size * channels * height * width).to(torch.float)
7435        x = x.resize(batch_size, channels, height, width)
7436        # Per-Tensor test
7437        scale, zp = _calculate_dynamic_qparams(x, qtype)
7438        qx = torch.quantize_per_tensor(x, scale, zp, qtype)
7439
7440        padding_op = torch.nn.ReflectionPad2d(padding)
7441
7442        y_ref = padding_op(x)
7443        qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype)
7444        qy_hat = padding_op(qx)
7445        self.assertEqual(qy_ref, qy_hat)
7446
7447        # Out variant
7448        qy_hat = torch._C._nn.reflection_pad2d(qx, padding, out=qy_hat)
7449        self.assertEqual(qy_ref, qy_hat)
7450
7451    @given(batch_size=st.integers(1, 64),
7452           channels=st.integers(1, 64),
7453           hwd=st.integers(1, 16),  # For 3D, max input size would be 16x16x16
7454           d=st.sampled_from([1, 2, 3]),
7455           value=st.floats(-5, 5, allow_nan=False, allow_infinity=False),
7456           qtype=st.sampled_from(hu._ALL_QINT_TYPES))
7457    def test_constant_padNd(self, batch_size, channels, d, hwd, value, qtype):
7458        padding = hwd // 4
7459
7460        shape = [batch_size, channels, hwd]
7461        op = torch.nn.ConstantPad1d
7462        if d >= 2:
7463            shape.append(hwd)
7464            op = torch.nn.ConstantPad2d
7465        if d == 3:
7466            shape.append(hwd)
7467            op = torch.nn.ConstantPad3d
7468        numel = np.prod(shape)
7469
7470        x = torch.arange(numel).to(torch.float)
7471        x = x.resize(*shape)
7472        # Per-Tensor test
7473        scale, zp = _calculate_dynamic_qparams(x, qtype)
7474        qx = torch.quantize_per_tensor(x, scale, zp, qtype)
7475
7476        padding_op = op(padding, value)
7477
7478        y_ref = padding_op(x)
7479        qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype)
7480        qy_hat = padding_op(qx)
7481
7482        self.assertEqual(qy_ref, qy_hat)
7483
7484
7485@unittest.skipUnless('qnnpack' in supported_qengines,
7486                     "This Pytorch Build has not been built with or does not support QNNPACK")
7487class TestQNNPackOps(TestCase):
7488    """Tests the correctness of the quantized::qnnpack_relu op."""
7489    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
7490                       qparams=hu.qparams(dtypes=torch.quint8,
7491                                          zero_point_min=0,
7492                                          zero_point_max=0)))
7493    def test_qnnpack_relu(self, X):
7494        with override_quantized_engine('qnnpack'):
7495            X, (scale, zero_point, torch_type) = X
7496            relu = torch.nn.functional.relu
7497            X = torch.from_numpy(X)
7498            Y = X.clone()
7499
7500            qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch_type)
7501            qY_hat = relu(qX)
7502
7503            Y[Y < 0] = 0
7504            qY = torch.quantize_per_tensor(Y, scale=scale, zero_point=zero_point, dtype=torch_type)
7505            self.assertEqual(qY, qY_hat)
7506
7507    """Tests the correctness of the quantized::qnnpack_tanh op."""
7508    @skipIfNoFBGEMM
7509    def test_qnnpack_tanh(self):
7510        # Note: In QNNPACK the output scale and zero_point can only be
7511        #       2.0/256, 128 respectively, as it uses a LUT with 256 bins.
7512
7513        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
7514        memory_formats = (torch.channels_last, torch.contiguous_format)
7515        test_cases = itertools.product(shapes, memory_formats)
7516        for shape, memory_format in test_cases:
7517            X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8
7518            if memory_format == torch.channels_last and len(shape) != 4:
7519                continue
7520            X = X.to(memory_format=memory_format)
7521            qX = torch.quantize_per_tensor(X, scale=scale,
7522                                           zero_point=zero_point,
7523                                           dtype=torch_type)
7524
7525            # Floating point reference
7526            Y = torch.tanh(qX.dequantize())
7527            qY = torch.quantize_per_tensor(Y, scale=1.0 / 128, zero_point=128,
7528                                           dtype=torch.quint8)
7529            with override_quantized_engine('fbgemm'):
7530                qYserver = torch.tanh(qX)
7531            with override_quantized_engine('qnnpack'):
7532                qY_hat = torch.tanh(qX)
7533                self.assertEqual(
7534                    qY, qY_hat,
7535                    msg=f"QNNPACK TanH failed (FP ref), memory_format {memory_format}")
7536                self.assertEqual(
7537                    qYserver, qY_hat,
7538                    msg=f"QNNPACK TanH failed (FBGEMM ref), memory_format {memory_format}")
7539
7540    """Tests the correctness of the quantized::qnnpack_sigmoid op."""
7541    @skipIfNoFBGEMM
7542    def test_qnnpack_sigmoid(self):
7543        # Note: In QNNPACK the output scale and zero_point can only be
7544        #       1.0/256, 0 respectively, as it uses a LUT with 256 bins.
7545        shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
7546        memory_formats = (torch.channels_last, torch.contiguous_format)
7547        test_cases = itertools.product(shapes, memory_formats)
7548        for shape, memory_format in test_cases:
7549            X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8
7550            if memory_format == torch.channels_last and len(shape) != 4:
7551                continue
7552            X = X.to(memory_format=memory_format)
7553            qX = torch.quantize_per_tensor(X, scale=scale,
7554                                           zero_point=zero_point,
7555                                           dtype=torch_type)
7556
7557            # Floating point reference
7558            Y = torch.sigmoid(qX.dequantize())
7559            qY = torch.quantize_per_tensor(Y, scale=1.0 / 256, zero_point=0,
7560                                           dtype=torch.quint8)
7561            with override_quantized_engine('fbgemm'):
7562                qYserver = torch.sigmoid(qX)
7563            with override_quantized_engine('qnnpack'):
7564                qY_hat = torch.sigmoid(qX)
7565                self.assertEqual(
7566                    qY, qY_hat,
7567                    msg=f"QNNPACK Sigmoid failed (FP ref), memory_format {memory_format}")
7568                self.assertEqual(
7569                    qYserver, qY_hat,
7570                    msg=f"QNNPACK Sigmoid failed (FBGEMM ref), memory_format {memory_format}")
7571
7572    @skipIfNoFBGEMM
7573    def test_qnnpack_sigmoid_sweep(self):
7574        # Input parameters
7575        f_min = -4.0
7576        f_max = 4.0
7577        scale = (f_max - f_min) / 256.0
7578        zero_point = 128
7579        dtype = torch.quint8
7580
7581        step = scale / 2.0
7582        x = np.arange(f_min, f_max + step, step)
7583        X = torch.from_numpy(x).to(torch.float32)
7584        qX = torch.quantize_per_tensor(X, scale=scale,
7585                                       zero_point=zero_point,
7586                                       dtype=dtype)
7587
7588        dqX = qX.dequantize()
7589        # Floating point reference
7590        Y = torch.sigmoid(dqX)
7591        qY = torch.quantize_per_tensor(Y, scale=1.0 / 256, zero_point=0,
7592                                       dtype=torch.quint8)
7593        with override_quantized_engine('fbgemm'):
7594            qYserver = torch.sigmoid(qX)
7595        with override_quantized_engine('qnnpack'):
7596            qY_hat = torch.sigmoid(qX)
7597            self.assertEqual(qY, qY_hat,
7598                             msg="QNNPACK Sigmoid failed (FP ref)!")
7599            self.assertEqual(qYserver, qY_hat,
7600                             msg="QNNPACK Sigmoid failed (FBGEMM ref)!")
7601
7602    """Tests the correctness of the quantized::add (qnnpack) op."""
7603    @settings(suppress_health_check=(HealthCheck.filter_too_much,))
7604    @given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
7605                       qparams=hu.qparams(dtypes=[torch.quint8, torch.qint8])),
7606           zero_point=st.sampled_from([0, 2, 5, 15, 127]),
7607           scale_A=st.sampled_from([0.001, 0.057, 0.889, 12.3]),
7608           scale_B=st.sampled_from([0.008, 0.0821, 0.67, 7]),
7609           scale_C=st.sampled_from([0.003, 0.07821, 0.457, 7.34]),)
7610    def test_qnnpack_add(self, A, zero_point, scale_A, scale_B, scale_C):
7611        with override_quantized_engine('qnnpack'):
7612            A_temp = A
7613            for channels_last in [True, False]:
7614                if channels_last and len(A_temp[0].shape) != 4:
7615                    continue
7616                A, (scale_a, zero_point_A, torch_type) = A_temp
7617                B, (scale_b, zero_point_B, torch_type) = A_temp
7618                A = torch.from_numpy(A)
7619                B = torch.from_numpy(B)
7620
7621                if torch_type == torch.qint8 and not torch.backends.xnnpack.enabled:
7622                    continue
7623
7624                if channels_last:
7625                    A = A.to(memory_format=torch.channels_last)
7626                    B = B.to(memory_format=torch.channels_last)
7627                assume(scale_A // scale_C >= 2**-14)
7628                assume(scale_A // scale_C < 2**8)
7629                assume(scale_B // scale_C >= 2**-14)
7630                assume(scale_B // scale_C < 2**8)
7631
7632                zero_point_C = 127
7633                np_dtype = np.uint8
7634
7635                if torch_type == torch.qint8:
7636                    zero_point_C = 0
7637                    np_dtype = np.int8
7638
7639                qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point,
7640                                               dtype=torch_type)
7641                qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point,
7642                                               dtype=torch_type)
7643
7644                # Add ground truth
7645                C = (qA.dequantize() + qB.dequantize()).numpy()
7646
7647                qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype)
7648
7649                qC_qnnp = torch.ops.quantized.add(qA, qB, scale_C, zero_point_C)
7650
7651                np.testing.assert_equal(qC, qC_qnnp.int_repr(),
7652                                        "Quantized addition failed.")
7653
7654                Crelu = C.copy()
7655                Crelu[C < 0] = 0
7656                qCrelu = torch.quantize_per_tensor(torch.from_numpy(Crelu), scale_C,
7657                                                   zero_point_C, dtype=torch_type)
7658                qCrelu_hat = torch.ops.quantized.add_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
7659                np.testing.assert_equal(qCrelu.int_repr().numpy(), qCrelu_hat.int_repr(),
7660                                        "Quantized addition with ReLU failed.")
7661
7662        """Tests the correctness of the quantized::add (qnnpack) mul."""
7663    @settings(suppress_health_check=(HealthCheck.filter_too_much,))
7664    @given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
7665                       qparams=hu.qparams(dtypes=[torch.quint8, torch.qint8])),
7666           zero_point=st.sampled_from([0, 2, 5, 15, 127]),
7667           scale_A=st.sampled_from([0.3, 0.57, 0.889]),
7668           scale_B=st.sampled_from([0.8, 0.821, 0.67]),
7669           scale_C=st.sampled_from([0.3, 0.7821, 0.457]),)
7670    def test_qnnpack_mul(self, A, zero_point, scale_A, scale_B, scale_C):
7671        with override_quantized_engine('qnnpack'):
7672            A_temp = A
7673            for channels_last in [True, False]:
7674                if channels_last and len(A_temp[0].shape) != 4:
7675                    continue
7676                A, (scale_a, zero_point_A, torch_type) = A_temp
7677                B, (scale_b, zero_point_B, torch_type) = A_temp
7678                A = torch.from_numpy(A)
7679                B = torch.from_numpy(B)
7680
7681                if torch_type == torch.qint8 and not torch.backends.xnnpack.enabled:
7682                    continue
7683
7684                if channels_last:
7685                    A = A.to(memory_format=torch.channels_last)
7686                    B = B.to(memory_format=torch.channels_last)
7687                assume(scale_A // scale_C >= 2**-14)
7688                assume(scale_A // scale_C < 2**8)
7689                assume(scale_B // scale_C >= 2**-14)
7690                assume(scale_B // scale_C < 2**8)
7691
7692                zero_point_C = 127
7693                np_dtype = np.uint8
7694
7695                if torch_type == torch.qint8:
7696                    zero_point_C = 0
7697                    np_dtype = np.int8
7698
7699                qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point,
7700                                               dtype=torch_type)
7701                qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point,
7702                                               dtype=torch_type)
7703
7704                # Add ground truth
7705                C = (qA.dequantize() * qB.dequantize()).numpy()
7706
7707                qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype)
7708                qC_qnnp = torch.ops.quantized.mul(qA, qB, scale_C, zero_point_C)
7709
7710                np.testing.assert_equal(qC, qC_qnnp.int_repr(),
7711                                        "Quantized addition failed.")
7712
7713                Crelu = C.copy()
7714                Crelu[C < 0] = 0
7715                qCrelu = torch.quantize_per_tensor(torch.from_numpy(Crelu), scale_C,
7716                                                   zero_point_C, dtype=torch_type)
7717                qCrelu_hat = torch.ops.quantized.mul_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
7718                np.testing.assert_equal(qCrelu.int_repr().numpy(), qCrelu_hat.int_repr(),
7719                                        "Quantized addition with ReLU failed.")
7720
7721
7722    """Tests that quantized add works with broadcasting """
7723    def test_qnnpack_add_broadcast(self):
7724        def _run_test(A, B):
7725            qA = torch.quantize_per_tensor(A, 0.02, 0, dtype)
7726            qB = torch.quantize_per_tensor(B, 0.04, 2, dtype)
7727
7728            output_scale = 0.01
7729            output_zp = 1
7730
7731            # ground truth
7732            C = qA.dequantize() + qB.dequantize()
7733            qC = torch.quantize_per_tensor(C, output_scale, output_zp, dtype)
7734
7735            # quantized
7736            qC_hat_1 = torch.ops.quantized.add(qA, qB, output_scale, output_zp)
7737            qC_hat_2 = torch.ops.quantized.add(qB, qA, output_scale, output_zp)
7738
7739            self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_1.dequantize()))
7740            self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_2.dequantize()))
7741
7742        with override_quantized_engine("qnnpack"):
7743            for dtype in (torch.qint8, torch.quint8):
7744                if dtype == torch.qint8 and not torch.backends.xnnpack.enabled:
7745                    continue
7746
7747                for channels_last in [True, False]:
7748                    # 4d
7749                    A = torch.randn(1, 3, 4, 4)
7750                    B = torch.randn(1, 1, 1, 1)
7751                    if channels_last:
7752                        A = A.to(memory_format=torch.channels_last)
7753                        B = B.to(memory_format=torch.channels_last)
7754                    _run_test(A, B)
7755
7756                    # 5d
7757                    C = torch.randn(1, 3, 4, 4, 4)
7758                    D = torch.randn(1, 1, 1, 1, 1)
7759                    if channels_last:
7760                        C = C.to(memory_format=torch.channels_last_3d)
7761                        D = D.to(memory_format=torch.channels_last_3d)
7762                    _run_test(C, D)
7763
7764    """Tests the correctness of quantized::qnnpack_maxpool2d op."""
7765    @given(A=hu.tensor(shapes=hu.array_shapes(4, 4, 3, 5),
7766                       qparams=hu.qparams(dtypes=torch.quint8)),
7767           kernel=st.sampled_from([2, 4]),
7768           stride=st.sampled_from([1, 2]),
7769           padding=st.sampled_from([1, 2]))
7770    def test_qnnpack_maxpool2d(self, A, kernel, stride, padding):
7771        import torch.nn.functional as F
7772
7773        with override_quantized_engine('qnnpack'):
7774            A, (scale, zero_point, torch_type) = A
7775            X = torch.from_numpy(A)
7776            np_type = np.uint8
7777            dilation = 1
7778
7779            # Check constraints
7780            assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
7781
7782            iH, iW = X.shape[-2:]
7783
7784            oH = pool_output_shape(iH, kernel, padding, stride, dilation)
7785            assume(oH > 0)
7786            oW = pool_output_shape(iW, kernel, padding, stride, dilation)
7787            assume(oW > 0)
7788
7789            k = (kernel, kernel)
7790            s = (stride, stride)
7791            d = (dilation, dilation)
7792            p = (padding, padding)
7793
7794            q_max_pool = torch.ops.quantized.max_pool2d
7795
7796            a = scale * (X - zero_point).to(dtype=torch.float)
7797            qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point,
7798                                           dtype=torch_type)
7799
7800            a_ref = qa.dequantize()
7801
7802            a_pool = F.max_pool2d(a_ref, kernel_size=k, stride=s, padding=p,
7803                                  dilation=d)
7804
7805            a_pool_nhwc = a_pool.permute([0, 2, 3, 1])
7806
7807            qa_pool = q_max_pool(qa, k, s, p, d, ceil_mode=False)
7808
7809            qa_pool_int = qa_pool.dequantize()
7810            np.testing.assert_equal(a_pool.numpy(), qa_pool_int.numpy())
7811
7812    @given(batch_size=st.integers(1, 5),
7813           channels=st.sampled_from([2, 4, 5, 8, 16, 32]),
7814           height=st.integers(4, 10),
7815           width=st.integers(4, 10),
7816           kernel=st.integers(2, 5),
7817           stride=st.integers(1, 2),
7818           padding=st.integers(1, 2),
7819           scale=st.floats(0.2, 1.6),
7820           zero_point=st.integers(0, 25)
7821           )
7822    def test_avg_pool2d(
7823            self,
7824            batch_size,
7825            channels,
7826            height,
7827            width,
7828            kernel,
7829            stride,
7830            padding,
7831            scale,
7832            zero_point
7833
7834    ):
7835        with override_quantized_engine('qnnpack'):
7836            import torch.nn.functional as F
7837            X_init = torch.from_numpy(np.random.randint(
7838                0, 50, (batch_size, channels, height, width)))
7839
7840            X = scale * (X_init - zero_point).to(dtype=torch.float)
7841
7842            # Check constraints
7843            assume(kernel // 2 >= padding)  # Kernel cannot be overhanging!
7844
7845            iH, iW = X.shape[-2:]
7846
7847            oH = pool_output_shape(iH, kernel, padding, stride, 1)
7848            assume(oH > 0)
7849            oW = pool_output_shape(iW, kernel, padding, stride, 1)
7850            assume(oW > 0)
7851            k = (kernel, kernel)
7852            s = (stride, stride)
7853            p = (padding, padding)
7854
7855            q_avg_pool = torch.ao.nn.quantized.functional.avg_pool2d
7856
7857            x_q = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
7858                                            dtype=torch.quint8)
7859
7860            a_pool = F.avg_pool2d(x_q.dequantize().to(torch.float), kernel_size=k, stride=s, padding=p)
7861            qa_pool = q_avg_pool(x_q, k, s, p)
7862            # Quantize Ref Output
7863            a_pool_q = torch.quantize_per_tensor(a_pool, scale=scale, zero_point=zero_point,
7864                                                 dtype=torch.quint8)
7865            np.testing.assert_array_almost_equal(a_pool_q.int_repr().numpy(),
7866                                                 qa_pool.int_repr().numpy(), decimal=0)
7867
7868
7869    @given(batch_size=st.integers(1, 5),
7870           channels=st.sampled_from([2, 4, 5, 8, 16, 32]),
7871           height=st.integers(4, 20),
7872           width=st.integers(4, 20),
7873           output_height=st.integers(2, 10),
7874           output_width=st.integers(2, 10),
7875           scale=st.floats(0.2, 1.6),
7876           zero_point=st.integers(0, 25)
7877           )
7878    def test_adaptive_avg_pool2d(
7879            self,
7880            batch_size,
7881            channels,
7882            height,
7883            width,
7884            output_height,
7885            output_width,
7886            scale,
7887            zero_point
7888
7889    ):
7890        with override_quantized_engine('qnnpack'):
7891            # Check constraints
7892            assume(height >= output_height)
7893            assume(width >= output_width)
7894
7895            import torch.nn.functional as F
7896            X_init = torch.from_numpy(np.random.randint(
7897                0, 50, (batch_size, channels, height, width)))
7898
7899            X = scale * (X_init - zero_point).to(dtype=torch.float)
7900
7901            iH, iW = X.shape[-2:]
7902
7903            q_avg_pool = torch.ao.nn.quantized.functional.adaptive_avg_pool2d
7904
7905            x_q = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
7906                                            dtype=torch.quint8)
7907
7908            a_pool = F.adaptive_avg_pool2d(x_q.dequantize().to(torch.float), (output_height, output_width))
7909            qa_pool = q_avg_pool(x_q, (output_height, output_width))
7910            # Quantize Ref Output
7911            a_pool_q = torch.quantize_per_tensor(a_pool, scale=scale, zero_point=zero_point,
7912                                                 dtype=torch.quint8)
7913            np.testing.assert_array_almost_equal(a_pool_q.int_repr().numpy(),
7914                                                 qa_pool.int_repr().numpy(), decimal=0)
7915
7916
7917    @given(batch_size=st.integers(1, 5),
7918           channels=st.sampled_from([2, 4, 5, 8, 16, 32]),
7919           height=st.integers(4, 10),
7920           width=st.integers(4, 10),
7921           scale=st.floats(0.02, 2.6),
7922           zero_point=st.integers(0, 25))
7923    def test_mean(self, batch_size, channels, height, width, scale, zero_point):
7924        with override_quantized_engine('qnnpack'):
7925            dim = (2, 3)
7926            X_init = torch.from_numpy(np.random.randint(
7927                0, 50, (batch_size, channels, height, width)))
7928            X = scale * (X_init - zero_point).to(dtype=torch.float)
7929
7930            qX = torch.quantize_per_tensor(X, scale, zero_point, torch.quint8)
7931            Y = torch.mean(qX.dequantize(), dim)
7932            Y = torch.quantize_per_tensor(Y, scale, zero_point, torch.quint8)
7933            qY = torch.mean(qX, dim)
7934            np.testing.assert_array_almost_equal(Y.int_repr().numpy(), qY.int_repr().numpy(), decimal=0)
7935
7936    """Tests the correctness of the quantized::hardtanh op."""
7937    def test_hardtanh(self):
7938        if 'qnnpack' not in torch.backends.quantized.supported_engines:
7939            return
7940        with override_quantized_engine('qnnpack'):
7941            shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4))
7942            memory_formats = (torch.channels_last, torch.contiguous_format)
7943            min_vals = (-0.5, -0.3, 0.5)
7944            max_vals = (-0.3, 0.3, 0.7)
7945            test_cases = itertools.product(shapes, memory_formats, min_vals, max_vals)
7946            for shape, memory_format, min_val, max_val in test_cases:
7947                X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8
7948                if memory_format == torch.channels_last and len(shape) != 4:
7949                    continue
7950
7951                Y = X.clone()
7952                Y[Y < min_val] = min_val
7953                Y[Y > max_val] = max_val
7954                qY = torch.quantize_per_tensor(Y, scale=scale,
7955                                               zero_point=zero_point, dtype=torch_type)
7956                qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
7957                                               dtype=torch_type)
7958
7959                qY_hat = torch.ao.nn.quantized.functional.hardtanh(qX, min_val, max_val)
7960                self.assertEqual(
7961                    qY, qY_hat,
7962                    msg=f"hardtanh failed:\nactual {qY_hat}\nexpected {qY}\nmemory_format {memory_format}")
7963
7964"""Tests the correctness of the tensor comparators."""
7965class TestComparatorOps(TestCase):
7966    """Tests the element-wise equality ops."""
7967    @given(A=hu.tensor(shapes=((3, 4, 5),),
7968                       qparams=hu.qparams()),
7969           B=hu.tensor(shapes=((5,), (1, 5), (1, 1, 5), (4, 5), (3, 4, 5)),
7970                       qparams=hu.qparams()))
7971    def test_compare_tensor_tensor(self, A, B):
7972        A, (scale_a, zero_point_a, dtype_a) = A
7973        B, (scale_b, zero_point_b, dtype_b) = B
7974        tA = torch.from_numpy(A)
7975        tB = torch.from_numpy(B)
7976
7977        qA = torch.quantize_per_tensor(tA, scale=scale_a, zero_point=zero_point_a,
7978                                       dtype=dtype_a)
7979        qB = torch.quantize_per_tensor(tB, scale=scale_b, zero_point=zero_point_b,
7980                                       dtype=dtype_b)
7981        dqA = qA.dequantize()
7982        dqB = qB.dequantize()
7983
7984        ops_under_test = ('__eq__', '__ne__', '__ge__', '__le__', '__gt__',
7985                          '__lt__', 'eq', 'ne', 'ge', 'le', 'gt', 'lt')
7986
7987        for op in ops_under_test:
7988            result_ref = getattr(dqA, op)(dqB)
7989            result = getattr(qA, op)(qB)
7990            self.assertEqual(result_ref, result,
7991                             msg=f"'tensor.{op}(tensor)'' failed")
7992            # Reversed broadcasting.
7993            result_ref = getattr(dqB, op)(dqA)
7994            result = getattr(qB, op)(qA)
7995            self.assertEqual(result_ref, result,
7996                             msg=f"'tensor.{op}(tensor)'' failed")
7997
7998    @given(A=hu.tensor(shapes=((3, 4, 5),),
7999                       qparams=hu.qparams()),
8000           b=hu.floats(allow_infinity=False, allow_nan=False))
8001    def test_compare_tensor_scalar(self, A, b):
8002        A, (scale_a, zero_point_a, dtype_a) = A
8003        tA = torch.from_numpy(A)
8004
8005        qA = torch.quantize_per_tensor(tA, scale=scale_a, zero_point=zero_point_a,
8006                                       dtype=dtype_a)
8007        dqA = qA.dequantize()
8008
8009        ops_under_test_reversible = ('__eq__', '__ne__', '__ge__', '__le__',
8010                                     '__gt__', '__lt__')
8011        ops_under_test_nonreversible = ('eq', 'ne', 'ge', 'le', 'gt', 'lt')
8012
8013        for op in ops_under_test_reversible:
8014            result_ref = getattr(dqA, op)(b)
8015            result = getattr(qA, op)(b)
8016            note(f"result_ref 1: {result_ref}")
8017            note(f"result 1: {result}")
8018            self.assertEqual(result_ref, result,
8019                             msg=f"'tensor.{op}(scalar)'' failed")
8020            # Reversed broadcasting.
8021            result_ref = getattr(b, op)(dqA)
8022            result = getattr(b, op)(qA)
8023            note(f"result_ref 2: {result_ref}")
8024            note(f"result 2: {result}")
8025            self.assertEqual(result_ref, result,
8026                             msg=f"'scalar.{op}(tensor)'' failed")
8027
8028        for op in ops_under_test_nonreversible:
8029            result_ref = getattr(dqA, op)(b)
8030            result = getattr(qA, op)(b)
8031            note(f"result_ref 3: {result_ref}")
8032            note(f"result 3: {result}")
8033            self.assertEqual(result_ref, result,
8034                             msg=f"'tensor.{op}(scalar)'' failed")
8035