xref: /aosp_15_r20/external/pytorch/test/quantization/core/experimental/test_nonuniform_observer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3from torch.ao.quantization.experimental.observer import APoTObserver
4import unittest
5import torch
6
7class TestNonUniformObserver(unittest.TestCase):
8    """
9        Test case 1: calculate_qparams
10        Test that error is thrown when k == 0
11    """
12    def test_calculate_qparams_invalid(self):
13        obs = APoTObserver(b=0, k=0)
14        obs.min_val = torch.tensor([0.0])
15        obs.max_val = torch.tensor([0.0])
16
17        with self.assertRaises(AssertionError):
18            alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False)
19
20    """
21        Test case 2: calculate_qparams
22        APoT paper example: https://arxiv.org/pdf/1909.13144.pdf
23        Assume hardcoded parameters:
24        * b = 4 (total number of bits across all terms)
25        * k = 2 (base bitwidth, i.e. bitwidth of every term)
26        * n = 2 (number of additive terms)
27        * note: b = k * n
28    """
29    def test_calculate_qparams_2terms(self):
30        obs = APoTObserver(b=4, k=2)
31
32        obs.min_val = torch.tensor([0.0])
33        obs.max_val = torch.tensor([1.0])
34        alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False)
35
36        alpha_test = torch.max(-obs.min_val, obs.max_val)
37
38        # check alpha value
39        self.assertEqual(alpha, alpha_test)
40
41        # calculate expected gamma value
42        gamma_test = 0
43        for i in range(2):
44            gamma_test += 2**(-i)
45
46        gamma_test = 1 / gamma_test
47
48        # check gamma value
49        self.assertEqual(gamma, gamma_test)
50
51        # check quantization levels size
52        quantlevels_size_test = int(len(quantization_levels))
53        quantlevels_size = 2**4
54        self.assertEqual(quantlevels_size_test, quantlevels_size)
55
56        # check level indices size
57        levelindices_size_test = int(len(level_indices))
58        self.assertEqual(levelindices_size_test, 16)
59
60        # check level indices unique values
61        level_indices_test_list = level_indices.tolist()
62        self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
63
64    """
65        Test case 3: calculate_qparams
66        Assume hardcoded parameters:
67        * b = 6 (total number of bits across all terms)
68        * k = 2 (base bitwidth, i.e. bitwidth of every term)
69        * n = 3 (number of additive terms)
70    """
71    def test_calculate_qparams_3terms(self):
72        obs = APoTObserver(b=6, k=2)
73
74        obs.min_val = torch.tensor([0.0])
75        obs.max_val = torch.tensor([1.0])
76        alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False)
77
78        alpha_test = torch.max(-obs.min_val, obs.max_val)
79
80        # check alpha value
81        self.assertEqual(alpha, alpha_test)
82
83        # calculate expected gamma value
84        gamma_test = 0
85        for i in range(3):
86            gamma_test += 2**(-i)
87
88        gamma_test = 1 / gamma_test
89
90        # check gamma value
91        self.assertEqual(gamma, gamma_test)
92
93        # check quantization levels size
94        quantlevels_size_test = int(len(quantization_levels))
95        quantlevels_size = 2**6
96        self.assertEqual(quantlevels_size_test, quantlevels_size)
97
98        # check level indices size
99        levelindices_size_test = int(len(level_indices))
100        self.assertEqual(levelindices_size_test, 64)
101
102        # check level indices unique values
103        level_indices_test_list = level_indices.tolist()
104        self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
105
106    """
107        Test case 4: calculate_qparams
108        Same as test case 2 but with signed = True
109        Assume hardcoded parameters:
110        * b = 4 (total number of bits across all terms)
111        * k = 2 (base bitwidth, i.e. bitwidth of every term)
112        * n = 2 (number of additive terms)
113        * signed = True
114    """
115    def test_calculate_qparams_signed(self):
116        obs = APoTObserver(b=4, k=2)
117
118        obs.min_val = torch.tensor([0.0])
119        obs.max_val = torch.tensor([1.0])
120        alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True)
121        alpha_test = torch.max(-obs.min_val, obs.max_val)
122
123        # check alpha value
124        self.assertEqual(alpha, alpha_test)
125
126        # calculate expected gamma value
127        gamma_test = 0
128        for i in range(2):
129            gamma_test += 2**(-i)
130
131        gamma_test = 1 / gamma_test
132
133        # check gamma value
134        self.assertEqual(gamma, gamma_test)
135
136        # check quantization levels size
137        quantlevels_size_test = int(len(quantization_levels))
138        self.assertEqual(quantlevels_size_test, 49)
139
140        # check negatives of each element contained
141        # in quantization levels
142        quantlevels_test_list = quantization_levels.tolist()
143        negatives_contained = True
144        for ele in quantlevels_test_list:
145            if -ele not in quantlevels_test_list:
146                negatives_contained = False
147        self.assertTrue(negatives_contained)
148
149        # check level indices size
150        levelindices_size_test = int(len(level_indices))
151        self.assertEqual(levelindices_size_test, 49)
152
153        # check level indices unique elements
154        level_indices_test_list = level_indices.tolist()
155        self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
156
157    """
158    Test case 5: calculate_qparams
159        Assume hardcoded parameters:
160        * b = 6 (total number of bits across all terms)
161        * k = 1 (base bitwidth, i.e. bitwidth of every term)
162        * n = 6 (number of additive terms)
163    """
164    def test_calculate_qparams_k1(self):
165        obs = APoTObserver(b=6, k=1)
166
167        obs.min_val = torch.tensor([0.0])
168        obs.max_val = torch.tensor([1.0])
169
170        alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False)
171
172        # calculate expected gamma value
173        gamma_test = 0
174        for i in range(6):
175            gamma_test += 2**(-i)
176
177        gamma_test = 1 / gamma_test
178
179        # check gamma value
180        self.assertEqual(gamma, gamma_test)
181
182        # check quantization levels size
183        quantlevels_size_test = int(len(quantization_levels))
184        quantlevels_size = 2**6
185        self.assertEqual(quantlevels_size_test, quantlevels_size)
186
187        # check level indices size
188        levelindices_size_test = int(len(level_indices))
189        level_indices_size = 2**6
190        self.assertEqual(levelindices_size_test, level_indices_size)
191
192        # check level indices unique values
193        level_indices_test_list = level_indices.tolist()
194        self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list)))
195
196    """
197        Test forward method on hard-coded tensor with arbitrary values.
198        Checks that alpha is max of abs value of max and min values in tensor.
199    """
200    def test_forward(self):
201        obs = APoTObserver(b=4, k=2)
202
203        X = torch.tensor([0.0, -100.23, -37.18, 3.42, 8.93, 9.21, 87.92])
204
205        X = obs.forward(X)
206
207        alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True)
208
209        min_val = torch.min(X)
210        max_val = torch.max(X)
211
212        expected_alpha = torch.max(-min_val, max_val)
213
214        self.assertEqual(alpha, expected_alpha)
215
216
217if __name__ == '__main__':
218    unittest.main()
219