# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique from typing import Callable, Optional, Sequence, Set import torch from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, ) from executorch.backends.qualcomm._passes.reduce_dynamic_range import ReduceDynamicRange from executorch.backends.qualcomm._passes.replace_inf_buffer import ReplaceInfBuffer from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) from torch._ops import OpOverload from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule from .annotators import OP_ANNOTATOR from .qconfig import ( get_16a16w_qnn_ptq_config, get_16a4w_qnn_ptq_config, get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, get_qat_per_channel_quant_config, QuantizationConfig, ) # To bypass the meta internal test error get_default_16bit_qnn_ptq_config = get_16a16w_qnn_ptq_config __all__ = [ "QnnQuantizer", "QuantDtype", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", "get_16a16w_qnn_ptq_config", "get_8a8w_qnn_ptq_config", "get_8a8w_qnn_qat_config", "get_16a4w_qnn_qat_config", ] @unique class QuantDtype(IntEnum): """ bits of activation and bits of weight """ use_16a16w = 0 use_16a8w = 1 use_16a4w = 2 use_8a8w = 3 quant_config_dict = { # PTQ (QuantDtype.use_16a16w, False): ( get_16a16w_qnn_ptq_config, get_ptq_per_channel_quant_config(torch.uint16, torch.int16), ), (QuantDtype.use_16a8w, False): ( get_16a8w_qnn_ptq_config, get_ptq_per_channel_quant_config(torch.uint16, torch.int8), ), (QuantDtype.use_16a4w, False): ( get_16a4w_qnn_ptq_config, get_ptq_per_channel_quant_config(torch.uint16, "int4"), ), (QuantDtype.use_8a8w, False): ( get_8a8w_qnn_ptq_config, get_ptq_per_channel_quant_config(), ), # QAT, (QuantDtype.use_16a4w, True): ( get_16a4w_qnn_qat_config, get_qat_per_channel_quant_config(torch.uint16, "int4"), ), (QuantDtype.use_8a8w, True): ( get_8a8w_qnn_qat_config, get_qat_per_channel_quant_config(), ), } class QnnQuantizer(Quantizer): SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) def __init__(self): super().__init__() self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() self.is_qat = False self.quant_dtype = QuantDtype.use_8a8w self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() self.per_channel_quant_config = get_ptq_per_channel_quant_config() self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() def _annotate(self, gm: GraphModule) -> None: for node in gm.graph.nodes: if node.name in self.discard_nodes: continue quant_config = self._get_quant_config(node.target) if quant_config: OP_ANNOTATOR[node.target](node, quant_config) def _annotate_custom_annotation(self, gm: GraphModule) -> None: for annotation_func in self.custom_quant_annotations: annotation_func(gm) def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]: """ Priority: 1. is one of use_per_channel_weight_quant_ops 2. quant config """ if isinstance(op, str): return if op in self.use_per_channel_weight_quant_ops: return self.per_channel_quant_config if op in self.quant_ops: return self.quant_config print(f"No quant config is implemented for op, {op}") def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): if enable: self.use_per_channel_weight_quant_ops.update(ops) else: self.use_per_channel_weight_quant_ops.difference_update(ops) def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] ) -> None: self.custom_quant_annotations = custom_quant_annotations def add_discard_nodes(self, nodes: Sequence[str]) -> None: self.discard_nodes = set(nodes) def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: for op in ops: self.quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) self._annotate_custom_annotation(model) return model def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS def set_quant_config( self, quant_dtype: QuantDtype, is_qat=False, act_observer=None ) -> None: self.quant_dtype = quant_dtype self.is_qat = is_qat if (quant_dtype, is_qat) not in quant_config_dict: raise RuntimeError( f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" ) quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ (quant_dtype, is_qat) ] self.quant_config = ( quant_config_fuc(act_observer) if act_observer else quant_config_fuc() ) def set_per_channel_conv_quant(self, enable: bool) -> None: conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} self._update_per_channel_weight_quant_ops(conv_ops, enable) def set_per_channel_linear_quant(self, enable: bool) -> None: linear_ops = { torch.ops.aten.linear.default, } self._update_per_channel_weight_quant_ops(linear_ops, enable) def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = ReduceDynamicRange()(model).graph_module model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module model = DecomposeScaledDotProductAttention()(model).graph_module model = DecomposeSilu()(model).graph_module model = DecomposeEinsum()(model).graph_module model = ReplaceInfBuffer()(model).graph_module return model def validate(self, model: GraphModule) -> None: pass