1# mypy: allow-untyped-defs 2""" 3This module implements nonuniform observers used to collect statistics about 4the values observed during calibration (PTQ) or training (QAT). 5""" 6 7import itertools 8 9import matplotlib.pyplot as plt 10 11import torch 12from torch.ao.quantization.experimental.apot_utils import apot_to_float, float_to_apot 13from torch.ao.quantization.observer import ObserverBase 14 15 16# TODO: Consider adding NonUniformQuantizationObserverBase class 17# when more than one non-uniform method is implemented 18 19 20class APoTObserver(ObserverBase): 21 b: int 22 k: int 23 n: int 24 min_val: torch.Tensor 25 max_val: torch.Tensor 26 27 def __init__(self, b, k, dtype=torch.quint8) -> None: 28 super().__init__(dtype) 29 self.b = b 30 self.k = k 31 32 self.min_val = torch.tensor([]) 33 self.max_val = torch.tensor([]) 34 35 # min_val and max_val are optional args to override 36 # the min_val and max_val observed by forward 37 def calculate_qparams(self, signed): 38 return self._calculate_qparams(signed, self.min_val, self.max_val) 39 40 r""" Calculates nonuniform quantization parameters according to APoT paper: 41 https://arxiv.org/pdf/1909.13144.pdf. 42 Arg: 43 signed: specifies whether to include signed values in quantization level calculations 44 min_val: optional arg that can override min_val internal attribute 45 max_val: optional arg that can override max_val internal attribute 46 Returns: 47 alpha: alpha quantization parameter, max of abs value of observed values 48 gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range 49 quantization_levels: non-uniform quantization levels (fp representation) 50 level_indices: int representation of quantization_levels indices 51 """ 52 53 def _calculate_qparams(self, signed: bool, min_val=None, max_val=None): 54 if min_val is not None: 55 self.min_val = min_val 56 if max_val is not None: 57 self.max_val = max_val 58 59 # compute alpha 60 alpha = torch.max(-self.min_val, self.max_val) 61 62 # check for valid inputs of b, k 63 assert self.k and self.k != 0 64 assert self.b % self.k == 0 65 66 # compute n and store as member variable 67 self.n = self.b // self.k 68 69 # store a tensor of subtensors (all levels) 70 p_all = [] 71 72 # create levels 73 for i in range(0, self.n): 74 p_curr = torch.tensor([0]) 75 76 for j in range(0, (2**self.k - 2) + 1): 77 curr_ele = 2 ** (-(i + j * self.n)) 78 p_append = torch.tensor([curr_ele]) 79 p_curr = torch.cat((p_curr, p_append)) 80 # introduce signed numbers 81 if signed: 82 p_curr = torch.cat((p_curr, torch.tensor([-curr_ele]))) 83 84 if signed: 85 # sort tensor in reverse order before adding to list if signed 86 sorted, indices = torch.sort(p_curr, descending=True) 87 p_all.append(sorted) 88 else: 89 p_all.append(p_curr) 90 91 # gamma calculation: 92 # loop through all tensors 93 # if signed, add element at index 0 for each tensor 94 # else, add element at index 1 for each tensor 95 # gamma defined to ensure alpha is at max of range 96 p_sum = 0.0 97 for tens in p_all: 98 if signed: 99 p_sum += float(tens[0]) 100 else: 101 p_sum += float(tens[1]) 102 103 # assign gamma 104 gamma = alpha / p_sum 105 106 # calculate cartesian product 107 cartesian_product = list(itertools.product(*p_all)) 108 109 quantization_levels_list = [] 110 111 # calculate sum of each row 112 for row in cartesian_product: 113 sum = 0.0 114 for ele in row: 115 sum += ele 116 quantization_levels_list.append(sum) 117 118 quantization_levels_gamma = [ 119 float(gamma) * ele for ele in quantization_levels_list 120 ] 121 quantization_levels = torch.tensor(quantization_levels_gamma) 122 level_indices = torch.tensor([]) 123 quantization_levels, level_indices = quantization_levels.sort() 124 125 return (alpha, gamma, quantization_levels, level_indices) 126 127 r"""Records the running minimum and maximum of ``x``. 128 Args: 129 x_orig: Tensor to be observed for min and max val""" 130 131 def forward(self, x_orig): 132 if x_orig.numel() == 0: 133 return x_orig 134 x = x_orig.detach() 135 min_val, max_val = torch.aminmax(x) 136 if self.min_val.numel(): 137 min_val = torch.min(min_val, self.min_val) 138 if self.max_val.numel(): 139 max_val = torch.max(max_val, self.max_val) 140 self.min_val = min_val 141 self.max_val = max_val 142 return x_orig 143 144 r"""Displays visualization of APoT quantization levels 145 Args: 146 observer: APoTObserver to calculate qparams 147 signed: bool to indicate if qparams should be signed/unsigned 148 """ 149 150 def quant_levels_visualization(self, signed=False): 151 alpha, gamma, quantization_levels, level_indices = self.calculate_qparams( 152 signed 153 ) 154 155 xs = [float(x) / 1000.0 for x in range(1000)] 156 ys = [ 157 apot_to_float( 158 float_to_apot(x, quantization_levels, level_indices, alpha), 159 quantization_levels, 160 level_indices, 161 ).item() 162 for x in xs 163 ] 164 165 f = plt.figure(figsize=(15, 10)) 166 167 plt.plot(xs, ys) 168 plt.title("APoT Quantization Plot") 169 plt.xlabel("Full Precision") 170 plt.ylabel("Quantized") 171 plt.show() 172