1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import numbers 7import operator 8from functools import partial 9from typing import Callable, Dict, List, Sequence, Tuple 10 11import torch 12from torch._ops import OpOverload 13 14from torch._subclasses import FakeTensor 15from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize 16 17from torch.ao.quantization.observer import FixedQParamsObserver 18from torch.ao.quantization.quantizer import ( 19 DerivedQuantizationSpec, 20 QuantizationAnnotation, 21 QuantizationSpec, 22 SharedQuantizationSpec, 23) 24from torch.ao.quantization.quantizer.utils import ( 25 _annotate_input_qspec_map, 26 _annotate_output_qspec, 27) 28from torch.fx import Node 29 30from .qconfig import ( 31 get_16a16w_qnn_ptq_config, 32 get_16a4w_qnn_qat_config, 33 get_8a8w_qnn_qat_config, 34 QuantizationConfig, 35) 36 37 38QUANT_ANNOTATION_KEY = "quantization_annotation" 39OP_ANNOTATOR: Dict[OpOverload, Callable] = {} 40 41 42def register_annotator(ops: List[OpOverload]): 43 def decorator(annotator: Callable): 44 for op in ops: 45 OP_ANNOTATOR[op] = annotator 46 47 return decorator 48 49 50def _is_annotated(nodes: List[Node]): 51 """ 52 Given a list of nodes (that represents an operator pattern), 53 return True if any of the node 54 is annotated, otherwise return False 55 """ 56 annotated = False 57 for node in nodes: 58 annotated = annotated or ( 59 QUANT_ANNOTATION_KEY in node.meta 60 and node.meta[QUANT_ANNOTATION_KEY]._annotated 61 ) 62 return annotated 63 64 65def _is_float_tensor(node: Node): 66 """Check if the node's tensor is a float tensor, so that we can skip quantization for the node 67 since observers only works with float Tensors 68 """ 69 if ( 70 not isinstance(node, Node) 71 or "val" not in node.meta 72 or not isinstance(node.meta["val"], FakeTensor) 73 ): 74 return False 75 return node.meta["val"].dtype == torch.float32 76 77 78def _mark_nodes_as_annotated(nodes: List[Node]): 79 for node in nodes: 80 if QUANT_ANNOTATION_KEY not in node.meta: 81 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation() 82 node.meta[QUANT_ANNOTATION_KEY]._annotated = True 83 84 85def annotate_in_out_obs_sharing_op( 86 node: Node, quantization_config: QuantizationConfig 87) -> None: 88 if _is_annotated([node]): 89 return 90 91 input_act = node.args[0] 92 assert isinstance(input_act, Node) 93 94 # only annotate input output sharing operator 95 # when the output of the input node is annotated 96 if ( 97 QUANT_ANNOTATION_KEY not in input_act.meta 98 or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated 99 or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None 100 ): 101 return 102 103 act_qspec = SharedQuantizationSpec(input_act) 104 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 105 input_qspec_map={ 106 input_act: act_qspec, 107 }, 108 output_qspec=act_qspec, 109 _annotated=True, 110 ) 111 112 113def annotate_single_in_single_out( 114 node: Node, quantization_config: QuantizationConfig 115) -> None: 116 if _is_annotated([node]): 117 return 118 119 input_qspec_map = {} 120 input_act = node.args[0] 121 assert isinstance(input_act, Node) 122 input_qspec_map[input_act] = quantization_config.input_activation 123 124 if _is_float_tensor(node): 125 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 126 input_qspec_map=input_qspec_map, 127 output_qspec=quantization_config.output_activation, 128 _annotated=True, 129 ) 130 131 132@register_annotator([torch.ops.aten.topk.default]) 133def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None: 134 if _is_annotated([node]): 135 return 136 # We can use single_in_single_out since we don't want to quantize indices output 137 annotate_single_in_single_out(node, quantization_config) 138 139 140def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: 141 if _is_annotated([node]): 142 return 143 144 input_act_qspec = quantization_config.input_activation 145 output_act_qspec = ( 146 quantization_config.output_activation if _is_float_tensor(node) else None 147 ) 148 149 input_qspec_map = {} 150 input_act0 = node.args[0] 151 if _is_float_tensor(input_act0): 152 input_qspec_map[input_act0] = input_act_qspec 153 154 input_act1 = node.args[1] 155 if _is_float_tensor(input_act1): 156 input_qspec_map[input_act1] = input_act_qspec 157 158 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 159 input_qspec_map=input_qspec_map, 160 output_qspec=output_act_qspec, 161 _annotated=True, 162 ) 163 164 165@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) 166def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: 167 annotate_binary(node, quantization_config) 168 169 170@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor]) 171def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None: 172 annotate_binary(node, quantization_config) 173 174 175@register_annotator( 176 [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar] 177) 178def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: 179 annotate_binary(node, quantization_config) 180 181 182@register_annotator( 183 [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor] 184) 185def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None: 186 def _derived_inp1_const_div_quant_spec( 187 node: torch.fx.Node, output_qspec: QuantizationSpec 188 ) -> DerivedQuantizationSpec: 189 def _derive_div_qparams_fn( 190 obs_or_fqs: List, 191 const_val: float, 192 ) -> Tuple[torch.Tensor, torch.Tensor]: 193 inp_0_obs_or_fq = obs_or_fqs[0] 194 inp_0_scale, inp_0_zp = inp_0_obs_or_fq.calculate_qparams() 195 derived_scale = inp_0_scale / const_val 196 return (derived_scale, inp_0_zp) 197 198 inp_0 = node.args[0] 199 const_inp_1 = node.args[1] 200 _derive_div_qparams_with_const_fn = partial( 201 _derive_div_qparams_fn, const_val=const_inp_1 202 ) 203 204 q_min = ( 205 torch.iinfo(output_qspec.dtype).min 206 if output_qspec.quant_min is None 207 else output_qspec.quant_min 208 ) 209 q_max = ( 210 torch.iinfo(output_qspec.dtype).max 211 if output_qspec.quant_max is None 212 else output_qspec.quant_max 213 ) 214 return DerivedQuantizationSpec( 215 derived_from=[(inp_0, node)], 216 derive_qparams_fn=_derive_div_qparams_with_const_fn, 217 dtype=output_qspec.dtype, 218 quant_min=q_min, 219 quant_max=q_max, 220 ch_axis=0, 221 qscheme=output_qspec.qscheme, 222 ) 223 224 if [a for a in node.args if isinstance(a, Node)]: 225 annotate_binary(node, quantization_config) 226 # special constant divisor case 227 elif isinstance(node.args[0], Node) and isinstance(node.args[1], numbers.Number): 228 if _is_annotated([node]): 229 return 230 231 input_act_qspec = quantization_config.input_activation 232 output_act_qspec = _derived_inp1_const_div_quant_spec( 233 node, quantization_config.output_activation 234 ) 235 input_qspec_map = {} 236 input_act0 = node.args[0] 237 if _is_float_tensor(input_act0): 238 input_qspec_map[input_act0] = input_act_qspec 239 240 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 241 input_qspec_map=input_qspec_map, 242 output_qspec=output_act_qspec, 243 _annotated=True, 244 ) 245 else: 246 raise NotImplementedError(f"No quant annotation is implemented for {node}.") 247 248 249@register_annotator([torch.ops.aten.rsub.Scalar]) 250def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None: 251 annotate_binary(node, quantization_config) 252 253 254@register_annotator([torch.ops.aten.sum.dim_IntList]) 255def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None: 256 annotate_binary(node, quantization_config) 257 258 259@register_annotator([torch.ops.aten.ceil.default]) 260def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None: 261 annotate_single_in_single_out(node, quantization_config) 262 263 264@register_annotator([torch.ops.aten.clamp.default]) 265def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None: 266 annotate_single_in_single_out(node, quantization_config) 267 268 269@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) 270def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: 271 annotate_single_in_single_out(node, quantization_config) 272 273 274@register_annotator([torch.ops.aten.tanh.default]) 275def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None: 276 annotate_single_in_single_out(node, quantization_config) 277 278 279@register_annotator( 280 [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default] 281) 282def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> None: 283 annotate_single_in_single_out(node, quantization_config) 284 285 286@register_annotator( 287 [torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default] 288) 289def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None: 290 annotate_single_in_single_out(node, quantization_config) 291 292 293@register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default]) 294def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None: 295 annotate_single_in_single_out(node, quantization_config) 296 297 298@register_annotator([torch.ops.aten.mean.default]) 299def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None: 300 annotate_single_in_single_out(node, quantization_config) 301 302 303@register_annotator([torch.ops.aten.max_pool2d.default]) 304def annotate_max_pool2d(node: Node, quantization_config: QuantizationConfig) -> None: 305 annotate_single_in_single_out(node, quantization_config) 306 307 308@register_annotator([torch.ops.aten.max_pool2d_with_indices.default]) 309def annotate_max_pool2d_with_indices( 310 node: Node, quantization_config: QuantizationConfig 311) -> None: 312 annotate_single_in_single_out(node, quantization_config) 313 314 315@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default]) 316def annotate_adaptive_avgpool2d( 317 node: Node, quantization_config: QuantizationConfig 318) -> None: 319 annotate_single_in_single_out(node, quantization_config) 320 321 322@register_annotator([torch.ops.aten.avg_pool2d.default]) 323def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> None: 324 annotate_single_in_single_out(node, quantization_config) 325 326 327@register_annotator([torch.ops.aten.permute.default]) 328def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: 329 annotate_in_out_obs_sharing_op(node, quantization_config) 330 if not _is_annotated([node]): 331 annotate_single_in_single_out(node, quantization_config) 332 333 334@register_annotator( 335 [ 336 torch.ops.aten.leaky_relu.default, 337 torch.ops.aten.leaky_relu_.default, 338 torch.ops.aten.prelu.default, 339 ] 340) 341def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None: 342 annotate_single_in_single_out(node, quantization_config) 343 344 345@register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default]) 346def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: 347 annotate_in_out_obs_sharing_op(node, quantization_config) 348 if not _is_annotated([node]): 349 annotate_single_in_single_out(node, quantization_config) 350 351 352@register_annotator([torch.ops.aten.pixel_shuffle.default]) 353def annotate_pixel_shuffle_default( 354 node: Node, quantization_config: QuantizationConfig 355) -> None: 356 annotate_single_in_single_out(node, quantization_config) 357 358 359@register_annotator([torch.ops.aten.pixel_unshuffle.default]) 360def annotate_pixel_unshuffle_default( 361 node: Node, quantization_config: QuantizationConfig 362) -> None: 363 annotate_single_in_single_out(node, quantization_config) 364 365 366@register_annotator([torch.ops.aten.upsample_bilinear2d.vec]) 367def annotate_upsample_bilinear2d( 368 node: Node, quantization_config: QuantizationConfig 369) -> None: 370 annotate_single_in_single_out(node, quantization_config) 371 372 373@register_annotator([torch.ops.aten.upsample_nearest2d.vec]) 374def annotate_upsample_nearest2d( 375 node: Node, quantization_config: QuantizationConfig 376) -> None: 377 annotate_single_in_single_out(node, quantization_config) 378 379 380@register_annotator( 381 [ 382 torch.ops.aten.softmax.int, 383 torch.ops.aten._softmax.default, 384 torch.ops.aten._safe_softmax.default, 385 ] 386) 387def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None: 388 annotate_single_in_single_out(node, quantization_config) 389 390 391@register_annotator([torch.ops.aten.log_softmax.int]) 392def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None: 393 annotate_single_in_single_out(node, quantization_config) 394 395 396@register_annotator([torch.ops.aten.pad.default]) 397def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None: 398 annotate_single_in_single_out(node, quantization_config) 399 400 401@register_annotator([torch.ops.aten.reshape.default]) 402def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None: 403 annotate_single_in_single_out(node, quantization_config) 404 405 406@register_annotator([torch.ops.aten.select.int]) 407def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None: 408 annotate_single_in_single_out(node, quantization_config) 409 410 411@register_annotator([torch.ops.aten.mean.dim]) 412def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None: 413 annotate_single_in_single_out(node, quantization_config) 414 415 416@register_annotator([torch.ops.aten.slice.Tensor]) 417def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: 418 annotate_single_in_single_out(node, quantization_config) 419 420 421@register_annotator([torch.ops.aten.sqrt.default]) 422def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: 423 annotate_single_in_single_out(node, quantization_config) 424 425 426@register_annotator([torch.ops.aten.gelu.default]) 427def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None: 428 annotate_single_in_single_out(node, quantization_config) 429 430 431@register_annotator([torch.ops.aten.scaled_dot_product_attention.default]) 432def annotate_scaled_dot_product_attention( 433 node: Node, quantization_config: QuantizationConfig 434) -> None: 435 annotate_single_in_single_out(node, quantization_config) 436 437 438@register_annotator( 439 [ 440 torch.ops.aten.squeeze.default, 441 torch.ops.aten.squeeze.dim, 442 torch.ops.aten.squeeze_copy.dims, 443 ] 444) 445def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None: 446 annotate_in_out_obs_sharing_op(node, quantization_config) 447 if not _is_annotated([node]): 448 annotate_single_in_single_out(node, quantization_config) 449 450 451@register_annotator([torch.ops.aten.rms_norm.default]) 452def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None: 453 act_node = node.args[0] 454 weight_node = node.args[2] 455 456 if _is_annotated([node]): 457 return 458 459 # TODO current only support 16a16w 460 _annotate_input_qspec_map( 461 node, 462 act_node, 463 quantization_config.input_activation, 464 ) 465 466 _annotate_input_qspec_map( 467 node, 468 weight_node, 469 quantization_config.input_activation, 470 ) 471 nodes_to_mark_annotated = [node] 472 _annotate_output_qspec(node, quantization_config.output_activation) 473 _mark_nodes_as_annotated(nodes_to_mark_annotated) 474 475 476@register_annotator([torch.ops.aten.rsqrt.default]) 477def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None: 478 annotate_single_in_single_out(node, quantization_config) 479 480 481@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default]) 482def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None: 483 if _is_annotated([node]): 484 return 485 486 input_qspec_map = {} 487 input_act = node.args[0] 488 input_qspec_map[input_act] = quantization_config.input_activation 489 490 assert isinstance(input_act, Node) 491 out_qconf = quantization_config.output_activation 492 493 q_max = ( 494 torch.iinfo(out_qconf.dtype).max 495 if out_qconf.quant_max is None 496 else out_qconf.quant_max 497 ) 498 q_min = ( 499 torch.iinfo(out_qconf.dtype).min 500 if out_qconf.quant_min is None 501 else out_qconf.quant_min 502 ) 503 504 scale = 1 / (q_max - q_min + 1) 505 506 bias_obs_ctr = observer = FixedQParamsObserver.with_args( 507 scale=scale, 508 zero_point=0, 509 dtype=quantization_config.output_activation.dtype, 510 qscheme=torch.torch.per_tensor_affine, 511 quant_max=q_max, 512 quant_min=q_min, 513 ) 514 if quantization_config in ( 515 get_8a8w_qnn_qat_config(), 516 get_16a4w_qnn_qat_config(), 517 ): 518 bias_obs_ctr = FixedQParamsFakeQuantize.with_args( 519 observer=observer, 520 scale=scale, 521 zero_point=0, 522 dtype=quantization_config.output_activation.dtype, 523 qscheme=torch.torch.per_tensor_affine, 524 quant_max=q_max, 525 quant_min=q_min, 526 ) 527 528 # make sigmoid map to the range between 0~1 529 out_act_quantization_spec = QuantizationSpec( 530 dtype=quantization_config.output_activation.dtype, 531 quant_max=q_max, 532 quant_min=q_min, 533 observer_or_fake_quant_ctr=bias_obs_ctr, 534 qscheme=torch.torch.per_tensor_affine, 535 ) 536 537 if _is_float_tensor(node): 538 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 539 input_qspec_map=input_qspec_map, 540 output_qspec=out_act_quantization_spec, 541 _annotated=True, 542 ) 543 544 545@register_annotator([torch.ops.aten.pow.Tensor_Scalar]) 546def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None: 547 annotate_single_in_single_out(node, quantization_config) 548 549 550@register_annotator([torch.ops.aten.unsqueeze.default]) 551def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None: 552 annotate_in_out_obs_sharing_op(node, quantization_config) 553 if not _is_annotated([node]): 554 annotate_single_in_single_out(node, quantization_config) 555 556 557@register_annotator( 558 [ 559 torch.ops.aten.unsqueeze_copy.default, 560 ] 561) 562def annotate_unsqueeze_copy( 563 node: Node, quantization_config: QuantizationConfig 564) -> None: 565 annotate_in_out_obs_sharing_op(node, quantization_config) 566 if not _is_annotated([node]): 567 annotate_single_in_single_out(node, quantization_config) 568 569 570@register_annotator([torch.ops.aten.transpose.int]) 571def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: 572 annotate_in_out_obs_sharing_op(node, quantization_config) 573 if not _is_annotated([node]): 574 annotate_single_in_single_out(node, quantization_config) 575 576 577@register_annotator([torch.ops.aten.embedding.default]) 578def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None: 579 weight = node.args[0] 580 581 input_qspec_map = {} 582 input_qspec_map[weight] = quantization_config.input_activation 583 584 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 585 input_qspec_map=input_qspec_map, 586 output_qspec=SharedQuantizationSpec((weight, node)), 587 _annotated=True, 588 ) 589 590 591@register_annotator([torch.ops.aten.index.Tensor]) 592def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: 593 annotate_in_out_obs_sharing_op(node, quantization_config) 594 if not _is_annotated([node]): 595 input_qspec_map = {} 596 input = node.args[0] 597 input_qspec_map[input] = quantization_config.input_activation 598 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 599 input_qspec_map=input_qspec_map, 600 output_qspec=SharedQuantizationSpec((input, node)), 601 _annotated=True, 602 ) 603 604 605@register_annotator( 606 [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] 607) 608def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: 609 input = node.args[0] 610 value = node.args[2] 611 612 input_qspec_map = {} 613 input_qspec_map[input] = quantization_config.input_activation 614 input_qspec_map[value] = SharedQuantizationSpec((input, node)) 615 616 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 617 input_qspec_map=input_qspec_map, 618 output_qspec=SharedQuantizationSpec((input, node)), 619 _annotated=True, 620 ) 621 622 623@register_annotator([torch.ops.aten.expand.default]) 624def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None: 625 annotate_in_out_obs_sharing_op(node, quantization_config) 626 if not _is_annotated([node]): 627 annotate_single_in_single_out(node, quantization_config) 628 629 630@register_annotator([torch.ops.aten.group_norm.default]) 631def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None: 632 act_node = node.args[0] 633 weight_node = node.args[2] 634 bias_node = None 635 if len(node.args) > 2: 636 bias_node = node.args[3] 637 638 if _is_annotated([node]): 639 return 640 641 _annotate_input_qspec_map( 642 node, 643 act_node, 644 quantization_config.input_activation, 645 ) 646 _annotate_input_qspec_map( 647 node, 648 weight_node, 649 quantization_config.weight, 650 ) 651 nodes_to_mark_annotated = [node, weight_node] 652 if bias_node: 653 _annotate_input_qspec_map( 654 node, 655 bias_node, 656 quantization_config.bias, 657 ) 658 nodes_to_mark_annotated.append(bias_node) 659 _annotate_output_qspec(node, quantization_config.output_activation) 660 _mark_nodes_as_annotated(nodes_to_mark_annotated) 661 662 663@register_annotator([torch.ops.aten.flatten.using_ints]) 664def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None: 665 annotate_in_out_obs_sharing_op(node, quantization_config) 666 if not _is_annotated([node]): 667 annotate_single_in_single_out(node, quantization_config) 668 669 670@register_annotator([torch.ops.aten.stack.default]) 671def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None: 672 input_qspec_map = {} 673 for input_act in node.args[0]: 674 assert isinstance(input_act, Node) 675 input_qspec_map[input_act] = quantization_config.input_activation 676 677 node_tensor = node.meta.get("val") 678 if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: 679 continue 680 681 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 682 input_qspec_map=input_qspec_map, 683 output_qspec=quantization_config.output_activation, 684 _annotated=True, 685 ) 686 687 688@register_annotator([torch.ops.aten.matmul.default]) 689def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None: 690 if _is_annotated([node]): 691 return 692 693 input_act_qspec = quantization_config.input_activation 694 output_act_qspec = quantization_config.output_activation 695 696 input_qspec_map = {} 697 input_act0 = node.args[0] 698 if isinstance(input_act0, Node): 699 input_qspec_map[input_act0] = input_act_qspec 700 701 input_act1 = node.args[1] 702 if isinstance(input_act1, Node): 703 # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. 704 if input_act_qspec.dtype == torch.int32: 705 # we should use int16 for mm / bmm instead of int4 706 input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight 707 else: 708 input_qspec_map[input_act1] = input_act_qspec 709 710 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 711 input_qspec_map=input_qspec_map, 712 output_qspec=output_act_qspec, 713 _annotated=True, 714 ) 715 716 717@register_annotator([torch.ops.aten.bmm.default]) 718def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: 719 if _is_annotated([node]): 720 return 721 722 input_act_qspec = quantization_config.input_activation 723 output_act_qspec = quantization_config.output_activation 724 725 input_qspec_map = {} 726 input_act0 = node.args[0] 727 if isinstance(input_act0, Node): 728 input_qspec_map[input_act0] = input_act_qspec 729 730 input_act1 = node.args[1] 731 if isinstance(input_act1, Node): 732 # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. 733 if input_act_qspec.dtype == torch.int32: 734 # we should use int16 for mm / bmm instead of int4 735 input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight 736 else: 737 input_qspec_map[input_act1] = input_act_qspec 738 739 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 740 input_qspec_map=input_qspec_map, 741 output_qspec=output_act_qspec, 742 _annotated=True, 743 ) 744 745 # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. 746 node.meta["source_fn_stack"] = [(node, torch.bmm)] 747 748 749@register_annotator( 750 [ 751 torch.ops.aten.conv2d.default, 752 torch.ops.aten.conv1d.default, 753 torch.ops.aten.conv_transpose2d.input, 754 ] 755) 756def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: 757 if _is_annotated([node]): 758 return 759 760 input_qspec_map = {} 761 input_act = node.args[0] 762 assert isinstance(input_act, Node) 763 input_spec = quantization_config.input_activation 764 input_qspec_map[input_act] = input_spec 765 766 weight = node.args[1] 767 assert isinstance(weight, Node) 768 input_qspec_map[weight] = quantization_config.weight 769 770 if len(node.args) > 2: 771 bias = node.args[2] 772 if isinstance(bias, Node): 773 if callable(quantization_config.bias): 774 input_qspec_map[bias] = quantization_config.bias(node) 775 else: 776 input_qspec_map[bias] = quantization_config.bias 777 778 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 779 input_qspec_map=input_qspec_map, 780 output_qspec=quantization_config.output_activation, 781 _annotated=True, 782 ) 783 784 785@register_annotator([torch.ops.aten.linear.default]) 786def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None: 787 act_node = node.args[0] 788 weight_node = node.args[1] 789 bias_node = None 790 if len(node.args) > 2: 791 bias_node = node.args[2] 792 793 if _is_annotated([node]): 794 return 795 796 _annotate_input_qspec_map( 797 node, 798 act_node, 799 quantization_config.input_activation, 800 ) 801 _annotate_input_qspec_map( 802 node, 803 weight_node, 804 quantization_config.weight, 805 ) 806 nodes_to_mark_annotated = [node, weight_node] 807 if bias_node: 808 if callable(quantization_config.bias): 809 bias_config = quantization_config.bias(node) 810 else: 811 bias_config = quantization_config.bias 812 _annotate_input_qspec_map(node, bias_node, bias_config) 813 nodes_to_mark_annotated.append(bias_node) 814 _annotate_output_qspec(node, quantization_config.output_activation) 815 _mark_nodes_as_annotated(nodes_to_mark_annotated) 816 817 # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. 818 node.meta["source_fn_stack"] = [(node, torch.nn.Linear)] 819 820 821@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default]) 822def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None: 823 act, weight, bias = node.args[0:3] 824 if _is_annotated([node]): 825 return 826 827 _annotate_input_qspec_map( 828 node, 829 act, 830 quantization_config.input_activation, 831 ) 832 # QNN requires uint8 instead of int8 in 'weight' config 833 _annotate_input_qspec_map( 834 node, 835 weight, 836 quantization_config.input_activation, 837 ) 838 _annotate_input_qspec_map( 839 node, 840 bias, 841 quantization_config.bias, 842 ) 843 _annotate_output_qspec(node, quantization_config.output_activation) 844 _mark_nodes_as_annotated([node, *node.args[0:3]]) 845 846 847@register_annotator([operator.getitem]) 848def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None: 849 if _is_annotated([node]): 850 return 851 852 if _is_float_tensor(node): 853 _annotate_output_qspec(node, quantization_config.output_activation) 854 _mark_nodes_as_annotated([node]) 855 856 857@register_annotator([torch.ops.aten.layer_norm.default]) 858def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None: 859 act_node = node.args[0] 860 weight_node = node.args[2] 861 bias_node = None 862 if len(node.args) > 2: 863 bias_node = node.args[3] 864 865 if _is_annotated([node]): 866 return 867 input_act_qspec = quantization_config.input_activation 868 869 _annotate_input_qspec_map( 870 node, 871 act_node, 872 input_act_qspec, 873 ) 874 if input_act_qspec.dtype == torch.int32: 875 _annotate_input_qspec_map( 876 node, 877 weight_node, 878 get_16a16w_qnn_ptq_config().weight, 879 ) 880 else: 881 _annotate_input_qspec_map( 882 node, 883 weight_node, 884 input_act_qspec, 885 ) 886 nodes_to_mark_annotated = [node, weight_node] 887 if bias_node: 888 _annotate_input_qspec_map( 889 node, 890 bias_node, 891 quantization_config.bias, 892 ) 893 nodes_to_mark_annotated.append(bias_node) 894 _annotate_output_qspec(node, quantization_config.output_activation) 895 _mark_nodes_as_annotated(nodes_to_mark_annotated) 896 897 898@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default]) 899def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None: 900 input_nodes = node.args[0] 901 if _is_annotated([node]): 902 return 903 904 assert isinstance(input_nodes, Sequence) 905 906 first_input_node = input_nodes[0] 907 input_qspec_map = {} 908 assert isinstance(first_input_node, Node) 909 assert isinstance(node, Node) 910 input_qspec_map[first_input_node] = quantization_config.input_activation 911 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 912 (first_input_node, node) 913 ) 914 915 for input_node in input_nodes[1:]: 916 if input_node not in input_qspec_map: 917 assert isinstance(input_node, Node) 918 input_qspec_map[input_node] = share_qparams_with_input_act0_qspec 919 920 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 921 input_qspec_map=input_qspec_map, 922 output_qspec=share_qparams_with_input_act0_qspec, 923 _annotated=True, 924 ) 925 926 927@register_annotator([torch.ops.aten.unbind.int]) 928def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None: 929 if _is_annotated([node]): 930 return 931 932 input_qspec_map = {} 933 input_act = node.args[0] 934 assert isinstance(input_act, Node) 935 input_qspec_map[input_act] = quantization_config.input_activation 936 937 node_tensor = node.meta.get("val") 938 if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64: 939 return 940 941 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 942 input_qspec_map=input_qspec_map, 943 _annotated=True, 944 ) 945 946 947@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default]) 948def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: 949 if _is_annotated([node]): 950 return 951 952 input_qspec_map = {} 953 input_act = node.args[0] 954 assert isinstance(input_act, Node) 955 input_qspec_map[input_act] = quantization_config.input_activation 956 957 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 958 input_qspec_map=input_qspec_map, 959 _annotated=True, 960 ) 961 962 for user in node.users: 963 user.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 964 output_qspec=quantization_config.output_activation, 965 _annotated=True, 966 ) 967