1# Owner(s): ["oncall: quantization"] 2 3 4import copy 5import itertools 6import numpy as np 7import operator 8import random 9import unittest 10from typing import NamedTuple, List 11 12import torch 13from torch import _VF 14import torch.jit 15import torch.nn.functional as F 16from torch.nn.modules.utils import _single, _pair 17 18from hypothesis import settings, HealthCheck 19from hypothesis import assume, given, note 20from hypothesis import strategies as st 21import torch.testing._internal.hypothesis_utils as hu 22hu.assert_deadline_disabled() 23 24from torch.testing._internal.common_cuda import SM80OrLater 25from torch.testing._internal.common_utils import TestCase 26from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE 27from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN 28from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ 29 override_quantized_engine, supported_qengines, override_qengines, _snr 30from torch.testing._internal.common_quantized import ( 31 qengine_is_qnnpack, 32 qengine_is_onednn, 33) 34from torch.ao.quantization import PerChannelMinMaxObserver 35from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA 36from torch.testing._internal.optests import opcheck 37import torch.backends.xnnpack 38 39from torch.utils.cpp_extension import ROCM_HOME 40 41from typing import Optional 42 43np_dtype = { 44 torch.quint8 : np.uint8, 45 torch.qint8 : np.int8, 46 torch.qint32 : np.int32 47} 48 49TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None 50 51class PointwisePostOp(NamedTuple): 52 binary_attr : str = "none" 53 alpha : float = 1.0 54 unary_attr : str = "none" 55 scalars : List = [] 56 algorithm : str = "" 57 58# Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM. 59# On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction 60# for the 8-bit int multiplication. This instruction vertically multiplies each 61# unsigned 8-bit integer from a with the corresponding signed 8-bit integer from 62# b, producing intermediate signed 16-bit integers. This function modifies the 63# weights to eliminate the overflow on the signed 16-bit integers. 64def avoid_vpmaddubsw_overflow_linear( 65 batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max 66): 67 for i, j in np.ndindex((batch_size, output_channels)): 68 for k in range(0, input_channels // 2 * 2, 2): 69 x0 = X[i, k] - X_min 70 x1 = X[i, k + 1] - X_min 71 w0 = W[j, k] - 128 - W_min 72 w1 = W[j, k + 1] - 128 - W_min 73 if x0 * w0 + x1 * w1 < -(1 << 15): 74 w1_adjusted = (-(1 << 15) - float(x0) * w0) / x1 75 W[j, k + 1] = int(w1_adjusted) + 128 + W_min 76 elif x0 * w0 + x1 * w1 > (1 << 15) - 1: 77 w1_adjusted = ((1 << 15) - 1 - float(x0) * w0) / x1 78 W[j, k + 1] = int(w1_adjusted) + 128 + W_min 79 80 # Go through the same loop again to double check we don't have any overflow 81 for i, j in np.ndindex((batch_size, output_channels)): 82 for k in range(0, input_channels // 2 * 2, 2): 83 x0 = X[i, k] - X_min 84 x1 = X[i, k + 1] - X_min 85 w0 = W[j, k] - 128 - W_min 86 w1 = W[j, k + 1] - 128 - W_min 87 assert -(1 << 15) <= x0 * w0 + x1 * w1 < (1 << 15) 88 89 90# Reference quantized Linear operator 91def qlinear_ref(X_q, X_scale, X_zp, W_q, W_scale, W_zp, b_q, Y_scale, Y_zp, dtype=np.uint8): 92 X_q = np.reshape(X_q, (-1, X_q.shape[X_q.ndim - 1])) 93 row_offsets_ref = X_q.sum(axis=1).astype(np.int32).reshape((-1, 1)) 94 col_offsets_ref = W_q.sum(axis=1).astype(np.int32).reshape((1, -1)) 95 assert X_q.ndim == 2 96 batch_size, input_channels = X_q.shape 97 Prod_XqWq_ref = ( 98 np.matmul(X_q.astype(np.int32), W_q.astype(np.int32).T) 99 - W_zp * row_offsets_ref 100 - X_zp * col_offsets_ref 101 + input_channels * X_zp * W_zp 102 ) 103 if b_q is not None: 104 Prod_XqWq_ref += b_q 105 Y_q_ref = _quantize(Prod_XqWq_ref, Y_scale / (X_scale * W_scale), Y_zp, dtype=dtype) 106 return Y_q_ref 107 108"""Computes the output shape given pooling parameters.""" 109def pool_output_shape(input_size, kernel_size, padding, stride, 110 dilation, ceiling_mode=False): 111 if stride is None: 112 stride = kernel_size 113 output_size = ( 114 (input_size + 2 * padding - dilation * (kernel_size - 1) - 1 115 + (stride - 1 if ceiling_mode else 0)) // stride + 1) 116 if (ceiling_mode and 117 ((output_size - 1) * stride >= input_size + padding)): 118 output_size -= 1 119 return output_size 120 121""" 122Util for creating a random tensor and quantization params when Hypothesis 123is undesirable. 124""" 125def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type): 126 X = (torch.rand(*shapes, dtype=torch.float) - 0.5) * rand_scale 127 # Calculate reasonable quantization params 128 min_val = torch.min(X) 129 max_val = torch.max(X) 130 if torch_type == torch.qint32: 131 X_zero_point = int(torch.randint(-1 * (2 ** 31), 2 ** 31 - 1, (1,))) 132 num_bins = 2 ** 32 133 X_scale = float(max_val - min_val) / num_bins 134 elif torch_type == torch.qint8: 135 X_zero_point = int(torch.randint(-128, 127, (1,))) 136 num_bins = 2 ** 8 137 X_scale = float(max_val - min_val) / num_bins 138 else: # torch.quint8 139 X_zero_point = 127 140 num_bins = 2 ** 8 141 X_scale = float(max_val - min_val) / num_bins 142 if X_scale == 0: 143 X_scale = 1e-10 144 return X, X_scale, X_zero_point 145 146class TestQuantizedOps(TestCase): 147 148 """Helper function to test quantized activation functions.""" 149 def _test_activation_function(self, X, fn_name, test_configs): 150 r""" 151 When writing a unit test for the activation function, 152 instead of specifying the test routines only applicable to the activation function itself, 153 you utilize the _test_activation_function that provides general testing. 154 To utilize the helper function, a test config must be provided. 155 A test config is a list that contains metadata about the quantized activation 156 functions that will be tested and how the tests need to be set up; it allows simpler and 157 more concise unit tests to be written by specifying the configurations needed 158 and calling the provided helper function _test_activation_function. 159 Inside the list, each config (as a dictionary) represents a suite of tests that assert the 160 correctness of various quantization functions. 161 You can check out the test_qrelu, test_qrelu6, test_qsigmoid, and test_qhardsigmoid for 162 how their test configs are specified. 163 Here's a list of the fields that can be included in a test config: 164 quantized_fn: a list of the quantized functions to be tested 165 reference_fn: the original reference function to be called on the 166 the dequantized X 167 extra_kwargs: the additional keyword arguments 168 for each test entry in ops_under_test, it must have at least the fields 169 for quantized_fn and reference_fn. 170 output_range: the output range the operator will map to. By default, if it is 171 no specified, the range will not be controlled and depend on Xmin and Xmax. 172 change_zero_point: a boolean flag indicating if the zero point parameter should 173 be determined based on torch_type during quantization (see sigmoid/hardsigmoid for 174 examples). By default, if it is not specified, change_zero_point is assumed to be 175 False and zero point will just take on the default value from X. 176 `output_is_observed`: if specified and is True, we'll append extra 177 output_scale/output_zero_point keyword argument when calling quantized op 178 """ 179 # Retrives the default parameters from X. 180 X, (scale, zero_point, torch_type) = X 181 if not isinstance(X, torch.Tensor): 182 X = torch.from_numpy(X) 183 if (X.device.type == 'cuda') and (torch.backends.quantized.engine == 'qnnpack'): 184 return 185 # Quantizes the reference to account for max error. 186 # q_min and q_max only depend on the initial torch_type. 187 q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max 188 189 for op_group in test_configs: 190 ref_op = op_group['reference_fn'] 191 for q_op in op_group['quantized_fn']: 192 193 for memory_format in (torch.channels_last, torch.contiguous_format): 194 if memory_format == torch.channels_last and len(X.shape) != 4: 195 continue 196 X = X.to(memory_format=memory_format) 197 198 # Retrieves the inplace keyword arguments 199 # some functions require inplace=True to test in-place. 200 # copy.copy is needed because these are modified in place 201 extra_kwargs = \ 202 copy.copy(op_group.get('extra_kwargs', {})) 203 output_is_observed = \ 204 copy.copy(op_group.get('output_is_observed', False)) 205 206 # Quantizes and dequantizes to account for max error. 207 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 208 dtype=torch_type) 209 dqX = qX.dequantize() 210 dqY_hat = ref_op(dqX.clone(), **extra_kwargs) 211 212 # Adjusts output_scale if needed. 213 # The output_scale determines the quantization scale for functions that 214 # have a constrained output range. e.x. sigmoid ranges from 0 to 1. 215 output_scale = scale 216 if 'output_range' in op_group: 217 (f_min, f_max) = op_group['output_range'] 218 output_scale = (f_max - f_min) / (q_max - q_min + 1.0) 219 220 # Adjusts output_zero_point if needed (see explanation for the 221 # change_zero_point parameter above). 222 # output_zero_point determines the additional offset that will be 223 # added to a scaled value during quantization. 224 if op_group.get('change_zero_point', False): 225 output_zero_point = 0 if torch_type == torch.qint32 else q_min 226 else: 227 output_zero_point = zero_point 228 229 # Quantizes the dequantized version of Y_hat. 230 qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, 231 zero_point=output_zero_point, 232 dtype=torch_type) 233 234 if output_is_observed: 235 extra_kwargs.update({'output_scale': output_scale, 'output_zero_point': output_zero_point}) 236 237 # Finds qY using in-place or non-in-place quantized operators. 238 qY = q_op(qX, **extra_kwargs) 239 240 self.assertEqual(qY, qY_hat, msg=f'{fn_name} - {q_op} failed: ({qY} vs. {qY_hat})') 241 242 """Tests the correctness of the quantized::relu op.""" 243 @override_qengines 244 def test_qrelu(self): 245 relu_test_configs = [ 246 { 247 'quantized_fn': [ 248 torch.relu, 249 torch.relu_, 250 torch.nn.functional.relu, 251 torch.nn.functional.relu, 252 ], 253 'reference_fn': torch.nn.functional.relu 254 }, 255 { 256 'quantized_fn': [ 257 torch.nn.functional.relu, 258 torch.nn.functional.relu, 259 ], 260 'reference_fn': torch.nn.functional.relu, 261 'extra_kwargs': { 262 'inplace': True 263 } 264 } 265 ] 266 devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"] 267 for device in devices: 268 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 269 dtypes = (torch.quint8, torch.qint8) 270 scales = (0.05, 0.1) 271 zero_points = (0, 5) 272 test_cases = itertools.product(shapes, dtypes, scales, zero_points) 273 for shape, dtype, scale, zero_point in test_cases: 274 X = torch.randn(*shape, device=device) 275 X = (X, (scale, zero_point, dtype)) 276 self._test_activation_function(X, 'relu', relu_test_configs) 277 278 """Tests the correctness of the quantized::relu6 op.""" 279 def test_qrelu6(self): 280 relu6_test_configs = [ 281 { 282 'quantized_fn': [ 283 torch.ops.quantized.relu6, 284 torch.ao.nn.quantized.ReLU6(inplace=False), 285 torch.ao.nn.quantized.ReLU6(inplace=True) 286 ], 287 'reference_fn': torch.nn.functional.relu6 288 } 289 ] 290 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 291 dtypes = (torch.quint8, torch.qint8) 292 scales = (0.05, 0.1) 293 zero_points = (0, 5) 294 test_cases = itertools.product(shapes, dtypes, scales, zero_points) 295 for shape, dtype, scale, zero_point in test_cases: 296 X = torch.randn(*shape) * 10 297 X = (X, (scale, zero_point, dtype)) 298 self._test_activation_function(X, 'relu6', relu6_test_configs) 299 300 """Tests the correctness of the quantized::sigmoid op.""" 301 @override_qengines 302 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 303 qparams=hu.qparams())) 304 def test_sigmoid_non_observed(self, X): 305 sigmoid_test_configs = [ 306 { 307 'quantized_fn': [ 308 torch.sigmoid 309 ], 310 'reference_fn': torch.sigmoid, 311 'output_range': (0.0, 1.0), 312 'change_zero_point': True 313 } 314 ] 315 self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) 316 317 """Tests the correctness of the quantized::sigmoid op.""" 318 # TODO: enable after observed output is supported in qnnpack 319 # @override_qengines 320 @skipIfNoFBGEMM 321 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 322 qparams=hu.qparams())) 323 def test_sigmoid(self, X): 324 sigmoid_test_configs = [ 325 { 326 'quantized_fn': [ 327 torch.ops.quantized.sigmoid 328 ], 329 'reference_fn': torch.sigmoid, 330 'output_range': (0.0, 1.0), 331 'change_zero_point': True, 332 'output_is_observed': True, 333 } 334 ] 335 self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) 336 337 @skipIfNoFBGEMM 338 def test_sigmoid_dequantize_rounding_error(self): 339 # issue #107030 340 sigmoid_test_configs = [ 341 { 342 'quantized_fn': [ 343 torch.ops.quantized.sigmoid 344 ], 345 'reference_fn': torch.sigmoid, 346 'output_range': (0.0, 1.0), 347 'change_zero_point': True, 348 'output_is_observed': True, 349 } 350 ] 351 X = (np.full(64, 514., dtype=np.float32), (1028.02, 255, torch.quint8)) 352 self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) 353 354 """Tests the correctness of the quantized::hardsigmoid op.""" 355 @override_qengines 356 def test_qhardsigmoid(self): 357 hardsigmoid_test_configs = [ 358 { 359 'quantized_fn': [ 360 torch.ao.nn.quantized.functional.hardsigmoid, 361 ], 362 'reference_fn': torch.nn.functional.hardsigmoid, 363 'output_range': (0.0, 1.0), 364 'change_zero_point': True, 365 }, 366 { 367 'quantized_fn': [ 368 torch.ao.nn.quantized.functional.hardsigmoid, 369 ], 370 'reference_fn': torch.nn.functional.hardsigmoid, 371 'output_range': (0.0, 1.0), 372 'change_zero_point': True, 373 'extra_kwargs': { 374 'inplace': True, 375 }, 376 }, 377 ] 378 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 379 dtypes = (torch.quint8, torch.qint8) 380 test_cases = itertools.product(shapes, dtypes) 381 for shape, dtype in test_cases: 382 X = (np.random.rand(*shape).astype(np.float32), (1.0, 0, dtype)) 383 self._test_activation_function(X, 'hardsigmoid', hardsigmoid_test_configs) 384 385 @override_qengines 386 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 387 qparams=hu.qparams())) 388 def test_leaky_relu_observed_output(self, X): 389 leaky_relu_test_configs = [ 390 { 391 'quantized_fn': [ 392 torch.ops.quantized.leaky_relu 393 ], 394 'reference_fn': torch.nn.functional.leaky_relu, 395 'extra_kwargs': { 396 'negative_slope': 0.1, 397 'inplace': False, 398 }, 399 'output_is_observed': True, 400 } 401 ] 402 self._test_activation_function(X, 'leaky_relu', leaky_relu_test_configs) 403 404 """Tests the correctness of the quantized::relu op.""" 405 def test_leaky_relu(self): 406 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 407 dtypes = (torch.quint8, torch.qint8) 408 memory_formats = (torch.channels_last, torch.contiguous_format) 409 test_cases = itertools.product(shapes, dtypes, memory_formats) 410 for shape, dtype, memory_format in test_cases: 411 if memory_format == torch.channels_last and len(shape) != 4: 412 continue 413 X, scale, zero_point, torch_type, alpha = \ 414 torch.randn(*shape), 0.1, 0, dtype, 0.01 415 X = X.to(memory_format=memory_format) 416 417 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 418 dtype=torch_type) 419 dqX = qX.dequantize() 420 421 # torch.nn.functional 422 op = torch.nn.functional.leaky_relu 423 dqY = op(dqX, negative_slope=alpha) 424 qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, 425 dtype=torch_type) 426 qY_hat = op(qX, negative_slope=alpha) 427 self.assertEqual(qY.dequantize(), qY_hat.dequantize(), 428 msg=f"F.leaky_relu failed ({qY} vs {qY_hat})") 429 430 """Tests the correctness of the quantized::elu op.""" 431 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 432 elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 433 qparams=hu.qparams()), 434 alpha=st.floats(0.01, 10.0, allow_nan=False, allow_infinity=False)) 435 def test_qelu(self, X, alpha): 436 X, (scale, zero_point, torch_type) = X 437 output_scale = 0.5 438 output_zero_point = 1 439 440 X = torch.from_numpy(X) 441 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 442 dtype=torch_type) 443 444 # calculate ELU(dqX) and quantize 445 dqX = qX.dequantize() 446 dqY_hat = dqX.clone() 447 dqY_hat = torch.nn.functional.elu(dqX, alpha) 448 qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, zero_point=output_zero_point, 449 dtype=torch_type) 450 451 qY = torch.ao.nn.quantized.functional.elu(qX, output_scale, output_zero_point, alpha=alpha) 452 self.assertEqual(qY, qY_hat, 453 msg=f"F.elu failed ({qY} vs {qY_hat})") 454 455 456 """Tests the correctness of the quantized::celu op.""" 457 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 458 elements=hu.floats(-1e2, 1e2, allow_nan=False, allow_infinity=False), 459 qparams=hu.qparams(scale_max=9.999999747378752e-06)), 460 alpha=st.floats(0.01, 100.0, allow_nan=False, allow_infinity=False)) 461 def test_qcelu(self, X, alpha): 462 X, (scale, zero_point, torch_type) = X 463 output_scale = 0.5 464 output_zero_point = 1 465 466 X = torch.from_numpy(X) 467 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 468 dtype=torch_type) 469 470 # calculate CELU(dqX) and quantize 471 dqX = qX.dequantize() 472 dqY_hat = torch.nn.functional.celu(dqX, alpha) 473 qY_hat = torch.quantize_per_tensor(dqY_hat, scale=output_scale, zero_point=output_zero_point, 474 dtype=torch_type) 475 476 # test regular 477 qY = torch.ops.quantized.celu(qX, output_scale, output_zero_point, alpha=alpha) 478 self.assertEqual(qY, qY_hat, 479 msg=f"F.celu failed ({qY} vs {qY_hat})") 480 481 """Tests the correctness of the quantized::gelu op.""" 482 def test_qgelu(self): 483 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 484 dtypes = (torch.quint8, torch.qint8) 485 memory_formats = (torch.channels_last, torch.contiguous_format) 486 approximation = ['none', 'tanh'] 487 test_cases = itertools.product(shapes, dtypes, memory_formats, approximation) 488 devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"] 489 for shape, dtype, memory_format, approximate in test_cases: 490 if memory_format == torch.channels_last and len(shape) != 4: 491 continue 492 493 X, scale, zero_point, torch_type = \ 494 torch.randn(*shape), 0.1, 0, dtype 495 X = X.to(memory_format=memory_format) 496 for device in devices: 497 X = X.to(device=device) 498 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 499 dtype=torch_type) 500 dqX = qX.dequantize() 501 502 op = torch.nn.functional.gelu 503 dqY = op(dqX, approximate=approximate) 504 qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, 505 dtype=torch_type) 506 qY_hat = op(qX) 507 self.assertEqual(qY.dequantize(), qY_hat.dequantize(), 508 msg=f"F.gelu failed ({qY} vs {qY_hat})") 509 510 """Tests the correctness of the quantized::prelu op.""" 511 def test_qprelu(self): 512 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 513 num_params = (0, 1) # 0: num_parameter = num_channels 514 dtypes = (torch.quint8, torch.qint8) 515 memory_formats = (torch.channels_last, torch.contiguous_format) 516 test_cases = itertools.product(shapes, num_params, dtypes, memory_formats) 517 for shape, num_param, dtype, memory_format in test_cases: 518 if memory_format == torch.channels_last and len(shape) != 4: 519 continue 520 X, scale, zero_point, torch_type = \ 521 torch.randn(*shape), 0.1, 0, dtype 522 X = X.to(memory_format=memory_format) 523 num_parameter = 1 if num_param == 1 or len(shape) == 1 else shape[1] 524 W = torch.randn(num_parameter) 525 W, w_scale, w_zero_point = \ 526 torch.randn(num_parameter), 0.2, 0 527 528 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 529 dtype=torch_type) 530 dqX = qX.dequantize() 531 qW = torch.quantize_per_tensor(W, scale=w_scale, zero_point=w_zero_point, 532 dtype=torch_type) 533 dqW = qW.dequantize() 534 535 op = torch.nn.functional.prelu 536 qop = torch.ops.quantized.prelu 537 dqY = op(dqX, dqW) 538 qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point, 539 dtype=torch_type) 540 qY_hat = qop(qX, qW, scale, zero_point) 541 self.assertEqual(qY.dequantize(), qY_hat.dequantize(), 542 msg=f"F.prelu failed ({qY} vs {qY_hat})") 543 544 """Tests the correctness of the quantized::qlayer_norm op.""" 545 @skipIfNoFBGEMM 546 def test_qlayer_norm(self): 547 # hypothesis is flaky for this test, create test cases manually 548 side_lens = (1, 8, 11) 549 torch_types = (torch.qint8, torch.quint8) 550 y_scales = (0.1, 4.23) 551 y_zero_points = (0, 1) 552 channels_last_list = (True, False) 553 affine_list = (True, False) 554 combined = [side_lens, torch_types, y_scales, y_zero_points, 555 channels_last_list, affine_list] 556 test_cases = itertools.product(*combined) 557 558 with override_quantized_engine("fbgemm"): 559 for test_case in test_cases: 560 561 side_len, torch_type, Y_scale, Y_zero_point, channels_last, \ 562 affine = test_case 563 shapes = [side_len] * 4 564 565 # In the FP kernel, mean and variance are calculated in floating point. 566 # In the quantized kernel, they are calculated in integer arithmetic. 567 # Because of this, the numerics do not always match exactly which is 568 # expected and acceptable. We do two things to allow this failure 569 # in this test: 570 # 1. do not use Hypothesis to generate the input tensor. Hypothesis 571 # favors homogeneous inputs in its search strategies which isn't 572 # representative of the inputs we care about, and tends to maximize 573 # this particular numerics difference. 574 # 2. allow a small % of off by Y_scale errors. Even when the 575 # variance of the input is high, there can be off by one errors 576 # in the result if the input value happens to fall exactly on 577 # the bin boundary of the output scale. 578 # 579 # If we want the numerics to match we could switch to calculating 580 # mean+var in floating point in the future, at the cost of speed. 581 X, X_scale, X_zero_point = \ 582 _get_random_tensor_and_q_params(shapes, 1.0, torch_type) 583 584 qX = torch.quantize_per_tensor(X, scale=X_scale, 585 zero_point=X_zero_point, 586 dtype=torch_type) 587 if channels_last: 588 qX = qX.contiguous(memory_format=torch.channels_last) 589 dqX = qX.dequantize() 590 591 # Enforce non-homogeneous inputs 592 enough_unique_vals_in_each_layer = sum( 593 1 if ( 594 dqX[i].shape[0] < 5 or 595 float(torch.unique(dqX[i]).shape[0]) / dqX[i].shape[0] > 0.01 596 ) else 0 597 for i in range(dqX.shape[0]) 598 ) == dqX.shape[0] 599 assume(enough_unique_vals_in_each_layer) 600 601 # Initialize the weights non-randomly for reproducibility, to avoid 602 # flaky tests 603 if affine: 604 weight = torch.ones(*qX.size()[1:], dtype=torch.float) * 0.5 605 bias = torch.ones(*qX.size()[1:], dtype=torch.float) * 1 606 else: 607 weight = None 608 bias = None 609 epsilon = 1e-5 610 611 qY = torch.ops.quantized.layer_norm( 612 qX, qX.size()[1:], weight=weight, bias=bias, eps=epsilon, 613 output_scale=Y_scale, output_zero_point=Y_zero_point) 614 615 Y_hat = F.layer_norm( 616 dqX, dqX.size()[1:], weight=weight, bias=bias, eps=epsilon) 617 qY_hat = torch.quantize_per_tensor( 618 Y_hat, scale=Y_scale, zero_point=Y_zero_point, dtype=torch_type) 619 620 # Due to the numerics difference mentioned above between calculating 621 # the variance in float vs int, the results can still be slightly 622 # different. 623 dqY = qY.dequantize() 624 dqY_hat = qY_hat.dequantize() 625 diff = dqY - dqY_hat 626 627 # off-by-one errors are magnitude of Y_scale 628 num_diff = torch.sum(diff > Y_scale * 1.0001) 629 pct_diff = float(num_diff) / (diff.numel() + 1e-5) 630 num_diff_off_by_one = torch.sum((diff > 0) * (diff <= Y_scale)) 631 pct_diff_off_by_one = float(num_diff_off_by_one) / (diff.numel() + 1e-5) 632 633 self.assertTrue(pct_diff < 1e-6) 634 self.assertTrue(pct_diff_off_by_one < 0.01) 635 636 637 """Tests the correctness of the quantized::qnnpack_tanh op.""" 638 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 639 qparams=hu.qparams())) 640 @unittest.skip( 641 "this is broken without changes to any relevant code, " 642 "we need to remove hypothesis testing in CI") 643 def test_qtanh(self, X): 644 # Note: QNNPACK is tested separately in TestQNNPackOps 645 X, (scale, zero_point, torch_type) = X 646 647 X = torch.from_numpy(X) 648 Y = torch.tanh(X) 649 650 qX = torch.quantize_per_tensor(X, scale=scale, 651 zero_point=zero_point, 652 dtype=torch_type) 653 654 # Quantize the reference to account for max error. 655 # Note that the output scale has +1, because we use scale of 2.0/2^BITS 656 # in the implementations. 657 f_min, f_max = -1.0, 1.0 658 q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max 659 output_scale = (f_max - f_min) / (q_max - q_min + 1.0) 660 output_zero_point = int(round((q_max + q_min) / 2.0)) 661 qY = torch.quantize_per_tensor(Y, scale=output_scale, 662 zero_point=output_zero_point, 663 dtype=torch_type) 664 qY_hat = torch.tanh(qX) 665 self.assertEqual(qY, qY_hat, 666 msg=f"TanH failed: {qY} vs. {qY_hat}") 667 668 """Tests the correctness of the quantized::threshold op.""" 669 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 670 elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 671 qparams=hu.qparams()), 672 threshold=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), 673 value=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False)) 674 def test_qthreshold(self, X, threshold, value): 675 X, (scale, zero_point, torch_type) = X 676 X = torch.from_numpy(X) 677 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 678 dtype=torch_type) 679 680 # calculate threshold(dqX) and quantize 681 dqX = qX.dequantize() 682 dqY_hat = dqX.clone() 683 dqY_hat = torch.nn.functional.threshold(dqY_hat, threshold, value) 684 qY_hat = torch.quantize_per_tensor(dqY_hat, scale=scale, zero_point=zero_point, 685 dtype=torch_type) 686 687 ops_under_test = { 688 'native': torch.threshold, 689 'nn.functional': torch.nn.functional.threshold, 690 'ao.nn.quantized.functional': torch.ao.nn.quantized.functional.threshold, 691 } 692 693 for name, op in ops_under_test.items(): 694 qY = op(qX, threshold, value) 695 self.assertEqual(qY, qY_hat, msg=f"{name} qthreshold failed") 696 697 """Tests the correctness of the quantized::clamp op.""" 698 @given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8, max_numel=10**5), 699 elements=hu.floats(-1e6, 1e6, allow_nan=False), 700 qparams=hu.qparams()), 701 min_val=hu.floats(-1e6, 1e6, allow_nan=False), 702 max_val=hu.floats(-1e6, 1e6, allow_nan=False)) 703 def test_qclamp(self, X, min_val, max_val): 704 X, (scale, zero_point, torch_type) = X 705 706 assume(min_val <= max_val) 707 Y_clamp = torch.clamp(torch.from_numpy(X), min=min_val, max=max_val) 708 qY_clamp = torch.quantize_per_tensor(Y_clamp, scale=scale, 709 zero_point=zero_point, dtype=torch_type) 710 711 X = torch.from_numpy(X) 712 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 713 dtype=torch_type) 714 ops_under_test = { 715 'ops.quantized': torch.ops.quantized.clamp, 716 } 717 718 for name, op in ops_under_test.items(): 719 qY_clamp_hat = op(qX, min=min_val, max=max_val) 720 self.assertEqual(qY_clamp, qY_clamp_hat, msg=f"{name} qclamp failed") 721 722 if torch.backends.quantized.engine == 'fbgemm': 723 with override_quantized_engine('fbgemm'): 724 Y_min_clamp = torch.clamp(X, min=min_val) 725 Y_max_clamp = torch.clamp(X, max=max_val) 726 727 qY_min_clamp = torch.quantize_per_tensor(Y_min_clamp, scale=scale, 728 zero_point=zero_point, dtype=torch_type) 729 qY_max_clamp = torch.quantize_per_tensor(Y_max_clamp, scale=scale, 730 zero_point=zero_point, dtype=torch_type) 731 732 733 for name, op in ops_under_test.items(): 734 qY_min_clamp_hat = op(qX, min=min_val) 735 self.assertEqual(qY_min_clamp, qY_min_clamp_hat, msg=f"{name} qclamp failed") 736 qY_max_clamp_hat = op(qX, max=max_val) 737 self.assertEqual(qY_max_clamp, qY_max_clamp_hat, msg=f"{name} qclamp failed") 738 739 """Tests the correctness of the quantized::hardtanh op.""" 740 @skipIfNoFBGEMM 741 @given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8, max_numel=10**5), 742 elements=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False), 743 qparams=hu.qparams()), 744 min_val=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False), 745 max_val=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False)) 746 def test_hardtanh(self, X, min_val, max_val): 747 with override_quantized_engine('fbgemm'): 748 X, (scale, zero_point, torch_type) = X 749 750 assume(min_val <= max_val) 751 Y = X.copy() 752 Y[Y < min_val] = min_val 753 Y[Y > max_val] = max_val 754 qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale, 755 zero_point=zero_point, dtype=torch_type) 756 X = torch.from_numpy(X) 757 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 758 dtype=torch_type) 759 760 ops_under_test = { 761 'ao.nn.quantized.functional.hardtanh': 762 torch.ao.nn.quantized.functional.hardtanh, 763 } 764 765 for name, op in ops_under_test.items(): 766 qY_hat = op(qX, min_val, max_val) 767 self.assertEqual(qY, qY_hat, msg=f"{name} hardtanh failed") 768 769 ops_under_test_inplace = { 770 'inplace ao.nn.quantized.functional.hardtanh': 771 torch.ao.nn.quantized.functional.hardtanh, 772 } 773 774 for name, op_ in ops_under_test_inplace.items(): 775 qY_hat = qX.clone() 776 op_(qY_hat, min_val, max_val, inplace=True) 777 self.assertEqual(qY, qY_hat, msg=f"{name} hardtanh failed") 778 779 """Tests the correctness of the quantized::hardswish op.""" 780 @override_qengines 781 def test_hardswish(self): 782 max_sides = (3, 4) 783 side_lens = (1, 7) 784 torch_types = (torch.quint8, torch.qint8) 785 y_scales = (0.1, ) 786 y_zero_points = (1,) 787 combined = [max_sides, side_lens, torch_types, y_scales, y_zero_points] 788 test_cases = itertools.product(*combined) 789 for test_case in test_cases: 790 max_side, side_len, torch_type, Y_scale, Y_zero_point = test_case 791 792 if torch.backends.quantized.engine == 'qnnpack' and torch_type != torch.quint8: 793 continue 794 795 shapes = [side_len] * max_side 796 X, X_scale, X_zero_point = \ 797 _get_random_tensor_and_q_params(shapes, 2.0, torch_type) 798 for memory_format in torch.channels_last, torch.contiguous_format: 799 if memory_format == torch.channels_last and len(shapes) == 4: 800 X = X.to(memory_format=memory_format) 801 qX = torch.quantize_per_tensor(X, scale=X_scale, zero_point=X_zero_point, 802 dtype=torch_type) 803 dqX = qX.dequantize() 804 805 dqY_hat = F.hardswish(dqX) 806 qY_hat = torch.quantize_per_tensor(dqY_hat, scale=Y_scale, 807 zero_point=Y_zero_point, 808 dtype=torch_type) 809 810 qY = torch.ao.nn.quantized.functional.hardswish( 811 qX, scale=Y_scale, zero_point=Y_zero_point) 812 self.assertEqual( 813 qY, qY_hat, 814 msg=f"Hardswish failed: {qY} vs {qY_hat}, {torch.backends.quantized.engine}") 815 816 """Tests the correctness of the binary op + scalar.""" 817 def _test_binary_op_scalar_relu(self, A, b, binary_op_name, binary_op, quantized_op, quantized_op_relu): 818 import copy 819 op_scalar = quantized_op 820 op_scalar_relu = quantized_op_relu 821 822 A, (scale, zero_point, dtype) = A 823 A = A.astype(np.float32) 824 qA = torch.quantize_per_tensor(torch.from_numpy(A), scale, zero_point, dtype) 825 826 if binary_op_name == 'add': 827 C = binary_op(qA.dequantize(), round(b / scale) * scale) 828 else: 829 C = binary_op(qA.dequantize(), b) 830 C_relu = copy.deepcopy(C) 831 C_relu[C_relu < 0] = 0 832 833 C_hat = op_scalar(qA, b) 834 C_ref = torch.quantize_per_tensor(C, C_hat.q_scale(), C_hat.q_zero_point(), dtype) 835 C_relu_hat = op_scalar_relu(qA, b) 836 C_relu_ref = torch.quantize_per_tensor( 837 C_relu, C_relu_hat.q_scale(), C_relu_hat.q_zero_point(), dtype) 838 839 self.assertEqual(C_ref.dequantize(), C_hat.dequantize(), 840 msg=f"{binary_op_name}_scalar results don't match: " 841 f"{C_ref.dequantize()} vs {C_hat.dequantize()}") 842 self.assertEqual(C_relu_ref.dequantize(), C_relu_hat.dequantize(), 843 msg=f"{binary_op_name}_scalar_relu results don't match: " 844 f"{C_relu_ref.dequantize()} vs {C_relu_hat.dequantize()}") 845 846 @unittest.skipIf(IS_MACOS, "skipping macos test") 847 @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5), 848 elements=hu.floats(-1e6, 1e6, allow_nan=False), 849 qparams=hu.qparams()), 850 b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False)) 851 def test_add_scalar_relu(self, A, b): 852 self._test_binary_op_scalar_relu(A, b, "add", operator.add, torch.ops.quantized.add, torch.ops.quantized.add_relu) 853 854 @unittest.skipIf(IS_MACOS, "skipping macos test") 855 @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5), 856 elements=hu.floats(-1e6, 1e6, allow_nan=False), 857 qparams=hu.qparams()), 858 b=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False)) 859 def test_mul_scalar_relu(self, A, b): 860 self._test_binary_op_scalar_relu(A, b, "mul", operator.mul, torch.ops.quantized.mul, torch.ops.quantized.mul_relu) 861 862 """Tests the correctness of the add and add_relu op.""" 863 def test_qadd_relu_same_qparams(self): 864 for dtype in [torch.quint8, torch.qint8, torch.qint32]: 865 add_relu = torch.ops.quantized.add_relu 866 add = torch.ops.quantized.add 867 add_out = torch.ops.quantized.add 868 add_relu_out = torch.ops.quantized.add_relu 869 870 # NB: This is a strange size so that we exercise both the vectorized 871 # implementation (64-element chunks at at time) as well as the scalar 872 # implementation 873 A = torch.arange(-128, 130, dtype=torch.float) 874 B = torch.arange(-128, 130, dtype=torch.float) 875 scale = 2.0 876 zero_point = 127 877 qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point, 878 dtype=dtype) 879 qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point, 880 dtype=dtype) 881 882 # Add ReLU ground truth 883 C = (qA.dequantize() + qB.dequantize()).numpy() 884 qC = _quantize(C, scale, zero_point, dtype=np_dtype[dtype]) 885 qC_hat = add(qA, qB, scale=scale, zero_point=zero_point) 886 np.testing.assert_equal(qC, qC_hat.int_repr(), 887 "Quantized addition failed.") 888 qC_out_hat = torch._empty_affine_quantized(qC.shape, 889 scale=scale, 890 zero_point=zero_point, 891 dtype=dtype) 892 add_out(qA, qB, out=qC_out_hat) 893 self.assertEqual(qC_hat, qC_out_hat, msg="Add.out failed") 894 895 # Add + ReLU ground truth 896 Crelu = C.copy() 897 Crelu[C < 0] = 0 898 qCrelu = _quantize(Crelu, scale, zero_point, dtype=np_dtype[dtype]) 899 qCrelu_hat = add_relu(qA, qB, scale=scale, zero_point=zero_point) 900 np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(), 901 "Quantized addition with ReLU failed.") 902 qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape, 903 scale=scale, 904 zero_point=zero_point, 905 dtype=dtype) 906 add_relu_out(qA, qB, out=qCrelu_out_hat) 907 self.assertEqual(qCrelu_hat, qCrelu_out_hat, 908 msg="AddReLU.out failed") 909 910 """Tests the correctness of the cudnn add and add_relu op 911 (Similar to test_qadd_relu_different_qparams, will probably merge in the future)""" 912 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 913 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 914 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 915 @unittest.skip("not currently working and feature isn't used") 916 def test_qadd_relu_cudnn(self): 917 dtype = torch.qint8 918 add_relu = torch.ops.quantized.add_relu 919 add = torch.ops.quantized.add 920 921 A = torch.arange(-128, 130, dtype=torch.float).to(torch.device("cuda")) 922 B = torch.arange(-128, 130, dtype=torch.float).to(torch.device("cuda")) 923 scale_A = 2.5 924 scale_B = 6.3 925 scale_C = 12.9 926 zero_point = 0 927 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point, 928 dtype=dtype) 929 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point, 930 dtype=dtype) 931 # Add ground truth 932 C = (qA.dequantize() + qB.dequantize()).to(device="cpu").numpy() 933 qC = _quantize(C, scale_C, zero_point, dtype=np_dtype[dtype]) 934 qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu") 935 np.testing.assert_equal(qC, qC_hat.int_repr(), 936 "Quantized addition failed.") 937 938 # Add + ReLU ground truth 939 Crelu = C.copy() 940 Crelu[C < 0] = 0 941 qCrelu = _quantize(Crelu, scale_C, zero_point, dtype=np_dtype[dtype]) 942 qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu") 943 np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(), 944 "Quantized addition with ReLU failed.") 945 946 """Tests the correctness of the cudnn add and add_relu op for nhwc format""" 947 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 948 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 949 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 950 @unittest.skip("not currently working and feature isn't used") 951 def test_qadd_relu_cudnn_nhwc(self): 952 dtype = torch.qint8 953 add_relu = torch.ops.quantized.add_relu 954 add = torch.ops.quantized.add 955 956 A = torch.rand(16, 8, 4, 12).to(device="cuda") 957 B = torch.rand(16, 8, 4, 12).to(device="cuda") 958 scale_A = 2.5 959 scale_B = 6.3 960 scale_C = 12.9 961 zero_point = 0 962 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point, 963 dtype=dtype) 964 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point, 965 dtype=dtype) 966 # Add ground truth 967 C = (qA.dequantize() + qB.dequantize()).to(device="cpu").numpy() 968 qC = _quantize(C, scale_C, zero_point, dtype=np_dtype[dtype]) 969 qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu") 970 np.testing.assert_equal(qC, qC_hat.int_repr(), 971 "Quantized addition failed.") 972 973 # Add + ReLU ground truth 974 Crelu = C.copy() 975 Crelu[C < 0] = 0 976 qCrelu = _quantize(Crelu, scale_C, zero_point, dtype=np_dtype[dtype]) 977 qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point).to(device="cpu") 978 np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(), 979 "Quantized addition with ReLU failed.") 980 981 """Tests the correctness of the add and add_relu op.""" 982 def test_qadd_relu_different_qparams(self): 983 for dtype in [torch.quint8, torch.qint8, torch.qint32]: 984 add_relu = torch.ops.quantized.add_relu 985 add = torch.ops.quantized.add 986 add_out = torch.ops.quantized.add 987 add_relu_out = torch.ops.quantized.add_relu 988 989 # NB: This is a strange size so that we exercise both the vectorized 990 # implementation (64-element chunks at at time) as well as the scalar 991 # implementation 992 A = torch.arange(-128, 130, dtype=torch.float) 993 B = torch.arange(-128, 130, dtype=torch.float) 994 scale_A = 3.0 995 zero_point_A = 7 996 scale_B = 5.0 997 zero_point_B = 127 998 999 scale_C = 0.5 1000 zero_point_C = 5 1001 1002 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A, 1003 dtype=dtype) 1004 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B, 1005 dtype=dtype) 1006 1007 # Add ground truth 1008 C = (qA.dequantize() + qB.dequantize()).numpy() 1009 qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype[dtype]) 1010 qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point_C) 1011 np.testing.assert_equal(qC, qC_hat.int_repr(), 1012 "Quantized addition failed.") 1013 qC_out_hat = torch._empty_affine_quantized(qC.shape, 1014 scale=scale_C, 1015 zero_point=zero_point_C, 1016 dtype=dtype) 1017 add_out(qA, qB, out=qC_out_hat) 1018 self.assertEqual(qC_hat, qC_out_hat, msg="Add.out failed") 1019 1020 # Add + ReLU ground truth 1021 Crelu = C.copy() 1022 Crelu[C < 0] = 0 1023 qCrelu = _quantize(Crelu, scale_C, zero_point_C, dtype=np_dtype[dtype]) 1024 qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point_C) 1025 np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(), 1026 "Quantized addition with ReLU failed.") 1027 qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape, 1028 scale=scale_C, 1029 zero_point=zero_point_C, 1030 dtype=dtype) 1031 add_relu_out(qA, qB, out=qCrelu_out_hat) 1032 self.assertEqual(qCrelu_hat, qCrelu_out_hat, 1033 msg="AddReLU.out failed") 1034 1035 """Tests the correctness of the mul and mul_relu op.""" 1036 def test_qmul_relu_same_qparams(self): 1037 for dtype in [torch.quint8, torch.qint8, torch.qint32]: 1038 mul_relu = torch.ops.quantized.mul_relu 1039 mul = torch.ops.quantized.mul 1040 mul_out = torch.ops.quantized.mul 1041 mul_relu_out = torch.ops.quantized.mul_relu 1042 1043 A = torch.arange(-100, 100, dtype=torch.float) 1044 B = torch.arange(-100, 100, dtype=torch.float) 1045 scale = 2 1046 zero_point = 127 1047 qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point, 1048 dtype=dtype) 1049 qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point, 1050 dtype=dtype) 1051 1052 # mul ReLU ground truth 1053 C = (qA.dequantize() * qB.dequantize()).numpy() 1054 qC = _quantize(C, scale, zero_point, dtype=np_dtype[dtype]) 1055 qC_hat = mul(qA, qB, scale=scale, zero_point=zero_point) 1056 np.testing.assert_equal(qC, qC_hat.int_repr(), 1057 "Quantized mulition failed.") 1058 qC_out_hat = torch._empty_affine_quantized(qC.shape, 1059 scale=scale, 1060 zero_point=zero_point, 1061 dtype=dtype) 1062 mul_out(qA, qB, out=qC_out_hat) 1063 self.assertEqual(qC_hat, qC_out_hat, msg="mul.out failed") 1064 1065 # mul + ReLU ground truth 1066 Crelu = C.copy() 1067 Crelu[C < 0] = 0 1068 qCrelu = _quantize(Crelu, scale, zero_point, dtype=np_dtype[dtype]) 1069 qCrelu_hat = mul_relu(qA, qB, scale=scale, zero_point=zero_point) 1070 np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(), 1071 "Quantized mulition with ReLU failed.") 1072 qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape, 1073 scale=scale, 1074 zero_point=zero_point, 1075 dtype=dtype) 1076 mul_relu_out(qA, qB, out=qCrelu_out_hat) 1077 self.assertEqual(qCrelu_hat, qCrelu_out_hat, 1078 msg="mulReLU.out failed") 1079 1080 # Scalar multiplication 1081 for b in B: 1082 C_ref = qA.dequantize().numpy() * b.item() 1083 qC_hat = torch.ops.quantized.mul(qA, b.item()) 1084 1085 self.assertEqual(C_ref, qC_hat.dequantize()) 1086 1087 # Scalar multiplication + relu 1088 for b in B: 1089 C_ref = qA.dequantize().numpy() * b.item() 1090 C_ref[C_ref < 0] = 0 1091 qC_hat = torch.ops.quantized.mul_relu(qA, b.item()) 1092 1093 self.assertEqual(C_ref, qC_hat.dequantize()) 1094 1095 """Tests the correctness of the mul and mul_relu op.""" 1096 def test_qmul_relu_different_qparams(self): 1097 for dtype in [torch.quint8, torch.qint8, torch.qint32]: 1098 mul_relu = torch.ops.quantized.mul_relu 1099 mul = torch.ops.quantized.mul 1100 mul_out = torch.ops.quantized.mul 1101 mul_relu_out = torch.ops.quantized.mul_relu 1102 1103 A = torch.arange(-100, 100, dtype=torch.float) 1104 B = torch.arange(-100, 100, dtype=torch.float) 1105 scale_A = 3.0 1106 zero_point_A = 7 1107 scale_B = 5.0 1108 zero_point_B = 127 1109 1110 scale_C = 0.5 1111 zero_point_C = 5 1112 1113 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A, 1114 dtype=dtype) 1115 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B, 1116 dtype=dtype) 1117 1118 # mul ground truth 1119 C = (qA.dequantize() * qB.dequantize()).numpy() 1120 qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype[dtype]) 1121 qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C) 1122 np.testing.assert_equal(qC, qC_hat.int_repr(), 1123 "Quantized multiplication failed.") 1124 qC_out_hat = torch._empty_affine_quantized(qC.shape, 1125 scale=scale_C, 1126 zero_point=zero_point_C, 1127 dtype=dtype) 1128 mul_out(qA, qB, out=qC_out_hat) 1129 self.assertEqual(qC_hat, qC_out_hat, msg="mul.out failed") 1130 1131 # mul + ReLU ground truth 1132 Crelu = C.copy() 1133 Crelu[C < 0] = 0 1134 qCrelu = _quantize(Crelu, scale_C, zero_point_C, dtype=np_dtype[dtype]) 1135 qCrelu_hat = mul_relu(qA, qB, scale=scale_C, zero_point=zero_point_C) 1136 np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(), 1137 "Quantized multiplication with ReLU failed.") 1138 qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape, 1139 scale=scale_C, 1140 zero_point=zero_point_C, 1141 dtype=dtype) 1142 mul_relu_out(qA, qB, out=qCrelu_out_hat) 1143 self.assertEqual(qCrelu_hat, qCrelu_out_hat, 1144 msg="mulReLU.out failed") 1145 1146 """Tests the correctness of the matmul op.""" 1147 @given(num_dims=st.integers(2, 5), 1148 outer_dims=st.lists(st.integers(2, 6), min_size=3, max_size=3), 1149 m=st.integers(2, 6), 1150 k=st.integers(2, 6), 1151 n=st.integers(2, 6), 1152 dtypes=st.sampled_from(((torch.qint8, np.int8), 1153 (torch.quint8, np.uint8)))) 1154 def test_qmatmul(self, num_dims, outer_dims, m, k, n, dtypes): 1155 (torch_dtype, np_dtype) = dtypes 1156 1157 size_a = outer_dims[:num_dims - 2] + [m, k] 1158 size_b = outer_dims[:num_dims - 2] + [k, n] 1159 A = torch.randn(size=size_a, dtype=torch.float32) * 3 1160 B = torch.randn(size=size_b, dtype=torch.float32) * 3 1161 1162 scale_A = 3.1 1163 zero_point_A = 7 1164 scale_B = 5.3 1165 zero_point_B = 127 1166 1167 scale_C = 1.3 1168 zero_point_C = 5 1169 1170 qA = torch.quantize_per_tensor(A, 1171 scale=scale_A, 1172 zero_point=zero_point_A, 1173 dtype=torch_dtype) 1174 qB = torch.quantize_per_tensor(B, 1175 scale=scale_B, 1176 zero_point=zero_point_B, 1177 dtype=torch_dtype) 1178 1179 # matmul ground truth 1180 C = torch.matmul(qA.dequantize(), qB.dequantize()).numpy() 1181 qC = _quantize(C, scale_C, zero_point_C, dtype=(np_dtype)) 1182 qC_hat = torch.ops.quantized.matmul(qA, 1183 qB, 1184 scale=scale_C, 1185 zero_point=zero_point_C) 1186 np.testing.assert_equal(qC, qC_hat.int_repr(), 1187 "Quantized multiplication failed.") 1188 1189 # Using per channel quantization fails 1190 axis = 0 1191 scales_A = torch.rand(size=(A.shape[axis],)) 1192 zero_points_A = torch.randint(low=0, high=5, size=(A.shape[axis],)) 1193 scales_B = torch.rand(size=(B.shape[axis],)) 1194 zero_points_B = torch.randint(low=0, high=5, size=(B.shape[axis],)) 1195 1196 qA = torch.quantize_per_channel(A, 1197 scales=scales_A, 1198 zero_points=zero_points_A, 1199 axis=axis, 1200 dtype=torch.qint8) 1201 qB = torch.quantize_per_channel(B, 1202 scales=scales_B, 1203 zero_points=zero_points_B, 1204 axis=axis, 1205 dtype=torch.qint8) 1206 np.testing.assert_raises_regex(RuntimeError, 1207 ".*per-tensor.*", 1208 torch.ops.quantized.matmul, 1209 qA, 1210 qB, 1211 scale_C, 1212 zero_point_C) 1213 1214 1215 """Tests the correctness of the quantized softmax op.""" 1216 @given(dims=st.lists(st.integers(2, 5), min_size=5, max_size=5)) 1217 def test_qsoftmax(self, dims): 1218 for (num_dims, dim, memory_format) in [ 1219 (2, 1, torch.contiguous_format), # 2d softmax over last dim 1220 (4, 3, torch.contiguous_format), # >2 dims, softmax along last dim 1221 (5, 2, torch.contiguous_format), # >2 dims, softmax along not last dim (requires permute) 1222 (4, 3, torch.channels_last), # >2 dims, softmax along last dim, but not contiguous 1223 (4, 1, torch.channels_last), # Channels Last, doesn't require permute 1224 (5, 1, torch.channels_last_3d), # Channels Last 3D, doesn't require permute 1225 ]: 1226 size = dims[:num_dims] 1227 torch_dtype = torch.quint8 1228 np_dtype = np.uint8 1229 1230 scale_X = 1.3 1231 zero_point_X = 5 1232 X = torch.rand(size=size, dtype=torch.float32) * 8 + zero_point_X 1233 X = X.to(memory_format=memory_format) 1234 1235 scale_Y = 1 / 256 1236 zero_point_Y = 0 1237 1238 qX = torch.quantize_per_tensor(X, 1239 scale=scale_X, 1240 zero_point=zero_point_X, 1241 dtype=torch_dtype) 1242 1243 1244 # softmax ground truth 1245 Y = torch.softmax(qX.dequantize(), dim=dim).numpy() 1246 qY = _quantize(Y, scale_Y, zero_point_Y, dtype=np_dtype) 1247 qY_hat = torch.ops.quantized.softmax(qX, 1248 dim=dim, 1249 output_scale=scale_Y, 1250 output_zero_point=zero_point_Y) 1251 1252 np.testing.assert_equal(qY, qY_hat.int_repr(), 1253 "Quantized softmax failed.") 1254 1255 """Tests the correctness of the quantized softmax op using qnnpack.""" 1256 @skipIfNoQNNPACK 1257 def test_qsoftmax_qnnpack(self): 1258 with override_quantized_engine('qnnpack'): 1259 self.test_qsoftmax() 1260 1261 """Tests the correctness of the mul and mul_relu op.""" 1262 def test_qmul_broadcast(self): 1263 mul_relu = torch.ops.quantized.mul_relu 1264 mul = torch.ops.quantized.mul 1265 mul_out = torch.ops.quantized.mul 1266 mul_relu_out = torch.ops.quantized.mul_relu 1267 1268 # A = torch.arange(-25, 25, dtype=torch.float) 1269 # B = torch.arange(-25, 25, dtype=torch.float) 1270 A = torch.randn(8, 1, 6, 1) 1271 B = torch.randn(7, 1, 5) 1272 scale_A = 3.0 1273 zero_point_A = 7 1274 scale_B = 5.0 1275 zero_point_B = 127 1276 1277 scale_C = 0.5 1278 zero_point_C = 5 1279 1280 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A, 1281 dtype=torch.quint8) 1282 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B, 1283 dtype=torch.quint8) 1284 1285 # mul ground truth 1286 C = (qA.dequantize() * qB.dequantize()).numpy() 1287 qC = _quantize(C, scale_C, zero_point_C) 1288 qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C) 1289 np.testing.assert_equal(qC, qC_hat.int_repr(), 1290 "Quantized multiplication failed.") 1291 1292 """Tests that quantized add works with broadcasting""" 1293 def test_qadd_broadcast(self): 1294 A = torch.randn(1, 1, 4, 4) 1295 B = torch.randn(2, 1, 4, 4) 1296 qA = torch.quantize_per_tensor(A, 0.02, 0, torch.quint8) 1297 qB = torch.quantize_per_tensor(B, 0.04, 2, torch.quint8) 1298 1299 output_scale = 0.01 1300 output_zp = 1 1301 1302 # ground truth 1303 C = qA.dequantize() + qB.dequantize() 1304 qC = torch.quantize_per_tensor(C, output_scale, output_zp, torch.quint8) 1305 1306 # quantized 1307 qC_hat_1 = torch.ops.quantized.add(qA, qB, output_scale, output_zp) 1308 qC_hat_2 = torch.ops.quantized.add(qB, qA, output_scale, output_zp) 1309 1310 self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_1.dequantize())) 1311 self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_2.dequantize())) 1312 1313 """Tests channel shuffle operation on quantized tensors.""" 1314 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, 1315 min_side=2, max_side=32, max_numel=10**5), 1316 qparams=hu.qparams(dtypes=[torch.quint8])), 1317 groups=st.integers(2, 6)) 1318 def test_channel_shuffle(self, X, groups): 1319 X, (scale, zero_point, torch_type) = X 1320 channels = X.shape[-3] 1321 iH, iW = X.shape[-2:] 1322 assume(channels % groups == 0) 1323 1324 a = torch.from_numpy(X) 1325 a = torch.rand(a.shape) 1326 a_out = torch.nn.functional.channel_shuffle(a, groups) 1327 1328 a_ref = torch.quantize_per_tensor(a_out, scale=scale, 1329 zero_point=zero_point, dtype=torch_type) 1330 a_ref = a_ref.dequantize() 1331 qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point, 1332 dtype=torch_type) 1333 1334 a_hat = torch.nn.functional.channel_shuffle(qa, groups) 1335 self.assertEqual(a_ref, a_hat.dequantize(), 1336 msg="torch.nn.functional.channel_shuffle results are off") 1337 1338 """Tests 1D max pool operation on quantized tensors.""" 1339 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=3, 1340 min_side=1, max_side=10), 1341 qparams=hu.qparams()), 1342 kernel=st.sampled_from((3, 5, 7)), 1343 stride=st.sampled_from((None, 1, 2)), 1344 dilation=st.integers(1, 2), 1345 padding=st.integers(0, 2), 1346 ceil_mode=st.booleans()) 1347 def test_max_pool1d(self, X, kernel, stride, dilation, padding, ceil_mode): 1348 X, (scale, zero_point, torch_type) = X 1349 # Check constraints 1350 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1351 iW = X.shape[-1] 1352 oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) 1353 assume(oW > 0) 1354 1355 a = torch.from_numpy(X) 1356 a_pool = torch.nn.functional.max_pool1d(a, kernel_size=kernel, 1357 stride=stride, 1358 padding=padding, 1359 dilation=dilation, 1360 ceil_mode=ceil_mode) 1361 a_ref = torch.quantize_per_tensor(a_pool, scale=scale, 1362 zero_point=zero_point, dtype=torch_type) 1363 a_ref = a_ref.dequantize() 1364 qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point, 1365 dtype=torch_type) 1366 1367 ops_under_test = { 1368 "torch": torch.max_pool1d, 1369 "nn.functional": torch.nn.functional.max_pool1d, 1370 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.max_pool1d, 1371 } 1372 1373 for name, op in ops_under_test.items(): 1374 a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, 1375 dilation=dilation, ceil_mode=ceil_mode) 1376 self.assertEqual(a_ref, a_hat.dequantize(), 1377 msg=f"{name} results are off") 1378 # Test the ops.quantized separately, because None is not treated. 1379 a_hat = torch.ops.quantized.max_pool1d( 1380 qa, kernel_size=_single(kernel), 1381 stride=_single(kernel if stride is None else stride), 1382 padding=_single(padding), dilation=_single(dilation), 1383 ceil_mode=ceil_mode) 1384 self.assertEqual(a_ref, a_hat.dequantize(), 1385 msg="ops.quantized.max_pool1d results are off") 1386 1387 # TODO: merge this test with test_max_pool2d 1388 """Tests 2D cudnn max pool operation on quantized tensors.""" 1389 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, 1390 min_side=1, max_side=10), 1391 # cudnn's support for quantized pooling is limited to 1392 # int8 currently 1393 qparams=hu.qparams(dtypes=[torch.qint8])), 1394 kernel=st.sampled_from((3, 5, 7)), 1395 stride=st.sampled_from((None, 1, 2)), 1396 # currently there is no support for dilation for cudnn 1397 # pooling 1398 dilation=st.integers(1, 1), 1399 padding=st.integers(0, 2), 1400 ceil_mode=st.booleans()) 1401 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 1402 @unittest.skipIf(TEST_CUDNN_VERSION <= 90100, "cuDNN maxpool2d mishandles -128 before v90100") 1403 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 1404 def test_max_pool2d_cudnn(self, X, kernel, stride, dilation, padding, ceil_mode): 1405 X, (scale, zero_point, torch_type) = X 1406 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1407 iH, iW = X.shape[-2:] 1408 oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode) 1409 assume(oH > 0) 1410 oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) 1411 assume(oW > 0) 1412 1413 a = torch.from_numpy(X).to(device="cuda") 1414 a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel, 1415 stride=stride, 1416 padding=padding, dilation=dilation, 1417 ceil_mode=ceil_mode) 1418 a_ref = torch.quantize_per_tensor(a_pool, scale=scale, 1419 zero_point=zero_point, dtype=torch_type) 1420 a_ref = a_ref.dequantize() 1421 qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point, 1422 dtype=torch_type) 1423 1424 # Test the ops.quantized separately, because None is not treated. 1425 a_hat = torch.ops.quantized.max_pool2d( 1426 qa, kernel_size=_pair(kernel), 1427 stride=_pair(kernel if stride is None else stride), 1428 padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode) 1429 self.assertEqual(a_ref, a_hat.dequantize(), 1430 msg="ops.quantized.max_pool2d results are off") 1431 1432 """Tests 2D max pool operation on quantized tensors.""" 1433 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, 1434 min_side=1, max_side=10), 1435 qparams=hu.qparams()), 1436 kernel=st.sampled_from((3, 5, 7)), 1437 stride=st.sampled_from((None, 1, 2)), 1438 dilation=st.integers(1, 2), 1439 padding=st.integers(0, 2), 1440 ceil_mode=st.booleans()) 1441 def test_max_pool2d(self, X, kernel, stride, dilation, padding, ceil_mode): 1442 X, (scale, zero_point, torch_type) = X 1443 # Check constraints 1444 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1445 iH, iW = X.shape[-2:] 1446 oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode) 1447 assume(oH > 0) 1448 oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) 1449 assume(oW > 0) 1450 1451 a = torch.from_numpy(X) 1452 a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel, 1453 stride=stride, 1454 padding=padding, dilation=dilation, 1455 ceil_mode=ceil_mode) 1456 a_ref = torch.quantize_per_tensor(a_pool, scale=scale, 1457 zero_point=zero_point, dtype=torch_type) 1458 a_ref = a_ref.dequantize() 1459 qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point, 1460 dtype=torch_type) 1461 1462 ops_under_test = { 1463 "torch": torch.max_pool2d, 1464 "nn.functional": torch.nn.functional.max_pool2d, 1465 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.max_pool2d, 1466 } 1467 1468 for name, op in ops_under_test.items(): 1469 a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, 1470 dilation=dilation, ceil_mode=ceil_mode) 1471 self.assertEqual(a_ref, a_hat.dequantize(), 1472 msg=f"{name} results are off") 1473 # Test the ops.quantized separately, because None is not treated. 1474 a_hat = torch.ops.quantized.max_pool2d( 1475 qa, kernel_size=_pair(kernel), 1476 stride=_pair(kernel if stride is None else stride), 1477 padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode) 1478 self.assertEqual(a_ref, a_hat.dequantize(), 1479 msg="ops.quantized.max_pool2d results are off") 1480 1481 1482 def test_max_pool2d_pt2e(self): 1483 kernel_list = [2, 3] 1484 stride_list = [1, 2] 1485 padding_list = [0, 2] 1486 dilation_list = [1, 2] 1487 ceil_mode_list = [False, True] 1488 channels_last_input = [False, True] 1489 options = itertools.product(kernel_list, stride_list, padding_list, dilation_list, ceil_mode_list, channels_last_input) 1490 for kernel, stride, padding, dilation, ceil_mode, channels_last in options: 1491 if padding >= (kernel // 2): 1492 # Continue with invalid input 1493 continue 1494 input = torch.randint(0, 8, (1, 3, 8, 8), dtype=torch.uint8) 1495 if channels_last: 1496 input = input.contiguous(memory_format=torch.channels_last) 1497 a_pool = torch.nn.functional.max_pool2d(input.to(torch.float32), kernel_size=kernel, 1498 stride=stride, padding=padding, dilation=dilation, 1499 ceil_mode=ceil_mode).to(torch.uint8) 1500 a_hat = torch.ops.quantized.max_pool2d(input, kernel_size=_pair(kernel), 1501 stride=_pair(stride), padding=_pair(padding), 1502 dilation=_pair(dilation), ceil_mode=ceil_mode) 1503 self.assertEqual(input.is_contiguous(), a_hat.is_contiguous(), 1504 msg="ops.quantized.max_pool2d input output diff memory format") 1505 self.assertEqual(a_pool, a_hat, 1506 msg="ops.quantized.max_pool2d results are off") 1507 1508 1509 """Tests 3D max pool operation on quantized tensors.""" 1510 def test_max_pool3d(self): 1511 torch_types = [torch.qint8, torch.quint8] 1512 kernels = [1, 3] 1513 strides = [1, 3] 1514 dilations = [1, 3] 1515 paddings = [1, 3] 1516 ceil_modes = [True, False] 1517 options = itertools.product(torch_types, kernels, strides, dilations, paddings, ceil_modes) 1518 for torch_type, kernel, stride, dilation, padding, ceil_mode in options: 1519 X = torch.randint(20, 40, (2, 3, 16, 10, 10)).to(torch.float) 1520 scale = 15 1521 zero_point = 20 1522 # Check constraints for invalid input 1523 if not (kernel // 2 >= padding): 1524 continue 1525 iT, iH, iW = X.shape[-3:] 1526 oT = pool_output_shape(iT, kernel, padding, stride, dilation, ceil_mode) 1527 if not (oT > 0): 1528 continue 1529 oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode) 1530 if not (oH > 0): 1531 continue 1532 oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) 1533 if not (oW > 0): 1534 continue 1535 1536 a_pool = torch.nn.functional.max_pool3d(X, kernel_size=kernel, 1537 stride=stride, 1538 padding=padding, dilation=dilation, 1539 ceil_mode=ceil_mode) 1540 a_ref = torch.quantize_per_tensor(a_pool, scale=scale, 1541 zero_point=zero_point, dtype=torch_type) 1542 a_ref = a_ref.dequantize() 1543 qa = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 1544 dtype=torch_type) 1545 ops_under_test = { 1546 "torch": torch.max_pool3d, 1547 "nn.functional": torch.nn.functional.max_pool3d, 1548 } 1549 for name, op in ops_under_test.items(): 1550 a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, 1551 dilation=dilation, ceil_mode=ceil_mode) 1552 self.assertEqual(a_ref, a_hat.dequantize(), 1553 msg=f"{name} results are off") 1554 1555 """Tests max pool operation on NHWC quantized tensors.""" 1556 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, 1557 min_side=1, max_side=10), 1558 qparams=hu.qparams()), 1559 kernel=st.sampled_from((3, 5, 7)), 1560 stride=st.sampled_from((None, 1, 2)), 1561 dilation=st.integers(1, 2), 1562 padding=st.integers(0, 2), 1563 ceil_mode=st.booleans()) 1564 def test_max_pool2d_nhwc(self, X, kernel, stride, dilation, padding, ceil_mode): 1565 X, (scale, zero_point, torch_type) = X 1566 # Ensure we hit the vectorized paths 1567 # 176 = 128 + 32 + 16 1568 # 128 hits the interleaved path 1569 # 32 hits the non-interleaved path 1570 # 16 hits the scalar path 1571 if X.shape[1] < 176: 1572 X = np.repeat(X, 176 / X.shape[1], 1) 1573 # Check constraints 1574 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1575 iH, iW = X.shape[-2:] 1576 oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode) 1577 assume(oH > 0) 1578 oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) 1579 assume(oW > 0) 1580 1581 X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1])) 1582 a = torch.from_numpy(X_nchw).permute([0, 3, 1, 2]) 1583 a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel, 1584 stride=stride, 1585 padding=padding, dilation=dilation, 1586 ceil_mode=ceil_mode) 1587 a_ref = torch.quantize_per_tensor(a_pool, scale=scale, 1588 zero_point=zero_point, dtype=torch_type) 1589 a_ref = a_ref.dequantize() 1590 qa = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale, zero_point=zero_point, 1591 dtype=torch_type).permute([0, 3, 1, 2]) 1592 self.assertTrue(qa.stride() != sorted(qa.stride())) 1593 1594 ops_under_test = { 1595 "torch": torch.max_pool2d, 1596 "nn.functional": torch.nn.functional.max_pool2d, 1597 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.max_pool2d, 1598 } 1599 1600 for name, op in ops_under_test.items(): 1601 a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, 1602 dilation=dilation, ceil_mode=ceil_mode) 1603 self.assertTrue(a_hat.stride() != sorted(a_hat.stride())) 1604 self.assertEqual(a_ref, a_hat.dequantize(), 1605 msg=f"{name} results are off") 1606 # Test the ops.quantized separately, because None is not treated. 1607 a_hat = torch.ops.quantized.max_pool2d( 1608 qa, kernel_size=_pair(kernel), 1609 stride=_pair(kernel if stride is None else stride), 1610 padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode) 1611 self.assertEqual(a_ref, a_hat.dequantize(), 1612 msg="ops.quantized.max_pool2d results are off") 1613 1614 """Tests 3D max pool operation on quantized channel_last tensors.""" 1615 def test_max_pool3d_nhwc(self): 1616 torch_types = [torch.qint8, torch.quint8] 1617 kernels = [1, 3] 1618 strides = [1, 3] 1619 dilations = [1, 3] 1620 paddings = [1, 3] 1621 ceil_modes = [True, False] 1622 options = itertools.product(torch_types, kernels, strides, dilations, paddings, ceil_modes) 1623 for torch_type, kernel, stride, dilation, padding, ceil_mode in options: 1624 X = torch.randint(20, 40, (2, 67, 16, 10, 10)).to(torch.float) 1625 X_copy = copy.deepcopy(X) 1626 X = X.contiguous(memory_format=torch.channels_last_3d) 1627 scale = 15 1628 zero_point = 20 1629 # Check constraints for invalid input 1630 if not (kernel // 2 >= padding): 1631 continue 1632 iT, iH, iW = X.shape[-3:] 1633 oT = pool_output_shape(iT, kernel, padding, stride, dilation, ceil_mode) 1634 if not (oT > 0): 1635 continue 1636 oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode) 1637 if not (oH > 0): 1638 continue 1639 oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode) 1640 if not (oW > 0): 1641 continue 1642 1643 a_pool = torch.nn.functional.max_pool3d(X, kernel_size=kernel, 1644 stride=stride, 1645 padding=padding, dilation=dilation, 1646 ceil_mode=ceil_mode) 1647 a_ref = torch.quantize_per_tensor(a_pool, scale=scale, 1648 zero_point=zero_point, dtype=torch_type) 1649 a_ref = a_ref.dequantize() 1650 qa = torch.quantize_per_tensor(X_copy, scale=scale, zero_point=zero_point, 1651 dtype=torch_type) 1652 qa = qa.contiguous(memory_format=torch.channels_last_3d) 1653 ops_under_test = { 1654 "torch": torch.max_pool3d, 1655 "nn.functional": torch.nn.functional.max_pool3d, 1656 } 1657 for name, op in ops_under_test.items(): 1658 a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, 1659 dilation=dilation, ceil_mode=ceil_mode) 1660 self.assertEqual(a_ref, a_hat.dequantize(), 1661 msg=f"{name} results are off") 1662 1663 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, 1664 min_side=5, max_side=10), 1665 qparams=hu.qparams(dtypes=torch.quint8)), 1666 kernel=st.sampled_from((3, 5)), 1667 stride=st.sampled_from((None, 1, 2)), 1668 padding=st.integers(0, 2), 1669 ceil_mode=st.sampled_from((True, False)), 1670 count_include_pad=st.sampled_from((True, False)), 1671 divisor_override=st.sampled_from((None, None))) 1672 def test_avg_pool2d(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override): 1673 """ 1674 Note: we currently cannot test the divisor_override, because quantized op will clamp the result 1675 within range. However, the float op will not. 1676 """ 1677 X, (scale, zero_point, torch_type) = X 1678 1679 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1680 iH, iW = X.shape[-2:] 1681 oH = pool_output_shape(iH, kernel, padding, stride, dilation=1) 1682 assume(oH > 0) 1683 oW = pool_output_shape(iW, kernel, padding, stride, dilation=1) 1684 assume(oW > 0) 1685 X = torch.from_numpy(X) 1686 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 1687 dtype=torch_type) 1688 X = qX.dequantize() 1689 # Run reference on float tensor and then quantize the result for comparison 1690 X_ref = torch.nn.functional.avg_pool2d( 1691 X, kernel_size=kernel, stride=stride, padding=padding, 1692 ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override) 1693 ops_under_test = { 1694 "nn.functional": torch.nn.functional.avg_pool2d, 1695 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool2d, 1696 } 1697 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 1698 for name, op in ops_under_test.items(): 1699 qX_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode, 1700 count_include_pad=count_include_pad, divisor_override=divisor_override) 1701 qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(), 1702 dtype=torch_type) 1703 1704 self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0, 1705 msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr())) 1706 self.assertEqual(scale, qX_hat.q_scale(), 1707 msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) 1708 self.assertEqual(zero_point, qX_hat.q_zero_point(), 1709 msg=error_message.format(name + '.zero_point', scale, 1710 qX_hat.q_zero_point())) 1711 1712 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, 1713 min_side=5, max_side=10), 1714 qparams=hu.qparams(dtypes=torch.qint8)), 1715 kernel=st.sampled_from((4, 5)), 1716 stride=st.sampled_from((None, 1, 2)), 1717 padding=st.integers(0, 2), 1718 ceil_mode=st.sampled_from((True, False)), 1719 count_include_pad=st.sampled_from((True, False)), 1720 divisor_override=st.sampled_from((None, None))) 1721 def test_avg_pool2d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override): 1722 """ 1723 Note: 1) we currently cannot test the divisor_override, because quantized op will clamp the result 1724 within range. However, the float op will not. 1725 2) we cannot test the qint32, since the float point precision is much lower than int32 for big number, 1726 which will make the test be very flaky. 1727 """ 1728 X, (scale, zero_point, torch_type) = X 1729 H, W = X.shape[-2:] 1730 1731 1732 if X.shape[1] < 176: 1733 X = np.repeat(X, 176 / X.shape[1], 1) 1734 1735 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1736 iH, iW = X.shape[-2:] 1737 oH = pool_output_shape(iH, kernel, padding, stride, dilation=1) 1738 assume(oH > 0) 1739 oW = pool_output_shape(iW, kernel, padding, stride, dilation=1) 1740 assume(oW > 0) 1741 1742 X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1])) 1743 1744 qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale, 1745 zero_point=zero_point, dtype=torch_type).permute([0, 3, 1, 2]) 1746 X = qX.dequantize() 1747 1748 # Run reference on int_repr + round to avoid double rounding error. 1749 X_ref = torch.nn.functional.avg_pool2d( 1750 X, kernel_size=kernel, stride=stride, padding=padding, 1751 ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override) 1752 1753 self.assertTrue(qX.stride() != sorted(qX.stride())) 1754 ops_under_test = { 1755 "nn.functional": torch.nn.functional.avg_pool2d, 1756 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool2d, 1757 } 1758 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 1759 for name, op in ops_under_test.items(): 1760 X_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode, 1761 count_include_pad=count_include_pad, divisor_override=divisor_override) 1762 self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) 1763 qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(), 1764 dtype=torch_type) 1765 1766 self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0, 1767 msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr())) 1768 self.assertEqual(scale, X_hat.q_scale(), 1769 msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) 1770 self.assertEqual(zero_point, X_hat.q_zero_point(), 1771 msg=error_message.format(name + '.zero_point', scale, 1772 X_hat.q_zero_point())) 1773 1774 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5, 1775 min_side=5, max_side=10), 1776 qparams=hu.qparams(dtypes=torch.quint8)), 1777 kernel=st.sampled_from((3, 5)), 1778 stride=st.sampled_from((None, 1, 2)), 1779 padding=st.integers(0, 2), 1780 ceil_mode=st.sampled_from((True, False)), 1781 count_include_pad=st.sampled_from((True, False)), 1782 divisor_override=st.sampled_from((None, None))) 1783 def test_avg_pool3d(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override): 1784 """ 1785 Note: we currently cannot test the divisor_override, because quantized op will clamp the result 1786 within range. However, the float op will not. 1787 """ 1788 X, (scale, zero_point, torch_type) = X 1789 1790 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1791 iD, iH, iW = X.shape[-3:] 1792 oD = pool_output_shape(iD, kernel, padding, stride, dilation=1) 1793 assume(oD > 0) 1794 oH = pool_output_shape(iH, kernel, padding, stride, dilation=1) 1795 assume(oH > 0) 1796 oW = pool_output_shape(iW, kernel, padding, stride, dilation=1) 1797 assume(oW > 0) 1798 1799 X = torch.from_numpy(X) 1800 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 1801 dtype=torch_type) 1802 X = qX.dequantize() 1803 # Run reference on float tensor and then quantize the result for comparison 1804 X_ref = torch.nn.functional.avg_pool3d( 1805 X, kernel_size=kernel, stride=stride, padding=padding, 1806 ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override) 1807 1808 ops_under_test = { 1809 "nn.functional": torch.nn.functional.avg_pool3d, 1810 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool3d, 1811 } 1812 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 1813 for name, op in ops_under_test.items(): 1814 qX_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode, 1815 count_include_pad=count_include_pad, divisor_override=divisor_override) 1816 qX_ref = torch.quantize_per_tensor(X_ref, scale=qX_hat.q_scale(), zero_point=qX_hat.q_zero_point(), 1817 dtype=torch_type) 1818 self.assertEqual(qX_ref.int_repr().to(torch.double), qX_hat.int_repr().to(torch.double), atol=1.0, rtol=0, 1819 msg=error_message.format(name, qX_ref.int_repr(), qX_hat.int_repr())) 1820 self.assertEqual(scale, qX_hat.q_scale(), 1821 msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) 1822 self.assertEqual(zero_point, qX_hat.q_zero_point(), 1823 msg=error_message.format(name + '.zero_point', scale, 1824 qX_hat.q_zero_point())) 1825 1826 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5, 1827 min_side=5, max_side=10), 1828 qparams=hu.qparams(dtypes=torch.qint8)), 1829 kernel=st.sampled_from((4, 5)), 1830 stride=st.sampled_from((None, 1, 2)), 1831 padding=st.integers(0, 2), 1832 ceil_mode=st.sampled_from((True, False)), 1833 count_include_pad=st.sampled_from((True, False)), 1834 divisor_override=st.sampled_from((None, None))) 1835 def test_avg_pool3d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override): 1836 """ 1837 Note: 1) we currently cannot test the divisor_override, because quantized op will clamp the result 1838 within range. However, the float op will not. 1839 2) we cannot test the qint32, since the float point precision is much lower than int32 for big number, 1840 which will make the test be very flaky. 1841 """ 1842 X, (scale, zero_point, torch_type) = X 1843 D, H, W = X.shape[-3:] 1844 1845 1846 if X.shape[1] < 176: 1847 X = np.repeat(X, 176 / X.shape[1], 1) 1848 1849 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 1850 iD, iH, iW = X.shape[-3:] 1851 oD = pool_output_shape(iD, kernel, padding, stride, dilation=1) 1852 assume(oD > 0) 1853 oH = pool_output_shape(iH, kernel, padding, stride, dilation=1) 1854 assume(oH > 0) 1855 oW = pool_output_shape(iW, kernel, padding, stride, dilation=1) 1856 assume(oW > 0) 1857 1858 X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 4, 1])) 1859 1860 qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale, 1861 zero_point=zero_point, dtype=torch_type).permute([0, 4, 1, 2, 3]) 1862 X = qX.dequantize() 1863 1864 # Run reference on int_repr + round to avoid double rounding error. 1865 X_ref = torch.nn.functional.avg_pool3d( 1866 X, kernel_size=kernel, stride=stride, padding=padding, 1867 ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override) 1868 1869 self.assertTrue(qX.stride() != sorted(qX.stride())) 1870 ops_under_test = { 1871 "nn.functional": torch.nn.functional.avg_pool3d, 1872 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.avg_pool3d, 1873 } 1874 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 1875 for name, op in ops_under_test.items(): 1876 X_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode, 1877 count_include_pad=count_include_pad, divisor_override=divisor_override) 1878 self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) 1879 qX_ref = torch.quantize_per_tensor(X_ref, scale=X_hat.q_scale(), zero_point=X_hat.q_zero_point(), 1880 dtype=torch_type) 1881 1882 self.assertEqual(qX_ref.int_repr().to(torch.double), X_hat.int_repr().to(torch.double), atol=1.0, rtol=0, 1883 msg=error_message.format(name, qX_ref.int_repr(), X_hat.int_repr())) 1884 self.assertEqual(scale, X_hat.q_scale(), 1885 msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) 1886 self.assertEqual(zero_point, X_hat.q_zero_point(), 1887 msg=error_message.format(name + '.zero_point', scale, 1888 X_hat.q_zero_point())) 1889 1890 """Tests adaptive average pool operation on NHWC quantized tensors.""" 1891 def test_adaptive_avg_pool2d_nhwc(self): 1892 side_lens = (range(1, 10)) 1893 dim_lens = (range(3, 4)) 1894 torch_type = torch.qint8 1895 zero_points = (0, 1) 1896 combined = [side_lens, dim_lens, zero_points] 1897 test_cases = itertools.product(*combined) 1898 for test_case in test_cases: 1899 output_size_h = random.randint(1, 10) 1900 output_size_w = random.randint(1, 10) 1901 side_len, dim_len, zero_point = test_case 1902 shapes = [side_len] * dim_len 1903 X, X_scale, X_zero_point = \ 1904 _get_random_tensor_and_q_params(shapes, 1.0, zero_point) 1905 X = np.array(X) 1906 scale = 1 1907 H, W = X.shape[-2:] 1908 output_size_h = min(output_size_h, H) 1909 output_size_w = min(output_size_w, W) 1910 if output_size_h == output_size_w: 1911 output_size = output_size_h 1912 else: 1913 output_size = (output_size_h, output_size_w) 1914 1915 if X.shape[1] < 176: 1916 X = np.repeat(X, 176 / X.shape[1], 1) 1917 1918 if X.ndim == 4: 1919 X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1])) 1920 X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2]) 1921 qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), 1922 scale=scale, 1923 zero_point=zero_point, 1924 dtype=torch_type).permute([0, 3, 1, 2]) 1925 else: # ndim == 3 1926 X_nchw = np.ascontiguousarray(X.transpose([1, 2, 0])) 1927 X = torch.from_numpy(X_nchw).permute([2, 0, 1]) 1928 qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), 1929 scale=scale, 1930 zero_point=zero_point, 1931 dtype=torch_type).permute([2, 0, 1]) 1932 1933 # Run reference on int_repr + round to avoid double rounding error. 1934 X_ref = torch.nn.functional.adaptive_avg_pool2d(qX.int_repr().to(torch.double), output_size).round() 1935 1936 self.assertTrue(qX.stride() != sorted(qX.stride())) 1937 1938 ops_under_test = { 1939 "nn.functional": torch.nn.functional.adaptive_avg_pool2d, 1940 "ao.nn.quantized.functional": 1941 torch.ao.nn.quantized.functional.adaptive_avg_pool2d, 1942 } 1943 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 1944 for name, op in ops_under_test.items(): 1945 X_hat = op(qX, output_size=output_size) 1946 self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) 1947 self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, 1948 msg=error_message.format(name, X_ref, X_hat.int_repr()), 1949 exact_dtype=False) 1950 self.assertEqual(scale, X_hat.q_scale(), 1951 msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) 1952 self.assertEqual(zero_point, X_hat.q_zero_point(), 1953 msg=error_message.format(name + '.zero_point', scale, 1954 X_hat.q_zero_point())) 1955 1956 @unittest.skip("not currently working and feature isn't used") 1957 def test_adaptive_avg_pool(self): 1958 1959 side_lens = (range(1, 10)) 1960 dim_lens = (range(3, 5)) 1961 torch_type = torch.qint8 1962 zero_points = (0, 1) 1963 combined = [side_lens, dim_lens, zero_points] 1964 test_cases = itertools.product(*combined) 1965 for test_case in test_cases: 1966 output_size_d = random.randint(1, 10) 1967 output_size_h = random.randint(1, 10) 1968 output_size_w = random.randint(1, 10) 1969 side_len, dim_len, zero_point = test_case 1970 shapes = [side_len] * dim_len 1971 X, X_scale, X_zero_point = \ 1972 _get_random_tensor_and_q_params(shapes, 1.0, zero_point) 1973 X = np.array(X) 1974 scale = 1 1975 ndim = X.ndim 1976 dim_to_check = [] 1977 if ndim <= 4: 1978 dim_to_check.append(2) 1979 if ndim >= 4: 1980 dim_to_check.append(3) 1981 1982 D, H, W = X.shape[-3:] 1983 output_size_d = min(output_size_d, D) 1984 output_size_h = min(output_size_h, H) 1985 output_size_w = min(output_size_w, W) 1986 1987 X = torch.from_numpy(X) 1988 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 1989 dtype=torch_type) 1990 1991 for dim in dim_to_check: 1992 if dim == 2: 1993 if output_size_h == output_size_w: 1994 output_size = output_size_h 1995 else: 1996 output_size = (output_size_h, output_size_w) 1997 elif dim == 3: 1998 if output_size_d == output_size_h == output_size_w: 1999 output_size = output_size_h 2000 else: 2001 output_size = (output_size_d, output_size_h, output_size_w) 2002 2003 # Run reference on int_repr + round to avoid double rounding error. 2004 ref_op = getattr(torch.nn.functional, f'adaptive_avg_pool{dim}d') 2005 X_ref = ref_op(qX.int_repr().to(torch.float), output_size).round() 2006 2007 ops_under_test = { 2008 "nn.functional": 2009 getattr(torch.nn.functional, f'adaptive_avg_pool{dim}d'), 2010 "nn.quantized.functional": 2011 getattr(torch.ao.nn.quantized.functional, f'adaptive_avg_pool{dim}d'), 2012 "ao.nn.quantized.functional": 2013 getattr(torch.ao.nn.quantized.functional, f'adaptive_avg_pool{dim}d') 2014 } 2015 2016 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 2017 2018 for name, op in ops_under_test.items(): 2019 # TODO: torch.cuda.is_available() should be swapped for a flag that checks if cudnn 2020 # is enabled in the build when cudnn supports adaptive average pooling 2021 devices = ["cpu", "cuda"] if (dim == 2 and torch.cuda.is_available()) else ["cpu"] 2022 for device in devices: 2023 qX_hat = op(qX.to(device=device), output_size=output_size) 2024 self.assertEqual( 2025 X_ref, qX_hat.int_repr(), atol=1.0, 2026 rtol=0, msg=error_message.format(name, X_ref, qX_hat), exact_dtype=False) 2027 self.assertEqual( 2028 scale, qX_hat.q_scale(), 2029 msg=error_message.format(name + '.scale', scale, 2030 qX_hat.q_scale())) 2031 self.assertEqual( 2032 zero_point, qX_hat.q_zero_point(), 2033 msg=error_message.format(name + '.zero_point', scale, 2034 qX_hat.q_zero_point())) 2035 2036 """Tests adaptive average pool operation on NHWC quantized tensors.""" 2037 def test_adaptive_avg_pool3d_ndhwc(self): 2038 side_lens = (range(1, 10)) 2039 dim_lens = (range(4, 5)) 2040 torch_type = torch.qint8 2041 zero_point = 0 2042 combined = [side_lens, dim_lens] 2043 test_cases = itertools.product(*combined) 2044 for test_case in test_cases: 2045 output_size_d = random.randint(1, 10) 2046 output_size_h = random.randint(1, 10) 2047 output_size_w = random.randint(1, 10) 2048 side_len, dim_len = test_case 2049 shapes = [side_len] * dim_len 2050 X, X_scale, X_zero_point = \ 2051 _get_random_tensor_and_q_params(shapes, 1.0, zero_point) 2052 X = np.array(X) 2053 scale = 1 2054 D, H, W = X.shape[-3:] 2055 output_size_d = min(output_size_d, D) 2056 output_size_h = min(output_size_h, H) 2057 output_size_w = min(output_size_w, W) 2058 if output_size_d == output_size_h == output_size_w: 2059 output_size = output_size_h 2060 else: 2061 output_size = (output_size_d, output_size_h, output_size_w) 2062 2063 if X.shape[1] < 176: 2064 X = np.repeat(X, 176 / X.shape[1], 1) 2065 2066 if X.ndim == 5: 2067 X_ncdhw = np.ascontiguousarray(X.transpose([0, 2, 3, 4, 1])) 2068 X = torch.from_numpy(X_ncdhw).permute([0, 4, 1, 2, 3]) 2069 qX = torch.quantize_per_tensor(torch.from_numpy(X_ncdhw), 2070 scale=scale, 2071 zero_point=zero_point, 2072 dtype=torch_type).permute([0, 4, 1, 2, 3]) 2073 else: # ndim == 4 2074 X_ncdhw = np.ascontiguousarray(X.transpose([1, 2, 3, 0])) 2075 X = torch.from_numpy(X_ncdhw).permute([3, 0, 1, 2]) 2076 qX = torch.quantize_per_tensor(torch.from_numpy(X_ncdhw), 2077 scale=scale, 2078 zero_point=zero_point, 2079 dtype=torch_type).permute([3, 0, 1, 2]) 2080 2081 # Run reference on int_repr + round to avoid double rounding error. 2082 X_ref = torch.nn.functional.adaptive_avg_pool3d( 2083 qX.int_repr().to(torch.double), output_size).round() 2084 2085 self.assertTrue(qX.stride() != sorted(qX.stride())) 2086 2087 ops_under_test = { 2088 "nn.functional": torch.nn.functional.adaptive_avg_pool3d, 2089 "ao.nn.quantized.functional": 2090 torch.ao.nn.quantized.functional.adaptive_avg_pool3d, 2091 } 2092 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 2093 for name, op in ops_under_test.items(): 2094 X_hat = op(qX, output_size=output_size) 2095 self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) 2096 self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, 2097 msg=error_message.format(name, X_ref, X_hat.int_repr()), 2098 exact_dtype=False) 2099 self.assertEqual(scale, X_hat.q_scale(), 2100 msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) 2101 self.assertEqual(zero_point, X_hat.q_zero_point(), 2102 msg=error_message.format(name + '.zero_point', scale, 2103 X_hat.q_zero_point())) 2104 2105 def test_qtopk(self): 2106 x_dims = [3, 4] # Num elements in the shape 2107 sides = [3, 5] # Side of the tensor generated 2108 dims = [0, 1, 2, 3] # dimension over which to perform topk 2109 largest = [False, True] # Return largest or smallest element 2110 sorted = [False, True] # Return sorted or not 2111 dtypes = [torch.qint8, torch.quint8] 2112 is_nhwc = [False, True] # Is input in the NHWC format? 2113 2114 test_cases = itertools.product(x_dims, sides, dims, largest, sorted, dtypes, is_nhwc) 2115 k = 2 2116 for x_dim, side, dim, larg, sort, dtype, nhwc in test_cases: 2117 if nhwc and x_dim != 4: # NHWC requires 4 dimensions 2118 continue 2119 if dim >= x_dim: # Dimension to find top-k for should exist 2120 continue 2121 shape = [side] * x_dim 2122 X, scale, zp = _get_random_tensor_and_q_params(shape, 1.0, dtype) 2123 qX = torch.quantize_per_tensor(X, scale, zp, dtype) 2124 2125 if nhwc: 2126 qX = qX.permute([0, 3, 1, 2]) 2127 X = np.transpose(X, [0, 3, 1, 2]) 2128 2129 unquantized_out = torch.topk(qX.dequantize(), k, dim=dim, largest=larg, sorted=sort) 2130 2131 values = torch.quantize_per_tensor(X, scale, zp, dtype) 2132 indices = torch.tensor(X).long() 2133 2134 quantized_out = torch.topk(qX, k, dim=dim, largest=larg, sorted=sort) 2135 2136 assert len(unquantized_out) == len(quantized_out) 2137 torch.testing.assert_close(quantized_out[0].dequantize(), unquantized_out[0]) 2138 torch.testing.assert_close(quantized_out[1], unquantized_out[1]) 2139 2140 """Tests quantize concatenation (both fused and not).""" 2141 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, 2142 min_side=1, max_side=10), 2143 qparams=hu.qparams()), 2144 num=st.integers(1, 4), 2145 dim=st.integers(1, 4), 2146 relu=st.booleans()) 2147 def test_cat(self, X, num, dim, relu): 2148 tensors_q = [] 2149 tensors_ref = [] 2150 X, (scale, zero_point, torch_type) = X 2151 assume(dim < X.ndim) 2152 X = torch.from_numpy(X) 2153 new_shape = np.array(X.shape) 2154 new_shape[dim] = 0 2155 for idx in range(num): 2156 tensors_q.append(torch.quantize_per_tensor(X, scale, zero_point, 2157 torch_type)) 2158 tensors_ref.append(X) 2159 new_shape[dim] += tensors_ref[-1].shape[dim] 2160 2161 cat_ref = torch.cat(tensors_ref, dim=dim) 2162 cat_ref = torch.quantize_per_tensor(cat_ref, scale, zero_point, torch_type) 2163 cat_ref = cat_ref.dequantize() 2164 2165 if relu: 2166 cat_ref = F.relu(cat_ref) 2167 q_cat_op = torch.ops.quantized.cat_relu 2168 q_cat_out_op = torch.ops.quantized.cat_relu_out 2169 else: 2170 q_cat_op = torch.ops.quantized.cat 2171 q_cat_out_op = torch.ops.quantized.cat_out 2172 2173 cat_q = q_cat_op(tensors_q, dim=dim, scale=scale, 2174 zero_point=zero_point) 2175 cat_q = cat_q.dequantize() 2176 np.testing.assert_equal(cat_ref.numpy(), cat_q.numpy()) 2177 2178 cat_q_out = torch._empty_affine_quantized( 2179 list(new_shape), scale=scale, 2180 zero_point=zero_point, dtype=torch_type) 2181 q_cat_out_op(tensors_q, dim=dim, out=cat_q_out) 2182 cat_q_out = cat_q_out.dequantize() 2183 np.testing.assert_equal(cat_ref.numpy(), cat_q_out.numpy()) 2184 2185 # Test the cat on per-channel quantized tensor. 2186 ch_axis = 1 2187 scales = torch.from_numpy(np.array([1.0] * X.shape[ch_axis])) 2188 scales = scales.to(torch.float64) 2189 zero_points = torch.from_numpy(np.array([0] * X.shape[ch_axis])) 2190 zero_points = zero_points.to(torch.long) 2191 tensors_q[0] = torch.quantize_per_channel( 2192 X, scales, zero_points, axis=ch_axis, dtype=torch_type) 2193 with self.assertRaisesRegex(RuntimeError, "supported.*cat"): 2194 cat_q = q_cat_op(tensors_q, dim=ch_axis, scale=scale, 2195 zero_point=zero_point) 2196 2197 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, 2198 min_side=5, max_side=10), 2199 qparams=hu.qparams()), 2200 size=st.sampled_from((1, 3, 5, 10)), 2201 mode=st.sampled_from(("bilinear", "nearest", "nearest-exact")), 2202 scale_factor=st.sampled_from((None, 1.5, 2.0)), 2203 align_corners=st.sampled_from((True, False)), 2204 nhwc_layout=st.sampled_from((True, False))) 2205 def test_interpolate(self, X, size, mode, scale_factor, align_corners, nhwc_layout): 2206 """ 2207 This test cover upsample_nearest2d and upsample_bilinear2d 2208 """ 2209 X, (scale, zero_point, torch_type) = X 2210 2211 if scale_factor is not None: 2212 size = None 2213 if mode in ("nearest", "nearest-exact"): 2214 align_corners = None 2215 2216 if nhwc_layout: 2217 if X.shape[1] < 176: 2218 X = np.repeat(X, 176 / X.shape[1], 1) 2219 2220 X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1])) 2221 X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2]) 2222 2223 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2224 dtype=torch_type).permute([0, 3, 1, 2]) 2225 else: 2226 X = torch.from_numpy(X) 2227 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2228 dtype=torch_type) 2229 2230 X_ref = torch.nn.functional.interpolate( 2231 qX.int_repr().to(torch.float), size=size, scale_factor=scale_factor, 2232 mode=mode, align_corners=align_corners) 2233 2234 ops_under_test = { 2235 "nn.functional": torch.nn.functional.interpolate, 2236 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.interpolate, 2237 } 2238 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 2239 for name, op in ops_under_test.items(): 2240 qX_hat = op(qX, size=size, scale_factor=scale_factor, 2241 mode=mode, align_corners=align_corners) 2242 self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0, 2243 msg=f"{name} results are off: qX_hat={qX_hat.int_repr()} X_ref={X_ref}", 2244 exact_dtype=False) 2245 self.assertEqual(scale, qX_hat.q_scale(), 2246 msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) 2247 self.assertEqual(zero_point, qX_hat.q_zero_point(), 2248 msg=error_message.format(name + '.zero_point', scale, 2249 qX_hat.q_zero_point())) 2250 2251 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5, 2252 min_side=5, max_side=10), 2253 qparams=hu.qparams()), 2254 size=st.sampled_from((1, 3, 5, 5, 10)), 2255 mode=st.sampled_from(("nearest", "nearest-exact")), 2256 scale_factor=st.sampled_from((None, 1.5, 2.0)), 2257 align_corners=st.sampled_from((True, False)), 2258 nhwc_layout=st.sampled_from((True, False))) 2259 def test_interpolate3d(self, X, size, mode, scale_factor, align_corners, nhwc_layout): 2260 """ 2261 This test cover upsample_nearest3d 2262 """ 2263 X, (scale, zero_point, torch_type) = X 2264 if scale_factor is not None: 2265 size = None 2266 2267 align_corners = None 2268 2269 if nhwc_layout: 2270 if X.shape[1] < 176: 2271 X = np.repeat(X, 176 / X.shape[1], 1) 2272 2273 X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 4, 1])) 2274 X = torch.from_numpy(X_nchw).permute([0, 4, 1, 2, 3]) 2275 2276 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2277 dtype=torch_type).permute([0, 4, 1, 2, 3]) 2278 else: 2279 X = torch.from_numpy(X) 2280 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2281 dtype=torch_type) 2282 X_ref = torch.nn.functional.interpolate( 2283 qX.int_repr().to(torch.float), size=size, scale_factor=scale_factor, 2284 mode=mode, align_corners=align_corners) 2285 2286 ops_under_test = { 2287 "nn.functional": torch.nn.functional.interpolate, 2288 "ao.nn.quantized.functional": torch.ao.nn.quantized.functional.interpolate, 2289 } 2290 2291 error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}" 2292 for name, op in ops_under_test.items(): 2293 qX_hat = op(qX, size=size, scale_factor=scale_factor, 2294 mode=mode, align_corners=align_corners) 2295 self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0, 2296 msg=f"{name} results are off: qX_hat={qX_hat.int_repr()}, X_ref={X_ref}", exact_dtype=False) 2297 self.assertEqual(scale, qX_hat.q_scale(), 2298 msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) 2299 self.assertEqual(zero_point, qX_hat.q_zero_point(), 2300 msg=error_message.format(name + '.zero_point', scale, 2301 qX_hat.q_zero_point())) 2302 2303 """Tests quantize concatenation (both fused and not).""" 2304 @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, 2305 min_side=1, max_side=10), 2306 qparams=hu.qparams()), 2307 relu=st.booleans()) 2308 def test_cat_nhwc(self, X, relu): 2309 # X is NHWC 2310 X, (scale, zero_point, torch_type) = X 2311 2312 # Tile out X so # channels is > 64 2313 X = np.repeat(X, 70 / X.shape[3], 3) 2314 X = torch.from_numpy(np.ascontiguousarray(X)) 2315 Y = X.clone() 2316 Y = torch.from_numpy(np.ascontiguousarray(Y)) 2317 # We add a fast path in qcat: when inputs share the same scale and zero_point, 2318 # it will go direct memcpy instead of dequant-cat-quant. 2319 for scaleX, scaleY in ((scale, scale), (scale, scale * 1.1)): 2320 # Here, we quantize and get quantized tensors in NHWC for both dims and strides. The 2321 # permute switches it so that the tensor looks like NCHW but it laid out in memory as 2322 # NHWC. 2323 qX = torch.quantize_per_tensor(X, scaleX, zero_point, torch_type).permute([0, 3, 1, 2]) 2324 qY = torch.quantize_per_tensor(Y, scaleY, zero_point, torch_type).permute([0, 3, 1, 2]) 2325 2326 ref = torch.cat([qX.dequantize(), qY.dequantize()], dim=1) 2327 if relu: 2328 ref[ref < 0] = 0.0 2329 ref = torch.quantize_per_tensor(ref, scale=scale, zero_point=zero_point, dtype=torch_type) 2330 2331 if relu: 2332 out = torch.ops.quantized.cat_relu( 2333 [qX, qY], dim=1, scale=scale, zero_point=zero_point) 2334 else: 2335 out = torch.ops.quantized.cat([qX, qY], dim=1, scale=scale, zero_point=zero_point) 2336 2337 torch.testing.assert_close(out.dequantize(), ref.dequantize()) 2338 self.assertNotEqual(out.stride(), sorted(out.stride())) 2339 2340 @override_qengines 2341 def test_mean(self): 2342 scale_list = (1, 0.25) 2343 zero_point_list = (0, 2) 2344 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4, 4)) 2345 dtypes = (torch.quint8, torch.qint8) 2346 dims = ((), (-1,), (0,), (1,), (2,), (3,), (0, 1), (1, 2), (3, 4)) 2347 test_cases = itertools.product(scale_list, zero_point_list, shapes, dtypes, dims) 2348 op = torch.mean 2349 for scale, zp, shape, dtype, dim in test_cases: 2350 if not all(d < len(shape) for d in dim): 2351 continue 2352 X = torch.randn(*shape) * 10 2353 qX = torch.quantize_per_tensor(X, scale, zp, dtype) 2354 Y = op(qX.dequantize(), dim) 2355 Y = torch.quantize_per_tensor(Y, scale, zp, dtype).dequantize() 2356 qY = op(qX, dim) 2357 self.assertEqual(Y, qY.dequantize()) 2358 2359 @skipIfNoQNNPACK 2360 @given(keep=st.booleans()) 2361 def test_quantized_mean_qnnpack(self, keep): 2362 with override_quantized_engine("qnnpack"): 2363 # using multiple of 4 sizes to satisfy pytorch_q8gavgpool_ukernel_up8xm__sse2() 4-byte alignment demand under ASAN 2364 in_dim = (4, 4, 4, 4) 2365 if keep: 2366 out_dim = (4, 4, 1, 1) 2367 else: 2368 out_dim = (4, 4) 2369 X = torch.ones(in_dim) 2370 Y = torch.ones(out_dim) 2371 XQ = torch.quantize_per_tensor(X, scale=0.2, zero_point=0, dtype=torch.quint8) 2372 YQ = torch.quantize_per_tensor(Y, scale=0.2, zero_point=0, dtype=torch.quint8) 2373 MQ = XQ.mean((2, 3), keepdim=keep) 2374 self.assertTrue(torch.equal(MQ, YQ)) 2375 2376 @override_qengines 2377 def test_std(self): 2378 scale_list = (1, 0.25) 2379 zero_point_list = (0, 2) 2380 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4), (4, 4, 4, 4, 4)) 2381 dtypes = (torch.quint8, torch.qint8) 2382 dims = ((), (-1,), (0,), (1,), (2,), (3,), (0, 1), (1, 2), (3, 4)) 2383 unbiased_list = (True, False) 2384 keep_dim_list = (True, False) 2385 test_cases = itertools.product(scale_list, zero_point_list, shapes, 2386 dtypes, dims, unbiased_list, keep_dim_list) 2387 op = torch.std 2388 for scale, zp, shape, dtype, dim, unbiased, keep_dim in test_cases: 2389 if not all(d < len(shape) for d in dim): 2390 continue 2391 X = torch.randn(*shape) * 10 2392 qX = torch.quantize_per_tensor(X, scale, zp, dtype) 2393 Y = op(qX.dequantize(), dim, unbiased, keep_dim) 2394 Y = torch.quantize_per_tensor(Y, scale, zp, dtype).dequantize() 2395 qY = op(qX, dim, unbiased, keep_dim) 2396 self.assertEqual(Y, qY.dequantize()) 2397 2398 """Tests the correctness of the quantized equal op.""" 2399 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 2400 qparams=hu.qparams()), 2401 X2=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 2402 qparams=hu.qparams()), 2403 X_per_channel=st.booleans(), 2404 X2_per_channel=st.booleans()) 2405 def test_equal(self, X, X2, X_per_channel, X2_per_channel): 2406 X, X_params = X 2407 (scale, zero_point, torch_type) = X_params 2408 X2, X2_params = X2 2409 (scale2, zero_point2, torch_type2) = X2_params 2410 2411 X = torch.from_numpy(X) 2412 if X_per_channel: 2413 X_scheme = 'per_channel' 2414 channels = X.shape[-1] 2415 qX = torch.quantize_per_channel( 2416 X, 2417 scales=torch.tensor([scale] * channels), 2418 zero_points=torch.tensor([zero_point] * channels), 2419 dtype=torch_type, 2420 axis=X.ndim - 1) 2421 else: 2422 X_scheme = 'per_tensor' 2423 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2424 dtype=torch_type) 2425 X2 = torch.from_numpy(X2) 2426 if X2_per_channel: 2427 X2_scheme = 'per_channel' 2428 channels = X2.shape[-1] 2429 qX2 = torch.quantize_per_channel( 2430 X2, 2431 scales=torch.tensor([scale2] * channels), 2432 zero_points=torch.tensor([zero_point2] * channels), 2433 dtype=torch_type2, 2434 axis=X2.ndim - 1) 2435 else: 2436 X2_scheme = 'per_tensor' 2437 qX2 = torch.quantize_per_tensor(X2, scale=scale2, zero_point=zero_point2, 2438 dtype=torch_type2) 2439 2440 def equal_ref(qX, qX2): 2441 if qX.qscheme() != qX2.qscheme(): 2442 return False 2443 if qX.shape != qX2.shape: 2444 return False 2445 if qX.dtype != qX2.dtype: 2446 return False 2447 if qX.qscheme() == torch.per_tensor_affine: 2448 if qX.q_scale() != qX2.q_scale(): 2449 return False 2450 if qX.q_zero_point() != qX2.q_zero_point(): 2451 return False 2452 elif qX.qscheme() == torch.per_channel_affine: 2453 if (qX.q_per_channel_scales() != 2454 qX2.q_per_channel_scales()).any(): 2455 return False 2456 if (qX.q_per_channel_zero_points() != 2457 qX2.q_per_channel_zero_points()).any(): 2458 return False 2459 else: 2460 raise NotImplementedError("Don't know what to do with", 2461 qX.qscheme()) 2462 if (qX.int_repr().to(float) != qX2.int_repr().to(float)).any(): 2463 return False 2464 return True 2465 2466 self.assertEqual(qX.equal(qX), equal_ref(qX, qX)) 2467 self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2)) 2468 2469 """Tests quantized equal op with input of non-quantized tensor.""" 2470 def test_quantized_equal(self,): 2471 x = torch.rand(1) 2472 y = torch.quantize_per_tensor(x, scale=0.5, zero_point=0, dtype=torch.qint8) 2473 self.assertTrue(not torch.equal(x, y)) 2474 self.assertTrue(not torch.equal(y, x)) 2475 2476 @skipIfNoFBGEMM 2477 def test_group_norm(self): 2478 # hypothesis is flaky for this test, create test cases manually 2479 batches_list = (1, 7) 2480 num_groups_list = (1, 4) 2481 channels_per_groups = (1, 36, 72) 2482 elements_per_channels = (8, 128, 1024) 2483 torch_types = (torch.qint8, torch.quint8) 2484 y_scales = (0.1, 4.23) 2485 y_zero_points = (0, 1) 2486 channels_last_list = [True, False] 2487 affine_list = [True, False] 2488 combined = [batches_list, num_groups_list, channels_per_groups, elements_per_channels, 2489 torch_types, y_scales, y_zero_points, channels_last_list, affine_list] 2490 test_cases = itertools.product(*combined) 2491 2492 with override_quantized_engine("fbgemm"): 2493 for test_case in test_cases: 2494 2495 batches, num_groups, channels_per_group, elements_per_channel, \ 2496 torch_type, Y_scale, Y_zero_point, channels_last, \ 2497 affine = test_case 2498 num_channels = num_groups * channels_per_group 2499 # minimum rank for channels_last 2500 shapes = (batches, num_channels, elements_per_channel, 1) 2501 2502 # In the FP kernel, sums and sums of squares are calculated in floating point. 2503 # In the int8 and uint8 versions of the quantized kernel, they are 2504 # calculated in integer arithmetic (which is exact). 2505 # Because of this, the numerics do not always match exactly which is 2506 # expected and acceptable. We do the following to allow this failure 2507 # in this test: 2508 # 1. do not use Hypothesis to generate the input tensor. Hypothesis 2509 # favors homogeneous inputs in its search strategies which isn't 2510 # representative of the inputs we care about, and tends to maximize 2511 # this particular numerics difference. 2512 # 2. allow a small % of off by Y_scale errors. Even when the 2513 # variance of the input is high, there can be off by one errors 2514 # in the result if the input value happens to fall exactly on 2515 # the bin boundary of the output scale. 2516 # 2517 # If we want the numerics to match we could switch to calculating 2518 # mean+var in floating point in the future, at the cost of speed. 2519 X, X_scale, X_zero_point = \ 2520 _get_random_tensor_and_q_params(shapes, 1.0, torch_type) 2521 2522 # Initialize the weights non-randomly for reproducibility 2523 if affine: 2524 weight = torch.ones(num_channels).float() * 0.5 2525 bias = torch.ones(num_channels).float() 2526 for i in range(num_channels): 2527 weight[i] *= i 2528 bias[i] *= i 2529 else: 2530 weight = None 2531 bias = None 2532 2533 eps = 0.001 2534 2535 qX = torch.quantize_per_tensor(X, X_scale, X_zero_point, torch_type) 2536 if channels_last: 2537 qX = qX.contiguous(memory_format=torch.channels_last) 2538 dqX = qX.dequantize() 2539 2540 # Enforce non-homogeneous inputs 2541 for batch_idx in range(batches): 2542 for group_idx in range(num_groups): 2543 ch_start = group_idx * channels_per_group 2544 ch_end = ch_start + channels_per_group 2545 group_vals = dqX[batch_idx][ch_start:ch_end] 2546 assume( 2547 float(torch.unique(group_vals).shape[0]) / group_vals.numel() > 0.001 2548 or group_vals.numel() < 5) 2549 2550 qY = torch.ops.quantized.group_norm(qX, num_groups, weight, bias, eps, Y_scale, Y_zero_point) 2551 2552 dqY_hat = F.group_norm(dqX, num_groups=num_groups, weight=weight, bias=bias, eps=eps) 2553 qY_hat = torch.quantize_per_tensor(dqY_hat, Y_scale, Y_zero_point, torch_type) 2554 2555 # Due to the numerics difference mentioned above between calculating 2556 # the variance in float vs int, the results can still be slightly 2557 # different. 2558 dqY = qY.dequantize() 2559 dqY_hat = qY_hat.dequantize() 2560 diff = dqY - dqY_hat 2561 2562 # off-by-one errors are magnitude of Y_scale 2563 num_diff = torch.sum(diff > Y_scale * 1.0001) 2564 pct_diff = float(num_diff) / (diff.numel() + 1e-5) 2565 num_diff_off_by_one = torch.sum((diff > 0) * (diff <= Y_scale)) 2566 pct_diff_off_by_one = float(num_diff_off_by_one) / (diff.numel() + 1e-5) 2567 2568 self.assertTrue(pct_diff < 1e-6) 2569 self.assertTrue(pct_diff_off_by_one < 0.01) 2570 2571 @skipIfNoFBGEMM 2572 def test_instance_norm(self): 2573 max_sides = (4, 5) 2574 shape_list = ([2, 2, 2, 2], [8, 8, 8, 8], [11, 11, 11, 11]) 2575 torch_types = (torch.qint8, torch.quint8) 2576 y_scales = (0.1, 4.23) 2577 y_zero_points = (0, 1) 2578 channels_last_list = (True, False) 2579 affine_list = (True, False) 2580 combined = [shape_list, torch_types, y_scales, y_zero_points, channels_last_list, affine_list] 2581 test_cases_product = itertools.product(*combined) 2582 test_cases = list(test_cases_product) 2583 # NB: Add just one test case to test overflow, but this case is too slow to run 2584 # internally in @fbcode//mode/dev, the long pole is the 4x calls to torch.sort 2585 # inside torch.unique current implementation 2586 if not IS_SANDCASTLE: 2587 test_cases.append([ 2588 [1, 4, 224, 224, 160], # shape, 2589 torch.qint8, # torch_type 2590 0.1, # scale 2591 0, # zero_point 2592 False, # channels_last 2593 True, # affine 2594 ]) 2595 with override_quantized_engine("fbgemm"): 2596 for test_case in test_cases: 2597 2598 shapes, torch_type, Y_scale, Y_zero_point, channels_last, affine = test_case 2599 if channels_last and shapes.__len__() >= 5: 2600 # required rank 4 tensor to use channels_last format 2601 continue 2602 2603 # In the FP kernel, sums and sums of squares are calculated in floating point. 2604 # In the int8 and uint8 versions of the quantized kernel, they are 2605 # calculated in integer arithmetic (which is exact). 2606 # Because of this, the numerics do not always match exactly which is 2607 # expected and acceptable. We do the following to allow this failure 2608 # in this test: 2609 # 1. do not use Hypothesis to generate the input tensor. Hypothesis 2610 # favors homogeneous inputs in its search strategies which isn't 2611 # representative of the inputs we care about, and tends to maximize 2612 # this particular numerics difference. 2613 # 2. allow a small % of off by Y_scale errors. Even when the 2614 # variance of the input is high, there can be off by one errors 2615 # in the result if the input value happens to fall exactly on 2616 # the bin boundary of the output scale. 2617 # 2618 # If we want the numerics to match we could switch to calculating 2619 # mean+var in floating point in the future, at the cost of speed. 2620 X, X_scale, X_zero_point = \ 2621 _get_random_tensor_and_q_params(shapes, 1.0, torch_type) 2622 2623 num_channels = shapes[1] 2624 if affine: 2625 weight = torch.rand(num_channels).float() * 0.5 2626 bias = torch.rand(num_channels).float() 2627 for i in range(num_channels): 2628 weight[i] *= i 2629 bias[i] *= i 2630 else: 2631 weight = None 2632 bias = None 2633 eps = 0.001 2634 2635 qX = torch.quantize_per_tensor(X, X_scale, X_zero_point, torch_type) 2636 if channels_last: 2637 qX = qX.contiguous(memory_format=torch.channels_last) 2638 dqX = qX.dequantize() 2639 2640 # Enforce non-homogeneous inputs 2641 batches = shapes[0] 2642 for batch_idx in range(batches): 2643 for ch_idx in range(num_channels): 2644 ch_vals = dqX[batch_idx][ch_idx] 2645 assume( 2646 float(torch.unique(ch_vals).shape[0]) / ch_vals.numel() > 0.01 2647 or ch_vals.numel() < 5 or ch_vals.numel() > 25600) 2648 2649 qY = torch.ops.quantized.instance_norm(qX, weight, bias, eps, Y_scale, Y_zero_point) 2650 2651 dqY_hat = F.instance_norm(dqX, weight=weight, bias=bias, eps=eps) 2652 qY_hat = torch.quantize_per_tensor(dqY_hat, Y_scale, Y_zero_point, torch_type) 2653 2654 # Due to the numerics difference mentioned above between calculating 2655 # the variance in float vs int, the results can still be slightly 2656 # different. 2657 dqY = qY.dequantize() 2658 dqY_hat = qY_hat.dequantize() 2659 diff = dqY - dqY_hat 2660 2661 # off-by-one errors are magnitude of Y_scale 2662 num_diff = torch.sum(diff > Y_scale * 1.0001) 2663 pct_diff = float(num_diff) / (diff.numel() + 1e-5) 2664 num_diff_off_by_one = torch.sum((diff > 0) * (diff <= Y_scale)) 2665 pct_diff_off_by_one = float(num_diff_off_by_one) / (diff.numel() + 1e-5) 2666 2667 self.assertTrue(pct_diff < 1e-6) 2668 self.assertTrue(pct_diff_off_by_one < 0.01) 2669 2670 @skipIfNoFBGEMM 2671 def test_batch_norm_relu(self): 2672 # hypothesis too slow for this test, create test cases manually 2673 max_sides = (2, 3, 4, 5) 2674 side_lens = (1, 8, 11) 2675 torch_types = (torch.qint8, torch.quint8) 2676 combined = [max_sides, side_lens, torch_types] 2677 test_cases = itertools.product(*combined) 2678 2679 with override_quantized_engine("fbgemm"): 2680 for test_case in test_cases: 2681 max_side, side_len, torch_type = test_case 2682 Y_zero_point = 1 2683 Y_scale = 0.5 2684 2685 shapes = [side_len] * max_side 2686 X, scale_x, zero_point_x = \ 2687 _get_random_tensor_and_q_params(shapes, 1.0, torch_type) 2688 dtype_x = torch_type 2689 2690 c = X.shape[1] 2691 mean = torch.rand(c).float() 2692 var = torch.rand(c).float() 2693 weight = torch.rand(c).float() 2694 bias = torch.rand(c).float() 2695 eps = 0.001 2696 qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x) 2697 if len(X.shape) == 2 or len(X.shape) == 3: 2698 qy = torch.ops.quantized.batch_norm1d_relu( 2699 qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) 2700 elif len(X.shape) == 4: 2701 qy = torch.ops.quantized.batch_norm2d_relu( 2702 qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) 2703 else: 2704 qy = torch.ops.quantized.batch_norm3d_relu( 2705 qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) 2706 2707 2708 float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias, 2709 running_mean=mean, running_var=var, 2710 training=False, momentum=0, eps=eps).numpy() 2711 2712 float_ref_relu = float_ref.copy() 2713 float_ref_relu[float_ref < 0] = 0 2714 quantize_ref = torch.quantize_per_tensor( 2715 torch.from_numpy(float_ref_relu), Y_scale, Y_zero_point, dtype_x) 2716 self.assertEqual( 2717 qy.int_repr().numpy(), 2718 quantize_ref.int_repr().numpy(), 2719 msg=f"{qy} vs {quantize_ref}") 2720 2721 @skipIfNoFBGEMM 2722 def test_batch_norm(self): 2723 # hypothesis too slow for this test, create test cases manually 2724 max_sides = (2, 3, 4, 5) 2725 side_lens = (1, 8, 11) 2726 torch_types = (torch.qint8, torch.quint8) 2727 combined = [max_sides, side_lens, torch_types] 2728 test_cases = itertools.product(*combined) 2729 2730 with override_quantized_engine("fbgemm"): 2731 for test_case in test_cases: 2732 max_side, side_len, torch_type = test_case 2733 Y_zero_point = 1 2734 Y_scale = 0.5 2735 2736 shapes = [side_len] * max_side 2737 X, scale_x, zero_point_x = \ 2738 _get_random_tensor_and_q_params(shapes, 1.0, torch_type) 2739 dtype_x = torch_type 2740 2741 c = X.shape[1] 2742 mean = torch.rand(c).float() 2743 var = torch.rand(c).float() 2744 weight = torch.rand(c).float() 2745 bias = torch.rand(c).float() 2746 eps = 0.001 2747 qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x) 2748 if len(X.shape) == 2 or len(X.shape) == 3: 2749 qy = torch.ops.quantized.batch_norm1d( 2750 qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) 2751 elif len(X.shape) == 4: 2752 qy = torch.ops.quantized.batch_norm2d( 2753 qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) 2754 elif len(X.shape) == 5: 2755 qy = torch.ops.quantized.batch_norm3d( 2756 qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) 2757 2758 float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias, 2759 running_mean=mean, running_var=var, training=False, 2760 momentum=0, eps=eps) 2761 quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x) 2762 self.assertEqual( 2763 qy.int_repr().numpy(), quantize_ref.int_repr().numpy(), 2764 msg=f"{qy} vs {quantize_ref}") 2765 2766 @override_qengines 2767 def test_empty_batch(self): 2768 scale = 1.0 2769 zero_point = 0 2770 X = torch.ones((0, 2, 4, 4), dtype=torch.float32) 2771 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2772 dtype=torch.quint8) 2773 2774 # upsample_nearest2d 2775 qY = torch.nn.functional.upsample_nearest(qX, scale_factor=2) 2776 np.testing.assert_equal(qY.size(), (0, 2, 8, 8), 2777 "Quantized upsample_nearsest2d with batch size 0 failed.") 2778 2779 # relu 2780 qY = torch.nn.functional.relu(qX) 2781 np.testing.assert_equal(qY.size(), qX.size(), 2782 "Quantized relu with batch size 0 failed.") 2783 2784 # tanh 2785 qY = torch.tanh(qX) 2786 np.testing.assert_equal(qY.size(), qX.size(), 2787 "Quantized tanh with batch size 0 failed.") 2788 # sigmoid 2789 qY = torch.sigmoid(qX) 2790 np.testing.assert_equal(qY.size(), qX.size(), 2791 "Quantized sigmoid with batch size 0 failed.") 2792 2793 # interpolate 2794 op = torch.ao.nn.quantized.functional.interpolate 2795 for mode in ["nearest", "bilinear", "nearest-exact"]: 2796 qY = op(qX, scale_factor=2, mode=mode) 2797 np.testing.assert_equal(qY.size(), (0, 2, 8, 8), 2798 "Quantized interpolate with batch size 0 failed.") 2799 2800 # avg_pool 2801 kernel = (2, 2) 2802 stride = (1, 1) 2803 padding = (0, 0) 2804 op = torch.ao.nn.quantized.functional.avg_pool2d 2805 qY = op(qX, kernel, stride, padding) 2806 np.testing.assert_equal(qY.size(), (0, 2, 3, 3), 2807 "Quantized avg_pool2d with batch size 0 failed.") 2808 2809 # adaptive_avg_pool 2810 op = torch.ao.nn.quantized.functional.adaptive_avg_pool2d 2811 qY = op(qX, (3, 3)) 2812 np.testing.assert_equal(qY.size(), (0, 2, 3, 3), 2813 "Quantized adaptive_avg_pool2d with batch size 0 failed.") 2814 2815 # max_pool 2816 dilation = (1, 1) 2817 qY = torch.ops.quantized.max_pool2d(qX, kernel, stride, padding, dilation, ceil_mode=False) 2818 oH = pool_output_shape(4, 2, 0, 1, 1) 2819 oW = pool_output_shape(4, 2, 0, 1, 1) 2820 np.testing.assert_equal(qY.size(), (0, 2, oH, oW), 2821 "Quantized maxpool2d with batch size 0 failed.") 2822 2823 # hardtanh 2824 qY = torch.ao.nn.quantized.functional.hardtanh(qX, -1, 6) 2825 np.testing.assert_equal(qY.size(), qX.size(), 2826 "Quantized hardtanh with batch size 0 failed.") 2827 2828 # mul 2829 qY = torch.ops.quantized.mul(qX, qX, 1.0, 0) 2830 np.testing.assert_equal(qY.size(), qX.size(), 2831 "Quantized mul with batch size 0 failed.") 2832 # add 2833 qY = torch.ops.quantized.add(qX, qX, 1.0, 0) 2834 np.testing.assert_equal(qY.size(), qX.size(), 2835 "Quantized addition with batch size 0 failed.") 2836 2837 # conv 2838 w = torch.randn((2, 2, 2, 2), dtype=torch.float) 2839 qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8) 2840 bias_float = torch.ones(2, dtype=torch.float) 2841 strides = [1, 1] 2842 pads = [0, 0] 2843 dilations = [1, 1] 2844 2845 w_packed = torch.ops.quantized.conv2d_prepack(qw, bias_float, strides, pads, dilations, 1) 2846 result = torch.ops.quantized.conv2d(qX, w_packed, 1.0, 0) 2847 self.assertEqual(result.shape, (0, 2, 3, 3)) 2848 2849 # linear 2850 X = torch.ones((0, 2), dtype=torch.float32) 2851 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 2852 dtype=torch.quint8) 2853 w = torch.randn((2, 2), dtype=torch.float) 2854 qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8) 2855 w_packed = torch.ops.quantized.linear_prepack(qw, bias_float) 2856 result = torch.ops.quantized.linear(qX, w_packed, 1.0, 0) 2857 self.assertEqual(result.shape, (0, 2)) 2858 2859 # dynamic linear 2860 result = torch.ops.quantized.linear_dynamic(X, w_packed) 2861 self.assertEqual(result.shape, (0, 2)) 2862 2863 @override_qengines 2864 def test_linear_bias_unpack(self): 2865 """ 2866 Verifies the correctness of bias() and unpack() API for LinearPackedParamBase. 2867 """ 2868 bias_float = torch.ones(2, dtype=torch.float) 2869 w = torch.randn((2, 2), dtype=torch.float) 2870 qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8) 2871 w_packed = torch.ops.quantized.linear_prepack(qw, bias_float) 2872 # test bias() 2873 self.assertEqual(w_packed.bias(), bias_float) 2874 # test unpack() 2875 self.assertEqual(w_packed.unpack()[0], qw) 2876 2877 def test_advanced_indexing(self): 2878 """ 2879 Verifies that the x[:, [0], :, :] syntax works for quantized tensors. 2880 """ 2881 for dtype in (torch.qint8, torch.quint8, torch.qint32): 2882 scale = 0.1 2883 zp = 0 2884 x_q = torch.quantize_per_tensor( 2885 torch.randn(1, 4, 4, 4), scale, zp, dtype) 2886 # reference 2887 x_fp32 = x_q.dequantize() 2888 2889 # single dim, single index 2890 x_q_s1 = x_q[:, [0], :, :] 2891 x_fp32_s1 = x_fp32[:, [0], :, :] 2892 x_fp32_s1_ref = \ 2893 torch.quantize_per_tensor(x_fp32_s1, scale, zp, dtype) 2894 self.assertEqual(x_q_s1, x_fp32_s1_ref) 2895 2896 # multiple dim, single index 2897 x_q_s2 = x_q[:, [0], [2], :] 2898 x_fp32_s2 = x_fp32[:, [0], [2], :] 2899 x_fp32_s2_ref = \ 2900 torch.quantize_per_tensor(x_fp32_s2, scale, zp, dtype) 2901 self.assertEqual(x_q_s2, x_fp32_s2_ref) 2902 2903 # single dim, multiple indices 2904 x_q_s3 = x_q[:, [2, 0, 1], :, :] 2905 x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :] 2906 x_fp32_s3_ref = \ 2907 torch.quantize_per_tensor(x_fp32_s3, scale, zp, dtype) 2908 self.assertEqual(x_q_s3, x_fp32_s3_ref) 2909 2910 # multiple dim, multiple indices 2911 x_q_s4 = x_q[:, [2, 0, 1], :, [1]] 2912 x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]] 2913 x_fp32_s4_ref = \ 2914 torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype) 2915 self.assertEqual(x_q_s4, x_fp32_s4_ref) 2916 2917 @override_qengines 2918 def test_custom_module_lstm(self): 2919 qengine = torch.backends.quantized.engine 2920 2921 batch_size = 4 2922 seq_len = 8 2923 input_size = 12 2924 2925 hidden_size = 8 2926 num_layers = 2 2927 2928 dropout = 0 # This is not supported 2929 2930 Bias = [False, True] 2931 Batch_first = [False, True] 2932 Bidirectional = [False, True] 2933 2934 dtype = np.uint8 2935 qtype = torch.quint8 2936 2937 x = np.random.randn(seq_len, batch_size, input_size) 2938 scale, zero_point = _calculate_dynamic_qparams(x, dtype=dtype) 2939 x = torch.from_numpy(x).to(torch.float) 2940 qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, 2941 dtype=qtype) 2942 x = qx.dequantize() 2943 2944 with torch.no_grad(): 2945 for bias, batch_first, bidirectional in itertools.product( 2946 Bias, Batch_first, Bidirectional): 2947 # Assume 12dB is sufficient for functional equivalence 2948 # Without the bias, linear performs poorly 2949 min_power = 10 if bias else 5 2950 max_mse = 5e-6 if bias else 5e-1 2951 2952 if batch_first: 2953 x = x.reshape(batch_size, seq_len, input_size) 2954 qx = qx.reshape(batch_size, seq_len, input_size) 2955 else: 2956 x = x.reshape(seq_len, batch_size, input_size) 2957 qx = qx.reshape(seq_len, batch_size, input_size) 2958 2959 lstm = torch.nn.Sequential( 2960 torch.nn.LSTM(input_size, hidden_size, 2961 num_layers=num_layers, 2962 bias=bias, batch_first=batch_first, 2963 dropout=dropout, 2964 bidirectional=bidirectional)) 2965 lstm.eval() 2966 y_ref = lstm(x) 2967 2968 # Prepare 2969 lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 2970 lstm_prepared = torch.ao.quantization.prepare(lstm) 2971 self.assertTrue(hasattr(lstm_prepared[0], 'layers')) 2972 self.assertEqual(num_layers, len(lstm_prepared[0].layers)) 2973 assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM 2974 2975 # Calibrate 2976 y = lstm_prepared(x) 2977 self.assertEqual(y_ref, y) 2978 2979 # Quantize 2980 lstm_quantized = torch.ao.quantization.convert(lstm_prepared) 2981 assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM 2982 qy = lstm_quantized(qx) 2983 2984 snr = _snr(y, qy) 2985 snr = [snr[0]] + snr[1] 2986 2987 for signal, mse, power in snr: 2988 self.assertTrue( 2989 power > min_power or mse < max_mse, 2990 msg=(f"Error is too high: SNR(dB): {power}, " 2991 f"Signal: {signal}, MSE: {mse}")) 2992 2993 # Trace 2994 jit_qmodule = torch.jit.trace(lstm_quantized, qx) 2995 2996 # Script 2997 jit_qmodule = torch.jit.script(lstm_quantized) 2998 2999 @override_qengines 3000 def test_custom_module_multi_head_attention(self): 3001 class MultiheadAttentionModel(torch.nn.Module): 3002 def __init__(self, *args, **kwargs): 3003 super().__init__() 3004 self.layer = torch.nn.MultiheadAttention(*args, **kwargs) 3005 3006 def forward( 3007 self, 3008 query, 3009 key, 3010 value, 3011 key_padding_mask: Optional[torch.Tensor] = None, 3012 need_weights: bool = True, 3013 attn_mask: Optional[torch.Tensor] = None, 3014 ): 3015 return self.layer(query, key, value, key_padding_mask, need_weights, attn_mask) 3016 3017 qengine = torch.backends.quantized.engine 3018 3019 min_power = 30 3020 max_mse = 2 3021 3022 num_heads = 16 3023 batch_size = 4 3024 target_seq_length = 128 3025 source_seq_length = 64 3026 qembed_dim = 512 # Must be divisible by the number of heads 3027 kembed_dim = 128 3028 vembed_dim = 256 3029 3030 dropout = 0.0 # This is not supported 3031 3032 Bias = [False, True] 3033 Add_bias_kv = [False, True] 3034 Add_zero_attn = [False, True] 3035 3036 dtype = np.uint8 3037 qtype = torch.quint8 3038 3039 for kdim, vdim in ((kembed_dim, vembed_dim), (None, None)): 3040 fp_data = [ 3041 torch.randn(target_seq_length, batch_size, qembed_dim), # Q 3042 torch.randn(source_seq_length, batch_size, 3043 qembed_dim if kdim is None else kembed_dim), # K 3044 torch.randn(source_seq_length, batch_size, 3045 qembed_dim if vdim is None else vembed_dim) # V 3046 ] 3047 3048 q_data = [] 3049 reduce_range = (qengine in ('x86', 'fbgemm', 'onednn')) 3050 for idx, x in enumerate(fp_data): 3051 scale, zero_point = _calculate_dynamic_qparams( 3052 x, dtype=dtype, reduce_range=reduce_range) 3053 x = x.to(torch.float) 3054 qx = torch.quantize_per_tensor(x, scale=scale, 3055 zero_point=zero_point, dtype=qtype) 3056 q_data.append(qx) 3057 3058 # Dequantize the data back for reference 3059 fp_data[idx] = qx.dequantize() 3060 3061 with torch.no_grad(): 3062 for bias, add_bias_kv, add_zero_attn in itertools.product( 3063 Bias, Add_bias_kv, Add_zero_attn): 3064 mha = MultiheadAttentionModel(qembed_dim, num_heads, dropout, 3065 bias, add_bias_kv, add_zero_attn, 3066 kdim=kdim, vdim=vdim) 3067 mha.eval() 3068 3069 # Prepare 3070 if qengine_is_onednn(): 3071 # `reduce_range` is False by default for ONEDNN backend 3072 # but the test fails on earlier CPUs without VNNI. 3073 # So we use a default qconfig with `reduce_range=True` here 3074 mha.qconfig = torch.ao.quantization.get_default_qconfig() 3075 else: 3076 mha.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 3077 mha_prepared = torch.ao.quantization.prepare( 3078 mha) 3079 3080 # Calibrate 3081 y = mha_prepared(*fp_data) 3082 y_ref = mha(*fp_data) 3083 # Check the result of the prepare 3084 self.assertEqual(y_ref[0], y[0]) # Attention 3085 self.assertEqual(y_ref[1], y[1]) # Weight 3086 3087 # Quantize 3088 mha_quantized = torch.ao.quantization.convert(mha_prepared) 3089 3090 for name, param in mha_quantized.named_parameters(): 3091 self.assertTrue("in_proj_weight" not in name) 3092 3093 qy = mha_quantized(*q_data) 3094 3095 # Reference result 3096 mha.layer = mha_quantized.layer.dequantize() 3097 y_ref = mha(*fp_data) 3098 3099 snr = _snr(y, qy) 3100 for signal, mse, power in snr: 3101 self.assertTrue( 3102 power > min_power or mse < max_mse, 3103 msg=(f"Error is too high: SNR(dB): {power}, " 3104 f"Signal: {signal}, MSE: {mse}; " 3105 f"Run with bias={bias}, " 3106 f"add_bias_kv={add_bias_kv}, " 3107 f"add_zero_attn={add_zero_attn}")) 3108 3109 # Verify the result is scriptable 3110 mha_quantized_scripted = torch.jit.script(mha_quantized) 3111 3112 3113class TestDynamicQuantizedOps(TestCase): 3114 """Tests the correctness of the dynamic quantized linear and linear_relu op.""" 3115 @override_qengines 3116 @given( 3117 batch_size=st.integers(1, 4), 3118 input_channels=st.integers(16, 32), 3119 output_channels=st.integers(4, 8), 3120 use_bias=st.booleans(), 3121 use_relu=st.booleans(), 3122 use_multi_dim_input=st.booleans(), 3123 use_channelwise=st.booleans(), 3124 reduce_range=st.booleans()) 3125 def test_qlinear(self, batch_size, input_channels, output_channels, 3126 use_bias, use_relu, use_multi_dim_input, use_channelwise, reduce_range): 3127 if torch.backends.quantized.engine == 'qnnpack': 3128 reduce_range = False 3129 3130 qlinear_prepack = torch.ops.quantized.linear_prepack 3131 if use_relu: 3132 qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic 3133 else: 3134 qlinear_dynamic = torch.ops.quantized.linear_dynamic 3135 3136 if use_multi_dim_input: 3137 batch_size *= 3 # Test the multi-dim input tensor 3138 3139 X_scale = 1.0 3140 X_zp = 0 3141 X_value_min = 0 3142 X_value_max = 255 3143 if reduce_range: 3144 X_value_max = 127 3145 X_q0 = np.round(np.random.rand(batch_size, input_channels) * 3146 (X_value_max - X_value_min) + X_value_min).astype(np.uint8) 3147 X_q0[0, 0] = X_value_min 3148 X_q0[0, 1] = X_value_max 3149 3150 # W_scale = 1.0 3151 # W_zp = 0 3152 W_scales = np.ones(output_channels) 3153 W_zps = np.zeros(output_channels).astype(int) 3154 W_value_min = -128 3155 W_value_max = 127 3156 W_q0 = np.round( 3157 np.random.rand(output_channels, input_channels) 3158 * (W_value_max - W_value_min) 3159 + W_value_min 3160 ).astype(np.int8) 3161 W_q0[0, 0] = W_value_min 3162 W_q0[1, 0] = W_value_max 3163 3164 b_value_min = -10 3165 b_value_max = 10 3166 b_q0 = np.round( 3167 np.random.rand(output_channels) * 3168 (b_value_max - b_value_min) + b_value_min 3169 ).astype(np.int32) if use_bias else None 3170 3171 if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'): 3172 avoid_vpmaddubsw_overflow_linear( 3173 batch_size, 3174 input_channels, 3175 output_channels, 3176 X_q0, 3177 X_value_min, 3178 X_value_max, 3179 W_q0, 3180 W_value_min, 3181 W_value_max, 3182 ) 3183 3184 X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float) 3185 if use_multi_dim_input: 3186 X_fp32 = X_fp32.view(3, int(batch_size / 3), input_channels) 3187 3188 # W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8) 3189 # We currently only check the case where W_scale = 1.0, W_zp = 0. 3190 3191 if use_channelwise: 3192 W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scales.reshape( 3193 (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float) 3194 W_q = torch.quantize_per_channel(W_fp32, scales=torch.from_numpy(W_scales), 3195 zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8) 3196 b_fp32 = torch.from_numpy( 3197 _dequantize(b_q0, X_scale * W_scales, 0) 3198 ).to(dtype=torch.float) if use_bias else None 3199 else: 3200 W_fp32 = torch.from_numpy(_dequantize( 3201 W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float) 3202 W_q = torch.quantize_per_tensor(W_fp32, scale=W_scales[0], zero_point=( 3203 W_zps[0].astype(int).item()), dtype=torch.qint8) 3204 b_fp32 = torch.from_numpy( 3205 _dequantize(b_q0, X_scale * int(W_scales[0].item()), 0) 3206 ).to(dtype=torch.float) if use_bias else None 3207 3208 # Observe X_fp32 and determine X_scale and X_zero_point, this should match 3209 # internals of dynamic linear. 3210 X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8, reduce_range) 3211 X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) 3212 3213 # Weight prepacking operator for dynamic quantized Linear 3214 W_prepack = qlinear_prepack(W_q, b_fp32) 3215 # Dynamic quantized Linear operator with prepacked weight 3216 Y_fp32 = qlinear_dynamic(X_q.dequantize(), W_prepack, reduce_range) 3217 # Y_fp32 = qlinear_dynamic(X_fp32, W_prepack, b_fp32) 3218 3219 Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32) 3220 # Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) 3221 # if use_multi_dim_input: 3222 # Y_fp32_ref = Y_fp32_ref.view(3, int(batch_size / 3), output_channels) 3223 3224 if use_relu: 3225 Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 3226 self.assertEqual(Y_fp32, Y_fp32_ref, 3227 msg="torch.ops.quantized.linear_dynamic results are off") 3228 3229 @skipIfNoFBGEMM 3230 @given( 3231 batch_size=st.integers(1, 4), 3232 input_channels=st.integers(16, 32), 3233 output_channels=st.integers(4, 8), 3234 ) 3235 def test_qlinear_legacy(self, batch_size, input_channels, output_channels): 3236 X_scale = 1.0 3237 X_zp = 0 3238 X_value_min = 0 3239 X_value_max = 255 3240 X_q0 = np.round(np.random.rand(batch_size, input_channels) * ( 3241 X_value_max - X_value_min) + X_value_min 3242 ).astype(np.uint8) 3243 X_q0[0, 0] = X_value_min 3244 X_q0[0, 1] = X_value_max 3245 3246 W_scale = 1.0 3247 W_zp = 0 3248 W_value_min = -128 3249 W_value_max = 127 3250 W_q0 = np.round( 3251 np.random.rand(output_channels, input_channels) 3252 * (W_value_max - W_value_min) 3253 + W_value_min 3254 ).astype(np.int8) 3255 W_q0[0, 0] = W_value_min 3256 W_q0[1, 0] = W_value_max 3257 3258 b_value_min = -10 3259 b_value_max = 10 3260 b_q0 = np.round( 3261 np.random.rand(output_channels) * (b_value_max - b_value_min) + 3262 b_value_min 3263 ).astype(np.int32) 3264 3265 avoid_vpmaddubsw_overflow_linear( 3266 batch_size, 3267 input_channels, 3268 output_channels, 3269 X_q0, 3270 X_value_min, 3271 X_value_max, 3272 W_q0, 3273 W_value_min, 3274 W_value_max, 3275 ) 3276 3277 X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float) 3278 W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float) 3279 b_fp32 = torch.from_numpy( 3280 _dequantize(b_q0, X_scale * W_scale, 0) 3281 ).to(dtype=torch.float) 3282 3283 W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8) 3284 W_q = torch.quantize_per_tensor(W_fp32, scale=W_scale, zero_point=W_zp, dtype=torch.qint8) 3285 3286 # Observe X_fp32 and determine X_scale and X_zero_point, this should match 3287 # internals of dynamic linear. 3288 X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8) 3289 X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) 3290 3291 W_int8, col_offsets, W_scale, W_zp = torch.fbgemm_linear_quantize_weight(W_q.dequantize()) 3292 W_prepack = torch.fbgemm_pack_quantized_matrix(W_int8.clone(), W_int8.size(1), W_int8.size(0)) 3293 # Quantized Linear operator with prepacked weight 3294 Y_fp32 = torch.fbgemm_linear_int8_weight( 3295 X_q.dequantize(), W_q.dequantize(), W_prepack, col_offsets, 3296 W_scale, W_zp, b_fp32) 3297 3298 Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32) 3299 # Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) 3300 3301 self.assertEqual(Y_fp32, Y_fp32_ref, 3302 msg="torch.ops.quantized.fbgemm_linear_dynamic results are off") 3303 3304 @skipIfNoFBGEMM 3305 @given( 3306 input_channels=st.integers(16, 32), 3307 output_channels=st.integers(4, 8), 3308 exponent=st.integers(0, 8)) 3309 def test_linear_prepack_fp16_numerics(self, input_channels, output_channels, exponent): 3310 w = torch.randn(output_channels, input_channels) * 10**exponent 3311 bias = None 3312 w_packed_fp16 = torch.ops.quantized.linear_prepack_fp16(w, bias) 3313 w_unpacked_fp16 = torch.ops.quantized.linear_unpack_fp16(w_packed_fp16) 3314 w_fp16 = w.to(torch.float16).to(torch.float32) 3315 self.assertTrue(torch.equal(w_fp16, w_unpacked_fp16[0])) 3316 3317 @skipIfNoFBGEMM 3318 def test_qlinear_dynamic_fp16(self): 3319 3320 options = itertools.product( 3321 (2, 4), # batch_size 3322 (4, 5, 12), # input_channels 3323 (4, 7, 8), # output_channels 3324 (True, False), # use_bias 3325 (True, False), # use_relu 3326 ) 3327 for batch_size, input_channels, output_channels, use_bias, use_relu in options: 3328 qlinear_prepack = torch.ops.quantized.linear_prepack_fp16 3329 if use_relu: 3330 qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic_fp16 3331 else: 3332 qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16 3333 3334 x = torch.randn(batch_size, input_channels) 3335 w = torch.randn(output_channels, input_channels) 3336 bias = torch.randn(output_channels) if use_bias else None 3337 3338 w_packed = qlinear_prepack(w, bias) 3339 out = qlinear_dynamic(x, w_packed) 3340 3341 # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors 3342 # output is FP32 3343 w_fp16 = w.to(torch.float16).to(torch.float32) 3344 ref = F.linear(x, w_fp16, bias) 3345 if use_relu: 3346 ref.relu_() 3347 3348 self.assertEqual(out, ref) 3349 3350 @skipIfNoFBGEMM 3351 def test_unpacked_qlinear_dynamic_fp16(self): 3352 3353 options = itertools.product( 3354 (2, 4), # batch_size 3355 (4, 5, 12), # input_channels 3356 (4, 7, 8), # output_channels 3357 ) 3358 for batch_size, input_channels, output_channels in options: 3359 qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight 3360 3361 x = torch.randn(batch_size, input_channels) 3362 w = torch.randn(output_channels, input_channels) 3363 bias = torch.randn(output_channels) 3364 3365 out = qlinear_dynamic(x, w, bias) 3366 3367 # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors 3368 # output is FP32 3369 w_fp16 = w.to(torch.float16).to(torch.float32) 3370 ref = F.linear(x, w_fp16, bias) 3371 3372 self.assertEqual(out, ref) 3373 3374 3375 @skipIfNoFBGEMM 3376 def test_unpacked_qlinear_dynamic_fp16_opcheck(self): 3377 qlinear_dynamic = torch.ops.quantized.linear_dynamic_fp16_unpacked_weight.default 3378 3379 x = torch.randn(4, 4, device='cpu') 3380 w = torch.randn(4, 4, device='cpu') 3381 bias = torch.randn(4, device='cpu') 3382 3383 opcheck(qlinear_dynamic, (x, w, bias)) 3384 3385 @skipIfNoFBGEMM 3386 def test_wrapped_fbgemm_linear_fp16(self): 3387 options = itertools.product( 3388 (2, 4), # batch_size 3389 (4, 5), # input_channels 3390 (4, 7), # output_channels 3391 ) 3392 for batch_size, input_channels, output_channels in options: 3393 pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16 3394 linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight 3395 3396 x = torch.randn(batch_size, input_channels) 3397 w = torch.randn(output_channels, input_channels) 3398 bias = torch.randn(output_channels) 3399 3400 w_packed = pack_op(w) 3401 out = linear_op(x, w_packed, bias, output_channels) 3402 3403 w_fp16 = w.to(torch.float16).to(torch.float32) 3404 ref = F.linear(x, w_fp16, bias) 3405 3406 self.assertEqual(out, ref) 3407 3408 @skipIfNoFBGEMM 3409 def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self): 3410 # We are not using opcheck over here because the output for the op we're testing 3411 # (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic 3412 # due to the C-struct it's procuding. This would fail the check when we're trying 3413 # to match the result between compiled and eager version. 3414 # 3415 # This is only a temporary solution, long term, we should be able to support PT2 3416 # with torchbind natively. 3417 def func(X, W, B): 3418 packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W) 3419 return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0)) 3420 3421 x = torch.randn(1, 4, device="cpu") 3422 w = torch.randn(4, 4, device="cpu") 3423 b = torch.zeros(4, device="cpu") 3424 3425 ref_out = func(x, w, b) 3426 3427 compiled = torch.compile(func) 3428 compiled_out = compiled(x, w, b) 3429 3430 self.assertEqual(ref_out, compiled_out) 3431 3432 """Tests the correctness of the dynamic quantized lstm/gru.""" 3433 3434 def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range): 3435 # For Input (seq_len, batch, input_size) 3436 X = torch.randn(seq_len, num_batches, input_size) 3437 s, z = _calculate_dynamic_qparams(X, torch.quint8, reduce_range) 3438 Xq = torch.quantize_per_tensor(X, s, z, torch.quint8) 3439 3440 # For H and C: (num_layers(1) * num_directions, batch, hidden_size) 3441 3442 if num_directions == 1: 3443 H = torch.randn(num_directions, num_batches, hidden_size) 3444 C = torch.randn(num_directions, num_batches, hidden_size) 3445 else: 3446 H = torch.zeros(num_directions, num_batches, hidden_size) 3447 C = torch.zeros(num_directions, num_batches, hidden_size) 3448 3449 s, z = _calculate_dynamic_qparams(H, torch.quint8, reduce_range) 3450 Hq = torch.quantize_per_tensor(H, s, z, torch.quint8) 3451 s, z = _calculate_dynamic_qparams(C, torch.quint8, reduce_range) 3452 Cq = torch.quantize_per_tensor(C, s, z, torch.quint8) 3453 return Xq, Hq, Cq 3454 3455 def _get_rnn_weights_and_bias(self, input_size, hidden_size, num_directions, per_channel_quant, rnn_type): 3456 hidden_mult_map = {'LSTM': 4, 'LSTMCell': 4, 'GRU': 3, 'GRUCell': 3, 'RNNTanh': 2, 'RNNReLU': 2} 3457 hidden_mult = hidden_mult_map[rnn_type] 3458 weights1 = torch.randn(hidden_mult * hidden_size, input_size) 3459 weights2 = torch.randn(hidden_mult * hidden_size, hidden_size) 3460 scale1 = 0.1 * torch.ones([weights1.size()[0]]) 3461 scale2 = 0.3 * torch.ones([weights2.size()[0]]) 3462 zero_point1 = torch.zeros(scale1.size()).to(int) 3463 zero_point2 = torch.zeros(scale2.size()).to(int) 3464 b1 = torch.zeros(hidden_mult * hidden_size) 3465 if per_channel_quant: 3466 Wq1 = torch.quantize_per_channel(weights1, scale1, zero_point1, 0, torch.qint8) 3467 Wq2 = torch.quantize_per_channel(weights2, scale2, zero_point2, 0, torch.qint8) 3468 3469 else: 3470 Wq1 = torch.quantize_per_tensor(weights1, float(scale1[0]), int(zero_point1[0]), torch.qint8) 3471 Wq2 = torch.quantize_per_tensor(weights2, float(scale2[0]), int(zero_point2[0]), torch.qint8) 3472 return Wq1, Wq2, b1, b1 3473 3474 @given( 3475 num_batches=st.integers(1, 4), 3476 input_size=st.integers(16, 32), 3477 hidden_size=st.integers(4, 8), 3478 num_directions=st.integers(1, 2), 3479 per_channel_quant=st.booleans()) 3480 @override_qengines 3481 def test_qlstmGRU(self, num_batches, input_size, hidden_size, 3482 num_directions, per_channel_quant): 3483 # We test only for seq length of 1 and num layers of 1 as dynamic quantization occurs multiple times 3484 # within the LSTM op and we do not model the quantization between multiple calls of the linear op within the 3485 # lstm op 3486 seq_len = 1 3487 3488 for rnn_type in ['LSTM', 'GRU']: 3489 for dtype in [torch.qint8, torch.float16]: 3490 # Fp16 quantization is not supported for qnnpack or onednn 3491 if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16: 3492 continue 3493 3494 if torch.backends.quantized.engine == 'qnnpack': 3495 reduce_range = False 3496 else: 3497 reduce_range = True 3498 Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, 3499 hidden_size, num_directions, reduce_range) 3500 Wq1, Wq2, b1, b2 = self._get_rnn_weights_and_bias(input_size, 3501 hidden_size, 3502 num_directions, 3503 per_channel_quant, 3504 rnn_type) 3505 if dtype == torch.qint8: 3506 packed_ih = torch.ops.quantized.linear_prepack(Wq1, b1) 3507 packed_hh = torch.ops.quantized.linear_prepack(Wq2, b2) 3508 cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic( 3509 packed_ih, packed_hh, b1, b2, reduce_range) 3510 W_ref1 = Wq1.dequantize() 3511 W_ref2 = Wq2.dequantize() 3512 3513 else: 3514 packed_ih = torch.ops.quantized.linear_prepack_fp16(Wq1.dequantize(), b1) 3515 packed_hh = torch.ops.quantized.linear_prepack_fp16(Wq2.dequantize(), b2) 3516 cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih, packed_hh) 3517 W_ref1 = Wq1.dequantize().to(torch.float16).to(torch.float32) 3518 W_ref2 = Wq2.dequantize().to(torch.float16).to(torch.float32) 3519 3520 if rnn_type == 'LSTM': 3521 if num_directions > 1: 3522 result_ref = _VF.lstm(Xq.dequantize(), 3523 (Hq.dequantize(), Cq.dequantize()), 3524 [W_ref1, W_ref2, b1, b2, W_ref1, W_ref2, b1, b2], 3525 True, 3526 1, 3527 0, 3528 False, 3529 num_directions > 1, 3530 False) 3531 3532 result_dynamic = torch.quantized_lstm(Xq.dequantize(), 3533 (Hq.dequantize(), Cq.dequantize()), 3534 ([cell_params, cell_params]), 3535 True, 3536 1, 3537 0, 3538 False, 3539 True, 3540 False, 3541 dtype=torch.qint8, 3542 use_dynamic=True) 3543 else: 3544 result_ref = _VF.lstm(Xq.dequantize(), 3545 (Hq.dequantize(), Cq.dequantize()), 3546 [W_ref1, W_ref2, b1, b2], 3547 True, 3548 1, 3549 0, 3550 False, 3551 num_directions > 1, 3552 False) 3553 3554 result_dynamic = torch.quantized_lstm(Xq.dequantize(), 3555 (Hq.dequantize(), Cq.dequantize()), 3556 ([cell_params]), 3557 True, 3558 1, 3559 0, 3560 False, 3561 num_directions > 1, 3562 False, 3563 dtype=torch.qint8, 3564 use_dynamic=True) 3565 3566 if rnn_type == 'GRU': 3567 if num_directions > 1: 3568 result_ref = _VF.gru(Xq.dequantize(), 3569 Hq.dequantize(), 3570 [W_ref1, W_ref2, b1, b2, W_ref1, W_ref2, b1, b2], 3571 True, 3572 1, 3573 0, 3574 False, 3575 True, 3576 False) 3577 3578 result_dynamic = torch.quantized_gru(Xq.dequantize(), 3579 Hq.dequantize(), 3580 ([cell_params, cell_params]), 3581 True, 3582 1, 3583 0, 3584 False, 3585 True, 3586 False) 3587 else: 3588 result_ref = _VF.gru(Xq.dequantize(), 3589 Hq.dequantize(), 3590 [W_ref1, W_ref2, b1, b2], 3591 True, 3592 1, 3593 0, 3594 False, 3595 False, 3596 False) 3597 3598 result_dynamic = torch.quantized_gru(Xq.dequantize(), 3599 Hq.dequantize(), 3600 ([cell_params]), 3601 True, 3602 1, 3603 0, 3604 False, 3605 False, 3606 False) 3607 3608 self.assertEqual(result_ref[0], result_dynamic[0], msg="torch.quantized_lstm results are off") 3609 3610 @given( 3611 num_batches=st.integers(1, 4), 3612 input_size=st.integers(16, 32), 3613 hidden_size=st.integers(4, 8), 3614 per_channel_quant=st.booleans()) 3615 @override_qengines 3616 def test_qrnncell(self, num_batches, input_size, hidden_size, per_channel_quant): 3617 # We test only for seq length of 1 and num layers of 1 as dynamic quantization occurs multiple times 3618 # within the LSTM op and we do not model the quantization between multiple calls of the linear op within the 3619 # lstm op 3620 seq_len = 1 3621 3622 for rnn_type in ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']: 3623 for dtype in [torch.qint8, torch.float16]: 3624 # Fp16 quantization is not supported for qnnpack or onednn 3625 if torch.backends.quantized.engine in ('qnnpack', 'onednn') and dtype == torch.float16: 3626 continue 3627 3628 if torch.backends.quantized.engine == 'qnnpack': 3629 reduce_range = False 3630 else: 3631 reduce_range = True 3632 3633 Xq, Hq, Cq = self._get_rnn_inputs(seq_len, num_batches, input_size, hidden_size, 1, reduce_range) 3634 Wq1, Wq2, b1, b2 = self._get_rnn_weights_and_bias( 3635 input_size, hidden_size, 1, per_channel_quant, rnn_type) 3636 if dtype == torch.qint8: 3637 packed_ih = torch.ops.quantized.linear_prepack(Wq1, b1) 3638 packed_hh = torch.ops.quantized.linear_prepack(Wq2, b2) 3639 W_ref1 = Wq1.dequantize() 3640 W_ref2 = Wq2.dequantize() 3641 else: 3642 packed_ih = torch.ops.quantized.linear_prepack_fp16(Wq1.dequantize(), b1) 3643 packed_hh = torch.ops.quantized.linear_prepack_fp16(Wq2.dequantize(), b2) 3644 W_ref1 = Wq1.dequantize().to(torch.float16).to(torch.float32) 3645 W_ref2 = Wq2.dequantize().to(torch.float16).to(torch.float32) 3646 3647 state = {'LSTMCell': (Hq.dequantize()[0], Cq.dequantize()[0]), 3648 'GRUCell': Hq.dequantize()[0], 3649 'RNNTanh': Hq.dequantize()[0], 3650 'RNNReLU': Hq.dequantize()[0]} 3651 fn_dict = {'LSTMCell': torch._VF.lstm_cell, 3652 'GRUCell': torch._VF.gru_cell, 3653 'RNNTanh': torch._VF.rnn_tanh_cell, 3654 'RNNReLU': torch._VF.rnn_relu_cell} 3655 qfn_dict = {'LSTMCell': torch.ops.quantized.quantized_lstm_cell_dynamic, 3656 'GRUCell': torch.ops.quantized.quantized_gru_cell_dynamic, 3657 'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic, 3658 'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic} 3659 W_ref_dict = {torch.float16: (Wq1.dequantize().to(torch.float16).to(torch.float32), 3660 Wq2.dequantize().to(torch.float16).to(torch.float32)), 3661 torch.qint8: (Wq1.dequantize(), Wq2.dequantize())} 3662 3663 result_ref = fn_dict[rnn_type](Xq.dequantize()[0], state[rnn_type], W_ref1, W_ref2, b1, b2) 3664 result_dynamic = qfn_dict[rnn_type](Xq.dequantize()[0], state[rnn_type], packed_ih, packed_hh, b1, b2) 3665 self.assertEqual(result_ref[0], result_dynamic[0], msg="torch.quantized_rnncell results are off") 3666 3667 def _test_qconv_op_impl(self, q_mod, dq_op, dim, dtype): 3668 # The goal here is to show that the dynamic op is the same as 3669 # calc params->quantize_input->quantized op->dequantize output 3670 3671 if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN): 3672 return # not supported by QNNPACK 3673 3674 if qengine_is_qnnpack(): 3675 reduce_range = False 3676 else: 3677 reduce_range = True 3678 3679 X_fp32 = torch.randn(*([2] * dim)) 3680 s, z = _calculate_dynamic_qparams(X_fp32, dtype, reduce_range) 3681 3682 quantized_module = q_mod(2, 3, 1) 3683 packed_params = quantized_module._packed_params 3684 3685 quantized_module.scale, quantized_module.zero_point = s, z 3686 3687 X_q = torch.quantize_per_tensor(X_fp32, s, z, dtype) 3688 Y_q_ref = quantized_module(X_q) 3689 Y_ref = torch.dequantize(Y_q_ref) 3690 3691 X_dq = torch.dequantize(X_q) 3692 Y = dq_op(X_dq, packed_params, reduce_range) 3693 3694 self.assertEqual(Y, Y_ref) 3695 3696 @override_qengines 3697 def test_dynamic_conv1d(self): 3698 q_mod = torch.ao.nn.quantized.Conv1d 3699 dq_op = torch.ops.quantized.conv1d_dynamic 3700 dim = 3 3701 dtype = torch.quint8 3702 3703 self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) 3704 3705 @override_qengines 3706 def test_dynamic_conv2d(self): 3707 q_mod = torch.ao.nn.quantized.Conv2d 3708 dq_op = torch.ops.quantized.conv2d_dynamic 3709 dim = 4 3710 dtype = torch.quint8 3711 3712 self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) 3713 3714 @override_qengines 3715 def test_dynamic_conv3d(self): 3716 q_mod = torch.ao.nn.quantized.Conv3d 3717 dq_op = torch.ops.quantized.conv3d_dynamic 3718 dim = 5 3719 dtype = torch.quint8 3720 3721 self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) 3722 3723 @override_qengines 3724 def test_dynamic_convtranspose1d(self): 3725 q_mod = torch.ao.nn.quantized.ConvTranspose1d 3726 dq_op = torch.ops.quantized.conv_transpose1d_dynamic 3727 dim = 3 3728 dtype = torch.quint8 3729 3730 self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) 3731 3732 @override_qengines 3733 def test_dynamic_convtranspose2d(self): 3734 q_mod = torch.ao.nn.quantized.ConvTranspose2d 3735 dq_op = torch.ops.quantized.conv_transpose2d_dynamic 3736 dim = 4 3737 dtype = torch.quint8 3738 3739 self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) 3740 3741 @override_qengines 3742 def test_dynamic_convtranspose3d(self): 3743 q_mod = torch.ao.nn.quantized.ConvTranspose3d 3744 dq_op = torch.ops.quantized.conv_transpose3d_dynamic 3745 dim = 5 3746 dtype = torch.quint8 3747 3748 if qengine_is_qnnpack(): 3749 return # TODO: fix MakeDeConvOutputShape overflowing for convT3d with qnnpack 3750 self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) 3751 3752 3753class TestQuantizedLinear(TestCase): 3754 def _test_qlinear_impl(self, batch_size, input_channels, output_channels, use_bias, 3755 post_op, use_multi_dim_input, use_channelwise, **post_op_kwargs): 3756 decimal_val = 4 3757 dtypes = [torch.quint8] 3758 if torch.backends.quantized.engine == 'qnnpack': 3759 # QNNPACK supports uint8 in the kernels. In the op we shift the int8 3760 # weight values to uint8 to be on par with fbgemm. However, this causes 3761 # some rounding issues in rare cases. So, we relax the check to allow 3762 # off by one results. 3763 decimal_val = 0 3764 3765 # only qnnpack qengine supports qint8 when xnnpack is available 3766 if torch.backends.xnnpack.enabled: 3767 dtypes.append(torch.qint8) 3768 3769 for dtype in dtypes: 3770 # No support for channelwise in xnnpack (int8) 3771 # ONEDNN does not support qint8 3772 if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()): 3773 return 3774 3775 nptype = np_dtype[dtype] 3776 qlinear_prepack = torch.ops.quantized.linear_prepack 3777 if post_op == 'relu': 3778 qlinear = torch.ops.quantized.linear_relu 3779 elif post_op == 'leaky_relu': 3780 qlinear = torch.ops.quantized.linear_leaky_relu 3781 else: 3782 qlinear = torch.ops.quantized.linear 3783 if use_multi_dim_input: 3784 batch_size *= 3 # Test the multi-dim input tensor 3785 X_scale = 1.5 3786 X_zp = 5 3787 X_value_min = -128 if dtype == torch.qint8 else 0 3788 X_value_max = 127 if dtype == torch.qint8 else 255 3789 X_q0 = np.round( 3790 np.random.rand(batch_size, input_channels) * 3791 (X_value_max - X_value_min) 3792 + X_value_min 3793 ).astype(nptype) 3794 3795 W_scales = np.random.rand(output_channels) 3796 # xnnpack forces W_zp to 0 when using symmetric quantization 3797 # ONEDNN only supports symmetric quantization of weight 3798 if dtype == torch.qint8 or qengine_is_onednn(): 3799 W_zps = np.zeros(output_channels).astype(int) 3800 else: 3801 W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(int) 3802 # when using symmetric quantization 3803 # special restriction for xnnpack fully connected op weight 3804 # [-127, 127] instead of [-128, 127] 3805 W_value_min = -127 if dtype == torch.qint8 else -128 3806 W_value_max = 127 3807 W_q0 = np.round( 3808 np.random.rand(output_channels, input_channels) 3809 * (W_value_max - W_value_min) 3810 + W_value_min 3811 ).astype(np.int8) # weight is always int8_t 3812 b_value_min = -10 3813 b_value_max = 10 3814 b_q0 = np.round( 3815 np.random.rand(output_channels) * 3816 (b_value_max - b_value_min) + b_value_min 3817 ).astype(np.int32) if use_bias else None 3818 if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'): 3819 avoid_vpmaddubsw_overflow_linear( 3820 batch_size, 3821 input_channels, 3822 output_channels, 3823 X_q0, 3824 X_value_min, 3825 X_value_max, 3826 W_q0, 3827 W_value_min, 3828 W_value_max, 3829 ) 3830 X = torch.from_numpy(_dequantize( 3831 X_q0, X_scale, X_zp)).to(dtype=torch.float) 3832 X_q = torch.quantize_per_tensor( 3833 X, scale=X_scale, zero_point=X_zp, dtype=dtype) 3834 if use_channelwise: 3835 W = torch.from_numpy(_dequantize(W_q0, W_scales.reshape( 3836 (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float) 3837 W_q = torch.quantize_per_channel(W, scales=torch.from_numpy(W_scales), 3838 zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8) 3839 b = torch.from_numpy(_dequantize( 3840 b_q0, X_scale * W_scales, 0)).to(dtype=torch.float) if use_bias else None 3841 b_q = torch.quantize_per_channel(b, scales=torch.from_numpy(X_scale * W_scales), 3842 zero_points=torch.zeros(output_channels, dtype=torch.long), 3843 axis=0, dtype=torch.qint32) if use_bias else None 3844 else: 3845 W = torch.from_numpy(_dequantize( 3846 W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float) 3847 W_q = torch.quantize_per_tensor(W, scale=W_scales[0], zero_point=( 3848 W_zps[0].astype(int).item()), dtype=torch.qint8) 3849 b = torch.from_numpy(_dequantize( 3850 b_q0, X_scale * (W_scales[0].item()), 0)).to(dtype=torch.float) if use_bias else None 3851 b_q = torch.quantize_per_tensor( 3852 b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) if use_bias else None 3853 # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with 3854 # Y_scale * 255 (max for uint8). 3855 Y_scale = 12.34 3856 Y_zp = 5 3857 # Weight prepacking operator for quantized Linear 3858 float_bias = b if use_bias else None 3859 W_prepack = qlinear_prepack(W_q, float_bias) 3860 if use_multi_dim_input: 3861 X_q = X_q.view(3, int(batch_size / 3), input_channels) 3862 # Quantized Linear operator with prepacked weight 3863 Y_q = qlinear(X_q, W_prepack, Y_scale, Y_zp, **post_op_kwargs) 3864 if not use_channelwise and post_op in ('none', 'relu'): 3865 # Test the per-tensor quantization only 3866 # Reference quantized Linear operator 3867 Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0, 3868 W_scales[0], W_zps[0], b_q0, Y_scale, Y_zp, dtype=nptype) 3869 if post_op == 'relu': 3870 Y_q_ref[Y_q_ref < Y_zp] = Y_zp 3871 if use_multi_dim_input: 3872 Y_q_ref = np.reshape( 3873 Y_q_ref, (3, int(batch_size / 3), output_channels)) 3874 # Assert equal 3875 np.testing.assert_array_almost_equal(Y_q_ref, Y_q.int_repr().numpy(), decimal=decimal_val) 3876 # Test both per-tensor and per-channel quantization 3877 # Reference quantized result from PyTorch Linear operator 3878 W_fp32 = W_q.dequantize().to(dtype=torch.float) 3879 X_fp32 = X_q.dequantize().to(dtype=torch.float) 3880 b_fp32 = b_q.dequantize().to(dtype=torch.float) if use_bias else None 3881 Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) 3882 if post_op == 'relu': 3883 Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 3884 elif post_op == 'leaky_relu': 3885 Y_fp32_ref = F.leaky_relu(Y_fp32_ref, **post_op_kwargs) 3886 Y_q_ref2 = torch.quantize_per_tensor( 3887 Y_fp32_ref, Y_scale, Y_zp, dtype) 3888 # Assert equal 3889 np.testing.assert_array_almost_equal( 3890 Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=decimal_val) 3891 3892 """Tests the correctness of the quantized linear op.""" 3893 @override_qengines 3894 def test_qlinear(self): 3895 batch_size_list = [1, 4] 3896 input_channels_list = [16, 32] 3897 output_channels_list = [4, 8] 3898 use_bias_list = [True, False] 3899 use_multi_dim_input_list = [True, False] 3900 use_channelwise_list = [True, False] 3901 post_op = 'none' 3902 cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, 3903 use_bias_list, use_multi_dim_input_list, use_channelwise_list) 3904 for batch_size, input_channels, output_channels, use_bias, \ 3905 use_multi_dim_input, use_channelwise in cases: 3906 self._test_qlinear_impl(batch_size, input_channels, output_channels, 3907 use_bias, post_op, use_multi_dim_input, use_channelwise) 3908 3909 """Tests the correctness of the quantized linear_relu op.""" 3910 @override_qengines 3911 def test_qlinear_relu(self): 3912 batch_size_list = [1, 4] 3913 input_channels_list = [16, 32] 3914 output_channels_list = [4, 8] 3915 use_bias_list = [True, False] 3916 use_multi_dim_input_list = [True, False] 3917 use_channelwise_list = [True, False] 3918 post_op = 'relu' 3919 cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, 3920 use_bias_list, use_multi_dim_input_list, use_channelwise_list) 3921 for batch_size, input_channels, output_channels, use_bias, \ 3922 use_multi_dim_input, use_channelwise in cases: 3923 self._test_qlinear_impl(batch_size, input_channels, output_channels, 3924 use_bias, post_op, use_multi_dim_input, use_channelwise) 3925 3926 @given(batch_size=st.integers(1, 4), 3927 input_channels=st.integers(16, 32), 3928 output_channels=st.integers(4, 8), 3929 use_bias=st.booleans(), 3930 use_relu=st.booleans(), 3931 use_multi_dim_input=st.booleans(), 3932 use_channelwise=st.booleans()) 3933 @skipIfNoFBGEMM 3934 def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( 3935 self, batch_size, input_channels, output_channels, use_bias, 3936 use_relu, use_multi_dim_input, use_channelwise): 3937 decimal_val = 4 3938 dtypes = [torch.quint8] 3939 for dtype in dtypes: 3940 # No support for channelwise in xnnpack (int8) 3941 # ONEDNN does not support qint8 3942 if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()): 3943 return 3944 3945 nptype = np_dtype[dtype] 3946 qlinear_prepack = torch.ops.quantized.linear_prepack 3947 if use_relu: 3948 qlinear = torch.ops.quantized.linear_with_input_q_dq_qweight_dq_relu_output_fp32 3949 else: 3950 qlinear = torch.ops.quantized.linear_with_input_q_dq_qweight_dq_output_fp32 3951 if use_multi_dim_input: 3952 batch_size *= 3 # Test the multi-dim input tensor 3953 X_scale = 1.5 3954 X_zp = 5 3955 X_value_min = -128 if dtype == torch.qint8 else 0 3956 X_value_max = 127 if dtype == torch.qint8 else 255 3957 X_q0 = np.round( 3958 np.random.rand(batch_size, input_channels) * 3959 (X_value_max - X_value_min) 3960 + X_value_min 3961 ).astype(nptype) 3962 3963 W_scales = np.random.rand(output_channels) 3964 # xnnpack forces W_zp to 0 when using symmetric quantization 3965 # ONEDNN only supports symmetric quantization of weight 3966 if dtype == torch.qint8 or qengine_is_onednn(): 3967 W_zps = np.zeros(output_channels).astype(int) 3968 else: 3969 W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(int) 3970 # when using symmetric quantization 3971 # special restriction for xnnpack fully connected op weight 3972 # [-127, 127] instead of [-128, 127] 3973 W_value_min = -127 if dtype == torch.qint8 else -128 3974 W_value_max = 127 3975 W_q0 = np.round( 3976 np.random.rand(output_channels, input_channels) 3977 * (W_value_max - W_value_min) 3978 + W_value_min 3979 ).astype(np.int8) # weight is always int8_t 3980 b_value_min = -10 3981 b_value_max = 10 3982 b_q0 = np.round( 3983 np.random.rand(output_channels) * 3984 (b_value_max - b_value_min) + b_value_min 3985 ).astype(np.int32) if use_bias else None 3986 if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'): 3987 avoid_vpmaddubsw_overflow_linear( 3988 batch_size, 3989 input_channels, 3990 output_channels, 3991 X_q0, 3992 X_value_min, 3993 X_value_max, 3994 W_q0, 3995 W_value_min, 3996 W_value_max, 3997 ) 3998 X = torch.from_numpy(_dequantize( 3999 X_q0, X_scale, X_zp)).to(dtype=torch.float) 4000 X_q = torch.quantize_per_tensor( 4001 X, scale=X_scale, zero_point=X_zp, dtype=dtype) 4002 if use_channelwise: 4003 W = torch.from_numpy(_dequantize(W_q0, W_scales.reshape( 4004 (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float) 4005 W_q = torch.quantize_per_channel(W, scales=torch.from_numpy(W_scales), 4006 zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8) 4007 b = torch.from_numpy(_dequantize( 4008 b_q0, X_scale * W_scales, 0)).to(dtype=torch.float) if use_bias else None 4009 b_q = torch.quantize_per_channel(b, scales=torch.from_numpy(X_scale * W_scales), 4010 zero_points=torch.zeros(output_channels, dtype=torch.long), 4011 axis=0, dtype=torch.qint32) if use_bias else None 4012 else: 4013 W = torch.from_numpy(_dequantize( 4014 W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float) 4015 W_q = torch.quantize_per_tensor(W, scale=W_scales[0], zero_point=( 4016 W_zps[0].astype(int).item()), dtype=torch.qint8) 4017 b = torch.from_numpy(_dequantize( 4018 b_q0, X_scale * (W_scales[0].item()), 0)).to(dtype=torch.float) if use_bias else None 4019 b_q = torch.quantize_per_tensor( 4020 b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) if use_bias else None 4021 # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with 4022 # Y_scale * 255 (max for uint8). 4023 Y_scale = 125.1234 4024 Y_zp = 5 4025 # Weight prepacking operator for quantized Linear 4026 float_bias = b if use_bias else None 4027 W_prepack = qlinear_prepack(W_q, float_bias) 4028 if use_multi_dim_input: 4029 X = X.view(3, int(batch_size / 3), input_channels) 4030 X_q = X_q.view(3, int(batch_size / 3), input_channels) 4031 # Quantized Linear operator with prepacked weight 4032 Y_q_dq = qlinear(X, X_scale, X_zp, W_prepack) 4033 # Test both per-tensor and per-channel quantization 4034 # Reference quantized result from PyTorch Linear operator 4035 W_fp32 = W_q.dequantize().to(dtype=torch.float) 4036 X_fp32 = X_q.dequantize().to(dtype=torch.float) 4037 b_fp32 = b_q.dequantize().to(dtype=torch.float) if use_bias else None 4038 Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) 4039 if use_relu: 4040 Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 4041 decimal_val = 1 4042 np.testing.assert_array_almost_equal(Y_fp32_ref.numpy(), Y_q_dq.numpy(), decimal=decimal_val) 4043 4044 @given(batch_size=st.integers(1, 4), 4045 # in cudnn v. 8.4.0, there is a limitation that input channels 4046 # should be a multiple of 4 for int8 tensors. in cudnn v.8.3.3 4047 # this should be a multiple of 16 4048 input_channels=st.sampled_from([4, 8, 12, 16, 32]), 4049 # constraints on output channels appear to be relax, as it seems we can use any positive integer here 4050 # except 1. It is not clear why 1 will not work. TODO: check with Yang 4051 output_channels=st.integers(2, 36), 4052 use_bias=st.booleans(), 4053 use_relu=st.booleans(), 4054 use_multi_dim_input=st.booleans(), 4055 use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn 4056 @skipIfNoFBGEMM 4057 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 4058 @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0") 4059 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 4060 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 4061 # TODO: check with yang regarding CUDNN flags 4062 @unittest.skip("not currently working and feature isn't used") 4063 def test_qlinear_cudnn(self, batch_size, input_channels, output_channels, use_bias, 4064 use_relu, use_multi_dim_input, use_channelwise): 4065 qlinear_prepack = torch.ops.quantized.linear_prepack 4066 if use_relu: 4067 qlinear_op = torch.ops.quantized.linear_relu 4068 else: 4069 qlinear_op = torch.ops.quantized.linear 4070 X_scale = 1.5 4071 X_zp = 0 4072 X_value_min = -128 4073 X_value_max = 127 4074 X_q0 = np.round( 4075 np.random.rand(batch_size, input_channels) * 4076 (X_value_max - X_value_min) 4077 + X_value_min).astype(np.int8) 4078 W_scale = 2.5 4079 W_zp = 0 4080 W_value_min = -128 4081 W_value_max = 127 4082 W_q0 = np.round( 4083 np.random.rand(output_channels, input_channels) 4084 * (W_value_max - W_value_min) 4085 + W_value_min 4086 ).astype(np.int8) 4087 b_value_min = -10 4088 b_value_max = 10 4089 b_q0 = np.round( 4090 np.random.rand(output_channels) * 4091 (b_value_max - b_value_min) + b_value_min 4092 ).astype(np.int32) if use_bias else None 4093 if use_bias: 4094 b_value_min = -10 4095 b_value_max = 10 4096 b_q0 = np.round( 4097 np.random.rand(output_channels) * 4098 (b_value_max - b_value_min) + b_value_min 4099 ).astype(np.int32) 4100 else: 4101 bias = None 4102 avoid_vpmaddubsw_overflow_linear( 4103 batch_size, 4104 input_channels, 4105 output_channels, 4106 X_q0, 4107 X_value_min, 4108 X_value_max, 4109 W_q0, 4110 W_value_min, 4111 W_value_max, 4112 ) 4113 quant_dtype = torch.qint8 4114 X = torch.from_numpy(_dequantize( 4115 X_q0, X_scale, X_zp)).to(dtype=torch.float).to(device="cuda") 4116 X_q = torch.quantize_per_tensor( 4117 X, scale=X_scale, zero_point=X_zp, dtype=quant_dtype) 4118 W = torch.from_numpy(_dequantize( 4119 W_q0, W_scale, W_zp)).to(dtype=torch.float).to(device="cuda") 4120 W_q = torch.quantize_per_tensor(W, scale=W_scale, zero_point=W_zp, dtype=quant_dtype) 4121 b = torch.from_numpy(_dequantize( 4122 b_q0, X_scale * (W_zp), 0)).to(dtype=torch.float).to(device="cuda") if use_bias else None 4123 b_q = torch.quantize_per_tensor( 4124 b, scale=X_scale * W_scale, zero_point=0, dtype=quant_dtype) if use_bias else None 4125 Y_scale = 0.5 4126 Y_zp = 0 4127 # Weight prepacking operator for quantized Linear 4128 float_bias = b if use_bias else None 4129 W_prepack = qlinear_prepack(W_q, float_bias if use_bias else None) 4130 # Quantized Linear operator with prepacked weight 4131 Y_q = qlinear_op(X_q, W_prepack, Y_scale, Y_zp).to(device="cpu") 4132 Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0, 4133 W_scale, W_zp, b_q0, Y_scale, Y_zp, dtype=np.int8) 4134 if use_relu: 4135 Y_q_ref[Y_q_ref < Y_zp] = Y_zp 4136 decimal_val = 0 4137 np.testing.assert_array_almost_equal(Y_q_ref, Y_q.int_repr().numpy(), decimal=decimal_val) 4138 4139 """Tests the correctness of the quantized::linear_unpack op.""" 4140 @given(W=hu.tensor(shapes=hu.array_shapes(2, 2,), 4141 qparams=hu.qparams(dtypes=torch.qint8)), 4142 use_channelwise=st.booleans()) 4143 @override_qengines 4144 def test_qlinear_unpack(self, W, use_channelwise): 4145 W, (W_scale, W_zp, torch_type) = W 4146 if use_channelwise: 4147 output_channels = W.shape[0] 4148 W_scales = torch.rand(output_channels).to(torch.double) 4149 W_zps = torch.round(torch.rand(output_channels) 4150 * 100 - 50).to(torch.int64) 4151 qlinear_prepack = torch.ops.quantized.linear_prepack 4152 qlinear_unpack = torch.ops.quantized.linear_unpack 4153 4154 # ONEDNN only supports symmetric quantization of weight 4155 if qengine_is_onednn(): 4156 if use_channelwise: 4157 W_zps = torch.zeros(output_channels).to(torch.int64) 4158 else: 4159 W_zp = 0 4160 4161 W = torch.from_numpy(W) 4162 if use_channelwise: 4163 W_q = torch.quantize_per_channel( 4164 W, W_scales, W_zps, 0, dtype=torch_type) 4165 else: 4166 W_q = torch.quantize_per_tensor(W, scale=W_scale, zero_point=W_zp, 4167 dtype=torch_type) 4168 # Weight prepacking operator for quantized Linear 4169 W_prepack = qlinear_prepack(W_q) 4170 # Weight unpack operator for quantized Linear (Used for serialization) 4171 W_q_origin = qlinear_unpack(W_prepack)[0] 4172 # Assert equal 4173 np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy()) 4174 if use_channelwise: 4175 np.testing.assert_array_almost_equal(np.float32(W_q.q_per_channel_scales().numpy()), 4176 np.float32( 4177 W_q_origin.q_per_channel_scales().numpy()), 4178 decimal=4) 4179 np.testing.assert_equal(W_q.q_per_channel_zero_points( 4180 ).numpy(), W_q_origin.q_per_channel_zero_points().numpy()) 4181 else: 4182 np.testing.assert_equal(np.float32( 4183 W_q.q_scale()), np.float32(W_q_origin.q_scale())) 4184 np.testing.assert_equal( 4185 W_q.q_zero_point(), W_q_origin.q_zero_point()) 4186 4187 """Tests the correctness of the _quantized::wrapped_quantized_linear op.""" 4188 @skipIfNoFBGEMM 4189 @given( 4190 m=st.integers(2, 6), 4191 k=st.integers(2, 6), 4192 n=st.integers(2, 6), 4193 ) 4194 def test_wrapped_quantized_linear(self, m, n, k): 4195 input = torch.randn(m, k, dtype=torch.float32) 4196 input_scale = torch.tensor(0.1) 4197 input_zero_point = torch.tensor(0) 4198 weight = torch.randn(n, k, dtype=torch.float32) 4199 weight_scale = torch.tensor(0.1) 4200 weight_zero_point = torch.tensor(0) 4201 bias = torch.randn(n, dtype=torch.float32) 4202 output_scale = torch.tensor(0.1) 4203 output_zero_point = torch.tensor(0) 4204 out_channel = n 4205 4206 ret = torch.ops._quantized.wrapped_quantized_linear( 4207 input, 4208 input_scale, 4209 input_zero_point, 4210 weight, 4211 weight_scale, 4212 weight_zero_point, 4213 bias, 4214 output_scale, 4215 output_zero_point, 4216 out_channel, 4217 ) 4218 4219 qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8) 4220 qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8) 4221 qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias) 4222 qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point) 4223 ret_ref = qlinear.dequantize() 4224 self.assertEqual(ret, ret_ref) 4225 4226 """Tests the correctness of the _quantized::_wrapped_linear_prepack and 4227 _quantized::_wrapped_quantized_linear_prepacked ops.""" 4228 @skipIfNoFBGEMM 4229 @given( 4230 m=st.integers(2, 6), 4231 k=st.integers(2, 6), 4232 n=st.integers(2, 6), 4233 ) 4234 def test_wrapped_quantized_linear_prepacked(self, m, n, k): 4235 input = torch.randn(m, k, dtype=torch.float32) 4236 input_scale = torch.tensor(0.1) 4237 input_zero_point = torch.tensor(0) 4238 weight = torch.randn(n, k, dtype=torch.float32) 4239 weight_scale = torch.tensor(0.1) 4240 weight_zero_point = torch.tensor(0) 4241 bias = torch.randn(n, dtype=torch.float32) 4242 output_scale = torch.tensor(0.1) 4243 output_zero_point = torch.tensor(0) 4244 out_channel = n 4245 4246 ret_1 = torch.ops._quantized._wrapped_linear_prepack( 4247 weight, 4248 weight_scale, 4249 weight_zero_point, 4250 bias 4251 ) 4252 ret_2 = torch.ops._quantized._wrapped_quantized_linear_prepacked( 4253 input, 4254 input_scale, 4255 input_zero_point, 4256 ret_1, 4257 output_scale, 4258 output_zero_point, 4259 out_channel 4260 ) 4261 qinput = torch.quantize_per_tensor(input, input_scale, input_zero_point, torch.quint8) 4262 qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, torch.qint8) 4263 qlinear_prepack = torch.ops.quantized.linear_prepack(qweight, bias) 4264 qlinear = torch.ops.quantized.linear(qinput, qlinear_prepack, output_scale, output_zero_point) 4265 ret_ref = qlinear.dequantize() 4266 self.assertEqual(ret_2, ret_ref) 4267 4268 """Tests the correctness of the quantized::linear_unpack after freeing original tensor op.""" 4269 @skipIfNoQNNPACK 4270 @given(W=hu.tensor(shapes=hu.array_shapes(2, 2,), 4271 qparams=hu.qparams(dtypes=torch.qint8))) 4272 @override_qengines 4273 def test_qlinear_qnnpack_free_memory_and_unpack(self, W): 4274 assert qengine_is_qnnpack 4275 W, (W_scale, W_zp, torch_type) = W 4276 qlinear_prepack = torch.ops.quantized.linear_prepack 4277 qlinear_unpack = torch.ops.quantized.linear_unpack 4278 4279 W = torch.from_numpy(W) 4280 # ONEDNN only supports symmetric quantization of weight 4281 if qengine_is_onednn(): 4282 W_zp = 0 4283 W_q = torch.quantize_per_tensor(W, scale=W_scale, zero_point=W_zp, dtype=torch_type) 4284 # Weight prepacking operator for quantized Linear 4285 W_prepack = qlinear_prepack(W_q) 4286 dummy_input = torch.randn((1, W.shape[1])) 4287 # Make sure we free original tensor by running matrix multiplication in backend. 4288 torch.ops.quantized.linear_dynamic(dummy_input, W_prepack) 4289 torch.ops.quantized.linear_dynamic(dummy_input, W_prepack) 4290 # At this step, original tensor should be recovered from a data_ptr 4291 W_q_origin = qlinear_unpack(W_prepack)[0] 4292 # Assert equal 4293 np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy()) 4294 np.testing.assert_equal(np.float32( 4295 W_q.q_scale()), np.float32(W_q_origin.q_scale())) 4296 np.testing.assert_equal( 4297 W_q.q_zero_point(), W_q_origin.q_zero_point()) 4298 4299 @skipIfNoONEDNN 4300 def test_qlinear_leaky_relu(self): 4301 with override_quantized_engine('onednn'): 4302 batch_size_list = [1, 4] 4303 input_channels_list = [16, 32] 4304 output_channels_list = [4, 8] 4305 use_bias_list = [True, False] 4306 use_multi_dim_input_list = [True, False] 4307 use_channelwise_list = [True, False] 4308 negative_slopes_list = [0.01, 0.05] 4309 post_op = 'leaky_relu' 4310 cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, 4311 use_bias_list, use_multi_dim_input_list, 4312 use_channelwise_list, negative_slopes_list) 4313 for batch_size, input_channels, output_channels, use_bias, \ 4314 use_multi_dim_input, use_channelwise, neg_slope in cases: 4315 self._test_qlinear_impl(batch_size, input_channels, output_channels, 4316 use_bias, post_op, use_multi_dim_input, 4317 use_channelwise, negative_slope=neg_slope) 4318 4319 @skipIfNoONEDNN 4320 def test_qlinear_tanh(self): 4321 with override_quantized_engine('onednn'): 4322 batch_size_list = [1, 4] 4323 input_channels_list = [16, 32] 4324 output_channels_list = [4, 8] 4325 use_bias_list = [True, False] 4326 use_multi_dim_input_list = [True, False] 4327 use_channelwise_list = [True, False] 4328 post_op = 'tanh' 4329 cases = itertools.product(batch_size_list, input_channels_list, 4330 output_channels_list, use_bias_list, 4331 use_multi_dim_input_list, use_channelwise_list) 4332 for batch_size, input_channels, output_channels, use_bias, \ 4333 use_multi_dim_input, use_channelwise in cases: 4334 self._test_qlinear_impl(batch_size, input_channels, output_channels, 4335 use_bias, post_op, use_multi_dim_input, 4336 use_channelwise) 4337 4338 def _test_qlinear_pt2e_helper( 4339 self, 4340 qlinear_op, 4341 post_op="none", 4342 unary_post_op_args=(), 4343 post_op_algorithms=("none"), 4344 ): 4345 qlinear_prepack = torch.ops.onednn.qlinear_prepack 4346 linear_op = F.linear 4347 in_channels_list = [4, 8] 4348 out_channels_list = [16, 32] 4349 batch_size = 1 4350 use_bias_list = [True, False] 4351 weight_quant_per_channel_list = [True, False] 4352 output_dtype_list = [None, torch.float32, torch.bfloat16] 4353 x_scale, x_zp = 1.2, 1 4354 w_scale, w_zp = 0.8, 0 4355 y_scale, y_zp = 4.7, 2 4356 input_dim_list = [2, 3] 4357 cases = itertools.product( 4358 in_channels_list, out_channels_list, use_bias_list, 4359 weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list) 4360 with override_quantized_engine('onednn'): 4361 for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases: 4362 used_y_scale = y_scale 4363 used_y_zp = y_zp 4364 fp32_out = output_dtype == torch.float32 4365 bfloat16_out = output_dtype == torch.bfloat16 4366 if fp32_out or bfloat16_out: 4367 used_y_scale, used_y_zp = 1.0, 0 4368 x2_scale, x2_zp = 1.0, 0 4369 else: 4370 x2_scale, x2_zp = 2.3, 5 4371 x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10 4372 w = torch.rand(oc, ic) * 10 4373 qx = torch.quantize_per_tensor(x, x_scale, x_zp, torch.quint8) 4374 if weight_quant_per_channel: 4375 w_scales = torch.Tensor([w_scale] * oc) 4376 w_zps = torch.zeros(oc).to(dtype=torch.int) 4377 qw = torch.quantize_per_channel(w, w_scales, w_zps, 0, torch.qint8) 4378 else: 4379 w_scales = torch.Tensor([w_scale]) 4380 w_zps = torch.Tensor([w_zp]).to(dtype=torch.int) 4381 qw = torch.quantize_per_tensor(w, w_scale, w_zp, torch.qint8) 4382 if use_bias: 4383 b = torch.rand(oc) * 10 4384 else: 4385 b = None 4386 4387 x_ref = qx.dequantize() 4388 w_ref = qw.dequantize() 4389 y_ref = linear_op(x_ref, w_ref, b) 4390 4391 # compute with CPU tensors 4392 qx_cpu = qx.int_repr() 4393 qw_cpu = qw.int_repr() 4394 qw_packed = qlinear_prepack(qw_cpu, x.shape) 4395 4396 if post_op in ("none", "relu", "gelu"): 4397 qy_cpu = qlinear_op( 4398 qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, 4399 b, used_y_scale, used_y_zp, output_dtype, 4400 post_op, unary_post_op_args, post_op_algo 4401 ) 4402 if post_op == "relu": 4403 y_ref = F.relu(y_ref) 4404 elif post_op == "gelu": 4405 y_ref = F.gelu(y_ref, approximate=post_op_algo) 4406 qy_ref = torch.quantize_per_tensor(y_ref, used_y_scale, used_y_zp, torch.quint8) 4407 elif post_op in ("sum", "sum_relu"): 4408 x2_int8 = torch.randint(0, 4, y_ref.size()) 4409 x2 = x2_scale * ((x2_int8 - x2_zp).float()) 4410 qx2 = torch.quantize_per_tensor( 4411 x2, scale=x2_scale, zero_point=x2_zp, dtype=torch.quint8 4412 ) 4413 unary_post_op = "relu" if post_op == "sum_relu" else "none" 4414 binary_alpha = 1.0 # we only support alpha=1.0 now 4415 accum = qx2.int_repr() if output_dtype is None else qx2.dequantize() 4416 if bfloat16_out: 4417 accum = accum.bfloat16() 4418 qy_cpu = qlinear_op( 4419 qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, 4420 accum, b, used_y_scale, used_y_zp, output_dtype, 4421 x2_scale, x2_zp, "sum", binary_alpha, 4422 unary_post_op, unary_post_op_args, post_op_algo 4423 ) 4424 y_ref = y_ref + x2 * binary_alpha 4425 if unary_post_op == "relu": 4426 y_ref = F.relu(y_ref) 4427 qy_ref = torch.quantize_per_tensor(y_ref, used_y_scale, used_y_zp, torch.quint8) 4428 elif post_op in ("add", "add_relu"): 4429 used_y_scale, used_y_zp = 1.0, 0 4430 if output_dtype is not None: 4431 # Only support int8 output 4432 continue 4433 x2 = torch.randn(y_ref.size()) * 10 4434 unary_post_op = "relu" if post_op == "add_relu" else "none" 4435 binary_alpha = 1.0 # we only support alpha=1.0 now 4436 qy_cpu = qlinear_op( 4437 qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps, 4438 x2, b, used_y_scale, used_y_zp, output_dtype, 4439 1.0, 0, "add", binary_alpha, 4440 unary_post_op, unary_post_op_args, post_op_algo 4441 ) 4442 y_ref = y_ref + x2 * binary_alpha 4443 if unary_post_op == "relu": 4444 y_ref = F.relu(y_ref) 4445 qy_ref = torch.quantize_per_tensor(y_ref, used_y_scale, used_y_zp, torch.quint8) 4446 4447 # Compare results 4448 if fp32_out or bfloat16_out: 4449 qy_cpu = torch.quantize_per_tensor( 4450 qy_cpu.to(torch.float32), 4451 used_y_scale, 4452 used_y_zp, dtype=torch.quint8 4453 ).int_repr() 4454 4455 self.assertEqual(x.dim(), qy_cpu.dim()) 4456 4457 np.testing.assert_array_almost_equal( 4458 qy_ref.int_repr().cpu().numpy(), 4459 qy_cpu.cpu().numpy(), 4460 decimal=0, 4461 err_msg=f"""X: {x}, W: {w}, b: {b}, 4462 x_s: {x_scale}, x_zp: {x_zp}, 4463 w_s: {w_scale}, w_zp: {w_zp}, 4464 y_s: {y_scale}, y_zp: {y_zp}""", 4465 ) 4466 4467 @skipIfNoONEDNN 4468 def test_qlinear_pt2e(self): 4469 qlinear = torch.ops.onednn.qlinear_pointwise 4470 self._test_qlinear_pt2e_helper(qlinear, "none") 4471 4472 @skipIfNoONEDNN 4473 def test_qlinear_relu_pt2e(self): 4474 qlinear = torch.ops.onednn.qlinear_pointwise 4475 self._test_qlinear_pt2e_helper(qlinear, "relu") 4476 4477 @skipIfNoONEDNN 4478 def test_qlinear_gelu_pt2e(self): 4479 qlinear = torch.ops.onednn.qlinear_pointwise 4480 post_op_algorithms = ['none', 'tanh'] 4481 self._test_qlinear_pt2e_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms) 4482 4483 @skipIfNoONEDNN 4484 def test_qlinear_sum_pt2e(self): 4485 qlinear = torch.ops.onednn.qlinear_pointwise.binary 4486 self._test_qlinear_pt2e_helper(qlinear, "sum") 4487 4488 @skipIfNoONEDNN 4489 def test_qlinear_sum_relu_pt2e(self): 4490 qlinear = torch.ops.onednn.qlinear_pointwise.binary 4491 self._test_qlinear_pt2e_helper(qlinear, "sum_relu") 4492 4493 @skipIfNoONEDNN 4494 def test_qlinear_add_pt2e(self): 4495 qlinear = torch.ops.onednn.qlinear_pointwise.binary 4496 self._test_qlinear_pt2e_helper(qlinear, "add") 4497 4498 @skipIfNoONEDNN 4499 def test_qlinear_add_relu_pt2e(self): 4500 qlinear = torch.ops.onednn.qlinear_pointwise.binary 4501 self._test_qlinear_pt2e_helper(qlinear, "add_relu") 4502 4503 4504@unittest.skipIf(IS_MACOS, "Known test failure on Mac.") 4505class TestQuantizedEmbeddingOps(TestCase): 4506 4507 def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimized_qparams, weights): 4508 data_type = weights.dtype 4509 4510 qtype = torch.quint8 4511 if bit_rate == 8: 4512 w_packed = pack_fn(weights) 4513 else: 4514 w_packed = pack_fn(weights, optimized_qparams=optimized_qparams) 4515 w_unpacked = unpack_fn(w_packed) 4516 4517 if (bit_rate == 8 or bit_rate == 4) and data_type != torch.float16: 4518 # torch.quantize_per_channel does not support float16 yet. 4519 4520 obs_weights = weights 4521 # Combine 3D embeddings (e.g. stacked combination of embeddings) 4522 # in a dimension orthogonal to channels. 4523 if (len(obs_weights.shape) > 2): 4524 stacked_shape = list(weights.size()) 4525 stacked_shape[1] *= stacked_shape[0] 4526 obs_weights = weights.reshape(stacked_shape[1:]) 4527 4528 # Check numerics of prepack function that accepts qtensor as input. 4529 # We use min-max observer to mimic the quantization performed in the original function. 4530 obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) 4531 obs(obs_weights) 4532 # Get the scale and zero point for the weight tensor 4533 qparams = obs.calculate_qparams() 4534 if bit_rate == 4: 4535 qtype = torch.quint4x2 4536 # Quantize the weights to 8bits 4537 qweight = torch.quantize_per_channel(obs_weights, qparams[0], qparams[1], axis=0, dtype=qtype) 4538 real_packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) 4539 self.assertEqual(isinstance(real_packed_weight, torch._C.ScriptObject), True) 4540 unpacked_weight = torch.ops.quantized.embedding_bag_unpack(real_packed_weight) 4541 self.assertEqual(unpacked_weight.int_repr().numpy(), qweight.int_repr().numpy()) 4542 self.assertEqual(unpacked_weight.q_per_channel_scales(), qweight.q_per_channel_scales()) 4543 self.assertEqual(unpacked_weight.q_per_channel_zero_points(), qweight.q_per_channel_zero_points()) 4544 4545 4546 4547 4548 def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, 4549 optimized_qparams, num_batches, data_type=np.float32): 4550 4551 # when num_batches = 1, it will create a 2D tensor 4552 unsplit_weight = torch.from_numpy((np.random.random_sample(( 4553 num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32)) 4554 4555 # test unsplit weight (memory format is `contiguous`) 4556 self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, unsplit_weight) 4557 4558 # test split weights (memory format is not `contiguous`) 4559 split_dim = len(unsplit_weight.shape) - 2 4560 split_weights = torch.split(unsplit_weight, 1, dim=split_dim) 4561 for weight in split_weights: 4562 self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight) 4563 4564 4565 4566 def embedding_bag_rowwise_offsets_run( 4567 self, bit_rate, num_embeddings, 4568 embedding_dim, num_offsets, 4569 use_32bit_indices, use_32bit_offsets, 4570 enable_per_sample_weights, 4571 include_last_offset, fallback_to_no_sparse, sparsity, atol, rtol): 4572 pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets 4573 pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack 4574 if bit_rate == 4: 4575 pt_op = torch.ops.quantized.embedding_bag_4bit_rowwise_offsets 4576 pt_prepack_op = torch.ops.quantized.embedding_bag_4bit_prepack 4577 elif bit_rate == 2: 4578 pt_op = torch.ops.quantized.embedding_bag_2bit_rowwise_offsets 4579 pt_prepack_op = torch.ops.quantized.embedding_bag_2bit_prepack 4580 4581 weights = torch.from_numpy((np.random.random_sample(( 4582 num_embeddings, embedding_dim)) + 1).astype(np.float32)) 4583 4584 max_segments = 5 4585 max_segment_length = 20 4586 num_lengths = np.random.randint(1, max_segments + 1) 4587 lengths = np.random.randint(0, max_segment_length + 1, 4588 size=num_lengths).astype(np.int32) 4589 num_indices = np.sum(lengths) 4590 4591 def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): 4592 """ 4593 Convert lengths to offsets 4594 """ 4595 tt = np.zeros((t.shape[0] + 1,), dtype=offset_type) 4596 tt[1:] = t 4597 tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type)) 4598 if use_begin_offset: 4599 return tt[:-1] 4600 return tt[1:] 4601 4602 offsets = lengths_to_offsets(lengths) 4603 indices = torch.from_numpy(np.random.randint( 4604 low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) 4605 4606 q_weights = pt_prepack_op(weights) 4607 per_sample_weights = torch.from_numpy(np.random.uniform( 4608 low=0.01, high=0.5, size=[len(indices)]).astype(np.float32)) if \ 4609 enable_per_sample_weights else None 4610 if include_last_offset: 4611 offsets = torch.cat( 4612 (offsets, torch.tensor([indices.size(0)], dtype=torch.long)), 0 4613 ) 4614 4615 # Reference result will be the floating point torch.nn.EmbeddingBag. 4616 def get_reference_result( 4617 num_embeddings, embedding_dim, 4618 include_last_offset, weights, per_sample_weights, 4619 indices, offsets): 4620 embedding_bag = torch.nn.EmbeddingBag( 4621 num_embeddings=num_embeddings, 4622 embedding_dim=embedding_dim, 4623 include_last_offset=include_last_offset, _weight=weights, 4624 scale_grad_by_freq=False, mode='sum' 4625 ) 4626 return embedding_bag(indices, offsets, 4627 per_sample_weights=per_sample_weights) 4628 4629 mapping_table = np.zeros(num_embeddings, dtype=np.int32) 4630 pruned_weights = weights 4631 prune_weights = sparsity > 0 4632 if prune_weights: 4633 if fallback_to_no_sparse: 4634 # Testing that prune_weight with mapping_table {0} will 4635 # fallback to non sparse embedding look up kernel. 4636 mapping_table = np.zeros(1, dtype=np.int32) 4637 else: 4638 # Prune and generate mapping table 4639 num_compressed_rows = 0 4640 unpruned_ids = [] 4641 for i in range(num_embeddings): 4642 if np.random.uniform() < sparsity: 4643 mapping_table[i] = -1 4644 q_weights[i, :] = 0 4645 weights[i, :] = 0 4646 else: 4647 mapping_table[i] = num_compressed_rows 4648 num_compressed_rows += 1 4649 unpruned_ids.append(i) 4650 q_weights = q_weights[unpruned_ids] 4651 pruned_weights = weights[unpruned_ids] 4652 4653 result = pt_op(q_weights, 4654 indices.int() if use_32bit_indices else indices, 4655 offsets.int() if use_32bit_offsets else offsets, 4656 mode=0, 4657 pruned_weights=prune_weights, 4658 per_sample_weights=per_sample_weights, 4659 compressed_indices_mapping=torch.tensor(mapping_table), 4660 include_last_offset=include_last_offset) 4661 4662 reference_result = get_reference_result( 4663 num_embeddings, embedding_dim, include_last_offset, weights, 4664 per_sample_weights, indices, offsets) 4665 4666 torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol) 4667 4668 4669 if bit_rate == 8 or bit_rate == 4: 4670 # Test operator that accepts TorchBind packed weights. 4671 if bit_rate == 4: 4672 qdtype = torch.quint4x2 4673 op = torch.ops.quantized.embedding_bag_4bit 4674 else: 4675 qdtype = torch.quint8 4676 op = torch.ops.quantized.embedding_bag_byte 4677 obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) 4678 obs(pruned_weights) 4679 # Get the scale and zero point for the weight tensor 4680 qparams = obs.calculate_qparams() 4681 # Quantize the weights to 8bits 4682 qweight = torch.quantize_per_channel(pruned_weights, qparams[0], qparams[1], axis=0, dtype=qdtype) 4683 packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) 4684 result = op(packed_weight, indices, offsets, mode=0, 4685 pruned_weights=prune_weights, 4686 per_sample_weights=per_sample_weights, 4687 compressed_indices_mapping=torch.tensor(mapping_table), 4688 include_last_offset=include_last_offset) 4689 torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol) 4690 4691 """ Tests the correctness of the embedding_bag_8bit quantized operator """ 4692 @given(num_embeddings=st.integers(10, 100), 4693 embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), 4694 num_offsets=st.integers(1, 20), 4695 use_32bit_indices=st.booleans(), 4696 use_32bit_offsets=st.booleans(), 4697 enable_per_sample_weights=st.booleans(), 4698 include_last_offset=st.booleans(), 4699 fallback_to_no_sparse=st.booleans(), 4700 sparsity=st.sampled_from([0.0, 0.5, 0.7])) 4701 def test_embedding_bag_byte(self, num_embeddings, 4702 embedding_dim, num_offsets, 4703 use_32bit_indices, 4704 use_32bit_offsets, 4705 enable_per_sample_weights, 4706 include_last_offset, 4707 fallback_to_no_sparse, 4708 sparsity): 4709 self.embedding_bag_rowwise_offsets_run( 4710 8, num_embeddings, embedding_dim, num_offsets, 4711 use_32bit_indices, use_32bit_offsets, 4712 enable_per_sample_weights, include_last_offset, 4713 fallback_to_no_sparse, 4714 sparsity=sparsity, atol=0.005, rtol=1e-3) 4715 4716 """ Tests the correctness of the embedding_bag_4bit quantized operator """ 4717 @given(num_embeddings=st.integers(10, 100), 4718 embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0), 4719 num_offsets=st.integers(1, 20), 4720 use_32bit_indices=st.booleans(), 4721 use_32bit_offsets=st.booleans(), 4722 enable_per_sample_weights=st.booleans(), 4723 include_last_offset=st.booleans(), 4724 fallback_to_no_sparse=st.booleans(), 4725 sparsity=st.sampled_from([0.0, 0.5, 0.7])) 4726 def test_embedding_bag_4bit(self, num_embeddings, 4727 embedding_dim, num_offsets, 4728 use_32bit_indices, 4729 use_32bit_offsets, 4730 enable_per_sample_weights, 4731 include_last_offset, 4732 fallback_to_no_sparse, 4733 sparsity): 4734 self.embedding_bag_rowwise_offsets_run(4, num_embeddings, 4735 embedding_dim, num_offsets, 4736 use_32bit_indices, use_32bit_offsets, 4737 enable_per_sample_weights, 4738 include_last_offset, 4739 fallback_to_no_sparse, 4740 sparsity=sparsity, 4741 atol=0.1, rtol=1e-2) 4742 4743 """ Tests the correctness of the embedding_bag_2bit quantized operator """ 4744 @given(num_embeddings=st.integers(10, 100), 4745 embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0), 4746 num_offsets=st.integers(1, 20), 4747 use_32bit_indices=st.booleans(), 4748 use_32bit_offsets=st.booleans(), 4749 enable_per_sample_weights=st.booleans(), 4750 include_last_offset=st.booleans(), 4751 fallback_to_no_sparse=st.booleans(), 4752 sparsity=st.sampled_from([0.0, 0.5, 0.7])) 4753 def test_embedding_bag_2bit(self, num_embeddings, 4754 embedding_dim, num_offsets, 4755 use_32bit_indices, 4756 use_32bit_offsets, 4757 enable_per_sample_weights, 4758 include_last_offset, 4759 fallback_to_no_sparse, 4760 sparsity): 4761 self.embedding_bag_rowwise_offsets_run(2, num_embeddings, 4762 embedding_dim, num_offsets, 4763 use_32bit_indices, use_32bit_offsets, 4764 enable_per_sample_weights, 4765 include_last_offset, 4766 fallback_to_no_sparse, 4767 sparsity=sparsity, 4768 atol=1.0, rtol=1e-1) 4769 4770 """ Tests the correctness of the quantized 8 bit embedding lookup operator """ 4771 @given(num_embeddings=st.integers(10, 100), 4772 embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0)) 4773 def test_embedding(self, num_embeddings, embedding_dim): 4774 dtypes = [torch.quint8, torch.quint4x2] 4775 quant_ops = [torch.ops.quantized.embedding_byte, torch.ops.quantized.embedding_4bit] 4776 atols = [0.005, 0.1] 4777 rtols = [1e-3, 1e-2] 4778 prepack_op = torch.ops.quantized.embedding_bag_prepack 4779 for quant_op, dtype, atol, rtol in zip(quant_ops, dtypes, atols, rtols): 4780 weights = torch.from_numpy((np.random.random_sample(( 4781 num_embeddings, embedding_dim)) + 1).astype(np.float32)) 4782 4783 obs = PerChannelMinMaxObserver(dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) 4784 obs(weights) 4785 # Get the scale and zero point for the weight tensor 4786 qparams = obs.calculate_qparams() 4787 4788 # Quantize the weights to 8bits 4789 qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype) 4790 max_segments = 5 4791 max_segment_length = 20 4792 num_lengths = np.random.randint(1, max_segments + 1) 4793 lengths = np.random.randint(1, max_segment_length + 1, 4794 size=num_lengths).astype(np.int32) 4795 num_indices = np.sum(lengths) 4796 indices = torch.from_numpy(np.random.randint( 4797 low=0, high=num_embeddings, size=num_indices, dtype=np.int64)) 4798 4799 packed_weight = prepack_op(qweight) 4800 qresult = quant_op(packed_weight, indices, pruned_weights=False) 4801 4802 ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) 4803 torch.testing.assert_close(ref, qresult, atol=atol, rtol=rtol) 4804 4805 def test_embedding_2d_indices(self): 4806 """ 4807 Tests the case where 2D indices are passed into the operator 4808 In this case the operator computes the correct offsets argument. 4809 Output shape is dependent on the indices dimension. 4810 """ 4811 quant_op = torch.ops.quantized.embedding_byte 4812 prepack_op = torch.ops.quantized.embedding_bag_prepack 4813 4814 indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]]) 4815 weights = torch.randn(10, 12, dtype=torch.float32) 4816 4817 ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False) 4818 obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) 4819 obs(weights) 4820 qparams = obs.calculate_qparams() 4821 4822 qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) 4823 packed_weight = prepack_op(qweight) 4824 qresult = quant_op(packed_weight, indices, pruned_weights=False) 4825 torch.testing.assert_close(ref, qresult, atol=0.05, rtol=1e-3) 4826 4827 def test_embedding_bag_2d_indices(self): 4828 """ 4829 Tests the case where 2D indices are passed into the operator 4830 In this case the operator computes the correct offsets argument. 4831 """ 4832 indices = torch.tensor([[9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8], [3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]]) 4833 weights = torch.randn(10, 12, dtype=torch.float32) 4834 4835 embedding_bag = torch.nn.EmbeddingBag( 4836 num_embeddings=10, 4837 embedding_dim=12, 4838 include_last_offset=False, _weight=weights, 4839 scale_grad_by_freq=False, mode='sum' 4840 ) 4841 result = embedding_bag(indices) 4842 4843 pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets 4844 pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack 4845 q_weights = pt_prepack_op(weights) 4846 qresult = pt_op(q_weights, indices, mode=0, pruned_weights=False) 4847 torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3) 4848 4849 # Test TorchBind based embedding_bag operator 4850 obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0) 4851 obs(weights) 4852 # Get the scale and zero point for the weight tensor 4853 qparams = obs.calculate_qparams() 4854 4855 # Quantize the weights to 8bits 4856 qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8) 4857 4858 packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight) 4859 qresult = torch.ops.quantized.embedding_bag_byte(packed_weight, indices, mode=0) 4860 4861 torch.testing.assert_close(result, qresult, atol=0.05, rtol=1e-3) 4862 4863 4864class TestQuantizedConv(TestCase): 4865 def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs, 4866 strides, i_pads, o_pads, channelwise): 4867 (X_data, W_data, bias_data, groups, transposed) = inputs 4868 (X, (X_scale, X_zero_point, X_qtype)) = X_data 4869 (W, (W_scale, W_zero_point, W_qtype)) = W_data 4870 (bias, (bias_scale, bias_zero_point, bias_qtype)) = bias_data 4871 4872 W = torch.from_numpy(W).float() 4873 bias = torch.from_numpy(bias).float() 4874 if channelwise and transposed: 4875 # currently transposed conv and per-channel per quantization does not work 4876 return 4877 # ONEDNN only supports symmetric quantization of weight and zero output padding 4878 if qengine_is_onednn(): 4879 W_zero_point = 0 4880 o_pads = len(o_pads) * [0] if o_pads is not None else None 4881 if channelwise: 4882 if transposed: 4883 output_channels = W.shape[1] # IC OC/G 4884 else: 4885 output_channels = W.shape[0] # OC IC/G 4886 W_scale = torch.tensor([W_scale] * output_channels) 4887 W_zero_point = torch.tensor([W_zero_point] * output_channels) 4888 W_q = torch.quantize_per_channel( 4889 W, scales=W_scale, zero_points=W_zero_point, 4890 axis=int(transposed), dtype=W_qtype) 4891 else: 4892 W_q = torch.quantize_per_tensor( 4893 W, scale=W_scale, zero_point=W_zero_point, dtype=W_qtype) 4894 4895 if isinstance(strides, int): 4896 dilations = [1] 4897 else: 4898 dilations = (1,) * len(strides) 4899 4900 if transposed: 4901 W_packed = qconv_prepack_fn(W_q, bias, strides, i_pads, o_pads, 4902 dilations, groups) 4903 else: 4904 W_packed = qconv_prepack_fn(W_q, bias, strides, i_pads, dilations, 4905 groups) 4906 (W_unpacked, bias) = qconv_unpack_fn(W_packed) 4907 4908 # Assert equal 4909 np.testing.assert_equal(W_q.int_repr().numpy(), 4910 W_unpacked.int_repr().numpy()) 4911 if channelwise: 4912 np.testing.assert_array_almost_equal( 4913 np.float32(W_q.q_per_channel_scales().numpy()), 4914 np.float32(W_unpacked.q_per_channel_scales().numpy()), 4915 decimal=4) 4916 np.testing.assert_equal(W_q.q_per_channel_zero_points( 4917 ).numpy(), W_unpacked.q_per_channel_zero_points().numpy()) 4918 else: 4919 np.testing.assert_equal(np.float32( 4920 W_q.q_scale()), np.float32(W_unpacked.q_scale())) 4921 np.testing.assert_equal( 4922 W_q.q_zero_point(), W_unpacked.q_zero_point()) 4923 4924 def _make_qconv_tensors( 4925 self, batch_size, input_channels_per_group, input_feature_map_shape, 4926 output_channels_per_group, groups, kernels, strides, pads, dilations, 4927 X_scale, X_zero_point, W_scale, W_zero_point, 4928 use_bias, use_channelwise, use_transpose, 4929 device=torch.device("cpu"), 4930 input_dtype=torch.quint8, 4931 weight_dtype=torch.qint8, 4932 ): 4933 assert not (use_channelwise and use_transpose), \ 4934 "Cannot generate channelwise qconv_transpose_tensors " 4935 input_channels = input_channels_per_group * groups 4936 output_channels = output_channels_per_group * groups 4937 # Padded input size should be at least as big as dilated kernel 4938 kernels = _single(kernels) 4939 strides = _single(strides) 4940 pads = _single(pads) 4941 dilations = _single(dilations) 4942 for i in range(len(kernels)): 4943 assume(input_feature_map_shape[i] + 2 * pads[i] 4944 >= dilations[i] * (kernels[i] - 1) + 1) 4945 W_scale = W_scale * output_channels 4946 W_zero_point = W_zero_point * output_channels 4947 # Resize W_scale and W_zero_points arrays equal to output_channels 4948 W_scale = W_scale[:output_channels] 4949 W_zero_point = W_zero_point[:output_channels] 4950 # For testing, we use small values for weights and for activations 4951 # so that no overflow occurs in vpmaddubsw instruction. If the 4952 # overflow occurs in qconv implementation and if there is no 4953 # overflow 4954 # In reference we can't exactly match the results with reference. 4955 # Please see the comment in qconv implementation file 4956 # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. 4957 (W_value_min, W_value_max) = (-5, 5) 4958 # the operator expects them in the format 4959 # (output_channels, input_channels/groups, kernel_d, kernel_h, kernel_w) 4960 # (input_channels, output_channels/groups, kernel_d, kernel_h, kernel_w) 4961 if use_transpose: 4962 output_shape = (input_channels, output_channels_per_group,) 4963 else: 4964 output_shape = (output_channels, input_channels_per_group,) 4965 W_init = torch.randint( 4966 W_value_min, 4967 W_value_max, 4968 output_shape + kernels, 4969 device=device, 4970 ) 4971 b_init = torch.randint(0, 10, (output_channels,), device=device) 4972 4973 (X_value_min, X_value_max) = (0, 4) 4974 X_init = torch.randint( 4975 X_value_min, 4976 X_value_max, 4977 (batch_size, input_channels,) + input_feature_map_shape, 4978 device=device 4979 ) 4980 X = X_scale * (X_init - X_zero_point).float() 4981 4982 if use_channelwise: 4983 W_shape = (-1, 1) + (1,) * len(kernels) 4984 W_scales_tensor = torch.tensor(W_scale, dtype=torch.float, device=device) 4985 W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float, device=device) 4986 W = W_scales_tensor.reshape(*W_shape) * ( 4987 W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() 4988 b = X_scale * W_scales_tensor * b_init.float() 4989 else: 4990 W = W_scale[0] * (W_init - W_zero_point[0]).float() 4991 b = X_scale * W_scale[0] * b_init.float() 4992 4993 X_q = torch.quantize_per_tensor( 4994 X, scale=X_scale, zero_point=X_zero_point, dtype=input_dtype) 4995 if use_channelwise: 4996 W_q = torch.quantize_per_channel( 4997 W, W_scales_tensor, W_zero_points_tensor.long(), 0, 4998 dtype=weight_dtype) 4999 else: 5000 W_q = torch.quantize_per_tensor( 5001 W, scale=W_scale[0], zero_point=W_zero_point[0], 5002 dtype=weight_dtype) 5003 5004 bias_float = b if use_bias else None 5005 5006 return (X, W), (X_q, W_q), bias_float 5007 5008 def _test_qconv_impl( 5009 self, qconv_fn, qconv_prepack_fn, conv_op, batch_size, 5010 input_channels_per_group, input_feature_map_shape, 5011 output_channels_per_group, groups, kernels, strides, pads, o_pads, 5012 dilations, X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, 5013 Y_zero_point, use_bias, post_op, use_channelwise, use_transpose, 5014 device=torch.device("cpu"), 5015 input_dtype=torch.quint8, 5016 weight_dtype=torch.qint8, 5017 output_dtype=torch.quint8, 5018 X2_scale=1.0, 5019 X2_zero_point=128 5020 ): 5021 # ONEDNN only supports symmetric quantization of weight 5022 if qengine_is_onednn() and W_zero_point is not None: 5023 W_zero_point = len(W_zero_point) * [0] 5024 (X, W), (X_q, W_q), bias_float = self._make_qconv_tensors( 5025 batch_size, input_channels_per_group, input_feature_map_shape, 5026 output_channels_per_group, groups, kernels, 5027 strides, pads, dilations, X_scale, X_zero_point, W_scale, 5028 W_zero_point, use_bias, use_channelwise, use_transpose, 5029 device=device, input_dtype=input_dtype, weight_dtype=weight_dtype) 5030 if bias_float is not None: 5031 bias_float = bias_float.to(device) 5032 # Assign weights 5033 W = W_q.dequantize() 5034 X = X_q.dequantize() 5035 conv_op.weight = torch.nn.Parameter(W, requires_grad=False) 5036 conv_op.bias = torch.nn.Parameter( 5037 bias_float, requires_grad=False) if use_bias else None 5038 result_ref = conv_op(X) 5039 if post_op == 'relu': 5040 assert not use_transpose, "Cannot fuse ReLU with ConvTranspose" 5041 relu = torch.nn.ReLU() 5042 result_ref = relu(result_ref) 5043 elif post_op == 'add': 5044 (X_value_min, X_value_max) = (0, 4) 5045 X2_init = torch.randint( 5046 X_value_min, 5047 X_value_max, 5048 result_ref.size(), 5049 device=device 5050 ) 5051 X2 = X2_scale * (X2_init - X2_zero_point).float() 5052 X2_q = torch.quantize_per_tensor( 5053 X2, scale=X2_scale, zero_point=X2_zero_point, dtype=input_dtype) 5054 result_ref = result_ref + X2 5055 elif post_op == 'add_relu': 5056 (X_value_min, X_value_max) = (0, 4) 5057 X2_init = torch.randint( 5058 X_value_min, 5059 X_value_max, 5060 result_ref.size(), 5061 device=device 5062 ) 5063 X2 = X2_scale * (X2_init - X2_zero_point).float() 5064 X2_q = torch.quantize_per_tensor( 5065 X2, scale=X2_scale, zero_point=X2_zero_point, dtype=input_dtype) 5066 result_ref = result_ref + X2 5067 relu = torch.nn.ReLU() 5068 result_ref = relu(result_ref) 5069 # Quantize reference results for comparison 5070 result_ref_q = torch.quantize_per_tensor( 5071 result_ref, scale=Y_scale, zero_point=Y_zero_point, 5072 dtype=output_dtype) 5073 5074 if qconv_prepack_fn is not None: 5075 if use_transpose: 5076 W_prepack = qconv_prepack_fn( 5077 W_q, bias_float, strides, pads, o_pads, dilations, groups) 5078 else: 5079 W_prepack = qconv_prepack_fn( 5080 W_q, bias_float, strides, pads, dilations, groups) 5081 if post_op == 'add' or post_op == 'add_relu': 5082 Y_q = qconv_fn( 5083 X_q, 5084 X2_q, 5085 W_prepack, 5086 Y_scale, 5087 Y_zero_point, 5088 ) 5089 else: 5090 Y_q = qconv_fn( 5091 X_q, 5092 W_prepack, 5093 Y_scale, 5094 Y_zero_point, 5095 ) 5096 else: 5097 # quantized conv op without prepacking 5098 Y_q = qconv_fn(X_q, W_q, bias_float, strides, pads, dilations, groups, Y_scale, Y_zero_point) 5099 5100 # Make sure the results match 5101 # assert_array_almost_equal compares using the following formula: 5102 # abs(desired-actual) < 1.5 * 10**(-decimal) 5103 # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html) 5104 # We use decimal = 0 to ignore off-by-1 differences between 5105 # reference and test. Off-by-1 differences arise due to the order of 5106 # round and zero_point addition operation, i.e., if addition 5107 # followed by round is used by reference and round followed by 5108 # addition is used by test, the results may differ by 1. 5109 # For example, the result of round(2.5) + 1 is 3 while 5110 # round(2.5 + 1) is 4 assuming the rounding mode is 5111 # round-to-nearest, ties-to-even. 5112 np.testing.assert_array_almost_equal( 5113 result_ref_q.int_repr().cpu().numpy(), Y_q.int_repr().cpu().numpy(), decimal=0, 5114 err_msg=f'''X: {X_q}, W: {W_q}, b: {bias_float}, strides: {strides}, 5115 pads: {pads}, o_pads: {o_pads}, dilations: {dilations}, 5116 groups: {groups}, y_s: {Y_scale}, y_zp: {Y_zero_point}''') 5117 5118 # Return the quantized data for later reuse 5119 return X_q, W_q, bias_float 5120 5121 """Tests the correctness of quantized convolution op.""" 5122 @given(batch_size=st.integers(1, 3), 5123 input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5124 height=st.integers(10, 16), 5125 width=st.integers(7, 14), 5126 output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5127 groups=st.integers(1, 300), 5128 kernel_h=st.integers(1, 7), 5129 kernel_w=st.integers(1, 7), 5130 stride_h=st.integers(1, 2), 5131 stride_w=st.integers(1, 2), 5132 pad_h=st.integers(0, 2), 5133 pad_w=st.integers(0, 2), 5134 dilation=st.integers(1, 2), 5135 X_scale=st.floats(1.2, 1.6), 5136 X_zero_point=st.integers(0, 4), 5137 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 5138 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 5139 Y_scale=st.floats(4.2, 5.6), 5140 Y_zero_point=st.integers(0, 4), 5141 use_bias=st.booleans(), 5142 use_channelwise=st.booleans()) 5143 @override_qengines 5144 def test_qconv2d( 5145 self, 5146 batch_size, 5147 input_channels_per_group, 5148 height, 5149 width, 5150 output_channels_per_group, 5151 groups, 5152 kernel_h, 5153 kernel_w, 5154 stride_h, 5155 stride_w, 5156 pad_h, 5157 pad_w, 5158 dilation, 5159 X_scale, 5160 X_zero_point, 5161 W_scale, 5162 W_zero_point, 5163 Y_scale, 5164 Y_zero_point, 5165 use_bias, 5166 use_channelwise, 5167 ): 5168 input_channels = input_channels_per_group * groups 5169 output_channels = output_channels_per_group * groups 5170 kernels = (kernel_h, kernel_w) 5171 strides = (stride_h, stride_w) 5172 pads = (pad_h, pad_w) 5173 dilations = (dilation, dilation) 5174 5175 qconv = torch.ops.quantized.conv2d 5176 qconv_prepack = torch.ops.quantized.conv2d_prepack 5177 conv_op = torch.nn.Conv2d( 5178 input_channels, 5179 output_channels, 5180 kernels, 5181 strides, 5182 pads, 5183 dilations, 5184 groups, 5185 ) 5186 5187 act_qdtypes = [torch.quint8] 5188 # Only qnnpack qengine supportes qint8 5189 if qengine_is_qnnpack() and torch.backends.xnnpack.enabled: 5190 act_qdtypes.append(torch.qint8) 5191 5192 for X_qdtype in act_qdtypes: 5193 if X_qdtype == torch.qint8: 5194 W_zero_point = [0 for i in range(len(W_zero_point))] 5195 5196 self._test_qconv_impl( 5197 qconv, qconv_prepack, conv_op, batch_size, 5198 input_channels_per_group, (height, width), 5199 output_channels_per_group, groups, kernels, strides, pads, None, 5200 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5201 Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False, input_dtype=X_qdtype, output_dtype=X_qdtype) 5202 5203 @given(batch_size=st.integers(1, 3), 5204 input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5205 height=st.integers(10, 16), 5206 width=st.integers(7, 14), 5207 output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5208 groups=st.integers(1, 300), 5209 kernel_h=st.integers(1, 7), 5210 kernel_w=st.integers(1, 7), 5211 stride_h=st.integers(1, 2), 5212 stride_w=st.integers(1, 2), 5213 pad_h=st.integers(0, 2), 5214 pad_w=st.integers(0, 2), 5215 dilation=st.integers(1, 2), 5216 X_scale=st.floats(1.2, 1.6), 5217 X_zero_point=st.integers(0, 4), 5218 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 5219 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 5220 Y_scale=st.floats(4.2, 5.6), 5221 Y_zero_point=st.integers(0, 4), 5222 use_bias=st.booleans(), 5223 use_channelwise=st.booleans()) 5224 @override_qengines 5225 def test_qconv2d_relu( 5226 self, 5227 batch_size, 5228 input_channels_per_group, 5229 height, 5230 width, 5231 output_channels_per_group, 5232 groups, 5233 kernel_h, 5234 kernel_w, 5235 stride_h, 5236 stride_w, 5237 pad_h, 5238 pad_w, 5239 dilation, 5240 X_scale, 5241 X_zero_point, 5242 W_scale, 5243 W_zero_point, 5244 Y_scale, 5245 Y_zero_point, 5246 use_bias, 5247 use_channelwise, 5248 ): 5249 input_channels = input_channels_per_group * groups 5250 output_channels = output_channels_per_group * groups 5251 kernels = (kernel_h, kernel_w) 5252 strides = (stride_h, stride_w) 5253 pads = (pad_h, pad_w) 5254 dilations = (dilation, dilation) 5255 5256 qconv = torch.ops.quantized.conv2d_relu 5257 qconv_prepack = torch.ops.quantized.conv2d_prepack 5258 conv_op = torch.nn.Conv2d( 5259 input_channels, 5260 output_channels, 5261 kernels, 5262 strides, 5263 pads, 5264 dilations, 5265 groups, 5266 ) 5267 5268 act_qdtypes = [torch.quint8] 5269 # Only qnnpack qengine supportes qint8 5270 if qengine_is_qnnpack() and torch.backends.xnnpack.enabled: 5271 act_qdtypes.append(torch.qint8) 5272 5273 for X_qdtype in act_qdtypes: 5274 if X_qdtype == torch.qint8: 5275 W_zero_point = [0 for i in range(len(W_zero_point))] 5276 5277 self._test_qconv_impl( 5278 qconv, qconv_prepack, conv_op, batch_size, 5279 input_channels_per_group, (height, width), 5280 output_channels_per_group, groups, kernels, strides, pads, None, 5281 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5282 Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False, input_dtype=X_qdtype, output_dtype=X_qdtype) 5283 5284 @skipIfNoONEDNN 5285 def test_qconv2d_add(self): 5286 batch_size = 3 5287 groups_list = [1, 10] 5288 input_channels_per_group = 2 5289 output_channels_per_group = 2 5290 height = 10 5291 width = 10 5292 kernel_h = 3 5293 kernel_w = 3 5294 stride_h = 2 5295 stride_w = 2 5296 pad_h = 1 5297 pad_w = 1 5298 dilation = 1 5299 X_scale = 1.5 5300 X_zero_point = 2 5301 W_scale = [1.5] 5302 W_zero_point = [-3] 5303 Y_scale = 4.2 5304 Y_zero_point = 0 5305 use_bias_list = [False, True] 5306 use_channelwise_list = [False, True] 5307 X2_scale = 1.2 5308 X2_zero_point_list = [0, 4] 5309 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list) 5310 for groups, use_bias, use_channelwise, X2_zero_point in options: 5311 with override_quantized_engine('onednn'): 5312 input_channels = input_channels_per_group * groups 5313 output_channels = output_channels_per_group * groups 5314 kernels = (kernel_h, kernel_w) 5315 strides = (stride_h, stride_w) 5316 pads = (pad_h, pad_w) 5317 dilations = (dilation, dilation) 5318 5319 qconv = torch.ops.quantized.conv2d_add 5320 qconv_prepack = torch.ops.quantized.conv2d_prepack 5321 conv_op = torch.nn.Conv2d( 5322 input_channels, 5323 output_channels, 5324 kernels, 5325 strides, 5326 pads, 5327 dilations, 5328 groups, 5329 ) 5330 5331 X_qdtype = torch.quint8 5332 self._test_qconv_impl( 5333 qconv, qconv_prepack, conv_op, batch_size, 5334 input_channels_per_group, (height, width), 5335 output_channels_per_group, groups, kernels, strides, pads, None, 5336 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5337 Y_scale, Y_zero_point, use_bias, "add", use_channelwise, False, 5338 input_dtype=X_qdtype, output_dtype=X_qdtype, X2_scale=X2_scale, X2_zero_point=X2_zero_point) 5339 5340 @skipIfNoONEDNN 5341 def test_qconv2d_add_relu(self): 5342 batch_size = 3 5343 height = 10 5344 width = 10 5345 groups_list = [1, 10] 5346 input_channels_per_group = 2 5347 output_channels_per_group = 2 5348 kernel_h = 3 5349 kernel_w = 3 5350 stride_h = 2 5351 stride_w = 2 5352 pad_h = 1 5353 pad_w = 1 5354 dilation = 1 5355 X_scale = 1.5 5356 X_zero_point = 2 5357 W_scale = [1.5] 5358 W_zero_point = [-3] 5359 Y_scale = 4.2 5360 Y_zero_point = 0 5361 use_bias_list = [False, True] 5362 use_channelwise_list = [False, True] 5363 X2_scale = 1.2 5364 X2_zero_point_list = [0, 4] 5365 5366 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list) 5367 for groups, use_bias, use_channelwise, X2_zero_point in options: 5368 with override_quantized_engine('onednn'): 5369 input_channels = input_channels_per_group * groups 5370 output_channels = output_channels_per_group * groups 5371 kernels = (kernel_h, kernel_w) 5372 strides = (stride_h, stride_w) 5373 pads = (pad_h, pad_w) 5374 dilations = (dilation, dilation) 5375 5376 qconv = torch.ops.quantized.conv2d_add_relu 5377 qconv_prepack = torch.ops.quantized.conv2d_prepack 5378 conv_op = torch.nn.Conv2d( 5379 input_channels, 5380 output_channels, 5381 kernels, 5382 strides, 5383 pads, 5384 dilations, 5385 groups, 5386 ) 5387 5388 X_qdtype = torch.quint8 5389 self._test_qconv_impl( 5390 qconv, qconv_prepack, conv_op, batch_size, 5391 input_channels_per_group, (height, width), 5392 output_channels_per_group, groups, kernels, strides, pads, None, 5393 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5394 Y_scale, Y_zero_point, use_bias, "add_relu", use_channelwise, False, 5395 input_dtype=X_qdtype, output_dtype=X_qdtype, X2_scale=X2_scale, X2_zero_point=X2_zero_point) 5396 5397 # TODO: merge this test with test_qconv2d when CUDNN runtime flags becomes available 5398 """Tests the correctness of quantized 2D convolution cudnn op.""" 5399 @given(batch_size=st.integers(1, 3), 5400 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 5401 input_channels_per_group=st.integers(1, 32), 5402 height=st.integers(10, 16), 5403 width=st.integers(7, 14), 5404 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 5405 output_channels_per_group=st.integers(1, 32), 5406 groups=st.integers(1, 1), # currently padding only supports groups=1 5407 kernel_h=st.integers(1, 7), 5408 kernel_w=st.integers(1, 7), 5409 stride_h=st.integers(1, 2), 5410 stride_w=st.integers(1, 2), 5411 pad_h=st.integers(0, 2), 5412 pad_w=st.integers(0, 2), 5413 # result for dilation == 2 is not correct 5414 # dilation=st.integers(1, 2), 5415 # currently cudnn has only been verified to work for dilation = 1 5416 # TODO: check backend works for dilation > 1 5417 dilation=st.integers(1, 1), 5418 X_scale=st.floats(1.2, 1.6), 5419 X_zero_point=st.sampled_from([0]), 5420 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 5421 W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2), 5422 Y_scale=st.floats(4.2, 5.6), 5423 Y_zero_point=st.sampled_from([0]), 5424 use_bias=st.booleans(), 5425 # TODO: enable channelwise 5426 use_channelwise=st.sampled_from([False])) 5427 @skipIfNoFBGEMM 5428 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 5429 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 5430 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 5431 @unittest.skip("not currently working and feature isn't used") 5432 def test_qconv2d_cudnn( 5433 self, 5434 batch_size, 5435 input_channels_per_group, 5436 height, 5437 width, 5438 output_channels_per_group, 5439 groups, 5440 kernel_h, 5441 kernel_w, 5442 stride_h, 5443 stride_w, 5444 pad_h, 5445 pad_w, 5446 dilation, 5447 X_scale, 5448 X_zero_point, 5449 W_scale, 5450 W_zero_point, 5451 Y_scale, 5452 Y_zero_point, 5453 use_bias, 5454 use_channelwise, 5455 ): 5456 input_channels = input_channels_per_group * groups 5457 output_channels = output_channels_per_group * groups 5458 kernels = (kernel_h, kernel_w) 5459 strides = (stride_h, stride_w) 5460 pads = (pad_h, pad_w) 5461 dilations = (dilation, dilation) 5462 5463 qconv = torch.ops.quantized.conv2d 5464 conv_op = torch.nn.Conv2d( 5465 input_channels, 5466 output_channels, 5467 kernels, 5468 strides, 5469 pads, 5470 dilations, 5471 groups, 5472 ).to(torch.device("cuda")) 5473 self._test_qconv_impl( 5474 qconv, torch.ops.quantized.conv2d_prepack, conv_op, batch_size, 5475 input_channels_per_group, (height, width), 5476 output_channels_per_group, groups, kernels, strides, pads, None, 5477 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5478 Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False, 5479 device=torch.device("cuda"), 5480 input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8) 5481 5482 @given(batch_size=st.integers(1, 3), 5483 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 5484 input_channels_per_group=st.integers(1, 32), 5485 height=st.integers(10, 16), 5486 width=st.integers(7, 14), 5487 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 5488 output_channels_per_group=st.integers(1, 32), 5489 groups=st.integers(1, 1), # currently padding only supports groups=1 5490 kernel_h=st.integers(1, 7), 5491 kernel_w=st.integers(1, 7), 5492 stride_h=st.integers(1, 2), 5493 stride_w=st.integers(1, 2), 5494 pad_h=st.integers(0, 2), 5495 pad_w=st.integers(0, 2), 5496 # result for dilation == 2 is not correct 5497 # dilation=st.integers(1, 2), 5498 # currently cudnn has only been verified to work for dilation = 1 5499 # TODO: check backend works for dilation > 1 5500 dilation=st.integers(1, 1), 5501 X_scale=st.floats(1.2, 1.6), 5502 X_zero_point=st.sampled_from([0]), 5503 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 5504 W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2), 5505 Y_scale=st.floats(4.2, 5.6), 5506 Y_zero_point=st.sampled_from([0]), 5507 use_bias=st.booleans(), 5508 # TODO: enable channelwise 5509 use_channelwise=st.sampled_from([False])) 5510 @skipIfNoFBGEMM 5511 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 5512 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 5513 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 5514 @unittest.skip("not currently working and feature isn't used") 5515 def test_qconv2d_relu_cudnn( 5516 self, 5517 batch_size, 5518 input_channels_per_group, 5519 height, 5520 width, 5521 output_channels_per_group, 5522 groups, 5523 kernel_h, 5524 kernel_w, 5525 stride_h, 5526 stride_w, 5527 pad_h, 5528 pad_w, 5529 dilation, 5530 X_scale, 5531 X_zero_point, 5532 W_scale, 5533 W_zero_point, 5534 Y_scale, 5535 Y_zero_point, 5536 use_bias, 5537 use_channelwise, 5538 ): 5539 input_channels = input_channels_per_group * groups 5540 output_channels = output_channels_per_group * groups 5541 kernels = (kernel_h, kernel_w) 5542 strides = (stride_h, stride_w) 5543 pads = (pad_h, pad_w) 5544 dilations = (dilation, dilation) 5545 5546 qconv = torch.ops.quantized.conv2d_relu 5547 conv_op = torch.nn.Conv2d( 5548 input_channels, 5549 output_channels, 5550 kernels, 5551 strides, 5552 pads, 5553 dilations, 5554 groups, 5555 ).to(torch.device("cuda")) 5556 self._test_qconv_impl( 5557 qconv, torch.ops.quantized.conv2d_prepack, conv_op, batch_size, 5558 input_channels_per_group, (height, width), 5559 output_channels_per_group, groups, kernels, strides, pads, None, 5560 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5561 Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False, 5562 device=torch.device("cuda"), 5563 input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8) 5564 5565 @unittest.skip("used for local benchmarking, comment when we want to run it") 5566 def test_benchmark(self): 5567 batch_size = 16 5568 in_channel = 64 5569 out_channel = 64 5570 kernel_size = 3 5571 height = 256 5572 width = 256 5573 print( 5574 "parameters:", 5575 "batch_size:", batch_size, 5576 "in_channel:", in_channel, 5577 "out_channel:", out_channel, 5578 "kernel_size:", kernel_size, 5579 "height:", height, 5580 "widht:", width 5581 ) 5582 conv = torch.nn.Conv2d(in_channel, out_channel, kernel_size).cuda() 5583 input = torch.randn((batch_size, in_channel, height, width), device='cuda') 5584 weight = conv.weight.detach() 5585 stride = (1, 1) 5586 padding = (0, 0) 5587 dilation = (1, 1) 5588 groups = 1 5589 conv_op = torch.nn.functional.conv2d 5590 # profile 5591 from torch.profiler import profile, ProfilerActivity 5592 5593 def trace_handler(p): 5594 output = p.key_averages().table(sort_by="self_cpu_time_total", row_limit=10) 5595 p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") 5596 5597 my_schedule = torch.profiler.schedule( 5598 wait=5, 5599 warmup=5, 5600 active=20) 5601 5602 # fp32 benchmark 5603 with profile( 5604 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 5605 schedule=my_schedule, 5606 on_trace_ready=trace_handler) as prof: 5607 for i in range(30): 5608 conv_op(input, weight, None, stride, padding, dilation, groups) 5609 prof.step() 5610 5611 print("fp32 benchmark result:") 5612 print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) 5613 5614 # fp16 benchmark 5615 input_fp16 = input.to(torch.float16) 5616 weight_fp16 = input.to(torch.float16) 5617 5618 with profile( 5619 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 5620 schedule=my_schedule, 5621 on_trace_ready=trace_handler) as prof: 5622 for i in range(30): 5623 conv_op(input_fp16, weight_fp16, None, stride, padding, dilation, groups) 5624 prof.step() 5625 5626 print("fp16 benchmark result:") 5627 print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) 5628 5629 input_int8 = torch.quantize_per_tensor(input, 1, 0, torch.qint8).contiguous(memory_format=torch.channels_last) 5630 weight_int8 = torch.quantize_per_tensor(weight, 1, 0, torch.qint8).contiguous(memory_format=torch.channels_last) 5631 scale = 1.0 5632 zero_point = 0 5633 conv_op = torch.ops.quantized.conv2d 5634 weight_prepacked = torch.ops.quantized.conv2d_prepack(weight_int8, None, stride, padding, dilation, groups) 5635 with profile( 5636 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 5637 schedule=my_schedule, 5638 on_trace_ready=trace_handler) as prof: 5639 for i in range(30): 5640 conv_op(input_int8, weight_prepacked, scale, zero_point) 5641 prof.step() 5642 5643 print("int8 benchmark result:") 5644 print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) 5645 5646 """Tests the correctness of quantized convolution op.""" 5647 @override_qengines 5648 def test_qconv_transpose1d(self): 5649 if not qengine_is_qnnpack(): 5650 return # Currently only the QNNPACK is supported 5651 if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN): 5652 return # QNNPACK doesn't support these 5653 batch_size = 2 5654 input_channels_per_group_list = [2, 32] 5655 width = 14 5656 output_channels_per_group_list = [2, 8] 5657 groups_list = [1, 3] 5658 kernel_list = [1, 7] 5659 stride_list = [1, 2] 5660 pad = 2 5661 o_pad = 0 5662 dilation = 1 5663 X_scale = 1.2 5664 X_zero_point = 1 5665 W_scale = [1.2] 5666 W_zero_point = [1] 5667 Y_scale = 4.2 5668 Y_zero_point = 2 5669 use_bias_list = [True, False] 5670 5671 test_cases = itertools.product( 5672 input_channels_per_group_list, output_channels_per_group_list, 5673 groups_list, kernel_list, stride_list, use_bias_list) 5674 for input_channels_per_group, output_channels_per_group, \ 5675 groups, kernel, stride, use_bias in test_cases: 5676 5677 input_channels = input_channels_per_group * groups 5678 output_channels = output_channels_per_group * groups 5679 kernels = (kernel,) 5680 strides = (stride,) 5681 pads = (pad,) 5682 o_pads = (o_pad,) 5683 dilations = (dilation,) 5684 5685 qconv = torch.ops.quantized.conv_transpose1d 5686 qconv_prepack = torch.ops.quantized.conv_transpose1d_prepack 5687 conv_op = torch.nn.ConvTranspose1d( 5688 in_channels=input_channels, 5689 out_channels=output_channels, 5690 kernel_size=kernels, 5691 stride=strides, 5692 padding=pads, 5693 output_padding=o_pads, 5694 groups=groups, 5695 dilation=dilations, 5696 bias=use_bias 5697 ) 5698 5699 act_qdtypes = [torch.quint8] 5700 # Only qnnpack qengine supportes qint8 5701 if qengine_is_qnnpack() and torch.backends.xnnpack.enabled: 5702 act_qdtypes.append(torch.qint8) 5703 5704 for X_qdtype in act_qdtypes: 5705 if X_qdtype == torch.qint8: 5706 W_zero_point = [0 for i in range(len(W_zero_point))] 5707 5708 X_q, W_q, bias_float = self._test_qconv_impl( 5709 qconv, qconv_prepack, conv_op, batch_size, 5710 input_channels_per_group, (width, ), 5711 output_channels_per_group, groups, kernels, strides, pads, o_pads, 5712 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5713 Y_scale, Y_zero_point, use_bias, post_op="none", 5714 use_channelwise=False, use_transpose=True, input_dtype=X_qdtype, output_dtype=X_qdtype) 5715 5716 # check that this doesn't error 5717 test_conv = torch.ao.nn.quantized.ConvTranspose1d(input_channels, output_channels, 1) 5718 test_conv.scale = Y_scale 5719 test_conv(X_q) 5720 5721 # Test the module implementation 5722 qconv_op = torch.ao.nn.quantized.ConvTranspose1d( 5723 in_channels=input_channels, 5724 out_channels=output_channels, 5725 kernel_size=kernels, 5726 stride=strides, 5727 padding=pads, 5728 output_padding=o_pads, 5729 groups=groups, 5730 dilation=dilations, 5731 bias=use_bias 5732 ) 5733 qconv_op.scale = Y_scale 5734 qconv_op.zero_point = Y_zero_point 5735 qconv_op.set_weight_bias(W_q, bias_float) 5736 5737 Y_dq_ref = conv_op(X_q.dequantize()) 5738 Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale, 5739 zero_point=Y_zero_point, 5740 dtype=X_qdtype) 5741 Y_q = qconv_op(X_q) 5742 self.assertEqual(Y_q_ref, Y_q) 5743 5744 5745 """Tests the correctness of quantized convolution op.""" 5746 @given(batch_size=st.integers(1, 3), 5747 input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5748 height=st.integers(10, 16), 5749 width=st.integers(7, 14), 5750 output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5751 groups=st.integers(1, 300), 5752 kernel_h=st.integers(1, 7), 5753 kernel_w=st.integers(1, 7), 5754 stride_h=st.integers(1, 2), 5755 stride_w=st.integers(1, 2), 5756 pad_h=st.integers(0, 2), 5757 pad_w=st.integers(0, 2), 5758 o_pad_h=st.integers(0, 2), 5759 o_pad_w=st.integers(0, 2), 5760 dilation=st.integers(1, 2), 5761 X_scale=st.floats(1.2, 1.6), 5762 X_zero_point=st.integers(0, 4), 5763 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 5764 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 5765 Y_scale=st.floats(4.2, 5.6), 5766 Y_zero_point=st.integers(0, 4), 5767 use_bias=st.booleans()) 5768 @override_qengines 5769 @unittest.skip( 5770 "this is broken without changes to any relevant code, " 5771 "we need to remove hypothesis testing in CI") 5772 def test_qconv_transpose2d( 5773 self, 5774 batch_size, 5775 input_channels_per_group, 5776 height, 5777 width, 5778 output_channels_per_group, 5779 groups, 5780 kernel_h, 5781 kernel_w, 5782 stride_h, 5783 stride_w, 5784 pad_h, 5785 pad_w, 5786 o_pad_h, 5787 o_pad_w, 5788 dilation, 5789 X_scale, 5790 X_zero_point, 5791 W_scale, 5792 W_zero_point, 5793 Y_scale, 5794 Y_zero_point, 5795 use_bias): 5796 if qengine_is_qnnpack() and (IS_PPC or TEST_WITH_UBSAN): 5797 return # QNNPACK doesn't support these 5798 # ONEDNN does not support output paddings 5799 if qengine_is_onednn() and (o_pad_h, o_pad_w) != (0, 0): 5800 return 5801 assume(o_pad_h < stride_h and o_pad_h < dilation) 5802 assume(o_pad_w < stride_w and o_pad_w < dilation) 5803 5804 input_channels = input_channels_per_group * groups 5805 output_channels = output_channels_per_group * groups 5806 kernels = (kernel_h, kernel_w) 5807 strides = (stride_h, stride_w) 5808 pads = (pad_h, pad_w) 5809 o_pads = (o_pad_h, o_pad_w) 5810 dilations = (dilation, dilation) 5811 5812 qconv = torch.ops.quantized.conv_transpose2d 5813 qconv_prepack = torch.ops.quantized.conv_transpose2d_prepack 5814 conv_op = torch.nn.ConvTranspose2d( 5815 in_channels=input_channels, 5816 out_channels=output_channels, 5817 kernel_size=kernels, 5818 stride=strides, 5819 padding=pads, 5820 output_padding=o_pads, 5821 groups=groups, 5822 dilation=dilations, 5823 bias=use_bias 5824 ) 5825 act_qdtypes = [torch.quint8] 5826 # Only qnnpack qengine supportes qint8 5827 if qengine_is_qnnpack() and torch.backends.xnnpack.enabled: 5828 act_qdtypes.append(torch.qint8) 5829 5830 for X_qdtype in act_qdtypes: 5831 if X_qdtype == torch.qint8: 5832 W_zero_point = [0 for i in range(len(W_zero_point))] 5833 5834 X_q, W_q, bias_float = self._test_qconv_impl( 5835 qconv, qconv_prepack, conv_op, batch_size, 5836 input_channels_per_group, (height, width), 5837 output_channels_per_group, groups, kernels, strides, pads, o_pads, 5838 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5839 Y_scale, Y_zero_point, use_bias, post_op="none", 5840 use_channelwise=False, use_transpose=True, input_dtype=X_qdtype, output_dtype=X_qdtype) 5841 5842 # check that this doesn't error 5843 test_conv = torch.ao.nn.quantized.ConvTranspose2d(input_channels, output_channels, 1) 5844 test_conv.scale = Y_scale 5845 test_conv(X_q) 5846 5847 # Test the module implementation 5848 qconv_op = torch.ao.nn.quantized.ConvTranspose2d( 5849 in_channels=input_channels, 5850 out_channels=output_channels, 5851 kernel_size=kernels, 5852 stride=strides, 5853 padding=pads, 5854 output_padding=o_pads, 5855 groups=groups, 5856 dilation=dilations, 5857 bias=use_bias 5858 ) 5859 qconv_op.scale = Y_scale 5860 qconv_op.zero_point = Y_zero_point 5861 qconv_op.set_weight_bias(W_q, bias_float) 5862 5863 Y_dq_ref = conv_op(X_q.dequantize()) 5864 Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale, 5865 zero_point=Y_zero_point, 5866 dtype=X_qdtype) 5867 Y_q = qconv_op(X_q) 5868 self.assertEqual(Y_q_ref, Y_q) 5869 5870 """Tests the correctness of quantized convolution op.""" 5871 @given(batch_size=st.integers(1, 3), 5872 input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5873 time=st.integers(2, 5), 5874 height=st.integers(10, 16), 5875 width=st.integers(7, 14), 5876 output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]), 5877 groups=st.integers(1, 300), 5878 kernel_t=st.integers(1, 7), 5879 kernel_h=st.integers(1, 7), 5880 kernel_w=st.integers(1, 7), 5881 stride_t=st.integers(1, 2), 5882 stride_h=st.integers(1, 2), 5883 stride_w=st.integers(1, 2), 5884 pad_t=st.integers(0, 2), 5885 pad_h=st.integers(0, 2), 5886 pad_w=st.integers(0, 2), 5887 o_pad_t=st.integers(0, 2), 5888 o_pad_h=st.integers(0, 2), 5889 o_pad_w=st.integers(0, 2), 5890 dilation=st.integers(1, 2), 5891 X_scale=st.floats(1.2, 1.6), 5892 X_zero_point=st.integers(0, 4), 5893 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 5894 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 5895 Y_scale=st.floats(4.2, 5.6), 5896 Y_zero_point=st.integers(0, 4), 5897 use_bias=st.booleans()) 5898 @override_qengines 5899 @unittest.skip( 5900 "this is broken without changes to any relevant code, " 5901 "we need to remove hypothesis testing in CI") 5902 def test_qconv_transpose3d( 5903 self, 5904 batch_size, 5905 input_channels_per_group, 5906 time, 5907 height, 5908 width, 5909 output_channels_per_group, 5910 groups, 5911 kernel_t, 5912 kernel_h, 5913 kernel_w, 5914 stride_t, 5915 stride_h, 5916 stride_w, 5917 pad_t, 5918 pad_h, 5919 pad_w, 5920 o_pad_t, 5921 o_pad_h, 5922 o_pad_w, 5923 dilation, 5924 X_scale, 5925 X_zero_point, 5926 W_scale, 5927 W_zero_point, 5928 Y_scale, 5929 Y_zero_point, 5930 use_bias): 5931 if qengine_is_qnnpack(): 5932 return # QNNPACK doesn't support this 5933 # ONEDNN doesn't support output paddings 5934 if qengine_is_onednn() and (o_pad_t, o_pad_h, o_pad_w) != (0, 0, 0): 5935 return 5936 assume(o_pad_t < stride_t or o_pad_t < dilation) 5937 assume(o_pad_h < stride_h or o_pad_h < dilation) 5938 assume(o_pad_w < stride_w or o_pad_w < dilation) 5939 5940 input_channels = input_channels_per_group * groups 5941 output_channels = output_channels_per_group * groups 5942 kernels = (kernel_t, kernel_h, kernel_w) 5943 strides = (stride_t, stride_h, stride_w) 5944 pads = (pad_t, pad_h, pad_w) 5945 o_pads = (o_pad_t, o_pad_h, o_pad_w) 5946 dilations = (dilation, dilation, dilation) 5947 5948 qconv = torch.ops.quantized.conv_transpose3d 5949 qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack 5950 conv_op = torch.nn.ConvTranspose3d( 5951 in_channels=input_channels, 5952 out_channels=output_channels, 5953 kernel_size=kernels, 5954 stride=strides, 5955 padding=pads, 5956 output_padding=o_pads, 5957 groups=groups, 5958 dilation=dilations, 5959 bias=use_bias 5960 ) 5961 X_q, W_q, bias_float = self._test_qconv_impl( 5962 qconv, qconv_prepack, conv_op, batch_size, 5963 input_channels_per_group, (time, height, width), 5964 output_channels_per_group, groups, kernels, strides, pads, o_pads, 5965 dilations, X_scale, X_zero_point, W_scale, W_zero_point, 5966 Y_scale, Y_zero_point, use_bias, post_op="none", 5967 use_channelwise=False, use_transpose=True) 5968 5969 # check that this doesn't error 5970 test_conv = torch.ao.nn.quantized.ConvTranspose3d(input_channels, output_channels, 1) 5971 test_conv.scale = Y_scale 5972 test_conv(X_q) 5973 5974 # Test the module implementation 5975 qconv_op = torch.ao.nn.quantized.ConvTranspose3d( 5976 in_channels=input_channels, 5977 out_channels=output_channels, 5978 kernel_size=kernels, 5979 stride=strides, 5980 padding=pads, 5981 output_padding=o_pads, 5982 groups=groups, 5983 dilation=dilations, 5984 bias=use_bias 5985 ) 5986 qconv_op.scale = Y_scale 5987 qconv_op.zero_point = Y_zero_point 5988 qconv_op.set_weight_bias(W_q, bias_float) 5989 5990 Y_dq_ref = conv_op(X_q.dequantize()) 5991 Y_q_ref = torch.quantize_per_tensor(Y_dq_ref, scale=Y_scale, 5992 zero_point=Y_zero_point, 5993 dtype=torch.quint8) 5994 Y_q = qconv_op(X_q) 5995 self.assertEqual(Y_q_ref, Y_q) 5996 5997 @given( 5998 inputs=hu.tensor_conv( 5999 spatial_dim=1, batch_size_range=(1, 3), 6000 input_channels_per_group_range=(1, 4), 6001 output_channels_per_group_range=(1, 4), feature_map_range=(4, 8), 6002 kernel_range=(1, 4), max_groups=4, 6003 can_be_transposed=False, 6004 qparams=[hu.qparams(dtypes=torch.quint8, 6005 zero_point_min=0, 6006 zero_point_max=0), 6007 hu.qparams(dtypes=torch.qint8, 6008 zero_point_min=0, 6009 zero_point_max=0), 6010 hu.qparams(dtypes=torch.qint32, 6011 zero_point_min=0, 6012 zero_point_max=0)]), 6013 stride=st.integers(1, 3), 6014 pad=st.integers(1, 2), 6015 o_pad=st.integers(1, 2), 6016 channelwise=st.booleans()) 6017 @override_qengines 6018 def test_qconv1d_unpack(self, inputs, stride, pad, o_pad, channelwise): 6019 transposed = inputs[-1] 6020 qengine = torch.backends.quantized.engine 6021 if qengine not in supported_qengines: 6022 return 6023 if qengine == 'qnnpack': 6024 assume(not channelwise) # QNNPACK doesn't support channelwise 6025 else: 6026 assume(not transposed) # Only QNNPACK supports transposed conv 6027 if transposed: 6028 qconv_prepack = torch.ops.quantized.conv_transpose1d_prepack 6029 qconv_unpack = torch.ops.quantized.conv_transpose1d_unpack 6030 else: 6031 qconv_prepack = torch.ops.quantized.conv1d_prepack 6032 qconv_unpack = torch.ops.quantized.conv1d_unpack 6033 self._test_qconv_unpack_impl( 6034 qconv_prepack, qconv_unpack, inputs, [stride], 6035 [pad], [o_pad], channelwise) 6036 6037 @given( 6038 inputs=hu.tensor_conv( 6039 spatial_dim=2, batch_size_range=(1, 3), 6040 input_channels_per_group_range=(1, 4), 6041 output_channels_per_group_range=(1, 4), feature_map_range=(4, 8), 6042 kernel_range=(1, 4), max_groups=4, 6043 can_be_transposed=True, 6044 qparams=[hu.qparams(dtypes=torch.quint8, 6045 zero_point_min=0, 6046 zero_point_max=0), 6047 hu.qparams(dtypes=torch.qint8, 6048 zero_point_min=0, 6049 zero_point_max=0), 6050 hu.qparams(dtypes=torch.qint32, 6051 zero_point_min=0, 6052 zero_point_max=0)]), 6053 stride=st.integers(1, 3), 6054 pad=st.integers(0, 2), 6055 o_pad=st.integers(0, 2), 6056 channelwise=st.booleans()) 6057 @override_qengines 6058 def test_qconv2d_unpack(self, inputs, stride, pad, o_pad, channelwise): 6059 transposed = inputs[-1] 6060 qengine = torch.backends.quantized.engine 6061 if qengine not in supported_qengines: 6062 return 6063 if qengine == 'qnnpack': 6064 assume(not channelwise) # QNNPACK doesn't support channelwise 6065 if transposed: 6066 qconv_prepack = torch.ops.quantized.conv_transpose2d_prepack 6067 qconv_unpack = torch.ops.quantized.conv_transpose2d_unpack 6068 else: 6069 qconv_prepack = torch.ops.quantized.conv2d_prepack 6070 qconv_unpack = torch.ops.quantized.conv2d_unpack 6071 self._test_qconv_unpack_impl( 6072 qconv_prepack, qconv_unpack, inputs, [stride, stride], 6073 [pad, pad], [o_pad, o_pad], channelwise) 6074 6075 """Tests the correctness of quantized 1D convolution op.""" 6076 @given(batch_size=st.integers(1, 6), 6077 input_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)), 6078 output_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)), 6079 groups=st.integers(1, 3), 6080 length=st.integers(4, 16), 6081 kernel=st.integers(1, 7), 6082 stride=st.integers(1, 2), 6083 pad=st.integers(0, 2), 6084 dilation=st.integers(1, 2), 6085 X_scale=st.floats(1.2, 1.6), 6086 X_zero_point=st.integers(0, 4), 6087 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 6088 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 6089 Y_scale=st.floats(4.2, 5.6), 6090 Y_zero_point=st.integers(0, 4), 6091 use_bias=st.booleans(), 6092 use_channelwise=st.booleans()) 6093 @override_qengines 6094 def test_qconv1d( 6095 self, 6096 batch_size, 6097 input_channels_per_group, 6098 output_channels_per_group, 6099 groups, 6100 length, 6101 kernel, 6102 stride, 6103 pad, 6104 dilation, 6105 X_scale, 6106 X_zero_point, 6107 W_scale, 6108 W_zero_point, 6109 Y_scale, 6110 Y_zero_point, 6111 use_bias, 6112 use_channelwise, 6113 ): 6114 input_channels = input_channels_per_group * groups 6115 output_channels = output_channels_per_group * groups 6116 if torch.backends.quantized.engine == 'qnnpack': 6117 use_channelwise = False 6118 conv1d = torch.nn.Conv1d( 6119 input_channels, 6120 output_channels, 6121 kernel, 6122 stride, 6123 pad, 6124 dilation, 6125 groups, 6126 ) 6127 qconv_prepack = torch.ops.quantized.conv1d_prepack 6128 qconv = torch.ops.quantized.conv1d 6129 6130 act_qdtypes = [torch.quint8] 6131 # Only qnnpack qengine supportes qint8 6132 if qengine_is_qnnpack() and torch.backends.xnnpack.enabled: 6133 act_qdtypes.append(torch.qint8) 6134 6135 for X_qdtype in act_qdtypes: 6136 if X_qdtype == torch.qint8: 6137 W_zero_point = [0 for i in range(len(W_zero_point))] 6138 6139 self._test_qconv_impl( 6140 qconv, qconv_prepack, conv1d, batch_size, 6141 input_channels_per_group, (length, ), 6142 output_channels_per_group, groups, kernel, [stride], [pad], None, 6143 [dilation], X_scale, X_zero_point, W_scale, W_zero_point, 6144 Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False, 6145 input_dtype=X_qdtype, output_dtype=X_qdtype) 6146 6147 @given(batch_size=st.integers(1, 6), 6148 input_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)), 6149 output_channels_per_group=st.sampled_from((2, 4, 5, 8, 16, 32)), 6150 groups=st.integers(1, 3), 6151 length=st.integers(4, 16), 6152 kernel=st.integers(1, 7), 6153 stride=st.integers(1, 2), 6154 pad=st.integers(0, 2), 6155 dilation=st.integers(1, 2), 6156 X_scale=st.floats(1.2, 1.6), 6157 X_zero_point=st.integers(0, 4), 6158 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 6159 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 6160 Y_scale=st.floats(4.2, 5.6), 6161 Y_zero_point=st.integers(0, 4), 6162 use_bias=st.booleans(), 6163 use_channelwise=st.booleans()) 6164 @override_qengines 6165 def test_qconv1d_relu( 6166 self, 6167 batch_size, 6168 input_channels_per_group, 6169 output_channels_per_group, 6170 groups, 6171 length, 6172 kernel, 6173 stride, 6174 pad, 6175 dilation, 6176 X_scale, 6177 X_zero_point, 6178 W_scale, 6179 W_zero_point, 6180 Y_scale, 6181 Y_zero_point, 6182 use_bias, 6183 use_channelwise, 6184 ): 6185 input_channels = input_channels_per_group * groups 6186 output_channels = output_channels_per_group * groups 6187 if torch.backends.quantized.engine == 'qnnpack': 6188 use_channelwise = False 6189 conv1d = torch.nn.Conv1d( 6190 input_channels, 6191 output_channels, 6192 kernel, 6193 stride, 6194 pad, 6195 dilation, 6196 groups, 6197 ) 6198 qconv_prepack = torch.ops.quantized.conv1d_prepack 6199 qconv = torch.ops.quantized.conv1d_relu 6200 6201 act_qdtypes = [torch.quint8] 6202 # Only qnnpack qengine supportes qint8 6203 if qengine_is_qnnpack() and torch.backends.xnnpack.enabled: 6204 act_qdtypes.append(torch.qint8) 6205 6206 for X_qdtype in act_qdtypes: 6207 if X_qdtype == torch.qint8: 6208 W_zero_point = [0 for i in range(len(W_zero_point))] 6209 6210 self._test_qconv_impl( 6211 qconv, qconv_prepack, conv1d, batch_size, 6212 input_channels_per_group, (length, ), 6213 output_channels_per_group, groups, kernel, [stride], [pad], None, 6214 [dilation], X_scale, X_zero_point, W_scale, W_zero_point, 6215 Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False, 6216 input_dtype=X_qdtype, output_dtype=X_qdtype) 6217 6218 # TODO: merge this test with test_qconv1d when CUDNN runtime flags becomes available 6219 """Tests the correctness of quantized 1D convolution cudnn op.""" 6220 @given(batch_size=st.integers(1, 6), 6221 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 6222 input_channels_per_group=st.integers(1, 32), 6223 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 6224 output_channels_per_group=st.integers(1, 32), 6225 groups=st.integers(1, 1), # currently padding only supports groups=1 6226 length=st.integers(4, 16), 6227 kernel=st.integers(1, 7), 6228 stride=st.integers(1, 2), 6229 pad=st.integers(0, 2), 6230 # currently cudnn has only been verified to work for dilation = 1 6231 # TODO: check backend works for dilation > 1 6232 dilation=st.integers(1, 1), 6233 X_scale=st.floats(1.2, 1.6), 6234 # currently conv cudnn backend is only implemented for int8 symmetric 6235 X_zero_point=st.sampled_from([0]), 6236 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 6237 # currently conv cudnn backend is only implemented for int8 symmetric 6238 W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2), 6239 Y_scale=st.floats(4.2, 5.6), 6240 # currently conv cudnn backend is only implemented for int8 symmetric 6241 Y_zero_point=st.sampled_from([0]), 6242 use_bias=st.booleans(), 6243 # TODO: enable channelwise 6244 use_channelwise=st.sampled_from([False])) 6245 @skipIfNoFBGEMM 6246 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 6247 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 6248 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 6249 @unittest.skip("not currently working and feature isn't used") 6250 def test_qconv1d_cudnn( 6251 self, 6252 batch_size, 6253 input_channels_per_group, 6254 output_channels_per_group, 6255 groups, 6256 length, 6257 kernel, 6258 stride, 6259 pad, 6260 dilation, 6261 X_scale, 6262 X_zero_point, 6263 W_scale, 6264 W_zero_point, 6265 Y_scale, 6266 Y_zero_point, 6267 use_bias, 6268 use_channelwise, 6269 ): 6270 input_channels = input_channels_per_group * groups 6271 output_channels = output_channels_per_group * groups 6272 6273 conv1d = torch.nn.Conv1d( 6274 input_channels, 6275 output_channels, 6276 kernel, 6277 stride, 6278 pad, 6279 dilation, 6280 groups, 6281 ).to(torch.device("cuda")) 6282 qconv_prepack = torch.ops.quantized.conv1d_prepack 6283 qconv = torch.ops.quantized.conv1d 6284 6285 self._test_qconv_impl( 6286 qconv, qconv_prepack, conv1d, batch_size, 6287 input_channels_per_group, (length, ), 6288 output_channels_per_group, groups, kernel, [stride], [pad], None, 6289 [dilation], X_scale, X_zero_point, W_scale, W_zero_point, 6290 Y_scale, Y_zero_point, use_bias, "none", use_channelwise, False, 6291 device=torch.device("cuda"), 6292 input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8) 6293 6294 @given(batch_size=st.integers(1, 6), 6295 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 6296 input_channels_per_group=st.integers(1, 32), 6297 # cudnn only supports multiples of 4, but we have explicitly added padding on the backend 6298 output_channels_per_group=st.integers(1, 32), 6299 groups=st.integers(1, 1), # currently padding only supports groups=1 6300 length=st.integers(4, 16), 6301 kernel=st.integers(1, 7), 6302 stride=st.integers(1, 2), 6303 pad=st.integers(0, 2), 6304 # currently cudnn has only been verified to work for dilation = 1 6305 # TODO: check backend works for dilation > 1 6306 dilation=st.integers(1, 1), 6307 X_scale=st.floats(1.2, 1.6), 6308 # currently conv cudnn backend is only implemented for int8 symmetric 6309 X_zero_point=st.sampled_from([0]), 6310 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 6311 # currently conv cudnn backend is only implemented for int8 symmetric 6312 W_zero_point=st.lists(st.integers(0, 0), min_size=1, max_size=2), 6313 Y_scale=st.floats(4.2, 5.6), 6314 # currently conv cudnn backend is only implemented for int8 symmetric 6315 Y_zero_point=st.sampled_from([0]), 6316 use_bias=st.booleans(), 6317 # TODO: enable channelwise 6318 use_channelwise=st.sampled_from([False])) 6319 @skipIfNoFBGEMM 6320 @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") 6321 @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") 6322 @unittest.skipIf(TEST_ROCM, "not supported on rocm.") 6323 @unittest.skip("not currently working and feature isn't used") 6324 def test_qconv1d_relu_cudnn( 6325 self, 6326 batch_size, 6327 input_channels_per_group, 6328 output_channels_per_group, 6329 groups, 6330 length, 6331 kernel, 6332 stride, 6333 pad, 6334 dilation, 6335 X_scale, 6336 X_zero_point, 6337 W_scale, 6338 W_zero_point, 6339 Y_scale, 6340 Y_zero_point, 6341 use_bias, 6342 use_channelwise, 6343 ): 6344 input_channels = input_channels_per_group * groups 6345 output_channels = output_channels_per_group * groups 6346 6347 conv1d = torch.nn.Conv1d( 6348 input_channels, 6349 output_channels, 6350 kernel, 6351 stride, 6352 pad, 6353 dilation, 6354 groups, 6355 ).to(torch.device("cuda")) 6356 qconv_prepack = torch.ops.quantized.conv1d_prepack 6357 qconv = torch.ops.quantized.conv1d_relu 6358 6359 self._test_qconv_impl( 6360 qconv, qconv_prepack, conv1d, batch_size, 6361 input_channels_per_group, (length, ), 6362 output_channels_per_group, groups, kernel, [stride], [pad], None, 6363 [dilation], X_scale, X_zero_point, W_scale, W_zero_point, 6364 Y_scale, Y_zero_point, use_bias, "relu", use_channelwise, False, 6365 device=torch.device("cuda"), 6366 input_dtype=torch.qint8, weight_dtype=torch.qint8, output_dtype=torch.qint8) 6367 6368 @given(batch_size=st.integers(1, 4), 6369 input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]), 6370 D=st.integers(4, 8), 6371 H=st.integers(4, 8), 6372 W=st.integers(4, 8), 6373 output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]), 6374 groups=st.integers(1, 3), 6375 kernel_d=st.integers(1, 4), 6376 kernel_h=st.integers(1, 4), 6377 kernel_w=st.integers(1, 4), 6378 stride_d=st.integers(1, 2), 6379 stride_h=st.integers(1, 2), 6380 stride_w=st.integers(1, 2), 6381 pad_d=st.integers(0, 2), 6382 pad_h=st.integers(0, 2), 6383 pad_w=st.integers(0, 2), 6384 dilation=st.integers(1, 2), 6385 X_scale=st.floats(1.2, 1.6), 6386 X_zero_point=st.integers(0, 4), 6387 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 6388 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 6389 Y_scale=st.floats(4.2, 5.6), 6390 Y_zero_point=st.integers(0, 4), 6391 use_bias=st.booleans(), 6392 use_channelwise=st.booleans(), 6393 qengine=st.sampled_from(("qnnpack", "fbgemm"))) 6394 def test_qconv3d( 6395 self, 6396 batch_size, 6397 input_channels_per_group, 6398 D, 6399 H, 6400 W, 6401 output_channels_per_group, 6402 groups, 6403 kernel_d, 6404 kernel_h, 6405 kernel_w, 6406 stride_d, 6407 stride_h, 6408 stride_w, 6409 pad_d, 6410 pad_h, 6411 pad_w, 6412 dilation, 6413 X_scale, 6414 X_zero_point, 6415 W_scale, 6416 W_zero_point, 6417 Y_scale, 6418 Y_zero_point, 6419 use_bias, 6420 use_channelwise, 6421 qengine 6422 ): 6423 if qengine not in supported_qengines: 6424 return 6425 6426 input_channels = input_channels_per_group * groups 6427 output_channels = output_channels_per_group * groups 6428 kernels = (kernel_d, kernel_h, kernel_w) 6429 strides = (stride_d, stride_h, stride_w) 6430 pads = (pad_d, pad_h, pad_w) 6431 dilations = (dilation, dilation, dilation) 6432 6433 with override_quantized_engine(qengine): 6434 qconv = torch.ops.quantized.conv3d 6435 qconv_prepack = torch.ops.quantized.conv3d_prepack 6436 conv_op = torch.nn.Conv3d( 6437 input_channels, 6438 output_channels, 6439 kernels, 6440 strides, 6441 pads, 6442 dilations, 6443 groups, 6444 ) 6445 self._test_qconv_impl( 6446 qconv, qconv_prepack, conv_op, batch_size, 6447 input_channels_per_group, (D, H, W), output_channels_per_group, 6448 groups, kernels, strides, pads, None, dilations, X_scale, 6449 X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, 6450 use_bias, "none", use_channelwise, use_transpose=False) 6451 6452 @given(batch_size=st.integers(1, 4), 6453 input_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]), 6454 D=st.integers(4, 8), 6455 H=st.integers(4, 8), 6456 W=st.integers(4, 8), 6457 output_channels_per_group=st.sampled_from([2, 4, 5, 8, 16]), 6458 groups=st.integers(1, 3), 6459 kernel_d=st.integers(1, 4), 6460 kernel_h=st.integers(1, 4), 6461 kernel_w=st.integers(1, 4), 6462 stride_d=st.integers(1, 2), 6463 stride_h=st.integers(1, 2), 6464 stride_w=st.integers(1, 2), 6465 pad_d=st.integers(0, 2), 6466 pad_h=st.integers(0, 2), 6467 pad_w=st.integers(0, 2), 6468 dilation=st.integers(1, 2), 6469 X_scale=st.floats(1.2, 1.6), 6470 X_zero_point=st.integers(0, 4), 6471 W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), 6472 W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2), 6473 Y_scale=st.floats(4.2, 5.6), 6474 Y_zero_point=st.integers(0, 4), 6475 use_bias=st.booleans(), 6476 use_channelwise=st.booleans(), 6477 qengine=st.sampled_from(("qnnpack", "fbgemm"))) 6478 def test_qconv3d_relu( 6479 self, 6480 batch_size, 6481 input_channels_per_group, 6482 D, 6483 H, 6484 W, 6485 output_channels_per_group, 6486 groups, 6487 kernel_d, 6488 kernel_h, 6489 kernel_w, 6490 stride_d, 6491 stride_h, 6492 stride_w, 6493 pad_d, 6494 pad_h, 6495 pad_w, 6496 dilation, 6497 X_scale, 6498 X_zero_point, 6499 W_scale, 6500 W_zero_point, 6501 Y_scale, 6502 Y_zero_point, 6503 use_bias, 6504 use_channelwise, 6505 qengine 6506 ): 6507 if qengine not in supported_qengines: 6508 return 6509 6510 input_channels = input_channels_per_group * groups 6511 output_channels = output_channels_per_group * groups 6512 kernels = (kernel_d, kernel_h, kernel_w) 6513 strides = (stride_d, stride_h, stride_w) 6514 pads = (pad_d, pad_h, pad_w) 6515 dilations = (dilation, dilation, dilation) 6516 6517 with override_quantized_engine(qengine): 6518 qconv = torch.ops.quantized.conv3d_relu 6519 qconv_prepack = torch.ops.quantized.conv3d_prepack 6520 conv_op = torch.nn.Conv3d( 6521 input_channels, 6522 output_channels, 6523 kernels, 6524 strides, 6525 pads, 6526 dilations, 6527 groups, 6528 ) 6529 self._test_qconv_impl( 6530 qconv, qconv_prepack, conv_op, batch_size, 6531 input_channels_per_group, (D, H, W), output_channels_per_group, 6532 groups, kernels, strides, pads, None, dilations, X_scale, 6533 X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, 6534 use_bias, "relu", use_channelwise, use_transpose=False) 6535 6536 """Tests the correctness of the quantized::qconv3d_unpack op.""" 6537 @given( 6538 inputs=hu.tensor_conv( 6539 spatial_dim=3, batch_size_range=(1, 3), 6540 input_channels_per_group_range=(1, 3), 6541 output_channels_per_group_range=(1, 3), feature_map_range=(3, 6), 6542 kernel_range=(1, 3), max_groups=3, 6543 qparams=[hu.qparams(dtypes=torch.quint8, 6544 zero_point_min=0, 6545 zero_point_max=0), 6546 hu.qparams(dtypes=torch.qint8, 6547 zero_point_min=0, 6548 zero_point_max=0), 6549 hu.qparams(dtypes=torch.qint32, 6550 zero_point_min=0, 6551 zero_point_max=0)]), 6552 stride_d=st.integers(1, 2), stride_h=st.integers(1, 2), 6553 stride_w=st.integers(1, 2), 6554 pad_d=st.integers(1, 2), pad_h=st.integers(1, 2), 6555 pad_w=st.integers(1, 2), 6556 o_pad=st.integers(0, 2), 6557 channelwise=st.booleans()) 6558 @override_qengines 6559 def test_qconv3d_unpack( 6560 self, inputs, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, o_pad, 6561 channelwise 6562 ): 6563 if qengine_is_qnnpack(): 6564 return # QNNPACK doesn't support this 6565 transposed = inputs[-1] 6566 if transposed: 6567 qconv_prepack = torch.ops.quantized.conv_transpose3d_prepack 6568 qconv_unpack = torch.ops.quantized.conv_transpose3d_unpack 6569 else: 6570 qconv_prepack = torch.ops.quantized.conv3d_prepack 6571 qconv_unpack = torch.ops.quantized.conv3d_unpack 6572 self._test_qconv_unpack_impl( 6573 qconv_prepack, qconv_unpack, inputs, 6574 (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad), 6575 channelwise) 6576 6577 def test_conv_reorder_issue_onednn(self): 6578 """ Ensure reorder failure issue in conv is fixed for onednn backend. 6579 Onednn backend used to encounter reorder failure 6580 when running conv with dynamic input shapes. 6581 Solved by https://github.com/pytorch/pytorch/pull/86876 6582 """ 6583 if 'onednn' not in supported_qengines: 6584 return 6585 with override_quantized_engine('onednn'): 6586 bs = 1 6587 ic, oc = 128, 512 6588 kh, kw = 1, 1 6589 bias = None 6590 strides, paddings, dilates = (1, 1), (0, 0), (1, 1) 6591 for groups in [1, 2]: 6592 ih, iw = 28, 28 6593 w = torch.randn((oc * groups, ic, kh, kw)) 6594 qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8) 6595 x = torch.randn((bs, ic * groups, ih, iw)) 6596 qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) 6597 w_packed = torch.ops.quantized.conv2d_prepack( 6598 qw, bias, strides, paddings, dilates, groups 6599 ) 6600 torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0) 6601 ih, iw = 5, 4 6602 x = torch.randn((bs, ic * groups, ih, iw)) 6603 qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) 6604 # The following should pass when input shape is changed 6605 torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0) 6606 6607 @skipIfNoONEDNN 6608 def test_conv_transpose_reorder_issue_onednn(self): 6609 with override_quantized_engine('onednn'): 6610 bs = 1 6611 ic, oc = 16, 33 6612 kh, kw = 3, 3 6613 ih, iw = 50, 100 6614 bias = None 6615 strides, paddings, output_paddings, dilates, groups = [2, 2], [0, 0], [0, 0], [1, 1], 1 6616 w = torch.randn((ic, oc, kh, kw)) 6617 qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8) 6618 x = torch.randn((bs, ic, ih, iw)) 6619 qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) 6620 w_packed = torch.ops.quantized.conv_transpose2d_prepack( 6621 qw, bias, strides, paddings, output_paddings, dilates, groups 6622 ) 6623 torch.ops.quantized.conv_transpose2d(qx, w_packed, output_scale=1.0, output_zero_point=0) 6624 ih, iw = 5, 4 6625 x = torch.randn((bs, ic, ih, iw)) 6626 qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) 6627 # The following should pass when input shape is changed 6628 torch.ops.quantized.conv_transpose2d(qx, w_packed, output_scale=1.0, output_zero_point=0) 6629 6630 def _test_qconv_impl_cpu_tensor( 6631 self, 6632 qconv, 6633 qconv_prepack, 6634 conv_op, 6635 input_channels_per_group=2, 6636 input_feature_map_shape=(), 6637 output_channels_per_group=2, 6638 groups=1, 6639 kernels=3, 6640 strides=(), 6641 pads=(), 6642 dilations=(), 6643 X_scale=1.3, 6644 X_zero_point=2, 6645 W_scale=(1.0,), 6646 W_zero_point=(0,), 6647 Y_scale=3.2, 6648 Y_zero_point=0, 6649 use_bias=True, 6650 post_op=PointwisePostOp(), 6651 use_channelwise=True, 6652 X2_scale=1.2, 6653 X2_zero_point=0, 6654 qconv_output_dtype=None, # None, torch.float32, torch.bfloat16 6655 weight_in_channel_last_format=False, 6656 qconv_x2_dtype=None, 6657 ): 6658 # ONEDNN only supports symmetric quantization of weight 6659 if W_zero_point is not None: 6660 W_zero_point = len(W_zero_point) * [0] 6661 fp32_output = True if qconv_output_dtype is torch.float32 else False 6662 bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False 6663 if fp32_output or bfloat16_output: 6664 Y_scale = 1.0 6665 Y_zero_point = 0 6666 X2_scale = 1.0 6667 X2_zero_point = 0 6668 batch_size = 3 6669 o_pads = None 6670 device = torch.device("cpu") 6671 input_dtype = torch.quint8 6672 weight_dtype = torch.qint8 6673 output_dtype = torch.quint8 6674 use_transpose = False 6675 (X, W), (X_q, W_q), bias_float = self._make_qconv_tensors( 6676 batch_size, 6677 input_channels_per_group, 6678 input_feature_map_shape, 6679 output_channels_per_group, 6680 groups, 6681 kernels, 6682 strides, 6683 pads, 6684 dilations, 6685 X_scale, 6686 X_zero_point, 6687 W_scale, 6688 W_zero_point, 6689 use_bias, 6690 use_channelwise, 6691 use_transpose, 6692 device=device, 6693 input_dtype=input_dtype, 6694 weight_dtype=weight_dtype, 6695 ) 6696 if bias_float is not None: 6697 bias_float = bias_float.to(device) 6698 # Assign weights 6699 W = W_q.dequantize() 6700 X = X_q.dequantize() 6701 conv_op.weight = torch.nn.Parameter(W, requires_grad=False) 6702 conv_op.bias = ( 6703 torch.nn.Parameter(bias_float, requires_grad=False) if use_bias else None 6704 ) 6705 result_ref = conv_op(X) 6706 X2_q = None 6707 6708 if post_op.binary_attr == "sum": 6709 (X_value_min, X_value_max) = (0, 4) 6710 X2_init = torch.randint( 6711 X_value_min, X_value_max, result_ref.size(), device=device 6712 ) 6713 X2 = X2_scale * ((X2_init - X2_zero_point).float()) 6714 X2_q = torch.quantize_per_tensor( 6715 X2, scale=X2_scale, zero_point=X2_zero_point, dtype=input_dtype 6716 ) 6717 result_ref = result_ref + X2 6718 if post_op.unary_attr == "relu": 6719 relu = torch.nn.ReLU() 6720 result_ref = relu(result_ref) 6721 elif post_op.unary_attr == "relu": 6722 assert not use_transpose, "Cannot fuse ReLU with ConvTranspose" 6723 relu = torch.nn.ReLU() 6724 result_ref = relu(result_ref) 6725 elif post_op.unary_attr == "hardtanh": 6726 assert not use_transpose, "Cannot fuse hardtanh with ConvTranspose" 6727 assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in" 6728 hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1]) 6729 result_ref = hardtanh(result_ref) 6730 elif post_op.unary_attr == "hardswish": 6731 assert not use_transpose, "Cannot fuse hardswish with ConvTranspose" 6732 hardswish = torch.nn.Hardswish() 6733 result_ref = hardswish(result_ref) 6734 elif post_op.unary_attr == "swish": 6735 assert not use_transpose, "Cannot fuse silu with ConvTranspose" 6736 silu = torch.nn.SiLU() 6737 result_ref = silu(result_ref) 6738 6739 # Quantize reference results for comparison 6740 result_ref_q = torch.quantize_per_tensor( 6741 result_ref, scale=Y_scale, zero_point=Y_zero_point, dtype=output_dtype 6742 ) 6743 6744 # Calculate the result for 2.X path 6745 X_q_cpu_tensor = X_q.int_repr() 6746 W_q_cpu_tensor = W_q.int_repr() 6747 6748 weight_scale = ( 6749 W_q.q_per_channel_scales() 6750 if use_channelwise 6751 else torch.tensor(W_q.q_scale(), dtype=torch.double, device=device) 6752 ) 6753 weight_zero_point = ( 6754 W_q.q_per_channel_zero_points() 6755 if use_channelwise 6756 else torch.tensor(W_q.q_zero_point(), dtype=torch.int64, device=device) 6757 ) 6758 6759 if weight_in_channel_last_format: 6760 if W_q_cpu_tensor.dim() == 5: 6761 W_q_cpu_tensor = W_q_cpu_tensor.to(memory_format=torch.channels_last_3d) 6762 elif W_q_cpu_tensor.dim() == 4: 6763 W_q_cpu_tensor = W_q_cpu_tensor.to(memory_format=torch.channels_last) 6764 6765 packed_weight = qconv_prepack( 6766 W_q_cpu_tensor, 6767 weight_scale, 6768 X_scale, 6769 X_zero_point, 6770 strides, 6771 pads, 6772 dilations, 6773 groups, 6774 X_q_cpu_tensor.size(), 6775 ) 6776 6777 if post_op.binary_attr == "sum": 6778 X2_cpu_tensor = ( 6779 X2_q.int_repr() 6780 if qconv_output_dtype is None 6781 else X2_q.dequantize().to(qconv_x2_dtype) 6782 ).contiguous(memory_format=torch.channels_last) 6783 Y_q_cpu_tensor = qconv( 6784 X_q_cpu_tensor, 6785 X_scale, 6786 X_zero_point, 6787 X2_cpu_tensor, 6788 X2_scale, 6789 X2_zero_point, 6790 packed_weight, 6791 weight_scale, 6792 weight_zero_point, 6793 bias_float, 6794 strides, 6795 pads, 6796 dilations, 6797 groups, 6798 Y_scale, 6799 Y_zero_point, 6800 qconv_output_dtype, 6801 post_op.binary_attr, 6802 post_op.alpha, 6803 post_op.unary_attr, 6804 post_op.scalars, 6805 post_op.algorithm, 6806 ) 6807 else: 6808 Y_q_cpu_tensor = qconv( 6809 X_q_cpu_tensor, 6810 X_scale, 6811 X_zero_point, 6812 packed_weight, 6813 weight_scale, 6814 weight_zero_point, 6815 bias_float, 6816 strides, 6817 pads, 6818 dilations, 6819 groups, 6820 Y_scale, 6821 Y_zero_point, 6822 qconv_output_dtype, 6823 post_op.unary_attr, 6824 post_op.scalars, 6825 post_op.algorithm, 6826 ) 6827 if fp32_output or bfloat16_output: 6828 self.assertTrue(Y_q_cpu_tensor.dtype == qconv_output_dtype) 6829 Y_q_cpu_tensor = torch.quantize_per_tensor( 6830 Y_q_cpu_tensor 6831 if fp32_output 6832 else Y_q_cpu_tensor.to(torch.float32), scale=Y_scale, zero_point=Y_zero_point, dtype=output_dtype 6833 ).int_repr() 6834 6835 # Make sure the results match 6836 # assert_array_almost_equal compares using the following formula: 6837 # abs(desired-actual) < 1.5 * 10**(-decimal) 6838 # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html) 6839 # We use decimal = 0 to ignore off-by-1 differences between 6840 # reference and test. Off-by-1 differences arise due to the order of 6841 # round and zero_point addition operation, i.e., if addition 6842 # followed by round is used by reference and round followed by 6843 # addition is used by test, the results may differ by 1. 6844 # For example, the result of round(2.5) + 1 is 3 while 6845 # round(2.5 + 1) is 4 assuming the rounding mode is 6846 # round-to-nearest, ties-to-even. 6847 6848 np.testing.assert_array_almost_equal( 6849 result_ref_q.int_repr().cpu().numpy(), 6850 Y_q_cpu_tensor.cpu().numpy(), 6851 decimal=0, 6852 err_msg=f"""X: {X_q}, W: {W_q}, b: {bias_float}, strides: {strides}, 6853 pads: {pads}, o_pads: {o_pads}, dilations: {dilations}, 6854 groups: {groups}, y_s: {Y_scale}, y_zp: {Y_zero_point}, X2: {X2_q}""", 6855 ) 6856 6857 # Return the quantized data for later reuse 6858 return X_q, W_q, bias_float 6859 6860 @skipIfNoONEDNN 6861 def test_qconv1d_pt2e(self): 6862 groups_list = [1, 3] 6863 input_channels_per_group = 2 6864 output_channels_per_group = 2 6865 length = 4 6866 kernel = 3 6867 stride = 1 6868 pad = 1 6869 dilation = 1 6870 W_scale = [1.5] 6871 W_zero_point = [0] 6872 use_bias_list = [False, True] 6873 use_channelwise_list = [False, True] 6874 output_dtype_list = [None, torch.float32, torch.bfloat16] 6875 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) 6876 for groups, use_bias, use_channelwise, output_dtype in options: 6877 if output_dtype is not None and not (use_bias and use_channelwise): 6878 # Remove some test combination to reduce UT test time 6879 continue 6880 conv1d = torch.nn.Conv1d( 6881 input_channels_per_group * groups, 6882 output_channels_per_group * groups, 6883 kernel, 6884 stride, 6885 pad, 6886 dilation, 6887 groups, 6888 ) 6889 qconv = torch.ops.onednn.qconv1d_pointwise 6890 qconv_prepack = torch.ops.onednn.qconv_prepack 6891 pointwise_post_op = PointwisePostOp() 6892 self._test_qconv_impl_cpu_tensor( 6893 qconv, 6894 qconv_prepack, 6895 conv1d, 6896 input_channels_per_group=input_channels_per_group, 6897 input_feature_map_shape=(length,), 6898 output_channels_per_group=output_channels_per_group, 6899 groups=groups, 6900 kernels=kernel, 6901 strides=[stride], 6902 pads=[pad], 6903 dilations=[dilation], 6904 W_scale=W_scale, 6905 W_zero_point=W_zero_point, 6906 use_bias=use_bias, 6907 post_op=pointwise_post_op, 6908 use_channelwise=use_channelwise, 6909 qconv_output_dtype=output_dtype, 6910 ) 6911 6912 @skipIfNoONEDNN 6913 def test_qconv2d_pt2e(self): 6914 groups_list = [1, 3] 6915 input_channels_per_group = 2 6916 output_channels_per_group = 2 6917 input_feature_map_shape = (10, 10) 6918 kernels = (3, 3) 6919 strides = (2, 2) 6920 pads = (1, 1) 6921 dilations = (1, 1) 6922 W_scale = [1.5] 6923 W_zero_point = [0] 6924 use_bias_list = [False, True] 6925 use_channelwise_list = [False, True] 6926 channel_last_weight_format_list = [False, True] 6927 output_dtype_list = [None, torch.float32, torch.bfloat16] 6928 options = itertools.product( 6929 groups_list, 6930 use_bias_list, 6931 use_channelwise_list, 6932 channel_last_weight_format_list, 6933 output_dtype_list, 6934 ) 6935 for groups, use_bias, use_channelwise, channel_last_weight_format, output_dtype in options: 6936 if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise): 6937 # Remove some test combination to reduce UT test time 6938 continue 6939 qconv = torch.ops.onednn.qconv2d_pointwise 6940 qconv_prepack = torch.ops.onednn.qconv_prepack 6941 conv_op = torch.nn.Conv2d( 6942 input_channels_per_group * groups, 6943 output_channels_per_group * groups, 6944 kernels, 6945 strides, 6946 pads, 6947 dilations, 6948 groups, 6949 ) 6950 pointwise_post_op = PointwisePostOp() 6951 self._test_qconv_impl_cpu_tensor( 6952 qconv, 6953 qconv_prepack, 6954 conv_op, 6955 input_channels_per_group=input_channels_per_group, 6956 input_feature_map_shape=input_feature_map_shape, 6957 output_channels_per_group=output_channels_per_group, 6958 groups=groups, 6959 kernels=kernels, 6960 strides=strides, 6961 pads=pads, 6962 dilations=dilations, 6963 W_scale=W_scale, 6964 W_zero_point=W_zero_point, 6965 use_bias=use_bias, 6966 post_op=pointwise_post_op, 6967 use_channelwise=use_channelwise, 6968 qconv_output_dtype=output_dtype, 6969 weight_in_channel_last_format=channel_last_weight_format, 6970 ) 6971 6972 @skipIfNoONEDNN 6973 def test_qconv3d_pt2e(self): 6974 input_channels_per_group = 2 6975 input_feature_map_shape = (6, 6, 6) 6976 output_channels_per_group = 2 6977 groups_list = [1, 3] 6978 kernels = (3, 3, 3) 6979 strides = (2, 2, 2) 6980 pads = (1, 1, 1) 6981 dilations = (1, 1, 1) 6982 W_scale = [1.5] 6983 W_zero_point = [0] 6984 use_bias_list = [False, True] 6985 use_channelwise_list = [False, True] 6986 channel_last_weight_format_list = [False, True] 6987 output_dtype_list = [None, torch.float32, torch.bfloat16] 6988 options = itertools.product( 6989 groups_list, 6990 use_bias_list, 6991 use_channelwise_list, 6992 channel_last_weight_format_list, 6993 output_dtype_list, 6994 ) 6995 for groups, use_bias, use_channelwise, channel_last_weight_format, output_dtype in options: 6996 if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise): 6997 # Remove some test combination to reduce UT test time 6998 continue 6999 qconv = torch.ops.onednn.qconv3d_pointwise 7000 qconv_prepack = torch.ops.onednn.qconv_prepack 7001 conv_op = torch.nn.Conv3d( 7002 input_channels_per_group * groups, 7003 output_channels_per_group * groups, 7004 kernels, 7005 strides, 7006 pads, 7007 dilations, 7008 groups, 7009 ) 7010 pointwise_post_op = PointwisePostOp() 7011 self._test_qconv_impl_cpu_tensor( 7012 qconv, 7013 qconv_prepack, 7014 conv_op, 7015 input_channels_per_group=input_channels_per_group, 7016 input_feature_map_shape=input_feature_map_shape, 7017 output_channels_per_group=output_channels_per_group, 7018 groups=groups, 7019 kernels=kernels, 7020 strides=strides, 7021 pads=pads, 7022 dilations=dilations, 7023 W_scale=W_scale, 7024 W_zero_point=W_zero_point, 7025 use_bias=use_bias, 7026 post_op=pointwise_post_op, 7027 use_channelwise=use_channelwise, 7028 qconv_output_dtype=output_dtype, 7029 weight_in_channel_last_format=channel_last_weight_format, 7030 ) 7031 7032 # Test qconv with post op relu 7033 @skipIfNoONEDNN 7034 def test_qconv2d_relu_pt2e(self): 7035 input_channels_per_group = 2 7036 output_channels_per_group = 2 7037 groups_list = [1, 10] 7038 input_feature_map_shape = (10, 10) 7039 kernels = (3, 3) 7040 strides = (2, 2) 7041 pads = (1, 1) 7042 dilations = (1, 1) 7043 W_scale = [1.5] 7044 W_zero_point = [0] 7045 use_bias_list = [False, True] 7046 use_channelwise_list = [False, True] 7047 output_dtype_list = [None, torch.float32, torch.bfloat16] 7048 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) 7049 for groups, use_bias, use_channelwise, output_dtype in options: 7050 qconv = torch.ops.onednn.qconv2d_pointwise 7051 qconv_prepack = torch.ops.onednn.qconv_prepack 7052 conv_op = torch.nn.Conv2d( 7053 input_channels_per_group * groups, 7054 output_channels_per_group * groups, 7055 kernels, 7056 strides, 7057 pads, 7058 dilations, 7059 groups, 7060 ) 7061 pointwise_post_op = PointwisePostOp(unary_attr="relu") 7062 self._test_qconv_impl_cpu_tensor( 7063 qconv, 7064 qconv_prepack, 7065 conv_op, 7066 input_channels_per_group=input_channels_per_group, 7067 input_feature_map_shape=input_feature_map_shape, 7068 output_channels_per_group=output_channels_per_group, 7069 groups=groups, 7070 kernels=kernels, 7071 strides=strides, 7072 pads=pads, 7073 dilations=dilations, 7074 W_scale=W_scale, 7075 W_zero_point=W_zero_point, 7076 use_bias=use_bias, 7077 post_op=pointwise_post_op, 7078 use_channelwise=use_channelwise, 7079 qconv_output_dtype=output_dtype, 7080 ) 7081 7082 # Test qconv with post op hardtanh 7083 @skipIfNoONEDNN 7084 def test_qconv2d_hardtanh_pt2e(self): 7085 input_channels_per_group = 2 7086 output_channels_per_group = 2 7087 groups_list = [1, 10] 7088 input_feature_map_shape = (10, 10) 7089 kernels = (3, 3) 7090 strides = (2, 2) 7091 pads = (1, 1) 7092 dilations = (1, 1) 7093 W_scale = [1.5] 7094 W_zero_point = [0] 7095 use_bias_list = [False, True] 7096 use_channelwise_list = [False, True] 7097 output_dtype_list = [None, torch.float32, torch.bfloat16] 7098 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) 7099 for groups, use_bias, use_channelwise, output_dtype in options: 7100 qconv = torch.ops.onednn.qconv2d_pointwise 7101 qconv_prepack = torch.ops.onednn.qconv_prepack 7102 conv_op = torch.nn.Conv2d( 7103 input_channels_per_group * groups, 7104 output_channels_per_group * groups, 7105 kernels, 7106 strides, 7107 pads, 7108 dilations, 7109 groups, 7110 ) 7111 pointwise_post_op = PointwisePostOp(unary_attr="hardtanh", scalars=[0.0, 6.0]) 7112 self._test_qconv_impl_cpu_tensor( 7113 qconv, 7114 qconv_prepack, 7115 conv_op, 7116 input_channels_per_group=input_channels_per_group, 7117 input_feature_map_shape=input_feature_map_shape, 7118 output_channels_per_group=output_channels_per_group, 7119 groups=groups, 7120 kernels=kernels, 7121 strides=strides, 7122 pads=pads, 7123 dilations=dilations, 7124 W_scale=W_scale, 7125 W_zero_point=W_zero_point, 7126 use_bias=use_bias, 7127 post_op=pointwise_post_op, 7128 use_channelwise=use_channelwise, 7129 qconv_output_dtype=output_dtype, 7130 ) 7131 7132 # Test qconv with post op silu 7133 @skipIfNoONEDNN 7134 def test_qconv2d_silu_pt2e(self): 7135 input_channels_per_group = 2 7136 output_channels_per_group = 2 7137 groups_list = [1, 10] 7138 input_feature_map_shape = (10, 10) 7139 kernels = (3, 3) 7140 strides = (2, 2) 7141 pads = (1, 1) 7142 dilations = (1, 1) 7143 W_scale = [1.5] 7144 W_zero_point = [0] 7145 use_bias_list = [False, True] 7146 use_channelwise_list = [False, True] 7147 output_dtype_list = [None, torch.float32, torch.bfloat16] 7148 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) 7149 for groups, use_bias, use_channelwise, output_dtype in options: 7150 qconv = torch.ops.onednn.qconv2d_pointwise 7151 qconv_prepack = torch.ops.onednn.qconv_prepack 7152 conv_op = torch.nn.Conv2d( 7153 input_channels_per_group * groups, 7154 output_channels_per_group * groups, 7155 kernels, 7156 strides, 7157 pads, 7158 dilations, 7159 groups, 7160 ) 7161 pointwise_post_op = PointwisePostOp(unary_attr="swish") 7162 self._test_qconv_impl_cpu_tensor( 7163 qconv, 7164 qconv_prepack, 7165 conv_op, 7166 input_channels_per_group=input_channels_per_group, 7167 input_feature_map_shape=input_feature_map_shape, 7168 output_channels_per_group=output_channels_per_group, 7169 groups=groups, 7170 kernels=kernels, 7171 strides=strides, 7172 pads=pads, 7173 dilations=dilations, 7174 W_scale=W_scale, 7175 W_zero_point=W_zero_point, 7176 use_bias=use_bias, 7177 post_op=pointwise_post_op, 7178 use_channelwise=use_channelwise, 7179 qconv_output_dtype=output_dtype, 7180 ) 7181 7182 # Test qconv with post op hardswish 7183 @skipIfNoONEDNN 7184 def test_qconv2d_hardswish_pt2e(self): 7185 input_channels_per_group = 2 7186 output_channels_per_group = 2 7187 groups_list = [1, 10] 7188 input_feature_map_shape = (10, 10) 7189 kernels = (3, 3) 7190 strides = (2, 2) 7191 pads = (1, 1) 7192 dilations = (1, 1) 7193 W_scale = [1.5] 7194 W_zero_point = [0] 7195 use_bias_list = [False, True] 7196 use_channelwise_list = [False, True] 7197 output_dtype_list = [None, torch.float32, torch.bfloat16] 7198 options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) 7199 7200 for groups, use_bias, use_channelwise, output_dtype in options: 7201 qconv = torch.ops.onednn.qconv2d_pointwise 7202 qconv_prepack = torch.ops.onednn.qconv_prepack 7203 conv_op = torch.nn.Conv2d( 7204 input_channels_per_group * groups, 7205 output_channels_per_group * groups, 7206 kernels, 7207 strides, 7208 pads, 7209 dilations, 7210 groups, 7211 ) 7212 pointwise_post_op = PointwisePostOp(unary_attr="hardswish") 7213 self._test_qconv_impl_cpu_tensor( 7214 qconv, 7215 qconv_prepack, 7216 conv_op, 7217 input_channels_per_group=input_channels_per_group, 7218 input_feature_map_shape=input_feature_map_shape, 7219 output_channels_per_group=output_channels_per_group, 7220 groups=groups, 7221 kernels=kernels, 7222 strides=strides, 7223 pads=pads, 7224 dilations=dilations, 7225 W_scale=W_scale, 7226 W_zero_point=W_zero_point, 7227 use_bias=use_bias, 7228 post_op=pointwise_post_op, 7229 use_channelwise=use_channelwise, 7230 qconv_output_dtype=output_dtype, 7231 ) 7232 7233 # Test qconv with post op sum 7234 @skipIfNoONEDNN 7235 def test_qconv2d_sum_pt2e(self): 7236 groups_list = [1, 3] 7237 input_channels_per_group = 2 7238 output_channels_per_group = 2 7239 input_feature_map_shape = (10, 10) 7240 kernels = (3, 3) 7241 strides = (2, 2) 7242 pads = (1, 1) 7243 dilations = (1, 1) 7244 W_scale = [1.5] 7245 W_zero_point = [-3] 7246 use_bias_list = [False, True] 7247 use_channelwise_list = [False, True] 7248 output_dtype_list = [None, torch.float32, torch.bfloat16] 7249 X2_zero_point_list = [0, 1] 7250 options = itertools.product( 7251 groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list, output_dtype_list 7252 ) 7253 for groups, use_bias, use_channelwise, X2_zero_point, output_dtype in options: 7254 qconv = torch.ops.onednn.qconv2d_pointwise.binary 7255 qconv_prepack = torch.ops.onednn.qconv_prepack 7256 conv_op = torch.nn.Conv2d( 7257 input_channels_per_group * groups, 7258 output_channels_per_group * groups, 7259 kernels, 7260 strides, 7261 pads, 7262 dilations, 7263 groups, 7264 ) 7265 pointwise_post_op = PointwisePostOp(binary_attr="sum") 7266 self._test_qconv_impl_cpu_tensor( 7267 qconv, 7268 qconv_prepack, 7269 conv_op, 7270 input_channels_per_group=input_channels_per_group, 7271 input_feature_map_shape=input_feature_map_shape, 7272 output_channels_per_group=output_channels_per_group, 7273 groups=groups, 7274 kernels=kernels, 7275 strides=strides, 7276 pads=pads, 7277 dilations=dilations, 7278 W_scale=W_scale, 7279 W_zero_point=W_zero_point, 7280 use_bias=use_bias, 7281 post_op=pointwise_post_op, 7282 use_channelwise=use_channelwise, 7283 X2_zero_point=X2_zero_point, 7284 qconv_output_dtype=output_dtype, 7285 qconv_x2_dtype=output_dtype, 7286 ) 7287 7288 # Test qconv with post op sum relu 7289 @skipIfNoONEDNN 7290 def test_qconv2d_sum_relu_pt2e(self): 7291 groups_list = [1, 3] 7292 input_channels_per_group = 2 7293 output_channels_per_group = 2 7294 input_feature_map_shape = (10, 10) 7295 kernels = (3, 3) 7296 strides = (2, 2) 7297 pads = (1, 1) 7298 dilations = (1, 1) 7299 W_scale = [1.5] 7300 W_zero_point = [-3] 7301 use_bias_list = [False, True] 7302 use_channelwise_list = [False, True] 7303 X2_zero_point_list = [0, 1] 7304 options = itertools.product( 7305 groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list 7306 ) 7307 for groups, use_bias, use_channelwise, X2_zero_point in options: 7308 qconv = torch.ops.onednn.qconv2d_pointwise.binary 7309 qconv_prepack = torch.ops.onednn.qconv_prepack 7310 conv_op = torch.nn.Conv2d( 7311 input_channels_per_group * groups, 7312 output_channels_per_group * groups, 7313 kernels, 7314 strides, 7315 pads, 7316 dilations, 7317 groups, 7318 ) 7319 pointwise_post_op = PointwisePostOp(binary_attr="sum", unary_attr="relu") 7320 self._test_qconv_impl_cpu_tensor( 7321 qconv, 7322 qconv_prepack, 7323 conv_op, 7324 input_channels_per_group=input_channels_per_group, 7325 input_feature_map_shape=input_feature_map_shape, 7326 output_channels_per_group=output_channels_per_group, 7327 groups=groups, 7328 kernels=kernels, 7329 strides=strides, 7330 pads=pads, 7331 dilations=dilations, 7332 W_scale=W_scale, 7333 W_zero_point=W_zero_point, 7334 use_bias=use_bias, 7335 post_op=pointwise_post_op, 7336 use_channelwise=use_channelwise, 7337 X2_zero_point=X2_zero_point, 7338 ) 7339 7340 # Test qconv with post op sum 7341 @skipIfNoONEDNN 7342 def test_qconv2d_sum_relu_float_output_pt2e(self): 7343 groups = 1 7344 input_channels_per_group = 2 7345 output_channels_per_group = 2 7346 input_feature_map_shape = (10, 10) 7347 kernels = (3, 3) 7348 strides = (2, 2) 7349 pads = (1, 1) 7350 dilations = (1, 1) 7351 W_scale = [1.5] 7352 W_zero_point = [-3] 7353 use_bias_list = [False, True] 7354 use_channelwise = True 7355 output_dtype_list = [torch.float32, torch.bfloat16] 7356 X2_zero_point = 0 7357 use_relu_list = [True, False] 7358 options = itertools.product( 7359 use_bias_list, output_dtype_list, use_relu_list 7360 ) 7361 for use_bias, output_dtype, use_relu in options: 7362 qconv_x2_dtype = output_dtype 7363 qconv = torch.ops.onednn.qconv2d_pointwise.binary 7364 qconv_prepack = torch.ops.onednn.qconv_prepack 7365 conv_op = torch.nn.Conv2d( 7366 input_channels_per_group * groups, 7367 output_channels_per_group * groups, 7368 kernels, 7369 strides, 7370 pads, 7371 dilations, 7372 groups, 7373 ) 7374 pointwise_post_op = ( 7375 PointwisePostOp(binary_attr="sum", unary_attr="relu") 7376 if use_relu 7377 else PointwisePostOp(binary_attr="sum") 7378 ) 7379 self._test_qconv_impl_cpu_tensor( 7380 qconv, 7381 qconv_prepack, 7382 conv_op, 7383 input_channels_per_group=input_channels_per_group, 7384 input_feature_map_shape=input_feature_map_shape, 7385 output_channels_per_group=output_channels_per_group, 7386 groups=groups, 7387 kernels=kernels, 7388 strides=strides, 7389 pads=pads, 7390 dilations=dilations, 7391 W_scale=W_scale, 7392 W_zero_point=W_zero_point, 7393 use_bias=use_bias, 7394 post_op=pointwise_post_op, 7395 use_channelwise=use_channelwise, 7396 X2_zero_point=X2_zero_point, 7397 qconv_output_dtype=output_dtype, 7398 qconv_x2_dtype=qconv_x2_dtype, 7399 ) 7400 7401class TestPadding(TestCase): 7402 @given(batch_size=st.integers(1, 64), 7403 channels=st.integers(1, 64), 7404 width=st.integers(16, 128), 7405 qtype=st.sampled_from(hu._ALL_QINT_TYPES)) 7406 def test_reflection_pad1d(self, batch_size, channels, width, qtype): 7407 padding = width // 4 7408 7409 x = torch.arange(batch_size * channels * width).to(torch.float) 7410 x = x.resize(batch_size, channels, width) 7411 # Per-Tensor test 7412 scale, zp = _calculate_dynamic_qparams(x, qtype) 7413 qx = torch.quantize_per_tensor(x, scale, zp, qtype) 7414 7415 padding_op = torch.nn.ReflectionPad1d(padding) 7416 7417 y_ref = padding_op(x) 7418 qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype) 7419 qy_hat = padding_op(qx) 7420 self.assertEqual(qy_ref, qy_hat) 7421 7422 # Out variant 7423 qy_hat = torch._C._nn.reflection_pad1d(qx, padding, out=qy_hat) 7424 self.assertEqual(qy_ref, qy_hat) 7425 7426 @given(batch_size=st.integers(1, 64), 7427 channels=st.integers(1, 64), 7428 height=st.integers(16, 128), 7429 width=st.integers(16, 128), 7430 qtype=st.sampled_from(hu._ALL_QINT_TYPES)) 7431 def test_reflection_pad2d(self, batch_size, channels, height, width, qtype): 7432 padding = (width // 4, width // 4, height // 4, height // 4) 7433 7434 x = torch.arange(batch_size * channels * height * width).to(torch.float) 7435 x = x.resize(batch_size, channels, height, width) 7436 # Per-Tensor test 7437 scale, zp = _calculate_dynamic_qparams(x, qtype) 7438 qx = torch.quantize_per_tensor(x, scale, zp, qtype) 7439 7440 padding_op = torch.nn.ReflectionPad2d(padding) 7441 7442 y_ref = padding_op(x) 7443 qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype) 7444 qy_hat = padding_op(qx) 7445 self.assertEqual(qy_ref, qy_hat) 7446 7447 # Out variant 7448 qy_hat = torch._C._nn.reflection_pad2d(qx, padding, out=qy_hat) 7449 self.assertEqual(qy_ref, qy_hat) 7450 7451 @given(batch_size=st.integers(1, 64), 7452 channels=st.integers(1, 64), 7453 hwd=st.integers(1, 16), # For 3D, max input size would be 16x16x16 7454 d=st.sampled_from([1, 2, 3]), 7455 value=st.floats(-5, 5, allow_nan=False, allow_infinity=False), 7456 qtype=st.sampled_from(hu._ALL_QINT_TYPES)) 7457 def test_constant_padNd(self, batch_size, channels, d, hwd, value, qtype): 7458 padding = hwd // 4 7459 7460 shape = [batch_size, channels, hwd] 7461 op = torch.nn.ConstantPad1d 7462 if d >= 2: 7463 shape.append(hwd) 7464 op = torch.nn.ConstantPad2d 7465 if d == 3: 7466 shape.append(hwd) 7467 op = torch.nn.ConstantPad3d 7468 numel = np.prod(shape) 7469 7470 x = torch.arange(numel).to(torch.float) 7471 x = x.resize(*shape) 7472 # Per-Tensor test 7473 scale, zp = _calculate_dynamic_qparams(x, qtype) 7474 qx = torch.quantize_per_tensor(x, scale, zp, qtype) 7475 7476 padding_op = op(padding, value) 7477 7478 y_ref = padding_op(x) 7479 qy_ref = torch.quantize_per_tensor(y_ref, scale, zp, qtype) 7480 qy_hat = padding_op(qx) 7481 7482 self.assertEqual(qy_ref, qy_hat) 7483 7484 7485@unittest.skipUnless('qnnpack' in supported_qengines, 7486 "This Pytorch Build has not been built with or does not support QNNPACK") 7487class TestQNNPackOps(TestCase): 7488 """Tests the correctness of the quantized::qnnpack_relu op.""" 7489 @given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 7490 qparams=hu.qparams(dtypes=torch.quint8, 7491 zero_point_min=0, 7492 zero_point_max=0))) 7493 def test_qnnpack_relu(self, X): 7494 with override_quantized_engine('qnnpack'): 7495 X, (scale, zero_point, torch_type) = X 7496 relu = torch.nn.functional.relu 7497 X = torch.from_numpy(X) 7498 Y = X.clone() 7499 7500 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch_type) 7501 qY_hat = relu(qX) 7502 7503 Y[Y < 0] = 0 7504 qY = torch.quantize_per_tensor(Y, scale=scale, zero_point=zero_point, dtype=torch_type) 7505 self.assertEqual(qY, qY_hat) 7506 7507 """Tests the correctness of the quantized::qnnpack_tanh op.""" 7508 @skipIfNoFBGEMM 7509 def test_qnnpack_tanh(self): 7510 # Note: In QNNPACK the output scale and zero_point can only be 7511 # 2.0/256, 128 respectively, as it uses a LUT with 256 bins. 7512 7513 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 7514 memory_formats = (torch.channels_last, torch.contiguous_format) 7515 test_cases = itertools.product(shapes, memory_formats) 7516 for shape, memory_format in test_cases: 7517 X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8 7518 if memory_format == torch.channels_last and len(shape) != 4: 7519 continue 7520 X = X.to(memory_format=memory_format) 7521 qX = torch.quantize_per_tensor(X, scale=scale, 7522 zero_point=zero_point, 7523 dtype=torch_type) 7524 7525 # Floating point reference 7526 Y = torch.tanh(qX.dequantize()) 7527 qY = torch.quantize_per_tensor(Y, scale=1.0 / 128, zero_point=128, 7528 dtype=torch.quint8) 7529 with override_quantized_engine('fbgemm'): 7530 qYserver = torch.tanh(qX) 7531 with override_quantized_engine('qnnpack'): 7532 qY_hat = torch.tanh(qX) 7533 self.assertEqual( 7534 qY, qY_hat, 7535 msg=f"QNNPACK TanH failed (FP ref), memory_format {memory_format}") 7536 self.assertEqual( 7537 qYserver, qY_hat, 7538 msg=f"QNNPACK TanH failed (FBGEMM ref), memory_format {memory_format}") 7539 7540 """Tests the correctness of the quantized::qnnpack_sigmoid op.""" 7541 @skipIfNoFBGEMM 7542 def test_qnnpack_sigmoid(self): 7543 # Note: In QNNPACK the output scale and zero_point can only be 7544 # 1.0/256, 0 respectively, as it uses a LUT with 256 bins. 7545 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 7546 memory_formats = (torch.channels_last, torch.contiguous_format) 7547 test_cases = itertools.product(shapes, memory_formats) 7548 for shape, memory_format in test_cases: 7549 X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8 7550 if memory_format == torch.channels_last and len(shape) != 4: 7551 continue 7552 X = X.to(memory_format=memory_format) 7553 qX = torch.quantize_per_tensor(X, scale=scale, 7554 zero_point=zero_point, 7555 dtype=torch_type) 7556 7557 # Floating point reference 7558 Y = torch.sigmoid(qX.dequantize()) 7559 qY = torch.quantize_per_tensor(Y, scale=1.0 / 256, zero_point=0, 7560 dtype=torch.quint8) 7561 with override_quantized_engine('fbgemm'): 7562 qYserver = torch.sigmoid(qX) 7563 with override_quantized_engine('qnnpack'): 7564 qY_hat = torch.sigmoid(qX) 7565 self.assertEqual( 7566 qY, qY_hat, 7567 msg=f"QNNPACK Sigmoid failed (FP ref), memory_format {memory_format}") 7568 self.assertEqual( 7569 qYserver, qY_hat, 7570 msg=f"QNNPACK Sigmoid failed (FBGEMM ref), memory_format {memory_format}") 7571 7572 @skipIfNoFBGEMM 7573 def test_qnnpack_sigmoid_sweep(self): 7574 # Input parameters 7575 f_min = -4.0 7576 f_max = 4.0 7577 scale = (f_max - f_min) / 256.0 7578 zero_point = 128 7579 dtype = torch.quint8 7580 7581 step = scale / 2.0 7582 x = np.arange(f_min, f_max + step, step) 7583 X = torch.from_numpy(x).to(torch.float32) 7584 qX = torch.quantize_per_tensor(X, scale=scale, 7585 zero_point=zero_point, 7586 dtype=dtype) 7587 7588 dqX = qX.dequantize() 7589 # Floating point reference 7590 Y = torch.sigmoid(dqX) 7591 qY = torch.quantize_per_tensor(Y, scale=1.0 / 256, zero_point=0, 7592 dtype=torch.quint8) 7593 with override_quantized_engine('fbgemm'): 7594 qYserver = torch.sigmoid(qX) 7595 with override_quantized_engine('qnnpack'): 7596 qY_hat = torch.sigmoid(qX) 7597 self.assertEqual(qY, qY_hat, 7598 msg="QNNPACK Sigmoid failed (FP ref)!") 7599 self.assertEqual(qYserver, qY_hat, 7600 msg="QNNPACK Sigmoid failed (FBGEMM ref)!") 7601 7602 """Tests the correctness of the quantized::add (qnnpack) op.""" 7603 @settings(suppress_health_check=(HealthCheck.filter_too_much,)) 7604 @given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 7605 qparams=hu.qparams(dtypes=[torch.quint8, torch.qint8])), 7606 zero_point=st.sampled_from([0, 2, 5, 15, 127]), 7607 scale_A=st.sampled_from([0.001, 0.057, 0.889, 12.3]), 7608 scale_B=st.sampled_from([0.008, 0.0821, 0.67, 7]), 7609 scale_C=st.sampled_from([0.003, 0.07821, 0.457, 7.34]),) 7610 def test_qnnpack_add(self, A, zero_point, scale_A, scale_B, scale_C): 7611 with override_quantized_engine('qnnpack'): 7612 A_temp = A 7613 for channels_last in [True, False]: 7614 if channels_last and len(A_temp[0].shape) != 4: 7615 continue 7616 A, (scale_a, zero_point_A, torch_type) = A_temp 7617 B, (scale_b, zero_point_B, torch_type) = A_temp 7618 A = torch.from_numpy(A) 7619 B = torch.from_numpy(B) 7620 7621 if torch_type == torch.qint8 and not torch.backends.xnnpack.enabled: 7622 continue 7623 7624 if channels_last: 7625 A = A.to(memory_format=torch.channels_last) 7626 B = B.to(memory_format=torch.channels_last) 7627 assume(scale_A // scale_C >= 2**-14) 7628 assume(scale_A // scale_C < 2**8) 7629 assume(scale_B // scale_C >= 2**-14) 7630 assume(scale_B // scale_C < 2**8) 7631 7632 zero_point_C = 127 7633 np_dtype = np.uint8 7634 7635 if torch_type == torch.qint8: 7636 zero_point_C = 0 7637 np_dtype = np.int8 7638 7639 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point, 7640 dtype=torch_type) 7641 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point, 7642 dtype=torch_type) 7643 7644 # Add ground truth 7645 C = (qA.dequantize() + qB.dequantize()).numpy() 7646 7647 qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype) 7648 7649 qC_qnnp = torch.ops.quantized.add(qA, qB, scale_C, zero_point_C) 7650 7651 np.testing.assert_equal(qC, qC_qnnp.int_repr(), 7652 "Quantized addition failed.") 7653 7654 Crelu = C.copy() 7655 Crelu[C < 0] = 0 7656 qCrelu = torch.quantize_per_tensor(torch.from_numpy(Crelu), scale_C, 7657 zero_point_C, dtype=torch_type) 7658 qCrelu_hat = torch.ops.quantized.add_relu(qA, qB, scale=scale_C, zero_point=zero_point_C) 7659 np.testing.assert_equal(qCrelu.int_repr().numpy(), qCrelu_hat.int_repr(), 7660 "Quantized addition with ReLU failed.") 7661 7662 """Tests the correctness of the quantized::add (qnnpack) mul.""" 7663 @settings(suppress_health_check=(HealthCheck.filter_too_much,)) 7664 @given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), 7665 qparams=hu.qparams(dtypes=[torch.quint8, torch.qint8])), 7666 zero_point=st.sampled_from([0, 2, 5, 15, 127]), 7667 scale_A=st.sampled_from([0.3, 0.57, 0.889]), 7668 scale_B=st.sampled_from([0.8, 0.821, 0.67]), 7669 scale_C=st.sampled_from([0.3, 0.7821, 0.457]),) 7670 def test_qnnpack_mul(self, A, zero_point, scale_A, scale_B, scale_C): 7671 with override_quantized_engine('qnnpack'): 7672 A_temp = A 7673 for channels_last in [True, False]: 7674 if channels_last and len(A_temp[0].shape) != 4: 7675 continue 7676 A, (scale_a, zero_point_A, torch_type) = A_temp 7677 B, (scale_b, zero_point_B, torch_type) = A_temp 7678 A = torch.from_numpy(A) 7679 B = torch.from_numpy(B) 7680 7681 if torch_type == torch.qint8 and not torch.backends.xnnpack.enabled: 7682 continue 7683 7684 if channels_last: 7685 A = A.to(memory_format=torch.channels_last) 7686 B = B.to(memory_format=torch.channels_last) 7687 assume(scale_A // scale_C >= 2**-14) 7688 assume(scale_A // scale_C < 2**8) 7689 assume(scale_B // scale_C >= 2**-14) 7690 assume(scale_B // scale_C < 2**8) 7691 7692 zero_point_C = 127 7693 np_dtype = np.uint8 7694 7695 if torch_type == torch.qint8: 7696 zero_point_C = 0 7697 np_dtype = np.int8 7698 7699 qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point, 7700 dtype=torch_type) 7701 qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point, 7702 dtype=torch_type) 7703 7704 # Add ground truth 7705 C = (qA.dequantize() * qB.dequantize()).numpy() 7706 7707 qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype) 7708 qC_qnnp = torch.ops.quantized.mul(qA, qB, scale_C, zero_point_C) 7709 7710 np.testing.assert_equal(qC, qC_qnnp.int_repr(), 7711 "Quantized addition failed.") 7712 7713 Crelu = C.copy() 7714 Crelu[C < 0] = 0 7715 qCrelu = torch.quantize_per_tensor(torch.from_numpy(Crelu), scale_C, 7716 zero_point_C, dtype=torch_type) 7717 qCrelu_hat = torch.ops.quantized.mul_relu(qA, qB, scale=scale_C, zero_point=zero_point_C) 7718 np.testing.assert_equal(qCrelu.int_repr().numpy(), qCrelu_hat.int_repr(), 7719 "Quantized addition with ReLU failed.") 7720 7721 7722 """Tests that quantized add works with broadcasting """ 7723 def test_qnnpack_add_broadcast(self): 7724 def _run_test(A, B): 7725 qA = torch.quantize_per_tensor(A, 0.02, 0, dtype) 7726 qB = torch.quantize_per_tensor(B, 0.04, 2, dtype) 7727 7728 output_scale = 0.01 7729 output_zp = 1 7730 7731 # ground truth 7732 C = qA.dequantize() + qB.dequantize() 7733 qC = torch.quantize_per_tensor(C, output_scale, output_zp, dtype) 7734 7735 # quantized 7736 qC_hat_1 = torch.ops.quantized.add(qA, qB, output_scale, output_zp) 7737 qC_hat_2 = torch.ops.quantized.add(qB, qA, output_scale, output_zp) 7738 7739 self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_1.dequantize())) 7740 self.assertTrue(torch.allclose(qC.dequantize(), qC_hat_2.dequantize())) 7741 7742 with override_quantized_engine("qnnpack"): 7743 for dtype in (torch.qint8, torch.quint8): 7744 if dtype == torch.qint8 and not torch.backends.xnnpack.enabled: 7745 continue 7746 7747 for channels_last in [True, False]: 7748 # 4d 7749 A = torch.randn(1, 3, 4, 4) 7750 B = torch.randn(1, 1, 1, 1) 7751 if channels_last: 7752 A = A.to(memory_format=torch.channels_last) 7753 B = B.to(memory_format=torch.channels_last) 7754 _run_test(A, B) 7755 7756 # 5d 7757 C = torch.randn(1, 3, 4, 4, 4) 7758 D = torch.randn(1, 1, 1, 1, 1) 7759 if channels_last: 7760 C = C.to(memory_format=torch.channels_last_3d) 7761 D = D.to(memory_format=torch.channels_last_3d) 7762 _run_test(C, D) 7763 7764 """Tests the correctness of quantized::qnnpack_maxpool2d op.""" 7765 @given(A=hu.tensor(shapes=hu.array_shapes(4, 4, 3, 5), 7766 qparams=hu.qparams(dtypes=torch.quint8)), 7767 kernel=st.sampled_from([2, 4]), 7768 stride=st.sampled_from([1, 2]), 7769 padding=st.sampled_from([1, 2])) 7770 def test_qnnpack_maxpool2d(self, A, kernel, stride, padding): 7771 import torch.nn.functional as F 7772 7773 with override_quantized_engine('qnnpack'): 7774 A, (scale, zero_point, torch_type) = A 7775 X = torch.from_numpy(A) 7776 np_type = np.uint8 7777 dilation = 1 7778 7779 # Check constraints 7780 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 7781 7782 iH, iW = X.shape[-2:] 7783 7784 oH = pool_output_shape(iH, kernel, padding, stride, dilation) 7785 assume(oH > 0) 7786 oW = pool_output_shape(iW, kernel, padding, stride, dilation) 7787 assume(oW > 0) 7788 7789 k = (kernel, kernel) 7790 s = (stride, stride) 7791 d = (dilation, dilation) 7792 p = (padding, padding) 7793 7794 q_max_pool = torch.ops.quantized.max_pool2d 7795 7796 a = scale * (X - zero_point).to(dtype=torch.float) 7797 qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point, 7798 dtype=torch_type) 7799 7800 a_ref = qa.dequantize() 7801 7802 a_pool = F.max_pool2d(a_ref, kernel_size=k, stride=s, padding=p, 7803 dilation=d) 7804 7805 a_pool_nhwc = a_pool.permute([0, 2, 3, 1]) 7806 7807 qa_pool = q_max_pool(qa, k, s, p, d, ceil_mode=False) 7808 7809 qa_pool_int = qa_pool.dequantize() 7810 np.testing.assert_equal(a_pool.numpy(), qa_pool_int.numpy()) 7811 7812 @given(batch_size=st.integers(1, 5), 7813 channels=st.sampled_from([2, 4, 5, 8, 16, 32]), 7814 height=st.integers(4, 10), 7815 width=st.integers(4, 10), 7816 kernel=st.integers(2, 5), 7817 stride=st.integers(1, 2), 7818 padding=st.integers(1, 2), 7819 scale=st.floats(0.2, 1.6), 7820 zero_point=st.integers(0, 25) 7821 ) 7822 def test_avg_pool2d( 7823 self, 7824 batch_size, 7825 channels, 7826 height, 7827 width, 7828 kernel, 7829 stride, 7830 padding, 7831 scale, 7832 zero_point 7833 7834 ): 7835 with override_quantized_engine('qnnpack'): 7836 import torch.nn.functional as F 7837 X_init = torch.from_numpy(np.random.randint( 7838 0, 50, (batch_size, channels, height, width))) 7839 7840 X = scale * (X_init - zero_point).to(dtype=torch.float) 7841 7842 # Check constraints 7843 assume(kernel // 2 >= padding) # Kernel cannot be overhanging! 7844 7845 iH, iW = X.shape[-2:] 7846 7847 oH = pool_output_shape(iH, kernel, padding, stride, 1) 7848 assume(oH > 0) 7849 oW = pool_output_shape(iW, kernel, padding, stride, 1) 7850 assume(oW > 0) 7851 k = (kernel, kernel) 7852 s = (stride, stride) 7853 p = (padding, padding) 7854 7855 q_avg_pool = torch.ao.nn.quantized.functional.avg_pool2d 7856 7857 x_q = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 7858 dtype=torch.quint8) 7859 7860 a_pool = F.avg_pool2d(x_q.dequantize().to(torch.float), kernel_size=k, stride=s, padding=p) 7861 qa_pool = q_avg_pool(x_q, k, s, p) 7862 # Quantize Ref Output 7863 a_pool_q = torch.quantize_per_tensor(a_pool, scale=scale, zero_point=zero_point, 7864 dtype=torch.quint8) 7865 np.testing.assert_array_almost_equal(a_pool_q.int_repr().numpy(), 7866 qa_pool.int_repr().numpy(), decimal=0) 7867 7868 7869 @given(batch_size=st.integers(1, 5), 7870 channels=st.sampled_from([2, 4, 5, 8, 16, 32]), 7871 height=st.integers(4, 20), 7872 width=st.integers(4, 20), 7873 output_height=st.integers(2, 10), 7874 output_width=st.integers(2, 10), 7875 scale=st.floats(0.2, 1.6), 7876 zero_point=st.integers(0, 25) 7877 ) 7878 def test_adaptive_avg_pool2d( 7879 self, 7880 batch_size, 7881 channels, 7882 height, 7883 width, 7884 output_height, 7885 output_width, 7886 scale, 7887 zero_point 7888 7889 ): 7890 with override_quantized_engine('qnnpack'): 7891 # Check constraints 7892 assume(height >= output_height) 7893 assume(width >= output_width) 7894 7895 import torch.nn.functional as F 7896 X_init = torch.from_numpy(np.random.randint( 7897 0, 50, (batch_size, channels, height, width))) 7898 7899 X = scale * (X_init - zero_point).to(dtype=torch.float) 7900 7901 iH, iW = X.shape[-2:] 7902 7903 q_avg_pool = torch.ao.nn.quantized.functional.adaptive_avg_pool2d 7904 7905 x_q = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 7906 dtype=torch.quint8) 7907 7908 a_pool = F.adaptive_avg_pool2d(x_q.dequantize().to(torch.float), (output_height, output_width)) 7909 qa_pool = q_avg_pool(x_q, (output_height, output_width)) 7910 # Quantize Ref Output 7911 a_pool_q = torch.quantize_per_tensor(a_pool, scale=scale, zero_point=zero_point, 7912 dtype=torch.quint8) 7913 np.testing.assert_array_almost_equal(a_pool_q.int_repr().numpy(), 7914 qa_pool.int_repr().numpy(), decimal=0) 7915 7916 7917 @given(batch_size=st.integers(1, 5), 7918 channels=st.sampled_from([2, 4, 5, 8, 16, 32]), 7919 height=st.integers(4, 10), 7920 width=st.integers(4, 10), 7921 scale=st.floats(0.02, 2.6), 7922 zero_point=st.integers(0, 25)) 7923 def test_mean(self, batch_size, channels, height, width, scale, zero_point): 7924 with override_quantized_engine('qnnpack'): 7925 dim = (2, 3) 7926 X_init = torch.from_numpy(np.random.randint( 7927 0, 50, (batch_size, channels, height, width))) 7928 X = scale * (X_init - zero_point).to(dtype=torch.float) 7929 7930 qX = torch.quantize_per_tensor(X, scale, zero_point, torch.quint8) 7931 Y = torch.mean(qX.dequantize(), dim) 7932 Y = torch.quantize_per_tensor(Y, scale, zero_point, torch.quint8) 7933 qY = torch.mean(qX, dim) 7934 np.testing.assert_array_almost_equal(Y.int_repr().numpy(), qY.int_repr().numpy(), decimal=0) 7935 7936 """Tests the correctness of the quantized::hardtanh op.""" 7937 def test_hardtanh(self): 7938 if 'qnnpack' not in torch.backends.quantized.supported_engines: 7939 return 7940 with override_quantized_engine('qnnpack'): 7941 shapes = ((4,), (4, 4), (4, 4, 4), (4, 4, 4, 4)) 7942 memory_formats = (torch.channels_last, torch.contiguous_format) 7943 min_vals = (-0.5, -0.3, 0.5) 7944 max_vals = (-0.3, 0.3, 0.7) 7945 test_cases = itertools.product(shapes, memory_formats, min_vals, max_vals) 7946 for shape, memory_format, min_val, max_val in test_cases: 7947 X, scale, zero_point, torch_type = torch.randn(*shape), 1.0, 0, torch.quint8 7948 if memory_format == torch.channels_last and len(shape) != 4: 7949 continue 7950 7951 Y = X.clone() 7952 Y[Y < min_val] = min_val 7953 Y[Y > max_val] = max_val 7954 qY = torch.quantize_per_tensor(Y, scale=scale, 7955 zero_point=zero_point, dtype=torch_type) 7956 qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, 7957 dtype=torch_type) 7958 7959 qY_hat = torch.ao.nn.quantized.functional.hardtanh(qX, min_val, max_val) 7960 self.assertEqual( 7961 qY, qY_hat, 7962 msg=f"hardtanh failed:\nactual {qY_hat}\nexpected {qY}\nmemory_format {memory_format}") 7963 7964"""Tests the correctness of the tensor comparators.""" 7965class TestComparatorOps(TestCase): 7966 """Tests the element-wise equality ops.""" 7967 @given(A=hu.tensor(shapes=((3, 4, 5),), 7968 qparams=hu.qparams()), 7969 B=hu.tensor(shapes=((5,), (1, 5), (1, 1, 5), (4, 5), (3, 4, 5)), 7970 qparams=hu.qparams())) 7971 def test_compare_tensor_tensor(self, A, B): 7972 A, (scale_a, zero_point_a, dtype_a) = A 7973 B, (scale_b, zero_point_b, dtype_b) = B 7974 tA = torch.from_numpy(A) 7975 tB = torch.from_numpy(B) 7976 7977 qA = torch.quantize_per_tensor(tA, scale=scale_a, zero_point=zero_point_a, 7978 dtype=dtype_a) 7979 qB = torch.quantize_per_tensor(tB, scale=scale_b, zero_point=zero_point_b, 7980 dtype=dtype_b) 7981 dqA = qA.dequantize() 7982 dqB = qB.dequantize() 7983 7984 ops_under_test = ('__eq__', '__ne__', '__ge__', '__le__', '__gt__', 7985 '__lt__', 'eq', 'ne', 'ge', 'le', 'gt', 'lt') 7986 7987 for op in ops_under_test: 7988 result_ref = getattr(dqA, op)(dqB) 7989 result = getattr(qA, op)(qB) 7990 self.assertEqual(result_ref, result, 7991 msg=f"'tensor.{op}(tensor)'' failed") 7992 # Reversed broadcasting. 7993 result_ref = getattr(dqB, op)(dqA) 7994 result = getattr(qB, op)(qA) 7995 self.assertEqual(result_ref, result, 7996 msg=f"'tensor.{op}(tensor)'' failed") 7997 7998 @given(A=hu.tensor(shapes=((3, 4, 5),), 7999 qparams=hu.qparams()), 8000 b=hu.floats(allow_infinity=False, allow_nan=False)) 8001 def test_compare_tensor_scalar(self, A, b): 8002 A, (scale_a, zero_point_a, dtype_a) = A 8003 tA = torch.from_numpy(A) 8004 8005 qA = torch.quantize_per_tensor(tA, scale=scale_a, zero_point=zero_point_a, 8006 dtype=dtype_a) 8007 dqA = qA.dequantize() 8008 8009 ops_under_test_reversible = ('__eq__', '__ne__', '__ge__', '__le__', 8010 '__gt__', '__lt__') 8011 ops_under_test_nonreversible = ('eq', 'ne', 'ge', 'le', 'gt', 'lt') 8012 8013 for op in ops_under_test_reversible: 8014 result_ref = getattr(dqA, op)(b) 8015 result = getattr(qA, op)(b) 8016 note(f"result_ref 1: {result_ref}") 8017 note(f"result 1: {result}") 8018 self.assertEqual(result_ref, result, 8019 msg=f"'tensor.{op}(scalar)'' failed") 8020 # Reversed broadcasting. 8021 result_ref = getattr(b, op)(dqA) 8022 result = getattr(b, op)(qA) 8023 note(f"result_ref 2: {result_ref}") 8024 note(f"result 2: {result}") 8025 self.assertEqual(result_ref, result, 8026 msg=f"'scalar.{op}(tensor)'' failed") 8027 8028 for op in ops_under_test_nonreversible: 8029 result_ref = getattr(dqA, op)(b) 8030 result = getattr(qA, op)(b) 8031 note(f"result_ref 3: {result_ref}") 8032 note(f"result 3: {result}") 8033 self.assertEqual(result_ref, result, 8034 msg=f"'tensor.{op}(scalar)'' failed") 8035