xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import numpy as np
3
4import torch
5from torch import Tensor
6from torch.ao.quantization.experimental.apot_utils import (
7    apot_to_float,
8    float_to_apot,
9    quant_dequant_util,
10)
11
12
13# class to store APoT quantizer and
14# implement quantize and dequantize
15class APoTQuantizer:
16    alpha: torch.Tensor
17    gamma: torch.Tensor
18    quantization_levels: torch.Tensor
19    level_indices: torch.Tensor
20
21    def __init__(
22        self,
23        alpha: torch.Tensor,
24        gamma: torch.Tensor,
25        quantization_levels: torch.Tensor,
26        level_indices: torch.Tensor,
27    ) -> None:
28        self.alpha = alpha
29        self.gamma = gamma
30        self.quantization_levels = quantization_levels
31        self.level_indices = level_indices
32
33    r""" Quantizes fp Tensor to integer APoT representation.
34    Conversion is based on the qparams from a specified APoT non-uniform observer.
35    The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
36    Args:
37        tensor2quantize: fp Tensor
38    Returns:
39        result: APoT Tensor representation of tensor2quantize
40    """
41
42    def quantize(self, tensor2quantize: Tensor):
43        result = torch.tensor([])
44
45        # map float_to_apot over tensor2quantize elements
46        tensor2quantize = tensor2quantize.detach().apply_(
47            lambda x: float_to_apot(
48                x, self.quantization_levels, self.level_indices, self.alpha
49            )
50        )
51
52        # convert to APoT int representation for dtype
53        tensor2quantize = tensor2quantize.int()
54
55        from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
56
57        result = TensorAPoT(self, tensor2quantize)  # type: ignore[assignment]
58
59        return result
60
61    r""" Dequantizes integer Tensor to floating point (fp) representation
62    based on the calculated quantization levels from a specified APoT non-uniform observer.
63    The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
64    Args:
65        tensor2quantize: fp Tensor
66    Returns:
67        result: fp reduced precision representation of input Tensor
68    """
69
70    def dequantize(self, apot_tensor) -> Tensor:
71        orig_size = apot_tensor.data.size()
72        apot_tensor_data = apot_tensor.data.flatten()
73
74        print(apot_tensor_data)
75
76        # map apot_to_float over tensor2quantize elements
77        result_temp = np.empty(shape=apot_tensor_data.size())
78        for i in range(len(apot_tensor_data)):
79            new_ele = apot_to_float(
80                apot_tensor_data[i], self.quantization_levels, self.level_indices
81            )
82            result_temp[i] = new_ele
83
84        result = torch.from_numpy(result_temp).reshape(orig_size)
85
86        return result
87
88    r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
89    based on the calculated quantization levels from a specified APoT non-uniform observer.
90    The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
91    Args:
92        apot_tensor: quantized APoT Tensor to dequantize
93    Returns:
94        result: fp representation of input Tensor
95    """
96
97    def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
98        levels_lst = list(self.quantization_levels)
99
100        result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))  # type: ignore[call-arg]
101
102        return result
103
104    def q_apot_alpha(self) -> float:
105        raise NotImplementedError
106
107
108r""" Global method to create quantizer and call quantizer quantize_APoT
109    Args:
110        tensor2quantize: fp Tensor to quantize
111        alpha: Tensor qparam alpha (clipping level)
112        gamma: Tensor qparam gamma (scale factor for quantization levels)
113        quantization levels: Tensor with fp quantization levels
114        level indices: Tensor with integer quantization level indices
115    Returns:
116        result: ApoT Tensor representation of tensor2quantize
117"""
118
119
120def quantize_APoT(
121    tensor2quantize: Tensor,
122    alpha: Tensor,
123    gamma: Tensor,
124    quantization_levels: Tensor,
125    level_indices: Tensor,
126):
127    quantizer = APoTQuantizer(
128        alpha=alpha,
129        gamma=gamma,
130        quantization_levels=quantization_levels,
131        level_indices=level_indices,
132    )
133    result = quantizer.quantize(tensor2quantize)
134    return result
135
136
137r""" Global method to create quantizer and call quantizer dequantize_APoT
138    Args:
139        apot_tensor: APoT Tensor to dequantize
140    Returns:
141        result: fp Tensor dequantized from apot_tensor
142"""
143
144
145def dequantize_APoT(apot_tensor) -> Tensor:
146    quantizer = apot_tensor.quantizer
147    result = quantizer.dequantize(apot_tensor)
148    return result
149
150
151r""" Global method to create quantizer and call quantizer quant_dequant
152    Args:
153        tensor2quantize: fp Tensor to quantize
154        alpha: Tensor qparam alpha (clipping level)
155        gamma: Tensor qparam gamma (scale factor for quantization levels)
156        quantization levels: Tensor with fp quantization levels
157        level indices: Tensor with integer quantization level indices
158    Returns:
159        result: fp reduced precision Tensor from tensor2quantize
160"""
161
162
163def quant_dequant_APoT(
164    tensor2quantize: Tensor,
165    alpha: Tensor,
166    gamma: Tensor,
167    quantization_levels: Tensor,
168    level_indices: Tensor,
169) -> Tensor:
170    quantizer = APoTQuantizer(
171        alpha=alpha,
172        gamma=gamma,
173        quantization_levels=quantization_levels,
174        level_indices=level_indices,
175    )
176    result = quantizer.quant_dequant(tensor2quantize)
177    return result
178