1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from functools import partial 4from typing import Any, Callable, Optional, Tuple 5 6import torch 7from torch._higher_order_ops.out_dtype import out_dtype 8from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 9from torch.ao.quantization.pt2e.export_utils import _WrapperModule 10from torch.ao.quantization.pt2e.utils import ( 11 _get_aten_graph_module_for_pattern, 12 _replace_literals_with_existing_placeholders, 13 _replace_literals_with_new_placeholders, 14 remove_tensor_overload_for_qdq_ops, 15) 16from torch.fx import GraphModule 17from torch.fx.subgraph_rewriter import replace_pattern 18 19 20__all__ = [ 21 "reference_representation_rewrite", 22] 23 24 25_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( 26 torch.randint(-128, 127, (2, 5), dtype=torch.int8), 27 torch.randn(1, dtype=torch.float), 28 torch.zeros(1, dtype=torch.int), 29 torch.tensor([-128], dtype=torch.int), 30 torch.tensor([127], dtype=torch.int), 31 torch.randint(-128, 127, (5, 5), dtype=torch.int8), 32 torch.randn(1, dtype=torch.float), 33 torch.zeros(1, dtype=torch.int), 34 torch.tensor([-127], dtype=torch.int), 35 torch.tensor([127], dtype=torch.int), 36 torch.randn(1, dtype=torch.float), 37 torch.randn(1, dtype=torch.float), 38 torch.zeros(1, dtype=torch.int), 39 torch.tensor([-128], dtype=torch.int), 40 torch.tensor([127], dtype=torch.int), 41) 42 43 44def _qdq_quantized_linear( 45 x_i8, 46 x_scale, 47 x_zero_point, 48 x_quant_min, 49 x_quant_max, 50 weight_i8, 51 weight_scale, 52 weight_zero_point, 53 weight_quant_min, 54 weight_quant_max, 55 bias_fp32, 56 out_scale, 57 out_zero_point, 58 out_quant_min, 59 out_quant_max, 60): 61 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 62 x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 63 ) 64 weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 65 weight_i8, 66 weight_scale, 67 weight_zero_point, 68 weight_quant_min, 69 weight_quant_max, 70 torch.int8, 71 ) 72 out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) 73 out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( 74 out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 75 ) 76 return out_i8 77 78 79def _reference_quantized_linear( 80 x_i8, 81 x_scale, 82 x_zero_point, 83 x_quant_min, 84 x_quant_max, 85 weight_i8, 86 weight_scale, 87 weight_zero_point, 88 weight_quant_min, 89 weight_quant_max, 90 bias_fp32, 91 out_scale, 92 out_zero_point, 93 out_quant_min, 94 out_quant_max, 95): 96 # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. 97 # This results in failure to match the pattern. 98 # Therefore, we call a torch.ops.aten.clamp here 99 x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) 100 weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) 101 102 x_i16 = x_i8.to(torch.int16) 103 weight_i16 = weight_i8.to(torch.int16) 104 # always set bias to None so that the same representation can work for the case 105 # no matter if bias_scale == x_scale * weight_scale or not 106 acc_i32 = out_dtype( 107 torch.ops.aten.linear.default, 108 torch.int32, 109 x_i16 - x_zero_point, 110 weight_i16 - weight_zero_point, 111 None, 112 ) 113 # TODO: change to mul.Scalar 114 # Note: we are quantizing bias with these scales without signal from user, but it might be OK 115 bias_scale = x_scale * weight_scale 116 bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) 117 acc_i32 = acc_i32 + bias_i32 118 # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values 119 acc_i32 = ( 120 out_dtype( 121 torch.ops.aten.mul.Tensor, 122 torch.int32, 123 acc_i32, 124 x_scale * weight_scale / out_scale, 125 ) 126 + out_zero_point 127 ) 128 out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) 129 return out_i8 130 131 132_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( 133 torch.randn((2, 5), dtype=torch.float), 134 -128, 135 127, 136 torch.finfo(torch.float32).eps, 137 torch.randint(-128, 127, (5, 5), dtype=torch.int8), 138 torch.randn(1, dtype=torch.float), 139 torch.zeros(1, dtype=torch.int), 140 torch.tensor([-127], dtype=torch.int), 141 torch.tensor([127], dtype=torch.int), 142 torch.randn(1, dtype=torch.float), 143) 144 145 146def _qdq_dynamic_quantized_linear( 147 x_fp32, 148 x_quant_min, 149 x_quant_max, 150 x_eps, 151 weight_i8, 152 weight_scale, 153 weight_zero_point, 154 weight_quant_min, 155 weight_quant_max, 156 bias_fp32, 157): 158 x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( 159 x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 160 ) 161 x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( 162 x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 163 ) 164 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 165 x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 166 ) 167 weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 168 weight_i8, 169 weight_scale, 170 weight_zero_point, 171 weight_quant_min, 172 weight_quant_max, 173 torch.int8, 174 ) 175 out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) 176 return out_fp32 177 178 179def _reference_dynamic_quantized_linear( 180 x_fp32, 181 x_quant_min, 182 x_quant_max, 183 x_eps, 184 weight_i8, 185 weight_scale, 186 weight_zero_point, 187 weight_quant_min, 188 weight_quant_max, 189 bias_fp32, 190): 191 x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams( 192 x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8 193 ) 194 # decomposed representation for quantize_per_tensor 195 # TODO: use out_dtype(mul, ...) here when the op is ready 196 x_fp32 = x_fp32 / x_scale # fp32 197 # round modes might be different here 198 # pytorch is rounding to even, which is also common for most of the backends 199 x_fp32 = torch.round(x_fp32) # fp32 200 x_i32 = x_fp32.to(dtype=torch.int32) # int32 201 x_i32 = x_i32 + x_zero_point # int32 202 # clamp works for fp32, int32 and int8 dtypes 203 x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32 204 x_i8 = x_i32.to(dtype=torch.int8) 205 206 weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) 207 208 x_i16 = x_i8.to(torch.int16) 209 weight_i16 = weight_i8.to(torch.int16) 210 # always set bias to None so that the same representation can work for the case 211 # no matter if bias_scale == x_scale * weight_scale or not 212 acc_i32 = out_dtype( 213 torch.ops.aten.linear.default, 214 torch.int32, 215 x_i16 - x_zero_point, 216 weight_i16 - weight_zero_point, 217 None, 218 ) 219 bias_scale = x_scale * weight_scale 220 bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) 221 acc_i32 = acc_i32 + bias_i32 222 out_fp32 = acc_i32 * (x_scale * weight_scale) 223 return out_fp32 224 225 226_QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( 227 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 228 torch.randn(1, dtype=torch.float), 229 torch.zeros(1, dtype=torch.int), 230 torch.tensor([-128], dtype=torch.int), 231 torch.tensor([127], dtype=torch.int), 232 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 233 torch.randn(1, dtype=torch.float), 234 torch.zeros(1, dtype=torch.int), 235 torch.tensor([-127], dtype=torch.int), 236 torch.tensor([127], dtype=torch.int), 237 torch.randn(1, dtype=torch.float), 238 torch.randn(1, dtype=torch.float), 239 torch.zeros(1, dtype=torch.int), 240 torch.tensor([-128], dtype=torch.int), 241 torch.tensor([127], dtype=torch.int), 242) 243 244 245def _qdq_quantized_conv2d( 246 x_i8, 247 x_scale, 248 x_zero_point, 249 x_quant_min, 250 x_quant_max, 251 weight_i8, 252 weight_scale, 253 weight_zero_point, 254 weight_quant_min, 255 weight_quant_max, 256 bias_fp32, 257 out_scale, 258 out_zero_point, 259 out_quant_min, 260 out_quant_max, 261): 262 stride = [1, 1] 263 padding = [0, 0] 264 dilation = [1, 1] 265 transposed = False 266 output_padding = [0, 0] 267 groups = 1 268 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 269 x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 270 ) 271 weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 272 weight_i8, 273 weight_scale, 274 weight_zero_point, 275 weight_quant_min, 276 weight_quant_max, 277 torch.int8, 278 ) 279 out_fp32 = torch.ops.aten.convolution.default( 280 x_fp32, 281 weight_fp32, 282 bias_fp32, 283 stride, 284 padding, 285 dilation, 286 transposed, 287 output_padding, 288 groups, 289 ) 290 out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( 291 out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 292 ) 293 return out_i8 294 295 296def _reference_quantized_conv2d( 297 x_i8, 298 x_scale, 299 x_zero_point, 300 x_quant_min, 301 x_quant_max, 302 weight_i8, 303 weight_scale, 304 weight_zero_point, 305 weight_quant_min, 306 weight_quant_max, 307 bias_fp32, 308 out_scale, 309 out_zero_point, 310 out_quant_min, 311 out_quant_max, 312): 313 stride = [1, 1] 314 padding = [0, 0] 315 dilation = [1, 1] 316 transposed = False 317 output_padding = [0, 0] 318 groups = 1 319 # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. 320 # This results in failure to match the pattern. 321 # Therefore, we call a torch.ops.aten.clamp here 322 x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) 323 weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) 324 325 x_i16 = x_i8.to(torch.int16) 326 weight_i16 = weight_i8.to(torch.int16) 327 # always set bias to None so that the same representation can work for the case 328 # no matter if bias_scale == x_scale * weight_scale or not 329 acc_i32 = out_dtype( 330 torch.ops.aten.convolution.default, 331 torch.int32, 332 x_i16 - x_zero_point, 333 weight_i16 - weight_zero_point, 334 None, 335 stride, 336 padding, 337 dilation, 338 transposed, 339 output_padding, 340 groups, 341 ) 342 # Note: we are quantizing bias with these scales without signal from user, but it might be OK 343 bias_scale = x_scale * weight_scale 344 # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to: 345 # Take linear calculation for example 346 # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32 347 # Represent X, W fp32 as their dequant transforms 348 # A_fp32 = (A_q - A_zero_point)/A_scale 349 # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32 350 # Factor out X_scale and W_scale 351 # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 352 # In order to addition of bias_(i)_fp32 inside, we must do 353 # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 354 # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale 355 # Thus bias quantization to int32 must be with X_scale * W_scale 356 357 bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) 358 # Unsqueeze to match broadcast dims 359 # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare 360 # in graph pattern replacement 361 bias_i32 = bias_i32.unsqueeze(-1) 362 bias_i32 = bias_i32.unsqueeze(-1) 363 acc_i32 = acc_i32 + bias_i32 364 # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values 365 acc_i32 = ( 366 out_dtype( 367 torch.ops.aten.mul.Tensor, 368 torch.int32, 369 acc_i32, 370 x_scale * weight_scale / out_scale, 371 ) 372 + out_zero_point 373 ) 374 out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) 375 return out_i8 376 377 378_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( 379 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 380 torch.randn(1, dtype=torch.float), 381 torch.zeros(1, dtype=torch.int), 382 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 383 torch.randn(1, dtype=torch.float), 384 torch.zeros(1, dtype=torch.int), 385 torch.randn(1, dtype=torch.float), 386 torch.zeros(1, dtype=torch.int), 387 torch.tensor([-128], dtype=torch.int), 388 torch.tensor([127], dtype=torch.int), 389) 390 391 392def _qdq_quantized_add_relu( 393 x_i8, 394 x_scale, 395 x_zero_point, 396 y_i8, 397 y_scale, 398 y_zero_point, 399 out_scale, 400 out_zero_point, 401 quant_min, 402 quant_max, 403): 404 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 405 x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 406 ) 407 y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 408 y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 409 ) 410 out_fp32 = x_fp32 + y_fp32 411 out_fp32 = torch.ops.aten.relu(out_fp32) 412 out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( 413 out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 414 ) 415 return out_i8 416 417 418def _reference_quantized_add_relu( 419 x_i8, 420 x_scale, 421 x_zero_point, 422 y_i8, 423 y_scale, 424 y_zero_point, 425 out_scale, 426 out_zero_point, 427 quant_min, 428 quant_max, 429): 430 """ 431 See comments for `_reference_quantized_add` for more information on 432 how to derive the formula for out_i8 based on x_i8 and y_i8 433 """ 434 x_i32 = x_i8.to(torch.int32) 435 y_i32 = y_i8.to(torch.int32) 436 # TODO: change this to mul.Scalar? 437 x_i32 = out_dtype( 438 torch.ops.aten.mul.Tensor, 439 torch.int32, 440 (x_i32 - x_zero_point), 441 (x_scale / out_scale), 442 ) 443 y_i32 = out_dtype( 444 torch.ops.aten.mul.Tensor, 445 torch.int32, 446 (y_i32 - y_zero_point), 447 (y_scale / out_scale), 448 ) 449 out_i32 = x_i32 + y_i32 + out_zero_point 450 # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point) 451 out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8) 452 return out_i8 453 454 455def _qdq_quantized_add( 456 x_i8, 457 x_scale, 458 x_zero_point, 459 y_i8, 460 y_scale, 461 y_zero_point, 462 out_scale, 463 out_zero_point, 464 quant_min, 465 quant_max, 466): 467 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 468 x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8 469 ) 470 y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 471 y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8 472 ) 473 out_fp32 = x_fp32 + y_fp32 474 out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( 475 out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 476 ) 477 return out_i8 478 479 480def _reference_quantized_add( 481 x_i8, 482 x_scale, 483 x_zero_point, 484 y_i8, 485 y_scale, 486 y_zero_point, 487 out_scale, 488 out_zero_point, 489 quant_min, 490 quant_max, 491): 492 """ 493 # How to Derive the formula for out_i8 based on x_i8 and y_i8 494 # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8) 495 496 # out_i8 is quantized output, we can write down the formula for it first: 497 out_i8 = out_f32 / out_scale + out_zero_point (1) 498 499 # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8 500 out_f32 = x_f32 + y_f32 (2) 501 x_fp32 = (x_i8 - x_zero_point) * x_scale (3) 502 y_fp32 = (y_i8 - y_zero_point) * y_scale (4) 503 504 # applying the above fomula to the out_i8 equation we can get the following: 505 out_i8 = out_fp32 / out_scale + out_zero_point # (1) 506 = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 507 = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) 508 """ 509 x_i32 = x_i8.to(torch.int32) 510 y_i32 = y_i8.to(torch.int32) 511 # TODO: use out_dtype op 512 x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32) 513 y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32) 514 out_i32 = x_i32 + y_i32 + out_zero_point 515 quant_min = -128 516 quant_max = 127 517 out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8) 518 return out_i8 519 520 521_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( 522 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 523 torch.randn(1, dtype=torch.float), 524 torch.zeros(1, dtype=torch.int), 525 torch.tensor([-128], dtype=torch.int), 526 torch.tensor([127], dtype=torch.int), 527 torch.randn(1, dtype=torch.float), 528 torch.zeros(1, dtype=torch.int), 529 torch.tensor([-128], dtype=torch.int), 530 torch.tensor([127], dtype=torch.int), 531) 532 533 534def _qdq_quantized_max_pool2d( 535 x_i8, 536 x_scale, 537 x_zero_point, 538 x_quant_min, 539 x_quant_max, 540 out_scale, 541 out_zero_point, 542 out_quant_min, 543 out_quant_max, 544): 545 kernel_size = 1 546 stride = 1 547 padding = 0 548 dilation = 1 549 ceil_mode = False 550 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 551 x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8 552 ) 553 out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default( 554 x_fp32, kernel_size, stride, padding, dilation, ceil_mode 555 ) 556 out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( 557 out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8 558 ) 559 return out_i8 560 561 562def _reference_quantized_max_pool2d( 563 x_i8, 564 x_scale, 565 x_zero_point, 566 x_quant_min, 567 x_quant_max, 568 out_scale, 569 out_zero_point, 570 out_quant_min, 571 out_quant_max, 572): 573 kernel_size = 1 574 stride = 1 575 padding = 0 576 dilation = 1 577 ceil_mode = False 578 # to preserve x_quant_min, x_quant_max in the graph for pattern matching 579 x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max) 580 x_i32 = x_i8.to(torch.int32) 581 out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default( 582 x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode 583 ) 584 out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point 585 out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max) 586 out_i8 = out_fp32.to(torch.int8) 587 return out_i8 588 589 590_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( 591 torch.randn(1, 3, 3, 3, dtype=torch.float), 592 torch.randn(1, dtype=torch.float), 593 torch.zeros(1, dtype=torch.int), 594 torch.tensor([-128], dtype=torch.int), 595 torch.tensor([127], dtype=torch.int), 596) 597 598 599def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): 600 x = torch.ops.quantized_decomposed.quantize_per_tensor( 601 x_fp32, scale, zero_point, quant_min, quant_max, torch.int8 602 ) 603 return x 604 605 606def _reference_quantize_per_tensor_int8( 607 x_fp32, scale, zero_point, quant_min, quant_max 608): 609 # TODO: use out_dtype(mul, ...) here when the op is ready 610 x = x_fp32 / scale # fp32 611 # round modes might be different here 612 # pytorch is rounding to even, which is also common for most of the backends 613 x = torch.round(x) # fp32 614 x = x.to(dtype=torch.int32) # int32 615 x = x + zero_point # int32 616 # clamp works for fp32, int32 and int8 dtypes 617 x = torch.clamp(x, quant_min, quant_max) # int32 618 x = x.to(dtype=torch.int8) 619 return x 620 621 622_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( 623 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 624 torch.randn(1, dtype=torch.float), 625 torch.zeros(1, dtype=torch.int), 626 torch.tensor([-128], dtype=torch.int), 627 torch.tensor([127], dtype=torch.int), 628) 629 630 631def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): 632 x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( 633 x_i8, scale, zero_point, quant_min, quant_max, torch.int8 634 ) 635 return x_fp32 636 637 638def _reference_dequantize_per_tensor_int8( 639 x_i8, scale, zero_point, quant_min, quant_max 640): 641 # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. 642 # This results in failure to match the pattern. 643 # Therefore, we call a torch.ops.aten.clamp here 644 x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) 645 # TODO: use out_dtype op 646 # note: x_i8.to(torch.int32) does not work here 647 # TODO: debug the implementation later when torchdynamo time out issue is resolved 648 return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) 649 650 651_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( 652 torch.randn(1, 3, 3, 3, dtype=torch.float), 653 torch.randn(3, dtype=torch.float), 654 torch.zeros(3, dtype=torch.int), 655 1, 656 -128, 657 127, 658) 659 660 661def _quantize_per_channel_int8( 662 x_fp32, scales, zero_points, ch_axis, quant_min, quant_max 663): 664 out_i8 = torch.ops.quantized_decomposed.quantize_per_channel( 665 x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 666 ) 667 return out_i8 668 669 670def _reference_quantize_per_channel_int8( 671 x_fp32, scales, zero_points, ch_axis, quant_min, quant_max 672): 673 x_fp32 = torch.transpose(x_fp32, ch_axis, -1) 674 out_i32 = torch.ops.aten.clamp( 675 torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max 676 ) 677 out_i32 = torch.transpose(out_i32, ch_axis, -1) 678 return out_i32.to(torch.int8) 679 680 681_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( 682 torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), 683 torch.randn(3, dtype=torch.float), 684 torch.zeros(3, dtype=torch.int), 685 1, 686 -128, 687 127, 688) 689 690 691def _dequantize_per_channel_int8( 692 x_i8, scales, zero_points, ch_axis, quant_min, quant_max 693): 694 # the following will be replaced as placeholders 695 out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel( 696 x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 697 ) 698 return out_fp32 699 700 701def _reference_dequantize_per_channel_int8( 702 x_i8, scales, zero_points, ch_axis, quant_min, quant_max 703): 704 # the following will be replaced as placeholders 705 # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops) 706 # we call a torch.ops.aten.clamp here 707 x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) 708 x_i8 = torch.transpose(x_i8, ch_axis, -1) 709 x_i32 = x_i8.to(torch.int32) 710 out_fp32 = (x_i32 - zero_points).to(torch.float) * scales 711 out_fp32 = torch.transpose(out_fp32, ch_axis, -1) 712 return out_fp32 713 714 715def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule): 716 return _replace_literals_with_existing_placeholders( 717 gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5} 718 ) 719 720 721@dataclass 722class _RewriteInfo: 723 """Data needed for rewrite, this includes example inputs, pattern and replacement functions 724 and post transformation functions for the exported pattern and replacement GraphModule 725 """ 726 727 # example inputs used for exporting the pattern into GraphModule 728 example_inputs: Tuple[Any, ...] 729 pattern: Callable 730 replacement: Callable 731 # post transformation on the exported pattern and replacement GraphModule 732 pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None 733 replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None 734 735 736_REWRITE_INFO_LIST = [ 737 _RewriteInfo( 738 _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, 739 _WrapperModule(_qdq_dynamic_quantized_linear), 740 _WrapperModule(_reference_dynamic_quantized_linear), 741 partial( 742 _replace_literals_with_existing_placeholders, 743 literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, 744 ), 745 partial( 746 _replace_literals_with_existing_placeholders, 747 literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3}, 748 ), 749 ), 750 _RewriteInfo( 751 _QUANTIZED_LINEAR_EXAMPLE_INPUTS, 752 _WrapperModule(_qdq_quantized_linear), 753 _WrapperModule(_reference_quantized_linear), 754 _replace_literals_with_new_placeholders, 755 _replace_literals_with_new_placeholders, 756 ), 757 _RewriteInfo( 758 _QUANTIZED_CONV2d_EXAMPLE_INPUTS, 759 _WrapperModule(_qdq_quantized_conv2d), 760 _WrapperModule(_reference_quantized_conv2d), 761 partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), 762 partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), 763 ), 764 _RewriteInfo( 765 _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, 766 _WrapperModule(_qdq_quantized_add_relu), 767 _WrapperModule(_reference_quantized_add_relu), 768 ), 769 _RewriteInfo( 770 _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, 771 _WrapperModule(_qdq_quantized_add), 772 _WrapperModule(_reference_quantized_add), 773 ), 774 _RewriteInfo( 775 _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, 776 _WrapperModule(_qdq_quantized_max_pool2d), 777 _WrapperModule(_reference_quantized_max_pool2d), 778 _replace_literals_with_new_placeholders, 779 _replace_literals_with_new_placeholders, 780 ), 781 _RewriteInfo( 782 _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, 783 _WrapperModule(_quantize_per_tensor_int8), 784 _WrapperModule(_reference_quantize_per_tensor_int8), 785 ), 786 _RewriteInfo( 787 _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, 788 _WrapperModule(_dequantize_per_tensor_int8), 789 _WrapperModule(_reference_dequantize_per_tensor_int8), 790 ), 791 _RewriteInfo( 792 _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, 793 _WrapperModule(_quantize_per_channel_int8), 794 _WrapperModule(_reference_quantize_per_channel_int8), 795 _replace_ph_qdq_per_channel_replacement, 796 _replace_ph_qdq_per_channel_replacement, 797 ), 798 _RewriteInfo( 799 _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, 800 _WrapperModule(_dequantize_per_channel_int8), 801 _WrapperModule(_reference_dequantize_per_channel_int8), 802 _replace_ph_qdq_per_channel_replacement, 803 _replace_ph_qdq_per_channel_replacement, 804 ), 805] 806 807 808def reference_representation_rewrite(model: GraphModule) -> GraphModule: 809 remove_tensor_overload_for_qdq_ops(model) 810 for rewrite_info in _REWRITE_INFO_LIST: 811 example_inputs = rewrite_info.example_inputs 812 pattern = rewrite_info.pattern 813 replacement = rewrite_info.replacement 814 pattern_post_trans = rewrite_info.pattern_post_trans 815 replacement_post_trans = rewrite_info.replacement_post_trans 816 pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] 817 remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] 818 replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] 819 remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] 820 if pattern_post_trans: 821 pattern = pattern_post_trans(pattern) 822 if replacement_post_trans: 823 replacement = replacement_post_trans(replacement) 824 pattern.recompile() # type: ignore[attr-defined] 825 replacement.recompile() # type: ignore[attr-defined] 826 matches = replace_pattern(model, pattern, replacement) 827 return model 828