1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from dataclasses import dataclass 10 11import torch 12 13from torch.ao.quantization.quantizer import ( 14 FixedQParamsQuantizationSpec, 15 QuantizationSpec, 16) 17 18 19@dataclass(eq=True, frozen=True) 20class QuantizationConfig: 21 input_activation: QuantizationSpec | None 22 output_activation: QuantizationSpec | None 23 weight: QuantizationSpec | None 24 bias: QuantizationSpec | None 25 26 def get_input_act_qspec(self) -> QuantizationSpec | None: 27 """Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid.""" 28 if self.input_activation is None: 29 return None 30 assert self.input_activation.qscheme in [ 31 torch.per_tensor_affine, 32 torch.per_tensor_symmetric, 33 ], f"Unsupported quantization_spec {self.input_activation} for input_activation." 34 return self.input_activation 35 36 def get_output_act_qspec(self) -> QuantizationSpec | None: 37 """Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid.""" 38 if self.output_activation is None: 39 return None 40 assert self.output_activation.qscheme in [ 41 torch.per_tensor_affine, 42 torch.per_tensor_symmetric, 43 ], f"Unsupported quantization_spec {self.output_activation} for output_activation." 44 return self.output_activation 45 46 def get_weight_qspec(self) -> QuantizationSpec | None: 47 """Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid.""" 48 if self.weight is None: 49 return None 50 assert self.weight.qscheme in [ 51 torch.per_tensor_symmetric, 52 torch.per_channel_symmetric, 53 ], f"Unsupported quantization_spec {self.weight} for weight" 54 return self.weight 55 56 def get_bias_qspec(self) -> QuantizationSpec | None: 57 """Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float.""" 58 if self.bias is None: 59 return None 60 assert ( 61 self.bias.dtype == torch.float 62 ), "Only float dtype for bias is supported for bias right now" 63 return self.bias 64 65 def get_fixed_qspec( 66 self, 67 scale: float, 68 zp: int, 69 dtype: torch.dtype = torch.int8, 70 quant_min: int = -128, 71 quant_max: int = 127, 72 ) -> FixedQParamsQuantizationSpec: 73 """Returns a new FixedQParamsQuantizationSpec with the given parameters.""" 74 return FixedQParamsQuantizationSpec( 75 dtype=dtype, 76 qscheme=torch.per_tensor_affine, 77 scale=scale, 78 zero_point=zp, 79 quant_min=quant_min, 80 quant_max=quant_max, 81 ) 82