xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_quantized_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import numpy as np
4import math
5import random
6import torch
7import io
8import unittest
9from copy import deepcopy
10from hypothesis import given
11from hypothesis import strategies as st
12from torch.testing._internal.common_utils import TemporaryFileName
13from torch.testing._internal.common_cuda import TEST_CUDA
14from torch.testing._internal.common_utils import TestCase, DeterministicGuard
15import torch.testing._internal.hypothesis_utils as hu
16from torch.testing._internal.common_quantization import get_supported_device_types
17
18hu.assert_deadline_disabled()
19
20import itertools
21import tempfile
22
23class Foo(torch.nn.Module):
24    def __init__(self) -> None:
25        super().__init__()
26        self.qscheme = torch.per_tensor_symmetric
27
28def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
29    """Calculate the dynamic quantization parameters (scale, zero_point)
30    according to the min and max element of the tensor"""
31    if isinstance(X, torch.Tensor):
32        X = X.cpu().data.numpy()
33    if dtype == torch.qint8:
34        if reduce_range:
35            qmin, qmax = -64, 63
36        else:
37            qmin, qmax = -128, 127
38    else:  # dtype == torch.quint8
39        if reduce_range:
40            qmin, qmax = 0, 127
41        else:
42            qmin, qmax = 0, 255
43
44    min_val = X.min().astype(dtype=np.float32)
45    max_val = X.max().astype(dtype=np.float32)
46    min_val = min(0.0, min_val)
47    max_val = max(0.0, max_val)
48    scale = (np.float64(max_val) - min_val) / (qmax - qmin)
49    if scale == 0.0 or math.isinf(1.0 / scale):
50        scale = np.float64(0.1)
51        zero_point = 0
52
53    zero_point_from_min = qmin - min_val / float(scale)
54    zero_point_from_max = qmax - max_val / float(scale)
55    zero_point_from_min_error = abs(qmin) - abs(min_val / float(scale))
56    zero_point_from_max_error = abs(qmax) - abs(max_val / float(scale))
57    if zero_point_from_min_error < zero_point_from_max_error:
58        initial_zero_point = zero_point_from_min
59    else:
60        initial_zero_point = zero_point_from_max
61    nudged_zero_point = 0
62
63    if initial_zero_point < qmin:
64        nudged_zero_point = qmin
65    elif initial_zero_point > qmax:
66        nudged_zero_point = qmax
67    else:
68        nudged_zero_point = int(round(initial_zero_point))
69
70    return [scale.astype(np.float32), int(nudged_zero_point)]
71
72# Note we explicitly cast variables to np.float32 in a couple of places to avoid
73# the default casting in Python often resulting in double precision and to make
74# sure we're doing the same numerics as C++ code.
75def param_search_greedy(x, bit_rate, n_bins=200, ratio=0.16):
76    xmin, xmax = np.min(x), np.max(x)
77    stepsize = (xmax - xmin) / np.float32(n_bins)
78    min_bins = np.float32(n_bins) * (np.float32(1) - np.float32(ratio))
79    xq, loss = _compress_uniform_simplified(x, bit_rate, xmin, xmax)
80
81    solutions = []  # [(left, right, loss)] # local optima solution
82
83    cur_min, cur_max, cur_loss = xmin, xmax, loss
84    thr = min_bins * stepsize
85    while cur_min + thr < cur_max:
86        # move left
87        xq, loss1 = _compress_uniform_simplified(
88            x, bit_rate, cur_min + stepsize, cur_max
89        )
90        # move right
91        xq, loss2 = _compress_uniform_simplified(
92            x, bit_rate, cur_min, cur_max - stepsize
93        )
94
95        if cur_loss < loss1 and cur_loss < loss2:
96            # found a local optima
97            solutions.append((cur_min, cur_max, cur_loss))
98        if loss1 < loss2:
99            cur_min, cur_max, cur_loss = cur_min + stepsize, cur_max, loss1
100        else:
101            cur_min, cur_max, cur_loss = cur_min, cur_max - stepsize, loss2
102    if len(solutions):
103        best = solutions[0]
104        for solution in solutions:
105            if solution[-1] < best[-1]:
106                best = solution
107        return best[1], best[0]  # xmax, xmin
108    return xmax, xmin
109
110
111def _compress_uniform_simplified(X, bit_rate, xmin, xmax, fp16_scale_bias=True):
112    # affine transform to put Xq in [0,2**bit_rate - 1]
113    # Xq = (2 ** bit_rate - 1) * (Xq - xmin) / data_range
114    if fp16_scale_bias:
115        xmin = xmin.astype(np.float16).astype(np.float32)
116    data_range = xmax - xmin
117    scale = np.where(
118        data_range == 0, np.float32(1), data_range / np.float32(2 ** bit_rate - 1)
119    )
120    if fp16_scale_bias:
121        scale = scale.astype(np.float16).astype(np.float32)
122    inverse_scale = np.float32(1) / scale
123    Xq = np.clip(np.round((X - xmin) * inverse_scale), 0, np.float32(2 ** bit_rate - 1))
124    Xq = Xq * scale + xmin
125
126    # Manually compute loss instead of using np.linalg.norm to use the same
127    # accumulation order used by C++ code
128    vlen = 8
129    loss_v = np.zeros(vlen).astype(np.float32)
130    for i in range(len(Xq) // vlen * vlen):
131        loss_v[i % vlen] += (X[i] - Xq[i]) * (X[i] - Xq[i])
132    loss = np.float32(0)
133    for i in range(vlen):
134        loss += loss_v[i]
135    for i in range(len(Xq) // vlen * vlen, len(Xq)):
136        loss += (X[i] - Xq[i]) * (X[i] - Xq[i])
137    loss = np.sqrt(loss)
138
139    return Xq, loss
140
141class TestQuantizedTensor(TestCase):
142    def test_qtensor_equal(self):
143        # ASAN regression test reported in https://github.com/pytorch/pytorch/issues/116087
144        x = torch.rand(5)
145        x_q = torch.quantize_per_tensor(x, 0.1, 10, torch.quint4x2)
146        y_q = torch.quantize_per_tensor(x, 0.1, 10, torch.quint4x2)
147        self.assertTrue(torch.equal(x_q, y_q))
148
149    def test_per_tensor_qtensor_to_memory_format(self):
150        n = np.random.randint(1, 10)
151        c = np.random.randint(2, 10)
152        h = np.random.randint(2, 10)
153        w = np.random.randint(2, 10)
154        x = torch.rand(n, c, h, w)
155        scale = np.random.uniform(0.1, 1.0)
156        zero_point = np.random.randint(0.0, 10)
157        qints = [torch.qint8, torch.quint8, torch.qint32]
158        dtype = qints[np.random.randint(0, len(qints))]
159        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=dtype)
160        x_nhwc = x.to(memory_format=torch.channels_last)
161        qx_nhwc_using_to = qx.to(memory_format=torch.channels_last)
162        qx_nhwc_using_contiguous = qx.contiguous(memory_format=torch.channels_last)
163        self.assertEqual(qx_nhwc_using_to.stride(), qx_nhwc_using_contiguous.stride())
164        self.assertEqual(qx_nhwc_using_to.stride(), x_nhwc.stride())
165
166        # When the last two dimensions of a 4D tensor are both size 1 or if c == 1, we have a degenerate case
167        # see https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html
168        # In this case, the output of torch.Tensor.to and torch.Tensor.contiguous should not be the same
169        x = torch.rand(10, 2, 1, 1)
170        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=dtype)
171        qx_nhwc_using_to = qx.to(memory_format=torch.channels_last)
172        qx_nhwc_using_contiguous = qx.contiguous(memory_format=torch.channels_last)
173        self.assertNotEqual(qx_nhwc_using_to.stride(), qx_nhwc_using_contiguous.stride())
174
175        x = torch.rand(10, 1, 2, 2)
176        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=dtype)
177        qx_nhwc_using_to = qx.to(memory_format=torch.channels_last)
178        qx_nhwc_using_contiguous = qx.contiguous(memory_format=torch.channels_last)
179        self.assertNotEqual(qx_nhwc_using_to.stride(), qx_nhwc_using_contiguous.stride())
180
181    def test_per_channel_qtensor_to_memory_format(self):
182        n = np.random.randint(1, 10)
183        c = np.random.randint(2, 10)
184        h = np.random.randint(2, 10)
185        w = np.random.randint(2, 10)
186        x = torch.rand(n, c, h, w)
187        x_nhwc = x.to(memory_format=torch.channels_last)
188        scale = np.random.uniform(0.1, 1.0)
189        zero_point = np.random.randint(0.0, 10)
190        qints = [torch.qint8, torch.quint8, torch.qint32]
191        dtype = qints[np.random.randint(0, len(qints))]
192        for axis in range(x.ndim):
193            scales = torch.rand(x.size(axis)) + 0.00001
194            zero_points = torch.randint(low=0, high=10, size=(x.size(axis), ))
195            qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=dtype, axis=axis)
196            qx_nhwc_using_to = qx.to(memory_format=torch.channels_last)
197            self.assertEqual(qx_nhwc_using_to.stride(), x_nhwc.stride())
198
199    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
200    def test_qtensor_cuda(self):
201        self._test_qtensor(torch.device('cuda'))
202        self._test_qtensor_dynamic(torch.device('cuda'))
203
204    def test_qtensor_cpu(self):
205        self._test_qtensor(torch.device('cpu'))
206        self._test_qtensor_dynamic(torch.device('cpu'))
207
208    def _test_qtensor_dynamic(self, device):
209        # max number of tensor dimensions
210        max_tensor_order = 4
211        # max size for any tensor dimension
212        max_dim_sz = 20
213
214        num_dim = np.random.randint(low=1, high=max_tensor_order)
215        dims = np.random.randint(low=1, high=max_dim_sz, size=num_dim)
216        mat2quant = torch.randn(*dims, dtype=torch.float, device=device)
217        reduce_flag = False
218
219        for dtype in [torch.qint8, torch.quint8]:
220            q_d = torch.quantize_per_tensor_dynamic(mat2quant, dtype, reduce_flag)
221            scale, zero_pt = _calculate_dynamic_qparams(mat2quant, dtype, reduce_flag)
222            q_s = torch.quantize_per_tensor(mat2quant, scale, zero_pt, dtype)
223
224            self.assertEqual(q_d, q_s)
225
226    def _test_qtensor(self, device):
227        device = str(device)
228        num_elements = 10
229        scale = 1.0
230        zero_point = 2
231        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
232            r = torch.ones(num_elements, dtype=torch.float, device=device)
233            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
234            self.assertEqual(qr.q_scale(), scale)
235            self.assertEqual(qr.q_zero_point(), zero_point)
236            self.assertTrue(qr.is_quantized)
237            self.assertFalse(r.is_quantized)
238            self.assertEqual(qr.qscheme(), torch.per_tensor_affine)
239            self.assertTrue(isinstance(qr.qscheme(), torch.qscheme))
240            # slicing and int_repr
241            int_repr = qr.int_repr()
242            for num in int_repr:
243                self.assertEqual(num, 3)
244            for num in qr[2:].int_repr():
245                self.assertEqual(num, 3)
246            # dequantize
247            rqr = qr.dequantize()
248            for i in range(num_elements):
249                self.assertEqual(r[i], rqr[i])
250            # we can also print a qtensor
251            empty_r = torch.ones((0, 1), dtype=torch.float, device=device)
252            empty_qr = torch.quantize_per_tensor(empty_r, scale, zero_point, dtype)
253
254            device_msg = "" if device == 'cpu' else "device='" + device + ":0', "
255            dtype_msg = str(dtype) + ", "
256            self.assertEqual(' '.join(str(empty_qr).split()),
257                             "tensor([], " + device_msg + "size=(0, 1), dtype=" + dtype_msg +
258                             "quantization_scheme=torch.per_tensor_affine, " +
259                             "scale=1.0, zero_point=2)")
260
261    def test_qtensor_int_repr(self):
262        # to catch edge case when num elements * bit rate < 8, make sure at lease allocate one byte to hold the int repr
263        num_elements = 1
264        device = torch.device('cpu')
265        scale = 1.0
266        zero_point = 2
267        dtype = torch.quint2x4
268        r = torch.ones(num_elements, dtype=torch.float, device=device)
269        qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
270        int_repr = qr.int_repr()
271        self.assertEqual(int_repr.numel(), 1)
272        # Packed one entry looks like 00000011
273        self.assertEqual(int_repr[0], 3)
274
275    def test_qtensor_sub_byte_aligned_cols(self):
276        # Packed 4 entries, each of value 3, look like 00110011, 00110011 for torch.qunit4x2, or 11111111 for torch.quint2x4
277        self._test_qtensor_sub_byte(1, 4, torch.quint4x2, 2, [51, 51])
278        self._test_qtensor_sub_byte(1, 4, torch.quint2x4, 4, [255])
279
280    def test_qtensor_sub_byte_not_aligned_cols(self):
281        # Packed 5 entries, each of value 3, look like 00110011, 00110011, 00000011 for torch.qunit4x2,
282        # or 11111111, 00000011 for torch.quint2x4
283        self._test_qtensor_sub_byte(1, 5, torch.quint4x2, 2, [51, 51, 3])
284        self._test_qtensor_sub_byte(1, 5, torch.quint2x4, 4, [255, 3])
285
286    def _test_qtensor_sub_byte(self, rows, cols, dtype, elements_per_byte, expected_packed_vals):
287        num_elements = rows * cols
288        scale = 1.0
289        zero_point = 2
290
291        r = torch.ones((rows, cols), dtype=torch.float)
292        qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
293        self.assertEqual(qr.q_scale(), scale)
294        self.assertEqual(qr.q_zero_point(), zero_point)
295        self.assertTrue(qr.is_quantized)
296        self.assertFalse(r.is_quantized)
297        self.assertEqual(qr.storage().size(), rows * math.ceil(cols / elements_per_byte), f"with {dtype}, {elements_per_byte}")
298
299        int_repr = qr.int_repr()
300        self.assertEqual(int_repr.numel(), len(expected_packed_vals))
301        for num, expected in zip(int_repr, expected_packed_vals):
302            self.assertEqual(num, expected, f"with dtype={dtype}, elements_per_byte={elements_per_byte}, rows={rows}, cols={cols}")
303
304        # Test tensor creation
305        q = torch._empty_affine_quantized([num_elements], scale=scale, zero_point=zero_point, dtype=dtype)
306        self.assertEqual(q.storage().size(), math.ceil(num_elements / elements_per_byte), f"with {dtype}, {elements_per_byte}")
307
308        # Test save/load
309        with tempfile.NamedTemporaryFile() as f:
310            torch.save(qr, f)
311            for weights_only in [True, False]:
312                f.seek(0)
313                loaded_q = torch.load(f, weights_only=weights_only)
314                loaded_int_repr = loaded_q.int_repr()
315                self.assertEqual(int_repr, loaded_int_repr)
316
317    def test_qtensor_channel_float_assignment(self):
318        t1 = torch.rand(2, 3, 5, 5)
319        t2 = torch.rand(2, 3, 5, 5)
320        for axis in range(t1.ndim):
321            scales = np.random.rand(t1.size()[axis])
322            zero_points = np.random.randint(low=0, high=50, size=t1.size()[axis])
323            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
324                qt1 = torch.quantize_per_channel(t1, scales=torch.tensor(scales),
325                                                 zero_points=torch.tensor(zero_points), dtype=dtype, axis=axis)
326                qt2 = torch.quantize_per_channel(t2, scales=torch.tensor(scales),
327                                                 zero_points=torch.tensor(zero_points), dtype=dtype, axis=axis)
328                i = 0
329                j = 1
330                k = 2
331                l = 4
332                # scalar assignment verification
333                qt1[i][j][k][l] = t2[i][j][k][l]
334                self.assertEqual(qt1[i][j][k][l], qt2[i][j][k][l])
335                # 1D tensor assignment verification
336                qt1[i][j][k][2:l] = t2[i][j][k][2:l]
337                self.assertEqual(qt1[i][j][k][2:l], qt2[i][j][k][2:l])
338                qt1[i][j][k] = t2[i][j][k]
339                self.assertEqual(qt1[i][j][k], qt2[i][j][k])
340                # 2D tensor assignment verification
341                qt1[i][j][k:] = t2[i][j][k:]
342                self.assertEqual(qt1[i][j][k:], qt2[i][j][k:])
343                qt1[i][j] = t2[i][j]
344                self.assertEqual(qt1[i][j], qt2[i][j])
345                # 3D tensor assignment verification
346                qt1[i][j:] = t2[i][j:]
347                self.assertEqual(qt1[i][j:], qt2[i][j:])
348                qt1[i] = t2[i]
349                self.assertEqual(qt1[i], qt2[i])
350                # 4D tensor assignment verification
351                qt1[:1] = t2[:1]
352                self.assertEqual(qt1[:1], qt2[:1])
353                qt1[:] = t2[:]
354                self.assertEqual(qt1[:], qt2[:])
355                # non-contiguous case **this should raise an exception**
356                with self.assertRaisesRegex(RuntimeError, "Quantized copy only works with contiguous and NHWC Tensors"):
357                    qt1[:, 0] = t2[:, 0]
358
359    def test_qtensor_float_assignment(self):
360        # Scalar Tensor
361        # item
362        scale = 1.0
363        zero_point = 2
364        devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
365        for device in devices:
366            r = torch.ones(1, dtype=torch.float).to(device=device)
367            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
368                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
369                self.assertEqual(qr.item(), 1)
370                self.assertEqual(qr[0].item(), 1)
371                # assignment
372                self.assertTrue(qr[0].is_quantized)
373                qr[0] = torch.Tensor([11.3]).to(device=device)  # float assignment
374                self.assertEqual(qr.item(), 11)
375                x = torch.ones(1, dtype=torch.float).to(device=device) * 15.3
376                # Copying from a float Tensor
377                qr[:] = x
378                self.assertEqual(qr.item(), 15)
379
380                dtype_msg = str(dtype) + ", "
381                if device == "cuda":
382                    self.assertEqual(' '.join(str(qr).split()),
383                                     "tensor([15.], device='" + str(qr.device) + "', size=(1,), dtype=" + dtype_msg +
384                                     "quantization_scheme=torch.per_tensor_affine, " +
385                                     "scale=1.0, zero_point=2)")
386                else:
387                    self.assertEqual(' '.join(str(qr).split()),
388                                     "tensor([15.], size=(1,), dtype=" + dtype_msg +
389                                     "quantization_scheme=torch.per_tensor_affine, " +
390                                     "scale=1.0, zero_point=2)")
391
392    def test_qtensor_quant_dequant(self):
393        scale = 0.02
394        zero_point = 2
395        for device in get_supported_device_types():
396            r = torch.rand(3, 2, 4, 5, dtype=torch.float, device=device) * 4 - 2
397            for memory_format in [torch.contiguous_format, torch.channels_last]:
398                r = r.contiguous(memory_format=memory_format)
399                for dtype in [torch.qint8, torch.quint8, torch.qint32]:
400                    qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
401                    rqr = qr.dequantize()
402                    self.assertTrue(np.allclose(r.cpu().numpy(), rqr.cpu().numpy(), atol=2 / scale))
403        # Also check 5D tensors work.
404        for device in get_supported_device_types():
405            r = torch.rand(3, 2, 4, 5, 6, dtype=torch.float, device=device) * 4 - 2
406            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
407                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
408                rqr = qr.dequantize()
409                self.assertTrue(np.allclose(r.cpu().numpy(), rqr.cpu().numpy(), atol=2 / scale))
410
411    # legacy constructor/new doesn't support qtensors
412    def test_qtensor_legacy_new_failure(self):
413        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
414        scale = 0.02
415        zero_point = 2
416        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
417        self.assertRaises(RuntimeError, lambda: qr.new(device='cpu'))
418        self.assertRaises(RuntimeError, lambda: qr.new(r.storage()))
419        self.assertRaises(RuntimeError, lambda: qr.new(r))
420        self.assertRaises(RuntimeError, lambda: qr.new(torch.Size([2, 3])))
421        self.assertRaises(RuntimeError, lambda: qr.new([6]))
422
423    def test_per_channel_qtensor_creation_cpu(self):
424        self._test_per_channel_qtensor_creation(torch.device('cpu'))
425
426    def _test_dequantize_fp16(self, device):
427        data_orig = torch.randn(1, 2, 4, 4, dtype=torch.float, device=device)
428        data_fp16 = data_orig.to(torch.float16)
429        data_fp16_dequant = data_fp16.dequantize()
430        data_fp16_fp32 = data_fp16.to(torch.float)
431        self.assertTrue(data_fp16_dequant.dtype == torch.float)
432        self.assertTrue(torch.allclose(data_fp16_fp32, data_fp16_dequant))
433
434    def test_dequantize_fp16_cpu(self):
435        self._test_dequantize_fp16(torch.device('cpu'))
436
437    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
438    def test_dequantize_fp16_cuda(self):
439        self._test_dequantize_fp16(torch.device('cuda'))
440
441    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
442    def test_per_channel_qtensor_creation_cuda(self):
443        self._test_per_channel_qtensor_creation(torch.device('cuda'))
444
445    def _test_per_channel_qtensor_creation(self, device):
446        numel = 10
447        ch_axis = 0
448        scales = torch.rand(numel, device=device)
449        zero_points_int = torch.randint(0, 10, size=(numel,), device=device)
450        zero_points_float = torch.randn(numel, device=device)
451        for dtype, zero_points in itertools.product([torch.qint8, torch.quint8], [zero_points_float, zero_points_int]):
452            q = torch._empty_per_channel_affine_quantized(
453                [numel], scales=scales, zero_points=zero_points, axis=ch_axis, dtype=dtype, device=device)
454            self.assertEqual(scales, q.q_per_channel_scales(), exact_dtype=False)
455            self.assertEqual(zero_points, q.q_per_channel_zero_points())
456            self.assertEqual(ch_axis, q.q_per_channel_axis())
457
458        # create Tensor from uint8_t Tensor, scales and zero_points
459        for zero_points in [zero_points_float, zero_points_int]:
460            int_tensor = torch.randint(0, 100, size=(numel,), dtype=torch.uint8, device=device)
461            q = torch._make_per_channel_quantized_tensor(int_tensor, scales, zero_points, ch_axis)
462            self.assertEqual(int_tensor, q.int_repr())
463            self.assertEqual(scales, q.q_per_channel_scales(), exact_dtype=False)
464            self.assertEqual(zero_points, q.q_per_channel_zero_points())
465            self.assertEqual(ch_axis, q.q_per_channel_axis())
466
467    def test_qtensor_creation(self):
468        scale = 0.5
469        zero_point = 10
470        numel = 10
471        for device in get_supported_device_types():
472            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
473                                              device=device, dtype=torch.quint8)
474            self.assertEqual(scale, q.q_scale())
475            self.assertEqual(zero_point, q.q_zero_point())
476
477            # create Tensor from uint8_t Tensor, scale and zero_point
478            int_tensor = torch.randint(0, 100, size=(10,), device=device, dtype=torch.uint8)
479            q = torch._make_per_tensor_quantized_tensor(int_tensor, scale, zero_point)
480            self.assertEqual(int_tensor, q.int_repr())
481            self.assertEqual(scale, q.q_scale())
482            self.assertEqual(zero_point, q.q_zero_point())
483
484            # create via empty_like
485            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
486                                              device=device, dtype=torch.quint8)
487            q_el = torch.empty_like(q)
488            self.assertEqual(q.q_scale(), q_el.q_scale())
489            self.assertEqual(q.q_zero_point(), q_el.q_zero_point())
490            self.assertEqual(q.dtype, q_el.dtype)
491
492            # create via empty_like but change the dtype (currently not supported)
493            with self.assertRaises(RuntimeError):
494                torch.empty_like(q, dtype=torch.qint8)
495
496    def test_qtensor_dtypes(self):
497        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
498        scale = 0.2
499        zero_point = 2
500        for dtype in [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]:
501            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
502            rqr = qr.dequantize()
503            self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
504
505    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
506    def test_per_tensor_to_device(self):
507        dtypes = [
508            torch.quint8,
509            torch.qint8,
510            torch.qint32,
511        ]
512        device = torch.device('cuda')
513        for dtype in dtypes:
514            r = torch.rand(2, 2, dtype=torch.float) * 10
515            scale = torch.rand(2).abs().max().item()
516            zero_point = (torch.rand(2) * 10).round().to(torch.long).max().item()
517
518            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
519            qr = qr.to(device)
520            qr_cuda = torch.quantize_per_tensor(r.to(device), scale, zero_point, dtype)
521            qr_cuda = qr_cuda.to('cpu')
522            self.assertEqual('cuda', qr.device.type)
523            self.assertEqual('cpu', qr_cuda.device.type)
524
525    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
526    def test_per_channel_to_device(self):
527        dtype_and_zero_types = [
528            (torch.quint8, torch.float),
529            (torch.qint8, torch.float),
530            #  (torch.qint32, torch.float) not supported for quantize_per_channel
531            (torch.quint8, torch.long),
532            (torch.qint8, torch.long),
533            (torch.qint32, torch.long),
534        ]
535        axis = 1
536        device = torch.device('cuda')
537        for dtype, zero_type in dtype_and_zero_types:
538            r = torch.rand(2, 2, dtype=torch.float) * 10
539            scales = torch.rand(2).abs()
540            zero_points = (torch.rand(2) * 10).round().to(zero_type)
541
542            dqr = torch.quantize_per_channel(r, scales, zero_points, axis, dtype)
543            dqr = dqr.to(device)
544            dqr_cuda = torch.quantize_per_channel(r.to(device), scales.to(
545                device), zero_points.to(device), axis, dtype)
546            dqr_cuda = dqr_cuda.to('cpu')
547
548            self.assertEqual('cuda', dqr.device.type)
549            self.assertEqual('cuda', dqr.q_per_channel_scales().device.type)
550            self.assertEqual('cuda', dqr.q_per_channel_zero_points().device.type)
551
552            self.assertEqual('cpu', dqr_cuda.device.type)
553            self.assertEqual('cpu', dqr_cuda.q_per_channel_scales().device.type)
554            self.assertEqual('cpu', dqr_cuda.q_per_channel_zero_points().device.type)
555
556    @unittest.skipIf(not torch.cuda.is_available(), 'CUDA is not available')
557    def test_compare_per_tensor_device_numerics(self):
558        dtypes = [
559            torch.quint8,
560            torch.qint8,
561            torch.qint32,
562        ]
563        device = torch.device('cuda')
564        for dtype in dtypes:
565            r = torch.rand(2, 2) * 10
566            r[0, 0] = 2.5
567            scale = torch.rand(2).abs().max().item()
568            zero_point = (torch.rand(2) * 10).round().to(torch.long).max().item()
569
570            qtr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
571            dqtr = qtr.dequantize()
572            qtr_cuda = torch.quantize_per_tensor(r.to(device), scale, zero_point, dtype)
573            dqtr_cuda = qtr_cuda.dequantize()
574            self.assertEqual(qtr.int_repr(), qtr_cuda.int_repr())
575            self.assertTrue(np.allclose(dqtr, dqtr_cuda.cpu()))
576
577    @unittest.skipIf(not torch.cuda.is_available(), 'CUDA is not available')
578    def test_compare_per_channel_device_numerics(self):
579        dtype_and_zero_types = [
580            (torch.quint8, torch.float),
581            (torch.qint8, torch.float),
582            #  (torch.qint32, torch.float) not supported for quantize_per_channel
583            (torch.quint8, torch.long),
584            (torch.qint8, torch.long),
585            (torch.qint32, torch.long),
586        ]
587        axis = 1
588        device = torch.device('cuda')
589        for i in range(20):
590            for dtype, zero_type in dtype_and_zero_types:
591                r = torch.rand(2, 2) * 10
592                r[0, 0] = 2.5
593                scales = torch.rand(2).abs()
594                zero_points = (torch.rand(2) * 10).round().to(zero_type)
595
596                qr = torch.quantize_per_channel(r, scales, zero_points, axis, dtype)
597                dqr = qr.dequantize()
598                qr_cuda = torch.quantize_per_channel(r.to(device), scales.to(
599                    device), zero_points.to(device), axis, dtype)
600                dqr_cuda = qr_cuda.dequantize()
601                self.assertEqual(qr.int_repr(), qr_cuda.int_repr())
602                self.assertTrue(np.allclose(dqr, dqr_cuda.cpu()))
603
604    def _test_quantize_per_channel(self, r, scales, zero_points, axis, float_params):
605
606        def _quantize_per_channel_ref_nd(data, scales, zero_points, float_params):
607            dims = data.size()
608            data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
609            res = torch.empty_like(data)
610            quant_min, quant_max = 0, 255
611            for i in range(res.size()[0]):
612                for j in range(res.size()[1]):
613                    for k in range(res.size()[2]):
614                        if float_params:
615                            inv_scale = 1.0 / scales[j]
616                            res[i][j][k] = np.clip(
617                                np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max)
618                        else:
619                            res[i][j][k] = np.clip(
620                                np.round(data[i][j][k] / scales[j]) + zero_points[j], quant_min, quant_max)
621            res = res.view(*dims)
622            return res
623
624        contig_format = torch.channels_last if r.ndim == 4 else torch.channels_last_3d
625        for memory_format in [torch.contiguous_format, contig_format]:
626            ref_res = _quantize_per_channel_ref_nd(r, scales, zero_points, float_params)
627            r_contig = r.contiguous(memory_format=memory_format)
628            qr = torch.quantize_per_channel(r_contig, scales, zero_points, axis, torch.quint8)
629            rqr = qr.dequantize()
630            self.assertTrue(np.allclose(qr.int_repr(), ref_res))
631            self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))
632
633    def test_qtensor_quantize_per_channel(self):
634        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
635        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
636        zero_points = torch.tensor([5, 10], dtype=torch.long)
637        axis = 1
638
639        def quantize_c(data, scales, zero_points):
640            res = torch.empty((3, 2))
641            quant_min, quant_max = 0, 255
642            for i in range(3):
643                for j in range(2):
644                    res[i][j] = np.clip(np.round(data[i][j] / scales[j]) + zero_points[j], quant_min, quant_max)
645            return res
646        qr = torch.quantize_per_channel(r, scales, zero_points, axis, torch.quint8)
647        rqr = qr.dequantize()
648        self.assertTrue(np.allclose(qr.int_repr(), quantize_c(r, scales, zero_points)))
649        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))
650
651        # Check 4D tensor with 2 different memory formats.
652        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 - 2
653        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
654        zero_points = torch.tensor([5, 10], dtype=torch.long)
655        self._test_quantize_per_channel(r, scales, zero_points, 1 , False)
656
657        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.double)
658        zero_points = torch.tensor([5, 10, 7], dtype=torch.long)
659        self._test_quantize_per_channel(r, scales, zero_points, 0, False)
660
661        # Check 5D tensor.
662        r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
663        scales = torch.tensor([0.2, 0.03], dtype=torch.double)
664        zero_points = torch.tensor([5, 10], dtype=torch.long)
665        self._test_quantize_per_channel(r, scales, zero_points, 1, False)
666
667        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.double)
668        zero_points = torch.tensor([5, 10, 7], dtype=torch.long)
669        self._test_quantize_per_channel(r, scales, zero_points, 0, False)
670
671    def test_quantize_per_channel_float_qparams(self):
672        r = torch.rand(3, 2, dtype=torch.float) * 4
673        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
674        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
675        axis = 1
676
677        # Reference quantize function with FP zero_point.
678        def quantize_ref(data, scales, zero_points):
679            res = torch.empty((3, 2))
680            quant_min, quant_max = 0, 255
681            for i in range(3):
682                for j in range(2):
683                    inv_scale = 1.0 / scales[j]
684                    res[i][j] = np.clip(np.round(data[i][j] * inv_scale + zero_points[j]), quant_min, quant_max)
685            return res
686
687        qr = torch.quantize_per_channel(r, scales, zero_points, axis, torch.quint8)
688        dequant_tensor = qr.dequantize()
689        ref = quantize_ref(r, scales, zero_points)
690        self.assertTrue(np.allclose(qr.int_repr(), ref))
691        self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1))
692
693        # Check 4D tensor with 2 different memory formats.
694        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
695        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
696        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
697        self._test_quantize_per_channel(r, scales, zero_points, 1, True)
698
699        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
700        zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
701        self._test_quantize_per_channel(r, scales, zero_points, 0, True)
702
703        # Check 5D tensor.
704        r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
705        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
706        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
707        self._test_quantize_per_channel(r, scales, zero_points, 1, True)
708
709        scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
710        zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
711        self._test_quantize_per_channel(r, scales, zero_points, 0, True)
712
713    def test_quantize_per_channel_sub_byte(self):
714        """ Tests the per channel quantization scheme for 4-bit qtensors.
715        The scale and zero point for this have to be in floating point. """
716        r = torch.rand(3, 2, dtype=torch.float) * 4
717        scales = torch.tensor([0.2, 0.3, 0.1], dtype=torch.float)
718        zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
719        qr = torch.quantize_per_channel(r, scales, zero_points, 0, torch.quint4x2)
720        dequant_tensor = qr.dequantize()
721
722        def _get_qranges(bit_width):
723            if bit_width == 4:
724                return 0, 15
725
726        def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis, bit_width):
727            dims = data.size()
728            data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
729            qtensor_size = math.ceil(data.numel() / 2)
730            res = torch.empty(qtensor_size, dtype=torch.uint8)
731            elem_per_byte = 8 // bit_width
732            quant_min, quant_max = _get_qranges(bit_width)
733            for i in range(data.size()[0]):
734                for j in range(data.size()[1]):
735                    for k in range(data.size()[2]):
736                        inv_scale = 1.0 / scales[j]
737                        index = i * data.size()[1] * data.size()[2] + j * data.size()[2] + k
738                        qvalue = np.clip(
739                            np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max).to(dtype=torch.int)
740                        res_idx = int(index / elem_per_byte)
741                        if (index % elem_per_byte == 0):
742                            res[res_idx] = qvalue
743                        else:
744                            res[res_idx] |= (qvalue << ((index % elem_per_byte) * bit_width))
745            return res
746
747        ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0, 4)
748        self.assertTrue(np.allclose(qr.int_repr(), ref_res))
749        self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1 / np.min(scales.numpy())))
750
751        # Check 4D tensor with non-zero axis.
752        r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
753        scales = torch.tensor([0.2, 0.03], dtype=torch.float)
754        zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
755        qr = torch.quantize_per_channel(r, scales, zero_points, axis=1, dtype=torch.quint4x2)
756        ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1, 4)
757        self.assertTrue(np.allclose(qr.int_repr(), ref_res))
758
759    def test_qtensor_permute(self):
760        scale = 0.02
761        zero_point = 1
762        for device in get_supported_device_types():
763            r = torch.rand(10, 30, 2, 2, device=device, dtype=torch.float) * 4 - 2
764            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
765                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
766                qr = qr.transpose(0, 1)
767                rqr = qr.dequantize()
768                # compare transpose + dequantized result with orignal transposed result
769                self.assertTrue(np.allclose(r.cpu().numpy().transpose([1, 0, 2, 3]), rqr.cpu().numpy(), atol=2 / scale))
770
771                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
772                qr1 = qr.permute([1, 0, 2, 3])
773                qr2 = qr.transpose(0, 1)
774                # compare int representation after transformations
775                self.assertEqual(qr1.int_repr(), qr2.int_repr())
776                self.assertEqual(qr1.q_scale(), qr2.q_scale())
777                self.assertEqual(qr1.q_zero_point(), qr2.q_zero_point())
778                # compare dequantized result
779                self.assertEqual(qr1.dequantize(), qr2.dequantize())
780                # compare permuted + dequantized result with original transposed result
781                self.assertTrue(np.allclose(qr2.dequantize().cpu().numpy(),
782                                            r.cpu().numpy().transpose([1, 0, 2, 3]), atol=2 / scale))
783                # make permuted result contiguous
784                self.assertEqual(qr2.contiguous().int_repr(), qr2.int_repr())
785
786                # change memory format
787                qlast = qr.contiguous(memory_format=torch.channels_last)
788                self.assertEqual(qr.stride(), sorted(qr.stride(), reverse=True))
789                self.assertNotEqual(qlast.stride(), sorted(qlast.stride(), reverse=True))
790                self.assertEqual(qr.int_repr(), qlast.int_repr())
791                self.assertEqual(qr.q_scale(), qlast.q_scale())
792                self.assertEqual(qr.q_zero_point(), qlast.q_zero_point())
793                self.assertEqual(qlast.dequantize(), qr.dequantize())
794
795                # permuting larger tensors
796                x = torch.randn(64, 64, device=device)
797                qx = torch.quantize_per_tensor(x, 1.0, 0, dtype)
798                # should work
799                qx.permute([1, 0])
800
801    def test_qtensor_per_channel_permute(self):
802        for device in get_supported_device_types():
803            r = torch.rand(20, 10, 2, 2, dtype=torch.float, device=device) * 4 - 2
804            dtype = torch.qint8
805            scales = torch.rand(10, device=device) * 0.02 + 0.01
806            zero_points = torch.round(torch.rand(10, device=device) * 2 - 1).to(torch.long)
807            qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
808
809            # we can't reorder the axis
810            with self.assertRaises(RuntimeError):
811                qr.transpose(0, 1)
812
813            # but we can change memory format
814            qlast = qr.contiguous(memory_format=torch.channels_last)
815            self.assertEqual(qr.stride(), sorted(qr.stride(), reverse=True))
816            self.assertNotEqual(qlast.stride(), sorted(qlast.stride(), reverse=True))
817            self.assertEqual(qr.int_repr(), qlast.int_repr())
818            self.assertEqual(scales.to(dtype=torch.float64), qlast.q_per_channel_scales())
819            self.assertEqual(zero_points, qlast.q_per_channel_zero_points())
820            self.assertEqual(1, qlast.q_per_channel_axis())
821            self.assertEqual(qlast.dequantize(), qr.dequantize())
822
823    def test_qtensor_load_save(self):
824        scale = 0.2
825        zero_point = 10
826        # storage is not accessible on the cuda right now
827        device = "cpu"
828        r = torch.rand(15, 2, dtype=torch.float32, device=device) * 2
829        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
830            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
831            qrv = qr[:, 1]
832            with tempfile.NamedTemporaryFile() as f:
833                # Serializing and Deserializing Tensor
834                torch.save((qr, qrv), f)
835                for weights_only in [True, False]:
836                    f.seek(0)
837                    qr2, qrv2 = torch.load(f, weights_only=weights_only)
838                    self.assertEqual(qr, qr2)
839                    self.assertEqual(qrv, qrv2)
840                    self.assertEqual(qr2.storage().data_ptr(), qrv2.storage().data_ptr())
841
842    def test_qtensor_per_channel_load_save(self):
843        r = torch.rand(20, 10, dtype=torch.float) * 4 - 2
844        scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01
845        zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
846        # quint32, cuda is not supported yet
847        for dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
848            if dtype == torch.quint4x2:
849                zero_points = torch.ones(10, dtype=torch.float)
850            qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
851            with tempfile.NamedTemporaryFile() as f:
852                # Serializing and Deserializing Tensor
853                torch.save(qr, f)
854                for weights_only in [True, False]:
855                    f.seek(0)
856                    qr2 = torch.load(f, weights_only=weights_only)
857                    self.assertEqual(qr, qr2)
858
859    def test_qtensor_copy(self):
860        scale = 0.5
861        zero_point = 10
862        numel = 10
863        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
864            for device in get_supported_device_types():
865                # copy from same scale and zero_point
866                q = torch._empty_affine_quantized([numel], scale=scale,
867                                                  zero_point=zero_point, device=device, dtype=dtype)
868                q2 = torch._empty_affine_quantized([numel], scale=scale,
869                                                   zero_point=zero_point, device=device, dtype=dtype)
870                q.copy_(q2)
871                self.assertEqual(q.int_repr(), q2.int_repr())
872                self.assertEqual(q.q_scale(), q2.q_scale())
873                self.assertEqual(q.q_zero_point(), q2.q_zero_point())
874                # copying from different scale and zero_point
875                new_scale = 3.2
876                new_zero_point = 5
877                q = torch._empty_affine_quantized([numel], scale=new_scale,
878                                                  zero_point=new_zero_point, device=device, dtype=dtype)
879                # check original scale and zero_points are set correctly
880                self.assertEqual(q.q_scale(), new_scale)
881                self.assertEqual(q.q_zero_point(), new_zero_point)
882                q.copy_(q2)
883                # check scale and zero_points has been copied
884                self.assertEqual(q, q2)
885                # can't copy from quantized tensor to non-quantized tensor
886                r = torch.empty([numel], dtype=torch.float)
887                q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=dtype)
888                with self.assertRaisesRegex(RuntimeError, "please use dequantize"):
889                    r.copy_(q)
890            # copy from float doesn't support cuda
891            device = 'cpu'
892            # check copy from non-quantized to quantized
893            r = torch.randn([numel], dtype=torch.float, device=device)
894            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=dtype, device=device)
895            q.copy_(r)
896            qr = torch.quantize_per_tensor(r, scale=scale, zero_point=zero_point, dtype=dtype)
897            self.assertEqual(q, qr)
898
899    def test_torch_qtensor_deepcopy(self):
900        # cuda is not supported yet
901        device = "cpu"
902        q_int = torch.randint(0, 100, [3, 5], device=device, dtype=torch.uint8)
903        scale, zero_point = 2.0, 3
904        q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
905        qc = deepcopy(q)
906        self.assertEqual(qc, q)
907
908    def test_clone(self):
909        numel = 10
910        scale = 0.5
911        zero_point = 10
912
913        options = itertools.product(
914            get_supported_device_types(),
915            [torch.qint8, torch.quint8, torch.qint32])
916
917        for device, dtype in options:
918            per_tensor_quantized = torch._empty_affine_quantized(
919                [numel], scale=scale, zero_point=zero_point,
920                device=device, dtype=dtype)
921            per_channel_quantized = torch._empty_per_channel_affine_quantized(
922                [numel],
923                scales=torch.tensor([scale] * numel, device=device),
924                zero_points=torch.tensor([zero_point] * numel, device=device),
925                axis=0,
926                device=device,
927                dtype=dtype
928            )
929            qtensors = [per_tensor_quantized, per_channel_quantized]
930
931            for q in qtensors:
932                q2 = q.clone()
933                # Check to make sure the scale and zero_point has been copied.
934                self.assertEqual(q, q2)
935
936    def test_qtensor_fill_per_tensor(self):
937        numel = 10
938        scale = 0.5
939        zero_point = 10
940
941        ones = torch.ones(numel).to(torch.float)
942
943        qtypes = [torch.qint8, torch.quint8, torch.qint32]
944        vals2fill = [-1, 1, 2**32]  # positive, negative, overflow
945
946        devices = get_supported_device_types()
947        for qtype, val2fill, device in itertools.product(qtypes, vals2fill, devices):
948            ones = ones.to(device)
949            q_filled = torch._empty_affine_quantized(
950                [numel], scale=scale, zero_point=zero_point, device=device,
951                dtype=qtype)
952            q_filled.fill_(val2fill)
953            # reference tensor for comparing q_filled
954            q_ref = torch.quantize_per_tensor(ones * val2fill, scale,
955                                              zero_point, qtype)
956            self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
957            self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
958            # Make sure the scale and zero_point don't change
959            self.assertEqual(q_filled.q_scale(), scale)
960            self.assertEqual(q_filled.q_zero_point(), zero_point)
961
962    # Adapted from test_qtensor_fill_per_tensor but for a NHWC tensor (requires 4D)
963    def test_qtensor_fill_per_tensor_nhwc(self):
964        dims = torch.randint(low=1, high=10, size=(4, )).tolist()
965        scale = 0.5
966        zero_point = 10
967
968        ones = torch.ones(dims).to(torch.float)
969
970        qtypes = [torch.qint8, torch.quint8, torch.qint32]
971        vals2fill = [-1, 1, 2**32]  # positive, negative, overflow
972        memory_formats = [torch.contiguous_format, torch.channels_last]
973        devices = get_supported_device_types()
974        for qtype, val2fill, memory_format, device in itertools.product(qtypes, vals2fill, memory_formats, devices):
975            q_filled = torch._empty_affine_quantized(
976                dims, scale=scale, zero_point=zero_point, device=device,
977                dtype=qtype, memory_format=memory_format)
978            q_filled.fill_(val2fill)
979            # reference tensor for comparing q_filled
980            q_ref = torch.quantize_per_tensor(ones * val2fill, scale,
981                                              zero_point, qtype)
982            self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
983            self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
984            # Make sure the scale and zero_point don't change
985            self.assertEqual(q_filled.q_scale(), scale)
986            self.assertEqual(q_filled.q_zero_point(), zero_point)
987
988    # adapted from test_qtensor_fill_per_tensor
989    def test_qtensor_fill_per_channel(self):
990        dims = [4, 5]
991        axis = 0
992        # adding a constant to avoid too small of a scale
993        scales = torch.rand(dims[axis], dtype=torch.float64) + 0.1
994        zero_points = torch.randint(low=0, high=10, size=(dims[axis], ))
995
996        ones = torch.ones(dims).to(torch.float)
997
998        qtypes = [torch.qint8, torch.quint8, torch.qint32]
999        vals2fill = [-1, 1, 2**32]  # positive, negative, overflow
1000
1001        devices = get_supported_device_types()
1002        for qtype, val2fill, device in itertools.product(qtypes, vals2fill, devices):
1003            scales = scales.to(device)
1004            zero_points = zero_points.to(device)
1005            ones = ones.to(device)
1006            q_filled = torch._empty_per_channel_affine_quantized(
1007                dims, scales=scales, zero_points=zero_points, device=device,
1008                axis=axis, dtype=qtype)
1009            q_filled.fill_(val2fill)
1010            # reference tensor for comparing q_filled
1011            q_ref = torch.quantize_per_channel(ones * val2fill, scales=scales,
1012                                               zero_points=zero_points, axis=axis, dtype=qtype)
1013            self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
1014            self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
1015            # Make sure the scale and zero_point don't change
1016            self.assertEqual(q_filled.q_per_channel_scales(), scales)
1017            self.assertEqual(q_filled.q_per_channel_zero_points(), zero_points)
1018
1019    def test_qtensor_masked_fill_cpu(self):
1020        self._test_qtensor_masked_fill('cpu')
1021
1022    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
1023    def test_qtensor_masked_fill_cuda(self):
1024        self._test_qtensor_masked_fill('cuda')
1025
1026    # adapted from test_qtensor_fill_per_tensor
1027    def _test_qtensor_masked_fill(self, device):
1028        numel = 10
1029        scale = 0.5
1030        zero_point = 10
1031
1032        ones = torch.ones(numel, dtype=torch.float, device=device)
1033
1034        types = [torch.qint8, torch.quint8, torch.qint32]
1035        fills = [-1, 1, 2**32]  # positive, negative, overflow
1036
1037        for qtype, fill_with in itertools.product(types, fills):
1038            q_filled = torch._empty_affine_quantized(
1039                [numel], scale=scale, zero_point=zero_point, device=device,
1040                dtype=qtype)
1041            q_filled.fill_(fill_with)
1042            q_masked_fill = torch._empty_affine_quantized(
1043                [numel], scale=scale, zero_point=zero_point, device=device,
1044                dtype=qtype)
1045            # mask fill the whole tensor, equivalent to calling plain vanilla fill
1046            mask = torch.tensor(True, device=device)
1047            q_masked_fill.masked_fill_(mask, fill_with)
1048            int_repr = torch.quantize_per_tensor(ones * fill_with, scale,
1049                                                 zero_point, qtype)
1050            fill_with = int_repr.dequantize()
1051            int_repr = int_repr.int_repr()
1052
1053            self.assertEqual(q_filled, q_masked_fill)
1054            self.assertEqual(q_masked_fill.int_repr(), int_repr)
1055            self.assertEqual(q_masked_fill.dequantize(), fill_with)
1056            # Make sure the scale and zero_point don't change
1057            self.assertEqual(q_masked_fill.q_scale(), scale)
1058            self.assertEqual(q_masked_fill.q_zero_point(), zero_point)
1059
1060        # the above loop does the same test as test_qtensor_fill
1061        # now we will check masked_fill for subset of indices
1062        mask = torch.randint(0, 2, (numel, ), device=device)
1063        mask = mask.bool()
1064        x = torch.rand(numel, device=device)
1065        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qtype)
1066        for qtype, fill_with in itertools.product(types, fills):
1067            q_masked_fill = qx.clone()
1068            q_masked_fill.masked_fill_(mask, fill_with)
1069            ref = qx.clone()
1070
1071            for i in range(numel):
1072                if mask[i]:
1073                    # this assignment doesn't end up calling masked_fill, allowing us to compare the different implementations
1074                    ref[i] = torch.tensor([fill_with], device=device, dtype=torch.float)
1075
1076            self.assertEqual(q_masked_fill, ref)
1077            self.assertEqual(q_masked_fill.int_repr(), ref.int_repr())
1078            self.assertEqual(q_masked_fill.dequantize(), ref.dequantize())
1079
1080    def test_qtensor_index_put_cpu(self):
1081        self._test_qtensor_index_put('cpu')
1082        self._test_qtensor_index_put_non_accumulate_deterministic('cpu')
1083
1084    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
1085    def test_qtensor_index_put_cuda(self):
1086        self._test_qtensor_index_put('cuda')
1087        self._test_qtensor_index_put_non_accumulate_deterministic('cuda')
1088
1089    def _test_qtensor_index_put(self, device):
1090        n = 10
1091        m = 10
1092        x_orig = torch.rand(n, m, device=device)
1093        indices = tuple(torch.tensor([[0, 0], [1, 1], [5, 5], [7, 3], [0, 5], [6, 9], [-1, -1]], device=device).t())
1094        # for the scalar tensor case, index_put routes to masked_fill
1095        values_list = [torch.tensor(2.5, device=device), torch.rand(len(indices[0]), device=device) * 1000]
1096        scale = 0.5
1097        zero_point = 10
1098        types = [torch.qint8, torch.quint8, torch.qint32]
1099        for qtype, values in itertools.product(types, values_list):
1100            x_ref = x_orig.clone()
1101            x_ref[indices] = values.to(dtype=x_ref.dtype)
1102            qx_ref = torch.quantize_per_tensor(x_ref, scale=scale, zero_point=zero_point, dtype=qtype)
1103
1104            x = x_orig.clone()
1105            qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qtype)
1106            qx[indices] = values
1107
1108            self.assertEqual(qx_ref, qx)
1109
1110    def _test_qtensor_index_put_non_accumulate_deterministic(self, device):
1111        with DeterministicGuard(True):
1112            scale = 0.5
1113            zero_point = 10
1114            types = [torch.qint8, torch.quint8, torch.qint32]
1115            for qtype in types:
1116                for i in range(3):
1117                    m = random.randint(10, 20)
1118                    elems = random.randint(20000, 30000)
1119                    values = torch.rand(elems, device=device)
1120                    indices = torch.randint(m, (elems,), device=device)
1121                    x_orig = torch.rand(m, device=device)
1122
1123                    x = x_orig.clone()
1124                    qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qtype)
1125                    output = qx.index_put((indices,), values, accumulate=False)
1126
1127
1128                    x_ref = x_orig.clone()
1129                    output_ref = x_ref.index_put((indices,), values, accumulate=False)
1130                    qx_ref = torch.quantize_per_tensor(output_ref, scale=scale, zero_point=zero_point, dtype=qtype)
1131
1132                    self.assertEqual(output, qx_ref)
1133
1134    # adapted from test_qtensor_fill_per_channel and test_qtensor_fill_per_tensor_nhwc
1135    def test_qtensor_fill_per_channel_nhwc(self):
1136        dims = torch.randint(low=1, high=10, size=(4, )).tolist()
1137        axis = 0
1138        # adding a constant to avoid too small of a scale
1139        scales = torch.rand(dims[axis], dtype=torch.float64) + 0.1
1140        zero_points = torch.randint(low=0, high=10, size=(dims[axis], ))
1141
1142        ones = torch.ones(dims).to(torch.float)
1143
1144        qtypes = [torch.qint8, torch.quint8, torch.qint32]
1145        vals2fill = [-1, 1, 2**32]  # positive, negative, overflow
1146        memory_formats = [torch.contiguous_format, torch.channels_last]
1147        devices = get_supported_device_types()
1148        for qtype, val2fill, memory_format, device in itertools.product(qtypes, vals2fill, memory_formats, devices):
1149            scales = scales.to(device)
1150            zero_points = zero_points.to(device)
1151            ones = ones.to(device)
1152            q_filled = torch._empty_per_channel_affine_quantized(
1153                dims, scales=scales, zero_points=zero_points, device=device,
1154                axis=axis, dtype=qtype, memory_format=memory_format)
1155            q_filled.fill_(val2fill)
1156            # reference tensor for comparing q_filled
1157            q_ref = torch.quantize_per_channel(ones * val2fill, scales=scales,
1158                                               zero_points=zero_points, axis=axis, dtype=qtype)
1159            self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
1160            self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
1161            # Make sure the scale and zero_point don't change
1162            self.assertEqual(q_filled.q_per_channel_scales(), scales)
1163            self.assertEqual(q_filled.q_per_channel_zero_points(), zero_points)
1164
1165    @unittest.skipIf(not TEST_CUDA, "No gpu is available.")
1166    def test_qtensor_index_select_cuda(self):
1167        self._test_qtensor_index_select('cuda')
1168
1169    def test_qtensor_index_select_cpu(self):
1170        self._test_qtensor_index_select('cpu')
1171
1172    def _test_qtensor_index_select(self, device):
1173        for quant_type in [torch.quint8, torch.qint8]:
1174            dims = 3
1175            index = torch.randint(dims, [1]).item()
1176            selected = torch.randperm(dims)[:2].to(device)
1177            scale = 1
1178            zp = 0
1179            x = torch.randn([3] * dims, device=device) * 10
1180
1181            x_selected = torch.index_select(x, index, selected)
1182            x_selected_quantized = torch.quantize_per_tensor(x_selected, scale, zp, quant_type)
1183
1184            x_quantized = torch.quantize_per_tensor(x, scale, zp, quant_type)
1185            x_quantized_selected = torch.index_select(x_quantized, index, selected)
1186
1187            self.assertEqual(x_quantized_selected, x_selected_quantized)
1188
1189    def test_qtensor_view(self):
1190        scale, zero_point, dtype = 1.0, 2, torch.uint8
1191        for device in get_supported_device_types():
1192            q_int = torch.randint(0, 100, [1, 2, 3], device=device, dtype=dtype)
1193            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
1194            q2 = q.view(1, 3, 2)
1195            self.assertEqual(q.numel(), q2.numel())
1196            # testing -1
1197            self.assertEqual(q, q2.view(1, -1, 3))
1198
1199            a_int = torch.randint(0, 100, [1, 2, 3, 4], device=device, dtype=dtype)
1200            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
1201            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
1202            c = a.view(1, 3, 2, 4)  # does not change tensor layout in memory
1203            self.assertEqual(b.size(), c.size())
1204            self.assertEqual(b.q_scale(), c.q_scale())
1205            self.assertEqual(b.q_zero_point(), c.q_zero_point())
1206            self.assertNotEqual(b.stride(), c.stride())
1207            # size is the same but the underlying data is different
1208            self.assertNotEqual(b.int_repr(), c.int_repr())
1209            # torch.equal is not supported for the cuda backend
1210            if device == 'cpu':
1211                self.assertFalse(torch.equal(b, c))
1212
1213            # a case can't view non-contiguos Tensor
1214            a_int = torch.randint(0, 100, [1, 2, 3, 4], device=device, dtype=dtype)
1215            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
1216            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
1217            err_str = "view size is not compatible with input tensor's size and stride*"
1218            with self.assertRaisesRegex(RuntimeError, err_str):
1219                b.view(1, 4, 2, 3)
1220            # view on contiguous tensor is fine
1221            b.contiguous().view(1, 4, 2, 3)
1222
1223    def test_qtensor_resize(self):
1224        for device in get_supported_device_types():
1225            scale, zero_point, dtype = 1.0, 2, torch.uint8
1226            sizes1 = [1, 2, 3, 4]
1227            sizes2 = [1 * 2, 3 * 4]
1228            sizes3 = [1, 2 * 3, 4]
1229            sizes4 = [1 * 2 * 3 * 4]
1230            sizes5 = [1, 2, 1, 3, 1, 4]
1231
1232            q1_int = torch.randint(0, 100, sizes1, dtype=dtype, device=device)
1233            q1 = torch._make_per_tensor_quantized_tensor(q1_int, scale=scale, zero_point=zero_point)
1234            q2 = q1.resize(*sizes2)
1235            q3 = q2.resize(*sizes3)
1236            q4 = q3.resize(*sizes4)
1237            q5 = q4.resize(*sizes5)
1238
1239            self.assertEqual(q1.numel(), q2.numel())
1240            self.assertEqual(q1.numel(), q3.numel())
1241            self.assertEqual(q1.numel(), q4.numel())
1242            self.assertEqual(q1.numel(), q5.numel())
1243
1244            # Compare original and post-transpose
1245            a_int = torch.randint(0, 100, sizes1, dtype=dtype, device=device)
1246            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
1247            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
1248            c = b.resize(*sizes1)  # Change the sizes back to the original
1249
1250            self.assertEqual(a.size(), c.size())
1251            self.assertEqual(b.q_scale(), c.q_scale())
1252            self.assertEqual(b.q_zero_point(), c.q_zero_point())
1253            self.assertNotEqual(b.stride(), c.stride())
1254            # size is the same but the underlying data is different
1255            self.assertNotEqual(b.int_repr(), c.int_repr())
1256            # torch.equal is not supported for the cuda backend
1257            if device == 'cpu':
1258                self.assertFalse(torch.equal(b, c))
1259
1260            # Throws an error if numel is wrong
1261            q1_int = torch.randint(0, 100, sizes1, dtype=dtype, device=device)
1262            q1 = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
1263            err_str = "requested resize to*"
1264            with self.assertRaisesRegex(RuntimeError, err_str):
1265                q2 = q1.resize(*sizes1[:-1])
1266            # resize on both contiguous and non-contiguous tensor should be fine
1267            q3 = q1.resize(*sizes2)
1268            q4 = q1.contiguous().resize(*sizes2)
1269
1270    def test_qtensor_reshape(self):
1271        scale, zero_point, dtype = 1.0, 2, torch.uint8
1272        for device in get_supported_device_types():
1273            q_int = torch.randint(0, 100, [3, 5], dtype=dtype, device=device)
1274            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
1275            q2 = q.reshape([15])
1276            self.assertEqual(q.numel(), q2.numel())
1277            self.assertEqual(q2.size(), [15])
1278            # testing -1
1279            self.assertEqual(q, q2.reshape([3, -1]))
1280
1281            a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype, device=device)
1282            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
1283            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
1284            c = a.reshape(1, 3, 2, 4)  # does not change tensor layout
1285            self.assertEqual(b.size(), c.size())
1286            self.assertEqual(b.q_scale(), c.q_scale())
1287            self.assertEqual(b.q_zero_point(), c.q_zero_point())
1288            self.assertNotEqual(b.stride(), c.stride())
1289            self.assertNotEqual(b.int_repr(), c.int_repr())
1290            # torch.equal is not supported for the cuda backend
1291            if device == 'cpu':
1292                self.assertFalse(torch.equal(b, c))
1293
1294            # we can use reshape for non-contiguous Tensor
1295            a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype, device=device)
1296            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
1297            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
1298            c = b.reshape(1, 4, 2, 3)
1299
1300    def test_qtensor_unsqueeze(self):
1301        for device in get_supported_device_types():
1302            x = torch.randn((1, 3, 4), device=device)
1303            qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
1304            qy = qx.unsqueeze(2)
1305            self.assertEqual(qy.size(), (1, 3, 1, 4))
1306            qy = qy.squeeze(2)
1307            self.assertEqual(qy.size(), qx.size())
1308
1309            # Per channel qtensor
1310            scales = torch.tensor([1.0], device=device)
1311            zero_points = torch.tensor([0], device=device)
1312            qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=torch.quint8, axis=0)
1313            qy = qx.unsqueeze(0)
1314            self.assertEqual(qy.size(), (1, 1, 3, 4))
1315            self.assertEqual(qy.q_per_channel_axis(), 1)
1316
1317            qz = qy.squeeze(0)
1318            self.assertEqual(qz.size(), x.size())
1319            self.assertEqual(qz.q_per_channel_axis(), 0)
1320            with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"):
1321                qz = qy.squeeze(1)
1322
1323            # squeeze without dim specified
1324            x = torch.randn((3, 1, 2, 1, 4), device=device)
1325            scales = torch.tensor([1.0, 1.0], device=device)
1326            zero_points = torch.tensor([0, 0], device=device)
1327            qx = torch.quantize_per_channel(x, scales=scales, zero_points=zero_points, dtype=torch.quint8, axis=2)
1328            qz = qx.squeeze()
1329            self.assertEqual(qz.size(), (3, 2, 4))
1330            self.assertEqual(qz.q_per_channel_axis(), 1)
1331            with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"):
1332                qz = qy.squeeze()
1333
1334    def test_repeat(self):
1335        scale, zero_point, dtype = 1.0, 2, torch.uint8
1336        for device in get_supported_device_types():
1337            q_int = torch.randint(0, 100, [3], dtype=dtype, device=device)
1338            q_int_repeat = q_int.repeat(4, 2)
1339            q_ref = torch._make_per_tensor_quantized_tensor(q_int_repeat, scale=scale, zero_point=zero_point)
1340
1341            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
1342            q_repeat = q.repeat(4, 2)
1343            self.assertEqual(q_ref, q_repeat)
1344
1345    def test_qscheme_pickle(self):
1346        f = Foo()
1347        buf = io.BytesIO()
1348        torch.save(f, buf)
1349
1350        buf.seek(0)
1351        # weights_only=False as this is legacy code that saves the model
1352        f2 = torch.load(buf, weights_only=False)
1353
1354        self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)
1355
1356    @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4,
1357                                              min_side=1, max_side=10),
1358                       qparams=hu.qparams()),
1359           reduce_range=st.booleans()
1360           )
1361    @unittest.skip(
1362        "this is broken without changes to any relevant code, "
1363        "we need to remove hypothesis testing in CI")
1364    def test_choose_qparams(self, X, reduce_range):
1365        X, (scale, zero_point, torch_type) = X
1366        X = torch.from_numpy(X)
1367        X_scale, X_zp = _calculate_dynamic_qparams(X, torch.quint8, reduce_range=reduce_range)
1368        qparams = torch._choose_qparams_per_tensor(X, reduce_range)
1369        np.testing.assert_array_almost_equal(X_scale, qparams[0], decimal=3)
1370        self.assertEqual(X_zp, qparams[1])
1371
1372    @unittest.skipIf(not torch.cuda.is_available(), 'CUDA is not available')
1373    def test_cuda_quantization_does_not_pin_memory(self):
1374        # Context - https://github.com/pytorch/pytorch/issues/41115
1375        x = torch.randn(3)
1376        self.assertEqual(x.is_pinned(), False)
1377
1378        q_int = torch.randint(0, 100, [1, 2, 3], device="cuda", dtype=torch.uint8)
1379        q = torch._make_per_tensor_quantized_tensor(q_int, scale=0.1, zero_point=0)
1380
1381        x = torch.randn(3)
1382        self.assertEqual(x.is_pinned(), False)
1383
1384    # There's no way to actually pin the memory of a quantized tensor
1385    @unittest.skipIf(not torch.cuda.is_available(), 'CUDA is not available')
1386    def test_quant_pin_memory(self):
1387        x = torch.randn(3).pin_memory()
1388        self.assertEqual(x.is_pinned(), True)
1389        x_q = torch.quantize_per_tensor(x, 1, 0, torch.quint8)
1390        self.assertEqual(x_q.is_pinned(), False)
1391        x_pin = torch.empty_quantized([3], x_q, pin_memory=True, dtype=torch.quint8)
1392        self.assertEqual(x_pin.is_pinned(), False)
1393        self.assertRaises(RuntimeError, lambda: x_q.pin_memory())
1394
1395    def test_fp16_saturate_op(self):
1396        x = torch.ones(5, 5, dtype=torch.float32) * 65532
1397        x[0] = torch.ones(5) * -65532
1398        # range of fp16 value is [-65504, + 65504]
1399        ref = torch.ones(5, 5) * 65504
1400        ref[0] = torch.ones(5) * -65504
1401        y = torch._saturate_weight_to_fp16(x)
1402        self.assertEqual(y, ref)
1403
1404    def test_choose_qparams_optimized(self):
1405        for bit_width in [4, 2]:
1406            x = torch.randn(64, dtype=torch.float)
1407            y = torch.choose_qparams_optimized(x, numel=64, n_bins=200, ratio=0.16, bit_width=bit_width)
1408            ref = param_search_greedy(x.numpy(), bit_rate=bit_width)
1409            self.assertEqual(y[0].numpy(), ref[0])
1410            self.assertEqual(y[1].numpy(), ref[1])
1411
1412    def _test_pickle_checkpoint_qtensor(self, device):
1413        with TemporaryFileName() as fname:
1414            class M(torch.jit.ScriptModule):
1415                __constants__ = ['fname']
1416
1417                def __init__(self) -> None:
1418                    super().__init__()
1419                    self.fname = fname
1420
1421                @torch.jit.script_method
1422                def forward(self, x, y):
1423                    torch.save((x, y), self.fname)
1424                    return y
1425
1426            q = torch.quantize_per_tensor(
1427                torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
1428            qc = torch.quantize_per_channel(
1429                torch.rand(2, 3, dtype=torch.float),
1430                scales=torch.tensor([0.1, 0.5, 0.01]),
1431                zero_points=torch.tensor([10, 0, 20]),
1432                axis=1, dtype=torch.quint8).to(device)
1433            m = M()
1434            m(q, qc)
1435            with open(fname, "rb") as handle:
1436                for weights_only in [True, False]:
1437                    loaded_q, loaded_qc = torch.load(fname, weights_only=weights_only)
1438                    self.assertEqual(loaded_q, q)
1439                    self.assertEqual(loaded_qc, qc)
1440
1441    def test_pickle_checkpoint_qtensor(self):
1442        self._test_pickle_checkpoint_qtensor('cpu')
1443
1444    def test_jit_serialization(self):
1445        class SimpleQTensor(torch.jit.ScriptModule):
1446            def __init__(self, per_channel):
1447                super().__init__()
1448                x = torch.rand(5, 5).float()
1449                if not per_channel:
1450                    x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
1451                else:
1452                    s = torch.rand(5, dtype=torch.float64) + 0.1
1453                    zp = torch.randint(5, 15, (5,))
1454                    x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
1455                self.x = torch.nn.Buffer(x_q)
1456
1457            @torch.jit.script_method
1458            def forward(self):
1459                return self.x
1460
1461        for per_channel in [False, True]:
1462            model = SimpleQTensor(per_channel)
1463            buffer = io.BytesIO()
1464            torch.jit.save(model, buffer)
1465            buffer.seek(0)
1466            model_loaded = torch.jit.load(buffer)
1467            self.assertEqual(model_loaded(), model())
1468
1469    def test_bfp16_quantize(self):
1470        X = torch.randn(5 , 10)
1471        quantized_X = X.to(torch.bfloat16)
1472        dedequantized_X = quantized_X.to(torch.float32)
1473        torch.testing.assert_close(X, dedequantized_X, rtol=1e-4, atol=5e-3)
1474
1475    def test_decomposed_quantize_per_tensor(self):
1476        # register the ops
1477        import torch.ao.quantization.fx._decomposed
1478        X = torch.randn(5, 10)
1479        test_cases = [
1480            (torch.quint8, torch.uint8, 0, 255),
1481            (torch.qint8, torch.int8, -128, 127),
1482            (torch.qint32, torch.int32, -2**31, 2**31 - 1),
1483        ]
1484        for qdtype, dtype, quant_min, quant_max in test_cases:
1485            scale, zero_point = _calculate_dynamic_qparams(X, qdtype)
1486            quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype)
1487            quantized_decomposed_X = \
1488                torch.ops.quantized_decomposed.quantize_per_tensor(
1489                    X, scale, zero_point, quant_min, quant_max, dtype)
1490            self.assertEqual(quantized_decomposed_X.dtype, dtype)
1491            self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1492
1493    def test_decomposed_quantize_per_tensor_bfloat16_input(self):
1494        # register the ops
1495        import torch.ao.quantization.fx._decomposed
1496        X = torch.randint(1, 10, (5, 5)).to(torch.float32)
1497        scale, zero_point = _calculate_dynamic_qparams(X, torch.quint8)
1498        quantized_X = torch.quantize_per_tensor(X, scale, zero_point, torch.quint8)
1499        quantized_decomposed_X = \
1500            torch.ops.quantized_decomposed.quantize_per_tensor(
1501                X.to(torch.bfloat16), scale, zero_point, 0, 255, torch.uint8)
1502        self.assertEqual(quantized_decomposed_X.dtype, torch.uint8)
1503        self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1504
1505    def test_decomposed_dequantize_per_tensor(self):
1506        import torch.ao.quantization.fx._decomposed
1507        X = torch.randn(5, 10)
1508        test_cases = [
1509            (torch.quint8, torch.uint8, 0, 255),
1510            (torch.qint8, torch.int8, -128, 127),
1511            (torch.qint32, torch.int32, -2**31, 2**31 - 1),
1512        ]
1513
1514        for qdtype, dtype, quant_min, quant_max in test_cases:
1515            scale, zero_point = _calculate_dynamic_qparams(X, qdtype)
1516            quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype)
1517            dequantized_X = torch.dequantize(quantized_X)
1518
1519            quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor(
1520                X, scale, zero_point, quant_min, quant_max, dtype)
1521            dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor(
1522                quantized_decomposed_X, scale, zero_point, quant_min, quant_max, dtype
1523            )
1524            self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1525            self.assertEqual(dequantized_X, dequantized_decomposed_X)
1526
1527    def test_decomposed_dynamic_quant_pattern(self):
1528        import torch.ao.quantization.fx._decomposed
1529        X = torch.randn(5, 10)
1530        dtype = torch.uint8
1531        qdtype = torch.quint8
1532        scale, zero_point = torch._choose_qparams_per_tensor(X, False)
1533        quant_min, quant_max = 0, 255
1534
1535        quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype)
1536        dequantized_X = torch.dequantize(quantized_X)
1537
1538        # Now try decomposed pattern
1539        (scale_decomposed, zero_point_decomposed) = torch.ops.quantized_decomposed.choose_qparams.tensor(
1540            X, quant_min, quant_max, torch.Tensor([torch.finfo(torch.float32).eps]), dtype)
1541        quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor.tensor(
1542            X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype)
1543
1544        dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor(
1545            quantized_decomposed_X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype
1546        )
1547        self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1548        self.assertEqual(dequantized_X, dequantized_decomposed_X)
1549
1550    def test_decomposed_quantize_per_channel(self):
1551        # register the ops
1552        import torch.ao.quantization.fx._decomposed
1553        X = torch.randn(5, 10)
1554        qdtype = torch.quint8
1555        dtype = torch.uint8
1556        scales = torch.randn(5,)
1557        zero_points = torch.randint(0, 100, (5,))
1558        quant_min, quant_max = 0, 255
1559        axis = 0
1560
1561        quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype)
1562        quantized_decomposed_X = \
1563            torch.ops.quantized_decomposed.quantize_per_channel(
1564                X, scales, zero_points, axis, quant_min, quant_max, dtype)
1565        self.assertEqual(quantized_decomposed_X.dtype, dtype)
1566        self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1567
1568    def test_decomposed_quantize_per_channel_bfloat16_input(self):
1569        # register the ops
1570        import torch.ao.quantization.fx._decomposed
1571        X = torch.randint(1, 10, (5, 5)).to(torch.float32)
1572        qdtype = torch.quint8
1573        dtype = torch.uint8
1574        scales = torch.randn(5,)
1575        zero_points = torch.randint(0, 100, (5,))
1576        quant_min, quant_max = 0, 255
1577        axis = 0
1578
1579        quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype)
1580        quantized_decomposed_X = \
1581            torch.ops.quantized_decomposed.quantize_per_channel(
1582                X.to(torch.bfloat16), scales, zero_points, axis, quant_min, quant_max, dtype)
1583        self.assertEqual(quantized_decomposed_X.dtype, dtype)
1584        self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1585
1586    def test_decomposed_dequantize_per_channel(self):
1587        # register the ops
1588        import torch.ao.quantization.fx._decomposed
1589        X = torch.randn(5, 10)
1590        qdtype = torch.quint8
1591        dtype = torch.uint8
1592        scales = torch.randn(5,)
1593        zero_points = torch.randint(0, 100, (5,))
1594        quant_min, quant_max = 0, 255
1595        axis = 0
1596
1597        quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype)
1598        dequantized_X = torch.dequantize(quantized_X)
1599
1600        quantized_decomposed_X = \
1601            torch.ops.quantized_decomposed.quantize_per_channel(
1602                X, scales, zero_points, axis, quant_min, quant_max, dtype)
1603        dequantized_decomposed_X = \
1604            torch.ops.quantized_decomposed.dequantize_per_channel(
1605                quantized_decomposed_X, scales, zero_points, axis, quant_min, quant_max, dtype)
1606
1607        self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X)
1608        self.assertEqual(dequantized_X, dequantized_decomposed_X)
1609
1610    def test_decomposed_choose_qparams_per_token_asymmetric_backward(self):
1611        # register the ops
1612        import torch.ao.quantization.fx._decomposed
1613        x = torch.randn(2, 3).requires_grad_()
1614        (s, zp) = torch.ops.quantized_decomposed._choose_qparams_per_token_asymmetric_impl(x, torch.int8)
1615        out = x.div(s).add(zp).round()
1616        out.sum().backward()
1617
1618    def test_decomposed_quantize_per_channel_group(self):
1619        # register the ops
1620        import torch.ao.quantization.fx._decomposed
1621        qmin, qmax = (-8, 7)
1622        group_size = 128
1623        x = torch.randn(100, 256)
1624        s = torch.randn(100, 2)
1625        zp = torch.randint(qmax, size=(100, 2), dtype=torch.int32)
1626
1627        # simulate fake quantize per channel group with qdq
1628        q = torch.ops.quantized_decomposed.quantize_per_channel_group(
1629            x, s, zp, qmin, qmax, torch.int8, group_size,
1630        )
1631        dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
1632            q, s, zp, qmin, qmax, torch.int8, group_size, torch.float32
1633        )
1634
1635        # express per group fake quant using `torch.fake_quantize_per_channel_affine`
1636        x_grouped = x.reshape(-1, group_size)
1637        s_flattened = s.flatten()
1638        zp_flattened = zp.flatten()
1639        fq = torch.fake_quantize_per_channel_affine(
1640            x_grouped, s_flattened, zp_flattened, 0, qmin, qmax,
1641        )
1642        fq = fq.reshape_as(x)
1643        torch.testing.assert_close(dq, fq, rtol=0, atol=0)
1644
1645    def test_decomposed_quantize_per_token(self):
1646        # register the ops
1647        import torch.ao.quantization.fx._decomposed
1648        qmin, qmax = (-8, 7)
1649        x = torch.randn(100, 256)
1650        s = torch.randn(100, 1)
1651        zp = torch.randint(qmax, size=(100, 1), dtype=torch.int32)
1652
1653        # simulate fake quantize per token with qdq
1654        q = torch.ops.quantized_decomposed.quantize_per_token(
1655            x, s, zp, qmin, qmax, torch.int8,
1656        )
1657        dq = torch.ops.quantized_decomposed.dequantize_per_token(
1658            q, s, zp, qmin, qmax, torch.int8, torch.float32
1659        )
1660
1661        # express per token fake quant using `torch.fake_quantize_per_channel_affine`
1662        s_flattened = s.flatten()
1663        zp_flattened = zp.flatten()
1664        fq = torch.fake_quantize_per_channel_affine(
1665            x, s_flattened, zp_flattened, 0, qmin, qmax,
1666        )
1667        torch.testing.assert_close(dq, fq, rtol=0, atol=0)
1668
1669
1670if __name__ == '__main__':
1671    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
1672                       "\tpython test/test_quantization.py TESTNAME\n\n"
1673                       "instead.")
1674