from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import Tensor from torch.ao.quantization.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) from torch.ao.quantization.observer import ( MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, ) from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec from torch.fx import Node @dataclass(eq=True, frozen=True) class QuantizationConfig: input_activation: Optional[QuantizationSpec] output_activation: Optional[QuantizationSpec] weight: Optional[QuantizationSpec] bias: Optional[QuantizationSpec | Callable] def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: def _derive_bias_qparams_fn( obs_or_fqs: List, ) -> Tuple[Tensor, Tensor]: assert ( len(obs_or_fqs) == 2 ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" act_obs_or_fq = obs_or_fqs[0] weight_obs_or_fq = obs_or_fqs[1] weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() act_scale, act_zp = act_obs_or_fq.calculate_qparams() (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( act_scale, weight_scale ) derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) return (derived_scale, derived_zero) input_act = node.args[0] assert isinstance(input_act, Node) weight = node.args[1] assert isinstance(weight, Node) return DerivedQuantizationSpec( derived_from=[(input_act, node), (weight, node)], derive_qparams_fn=_derive_bias_qparams_fn, dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, ch_axis=0, qscheme=torch.per_channel_symmetric, ) def get_8a8w_qnn_ptq_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} act_quantization_spec = QuantizationSpec( dtype=torch.uint8, qscheme=( torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine ), ch_axis=0, observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config # 4 bits quantization only supports specific ops. def get_16a4w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-7, quant_max=7, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config def get_16a8w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( dtype=torch.uint8, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config def get_16a16w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( dtype=torch.int16, quant_min=torch.iinfo(torch.int16).min + 1, quant_max=torch.iinfo(torch.int16).max, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) # torch does not support uint16 quantization, use int32 to bypass bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config def get_ptq_per_channel_quant_config( act_dtype=torch.uint8, weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} supported_act_types = { torch.uint8, torch.uint16, torch.int8, torch.int16, } # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype supported_weight_dtypes = {"int4", torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" assert ( weight_dtype in supported_weight_dtypes ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" # torch do not support uint16 quantization, use int32 to bypass act_quantization_spec = QuantizationSpec( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), ) bias_quantization_spec = _derived_bias_quant_spec quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config # TODO merge qat and ptq to a fucntion, and use a bool flag to control it def get_8a8w_qnn_qat_config( act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: act_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.uint8, qscheme=( torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine ), reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( dtype=torch.uint8, qscheme=( torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine ), ch_axis=0, observer_or_fake_quant_ctr=act_fake_quant_ctr, ) weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int8, quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=torch.iinfo(torch.int8).min + 1, quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=weight_fake_quant_ctr, ) bias_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, observer_or_fake_quant_ctr=bias_fake_quant_ctr, ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config def get_16a4w_qnn_qat_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: act_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.uint16).min, quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=act_fake_quant_ctr, ) weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int8, quant_min=-7, quant_max=7, qscheme=torch.per_tensor_symmetric, ch_axis=0, reduce_range=True, observer=MovingAverageMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=-7, quant_max=7, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=weight_fake_quant_ctr, ) bias_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, reduce_range=True, observer=MovingAverageMinMaxObserver, ) bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, observer_or_fake_quant_ctr=bias_fake_quant_ctr, ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config def get_qat_per_channel_quant_config( act_dtype=torch.uint8, weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: supported_act_types = { torch.uint8, torch.uint16, torch.int8, torch.int16, } # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype supported_weight_dtypes = {"int4", torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" assert ( weight_dtype in supported_weight_dtypes ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" # torch do not support uint16 quantization, use int32 to bypass act_fake_quant_ctr = FakeQuantize.with_args( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, reduce_range=True, observer=act_observer, ) act_quantization_spec = QuantizationSpec( dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=act_fake_quant_ctr, ) weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer=MovingAveragePerChannelMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=weight_fake_quant_ctr, ) bias_quantization_spec = _derived_bias_quant_spec quantization_config = QuantizationConfig( input_activation=act_quantization_spec, output_activation=act_quantization_spec, weight=weight_quantization_spec, bias=bias_quantization_spec, ) return quantization_config