xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_workflow_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4import math
5from typing import Tuple
6from torch.ao.quantization import (
7    FakeQuantize,
8    MovingAverageMinMaxObserver,
9    default_observer,
10    default_fixed_qparams_range_0to1_fake_quant,
11)
12
13from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize
14from torch.testing._internal.common_quantized import (
15    _fake_quantize_per_channel_affine_reference,
16    _fake_quantize_per_channel_affine_grad_reference,
17    to_tensor,
18)
19import torch.nn as nn
20
21# Standard library
22import io
23import itertools
24import unittest
25import numpy as np
26
27# Testing utils
28from hypothesis import given, settings
29from hypothesis import strategies as st
30import torch.testing._internal.hypothesis_utils as hu
31hu.assert_deadline_disabled()
32from torch.testing._internal.common_cuda import TEST_CUDA
33from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
34
35# Reference method for fake quantize
36# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
37def _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, quant_min, quant_max):
38    dtype = X.dtype
39    res = ((torch.clamp(torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point), quant_min, quant_max) - zero_point) * scale)
40    return res.to(dtype)
41
42# Reference method for the gradient of the fake quantize operator
43# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
44def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max):
45    dtype = X.dtype
46    Xq = torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point)
47    mask = (Xq >= quant_min) * (Xq <= quant_max)
48    res = torch.zeros_like(dY)
49    res[mask] = dY[mask]
50    return res.to(dtype)
51
52# Reference method for the gradients of the fake quantize operator
53def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device):
54    r"""This method references the following literatures for back propagation on scale and zero point.
55    - https://arxiv.org/pdf/1902.08153.pdf
56    - https://arxiv.org/pdf/1903.08066.pdf
57    """
58    zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item())
59    Xq = torch.round(X * (1.0 / scale) + zero_point_rounded)
60
61    indicate_small_scale = (Xq < quant_min).float().to(device)
62    indicate_big_scale = (Xq > quant_max).float().to(device)
63    indicate_middle_scale = torch.ones(indicate_small_scale.shape).to(device) - \
64        indicate_small_scale - indicate_big_scale
65
66    indicate_saturate_zp = ((Xq < quant_min).float() + (Xq > quant_max).float()).to(device)
67    indicate_unsaturate_zp = torch.ones(indicate_saturate_zp.shape).to(device) - indicate_saturate_zp
68
69    Xq = Xq.clamp(quant_min, quant_max)
70    Xfq = (Xq - zero_point_rounded) * scale
71
72    grad_small_scale = quant_min - zero_point_rounded
73    grad_big_scale = quant_max - zero_point_rounded
74    grad_middle_scale = ((Xfq - X) / scale).to(device)
75
76    grad_saturate_zp = -scale.to(device)
77    grad_unsaturate_zp = 0
78
79    grad_scale = indicate_small_scale * grad_small_scale + \
80        indicate_big_scale * grad_big_scale + \
81        indicate_middle_scale * grad_middle_scale
82    grad_zp = indicate_saturate_zp * grad_saturate_zp + \
83        indicate_unsaturate_zp * grad_unsaturate_zp
84    grad_X = _fake_quantize_per_tensor_affine_grad_reference(
85        dY, X, scale, zero_point, quant_min, quant_max).to(device)
86
87    grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0)
88    grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0)
89    return grad_X, grad_scale, grad_zp
90
91
92# Reference method for quantization.
93def _quantize_per_tensor(x, scale, zero_point, quant_min, quant_max):
94    return ((x / scale) + zero_point).round().clamp(quant_min, quant_max)
95
96# Reference method for the per channel gradients of the learnable fake quantize operator
97def _fake_quantize_learnable_per_channel_affine_grad_reference(
98        dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max, device):
99    r"""This method references the following literatures for back propagation on scale and zero point.
100    - https://arxiv.org/pdf/1902.08153.pdf
101    - https://arxiv.org/pdf/1903.08066.pdf
102    """
103    per_channel_zero_point = ((per_channel_zero_point.detach() + 0.5).clamp(quant_min, quant_max)).type(torch.int32)
104    grad_X = _fake_quantize_per_channel_affine_grad_reference(
105        dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max).to(device)
106    per_channel_scale = per_channel_scale.detach().type(torch.float)
107
108    grad_scale = torch.zeros([per_channel_scale.size(0)]).to(device)
109    grad_zero_point = torch.zeros([per_channel_zero_point.size(0)]).to(device)
110
111    X_flattened = torch.unbind(X, dim=axis)
112    dY_flattened = torch.unbind(dY, dim=axis)
113
114    for i, X_i in enumerate(torch.unbind(X, dim=axis), 0):
115        scale_i = per_channel_scale[i]
116        zero_point_i = per_channel_zero_point[i]
117        X_i = X_flattened[i]
118        dY_i = dY_flattened[i]
119
120        Xq_i = ((X_i / scale_i) + zero_point_i).round()
121        Xfq_i = (Xq_i - zero_point_i) * scale_i
122
123        indicate_small_scale_i = (Xq_i < quant_min).float().to(device)
124        indicate_big_scale_i = (Xq_i > quant_max).float().to(device)
125        indicate_middle_scale_i = torch.ones(indicate_small_scale_i.shape).to(device) - \
126            indicate_small_scale_i - indicate_big_scale_i
127
128        indicate_saturate_zp_i = ((Xq_i < quant_min).float() +
129                                  (Xq_i > quant_max).float()).to(device)
130        indicate_unsaturate_zp_i = torch.ones(indicate_saturate_zp_i.shape).to(device) - \
131            indicate_saturate_zp_i
132
133        Xq_i = Xq_i.clamp(quant_min, quant_max)
134        Xfq_i = (Xq_i - zero_point_i) * scale_i
135
136        grad_small_scale_i = quant_min - zero_point_i
137        grad_big_scale_i = quant_max - zero_point_i
138        grad_middle_scale_i = ((Xfq_i - X_i) / scale_i).to(device)
139
140        grad_saturate_zp_i = -scale_i.to(device)
141        grad_unsaturate_zp_i = 0
142
143        grad_scale_i = indicate_small_scale_i * grad_small_scale_i + \
144            indicate_middle_scale_i * grad_middle_scale_i + \
145            indicate_big_scale_i * grad_big_scale_i
146        grad_zp_i = indicate_saturate_zp_i * grad_saturate_zp_i + \
147            indicate_unsaturate_zp_i * grad_unsaturate_zp_i
148
149        grad_scale_i = (grad_scale_i * dY_i).sum().unsqueeze(dim=0)
150        grad_zp_i = (grad_zp_i * dY_i).sum().unsqueeze(dim=0)
151
152        grad_scale[i] = grad_scale_i
153        grad_zero_point[i] = grad_zp_i
154    return grad_X, grad_scale, grad_zero_point
155
156def _get_tensor_min_max(
157        X: torch.Tensor,
158        running_min: float = float("inf"),
159        running_max: float = float("-inf"),
160        averaging_const: float = 0.01) -> Tuple[float, float]:
161    min_val = X.min().to(dtype=torch.float32).item()
162    max_val = X.max().to(dtype=torch.float32).item()
163
164    if not math.isinf(running_min):
165        min_val = running_min + averaging_const * (min_val - running_min)
166    if not math.isinf(running_max):
167        max_val = running_max + averaging_const * (max_val - running_max)
168
169    return min_val, max_val
170
171def _get_per_row_min_max(
172        x: torch.Tensor,
173        min_vals: torch.Tensor,
174        max_vals: torch.Tensor,
175        axis: int = 0,
176        averaging_const: float = 0.01) -> Tuple[torch.Tensor, torch.Tensor]:
177    x_dim = x.size()
178    new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
179    new_axis_list[axis] = 0
180    new_axis_list[0] = axis
181    y = x.permute(*new_axis_list)
182
183    y = torch.flatten(y, start_dim=1)
184    # min_vals, max_vals = torch.aminmax(y, dim=1)
185    if math.isinf(min_vals[0]) or math.isinf(max_vals[0]):
186        min_vals, max_vals = torch.aminmax(y, dim=1)
187    else:
188        min_vals_cur, max_vals_cur = torch.aminmax(y, dim=1)
189        min_vals = min_vals + averaging_const * (min_vals_cur - min_vals)
190        max_vals = max_vals + averaging_const * (max_vals_cur - max_vals)
191    return min_vals, max_vals
192
193def _get_scale_zp(
194        min_val: float,
195        max_val: float,
196        dtype: torch.dtype,
197        reduce_range: bool = False,
198        preserve_sparsity: bool = False) -> Tuple[float, int]:
199    """
200    Calculate the quantization parameters (scale, zero_point)
201    based on the min and max element of the tensor
202    """
203    if dtype == torch.qint8:
204        if reduce_range:
205            qmin, qmax = -64, 63
206        else:
207            qmin, qmax = -128, 127
208    else:
209        if reduce_range:
210            qmin, qmax = 0, 127
211        else:
212            qmin, qmax = 0, 255
213
214    if min_val < 0 and max_val > 0 and preserve_sparsity:
215        symmetric_qmin = int(-((qmax - qmin) / 2 + 1))
216        symmetric_qmax = int((qmax - qmin) / 2)
217        max_scale = max(
218            abs(min_val / symmetric_qmin), abs(max_val / symmetric_qmax)
219        )
220        min_val = max_scale * symmetric_qmin
221        max_val = max_scale * symmetric_qmax
222    min_val = min(min_val, 0.0)
223    max_val = max(max_val, 0.0)
224    scale = (max_val - min_val) / (qmax - qmin)
225    if scale == 0.0 or math.isinf(1.0 / scale):
226        scale = 0.1
227        zero_point = 0
228
229    zero_point_from_min = qmin - min_val / float(scale)
230    zero_point_from_max = qmax - max_val / float(scale)
231    zero_point_from_min_error = abs(qmin) - abs(min_val / float(scale))
232    zero_point_from_max_error = abs(qmax) - abs(max_val / float(scale))
233    if zero_point_from_min_error < zero_point_from_max_error:
234        initial_zero_point = zero_point_from_min
235    else:
236        initial_zero_point = zero_point_from_max
237
238    if min_val < 0 and max_val > 0 and preserve_sparsity:
239        initial_zero_point = (qmin + qmax) / 2 + 1
240
241    nudged_zero_point = 0
242
243    if initial_zero_point < qmin:
244        nudged_zero_point = qmin
245    elif initial_zero_point > qmax:
246        nudged_zero_point = qmax
247    else:
248        nudged_zero_point = int(round(initial_zero_point))
249
250    return (scale, int(nudged_zero_point))
251
252NP_RANDOM_SEED = 19
253tolerance = 1e-6
254
255class TestFakeQuantizeOps(TestCase):
256    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
257           X=hu.tensor(shapes=hu.array_shapes(1, 5,),
258                       qparams=hu.qparams(dtypes=torch.quint8)))
259    def test_forward_per_tensor(self, device, X):
260        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
261        """
262        np.random.seed(NP_RANDOM_SEED)
263        X, (scale, zero_point, torch_type) = X
264        quant_min = torch.iinfo(torch_type).min
265        quant_max = torch.iinfo(torch_type).max
266
267        X = to_tensor(X, device)
268        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
269        Y_prime = torch.fake_quantize_per_tensor_affine(
270            X, scale, zero_point, quant_min, quant_max)
271        np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
272
273    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
274           X=hu.tensor(shapes=hu.array_shapes(1, 5,),
275                       qparams=hu.qparams(dtypes=torch.quint8)))
276    @unittest.skip("temporarily disable the test")
277    def test_backward_per_tensor(self, device, X):
278        r"""Tests the backward method.
279        """
280        np.random.seed(NP_RANDOM_SEED)
281        X, (scale, zero_point, torch_type) = X
282        quant_min = torch.iinfo(torch_type).min
283        quant_max = torch.iinfo(torch_type).max
284
285        X = to_tensor(X, device)
286        X.requires_grad_()
287        Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
288        Y_prime = torch.fake_quantize_per_tensor_affine(
289            X, scale, zero_point, quant_min, quant_max)
290        dout = torch.rand_like(X, dtype=torch.float).to(device)
291        dX = _fake_quantize_per_tensor_affine_grad_reference(
292            dout, X, scale, zero_point, quant_min, quant_max)
293        Y_prime.backward(dout)
294        np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
295
296    def test_forward_backward_per_tensor_with_amp(self):
297        net = nn.Sequential(nn.Conv2d(1, 1, 3))
298        net.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
299        net_prep = torch.ao.quantization.prepare_qat(net)
300
301        with torch.cuda.amp.autocast():
302            x = torch.randn(4, 1, 5, 5)
303            out = net_prep(x).sum()
304            out.backward()
305            self.assertTrue(net_prep[0].weight.grad is not None)
306
307    def test_forward_per_tensor_half_precision_numerics(self):
308        scale = .1
309        zero = 0
310        maxi = 255
311        mini = 0
312
313        for i in range(20):
314            X1 = torch.randn(5, 5).to(torch.float16)
315            Y1 = torch.fake_quantize_per_tensor_affine(X1, scale, zero, mini, maxi)
316            Y1r = _fake_quantize_per_tensor_affine_reference(X1, scale, zero, mini, maxi)
317            self.assertEqual(Y1, Y1r, rtol=tolerance, atol=tolerance)
318
319        # to force overflow
320        X2 = torch.tensor(2**15 + .01).to(torch.float16)
321        Y2 = torch.fake_quantize_per_tensor_affine(X2, scale, zero, mini, maxi)
322        Y2r = _fake_quantize_per_tensor_affine_reference(X2, scale, zero, mini, maxi)
323        self.assertEqual(Y2, Y2r, rtol=tolerance, atol=tolerance)
324
325        scale = 10
326
327        # to force underflow
328        X3 = torch.tensor(2**-24).to(torch.float16)
329        Y3 = torch.fake_quantize_per_tensor_affine(X3, scale, zero, mini, maxi)
330        Y3r = _fake_quantize_per_tensor_affine_reference(X3, scale, zero, mini, maxi)
331        self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance)
332
333    def _test_forward_per_tensor_cachemask_impl(self, device):
334        float_types = (torch.float32, torch.float16, torch.float64)
335        torch_types = (torch.qint8, torch.quint8)
336        Xs = (torch.randn(4, 8, device=device), torch.randn(4, 16, device=device)[:, ::2])
337        tensor_qparam = (True, False)
338        for float_type, torch_type, X, tensor_qparams in itertools.product(float_types, torch_types, Xs, tensor_qparam):
339            # pick the scale + zp so that some values get clipped
340            X = X.to(float_type)
341            obs = torch.ao.quantization.MinMaxObserver(torch_type)
342            obs.to(device)
343            obs(X * 0.75)
344            scale, zero_point = obs.calculate_qparams()
345            quant_min, quant_max = obs.quant_min, obs.quant_max
346            if not tensor_qparam:
347                scale, zero_point = float(scale), int(zero_point)
348            Y_test = torch.fake_quantize_per_tensor_affine(
349                X, scale, zero_point, quant_min, quant_max)
350            Y_ref = _fake_quantize_per_tensor_affine_reference(
351                X, scale, zero_point, quant_min, quant_max).to(device)
352            self.assertEqual(Y_test, Y_ref, rtol=tolerance, atol=tolerance)
353            self.assertTrue(Y_test.dtype == float_type)
354
355    def test_forward_per_tensor_cachemask_cpu(self):
356        device = torch.device('cpu')
357        self._test_forward_per_tensor_cachemask_impl(device)
358
359    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
360    def test_forward_per_tensor_cachemask_cuda(self):
361        device = torch.device('cuda')
362        self._test_forward_per_tensor_cachemask_impl(device)
363
364    def _test_backward_per_tensor_cachemask_impl(self, device):
365        float_types = (torch.float32, torch.float16, torch.float64)
366        torch_types = (torch.qint8, torch.quint8)
367        tensor_qparams = (True, False)
368        for float_type, torch_type, tensor_qparam in itertools.product(float_types, torch_types, tensor_qparams):
369            X = torch.randn(4, 8).to(device).to(float_type)
370            X.requires_grad_()
371            # pick the scale + zp so that some values get clipped
372            obs = torch.ao.quantization.MinMaxObserver(torch_type)
373            obs.to(device)
374            obs(X * 0.75)
375            scale, zero_point = obs.calculate_qparams()
376            if not tensor_qparam:
377                scale, zero_point = float(scale), int(zero_point)
378            quant_min, quant_max = obs.quant_min, obs.quant_max
379
380            # forward pass
381            Y_test = torch.fake_quantize_per_tensor_affine(
382                X, scale, zero_point, quant_min, quant_max)
383            Y_ref = _fake_quantize_per_tensor_affine_reference(
384                X, scale, zero_point, quant_min, quant_max).to(device)
385            self.assertEqual(Y_test, Y_ref, rtol=tolerance, atol=tolerance)
386
387            # backward pass
388            dout = torch.rand_like(X, dtype=torch.float).to(device)
389            dX = _fake_quantize_per_tensor_affine_grad_reference(
390                dout, X, scale, zero_point, quant_min, quant_max)
391            Y_test.backward(dout)
392            self.assertEqual(dX, X.grad)
393            self.assertTrue(X.grad.dtype == float_type)
394
395    def test_backward_per_tensor_cachemask_cpu(self):
396        device = torch.device('cpu')
397        self._test_backward_per_tensor_cachemask_impl(device)
398
399    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
400    def test_backward_per_tensor_cachemask_cuda(self):
401        device = torch.device('cuda')
402        self._test_backward_per_tensor_cachemask_impl(device)
403
404    def _test_learnable_forward_per_tensor(self, X, device, scale_base, zero_point_base):
405        X_base = torch.tensor(X).to(device)
406
407        for n_bits in (4, 8):
408            quant_min, quant_max = 0, 2 ** n_bits - 1
409
410            X = X_base.clone().float()
411            scale_base = scale_base.to(device).float()
412            zero_point_base = zero_point_base.to(dtype=torch.int32, device=device)
413            scale = scale_base.clone()
414            zero_point = zero_point_base.clamp(quant_min, quant_max)
415
416            Y = _fake_quantize_per_tensor_affine_reference(
417                X, scale, zero_point, quant_min, quant_max).to(device)
418            for grad_factor in [0.1, 1.0, 10.0]:
419                Y_prime = torch._fake_quantize_learnable_per_tensor_affine(
420                    X, scale, zero_point, quant_min, quant_max, grad_factor).to(device)
421                self.assertTrue(
422                    torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
423                    "Expected kernel forward function to have results match the reference forward function")
424
425    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
426                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
427                       qparams=hu.qparams(dtypes=torch.quint8)))
428    @unittest.skip(
429        "this is broken without changes to any relevant code, "
430        "we need to remove hypothesis testing in CI")
431    def test_learnable_forward_per_tensor_cpu(self, X):
432        X, (_, _, _) = X
433        scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
434        zero_point_base = torch.normal(mean=0, std=128, size=(1,))
435        self._test_learnable_forward_per_tensor(
436            X, 'cpu', scale_base, zero_point_base)
437
438    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
439                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
440                       qparams=hu.qparams(dtypes=torch.quint8)))
441    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
442    def test_learnable_forward_per_tensor_cuda(self, X):
443        X, (_, _, _) = X
444        scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
445        zero_point_base = torch.normal(mean=0, std=128, size=(1,))
446        self._test_learnable_forward_per_tensor(
447            X, 'cuda', scale_base, zero_point_base)
448
449    def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base):
450        r"""Tests the backward method with additional backprop support for scale and zero point.
451        """
452        X_base = torch.tensor(X).to(device)
453
454        for n_bits in (4, 8):
455            quant_min, quant_max = 0, 2 ** n_bits - 1
456
457            X = X_base.clone().float().to(device)
458            X.requires_grad_()
459            scale_base = scale_base.to(device)
460            zero_point_base = zero_point_base.to(device)
461            scale = scale_base.clone()
462            scale.requires_grad_()
463            zero_point = zero_point_base.clone().clamp(quant_min, quant_max)
464            zero_point.requires_grad_()
465            for grad_factor in [0.1, 1.0, 10.0]:
466                Y_prime = torch._fake_quantize_learnable_per_tensor_affine(
467                    X, scale, zero_point, quant_min, quant_max, grad_factor).to(device)
468                dout = torch.rand_like(X, dtype=torch.float).to(device)
469                dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference(
470                    dout, X, scale, zero_point, quant_min, quant_max, device)
471                Y_prime.backward(dout)
472
473                expected_dX = dX.to(device).detach()
474                actual_dX = X.grad.to(device).detach()
475                expected_dScale = dScale.to(device).detach()
476                actual_dScale = scale.grad.to(device).detach()
477                expected_dZeroPoint = dZeroPoint.to(device).detach()
478                actual_dZeroPoint = zero_point.grad.to(device).detach()
479
480                self.assertTrue(
481                    torch.allclose(
482                        expected_dX, actual_dX, rtol=tolerance, atol=tolerance),
483                    "Expected dX to match X.grad")
484                self.assertTrue(
485                    torch.allclose(
486                        expected_dScale * grad_factor, actual_dScale, rtol=tolerance, atol=tolerance),
487                    "Expected dScale to match scale.grad")
488                self.assertTrue(
489                    torch.allclose(
490                        expected_dZeroPoint * grad_factor, actual_dZeroPoint, rtol=tolerance, atol=tolerance),
491                    "Expected dZeroPoint to match zero_point.grad")
492                X.grad.data.zero_()
493                scale.grad.data.zero_()
494                zero_point.grad.data.zero_()
495
496    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
497                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
498                       qparams=hu.qparams(dtypes=torch.quint8)))
499    def test_learnable_backward_per_tensor_cpu(self, X):
500        torch.random.manual_seed(NP_RANDOM_SEED)
501        X, (_, _, _) = X
502        scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
503        zero_point_base = torch.normal(mean=0, std=128, size=(1,))
504        self._test_learnable_backward_per_tensor(
505            X, 'cpu', scale_base, zero_point_base)
506
507    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
508                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
509                       qparams=hu.qparams(dtypes=torch.quint8)))
510    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
511    def test_learnable_backward_per_tensor_cuda(self, X):
512        torch.random.manual_seed(NP_RANDOM_SEED)
513        X, (_, _, _) = X
514        scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
515        zero_point_base = torch.normal(mean=0, std=128, size=(1,))
516        self._test_learnable_backward_per_tensor(
517            X, 'cuda', scale_base, zero_point_base)
518
519    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
520           X=hu.tensor(shapes=hu.array_shapes(1, 5,),
521                       qparams=hu.qparams(dtypes=[torch.quint8])),
522           )
523    def test_fq_module_per_tensor(self, device, X):
524        np.random.seed(NP_RANDOM_SEED)
525        X, (scale, zero_point, torch_type) = X
526        quant_min = torch.iinfo(torch_type).min
527        quant_max = torch.iinfo(torch_type).max
528
529        X = to_tensor(X, device)
530        X.requires_grad_()
531        fq_module = torch.ao.quantization.default_fake_quant().to(device)
532        Y_prime = fq_module(X)
533        assert fq_module.scale is not None
534        assert fq_module.zero_point is not None
535        Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
536        np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
537
538        # Test backward
539        dout = torch.rand_like(X, dtype=torch.float, device=device)
540        Y_prime.backward(dout)
541        dX = _fake_quantize_per_tensor_affine_grad_reference(dout, X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
542        np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
543
544    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
545           X=hu.tensor(shapes=hu.array_shapes(1, 5,),
546                       qparams=hu.qparams(dtypes=torch.quint8)))
547    def test_fixed_qparams_fq_module(self, device, X):
548        X, (scale, zero_point, torch_type) = X
549        X = to_tensor(X, device)
550        fq_module = default_fixed_qparams_range_0to1_fake_quant()
551        fq_module.to(device)
552        fixed_scale = fq_module.scale.clone()
553        fixed_zero_point = fq_module.zero_point.clone()
554        # run fq module and make sure the quantization parameters does not change
555        torch.ao.quantization.enable_observer(fq_module)
556        fq_module(X)
557        self.assertEqual(fixed_scale, fq_module.scale)
558        self.assertEqual(fixed_zero_point, fq_module.zero_point)
559
560    def test_fq_serializable_per_tensor(self):
561        observer = default_observer
562        quant_min = 0
563        quant_max = 127
564        for FakeQuantizeClass in [FakeQuantize, _LearnableFakeQuantize]:
565            fq_module = FakeQuantizeClass(observer, quant_min, quant_max)
566            X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
567            y_ref = fq_module(X)
568            state_dict = fq_module.state_dict()
569            self.assertEqual(state_dict['scale'], 0.094488)
570            self.assertEqual(state_dict['zero_point'], 53)
571            b = io.BytesIO()
572            torch.save(state_dict, b)
573            for weights_only in [True, False]:
574                b.seek(0)
575                loaded_dict = torch.load(b, weights_only=weights_only)
576                loaded_fq_module = FakeQuantizeClass(observer, quant_min, quant_max)
577                loaded_fq_module.load_state_dict(loaded_dict)
578                for key in state_dict:
579                    self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])
580
581                self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
582
583    def test_fake_quant_control(self):
584        for fq_module in [torch.ao.quantization.default_fake_quant(),
585                          _LearnableFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0,
586                                                           quant_max=255,
587                                                           dtype=torch.quint8, qscheme=torch.per_tensor_affine,
588                                                           reduce_range=True)()]:
589            torch.manual_seed(42)
590            X = torch.rand(20, 10, dtype=torch.float32)
591            # Output of fake quant is not identical to input
592            Y = fq_module(X)
593            self.assertNotEqual(Y, X)
594            if type(fq_module) == _LearnableFakeQuantize:
595                fq_module.toggle_fake_quant(False)
596            else:
597                torch.ao.quantization.disable_fake_quant(fq_module)
598            X = torch.rand(20, 10, dtype=torch.float32)
599            Y = fq_module(X)
600            # Fake quant is disabled,output is identical to input
601            self.assertEqual(Y, X)
602
603            # Explicit copy at this point in time, because FakeQuant keeps internal
604            # state in mutable buffers.
605            scale = fq_module.scale.clone().detach()
606            zero_point = fq_module.zero_point.clone().detach()
607
608            if type(fq_module) == _LearnableFakeQuantize:
609                fq_module.toggle_observer_update(False)
610                fq_module.toggle_fake_quant(True)
611            else:
612                torch.ao.quantization.disable_observer(fq_module)
613                torch.ao.quantization.enable_fake_quant(fq_module)
614            X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
615            Y = fq_module(X)
616            self.assertNotEqual(Y, X)
617            # Observer is disabled, scale and zero-point do not change
618            self.assertEqual(fq_module.scale, scale)
619            self.assertEqual(fq_module.zero_point, zero_point)
620            if type(fq_module) == _LearnableFakeQuantize:
621                fq_module.toggle_observer_update(True)
622            else:
623                torch.ao.quantization.enable_observer(fq_module)
624            Y = fq_module(X)
625            self.assertNotEqual(Y, X)
626            # Observer is enabled, scale and zero-point are different
627            self.assertNotEqual(fq_module.scale, scale)
628            self.assertNotEqual(fq_module.zero_point, zero_point)
629
630    def test_fake_quant_preserves_qparam_shapes_for_activations(self):
631        class Model(nn.Module):
632            def __init__(self) -> None:
633                super().__init__()
634                self.linear = nn.Linear(4, 4)
635
636            def forward(self, x):
637                x = self.linear(x)
638                return x
639
640        m = Model()
641
642        m.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
643        torch.ao.quantization.prepare_qat(m, inplace=True)
644
645        scale_shape_before = m.linear.activation_post_process.scale.shape
646        zero_point_shape_before = m.linear.activation_post_process.zero_point.shape
647
648        x = torch.rand(4, 4, 4, 4)
649        m(x)
650        scale_shape_after = m.linear.activation_post_process.scale.shape
651        zero_point_shape_after = m.linear.activation_post_process.zero_point.shape
652        self.assertEqual(
653            scale_shape_before, scale_shape_after,
654            msg="FakeQuant scale shape must stay consistent")
655        self.assertEqual(
656            zero_point_shape_before, zero_point_shape_after,
657            msg="FakeQuant zero_point shape must stay consistent")
658
659    def fake_quant_scriptable(self):
660        observer = default_observer
661        quant_min = 0
662        quant_max = 255
663        for FakeQuantizeClass in [FakeQuantize, _LearnableFakeQuantize]:
664            fq_module = FakeQuantizeClass(observer, quant_min, quant_max)
665            scripted_module = torch.jit.script(fq_module)
666
667            X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
668
669            fq_module(X)
670            scripted_module(X)
671            self.assertEqual(fq_module.calculate_qparams(), scripted_module.calculate_qparams())
672
673            buf = io.BytesIO()
674            torch.jit.save(scripted_module, buf)
675            buf.seek(0)
676            loaded_module = torch.jit.load(buf)
677            self.assertEqual(fq_module.calculate_qparams(), loaded_module.calculate_qparams())
678
679
680    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
681           X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
682           qparams=hu.qparams(dtypes=torch.quint8)))
683    def test_forward_per_channel(self, device, X):
684        r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
685        """
686        np.random.seed(NP_RANDOM_SEED)
687        X, (scale, zero_point, axis, torch_type) = X
688        quant_min = torch.iinfo(torch_type).min
689        quant_max = torch.iinfo(torch_type).max
690
691        X = to_tensor(X, device)
692        scale = to_tensor(scale, device)
693        zero_point = torch.tensor(zero_point).to(dtype=torch.int32, device=device)
694        Y = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
695        Y_prime = torch.fake_quantize_per_channel_affine(
696            X, scale, zero_point, axis, quant_min, quant_max)
697        np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
698
699    def _test_forward_per_channel_cachemask_impl(self, device):
700        torch_types = (torch.qint8, torch.quint8)
701        float_types = (torch.float32, torch.float16, torch.float64)
702        zero_point_types = (torch.int, torch.float32, torch.float16)
703
704        for torch_type, float_type, zero_point_type in itertools.product(torch_types, float_types, zero_point_types):
705            X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device)
706            # pick the scale + zp so that some values get clipped
707            axis = 1
708            obs = torch.ao.quantization.PerChannelMinMaxObserver(axis, torch_type).to(device)
709            obs(X * 0.75)
710            scale, zero_point = obs.calculate_qparams()
711            # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
712            zero_point = zero_point.to(zero_point_type)
713            quant_min, quant_max = obs.quant_min, obs.quant_max
714
715            Y = _fake_quantize_per_channel_affine_reference(
716                X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
717            Y_prime = torch.fake_quantize_per_channel_affine(
718                X, scale, zero_point, axis, quant_min, quant_max)
719            np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
720            self.assertTrue(Y.dtype == float_type)
721
722    def test_forward_per_channel_cachemask_cpu(self):
723        self._test_forward_per_channel_cachemask_impl('cpu')
724
725    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
726    def test_forward_per_channel_cachemask_cuda(self):
727        self._test_forward_per_channel_cachemask_impl('cuda')
728
729    def test_forward_per_channel_half_precision_numerics(self):
730        scale = torch.randn(5).abs()
731        zero = torch.randn(5).to(dtype=torch.int)
732        axis = 1
733        mini = 0
734        maxi = 255
735
736        for i in range(20):
737            X1 = torch.randn(4, 5).to(torch.float16)
738            Y1 = torch.fake_quantize_per_channel_affine(X1, scale, zero, axis, mini, maxi)
739            Y1r = _fake_quantize_per_channel_affine_reference(X1, scale, zero, axis, mini, maxi)
740            self.assertEqual(Y1, Y1r, rtol=tolerance, atol=tolerance)
741
742        # to force overflow
743        X2 = torch.randn(4, 5).to(torch.float16)
744        X2[0, 0] = 2**15 + .01
745        Y2 = torch.fake_quantize_per_channel_affine(X2, scale, zero, axis, mini, maxi)
746        Y2r = _fake_quantize_per_channel_affine_reference(X2, scale, zero, axis, mini, maxi)
747        self.assertEqual(Y2, Y2r, rtol=tolerance, atol=tolerance)
748
749        scale = torch.zeros(5) + 10
750
751        # to force underflow
752        X3 = torch.randn(4, 5).to(torch.float16)
753        X3[0, 0] = 2**-24
754        Y3 = torch.fake_quantize_per_channel_affine(X3, scale, zero, axis, mini, maxi)
755        Y3r = _fake_quantize_per_channel_affine_reference(X3, scale, zero, axis, mini, maxi)
756        self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance)
757
758    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
759           qparams=hu.qparams(dtypes=torch.quint8)))
760    def test_fake_quant_per_channel_qparam_range(self, X):
761        X, (scale, zero_point, axis, torch_type) = X
762        quant_min = torch.iinfo(torch_type).min
763        quant_max = torch.iinfo(torch_type).max
764
765        for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']:
766            X = to_tensor(X, device)
767            scale = to_tensor(scale, device)
768
769            # Ensure that zero_point < quant_min.
770            zero_point = torch.full(zero_point.shape, -1 - quant_min).to(dtype=torch.int32, device=device)
771
772            # For non-float zero_point, fakequant requires zero_point between quant_min and quant_max.
773            with self.assertRaisesRegex(RuntimeError, "`zero_point` must be between `quant_min` and `quant_max`."):
774                Y = torch.fake_quantize_per_channel_affine(X, scale, zero_point, axis, quant_min, quant_max)
775
776            # For float zero_point, fakequant can be outside quant_min and quant_max.
777            for zero_point_dtype in [torch.float32, torch.float16]:
778                zero_point = zero_point.to(dtype=zero_point_dtype)
779                Y = torch.fake_quantize_per_channel_affine(X, scale, zero_point, axis, quant_min, quant_max)
780                Y_ref = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(),
781                                                                    axis, quant_min, quant_max)
782                np.testing.assert_allclose(Y.cpu().numpy(), Y_ref.cpu().numpy(), rtol=tolerance, atol=tolerance)
783
784    def _test_learnable_forward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
785        r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op.
786        """
787        for n_bits in (4, 8):
788            quant_min, quant_max = 0, 2 ** (n_bits) - 1
789
790            scale_base = scale_base.to(device)
791            zero_point_base = zero_point_base.to(device)
792
793            X_curr = X_base.clone()
794            scale_curr = scale_base.clone()
795            zero_point_curr = zero_point_base.clone()
796
797            Y = _fake_quantize_per_channel_affine_reference(
798                X_curr, scale_curr, zero_point_curr.round().clamp(quant_min, quant_max), axis, quant_min, quant_max).to(device)
799            for grad_factor in [0.1, 1.0, 10.0]:
800                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
801                    X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device)
802                self.assertTrue(
803                    torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
804                    "Expected kernel forward function to have results match the reference forward function")
805
806    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
807                                   qparams=hu.qparams(dtypes=torch.quint8)))
808    def test_learnable_forward_per_channel_cpu(self, X):
809        torch.random.manual_seed(NP_RANDOM_SEED)
810        X, (_, _, axis, _) = X
811        X_base = torch.tensor(X).to('cpu')
812        channel_size = X_base.size(axis)
813        scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
814        zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
815        self._test_learnable_forward_per_channel(
816            X_base, 'cpu', scale_base, zero_point_base, axis)
817
818    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
819                                   qparams=hu.qparams(dtypes=torch.quint8)))
820    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
821    def test_learnable_forward_per_channel_cuda(self, X):
822        torch.random.manual_seed(NP_RANDOM_SEED)
823        X, (_, _, axis, _) = X
824        X_base = torch.tensor(X).to('cuda')
825        channel_size = X_base.size(axis)
826        scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
827        zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
828        self._test_learnable_forward_per_channel(
829            X_base, 'cuda', scale_base, zero_point_base, axis)
830
831    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
832           X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
833           qparams=hu.qparams(dtypes=torch.quint8)))
834    @unittest.skip(
835        "this is broken without changes to any relevant code, "
836        "we need to remove hypothesis testing in CI")
837    def test_backward_per_channel(self, device, X):
838        r"""Tests the backward method.
839        """
840        np.random.seed(NP_RANDOM_SEED)
841        X, (scale, zero_point, axis, torch_type) = X
842        quant_min = torch.iinfo(torch_type).min
843        quant_max = torch.iinfo(torch_type).max
844        zero_point_types = (torch.int, torch.float, torch.float16)
845
846        for zero_point_type in zero_point_types:
847            X = to_tensor(X, device)
848            scale = to_tensor(scale, device)
849            zero_point = to_tensor(zero_point, device).to(dtype=zero_point_type)
850            X.requires_grad_()
851            Y_prime = torch.fake_quantize_per_channel_affine(
852                X, scale, zero_point, axis, quant_min, quant_max)
853            dout = torch.rand_like(X, dtype=torch.float).to(device)
854            dX = _fake_quantize_per_channel_affine_grad_reference(
855                dout, X, scale, zero_point, axis, quant_min, quant_max)
856            Y_prime.backward(dout)
857            np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
858
859    def _test_backward_per_channel_cachemask_impl(self, device):
860        torch_types = (torch.qint8, torch.quint8)
861        float_types = (torch.float32, torch.float16, torch.float64)
862        zero_point_types = (torch.int, torch.float32, torch.float16)
863
864        for torch_type, float_type, zero_point_type in itertools.product(torch_types, float_types, zero_point_types):
865            X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device)
866            # pick the scale + zp so that some values get clipped
867            axis = 1
868            obs = torch.ao.quantization.PerChannelMinMaxObserver(axis, torch_type).to(device)
869            obs(X * 0.75)
870            scale, zero_point = obs.calculate_qparams()
871            # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
872            zero_point = zero_point.to(zero_point_type)
873            quant_min, quant_max = obs.quant_min, obs.quant_max
874            X.requires_grad_()
875            Y_prime = torch.fake_quantize_per_channel_affine(
876                X, scale, zero_point, axis, quant_min, quant_max)
877            dout = torch.rand_like(X, dtype=float_type).to(device)
878            dX = _fake_quantize_per_channel_affine_grad_reference(
879                dout, X, scale, zero_point, axis, quant_min, quant_max)
880            Y_prime.backward(dout)
881            np.testing.assert_allclose(
882                dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
883            assert X.grad.dtype == float_type
884
885
886    def test_backward_per_channel_cachemask_cpu(self):
887        self._test_backward_per_channel_cachemask_impl('cpu')
888
889    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
890    def test_backward_per_channel_cachemask_cuda(self):
891        self._test_backward_per_channel_cachemask_impl('cuda')
892
893    def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
894        r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op.
895        """
896        for n_bits in (4, 8):
897            quant_min, quant_max = 0, 2 ** n_bits - 1
898
899            scale_base = scale_base.to(device)
900            zero_point_base = zero_point_base.to(device=device)
901
902            X_curr = X_base.clone()
903            X_curr.requires_grad_()
904            scale_curr = scale_base.clone()
905            scale_curr.requires_grad_()
906            zero_point_curr = zero_point_base.clone()
907            zero_point_curr.requires_grad_()
908
909            for grad_factor in [0.1, 1.0, 10.0]:
910                Y_prime = torch._fake_quantize_learnable_per_channel_affine(
911                    X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device)
912
913                dout = torch.rand(X_curr.shape, dtype=torch.float).to(device)
914                dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference(
915                    dout, X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, device)
916                Y_prime.backward(dout)
917
918                dX_expected = dX.to(device).detach()
919                dX_actual = X_curr.to(device).grad.detach()
920                dScale_expected = dScale.to(device).detach()
921                dScale_actual = scale_curr.to(device).grad.detach()
922                dZeroPoint_expected = dZeroPoint.to(device).detach()
923                dZeroPoint_actual = zero_point_curr.to(device).grad.detach()
924                tolerance = 1e-4
925
926                self.assertTrue(
927                    torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance),
928                    f"Expected dX={dX_expected} to match X.grad={dX_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}")  # noqa: B950
929                self.assertTrue(
930                    torch.allclose(dScale_expected * grad_factor, dScale_actual, rtol=tolerance, atol=tolerance),
931                    f"Expected dScale={dScale_expected * grad_factor} to match scale.grad={dScale_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}")  # noqa: B950
932                self.assertTrue(
933                    torch.allclose(dZeroPoint_expected * grad_factor, dZeroPoint_actual, rtol=tolerance, atol=tolerance),
934                    f"Expected dZeroPoint={dZeroPoint_expected * grad_factor} to match zero_point.grad={dZeroPoint_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}")  # noqa: B950
935                X_curr.grad.data.zero_()
936                scale_curr.grad.data.zero_()
937                zero_point_curr.grad.data.zero_()
938
939    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
940                                   qparams=hu.qparams(dtypes=torch.quint8)))
941    @unittest.skip(
942        "this is broken without changes to any relevant code, "
943        "we need to remove hypothesis testing in CI")
944    def test_learnable_backward_per_channel_cpu(self, X):
945        torch.random.manual_seed(NP_RANDOM_SEED)
946        X, (_, _, axis, _) = X
947        X_base = torch.tensor(X).to('cpu')
948        channel_size = X_base.size(axis)
949        scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
950        zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
951        self._test_learnable_backward_per_channel(
952            X_base, 'cpu', scale_base, zero_point_base, axis)
953
954    @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
955                                   qparams=hu.qparams(dtypes=torch.quint8)))
956    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
957    def test_learnable_backward_per_channel_cuda(self, X):
958        torch.random.manual_seed(NP_RANDOM_SEED)
959        X, (scale, zero_point, axis, torch_type) = X
960        X_base = torch.tensor(X).to('cuda')
961        scale_base = to_tensor(scale, 'cuda')
962        zero_point_base = to_tensor(zero_point, 'cuda')
963        self._test_learnable_backward_per_channel(
964            X_base, 'cuda', scale_base, zero_point_base, axis)
965
966    def test_numerical_consistency_per_tensor(self):
967        self._test_numerical_consistency('per_tensor')
968
969    def test_numerical_consistency_per_channel(self):
970        self._test_numerical_consistency('per_channel')
971
972    def _test_numerical_consistency(self, test_type):
973        r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes
974        """
975        torch.random.manual_seed(NP_RANDOM_SEED)
976        torch_types = [torch.qint8, torch.quint8]
977        float_types = [torch.float, torch.float16, torch.float64]
978        if test_type == "per_channel":
979            zero_types = [torch.int, torch.float, torch.float16]
980        else:
981            zero_types = [torch.int]
982        devices = [torch.device('cpu'), torch.device('cuda')] if torch.cuda.is_available() else [torch.device('cpu')]
983        axis = 1
984        for i in range(20):
985            for torch_type, float_type, device, zero_type in itertools.product(torch_types, float_types, devices, zero_types):
986                X = torch.randn(3, 3, device=device).to(float_type)
987                scales = (10 * torch.randn(3, device=device)).abs()
988                scale = scales.mean().to(float).item()
989                zeros = (10 * torch.randn(3, device=device)).abs().to(dtype=zero_type)
990                zero = zeros.max().view(1).item()
991                quant_min = torch.iinfo(torch_type).min
992                quant_max = torch.iinfo(torch_type).max
993
994                test_was_run = False
995                if test_type == "per_tensor":
996                    test_was_run = True
997                    Y = torch.dequantize(torch.quantize_per_tensor(X.to('cpu').to(torch.float),
998                                                                   scale, zero, torch_type)).to(device).to(float_type)
999                    Y_prime = torch.fake_quantize_per_tensor_affine(X, scale, zero, quant_min, quant_max)
1000                    self.assertEqual(
1001                        Y, Y_prime, "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor")
1002
1003                if test_type == "per_channel":
1004                    test_was_run = True
1005                    Y = torch.dequantize(torch.quantize_per_channel(X.to('cpu').to(torch.float), scales.to(
1006                        'cpu'), zeros.to('cpu'), axis, torch_type)).to(device).to(float_type)
1007                    Y_prime = torch.fake_quantize_per_channel_affine(X, scales, zeros, axis, quant_min, quant_max)
1008                    self.assertEqual(
1009                        Y, Y_prime, "Difference found between dequant+quant_per_channel and fake_quantize_per_channel")
1010                self.assertTrue(test_was_run)
1011
1012    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1013    def test_fake_quantize_per_channel_affine_scale_dtypes(self):
1014        """
1015        Ensure the error message is more helpful
1016        """
1017        dtype_list = [torch.float, torch.float64, torch.bfloat16, torch.half]
1018        for scale_dtype in dtype_list:
1019            input = torch.randn(3, 4, 5, 6)
1020            scale = torch.Tensor([0.1, 0.2, 0.3, 0.4]).to(scale_dtype)
1021            zero_point = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
1022            axis = 1
1023            quant_min = 0
1024            quant_max = 255
1025            if scale_dtype != torch.float:
1026                with self.assertRaises(RuntimeError):
1027                    torch.fake_quantize_per_channel_affine(
1028                        input, scale, zero_point, axis, quant_min, quant_max
1029                    )
1030            else:
1031                torch.fake_quantize_per_channel_affine(
1032                    input, scale, zero_point, axis, quant_min, quant_max
1033                )
1034
1035
1036class TestFusedObsFakeQuant(TestCase):
1037    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
1038           symmetric_quant=st.booleans())
1039    @settings(deadline=None)
1040    def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None:
1041        """
1042        Tests the case where we call the fused_obs_fake_quant op multiple times
1043        and update the running_min and max of the activation tensors.
1044        """
1045        in_running_min_ref = out_running_min_ref = float("inf")
1046        in_running_min_op = torch.tensor(float("inf"), device=device)
1047        in_running_max_ref = out_running_max_ref = float("-inf")
1048        in_running_max_op = torch.tensor(float("-inf"), device=device)
1049        avg_const = 0.01
1050        scale = torch.tensor([1.0], device=device)
1051        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1052        observer_on = fake_quant_on = 0
1053
1054        pt_op = torch.fused_moving_avg_obs_fake_quant
1055        # enable observer after 2 iterations and fake_quant after 4 iterations
1056        for i in range(10):
1057            if i > 2:
1058                observer_on = 1
1059            if i > 4:
1060                fake_quant_on = 1
1061
1062            x = torch.randn(5, 5, device=device)
1063            out = pt_op(
1064                x,
1065                torch.tensor(observer_on, device=device),
1066                torch.tensor(fake_quant_on, device=device),
1067                in_running_min_op,
1068                in_running_max_op,
1069                scale,
1070                zero_point,
1071                avg_const,
1072                0,
1073                255,
1074                0,
1075                False,
1076                symmetric_quant,
1077            )
1078            if observer_on:
1079                (
1080                    in_running_min_ref,
1081                    in_running_max_ref,
1082                ) = _get_tensor_min_max(
1083                    x,
1084                    running_min=in_running_min_ref,
1085                    running_max=in_running_max_ref,
1086                    averaging_const=0.01,
1087                )
1088
1089            if fake_quant_on:
1090                x_scale, x_zero_point = _get_scale_zp(
1091                    in_running_min_ref,
1092                    in_running_max_ref,
1093                    torch.quint8,
1094                    preserve_sparsity=symmetric_quant,
1095                )
1096                x_in = _fake_quantize_per_tensor_affine_reference(
1097                    x, x_scale, x_zero_point, 0, 255
1098                )
1099                self.assertEqual(scale, x_scale)
1100                self.assertEqual(zero_point, x_zero_point)
1101            else:
1102                x_in = x
1103
1104            self.assertEqual(in_running_min_ref, in_running_min_op)
1105            self.assertEqual(in_running_max_ref, in_running_max_op)
1106            torch.testing.assert_close(out, x_in)
1107
1108        # Test empty input works
1109        x = torch.empty(0, 5, device=device)
1110        out = pt_op(
1111            x,
1112            torch.tensor(1, device=device),
1113            torch.tensor(1, device=device),
1114            in_running_min_op,
1115            in_running_max_op,
1116            scale,
1117            zero_point,
1118            avg_const,
1119            0,
1120            255,
1121            0,
1122            False,
1123            symmetric_quant,
1124        )
1125        output_shape = (0, 5)
1126        self.assertEqual(out.shape, output_shape)
1127
1128    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
1129           symmetric_quant=st.booleans())
1130    @settings(deadline=None)
1131    def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None:
1132        """
1133        Tests the case where we call the fused_obs_fake_quant op multiple times
1134        and update the running_min and max of the activation tensors.
1135        """
1136        m = 5
1137        sizes = [[5, 5], [5, 4, 3]]
1138        for size in sizes:
1139            in_running_min_ref = torch.empty(m, device=device).fill_(float("inf"))
1140            in_running_min_op = torch.empty(m, device=device).fill_(float("inf"))
1141            in_running_max_ref = torch.empty(m, device=device).fill_(float("-inf"))
1142            in_running_max_op = torch.empty(m, device=device).fill_(float("-inf"))
1143            avg_const = 0.01
1144
1145            scale = torch.empty(m, device=device).fill_(0.1)
1146            zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
1147
1148            observer_on = fake_quant_on = 0
1149
1150            pt_op = torch.fused_moving_avg_obs_fake_quant
1151            # enable observer after 2 iterations and fake_quant after 4 iterations
1152            for i in range(10):
1153                if i > 2:
1154                    observer_on = 1
1155                if i > 4:
1156                    fake_quant_on = 1
1157
1158                x = torch.randn(size, device=device)
1159                out = pt_op(
1160                    x,
1161                    torch.tensor(observer_on, device=device),
1162                    torch.tensor(fake_quant_on, device=device),
1163                    in_running_min_op,
1164                    in_running_max_op,
1165                    scale,
1166                    zero_point,
1167                    avg_const,
1168                    0,
1169                    255,
1170                    0,
1171                    True,  # per_channel_enabled
1172                    symmetric_quant,
1173                )
1174                if observer_on:
1175                    (
1176                        in_running_min_ref,
1177                        in_running_max_ref,
1178                    ) = _get_per_row_min_max(x, in_running_min_ref, in_running_max_ref)
1179                if fake_quant_on:
1180                    x_scale = torch.empty(m, device=device)
1181                    x_zero_point = torch.empty(m, dtype=torch.int, device=device)
1182
1183                    for i in range(x_scale.numel()):
1184                        x_scale[i], x_zero_point[i] = _get_scale_zp(
1185                            in_running_min_ref[i].item(),
1186                            in_running_max_ref[i].item(),
1187                            torch.quint8,
1188                            preserve_sparsity=symmetric_quant,
1189                        )
1190                    x_in = _fake_quantize_per_channel_affine_reference(
1191                        x, x_scale, x_zero_point, 0, 0, 255
1192                    )
1193                    self.assertEqual(scale, x_scale)
1194                    self.assertEqual(zero_point, x_zero_point)
1195                else:
1196                    x_in = x
1197                self.assertEqual(in_running_min_ref, in_running_min_op)
1198                self.assertEqual(in_running_max_ref, in_running_max_op)
1199                torch.testing.assert_close(out, x_in)
1200
1201    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),)
1202    @settings(deadline=None)
1203    def test_fused_obs_fake_quant_backward_op(self, device) -> None:
1204        n = m = k = 10
1205        input_shape = (m, n)
1206        output_shape = (m, n)
1207
1208        x = torch.randn(input_shape, device=device, requires_grad=True)
1209
1210        avg_const = 0.01
1211        scale = torch.tensor([1.0], device=device)
1212        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1213
1214        x_min, x_max = _get_tensor_min_max(x)
1215        x_scale, x_zero_point = _get_scale_zp(
1216            x_min, x_max, torch.quint8
1217        )
1218
1219        x_scale = torch.tensor(x_scale, device=device)
1220        x_zero_point = torch.tensor(x_zero_point, dtype=torch.int, device=device)
1221        x_fake_quant = torch.fake_quantize_per_tensor_affine(
1222            x, x_scale, x_zero_point, 0, 255
1223        )
1224
1225        pt_op = torch.fused_moving_avg_obs_fake_quant
1226        out = pt_op(
1227            x,
1228            torch.tensor(1, device=device),
1229            torch.tensor(1, device=device),
1230            torch.tensor(x_min, device=device),
1231            torch.tensor(x_max, device=device),
1232            scale,
1233            zero_point,
1234            avg_const,
1235            0,
1236            255,
1237            0,
1238            False,
1239        )
1240        # verify the output matches
1241        torch.testing.assert_close(out, x_fake_quant)
1242
1243        # verify the gradient matches expectation of fake_quant op
1244        dout = torch.rand_like(x, dtype=torch.float).to(device)
1245        out.backward(dout)
1246
1247        dX = _fake_quantize_per_tensor_affine_grad_reference(
1248            dout, x, x_scale, x_zero_point, 0, 255)
1249        self.assertEqual(dX, x.grad)
1250        self.assertTrue(x.grad.dtype == torch.float32)
1251
1252    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),)
1253    @settings(deadline=None)
1254    def test_fused_backward_op_fake_quant_off(self, device) -> None:
1255        n = m = 4
1256        input_shape = (m, n)
1257        output_shape = (m, n)
1258
1259        x = torch.randn(input_shape, device=device, requires_grad=True)
1260
1261        avg_const = 0.01
1262        scale = torch.tensor([1.0], device=device)
1263        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1264
1265        x_min, x_max = _get_tensor_min_max(x)
1266        x_scale, x_zero_point = _get_scale_zp(
1267            x_min, x_max, torch.quint8
1268        )
1269
1270
1271        pt_op = torch.fused_moving_avg_obs_fake_quant
1272        out = pt_op(
1273            x,
1274            torch.tensor(0, device=device),
1275            torch.tensor(0, device=device),
1276            torch.tensor(x_min, device=device),
1277            torch.tensor(x_max, device=device),
1278            scale,
1279            zero_point,
1280            avg_const,
1281            0,
1282            255,
1283            0,
1284            False,
1285        )
1286        # verify the output matches
1287        torch.testing.assert_close(out, x)
1288
1289        # verify the gradient matches expectation of fake_quant op
1290        dout = torch.rand_like(x, dtype=torch.float).to(device)
1291        out.backward(dout)
1292
1293        dX = _fake_quantize_per_tensor_affine_grad_reference(
1294            dout, x, x_scale, x_zero_point, 0, 255)
1295        self.assertEqual(dX, x.grad)
1296        self.assertTrue(x.grad.dtype == torch.float32)
1297
1298if __name__ == '__main__':
1299    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
1300                       "\tpython test/test_quantization.py TESTNAME\n\n"
1301                       "instead.")
1302