1# Owner(s): ["oncall: quantization"] 2 3from torch.ao.quantization.experimental.observer import APoTObserver 4import unittest 5import torch 6 7class TestNonUniformObserver(unittest.TestCase): 8 """ 9 Test case 1: calculate_qparams 10 Test that error is thrown when k == 0 11 """ 12 def test_calculate_qparams_invalid(self): 13 obs = APoTObserver(b=0, k=0) 14 obs.min_val = torch.tensor([0.0]) 15 obs.max_val = torch.tensor([0.0]) 16 17 with self.assertRaises(AssertionError): 18 alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) 19 20 """ 21 Test case 2: calculate_qparams 22 APoT paper example: https://arxiv.org/pdf/1909.13144.pdf 23 Assume hardcoded parameters: 24 * b = 4 (total number of bits across all terms) 25 * k = 2 (base bitwidth, i.e. bitwidth of every term) 26 * n = 2 (number of additive terms) 27 * note: b = k * n 28 """ 29 def test_calculate_qparams_2terms(self): 30 obs = APoTObserver(b=4, k=2) 31 32 obs.min_val = torch.tensor([0.0]) 33 obs.max_val = torch.tensor([1.0]) 34 alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) 35 36 alpha_test = torch.max(-obs.min_val, obs.max_val) 37 38 # check alpha value 39 self.assertEqual(alpha, alpha_test) 40 41 # calculate expected gamma value 42 gamma_test = 0 43 for i in range(2): 44 gamma_test += 2**(-i) 45 46 gamma_test = 1 / gamma_test 47 48 # check gamma value 49 self.assertEqual(gamma, gamma_test) 50 51 # check quantization levels size 52 quantlevels_size_test = int(len(quantization_levels)) 53 quantlevels_size = 2**4 54 self.assertEqual(quantlevels_size_test, quantlevels_size) 55 56 # check level indices size 57 levelindices_size_test = int(len(level_indices)) 58 self.assertEqual(levelindices_size_test, 16) 59 60 # check level indices unique values 61 level_indices_test_list = level_indices.tolist() 62 self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) 63 64 """ 65 Test case 3: calculate_qparams 66 Assume hardcoded parameters: 67 * b = 6 (total number of bits across all terms) 68 * k = 2 (base bitwidth, i.e. bitwidth of every term) 69 * n = 3 (number of additive terms) 70 """ 71 def test_calculate_qparams_3terms(self): 72 obs = APoTObserver(b=6, k=2) 73 74 obs.min_val = torch.tensor([0.0]) 75 obs.max_val = torch.tensor([1.0]) 76 alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) 77 78 alpha_test = torch.max(-obs.min_val, obs.max_val) 79 80 # check alpha value 81 self.assertEqual(alpha, alpha_test) 82 83 # calculate expected gamma value 84 gamma_test = 0 85 for i in range(3): 86 gamma_test += 2**(-i) 87 88 gamma_test = 1 / gamma_test 89 90 # check gamma value 91 self.assertEqual(gamma, gamma_test) 92 93 # check quantization levels size 94 quantlevels_size_test = int(len(quantization_levels)) 95 quantlevels_size = 2**6 96 self.assertEqual(quantlevels_size_test, quantlevels_size) 97 98 # check level indices size 99 levelindices_size_test = int(len(level_indices)) 100 self.assertEqual(levelindices_size_test, 64) 101 102 # check level indices unique values 103 level_indices_test_list = level_indices.tolist() 104 self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) 105 106 """ 107 Test case 4: calculate_qparams 108 Same as test case 2 but with signed = True 109 Assume hardcoded parameters: 110 * b = 4 (total number of bits across all terms) 111 * k = 2 (base bitwidth, i.e. bitwidth of every term) 112 * n = 2 (number of additive terms) 113 * signed = True 114 """ 115 def test_calculate_qparams_signed(self): 116 obs = APoTObserver(b=4, k=2) 117 118 obs.min_val = torch.tensor([0.0]) 119 obs.max_val = torch.tensor([1.0]) 120 alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True) 121 alpha_test = torch.max(-obs.min_val, obs.max_val) 122 123 # check alpha value 124 self.assertEqual(alpha, alpha_test) 125 126 # calculate expected gamma value 127 gamma_test = 0 128 for i in range(2): 129 gamma_test += 2**(-i) 130 131 gamma_test = 1 / gamma_test 132 133 # check gamma value 134 self.assertEqual(gamma, gamma_test) 135 136 # check quantization levels size 137 quantlevels_size_test = int(len(quantization_levels)) 138 self.assertEqual(quantlevels_size_test, 49) 139 140 # check negatives of each element contained 141 # in quantization levels 142 quantlevels_test_list = quantization_levels.tolist() 143 negatives_contained = True 144 for ele in quantlevels_test_list: 145 if -ele not in quantlevels_test_list: 146 negatives_contained = False 147 self.assertTrue(negatives_contained) 148 149 # check level indices size 150 levelindices_size_test = int(len(level_indices)) 151 self.assertEqual(levelindices_size_test, 49) 152 153 # check level indices unique elements 154 level_indices_test_list = level_indices.tolist() 155 self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) 156 157 """ 158 Test case 5: calculate_qparams 159 Assume hardcoded parameters: 160 * b = 6 (total number of bits across all terms) 161 * k = 1 (base bitwidth, i.e. bitwidth of every term) 162 * n = 6 (number of additive terms) 163 """ 164 def test_calculate_qparams_k1(self): 165 obs = APoTObserver(b=6, k=1) 166 167 obs.min_val = torch.tensor([0.0]) 168 obs.max_val = torch.tensor([1.0]) 169 170 alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=False) 171 172 # calculate expected gamma value 173 gamma_test = 0 174 for i in range(6): 175 gamma_test += 2**(-i) 176 177 gamma_test = 1 / gamma_test 178 179 # check gamma value 180 self.assertEqual(gamma, gamma_test) 181 182 # check quantization levels size 183 quantlevels_size_test = int(len(quantization_levels)) 184 quantlevels_size = 2**6 185 self.assertEqual(quantlevels_size_test, quantlevels_size) 186 187 # check level indices size 188 levelindices_size_test = int(len(level_indices)) 189 level_indices_size = 2**6 190 self.assertEqual(levelindices_size_test, level_indices_size) 191 192 # check level indices unique values 193 level_indices_test_list = level_indices.tolist() 194 self.assertEqual(len(level_indices_test_list), len(set(level_indices_test_list))) 195 196 """ 197 Test forward method on hard-coded tensor with arbitrary values. 198 Checks that alpha is max of abs value of max and min values in tensor. 199 """ 200 def test_forward(self): 201 obs = APoTObserver(b=4, k=2) 202 203 X = torch.tensor([0.0, -100.23, -37.18, 3.42, 8.93, 9.21, 87.92]) 204 205 X = obs.forward(X) 206 207 alpha, gamma, quantization_levels, level_indices = obs.calculate_qparams(signed=True) 208 209 min_val = torch.min(X) 210 max_val = torch.max(X) 211 212 expected_alpha = torch.max(-min_val, max_val) 213 214 self.assertEqual(alpha, expected_alpha) 215 216 217if __name__ == '__main__': 218 unittest.main() 219