xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/observer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module implements nonuniform observers used to collect statistics about
4the values observed during calibration (PTQ) or training (QAT).
5"""
6
7import itertools
8
9import matplotlib.pyplot as plt
10
11import torch
12from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot
13from torch.ao.quantization.observer import ObserverBase
14
15
16# TODO: Consider adding NonUniformQuantizationObserverBase class
17# when more than one non-uniform method is implemented
18
19
20class APoTObserver(ObserverBase):
21    b: int
22    k: int
23    n: int
24    min_val: torch.Tensor
25    max_val: torch.Tensor
26
27    def __init__(self, b, k, dtype=torch.quint8) -> None:
28        super().__init__(dtype)
29        self.b = b
30        self.k = k
31
32        self.min_val = torch.tensor([])
33        self.max_val = torch.tensor([])
34
35    # min_val and max_val are optional args to override
36    # the min_val and max_val observed by forward
37    def calculate_qparams(self, signed):
38        return self._calculate_qparams(signed, self.min_val, self.max_val)
39
40    r""" Calculates nonuniform quantization parameters according to APoT paper:
41    https://arxiv.org/pdf/1909.13144.pdf.
42    Arg:
43        signed: specifies whether to include signed values in quantization level calculations
44        min_val: optional arg that can override min_val internal attribute
45        max_val: optional arg that can override max_val internal attribute
46    Returns:
47        alpha: alpha quantization parameter, max of abs value of observed values
48        gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
49        quantization_levels: non-uniform quantization levels (fp representation)
50        level_indices: int representation of quantization_levels indices
51    """
52
53    def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
54        if min_val is not None:
55            self.min_val = min_val
56        if max_val is not None:
57            self.max_val = max_val
58
59        # compute alpha
60        alpha = torch.max(-self.min_val, self.max_val)
61
62        # check for valid inputs of b, k
63        assert self.k and self.k != 0
64        assert self.b % self.k == 0
65
66        # compute n and store as member variable
67        self.n = self.b // self.k
68
69        # store a tensor of subtensors (all levels)
70        p_all = []
71
72        # create levels
73        for i in range(0, self.n):
74            p_curr = torch.tensor([0])
75
76            for j in range(0, (2**self.k - 2) + 1):
77                curr_ele = 2 ** (-(i + j * self.n))
78                p_append = torch.tensor([curr_ele])
79                p_curr = torch.cat((p_curr, p_append))
80                # introduce signed numbers
81                if signed:
82                    p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
83
84            if signed:
85                # sort tensor in reverse order before adding to list if signed
86                sorted, indices = torch.sort(p_curr, descending=True)
87                p_all.append(sorted)
88            else:
89                p_all.append(p_curr)
90
91        # gamma calculation:
92        # loop through all tensors
93        # if signed, add element at index 0 for each tensor
94        # else, add element at index 1 for each tensor
95        # gamma defined to ensure alpha is at max of range
96        p_sum = 0.0
97        for tens in p_all:
98            if signed:
99                p_sum += float(tens[0])
100            else:
101                p_sum += float(tens[1])
102
103        # assign gamma
104        gamma = alpha / p_sum
105
106        # calculate cartesian product
107        cartesian_product = list(itertools.product(*p_all))
108
109        quantization_levels_list = []
110
111        # calculate sum of each row
112        for row in cartesian_product:
113            sum = 0.0
114            for ele in row:
115                sum += ele
116            quantization_levels_list.append(sum)
117
118        quantization_levels_gamma = [
119            float(gamma) * ele for ele in quantization_levels_list
120        ]
121        quantization_levels = torch.tensor(quantization_levels_gamma)
122        level_indices = torch.tensor([])
123        quantization_levels, level_indices = quantization_levels.sort()
124
125        return (alpha, gamma, quantization_levels, level_indices)
126
127    r"""Records the running minimum and maximum of ``x``.
128        Args:
129            x_orig: Tensor to be observed for min and max val"""
130
131    def forward(self, x_orig):
132        if x_orig.numel() == 0:
133            return x_orig
134        x = x_orig.detach()
135        min_val, max_val = torch.aminmax(x)
136        if self.min_val.numel():
137            min_val = torch.min(min_val, self.min_val)
138        if self.max_val.numel():
139            max_val = torch.max(max_val, self.max_val)
140        self.min_val = min_val
141        self.max_val = max_val
142        return x_orig
143
144    r"""Displays visualization of APoT quantization levels
145        Args:
146            observer: APoTObserver to calculate qparams
147            signed: bool to indicate if qparams should be signed/unsigned
148    """
149
150    def quant_levels_visualization(self, signed=False):
151        alpha, gamma, quantization_levels, level_indices = self.calculate_qparams(
152            signed
153        )
154
155        xs = [float(x) / 1000.0 for x in range(1000)]
156        ys = [
157            apot_to_float(
158                float_to_apot(x, quantization_levels, level_indices, alpha),
159                quantization_levels,
160                level_indices,
161            ).item()
162            for x in xs
163        ]
164
165        f = plt.figure(figsize=(15, 10))
166
167        plt.plot(xs, ys)
168        plt.title("APoT Quantization Plot")
169        plt.xlabel("Full Precision")
170        plt.ylabel("Quantized")
171        plt.show()
172