1# Owner(s): ["oncall: quantization"] 2 3import torch 4import math 5from typing import Tuple 6from torch.ao.quantization import ( 7 FakeQuantize, 8 MovingAverageMinMaxObserver, 9 default_observer, 10 default_fixed_qparams_range_0to1_fake_quant, 11) 12 13from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize 14from torch.testing._internal.common_quantized import ( 15 _fake_quantize_per_channel_affine_reference, 16 _fake_quantize_per_channel_affine_grad_reference, 17 to_tensor, 18) 19import torch.nn as nn 20 21# Standard library 22import io 23import itertools 24import unittest 25import numpy as np 26 27# Testing utils 28from hypothesis import given, settings 29from hypothesis import strategies as st 30import torch.testing._internal.hypothesis_utils as hu 31hu.assert_deadline_disabled() 32from torch.testing._internal.common_cuda import TEST_CUDA 33from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo 34 35# Reference method for fake quantize 36# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 37def _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, quant_min, quant_max): 38 dtype = X.dtype 39 res = ((torch.clamp(torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point), quant_min, quant_max) - zero_point) * scale) 40 return res.to(dtype) 41 42# Reference method for the gradient of the fake quantize operator 43# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 44def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max): 45 dtype = X.dtype 46 Xq = torch.round(X.to(torch.float32) * (1.0 / scale) + zero_point) 47 mask = (Xq >= quant_min) * (Xq <= quant_max) 48 res = torch.zeros_like(dY) 49 res[mask] = dY[mask] 50 return res.to(dtype) 51 52# Reference method for the gradients of the fake quantize operator 53def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device): 54 r"""This method references the following literatures for back propagation on scale and zero point. 55 - https://arxiv.org/pdf/1902.08153.pdf 56 - https://arxiv.org/pdf/1903.08066.pdf 57 """ 58 zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item()) 59 Xq = torch.round(X * (1.0 / scale) + zero_point_rounded) 60 61 indicate_small_scale = (Xq < quant_min).float().to(device) 62 indicate_big_scale = (Xq > quant_max).float().to(device) 63 indicate_middle_scale = torch.ones(indicate_small_scale.shape).to(device) - \ 64 indicate_small_scale - indicate_big_scale 65 66 indicate_saturate_zp = ((Xq < quant_min).float() + (Xq > quant_max).float()).to(device) 67 indicate_unsaturate_zp = torch.ones(indicate_saturate_zp.shape).to(device) - indicate_saturate_zp 68 69 Xq = Xq.clamp(quant_min, quant_max) 70 Xfq = (Xq - zero_point_rounded) * scale 71 72 grad_small_scale = quant_min - zero_point_rounded 73 grad_big_scale = quant_max - zero_point_rounded 74 grad_middle_scale = ((Xfq - X) / scale).to(device) 75 76 grad_saturate_zp = -scale.to(device) 77 grad_unsaturate_zp = 0 78 79 grad_scale = indicate_small_scale * grad_small_scale + \ 80 indicate_big_scale * grad_big_scale + \ 81 indicate_middle_scale * grad_middle_scale 82 grad_zp = indicate_saturate_zp * grad_saturate_zp + \ 83 indicate_unsaturate_zp * grad_unsaturate_zp 84 grad_X = _fake_quantize_per_tensor_affine_grad_reference( 85 dY, X, scale, zero_point, quant_min, quant_max).to(device) 86 87 grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0) 88 grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0) 89 return grad_X, grad_scale, grad_zp 90 91 92# Reference method for quantization. 93def _quantize_per_tensor(x, scale, zero_point, quant_min, quant_max): 94 return ((x / scale) + zero_point).round().clamp(quant_min, quant_max) 95 96# Reference method for the per channel gradients of the learnable fake quantize operator 97def _fake_quantize_learnable_per_channel_affine_grad_reference( 98 dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max, device): 99 r"""This method references the following literatures for back propagation on scale and zero point. 100 - https://arxiv.org/pdf/1902.08153.pdf 101 - https://arxiv.org/pdf/1903.08066.pdf 102 """ 103 per_channel_zero_point = ((per_channel_zero_point.detach() + 0.5).clamp(quant_min, quant_max)).type(torch.int32) 104 grad_X = _fake_quantize_per_channel_affine_grad_reference( 105 dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max).to(device) 106 per_channel_scale = per_channel_scale.detach().type(torch.float) 107 108 grad_scale = torch.zeros([per_channel_scale.size(0)]).to(device) 109 grad_zero_point = torch.zeros([per_channel_zero_point.size(0)]).to(device) 110 111 X_flattened = torch.unbind(X, dim=axis) 112 dY_flattened = torch.unbind(dY, dim=axis) 113 114 for i, X_i in enumerate(torch.unbind(X, dim=axis), 0): 115 scale_i = per_channel_scale[i] 116 zero_point_i = per_channel_zero_point[i] 117 X_i = X_flattened[i] 118 dY_i = dY_flattened[i] 119 120 Xq_i = ((X_i / scale_i) + zero_point_i).round() 121 Xfq_i = (Xq_i - zero_point_i) * scale_i 122 123 indicate_small_scale_i = (Xq_i < quant_min).float().to(device) 124 indicate_big_scale_i = (Xq_i > quant_max).float().to(device) 125 indicate_middle_scale_i = torch.ones(indicate_small_scale_i.shape).to(device) - \ 126 indicate_small_scale_i - indicate_big_scale_i 127 128 indicate_saturate_zp_i = ((Xq_i < quant_min).float() + 129 (Xq_i > quant_max).float()).to(device) 130 indicate_unsaturate_zp_i = torch.ones(indicate_saturate_zp_i.shape).to(device) - \ 131 indicate_saturate_zp_i 132 133 Xq_i = Xq_i.clamp(quant_min, quant_max) 134 Xfq_i = (Xq_i - zero_point_i) * scale_i 135 136 grad_small_scale_i = quant_min - zero_point_i 137 grad_big_scale_i = quant_max - zero_point_i 138 grad_middle_scale_i = ((Xfq_i - X_i) / scale_i).to(device) 139 140 grad_saturate_zp_i = -scale_i.to(device) 141 grad_unsaturate_zp_i = 0 142 143 grad_scale_i = indicate_small_scale_i * grad_small_scale_i + \ 144 indicate_middle_scale_i * grad_middle_scale_i + \ 145 indicate_big_scale_i * grad_big_scale_i 146 grad_zp_i = indicate_saturate_zp_i * grad_saturate_zp_i + \ 147 indicate_unsaturate_zp_i * grad_unsaturate_zp_i 148 149 grad_scale_i = (grad_scale_i * dY_i).sum().unsqueeze(dim=0) 150 grad_zp_i = (grad_zp_i * dY_i).sum().unsqueeze(dim=0) 151 152 grad_scale[i] = grad_scale_i 153 grad_zero_point[i] = grad_zp_i 154 return grad_X, grad_scale, grad_zero_point 155 156def _get_tensor_min_max( 157 X: torch.Tensor, 158 running_min: float = float("inf"), 159 running_max: float = float("-inf"), 160 averaging_const: float = 0.01) -> Tuple[float, float]: 161 min_val = X.min().to(dtype=torch.float32).item() 162 max_val = X.max().to(dtype=torch.float32).item() 163 164 if not math.isinf(running_min): 165 min_val = running_min + averaging_const * (min_val - running_min) 166 if not math.isinf(running_max): 167 max_val = running_max + averaging_const * (max_val - running_max) 168 169 return min_val, max_val 170 171def _get_per_row_min_max( 172 x: torch.Tensor, 173 min_vals: torch.Tensor, 174 max_vals: torch.Tensor, 175 axis: int = 0, 176 averaging_const: float = 0.01) -> Tuple[torch.Tensor, torch.Tensor]: 177 x_dim = x.size() 178 new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 179 new_axis_list[axis] = 0 180 new_axis_list[0] = axis 181 y = x.permute(*new_axis_list) 182 183 y = torch.flatten(y, start_dim=1) 184 # min_vals, max_vals = torch.aminmax(y, dim=1) 185 if math.isinf(min_vals[0]) or math.isinf(max_vals[0]): 186 min_vals, max_vals = torch.aminmax(y, dim=1) 187 else: 188 min_vals_cur, max_vals_cur = torch.aminmax(y, dim=1) 189 min_vals = min_vals + averaging_const * (min_vals_cur - min_vals) 190 max_vals = max_vals + averaging_const * (max_vals_cur - max_vals) 191 return min_vals, max_vals 192 193def _get_scale_zp( 194 min_val: float, 195 max_val: float, 196 dtype: torch.dtype, 197 reduce_range: bool = False, 198 preserve_sparsity: bool = False) -> Tuple[float, int]: 199 """ 200 Calculate the quantization parameters (scale, zero_point) 201 based on the min and max element of the tensor 202 """ 203 if dtype == torch.qint8: 204 if reduce_range: 205 qmin, qmax = -64, 63 206 else: 207 qmin, qmax = -128, 127 208 else: 209 if reduce_range: 210 qmin, qmax = 0, 127 211 else: 212 qmin, qmax = 0, 255 213 214 if min_val < 0 and max_val > 0 and preserve_sparsity: 215 symmetric_qmin = int(-((qmax - qmin) / 2 + 1)) 216 symmetric_qmax = int((qmax - qmin) / 2) 217 max_scale = max( 218 abs(min_val / symmetric_qmin), abs(max_val / symmetric_qmax) 219 ) 220 min_val = max_scale * symmetric_qmin 221 max_val = max_scale * symmetric_qmax 222 min_val = min(min_val, 0.0) 223 max_val = max(max_val, 0.0) 224 scale = (max_val - min_val) / (qmax - qmin) 225 if scale == 0.0 or math.isinf(1.0 / scale): 226 scale = 0.1 227 zero_point = 0 228 229 zero_point_from_min = qmin - min_val / float(scale) 230 zero_point_from_max = qmax - max_val / float(scale) 231 zero_point_from_min_error = abs(qmin) - abs(min_val / float(scale)) 232 zero_point_from_max_error = abs(qmax) - abs(max_val / float(scale)) 233 if zero_point_from_min_error < zero_point_from_max_error: 234 initial_zero_point = zero_point_from_min 235 else: 236 initial_zero_point = zero_point_from_max 237 238 if min_val < 0 and max_val > 0 and preserve_sparsity: 239 initial_zero_point = (qmin + qmax) / 2 + 1 240 241 nudged_zero_point = 0 242 243 if initial_zero_point < qmin: 244 nudged_zero_point = qmin 245 elif initial_zero_point > qmax: 246 nudged_zero_point = qmax 247 else: 248 nudged_zero_point = int(round(initial_zero_point)) 249 250 return (scale, int(nudged_zero_point)) 251 252NP_RANDOM_SEED = 19 253tolerance = 1e-6 254 255class TestFakeQuantizeOps(TestCase): 256 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 257 X=hu.tensor(shapes=hu.array_shapes(1, 5,), 258 qparams=hu.qparams(dtypes=torch.quint8))) 259 def test_forward_per_tensor(self, device, X): 260 r"""Tests the forward path of the FakeQuantizePerTensorAffine op. 261 """ 262 np.random.seed(NP_RANDOM_SEED) 263 X, (scale, zero_point, torch_type) = X 264 quant_min = torch.iinfo(torch_type).min 265 quant_max = torch.iinfo(torch_type).max 266 267 X = to_tensor(X, device) 268 Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max) 269 Y_prime = torch.fake_quantize_per_tensor_affine( 270 X, scale, zero_point, quant_min, quant_max) 271 np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance) 272 273 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 274 X=hu.tensor(shapes=hu.array_shapes(1, 5,), 275 qparams=hu.qparams(dtypes=torch.quint8))) 276 @unittest.skip("temporarily disable the test") 277 def test_backward_per_tensor(self, device, X): 278 r"""Tests the backward method. 279 """ 280 np.random.seed(NP_RANDOM_SEED) 281 X, (scale, zero_point, torch_type) = X 282 quant_min = torch.iinfo(torch_type).min 283 quant_max = torch.iinfo(torch_type).max 284 285 X = to_tensor(X, device) 286 X.requires_grad_() 287 Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max) 288 Y_prime = torch.fake_quantize_per_tensor_affine( 289 X, scale, zero_point, quant_min, quant_max) 290 dout = torch.rand_like(X, dtype=torch.float).to(device) 291 dX = _fake_quantize_per_tensor_affine_grad_reference( 292 dout, X, scale, zero_point, quant_min, quant_max) 293 Y_prime.backward(dout) 294 np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 295 296 def test_forward_backward_per_tensor_with_amp(self): 297 net = nn.Sequential(nn.Conv2d(1, 1, 3)) 298 net.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm') 299 net_prep = torch.ao.quantization.prepare_qat(net) 300 301 with torch.cuda.amp.autocast(): 302 x = torch.randn(4, 1, 5, 5) 303 out = net_prep(x).sum() 304 out.backward() 305 self.assertTrue(net_prep[0].weight.grad is not None) 306 307 def test_forward_per_tensor_half_precision_numerics(self): 308 scale = .1 309 zero = 0 310 maxi = 255 311 mini = 0 312 313 for i in range(20): 314 X1 = torch.randn(5, 5).to(torch.float16) 315 Y1 = torch.fake_quantize_per_tensor_affine(X1, scale, zero, mini, maxi) 316 Y1r = _fake_quantize_per_tensor_affine_reference(X1, scale, zero, mini, maxi) 317 self.assertEqual(Y1, Y1r, rtol=tolerance, atol=tolerance) 318 319 # to force overflow 320 X2 = torch.tensor(2**15 + .01).to(torch.float16) 321 Y2 = torch.fake_quantize_per_tensor_affine(X2, scale, zero, mini, maxi) 322 Y2r = _fake_quantize_per_tensor_affine_reference(X2, scale, zero, mini, maxi) 323 self.assertEqual(Y2, Y2r, rtol=tolerance, atol=tolerance) 324 325 scale = 10 326 327 # to force underflow 328 X3 = torch.tensor(2**-24).to(torch.float16) 329 Y3 = torch.fake_quantize_per_tensor_affine(X3, scale, zero, mini, maxi) 330 Y3r = _fake_quantize_per_tensor_affine_reference(X3, scale, zero, mini, maxi) 331 self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance) 332 333 def _test_forward_per_tensor_cachemask_impl(self, device): 334 float_types = (torch.float32, torch.float16, torch.float64) 335 torch_types = (torch.qint8, torch.quint8) 336 Xs = (torch.randn(4, 8, device=device), torch.randn(4, 16, device=device)[:, ::2]) 337 tensor_qparam = (True, False) 338 for float_type, torch_type, X, tensor_qparams in itertools.product(float_types, torch_types, Xs, tensor_qparam): 339 # pick the scale + zp so that some values get clipped 340 X = X.to(float_type) 341 obs = torch.ao.quantization.MinMaxObserver(torch_type) 342 obs.to(device) 343 obs(X * 0.75) 344 scale, zero_point = obs.calculate_qparams() 345 quant_min, quant_max = obs.quant_min, obs.quant_max 346 if not tensor_qparam: 347 scale, zero_point = float(scale), int(zero_point) 348 Y_test = torch.fake_quantize_per_tensor_affine( 349 X, scale, zero_point, quant_min, quant_max) 350 Y_ref = _fake_quantize_per_tensor_affine_reference( 351 X, scale, zero_point, quant_min, quant_max).to(device) 352 self.assertEqual(Y_test, Y_ref, rtol=tolerance, atol=tolerance) 353 self.assertTrue(Y_test.dtype == float_type) 354 355 def test_forward_per_tensor_cachemask_cpu(self): 356 device = torch.device('cpu') 357 self._test_forward_per_tensor_cachemask_impl(device) 358 359 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 360 def test_forward_per_tensor_cachemask_cuda(self): 361 device = torch.device('cuda') 362 self._test_forward_per_tensor_cachemask_impl(device) 363 364 def _test_backward_per_tensor_cachemask_impl(self, device): 365 float_types = (torch.float32, torch.float16, torch.float64) 366 torch_types = (torch.qint8, torch.quint8) 367 tensor_qparams = (True, False) 368 for float_type, torch_type, tensor_qparam in itertools.product(float_types, torch_types, tensor_qparams): 369 X = torch.randn(4, 8).to(device).to(float_type) 370 X.requires_grad_() 371 # pick the scale + zp so that some values get clipped 372 obs = torch.ao.quantization.MinMaxObserver(torch_type) 373 obs.to(device) 374 obs(X * 0.75) 375 scale, zero_point = obs.calculate_qparams() 376 if not tensor_qparam: 377 scale, zero_point = float(scale), int(zero_point) 378 quant_min, quant_max = obs.quant_min, obs.quant_max 379 380 # forward pass 381 Y_test = torch.fake_quantize_per_tensor_affine( 382 X, scale, zero_point, quant_min, quant_max) 383 Y_ref = _fake_quantize_per_tensor_affine_reference( 384 X, scale, zero_point, quant_min, quant_max).to(device) 385 self.assertEqual(Y_test, Y_ref, rtol=tolerance, atol=tolerance) 386 387 # backward pass 388 dout = torch.rand_like(X, dtype=torch.float).to(device) 389 dX = _fake_quantize_per_tensor_affine_grad_reference( 390 dout, X, scale, zero_point, quant_min, quant_max) 391 Y_test.backward(dout) 392 self.assertEqual(dX, X.grad) 393 self.assertTrue(X.grad.dtype == float_type) 394 395 def test_backward_per_tensor_cachemask_cpu(self): 396 device = torch.device('cpu') 397 self._test_backward_per_tensor_cachemask_impl(device) 398 399 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 400 def test_backward_per_tensor_cachemask_cuda(self): 401 device = torch.device('cuda') 402 self._test_backward_per_tensor_cachemask_impl(device) 403 404 def _test_learnable_forward_per_tensor(self, X, device, scale_base, zero_point_base): 405 X_base = torch.tensor(X).to(device) 406 407 for n_bits in (4, 8): 408 quant_min, quant_max = 0, 2 ** n_bits - 1 409 410 X = X_base.clone().float() 411 scale_base = scale_base.to(device).float() 412 zero_point_base = zero_point_base.to(dtype=torch.int32, device=device) 413 scale = scale_base.clone() 414 zero_point = zero_point_base.clamp(quant_min, quant_max) 415 416 Y = _fake_quantize_per_tensor_affine_reference( 417 X, scale, zero_point, quant_min, quant_max).to(device) 418 for grad_factor in [0.1, 1.0, 10.0]: 419 Y_prime = torch._fake_quantize_learnable_per_tensor_affine( 420 X, scale, zero_point, quant_min, quant_max, grad_factor).to(device) 421 self.assertTrue( 422 torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance), 423 "Expected kernel forward function to have results match the reference forward function") 424 425 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,), 426 elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 427 qparams=hu.qparams(dtypes=torch.quint8))) 428 @unittest.skip( 429 "this is broken without changes to any relevant code, " 430 "we need to remove hypothesis testing in CI") 431 def test_learnable_forward_per_tensor_cpu(self, X): 432 X, (_, _, _) = X 433 scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100) 434 zero_point_base = torch.normal(mean=0, std=128, size=(1,)) 435 self._test_learnable_forward_per_tensor( 436 X, 'cpu', scale_base, zero_point_base) 437 438 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,), 439 elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 440 qparams=hu.qparams(dtypes=torch.quint8))) 441 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 442 def test_learnable_forward_per_tensor_cuda(self, X): 443 X, (_, _, _) = X 444 scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100) 445 zero_point_base = torch.normal(mean=0, std=128, size=(1,)) 446 self._test_learnable_forward_per_tensor( 447 X, 'cuda', scale_base, zero_point_base) 448 449 def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base): 450 r"""Tests the backward method with additional backprop support for scale and zero point. 451 """ 452 X_base = torch.tensor(X).to(device) 453 454 for n_bits in (4, 8): 455 quant_min, quant_max = 0, 2 ** n_bits - 1 456 457 X = X_base.clone().float().to(device) 458 X.requires_grad_() 459 scale_base = scale_base.to(device) 460 zero_point_base = zero_point_base.to(device) 461 scale = scale_base.clone() 462 scale.requires_grad_() 463 zero_point = zero_point_base.clone().clamp(quant_min, quant_max) 464 zero_point.requires_grad_() 465 for grad_factor in [0.1, 1.0, 10.0]: 466 Y_prime = torch._fake_quantize_learnable_per_tensor_affine( 467 X, scale, zero_point, quant_min, quant_max, grad_factor).to(device) 468 dout = torch.rand_like(X, dtype=torch.float).to(device) 469 dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference( 470 dout, X, scale, zero_point, quant_min, quant_max, device) 471 Y_prime.backward(dout) 472 473 expected_dX = dX.to(device).detach() 474 actual_dX = X.grad.to(device).detach() 475 expected_dScale = dScale.to(device).detach() 476 actual_dScale = scale.grad.to(device).detach() 477 expected_dZeroPoint = dZeroPoint.to(device).detach() 478 actual_dZeroPoint = zero_point.grad.to(device).detach() 479 480 self.assertTrue( 481 torch.allclose( 482 expected_dX, actual_dX, rtol=tolerance, atol=tolerance), 483 "Expected dX to match X.grad") 484 self.assertTrue( 485 torch.allclose( 486 expected_dScale * grad_factor, actual_dScale, rtol=tolerance, atol=tolerance), 487 "Expected dScale to match scale.grad") 488 self.assertTrue( 489 torch.allclose( 490 expected_dZeroPoint * grad_factor, actual_dZeroPoint, rtol=tolerance, atol=tolerance), 491 "Expected dZeroPoint to match zero_point.grad") 492 X.grad.data.zero_() 493 scale.grad.data.zero_() 494 zero_point.grad.data.zero_() 495 496 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,), 497 elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 498 qparams=hu.qparams(dtypes=torch.quint8))) 499 def test_learnable_backward_per_tensor_cpu(self, X): 500 torch.random.manual_seed(NP_RANDOM_SEED) 501 X, (_, _, _) = X 502 scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100) 503 zero_point_base = torch.normal(mean=0, std=128, size=(1,)) 504 self._test_learnable_backward_per_tensor( 505 X, 'cpu', scale_base, zero_point_base) 506 507 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,), 508 elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 509 qparams=hu.qparams(dtypes=torch.quint8))) 510 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 511 def test_learnable_backward_per_tensor_cuda(self, X): 512 torch.random.manual_seed(NP_RANDOM_SEED) 513 X, (_, _, _) = X 514 scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100) 515 zero_point_base = torch.normal(mean=0, std=128, size=(1,)) 516 self._test_learnable_backward_per_tensor( 517 X, 'cuda', scale_base, zero_point_base) 518 519 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 520 X=hu.tensor(shapes=hu.array_shapes(1, 5,), 521 qparams=hu.qparams(dtypes=[torch.quint8])), 522 ) 523 def test_fq_module_per_tensor(self, device, X): 524 np.random.seed(NP_RANDOM_SEED) 525 X, (scale, zero_point, torch_type) = X 526 quant_min = torch.iinfo(torch_type).min 527 quant_max = torch.iinfo(torch_type).max 528 529 X = to_tensor(X, device) 530 X.requires_grad_() 531 fq_module = torch.ao.quantization.default_fake_quant().to(device) 532 Y_prime = fq_module(X) 533 assert fq_module.scale is not None 534 assert fq_module.zero_point is not None 535 Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale, fq_module.zero_point, quant_min, quant_max) 536 np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 537 538 # Test backward 539 dout = torch.rand_like(X, dtype=torch.float, device=device) 540 Y_prime.backward(dout) 541 dX = _fake_quantize_per_tensor_affine_grad_reference(dout, X, fq_module.scale, fq_module.zero_point, quant_min, quant_max) 542 np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 543 544 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 545 X=hu.tensor(shapes=hu.array_shapes(1, 5,), 546 qparams=hu.qparams(dtypes=torch.quint8))) 547 def test_fixed_qparams_fq_module(self, device, X): 548 X, (scale, zero_point, torch_type) = X 549 X = to_tensor(X, device) 550 fq_module = default_fixed_qparams_range_0to1_fake_quant() 551 fq_module.to(device) 552 fixed_scale = fq_module.scale.clone() 553 fixed_zero_point = fq_module.zero_point.clone() 554 # run fq module and make sure the quantization parameters does not change 555 torch.ao.quantization.enable_observer(fq_module) 556 fq_module(X) 557 self.assertEqual(fixed_scale, fq_module.scale) 558 self.assertEqual(fixed_zero_point, fq_module.zero_point) 559 560 def test_fq_serializable_per_tensor(self): 561 observer = default_observer 562 quant_min = 0 563 quant_max = 127 564 for FakeQuantizeClass in [FakeQuantize, _LearnableFakeQuantize]: 565 fq_module = FakeQuantizeClass(observer, quant_min, quant_max) 566 X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) 567 y_ref = fq_module(X) 568 state_dict = fq_module.state_dict() 569 self.assertEqual(state_dict['scale'], 0.094488) 570 self.assertEqual(state_dict['zero_point'], 53) 571 b = io.BytesIO() 572 torch.save(state_dict, b) 573 for weights_only in [True, False]: 574 b.seek(0) 575 loaded_dict = torch.load(b, weights_only=weights_only) 576 loaded_fq_module = FakeQuantizeClass(observer, quant_min, quant_max) 577 loaded_fq_module.load_state_dict(loaded_dict) 578 for key in state_dict: 579 self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key]) 580 581 self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams()) 582 583 def test_fake_quant_control(self): 584 for fq_module in [torch.ao.quantization.default_fake_quant(), 585 _LearnableFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, 586 quant_max=255, 587 dtype=torch.quint8, qscheme=torch.per_tensor_affine, 588 reduce_range=True)()]: 589 torch.manual_seed(42) 590 X = torch.rand(20, 10, dtype=torch.float32) 591 # Output of fake quant is not identical to input 592 Y = fq_module(X) 593 self.assertNotEqual(Y, X) 594 if type(fq_module) == _LearnableFakeQuantize: 595 fq_module.toggle_fake_quant(False) 596 else: 597 torch.ao.quantization.disable_fake_quant(fq_module) 598 X = torch.rand(20, 10, dtype=torch.float32) 599 Y = fq_module(X) 600 # Fake quant is disabled,output is identical to input 601 self.assertEqual(Y, X) 602 603 # Explicit copy at this point in time, because FakeQuant keeps internal 604 # state in mutable buffers. 605 scale = fq_module.scale.clone().detach() 606 zero_point = fq_module.zero_point.clone().detach() 607 608 if type(fq_module) == _LearnableFakeQuantize: 609 fq_module.toggle_observer_update(False) 610 fq_module.toggle_fake_quant(True) 611 else: 612 torch.ao.quantization.disable_observer(fq_module) 613 torch.ao.quantization.enable_fake_quant(fq_module) 614 X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 615 Y = fq_module(X) 616 self.assertNotEqual(Y, X) 617 # Observer is disabled, scale and zero-point do not change 618 self.assertEqual(fq_module.scale, scale) 619 self.assertEqual(fq_module.zero_point, zero_point) 620 if type(fq_module) == _LearnableFakeQuantize: 621 fq_module.toggle_observer_update(True) 622 else: 623 torch.ao.quantization.enable_observer(fq_module) 624 Y = fq_module(X) 625 self.assertNotEqual(Y, X) 626 # Observer is enabled, scale and zero-point are different 627 self.assertNotEqual(fq_module.scale, scale) 628 self.assertNotEqual(fq_module.zero_point, zero_point) 629 630 def test_fake_quant_preserves_qparam_shapes_for_activations(self): 631 class Model(nn.Module): 632 def __init__(self) -> None: 633 super().__init__() 634 self.linear = nn.Linear(4, 4) 635 636 def forward(self, x): 637 x = self.linear(x) 638 return x 639 640 m = Model() 641 642 m.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm') 643 torch.ao.quantization.prepare_qat(m, inplace=True) 644 645 scale_shape_before = m.linear.activation_post_process.scale.shape 646 zero_point_shape_before = m.linear.activation_post_process.zero_point.shape 647 648 x = torch.rand(4, 4, 4, 4) 649 m(x) 650 scale_shape_after = m.linear.activation_post_process.scale.shape 651 zero_point_shape_after = m.linear.activation_post_process.zero_point.shape 652 self.assertEqual( 653 scale_shape_before, scale_shape_after, 654 msg="FakeQuant scale shape must stay consistent") 655 self.assertEqual( 656 zero_point_shape_before, zero_point_shape_after, 657 msg="FakeQuant zero_point shape must stay consistent") 658 659 def fake_quant_scriptable(self): 660 observer = default_observer 661 quant_min = 0 662 quant_max = 255 663 for FakeQuantizeClass in [FakeQuantize, _LearnableFakeQuantize]: 664 fq_module = FakeQuantizeClass(observer, quant_min, quant_max) 665 scripted_module = torch.jit.script(fq_module) 666 667 X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32) 668 669 fq_module(X) 670 scripted_module(X) 671 self.assertEqual(fq_module.calculate_qparams(), scripted_module.calculate_qparams()) 672 673 buf = io.BytesIO() 674 torch.jit.save(scripted_module, buf) 675 buf.seek(0) 676 loaded_module = torch.jit.load(buf) 677 self.assertEqual(fq_module.calculate_qparams(), loaded_module.calculate_qparams()) 678 679 680 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 681 X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), 682 qparams=hu.qparams(dtypes=torch.quint8))) 683 def test_forward_per_channel(self, device, X): 684 r"""Tests the forward path of the FakeQuantizePerTensorAffine op. 685 """ 686 np.random.seed(NP_RANDOM_SEED) 687 X, (scale, zero_point, axis, torch_type) = X 688 quant_min = torch.iinfo(torch_type).min 689 quant_max = torch.iinfo(torch_type).max 690 691 X = to_tensor(X, device) 692 scale = to_tensor(scale, device) 693 zero_point = torch.tensor(zero_point).to(dtype=torch.int32, device=device) 694 Y = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max) 695 Y_prime = torch.fake_quantize_per_channel_affine( 696 X, scale, zero_point, axis, quant_min, quant_max) 697 np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance) 698 699 def _test_forward_per_channel_cachemask_impl(self, device): 700 torch_types = (torch.qint8, torch.quint8) 701 float_types = (torch.float32, torch.float16, torch.float64) 702 zero_point_types = (torch.int, torch.float32, torch.float16) 703 704 for torch_type, float_type, zero_point_type in itertools.product(torch_types, float_types, zero_point_types): 705 X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device) 706 # pick the scale + zp so that some values get clipped 707 axis = 1 708 obs = torch.ao.quantization.PerChannelMinMaxObserver(axis, torch_type).to(device) 709 obs(X * 0.75) 710 scale, zero_point = obs.calculate_qparams() 711 # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast 712 zero_point = zero_point.to(zero_point_type) 713 quant_min, quant_max = obs.quant_min, obs.quant_max 714 715 Y = _fake_quantize_per_channel_affine_reference( 716 X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max) 717 Y_prime = torch.fake_quantize_per_channel_affine( 718 X, scale, zero_point, axis, quant_min, quant_max) 719 np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance) 720 self.assertTrue(Y.dtype == float_type) 721 722 def test_forward_per_channel_cachemask_cpu(self): 723 self._test_forward_per_channel_cachemask_impl('cpu') 724 725 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 726 def test_forward_per_channel_cachemask_cuda(self): 727 self._test_forward_per_channel_cachemask_impl('cuda') 728 729 def test_forward_per_channel_half_precision_numerics(self): 730 scale = torch.randn(5).abs() 731 zero = torch.randn(5).to(dtype=torch.int) 732 axis = 1 733 mini = 0 734 maxi = 255 735 736 for i in range(20): 737 X1 = torch.randn(4, 5).to(torch.float16) 738 Y1 = torch.fake_quantize_per_channel_affine(X1, scale, zero, axis, mini, maxi) 739 Y1r = _fake_quantize_per_channel_affine_reference(X1, scale, zero, axis, mini, maxi) 740 self.assertEqual(Y1, Y1r, rtol=tolerance, atol=tolerance) 741 742 # to force overflow 743 X2 = torch.randn(4, 5).to(torch.float16) 744 X2[0, 0] = 2**15 + .01 745 Y2 = torch.fake_quantize_per_channel_affine(X2, scale, zero, axis, mini, maxi) 746 Y2r = _fake_quantize_per_channel_affine_reference(X2, scale, zero, axis, mini, maxi) 747 self.assertEqual(Y2, Y2r, rtol=tolerance, atol=tolerance) 748 749 scale = torch.zeros(5) + 10 750 751 # to force underflow 752 X3 = torch.randn(4, 5).to(torch.float16) 753 X3[0, 0] = 2**-24 754 Y3 = torch.fake_quantize_per_channel_affine(X3, scale, zero, axis, mini, maxi) 755 Y3r = _fake_quantize_per_channel_affine_reference(X3, scale, zero, axis, mini, maxi) 756 self.assertEqual(Y3, Y3r, rtol=tolerance, atol=tolerance) 757 758 @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), 759 qparams=hu.qparams(dtypes=torch.quint8))) 760 def test_fake_quant_per_channel_qparam_range(self, X): 761 X, (scale, zero_point, axis, torch_type) = X 762 quant_min = torch.iinfo(torch_type).min 763 quant_max = torch.iinfo(torch_type).max 764 765 for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']: 766 X = to_tensor(X, device) 767 scale = to_tensor(scale, device) 768 769 # Ensure that zero_point < quant_min. 770 zero_point = torch.full(zero_point.shape, -1 - quant_min).to(dtype=torch.int32, device=device) 771 772 # For non-float zero_point, fakequant requires zero_point between quant_min and quant_max. 773 with self.assertRaisesRegex(RuntimeError, "`zero_point` must be between `quant_min` and `quant_max`."): 774 Y = torch.fake_quantize_per_channel_affine(X, scale, zero_point, axis, quant_min, quant_max) 775 776 # For float zero_point, fakequant can be outside quant_min and quant_max. 777 for zero_point_dtype in [torch.float32, torch.float16]: 778 zero_point = zero_point.to(dtype=zero_point_dtype) 779 Y = torch.fake_quantize_per_channel_affine(X, scale, zero_point, axis, quant_min, quant_max) 780 Y_ref = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), 781 axis, quant_min, quant_max) 782 np.testing.assert_allclose(Y.cpu().numpy(), Y_ref.cpu().numpy(), rtol=tolerance, atol=tolerance) 783 784 def _test_learnable_forward_per_channel(self, X_base, device, scale_base, zero_point_base, axis): 785 r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op. 786 """ 787 for n_bits in (4, 8): 788 quant_min, quant_max = 0, 2 ** (n_bits) - 1 789 790 scale_base = scale_base.to(device) 791 zero_point_base = zero_point_base.to(device) 792 793 X_curr = X_base.clone() 794 scale_curr = scale_base.clone() 795 zero_point_curr = zero_point_base.clone() 796 797 Y = _fake_quantize_per_channel_affine_reference( 798 X_curr, scale_curr, zero_point_curr.round().clamp(quant_min, quant_max), axis, quant_min, quant_max).to(device) 799 for grad_factor in [0.1, 1.0, 10.0]: 800 Y_prime = torch._fake_quantize_learnable_per_channel_affine( 801 X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device) 802 self.assertTrue( 803 torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance), 804 "Expected kernel forward function to have results match the reference forward function") 805 806 @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), 807 qparams=hu.qparams(dtypes=torch.quint8))) 808 def test_learnable_forward_per_channel_cpu(self, X): 809 torch.random.manual_seed(NP_RANDOM_SEED) 810 X, (_, _, axis, _) = X 811 X_base = torch.tensor(X).to('cpu') 812 channel_size = X_base.size(axis) 813 scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100) 814 zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)) 815 self._test_learnable_forward_per_channel( 816 X_base, 'cpu', scale_base, zero_point_base, axis) 817 818 @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), 819 qparams=hu.qparams(dtypes=torch.quint8))) 820 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 821 def test_learnable_forward_per_channel_cuda(self, X): 822 torch.random.manual_seed(NP_RANDOM_SEED) 823 X, (_, _, axis, _) = X 824 X_base = torch.tensor(X).to('cuda') 825 channel_size = X_base.size(axis) 826 scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100) 827 zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)) 828 self._test_learnable_forward_per_channel( 829 X_base, 'cuda', scale_base, zero_point_base, axis) 830 831 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 832 X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,), 833 qparams=hu.qparams(dtypes=torch.quint8))) 834 @unittest.skip( 835 "this is broken without changes to any relevant code, " 836 "we need to remove hypothesis testing in CI") 837 def test_backward_per_channel(self, device, X): 838 r"""Tests the backward method. 839 """ 840 np.random.seed(NP_RANDOM_SEED) 841 X, (scale, zero_point, axis, torch_type) = X 842 quant_min = torch.iinfo(torch_type).min 843 quant_max = torch.iinfo(torch_type).max 844 zero_point_types = (torch.int, torch.float, torch.float16) 845 846 for zero_point_type in zero_point_types: 847 X = to_tensor(X, device) 848 scale = to_tensor(scale, device) 849 zero_point = to_tensor(zero_point, device).to(dtype=zero_point_type) 850 X.requires_grad_() 851 Y_prime = torch.fake_quantize_per_channel_affine( 852 X, scale, zero_point, axis, quant_min, quant_max) 853 dout = torch.rand_like(X, dtype=torch.float).to(device) 854 dX = _fake_quantize_per_channel_affine_grad_reference( 855 dout, X, scale, zero_point, axis, quant_min, quant_max) 856 Y_prime.backward(dout) 857 np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 858 859 def _test_backward_per_channel_cachemask_impl(self, device): 860 torch_types = (torch.qint8, torch.quint8) 861 float_types = (torch.float32, torch.float16, torch.float64) 862 zero_point_types = (torch.int, torch.float32, torch.float16) 863 864 for torch_type, float_type, zero_point_type in itertools.product(torch_types, float_types, zero_point_types): 865 X = torch.randn(1, 2, 4, 4, dtype=float_type).to(device) 866 # pick the scale + zp so that some values get clipped 867 axis = 1 868 obs = torch.ao.quantization.PerChannelMinMaxObserver(axis, torch_type).to(device) 869 obs(X * 0.75) 870 scale, zero_point = obs.calculate_qparams() 871 # TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast 872 zero_point = zero_point.to(zero_point_type) 873 quant_min, quant_max = obs.quant_min, obs.quant_max 874 X.requires_grad_() 875 Y_prime = torch.fake_quantize_per_channel_affine( 876 X, scale, zero_point, axis, quant_min, quant_max) 877 dout = torch.rand_like(X, dtype=float_type).to(device) 878 dX = _fake_quantize_per_channel_affine_grad_reference( 879 dout, X, scale, zero_point, axis, quant_min, quant_max) 880 Y_prime.backward(dout) 881 np.testing.assert_allclose( 882 dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 883 assert X.grad.dtype == float_type 884 885 886 def test_backward_per_channel_cachemask_cpu(self): 887 self._test_backward_per_channel_cachemask_impl('cpu') 888 889 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 890 def test_backward_per_channel_cachemask_cuda(self): 891 self._test_backward_per_channel_cachemask_impl('cuda') 892 893 def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_point_base, axis): 894 r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op. 895 """ 896 for n_bits in (4, 8): 897 quant_min, quant_max = 0, 2 ** n_bits - 1 898 899 scale_base = scale_base.to(device) 900 zero_point_base = zero_point_base.to(device=device) 901 902 X_curr = X_base.clone() 903 X_curr.requires_grad_() 904 scale_curr = scale_base.clone() 905 scale_curr.requires_grad_() 906 zero_point_curr = zero_point_base.clone() 907 zero_point_curr.requires_grad_() 908 909 for grad_factor in [0.1, 1.0, 10.0]: 910 Y_prime = torch._fake_quantize_learnable_per_channel_affine( 911 X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, grad_factor).to(device) 912 913 dout = torch.rand(X_curr.shape, dtype=torch.float).to(device) 914 dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference( 915 dout, X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, device) 916 Y_prime.backward(dout) 917 918 dX_expected = dX.to(device).detach() 919 dX_actual = X_curr.to(device).grad.detach() 920 dScale_expected = dScale.to(device).detach() 921 dScale_actual = scale_curr.to(device).grad.detach() 922 dZeroPoint_expected = dZeroPoint.to(device).detach() 923 dZeroPoint_actual = zero_point_curr.to(device).grad.detach() 924 tolerance = 1e-4 925 926 self.assertTrue( 927 torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance), 928 f"Expected dX={dX_expected} to match X.grad={dX_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") # noqa: B950 929 self.assertTrue( 930 torch.allclose(dScale_expected * grad_factor, dScale_actual, rtol=tolerance, atol=tolerance), 931 f"Expected dScale={dScale_expected * grad_factor} to match scale.grad={dScale_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") # noqa: B950 932 self.assertTrue( 933 torch.allclose(dZeroPoint_expected * grad_factor, dZeroPoint_actual, rtol=tolerance, atol=tolerance), 934 f"Expected dZeroPoint={dZeroPoint_expected * grad_factor} to match zero_point.grad={dZeroPoint_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") # noqa: B950 935 X_curr.grad.data.zero_() 936 scale_curr.grad.data.zero_() 937 zero_point_curr.grad.data.zero_() 938 939 @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,), 940 qparams=hu.qparams(dtypes=torch.quint8))) 941 @unittest.skip( 942 "this is broken without changes to any relevant code, " 943 "we need to remove hypothesis testing in CI") 944 def test_learnable_backward_per_channel_cpu(self, X): 945 torch.random.manual_seed(NP_RANDOM_SEED) 946 X, (_, _, axis, _) = X 947 X_base = torch.tensor(X).to('cpu') 948 channel_size = X_base.size(axis) 949 scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100) 950 zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,)) 951 self._test_learnable_backward_per_channel( 952 X_base, 'cpu', scale_base, zero_point_base, axis) 953 954 @given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,), 955 qparams=hu.qparams(dtypes=torch.quint8))) 956 @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") 957 def test_learnable_backward_per_channel_cuda(self, X): 958 torch.random.manual_seed(NP_RANDOM_SEED) 959 X, (scale, zero_point, axis, torch_type) = X 960 X_base = torch.tensor(X).to('cuda') 961 scale_base = to_tensor(scale, 'cuda') 962 zero_point_base = to_tensor(zero_point, 'cuda') 963 self._test_learnable_backward_per_channel( 964 X_base, 'cuda', scale_base, zero_point_base, axis) 965 966 def test_numerical_consistency_per_tensor(self): 967 self._test_numerical_consistency('per_tensor') 968 969 def test_numerical_consistency_per_channel(self): 970 self._test_numerical_consistency('per_channel') 971 972 def _test_numerical_consistency(self, test_type): 973 r"""Comparing numerical consistency between quantize/dequantize op and the fake quantize op across devices and dtypes 974 """ 975 torch.random.manual_seed(NP_RANDOM_SEED) 976 torch_types = [torch.qint8, torch.quint8] 977 float_types = [torch.float, torch.float16, torch.float64] 978 if test_type == "per_channel": 979 zero_types = [torch.int, torch.float, torch.float16] 980 else: 981 zero_types = [torch.int] 982 devices = [torch.device('cpu'), torch.device('cuda')] if torch.cuda.is_available() else [torch.device('cpu')] 983 axis = 1 984 for i in range(20): 985 for torch_type, float_type, device, zero_type in itertools.product(torch_types, float_types, devices, zero_types): 986 X = torch.randn(3, 3, device=device).to(float_type) 987 scales = (10 * torch.randn(3, device=device)).abs() 988 scale = scales.mean().to(float).item() 989 zeros = (10 * torch.randn(3, device=device)).abs().to(dtype=zero_type) 990 zero = zeros.max().view(1).item() 991 quant_min = torch.iinfo(torch_type).min 992 quant_max = torch.iinfo(torch_type).max 993 994 test_was_run = False 995 if test_type == "per_tensor": 996 test_was_run = True 997 Y = torch.dequantize(torch.quantize_per_tensor(X.to('cpu').to(torch.float), 998 scale, zero, torch_type)).to(device).to(float_type) 999 Y_prime = torch.fake_quantize_per_tensor_affine(X, scale, zero, quant_min, quant_max) 1000 self.assertEqual( 1001 Y, Y_prime, "Difference found between dequant+quant_per_tensor and fake_quantize_per_tensor") 1002 1003 if test_type == "per_channel": 1004 test_was_run = True 1005 Y = torch.dequantize(torch.quantize_per_channel(X.to('cpu').to(torch.float), scales.to( 1006 'cpu'), zeros.to('cpu'), axis, torch_type)).to(device).to(float_type) 1007 Y_prime = torch.fake_quantize_per_channel_affine(X, scales, zeros, axis, quant_min, quant_max) 1008 self.assertEqual( 1009 Y, Y_prime, "Difference found between dequant+quant_per_channel and fake_quantize_per_channel") 1010 self.assertTrue(test_was_run) 1011 1012 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1013 def test_fake_quantize_per_channel_affine_scale_dtypes(self): 1014 """ 1015 Ensure the error message is more helpful 1016 """ 1017 dtype_list = [torch.float, torch.float64, torch.bfloat16, torch.half] 1018 for scale_dtype in dtype_list: 1019 input = torch.randn(3, 4, 5, 6) 1020 scale = torch.Tensor([0.1, 0.2, 0.3, 0.4]).to(scale_dtype) 1021 zero_point = torch.tensor([1, 2, 3, 4], dtype=torch.int32) 1022 axis = 1 1023 quant_min = 0 1024 quant_max = 255 1025 if scale_dtype != torch.float: 1026 with self.assertRaises(RuntimeError): 1027 torch.fake_quantize_per_channel_affine( 1028 input, scale, zero_point, axis, quant_min, quant_max 1029 ) 1030 else: 1031 torch.fake_quantize_per_channel_affine( 1032 input, scale, zero_point, axis, quant_min, quant_max 1033 ) 1034 1035 1036class TestFusedObsFakeQuant(TestCase): 1037 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 1038 symmetric_quant=st.booleans()) 1039 @settings(deadline=None) 1040 def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None: 1041 """ 1042 Tests the case where we call the fused_obs_fake_quant op multiple times 1043 and update the running_min and max of the activation tensors. 1044 """ 1045 in_running_min_ref = out_running_min_ref = float("inf") 1046 in_running_min_op = torch.tensor(float("inf"), device=device) 1047 in_running_max_ref = out_running_max_ref = float("-inf") 1048 in_running_max_op = torch.tensor(float("-inf"), device=device) 1049 avg_const = 0.01 1050 scale = torch.tensor([1.0], device=device) 1051 zero_point = torch.tensor([0], dtype=torch.int, device=device) 1052 observer_on = fake_quant_on = 0 1053 1054 pt_op = torch.fused_moving_avg_obs_fake_quant 1055 # enable observer after 2 iterations and fake_quant after 4 iterations 1056 for i in range(10): 1057 if i > 2: 1058 observer_on = 1 1059 if i > 4: 1060 fake_quant_on = 1 1061 1062 x = torch.randn(5, 5, device=device) 1063 out = pt_op( 1064 x, 1065 torch.tensor(observer_on, device=device), 1066 torch.tensor(fake_quant_on, device=device), 1067 in_running_min_op, 1068 in_running_max_op, 1069 scale, 1070 zero_point, 1071 avg_const, 1072 0, 1073 255, 1074 0, 1075 False, 1076 symmetric_quant, 1077 ) 1078 if observer_on: 1079 ( 1080 in_running_min_ref, 1081 in_running_max_ref, 1082 ) = _get_tensor_min_max( 1083 x, 1084 running_min=in_running_min_ref, 1085 running_max=in_running_max_ref, 1086 averaging_const=0.01, 1087 ) 1088 1089 if fake_quant_on: 1090 x_scale, x_zero_point = _get_scale_zp( 1091 in_running_min_ref, 1092 in_running_max_ref, 1093 torch.quint8, 1094 preserve_sparsity=symmetric_quant, 1095 ) 1096 x_in = _fake_quantize_per_tensor_affine_reference( 1097 x, x_scale, x_zero_point, 0, 255 1098 ) 1099 self.assertEqual(scale, x_scale) 1100 self.assertEqual(zero_point, x_zero_point) 1101 else: 1102 x_in = x 1103 1104 self.assertEqual(in_running_min_ref, in_running_min_op) 1105 self.assertEqual(in_running_max_ref, in_running_max_op) 1106 torch.testing.assert_close(out, x_in) 1107 1108 # Test empty input works 1109 x = torch.empty(0, 5, device=device) 1110 out = pt_op( 1111 x, 1112 torch.tensor(1, device=device), 1113 torch.tensor(1, device=device), 1114 in_running_min_op, 1115 in_running_max_op, 1116 scale, 1117 zero_point, 1118 avg_const, 1119 0, 1120 255, 1121 0, 1122 False, 1123 symmetric_quant, 1124 ) 1125 output_shape = (0, 5) 1126 self.assertEqual(out.shape, output_shape) 1127 1128 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 1129 symmetric_quant=st.booleans()) 1130 @settings(deadline=None) 1131 def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_quant) -> None: 1132 """ 1133 Tests the case where we call the fused_obs_fake_quant op multiple times 1134 and update the running_min and max of the activation tensors. 1135 """ 1136 m = 5 1137 sizes = [[5, 5], [5, 4, 3]] 1138 for size in sizes: 1139 in_running_min_ref = torch.empty(m, device=device).fill_(float("inf")) 1140 in_running_min_op = torch.empty(m, device=device).fill_(float("inf")) 1141 in_running_max_ref = torch.empty(m, device=device).fill_(float("-inf")) 1142 in_running_max_op = torch.empty(m, device=device).fill_(float("-inf")) 1143 avg_const = 0.01 1144 1145 scale = torch.empty(m, device=device).fill_(0.1) 1146 zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0) 1147 1148 observer_on = fake_quant_on = 0 1149 1150 pt_op = torch.fused_moving_avg_obs_fake_quant 1151 # enable observer after 2 iterations and fake_quant after 4 iterations 1152 for i in range(10): 1153 if i > 2: 1154 observer_on = 1 1155 if i > 4: 1156 fake_quant_on = 1 1157 1158 x = torch.randn(size, device=device) 1159 out = pt_op( 1160 x, 1161 torch.tensor(observer_on, device=device), 1162 torch.tensor(fake_quant_on, device=device), 1163 in_running_min_op, 1164 in_running_max_op, 1165 scale, 1166 zero_point, 1167 avg_const, 1168 0, 1169 255, 1170 0, 1171 True, # per_channel_enabled 1172 symmetric_quant, 1173 ) 1174 if observer_on: 1175 ( 1176 in_running_min_ref, 1177 in_running_max_ref, 1178 ) = _get_per_row_min_max(x, in_running_min_ref, in_running_max_ref) 1179 if fake_quant_on: 1180 x_scale = torch.empty(m, device=device) 1181 x_zero_point = torch.empty(m, dtype=torch.int, device=device) 1182 1183 for i in range(x_scale.numel()): 1184 x_scale[i], x_zero_point[i] = _get_scale_zp( 1185 in_running_min_ref[i].item(), 1186 in_running_max_ref[i].item(), 1187 torch.quint8, 1188 preserve_sparsity=symmetric_quant, 1189 ) 1190 x_in = _fake_quantize_per_channel_affine_reference( 1191 x, x_scale, x_zero_point, 0, 0, 255 1192 ) 1193 self.assertEqual(scale, x_scale) 1194 self.assertEqual(zero_point, x_zero_point) 1195 else: 1196 x_in = x 1197 self.assertEqual(in_running_min_ref, in_running_min_op) 1198 self.assertEqual(in_running_max_ref, in_running_max_op) 1199 torch.testing.assert_close(out, x_in) 1200 1201 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),) 1202 @settings(deadline=None) 1203 def test_fused_obs_fake_quant_backward_op(self, device) -> None: 1204 n = m = k = 10 1205 input_shape = (m, n) 1206 output_shape = (m, n) 1207 1208 x = torch.randn(input_shape, device=device, requires_grad=True) 1209 1210 avg_const = 0.01 1211 scale = torch.tensor([1.0], device=device) 1212 zero_point = torch.tensor([0], dtype=torch.int, device=device) 1213 1214 x_min, x_max = _get_tensor_min_max(x) 1215 x_scale, x_zero_point = _get_scale_zp( 1216 x_min, x_max, torch.quint8 1217 ) 1218 1219 x_scale = torch.tensor(x_scale, device=device) 1220 x_zero_point = torch.tensor(x_zero_point, dtype=torch.int, device=device) 1221 x_fake_quant = torch.fake_quantize_per_tensor_affine( 1222 x, x_scale, x_zero_point, 0, 255 1223 ) 1224 1225 pt_op = torch.fused_moving_avg_obs_fake_quant 1226 out = pt_op( 1227 x, 1228 torch.tensor(1, device=device), 1229 torch.tensor(1, device=device), 1230 torch.tensor(x_min, device=device), 1231 torch.tensor(x_max, device=device), 1232 scale, 1233 zero_point, 1234 avg_const, 1235 0, 1236 255, 1237 0, 1238 False, 1239 ) 1240 # verify the output matches 1241 torch.testing.assert_close(out, x_fake_quant) 1242 1243 # verify the gradient matches expectation of fake_quant op 1244 dout = torch.rand_like(x, dtype=torch.float).to(device) 1245 out.backward(dout) 1246 1247 dX = _fake_quantize_per_tensor_affine_grad_reference( 1248 dout, x, x_scale, x_zero_point, 0, 255) 1249 self.assertEqual(dX, x.grad) 1250 self.assertTrue(x.grad.dtype == torch.float32) 1251 1252 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),) 1253 @settings(deadline=None) 1254 def test_fused_backward_op_fake_quant_off(self, device) -> None: 1255 n = m = 4 1256 input_shape = (m, n) 1257 output_shape = (m, n) 1258 1259 x = torch.randn(input_shape, device=device, requires_grad=True) 1260 1261 avg_const = 0.01 1262 scale = torch.tensor([1.0], device=device) 1263 zero_point = torch.tensor([0], dtype=torch.int, device=device) 1264 1265 x_min, x_max = _get_tensor_min_max(x) 1266 x_scale, x_zero_point = _get_scale_zp( 1267 x_min, x_max, torch.quint8 1268 ) 1269 1270 1271 pt_op = torch.fused_moving_avg_obs_fake_quant 1272 out = pt_op( 1273 x, 1274 torch.tensor(0, device=device), 1275 torch.tensor(0, device=device), 1276 torch.tensor(x_min, device=device), 1277 torch.tensor(x_max, device=device), 1278 scale, 1279 zero_point, 1280 avg_const, 1281 0, 1282 255, 1283 0, 1284 False, 1285 ) 1286 # verify the output matches 1287 torch.testing.assert_close(out, x) 1288 1289 # verify the gradient matches expectation of fake_quant op 1290 dout = torch.rand_like(x, dtype=torch.float).to(device) 1291 out.backward(dout) 1292 1293 dX = _fake_quantize_per_tensor_affine_grad_reference( 1294 dout, x, x_scale, x_zero_point, 0, 255) 1295 self.assertEqual(dX, x.grad) 1296 self.assertTrue(x.grad.dtype == torch.float32) 1297 1298if __name__ == '__main__': 1299 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 1300 "\tpython test/test_quantization.py TESTNAME\n\n" 1301 "instead.") 1302