1# Owner(s): ["oncall: quantization"] 2 3# Torch 4import torch 5from torch.ao.quantization import ( 6 MinMaxObserver, 7 PerChannelMinMaxObserver, 8 MovingAverageMinMaxObserver, 9 MovingAveragePerChannelMinMaxObserver, 10 HistogramObserver, 11 RecordingObserver, 12 PlaceholderObserver, 13 NoopObserver, 14 FakeQuantize, 15 FixedQParamsObserver, 16 default_debug_qconfig, 17 default_observer, 18 default_histogram_observer, 19 default_per_channel_weight_observer, 20 prepare, 21 prepare_qat, 22 convert, 23 QConfig, 24 FusedMovingAvgObsFakeQuantize, 25 get_embedding_qat_module_mappings, 26 get_embedding_static_quant_module_mappings, 27) 28from torch.ao.quantization.quantize import _get_observer_dict 29 30import torch.nn as nn 31 32# Standard library 33import copy 34import io 35import itertools 36import unittest 37import math 38import numpy as np 39 40# Testing utils 41from hypothesis import given, settings 42from hypothesis import strategies as st 43import torch.testing._internal.hypothesis_utils as hu 44hu.assert_deadline_disabled() 45from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA 46from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo 47from torch.testing._internal.common_quantization import ( 48 QuantizationTestCase, 49 AnnotatedSingleLayerLinearModel, 50 test_only_eval_fn, 51 SingleLayerLinearModel, 52) 53 54from torch.testing._internal.common_quantized import ( 55 override_quantized_engine, 56 supported_qengines, 57 override_qengines, 58 _fake_quantize_per_channel_affine_reference, 59 _fake_quantize_per_channel_affine_grad_reference, 60 to_tensor, 61) 62 63from torch.testing._internal.common_quantization import ( 64 DeFusedEmbeddingBagLinear, 65) 66 67NP_RANDOM_SEED = 19 68tolerance = 1e-6 69 70class TestObserver(QuantizationTestCase): 71 @given(qdtype=st.sampled_from((torch.qint8, torch.quint8, torch.qint32)), 72 qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), 73 reduce_range=st.booleans()) 74 def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): 75 # reduce_range cannot be true for symmetric quantization with uint8 76 if (qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric) or qdtype == torch.qint32: 77 reduce_range = False 78 ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range), 79 MovingAverageMinMaxObserver(averaging_constant=0.5, 80 dtype=qdtype, 81 qscheme=qscheme, 82 reduce_range=reduce_range)] 83 84 def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val): 85 eps = torch.tensor([tolerance]) 86 if dtype == torch.qint8: 87 if reduce_range: 88 quant_min, quant_max = -64, 63 89 else: 90 quant_min, quant_max = -128, 127 91 elif dtype == torch.quint8: 92 if reduce_range: 93 quant_min, quant_max = 0, 127 94 else: 95 quant_min, quant_max = 0, 255 96 elif dtype == torch.qint32: 97 quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1 98 99 min_val_neg = torch.tensor([0.]) 100 max_val_pos = torch.tensor([input_scale * max_val]) if qdtype is torch.qint32 else torch.tensor([max_val]) 101 102 scale, zero_point = 1.0, 0 103 if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: 104 scale = torch.max(-min_val_neg, max_val_pos) / (float(quant_max - quant_min) / 2) 105 scale = torch.max(scale, eps) 106 if dtype == torch.quint8: 107 zero_point = 128 108 else: 109 scale = torch.max((max_val_pos - min_val_neg) / float(quant_max - quant_min), eps) 110 zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) 111 zero_point = torch.clamp(zero_point, quant_min, quant_max) 112 return scale, zero_point 113 114 for myobs in ObserverList: 115 # Calculate Qparams should return with a warning for observers with no data 116 qparams = myobs.calculate_qparams() 117 input_scale = 2**16 if qdtype is torch.qint32 else 1 118 if type(myobs) == MinMaxObserver: 119 x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale 120 y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) * input_scale 121 else: 122 # Moving average of min/max for x and y matches that of 123 # extreme values for x/y used for minmax observer 124 x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale 125 y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0]) * input_scale 126 127 result = myobs(x) 128 result = myobs(y) 129 self.assertEqual(result, y) 130 self.assertEqual(myobs.min_val, 1.0 * input_scale) 131 self.assertEqual(myobs.max_val, 8.0 * input_scale) 132 qparams = myobs.calculate_qparams() 133 ref_scale, ref_zero_point = _get_ref_params(reduce_range, qscheme, qdtype, input_scale, 1.0, 8.0) 134 135 self.assertEqual(qparams[1].item(), ref_zero_point) 136 self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0) 137 state_dict = myobs.state_dict() 138 b = io.BytesIO() 139 torch.save(state_dict, b) 140 for weights_only in [True, False]: 141 b.seek(0) 142 loaded_dict = torch.load(b, weights_only=weights_only) 143 for key in state_dict: 144 self.assertEqual(state_dict[key], loaded_dict[key]) 145 loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) 146 loaded_obs.load_state_dict(loaded_dict) 147 loaded_qparams = loaded_obs.calculate_qparams() 148 self.assertEqual(myobs.min_val, loaded_obs.min_val) 149 self.assertEqual(myobs.max_val, loaded_obs.max_val) 150 self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) 151 152 153 @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), 154 qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams)), 155 ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans()) 156 def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): 157 # reduce_range cannot be true for symmetric quantization with uint8 158 if qscheme == torch.per_channel_affine_float_qparams: 159 reduce_range = False 160 if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: 161 reduce_range = False 162 ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range, 163 ch_axis=ch_axis, 164 dtype=qdtype, 165 qscheme=qscheme), 166 MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5, 167 reduce_range=reduce_range, 168 ch_axis=ch_axis, 169 dtype=qdtype, 170 qscheme=qscheme)] 171 172 for myobs in ObserverList: 173 # Calculate qparams should work for empty observers 174 qparams = myobs.calculate_qparams() 175 x = torch.tensor( 176 [ 177 [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], 178 [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], 179 ] 180 ) 181 if type(myobs) == MovingAveragePerChannelMinMaxObserver: 182 # Scaling the input tensor to model change in min/max values 183 # across batches 184 result = myobs(0.5 * x) 185 result = myobs(1.5 * x) 186 self.assertEqual(result, 1.5 * x) 187 else: 188 result = myobs(x) 189 self.assertEqual(result, x) 190 191 qparams = myobs.calculate_qparams() 192 ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] 193 ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] 194 per_channel_symmetric_ref_scales = [ 195 [0.04705882, 0.06274509], 196 [0.03921569, 0.0627451], 197 [0.04705882, 0.0627451], 198 [0.05490196, 0.0627451], 199 ] 200 per_channel_affine_ref_scales = [ 201 [0.02352941, 0.04705882], 202 [0.03529412, 0.03137255], 203 [0.03921569, 0.03137255], 204 [0.04313726, 0.04313726], 205 ] 206 per_channel_affine_qint8_zp = [ 207 [-128, -43], 208 [-15, -128], 209 [-26, -128], 210 [-35, -58], 211 ] 212 per_channel_affine_float_qparams_ref_scales = [ 213 [0.0196, 0.0471], 214 [0.0353, 0.0196], 215 [0.0392, 0.0235], 216 [0.0431, 0.0431], 217 ] 218 per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] 219 220 self.assertEqual(myobs.min_val, ref_min_vals[ch_axis]) 221 self.assertEqual(myobs.max_val, ref_max_vals[ch_axis]) 222 if qscheme == torch.per_channel_symmetric: 223 ref_scales = per_channel_symmetric_ref_scales[ch_axis] 224 ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] 225 elif qscheme == torch.per_channel_affine_float_qparams: 226 ref_scales = per_channel_affine_float_qparams_ref_scales[ch_axis] 227 ref_zero_points = [-1 * ref_min_vals[ch_axis][i] / ref_scales[i] for i in range(len(ref_scales))] 228 else: 229 ref_scales = per_channel_affine_ref_scales[ch_axis] 230 ref_zero_points = ( 231 per_channel_affine_qint8_zp[ch_axis] 232 if qdtype is torch.qint8 233 else per_channel_affine_quint8_zp[ch_axis] 234 ) 235 236 if reduce_range: 237 ref_scales = [s * 255 / 127 for s in ref_scales] 238 ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] 239 self.assertEqual(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), rtol=1e-5, atol=0.0001) 240 if qscheme == torch.per_channel_affine_float_qparams: 241 self.assertEqual(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), rtol=1e-5, atol=1) 242 else: 243 self.assertEqual(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)) 244 245 246 # Test for serializability 247 state_dict = myobs.state_dict() 248 b = io.BytesIO() 249 torch.save(state_dict, b) 250 b.seek(0) 251 loaded_dict = torch.load(b) 252 for key in state_dict: 253 self.assertEqual(state_dict[key], loaded_dict[key]) 254 loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) 255 loaded_obs.load_state_dict(loaded_dict) 256 loaded_qparams = loaded_obs.calculate_qparams() 257 self.assertEqual(myobs.min_val, loaded_obs.min_val) 258 self.assertEqual(myobs.max_val, loaded_obs.max_val) 259 self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) 260 261 262 def test_observer_scriptable(self): 263 obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver()] 264 for obs in obs_list: 265 scripted = torch.jit.script(obs) 266 267 x = torch.rand(3, 4) 268 obs(x) 269 scripted(x) 270 self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) 271 272 buf = io.BytesIO() 273 torch.jit.save(scripted, buf) 274 buf.seek(0) 275 loaded = torch.jit.load(buf) 276 self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) 277 278 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 279 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 280 @override_qengines 281 def test_state_dict_respects_device_affinity(self): 282 """ 283 Tests that loading from a state dict loads buffers to the correct 284 device. 285 """ 286 device_cpu = torch.device('cpu') 287 device_cuda = torch.device('cuda:0') 288 test_cases = itertools.product( 289 [device_cpu, device_cuda], 290 [device_cpu, device_cuda], 291 [MinMaxObserver, MovingAverageMinMaxObserver, 292 PerChannelMinMaxObserver, 293 MovingAveragePerChannelMinMaxObserver, 294 # TODO: enable this (separate PR) 295 # HistogramObserver, 296 PlaceholderObserver, RecordingObserver, NoopObserver, 297 FakeQuantize]) 298 299 for device_source, device_target, obs_cls in test_cases: 300 # calibrated source model 301 model = obs_cls() 302 model.to(device_source) 303 model(torch.randn(4, 1, 4, 4, device=device_source)) 304 # target model 305 model2 = obs_cls() 306 model2.to(device_target) 307 model2.load_state_dict(model.state_dict()) 308 # verify that buffers stayed on model2's device 309 model_devices = {p.device for p in model2.parameters()} | \ 310 {p.device for p in model2.buffers()} 311 # some observers do not have any buffers, so lessEqual instead of 312 # Equal 313 self.assertLessEqual(len(model_devices), 1) 314 if len(model_devices) == 1: 315 model_device = next(iter(model_devices)) 316 self.assertEqual(model_device, device_target) 317 318 def test_histogram_observer_consistent_buffer_shape(self): 319 """ 320 Ensures that the buffer shapes do not change from uninitialized to 321 initialized states for HistogramObserver. 322 """ 323 obs = HistogramObserver() 324 min_shape_before = obs.min_val.shape 325 max_shape_before = obs.max_val.shape 326 for _ in range(2): 327 obs(torch.randn(4, 4, 4, 4)) 328 self.assertEqual(min_shape_before, obs.min_val.shape) 329 self.assertEqual(max_shape_before, obs.max_val.shape) 330 331 def test_histogram_observer_ignore_infinity(self): 332 """ 333 Ensures that HistogramObserver doesn't record values of infinity 334 """ 335 obs = HistogramObserver() 336 obs2 = HistogramObserver() 337 x = torch.randn(4, 4, 4, 4) 338 obs(x * torch.inf) 339 obs(x) 340 obs2(x) 341 obs(x * torch.inf) 342 self.assertTrue(obs.min_val != -torch.inf and obs.max_val != torch.inf) 343 self.assertEqual(obs.histogram, obs2.histogram) 344 345 def test_histogram_observer_handle_close_to_infinity(self): 346 for sign in [-1, 1]: 347 obser = HistogramObserver.with_args(reduce_range=False)() 348 mask = torch.tensor([-3.4028234663852886 * 10**30, 0, 0, 0]) * sign 349 obser(mask) 350 obser(mask - sign) 351 scale, zp = obser.calculate_qparams() 352 353 input = torch.randn(1, 4) 354 ref_result = torch.softmax(input + mask, dim=1) 355 356 quant_mask = torch.quantize_per_tensor(mask, scale, zp, torch.quint8) 357 dequant_mask = quant_mask.dequantize() 358 result = torch.softmax(input + dequant_mask, dim=1) 359 self.assertEqual(result, ref_result) 360 361 def test_histogram_observer_handle_OOM_due_to_close_min_max_value(self): 362 obser = HistogramObserver.with_args(reduce_range=False)() 363 # close min and max value in the 1st forward() pass of observer tends 364 # to cause OOM in the following pass. 365 # This is due to the allocation of histogram tensor during _combine_histograms(). 366 # With sanity check on the size of histogram tensor, we expect the histogram observer 367 # can still work by resetting the histogram 368 x1 = torch.tensor([0, 1e-9]) 369 obser(x1) 370 371 x2 = torch.tensor([2.0, 3.0]) 372 obser(x2) 373 374 def test_histogram_observer_save_load_state_dict(self): 375 """ 376 Smoke test on saving/loading state_dict 377 """ 378 obs1 = HistogramObserver() 379 obs1(torch.randn(4, 4, 4, 4)) 380 obs2 = HistogramObserver() 381 obs2.load_state_dict(obs1.state_dict()) 382 self.assertEqual(obs2.min_val.shape, torch.Size([])) 383 self.assertEqual(obs2.max_val.shape, torch.Size([])) 384 385 386 def test_save_load_state_dict_script(self): 387 """ 388 Tests that we can save and load state_dict for observers that are scripted 389 in a quantized model. 390 """ 391 obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver] 392 393 for obs in obs_list: 394 model = SingleLayerLinearModel().eval() 395 qconfig = QConfig(activation=default_observer, weight=obs) 396 qconfig_dict = {'' : qconfig} 397 scripted = torch.jit.script(model) 398 scripted = torch.ao.quantization.prepare_jit(scripted, qconfig_dict) 399 x = torch.rand(5, 5) 400 scripted(x) 401 obs_dict = torch.ao.quantization.get_observer_state_dict(scripted) 402 403 # Load stats 404 scripted_2 = torch.jit.script(model) 405 scripted_2 = torch.ao.quantization.prepare_jit(scripted_2, qconfig_dict) 406 torch.ao.quantization.load_observer_state_dict(scripted_2, obs_dict) 407 # Verify that state_dict matches exactly with original one. 408 self.assertEqual(scripted.state_dict(), scripted_2.state_dict()) 409 410 411 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 412 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 413 def test_observer_qparams_respects_device_affinity(self): 414 """ 415 Ensure that the scale and zero_point returned by the observer 416 are on the same device as the input tensor. 417 """ 418 observerList = [MinMaxObserver(), 419 MovingAverageMinMaxObserver(), 420 PerChannelMinMaxObserver(), 421 MovingAveragePerChannelMinMaxObserver()] 422 for obs in observerList: 423 device = torch.device('cuda:1') 424 x = torch.randn(1, 2, device=device) 425 obs.to(device) 426 result = obs(x) 427 scale, zero_point = obs.calculate_qparams() 428 429 self.assertEqual(x.device, scale.device) 430 self.assertEqual(x.device, zero_point.device) 431 432 def test_zero_numel(self): 433 obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, 434 PerChannelMinMaxObserver, 435 MovingAveragePerChannelMinMaxObserver, HistogramObserver, 436 FakeQuantize, FixedQParamsObserver] 437 for obs_cls in obs_list: 438 if obs_cls is FixedQParamsObserver: 439 obs = obs_cls(0.1, 0) 440 else: 441 obs = obs_cls() 442 x = torch.tensor([]) 443 # verify no crash 444 x = obs(x) 445 446 def test_dynamic_quant_observer(self): 447 obs = MovingAverageMinMaxObserver(averaging_constant=1, is_dynamic=True) 448 x = torch.randn((3, 3)) 449 obs(x) 450 params = obs.calculate_qparams() 451 for _ in range(20): 452 obs(10 * torch.randn((3, 3))) 453 self.assertNotEqual(params, obs.calculate_qparams()) 454 obs(x) 455 self.assertEqual(params, obs.calculate_qparams()) 456 457 def test_dynamic_quant_observer_matching_choose_qparams(self): 458 obs = MovingAverageMinMaxObserver(averaging_constant=1, is_dynamic=True) 459 for x in [torch.randn(3, 3), torch.rand(3, 3, 3), torch.randn(3, 3, 3, 3)]: 460 obs(x) 461 params = obs.calculate_qparams() 462 scale, zero_point = torch._choose_qparams_per_tensor(x) 463 self.assertEqual(scale, params[0]) 464 self.assertEqual(zero_point, params[1]) 465 466 def test_per_channel_observers_load_state_dict(self): 467 observer_list = [PerChannelMinMaxObserver, MovingAveragePerChannelMinMaxObserver] 468 469 for obs_cls in observer_list: 470 obs = obs_cls() 471 obs(torch.randn((32, 32))) 472 new_obs = obs_cls() 473 # make sure the state_dict can be loaded 474 new_obs.load_state_dict(obs.state_dict()) 475 self.assertTrue(torch.equal(obs.min_val, new_obs.min_val)) 476 self.assertTrue(torch.equal(obs.max_val, new_obs.max_val)) 477 478# HistogramObserver that works like it does on master 479class _ReferenceHistogramObserver(HistogramObserver): 480 def __init__(self, *args, **kwargs): 481 super().__init__(*args, **kwargs) 482 483 @torch.jit.ignore 484 def _non_linear_param_search(self): 485 r"""Non-linear parameter search. 486 487 An approximation for L2 error minimization for selecting min/max. 488 By selecting new min/max, we filter out outliers in input distribution. 489 This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in 490 caffe2/quantization/server/norm_minimization.cc 491 """ 492 def _get_norm(delta_begin, delta_end, density, norm_type): 493 r""" 494 Compute the norm of the values uniformaly distributed between 495 delta_begin and delta_end. 496 497 norm = density * (integral_{begin, end} x^2) 498 = density * (end^3 - begin^3) / 3 499 """ 500 assert norm_type == "L2", "Only L2 norms are currently supported" 501 norm = 0.0 502 if norm_type == "L2": 503 norm = ( 504 delta_end * delta_end * delta_end 505 - delta_begin * delta_begin * delta_begin 506 ) / 3 507 return density * norm 508 509 def _compute_quantization_error(next_start_bin, next_end_bin, norm_type): 510 r""" 511 Compute the quantization error if we use start_bin to end_bin as the 512 min and max to do the quantization. 513 """ 514 bin_width = (self.max_val.item() - self.min_val.item()) / self.bins 515 516 norm = 0.0 517 dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins 518 if dst_bin_width == 0.0: 519 return 0.0 520 for src_bin in range(self.bins): 521 # distances from the beginning of first dst_bin to the beginning and 522 # end of src_bin 523 src_bin_begin = (src_bin - next_start_bin) * bin_width 524 src_bin_end = src_bin_begin + bin_width 525 526 # which dst_bins the beginning and end of src_bin belong to? 527 dst_bin_of_begin = min( 528 self.dst_nbins - 1, max(0.0, math.floor(src_bin_begin / dst_bin_width)) 529 ) 530 dst_bin_of_end = min( 531 self.dst_nbins - 1, max(0.0, math.floor(src_bin_end / dst_bin_width)) 532 ) 533 dst_bin_of_begin_center = ( 534 dst_bin_of_begin * dst_bin_width + dst_bin_width / 2 535 ) 536 537 density = self.histogram[src_bin] / bin_width 538 if dst_bin_of_begin == dst_bin_of_end: 539 # if src_bin is entirely within 1 dst_bin 540 delta_begin = src_bin_begin - dst_bin_of_begin_center 541 delta_end = src_bin_end - dst_bin_of_begin_center 542 norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) 543 else: 544 delta_begin = src_bin_begin - dst_bin_of_begin_center 545 delta_end = dst_bin_width / 2 546 norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) 547 548 norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm( 549 -dst_bin_width / 2, dst_bin_width / 2, density, norm_type 550 ) 551 552 dst_bin_of_end_center = ( 553 dst_bin_of_end * dst_bin_width + dst_bin_width / 2 554 ) 555 556 delta_begin = -dst_bin_width / 2 557 delta_end = src_bin_end - dst_bin_of_end_center 558 norm = norm + _get_norm(delta_begin, delta_end, density, norm_type) 559 return norm 560 561 assert self.histogram.size()[0] == self.bins, "bins mistmatch" 562 bin_width = (self.max_val - self.min_val) / self.bins 563 564 # cumulative sum 565 total = torch.sum(self.histogram).item() 566 cSum = torch.cumsum(self.histogram, dim=0) 567 568 stepsize = 1e-5 # granularity 569 alpha = 0.0 # lower bound 570 beta = 1.0 # upper bound 571 start_bin = 0 572 end_bin = self.bins - 1 573 norm_min = float("inf") 574 575 while alpha < beta: 576 # Find the next step 577 next_alpha = alpha + stepsize 578 next_beta = beta - stepsize 579 580 # find the left and right bins between the quantile bounds 581 l = start_bin 582 r = end_bin 583 while l < end_bin and cSum[l] < next_alpha * total: 584 l = l + 1 585 while r > start_bin and cSum[r] > next_beta * total: 586 r = r - 1 587 588 # decide the next move 589 next_start_bin = start_bin 590 next_end_bin = end_bin 591 if (l - start_bin) > (end_bin - r): 592 # move the start bin 593 next_start_bin = l 594 alpha = next_alpha 595 else: 596 # move the end bin 597 next_end_bin = r 598 beta = next_beta 599 600 if next_start_bin == start_bin and next_end_bin == end_bin: 601 continue 602 603 # calculate the quantization error using next_start_bin and next_end_bin 604 norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2") 605 606 if norm > norm_min: 607 break 608 norm_min = norm 609 start_bin = next_start_bin 610 end_bin = next_end_bin 611 612 new_min = self.min_val + bin_width * start_bin 613 new_max = self.min_val + bin_width * (end_bin + 1) 614 return new_min, new_max 615 616class TestRecordHistogramObserver(QuantizationTestCase): 617 # TODO: move this to quantize.py 618 def test_record_observer(self): 619 for qengine in supported_qengines: 620 with override_quantized_engine(qengine): 621 model = AnnotatedSingleLayerLinearModel() 622 model.qconfig = default_debug_qconfig 623 model = prepare(model) 624 # run the evaluation and dump all tensors 625 test_only_eval_fn(model, self.calib_data) 626 test_only_eval_fn(model, self.calib_data) 627 observer_dict = {} 628 _get_observer_dict(model, observer_dict) 629 630 self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), 631 'observer is not recorded in the dict') 632 self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 633 2 * len(self.calib_data)) 634 self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0], 635 model(self.calib_data[0][0])) 636 637 @given(qdtype=st.sampled_from((torch.qint8, torch.quint8))) 638 def test_observer_scriptable(self, qdtype): 639 obs = RecordingObserver(dtype=qdtype) 640 scripted = torch.jit.script(obs) 641 642 x = torch.rand(3, 4) 643 obs(x) 644 scripted(x) 645 self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0])) 646 buf = io.BytesIO() 647 torch.jit.save(scripted, buf) 648 buf.seek(0) 649 loaded = torch.jit.load(buf) 650 self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0])) 651 652class TestHistogramObserver(QuantizationTestCase): 653 @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), 654 qscheme=st.sampled_from( 655 (torch.per_tensor_affine, torch.per_tensor_symmetric)) 656 ) 657 def test_observer_scriptable(self, qdtype, qscheme): 658 ob_list = [ 659 HistogramObserver(dtype=qdtype, qscheme=qscheme), 660 default_histogram_observer() 661 ] 662 for obs in ob_list: 663 scripted = torch.jit.script(obs) 664 665 x = torch.rand(3, 4) 666 obs(x) 667 scripted(x) 668 self.assertTrue(torch.equal(obs.histogram, scripted.histogram)) 669 buf = io.BytesIO() 670 torch.jit.save(scripted, buf) 671 buf.seek(0) 672 loaded = torch.jit.load(buf) 673 self.assertTrue(torch.equal(obs.histogram, scripted.histogram)) 674 675 @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), 676 qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), 677 reduce_range=st.booleans()) 678 @settings(max_examples=10) 679 def test_histogram_observer(self, qdtype, qscheme, reduce_range): 680 myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) 681 # Calculate qparams should work for empty observers 682 qparams = myobs.calculate_qparams() 683 x = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True) 684 y = torch.tensor([5.0, 6.0, 7.0, 8.0]) 685 out_x = myobs(x) 686 self.assertTrue(out_x.requires_grad) 687 myobs(y) 688 self.assertEqual(myobs.min_val, 2.0) 689 self.assertEqual(myobs.max_val, 8.0) 690 self.assertEqual(myobs.histogram, [2., 3., 3.]) 691 692 qparams = myobs.calculate_qparams() 693 694 if reduce_range: 695 if qscheme == torch.per_tensor_symmetric: 696 ref_scale = 0.0470588 * 255 / 127 697 ref_zero_point = 0 if qdtype is torch.qint8 else 128 698 else: 699 ref_scale = 0.0235294 * 255 / 127 700 ref_zero_point = -64 if qdtype is torch.qint8 else 0 701 else: 702 if qscheme == torch.per_tensor_symmetric: 703 ref_scale = 0.0470588 704 ref_zero_point = 0 if qdtype is torch.qint8 else 128 705 else: 706 ref_scale = 0.0235294 707 ref_zero_point = -128 if qdtype is torch.qint8 else 0 708 709 self.assertEqual(qparams[1].item(), ref_zero_point) 710 self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0) 711 # Test for serializability 712 state_dict = myobs.state_dict() 713 b = io.BytesIO() 714 torch.save(state_dict, b) 715 b.seek(0) 716 loaded_dict = torch.load(b) 717 for key in state_dict: 718 self.assertEqual(state_dict[key], loaded_dict[key]) 719 loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) 720 loaded_obs.load_state_dict(loaded_dict) 721 loaded_qparams = loaded_obs.calculate_qparams() 722 self.assertEqual(myobs.min_val, loaded_obs.min_val) 723 self.assertEqual(myobs.max_val, loaded_obs.max_val) 724 self.assertEqual(myobs.histogram, loaded_obs.histogram) 725 self.assertEqual(myobs.bins, loaded_obs.bins) 726 self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) 727 728 def test_histogram_observer_one_sided(self): 729 myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) 730 x = torch.tensor([0.0, 0.3, 1.2, 1.7]) 731 y = torch.tensor([0.1, 1.3, 2.0, 2.7]) 732 myobs(x) 733 myobs(y) 734 self.assertEqual(myobs.min_val, 0) 735 qparams = myobs.calculate_qparams() 736 self.assertEqual(qparams[1].item(), 0) 737 738 def test_histogram_observer_same_inputs(self): 739 myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, 740 reduce_range=False) 741 w = torch.ones(4, requires_grad=True) 742 x = torch.zeros(4, requires_grad=True) 743 y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True) 744 z = torch.tensor([5.0, 6.0, 7.0, 8.0]) 745 myobs(w) 746 myobs(x) 747 myobs(x) 748 myobs(y) 749 myobs(z) 750 qparams = myobs.calculate_qparams() 751 self.assertEqual(myobs.min_val, 0.0) 752 self.assertEqual(myobs.max_val, 8.0) 753 self.assertEqual(myobs.histogram, [13.25, 3.75, 3.]) 754 755 @skipIfTorchDynamo("too slow") 756 @given(N=st.sampled_from([10, 1000]), 757 bins=st.sampled_from([256, 512, 1024, 2048]), 758 dtype=st.sampled_from([torch.qint8, torch.quint8]), 759 qscheme=st.sampled_from([torch.per_tensor_affine, torch.per_tensor_symmetric]), 760 reduce_range=st.booleans()) 761 def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, reduce_range): 762 763 ref_obs = _ReferenceHistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range) 764 my_obs = HistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range) 765 766 for _ in range(10): 767 X = torch.randn(N) 768 my_obs(X) 769 ref_obs(X) 770 self.assertEqual(my_obs.histogram, ref_obs.histogram) 771 self.assertEqual(my_obs.min_val, ref_obs.min_val) 772 self.assertEqual(my_obs.max_val, ref_obs.max_val) 773 774 ref_qparams = ref_obs.calculate_qparams() 775 my_qparams = my_obs.calculate_qparams() 776 777 for i in range(0, bins, 200): 778 for j in range(i + 5, bins, 200): 779 ref_qe = ref_obs._compute_quantization_error(i, j) 780 qe = my_obs._compute_quantization_error(i, j) 781 self.assertEqual(ref_qe, qe) 782 783 self.assertEqual(ref_qparams, my_qparams) 784 785 def test_histogram_observer_extreme_inputs(self): 786 """ 787 Ensures that the HistogramObserver is able to work correctly in 788 a rare case: extreme samll max values 789 """ 790 obs = HistogramObserver() 791 test_input = torch.tensor( 792 [0.0, 0.0, 4.58e-41, 4.58e-41] 793 ) 794 # Make sure it runs, two passes are required based on the behavior of forward func 795 # The first pass initializes min_val&max_val, and second pass calls _adjust_min_max 796 obs(test_input) 797 obs(test_input) 798 799 def test_histogram_observer_correct_numel(self): 800 for i in range(1, 10): 801 obs = HistogramObserver() 802 obs(torch.randn(i, i)) 803 self.assertEqual(obs.histogram.sum().item(), i**2) 804 805 def test_histogram_observer_single_inputs(self): 806 # Make sure that if we pass single valued tensors to the observer, the code runs 807 observer = HistogramObserver(bins=10) 808 a = torch.FloatTensor([1]) 809 b = torch.FloatTensor([3]) 810 c = torch.FloatTensor([2]) 811 d = torch.FloatTensor([4]) 812 813 observer(a) 814 observer(b) 815 observer(c) 816 observer(d) 817 818 self.assertEqual(observer.min_val, 1) 819 self.assertEqual(observer.max_val, 4) 820 self.assertEqual(torch.sum(observer.histogram), 4) 821 822 def test_histogram_observer_update_within_range_succeeds(self): 823 # test if an update within the existing range actually updates 824 myobs = HistogramObserver(bins=10) 825 x = torch.tensor([0.0, 3.0, 4.0, 9.0]) 826 y = torch.tensor([2.0, 3.0, 7.0, 8.0]) 827 myobs(x) 828 myobs(y) 829 self.assertEqual(myobs.min_val, 0.0) 830 self.assertEqual(myobs.max_val, 9.0) 831 self.assertEqual(myobs.histogram, [1., 0., 1., 2., 1., 0., 0., 1., 1., 1.]) 832 833class TestFakeQuantize(TestCase): 834 @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), 835 X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,), 836 qparams=hu.qparams(dtypes=torch.qint8))) 837 def test_fq_module_per_channel(self, device, X): 838 np.random.seed(NP_RANDOM_SEED) 839 X, (scale, zero_point, axis, torch_type) = X 840 quant_min = torch.iinfo(torch_type).min 841 quant_max = torch.iinfo(torch_type).max 842 843 X = to_tensor(X, device) 844 X.requires_grad_() 845 fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device) 846 Y_prime = fq_module(X) 847 assert fq_module.scale is not None 848 assert fq_module.zero_point is not None 849 Y = _fake_quantize_per_channel_affine_reference(X, fq_module.scale, 850 fq_module.zero_point, axis, quant_min, quant_max) 851 np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 852 853 # Test backward 854 dout = torch.rand_like(X, dtype=torch.float, device=device) 855 Y_prime.backward(dout) 856 dX = _fake_quantize_per_channel_affine_grad_reference(dout, X, fq_module.scale, 857 fq_module.zero_point, axis, quant_min, quant_max) 858 np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) 859 860 def test_fq_serializable_per_channel(self): 861 observer = default_per_channel_weight_observer 862 quant_min = -128 863 quant_max = 127 864 fq_module = FakeQuantize(observer, quant_min, quant_max) 865 X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32) 866 y_ref = fq_module(X) 867 state_dict = fq_module.state_dict() 868 self.assertEqual(state_dict['scale'], [0.054902, 0.078431]) 869 self.assertEqual(state_dict['zero_point'], [0, 0]) 870 b = io.BytesIO() 871 torch.save(state_dict, b) 872 b.seek(0) 873 loaded_dict = torch.load(b) 874 for key in state_dict: 875 self.assertEqual(state_dict[key], loaded_dict[key]) 876 877 def test_quant_min_max_override(self): 878 observer = default_per_channel_weight_observer 879 # test no override 880 fq_module = FakeQuantize(observer) 881 self.assertEqual(fq_module.activation_post_process.quant_min, -128) 882 self.assertEqual(fq_module.activation_post_process.quant_max, 127) 883 # test quant_min/quant_max override 884 fq_module = FakeQuantize(observer, quant_min=0, quant_max=127) 885 self.assertEqual(fq_module.activation_post_process.quant_min, 0) 886 self.assertEqual(fq_module.activation_post_process.quant_max, 127) 887 888def _get_buffer_ids(module): 889 """ 890 Object addresses stay constant if and only if all modifications are in-place 891 """ 892 return [id(v) for k, v in module._buffers.items()] 893 894class TestDistributed(QuantizationTestCase): 895 896 def test_observers_preserve_buffers(self): 897 """ 898 Tests that observers only modify buffers in place. Note: this is important 899 because nn.DataParallel depends on this assumption to work correctly. 900 However, DataParallel does not expose IDs of the replicas, so we test it 901 without DataParallel in order to easily access the object IDs. 902 """ 903 observer_types = [ 904 torch.ao.quantization.MinMaxObserver.with_args(dtype=torch.qint8), 905 torch.ao.quantization.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8), 906 torch.ao.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.qint8), 907 torch.ao.quantization.MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8), 908 torch.ao.quantization.HistogramObserver.with_args(dtype=torch.qint8), 909 torch.ao.quantization.RecordingObserver.with_args(dtype=torch.qint8), 910 torch.ao.quantization.PlaceholderObserver.with_args(dtype=torch.float16), 911 ] 912 913 for observer_type in observer_types: 914 observer = observer_type() 915 buffer_ids_before = _get_buffer_ids(observer) 916 for _i in range(5): 917 inputs = torch.rand((4, 4, 4)) 918 observer(inputs) 919 buffer_ids_after = _get_buffer_ids(observer) 920 self.assertEqual( 921 buffer_ids_before, 922 buffer_ids_after, 923 msg=f"{str(observer)}: Buffers must be modified in place") 924 925 def test_fake_quant_preserves_buffers(self): 926 """ 927 Tests that fake quant only modifies buffers in place. Note: this is important 928 because nn.DataParallel depends on this assumption to work correctly. 929 However, DataParallel does not expose IDs of the replicas, so we test it 930 without DataParallel in order to easily access the object IDs. 931 """ 932 model = torch.ao.quantization.FakeQuantize() 933 buffer_ids_before = _get_buffer_ids(model) 934 for _i in range(5): 935 inputs = torch.rand((4, 4, 4)) 936 model(inputs) 937 model.apply(torch.ao.quantization.enable_fake_quant) 938 model.apply(torch.ao.quantization.disable_fake_quant) 939 model.apply(torch.ao.quantization.enable_observer) 940 model.apply(torch.ao.quantization.disable_observer) 941 buffer_ids_after = _get_buffer_ids(model) 942 self.assertEqual( 943 buffer_ids_before, 944 buffer_ids_after, 945 msg="FakeQuant: Buffers must be modified in place") 946 947 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 948 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 949 def test_qat_data_parallel(self): 950 """ 951 Tests that doing QAT in nn.DataParallel does not crash. 952 """ 953 if 'fbgemm' not in torch.backends.quantized.supported_engines: 954 return 955 with override_quantized_engine('fbgemm'): 956 device = torch.device('cuda') 957 958 model = nn.Sequential( 959 torch.ao.quantization.QuantStub(), 960 nn.Conv2d(3, 1, 1, bias=False), 961 nn.BatchNorm2d(1), 962 nn.ReLU(), 963 nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False), 964 nn.BatchNorm2d(2), 965 nn.AvgPool2d(14), 966 nn.Sigmoid(), 967 torch.ao.quantization.DeQuantStub(), 968 ) 969 970 torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True) 971 972 model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm') 973 torch.ao.quantization.prepare_qat(model, inplace=True) 974 model = nn.DataParallel(model, device_ids=[0, 1]) 975 model.to(device) 976 model.train() 977 978 for epoch in range(3): 979 inputs = torch.rand(2, 3, 28, 28).to(device) 980 model(inputs) 981 if epoch >= 1: 982 model.apply(torch.ao.quantization.disable_observer) 983 if epoch >= 2: 984 model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) 985 quant_model = copy.deepcopy(model.module) 986 quant_model = torch.ao.quantization.convert(quant_model.eval().cpu(), inplace=False) 987 with torch.no_grad(): 988 out = quant_model(torch.rand(1, 3, 28, 28)) 989 990 def test_qat_convbn_fused_syncbn_replacement(self): 991 """ 992 Tests that SyncBatchNorm replacement works for fused ConvBN. 993 """ 994 if 'fbgemm' not in torch.backends.quantized.supported_engines: 995 return 996 with override_quantized_engine('fbgemm'): 997 # create conv-bn 998 class Model(nn.Module): 999 def __init__(self) -> None: 1000 super().__init__() 1001 self.conv = nn.Conv2d(4, 1, 3, padding=1) 1002 self.bn = nn.BatchNorm2d(1) 1003 1004 def forward(self, x): 1005 x = self.conv(x) 1006 x = self.bn(x) 1007 return x 1008 1009 model = Model() 1010 # fuse it 1011 fused_model = torch.ao.quantization.fuse_modules_qat( 1012 model, 1013 [['conv', 'bn']], 1014 ) 1015 # convert to QAT 1016 fused_model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') 1017 torch.ao.quantization.prepare_qat(fused_model, inplace=True) 1018 # replace with DDP 1019 fused_model = nn.SyncBatchNorm.convert_sync_batchnorm(fused_model) 1020 self.assertTrue( 1021 isinstance(fused_model.conv.bn, nn.SyncBatchNorm), 1022 "Expected BN to be converted to SyncBN") 1023 1024 def test_syncbn_preserves_qconfig(self): 1025 """ 1026 Makes sure that if a BatchNorm is not fused and a qconfig exists, 1027 convering the module to SyncBatchNorm preserves the qconfig. 1028 """ 1029 m = nn.Sequential( 1030 nn.Conv2d(1, 1, 1), 1031 nn.BatchNorm2d(1), 1032 ) 1033 m[1].qconfig = torch.ao.quantization.default_qconfig 1034 m = torch.nn.SyncBatchNorm.convert_sync_batchnorm(m) 1035 self.assertTrue( 1036 hasattr(m[1], "qconfig"), 1037 "missing qconfig after SyncBatchNorm conversion") 1038 1039 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 1040 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1041 @override_qengines 1042 def test_device_affinity(self): 1043 """ 1044 Tests that converting a model to QAT respects device affinity 1045 """ 1046 class Model(nn.Module): 1047 1048 def __init__(self) -> None: 1049 super().__init__() 1050 self.conv = nn.Conv2d(1, 1, 1) 1051 self.bn = nn.BatchNorm2d(1) 1052 self.relu = nn.ReLU() 1053 1054 def forward(self, x): 1055 x = self.conv(x) 1056 x = self.bn(x) 1057 x = self.relu(x) 1058 return x 1059 1060 model = Model() 1061 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(torch.backends.quantized.engine) 1062 device = torch.device('cuda:0') 1063 model.to(device) 1064 torch.ao.quantization.prepare_qat(model, inplace=True) 1065 model_devices = {p.device for p in model.parameters()} | \ 1066 {p.device for p in model.buffers()} 1067 self.assertEqual(len(model_devices), 1) 1068 model_device = next(iter(model_devices)) 1069 self.assertEqual(model_device, device) 1070 1071 # ensure that running an input on CUDA works without any needed changes 1072 input = torch.randn(4, 1, 4, 4, device=device) 1073 model(input) 1074 1075class TestFusedObsFakeQuantModule(TestCase): 1076 @given( 1077 device=st.sampled_from( 1078 ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] 1079 ) 1080 ) 1081 @settings(deadline=None) 1082 def test_fused_obs_fq_module(self, device): 1083 # Set up the parameters 1084 x = torch.randn(5, 5, device=device) 1085 running_min_op = torch.tensor(float("inf"), device=device) 1086 running_max_op = torch.tensor(float("-inf"), device=device) 1087 avg_const = 0.01 1088 scale = torch.tensor([1.0], device=device) 1089 zero_point = torch.tensor([0], dtype=torch.int, device=device) 1090 1091 # Run the forward on the Module 1092 mod = FusedMovingAvgObsFakeQuantize() 1093 torch.ao.quantization.enable_fake_quant(mod) 1094 torch.ao.quantization.enable_observer(mod) 1095 mod.to(device) 1096 out = mod(x) 1097 1098 # Run the operator directly 1099 pt_op = torch.fused_moving_avg_obs_fake_quant 1100 1101 out_ref = pt_op( 1102 x, 1103 mod.observer_enabled, 1104 mod.fake_quant_enabled, 1105 running_min_op, 1106 running_max_op, 1107 scale, 1108 zero_point, 1109 avg_const, 1110 0, 1111 255, 1112 0, 1113 False, 1114 ) 1115 1116 # Compare params with reference 1117 torch.testing.assert_close(out, out_ref) 1118 torch.testing.assert_close( 1119 running_min_op, mod.activation_post_process.min_val 1120 ) 1121 torch.testing.assert_close( 1122 running_max_op, mod.activation_post_process.max_val 1123 ) 1124 1125 @given( 1126 device=st.sampled_from( 1127 ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] 1128 ) 1129 ) 1130 @settings(deadline=None) 1131 def test_fused_obs_fq_moving_avg_module(self, device): 1132 # Set up the parameters 1133 running_min_op = torch.tensor(float("inf"), device=device) 1134 running_max_op = torch.tensor(float("-inf"), device=device) 1135 avg_const = 0.001 1136 scale = torch.tensor([1.0], device=device) 1137 zero_point = torch.tensor([0], dtype=torch.int, device=device) 1138 1139 mod = FusedMovingAvgObsFakeQuantize(averaging_constant=0.001) 1140 mod.to(device) 1141 mod.observer_enabled[0] = 0 1142 mod.fake_quant_enabled[0] = 0 1143 1144 for i in range(10): 1145 x = torch.randn(5, 5, device=device) 1146 if i > 2: 1147 mod.observer_enabled[0] = 1 1148 if i > 4: 1149 mod.fake_quant_enabled[0] = 1 1150 # Run the forward on the Module 1151 out = mod(x) 1152 1153 # Run the operator directly 1154 pt_op = torch.fused_moving_avg_obs_fake_quant 1155 1156 out_ref = pt_op( 1157 x, 1158 mod.observer_enabled, 1159 mod.fake_quant_enabled, 1160 running_min_op, 1161 running_max_op, 1162 scale, 1163 zero_point, 1164 avg_const, 1165 0, 1166 255, 1167 0, 1168 False, 1169 ) 1170 1171 # Compare params with reference 1172 torch.testing.assert_close(out, out_ref) 1173 torch.testing.assert_close( 1174 running_min_op, mod.activation_post_process.min_val 1175 ) 1176 torch.testing.assert_close( 1177 running_max_op, mod.activation_post_process.max_val 1178 ) 1179 1180 @given( 1181 device=st.sampled_from( 1182 ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] 1183 ) 1184 ) 1185 @settings(deadline=None) 1186 def test_compare_fused_obs_fq_oss_module(self, device): 1187 mod = FusedMovingAvgObsFakeQuantize() 1188 torch.ao.quantization.enable_fake_quant(mod) 1189 torch.ao.quantization.enable_observer(mod) 1190 mod.to(device) 1191 1192 mod_ref = FakeQuantize() 1193 torch.ao.quantization.enable_fake_quant(mod_ref) 1194 torch.ao.quantization.enable_observer(mod_ref) 1195 mod_ref.to(device) 1196 1197 for i in range(10): 1198 x = torch.randn(5, 5, device=device) 1199 out = mod(x) 1200 out_ref = mod_ref(x) 1201 torch.testing.assert_close(out, out_ref) 1202 torch.testing.assert_close( 1203 mod_ref.activation_post_process.min_val, 1204 mod.activation_post_process.min_val, 1205 ) 1206 torch.testing.assert_close( 1207 mod_ref.activation_post_process.max_val, 1208 mod.activation_post_process.max_val, 1209 ) 1210 1211 def test_fused_mod_per_channel(self): 1212 devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] 1213 m = 5 1214 n = 10 1215 for device in devices: 1216 running_min_op = torch.empty(m, device=device).fill_(float("inf")) 1217 running_max_op = torch.empty(m, device=device).fill_(float("-inf")) 1218 avg_const = 0.001 1219 scale = torch.empty(m, device=device).fill_(0.1) 1220 zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0) 1221 obs = FusedMovingAvgObsFakeQuantize.with_args( 1222 averaging_constant=avg_const, 1223 observer=MovingAveragePerChannelMinMaxObserver, 1224 ) 1225 mod = obs() 1226 mod = torch.jit.script(mod) 1227 mod.to(device) 1228 1229 for i in range(10): 1230 x = torch.randn(m, n, device=device) 1231 if i > 2: 1232 mod.observer_enabled[0] = 1 1233 if i > 4: 1234 mod.fake_quant_enabled[0] = 1 1235 # Run the forward on the Module 1236 out = mod(x) 1237 1238 # Run the operator directly 1239 pt_op = torch.fused_moving_avg_obs_fake_quant 1240 1241 out_ref = pt_op( 1242 x, 1243 mod.observer_enabled, 1244 mod.fake_quant_enabled, 1245 running_min_op, 1246 running_max_op, 1247 scale, 1248 zero_point, 1249 avg_const, 1250 0, 1251 255, 1252 0, 1253 True, 1254 False, 1255 ) 1256 # Compare params with reference 1257 torch.testing.assert_close(out, out_ref) 1258 if mod.observer_enabled[0]: 1259 torch.testing.assert_close( 1260 running_min_op, mod.activation_post_process.min_val 1261 ) 1262 torch.testing.assert_close( 1263 running_max_op, mod.activation_post_process.max_val 1264 ) 1265 if mod.fake_quant_enabled: 1266 torch.testing.assert_close(scale, mod.scale) 1267 torch.testing.assert_close(zero_point, mod.zero_point) 1268 1269 torch.testing.assert_close(mod.state_dict()['activation_post_process.min_val'], running_min_op) 1270 torch.testing.assert_close(mod.state_dict()['activation_post_process.max_val'], running_max_op) 1271 1272 def test_fused_mod_reduce_range(self): 1273 obs = FusedMovingAvgObsFakeQuantize(quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=True) 1274 self.assertEqual(obs.activation_post_process.quant_min, 0) 1275 self.assertEqual(obs.activation_post_process.quant_max, 127) 1276 1277 def test_embedding_bag_qat_config(self): 1278 class Model(nn.Module): 1279 def __init__(self) -> None: 1280 super().__init__() 1281 self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, 1282 include_last_offset=True, scale_grad_by_freq=False, mode='sum') 1283 self.emb2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, 1284 include_last_offset=True, scale_grad_by_freq=False, mode='sum') 1285 1286 def forward(self, indices): 1287 return torch.cat((self.emb1(indices), self.emb2(indices))) 1288 1289 1290 qconfigs = [torch.ao.quantization.default_embedding_qat_qconfig, 1291 torch.ao.quantization.default_embedding_qat_qconfig_4bit] 1292 for qconfig in qconfigs: 1293 model = Model().train() 1294 indices = torch.randint(0, 10, (5, 12)) 1295 1296 model.qconfig = qconfig 1297 1298 quant_model = prepare_qat(model, 1299 mapping=get_embedding_qat_module_mappings()) 1300 1301 count_fake_quant = 0 1302 for name, mod in quant_model.named_modules(): 1303 if name.endswith('weight_fake_quant'): 1304 count_fake_quant += 1 1305 self.assertEqual(type(mod), FakeQuantize) 1306 self.assertEqual(count_fake_quant, 2) 1307 1308 quant_model(indices) 1309 1310 # Ensure that EmbeddingBags have float zero_point values 1311 self.assertEqual(quant_model.emb1.weight_fake_quant.zero_point.dtype, torch.float32) 1312 self.assertEqual(quant_model.emb2.weight_fake_quant.zero_point.dtype, torch.float32) 1313 1314 inference_gm = convert(quant_model.eval().cpu(), 1315 mapping=get_embedding_static_quant_module_mappings()) 1316 1317 # Ensure that EmbeddingBags are now quantized with the appropriate bitwidth. 1318 self.assertEqual(type(inference_gm.emb1), torch.ao.nn.quantized.EmbeddingBag) 1319 self.assertEqual(type(inference_gm.emb2), torch.ao.nn.quantized.EmbeddingBag) 1320 self.assertEqual(inference_gm.emb1.dtype, qconfig.weight().dtype) 1321 self.assertEqual(inference_gm.emb2.dtype, qconfig.weight().dtype) 1322 1323 def test_embedding_qat_config(self): 1324 for qengine in supported_qengines: 1325 with override_quantized_engine(qengine): 1326 model = DeFusedEmbeddingBagLinear() 1327 indices = torch.randint(0, 10, (5, 12)) 1328 quant_model = prepare_qat(model, 1329 mapping=get_embedding_qat_module_mappings()) 1330 1331 count_fake_quant = 0 1332 count_activation_postproc = 0 1333 for name, mod in quant_model.named_modules(): 1334 if name.endswith('weight_fake_quant'): 1335 count_fake_quant += 1 1336 if name.count('activation_post_process') == 1 and 'weight_fake_quant' not in name: 1337 count_activation_postproc += 1 1338 # One for embeddings, one for linear layer. 1339 self.assertEqual(count_fake_quant, 2) 1340 # One for embeddings (but it is a NoOp), One for quantize, one for linear layer. 1341 self.assertEqual(count_activation_postproc, 3) 1342 1343 self.assertEqual(type(quant_model.emb.weight_fake_quant), FakeQuantize) 1344 self.assertEqual(quant_model.emb.weight_fake_quant.zero_point.dtype, torch.float32) 1345 self.assertEqual(type(quant_model.emb.activation_post_process), NoopObserver) 1346 self.assertEqual(type(quant_model.linear.weight_fake_quant), FusedMovingAvgObsFakeQuantize) 1347 self.assertEqual(type(quant_model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) 1348 1349 quant_model(indices) 1350 inference_gm = convert(quant_model, 1351 mapping=get_embedding_static_quant_module_mappings()) 1352 # Ensure that Embedding is now quantized 1353 self.assertEqual(type(inference_gm.emb), torch.ao.nn.quantized.Embedding) 1354 # Ensure that Linear is now quantized 1355 self.assertEqual(type(inference_gm.linear), torch.ao.nn.quantized.Linear) 1356 1357 def test_default_fused_qat_config(self): 1358 class Model(nn.Module): 1359 def __init__(self) -> None: 1360 super().__init__() 1361 self.linear = nn.Linear(2, 2) 1362 self.relu = nn.ReLU() 1363 1364 def forward(self, x): 1365 x = self.linear(x) 1366 x = self.relu(x) 1367 return x 1368 1369 for qengine in ["fbgemm", "qnnpack"]: 1370 model = Model() 1371 model.linear.weight = torch.nn.Parameter(torch.randn(2, 2)) 1372 sample_input = torch.randn(2, 2) 1373 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine, version=1) 1374 ref_model = torch.ao.quantization.QuantWrapper(model) 1375 ref_model = torch.ao.quantization.prepare_qat(ref_model) 1376 ref_model(sample_input) 1377 count_fake_quant = 0 1378 for name, mod in ref_model.named_modules(): 1379 if name.endswith('weight_fake_quant'): 1380 count_fake_quant += 1 1381 self.assertEqual(type(mod), FusedMovingAvgObsFakeQuantize) 1382 1383 if name.count('activation_post_process') == 1 and 'weight_fake_quant' not in name: 1384 count_fake_quant += 1 1385 self.assertEqual(type(mod), FusedMovingAvgObsFakeQuantize) 1386 1387 self.assertEqual(count_fake_quant, 3) 1388 1389 if qengine == "fbgemm": 1390 lower_bnd = 0 1391 upper_bnd = 127 1392 obs2match = MovingAveragePerChannelMinMaxObserver 1393 1394 else: 1395 lower_bnd = 0 1396 upper_bnd = 255 1397 obs2match = MovingAverageMinMaxObserver 1398 1399 self.assertEqual(ref_model.quant.activation_post_process.activation_post_process.quant_min, lower_bnd) 1400 self.assertEqual(ref_model.quant.activation_post_process.activation_post_process.quant_max, upper_bnd) 1401 self.assertEqual(type(ref_model.module.linear.weight_fake_quant.activation_post_process), 1402 obs2match) 1403 1404if __name__ == '__main__': 1405 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 1406 "\tpython test/test_quantization.py TESTNAME\n\n" 1407 "instead.") 1408