xref: /aosp_15_r20/external/executorch/backends/arm/quantizer/quantization_config.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from dataclasses import dataclass
10
11import torch
12
13from torch.ao.quantization.quantizer import (
14    FixedQParamsQuantizationSpec,
15    QuantizationSpec,
16)
17
18
19@dataclass(eq=True, frozen=True)
20class QuantizationConfig:
21    input_activation: QuantizationSpec | None
22    output_activation: QuantizationSpec | None
23    weight: QuantizationSpec | None
24    bias: QuantizationSpec | None
25
26    def get_input_act_qspec(self) -> QuantizationSpec | None:
27        """Returns QuantizationSpec 'input_activation' after asserting that input_activation.qscheme is valid."""
28        if self.input_activation is None:
29            return None
30        assert self.input_activation.qscheme in [
31            torch.per_tensor_affine,
32            torch.per_tensor_symmetric,
33        ], f"Unsupported quantization_spec {self.input_activation} for input_activation."
34        return self.input_activation
35
36    def get_output_act_qspec(self) -> QuantizationSpec | None:
37        """Returns QuantizationSpec 'output_activation' after asserting that output_activation.qscheme is valid."""
38        if self.output_activation is None:
39            return None
40        assert self.output_activation.qscheme in [
41            torch.per_tensor_affine,
42            torch.per_tensor_symmetric,
43        ], f"Unsupported quantization_spec {self.output_activation} for output_activation."
44        return self.output_activation
45
46    def get_weight_qspec(self) -> QuantizationSpec | None:
47        """Returns QuantizationSpec 'weight' after asserting that weight.qscheme is valid."""
48        if self.weight is None:
49            return None
50        assert self.weight.qscheme in [
51            torch.per_tensor_symmetric,
52            torch.per_channel_symmetric,
53        ], f"Unsupported quantization_spec {self.weight} for weight"
54        return self.weight
55
56    def get_bias_qspec(self) -> QuantizationSpec | None:
57        """Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float."""
58        if self.bias is None:
59            return None
60        assert (
61            self.bias.dtype == torch.float
62        ), "Only float dtype for bias is supported for bias right now"
63        return self.bias
64
65    def get_fixed_qspec(
66        self,
67        scale: float,
68        zp: int,
69        dtype: torch.dtype = torch.int8,
70        quant_min: int = -128,
71        quant_max: int = 127,
72    ) -> FixedQParamsQuantizationSpec:
73        """Returns a new FixedQParamsQuantizationSpec with the given parameters."""
74        return FixedQParamsQuantizationSpec(
75            dtype=dtype,
76            qscheme=torch.per_tensor_affine,
77            scale=scale,
78            zero_point=zp,
79            quant_min=quant_min,
80            quant_max=quant_max,
81        )
82