1# mypy: allow-untyped-defs 2import numpy as np 3 4import torch 5from torch import Tensor 6from torch.ao.quantization.experimental.apot_utils import ( 7 apot_to_float, 8 float_to_apot, 9 quant_dequant_util, 10) 11 12 13# class to store APoT quantizer and 14# implement quantize and dequantize 15class APoTQuantizer: 16 alpha: torch.Tensor 17 gamma: torch.Tensor 18 quantization_levels: torch.Tensor 19 level_indices: torch.Tensor 20 21 def __init__( 22 self, 23 alpha: torch.Tensor, 24 gamma: torch.Tensor, 25 quantization_levels: torch.Tensor, 26 level_indices: torch.Tensor, 27 ) -> None: 28 self.alpha = alpha 29 self.gamma = gamma 30 self.quantization_levels = quantization_levels 31 self.level_indices = level_indices 32 33 r""" Quantizes fp Tensor to integer APoT representation. 34 Conversion is based on the qparams from a specified APoT non-uniform observer. 35 The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. 36 Args: 37 tensor2quantize: fp Tensor 38 Returns: 39 result: APoT Tensor representation of tensor2quantize 40 """ 41 42 def quantize(self, tensor2quantize: Tensor): 43 result = torch.tensor([]) 44 45 # map float_to_apot over tensor2quantize elements 46 tensor2quantize = tensor2quantize.detach().apply_( 47 lambda x: float_to_apot( 48 x, self.quantization_levels, self.level_indices, self.alpha 49 ) 50 ) 51 52 # convert to APoT int representation for dtype 53 tensor2quantize = tensor2quantize.int() 54 55 from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT 56 57 result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment] 58 59 return result 60 61 r""" Dequantizes integer Tensor to floating point (fp) representation 62 based on the calculated quantization levels from a specified APoT non-uniform observer. 63 The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. 64 Args: 65 tensor2quantize: fp Tensor 66 Returns: 67 result: fp reduced precision representation of input Tensor 68 """ 69 70 def dequantize(self, apot_tensor) -> Tensor: 71 orig_size = apot_tensor.data.size() 72 apot_tensor_data = apot_tensor.data.flatten() 73 74 print(apot_tensor_data) 75 76 # map apot_to_float over tensor2quantize elements 77 result_temp = np.empty(shape=apot_tensor_data.size()) 78 for i in range(len(apot_tensor_data)): 79 new_ele = apot_to_float( 80 apot_tensor_data[i], self.quantization_levels, self.level_indices 81 ) 82 result_temp[i] = new_ele 83 84 result = torch.from_numpy(result_temp).reshape(orig_size) 85 86 return result 87 88 r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision) 89 based on the calculated quantization levels from a specified APoT non-uniform observer. 90 The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf. 91 Args: 92 apot_tensor: quantized APoT Tensor to dequantize 93 Returns: 94 result: fp representation of input Tensor 95 """ 96 97 def quant_dequant(self, tensor2quantize: Tensor) -> Tensor: 98 levels_lst = list(self.quantization_levels) 99 100 result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg] 101 102 return result 103 104 def q_apot_alpha(self) -> float: 105 raise NotImplementedError 106 107 108r""" Global method to create quantizer and call quantizer quantize_APoT 109 Args: 110 tensor2quantize: fp Tensor to quantize 111 alpha: Tensor qparam alpha (clipping level) 112 gamma: Tensor qparam gamma (scale factor for quantization levels) 113 quantization levels: Tensor with fp quantization levels 114 level indices: Tensor with integer quantization level indices 115 Returns: 116 result: ApoT Tensor representation of tensor2quantize 117""" 118 119 120def quantize_APoT( 121 tensor2quantize: Tensor, 122 alpha: Tensor, 123 gamma: Tensor, 124 quantization_levels: Tensor, 125 level_indices: Tensor, 126): 127 quantizer = APoTQuantizer( 128 alpha=alpha, 129 gamma=gamma, 130 quantization_levels=quantization_levels, 131 level_indices=level_indices, 132 ) 133 result = quantizer.quantize(tensor2quantize) 134 return result 135 136 137r""" Global method to create quantizer and call quantizer dequantize_APoT 138 Args: 139 apot_tensor: APoT Tensor to dequantize 140 Returns: 141 result: fp Tensor dequantized from apot_tensor 142""" 143 144 145def dequantize_APoT(apot_tensor) -> Tensor: 146 quantizer = apot_tensor.quantizer 147 result = quantizer.dequantize(apot_tensor) 148 return result 149 150 151r""" Global method to create quantizer and call quantizer quant_dequant 152 Args: 153 tensor2quantize: fp Tensor to quantize 154 alpha: Tensor qparam alpha (clipping level) 155 gamma: Tensor qparam gamma (scale factor for quantization levels) 156 quantization levels: Tensor with fp quantization levels 157 level indices: Tensor with integer quantization level indices 158 Returns: 159 result: fp reduced precision Tensor from tensor2quantize 160""" 161 162 163def quant_dequant_APoT( 164 tensor2quantize: Tensor, 165 alpha: Tensor, 166 gamma: Tensor, 167 quantization_levels: Tensor, 168 level_indices: Tensor, 169) -> Tensor: 170 quantizer = APoTQuantizer( 171 alpha=alpha, 172 gamma=gamma, 173 quantization_levels=quantization_levels, 174 level_indices=level_indices, 175 ) 176 result = quantizer.quant_dequant(tensor2quantize) 177 return result 178