1# mypy: allow-untyped-defs 2import torch 3from torch.ao.quantization.observer import ObserverBase 4 5 6class ModelReportObserver(ObserverBase): 7 r"""This observer is used to record additional information regarding keeping track 8 of S = average_batch_activation_range/epoch_activation_range. 9 10 The purpose of this information is to prepare a report to present to users on whether 11 Dynamic or Static Quantization is more appropriate for their model given the general 12 distributions of their data. 13 14 Args: 15 ch_axis (int, optional): The channel axis for which the range and outlier stats are computed 16 Default: 1 17 comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers 18 Should be between 0 and 1 exclusive 19 Default: 0.9 20 21 * :attr:`num_batches_tracked` specifies number of batches passed through the observer 22 23 * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through 24 25 * :attr:`epoch_activation_min` defines the minimum value passed through the observer 26 27 * :attr:`epoch_activation_max` defines the maximum value passed through the observer 28 29 * :attr:`ch_axis` defines the channel being used to compute per channel min max stats 30 31 * :attr:`min_val` defines the per channel minimum values passed through 32 33 * :attr:`max_val` defines the per channel maximum values passed through 34 35 * :attr:`comp_percentile` defines comparison percentile to find outliers 36 37 * :attr:`average_percentile_ratio` defines the per channel average percentile ratios 38 39 * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel 40 41 * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel 42 43 Note: this tool is meant for FX Graph Mode Quantization 44 """ 45 46 epoch_activation_min: torch.Tensor 47 epoch_activation_max: torch.Tensor 48 min_val: torch.Tensor 49 max_val: torch.Tensor 50 comp_percentile: torch.Tensor 51 average_percentile_ratio: torch.Tensor 52 percentile_batches_tracked: torch.Tensor 53 constant_channels: torch.Tensor 54 55 def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9): 56 super().__init__(torch.qint8) 57 self.num_batches_tracked = 0 58 59 # keep track of the min and mix of the range for average batch and epoch as a whole 60 self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0)) 61 self.register_buffer("epoch_activation_min", torch.tensor(float("inf"))) 62 self.register_buffer("epoch_activation_max", torch.tensor(float("-inf"))) 63 64 # keep track of per channel min max information using the given channel 65 self.ch_axis: int = ch_axis 66 self.register_buffer("min_val", torch.tensor([])) 67 self.register_buffer("max_val", torch.tensor([])) 68 69 # keep track of percentile ratio information per channel 70 self.register_buffer("comp_percentile", torch.tensor([comp_percentile])) 71 self.register_buffer("average_percentile_ratio", torch.tensor([])) 72 self.register_buffer("percentile_batches_tracked", torch.tensor([])) 73 self.register_buffer("constant_channels", torch.tensor([])) 74 75 def forward(self, x): 76 x_copy = x.detach() # avoid keeping autograd tape 77 x_copy = x_copy.to(self.epoch_activation_min.dtype) 78 79 x_copy = self._calculate_range_stats(x_copy) 80 x_copy = self._calculate_min_max_stats(x_copy) 81 x_copy = self._calculate_percentile_stats(x_copy) 82 83 # return the passed in the value 84 return x 85 86 def _calculate_range_stats(self, x_copy): 87 r"""Calculates and stores range stats with forward values. 88 89 Args 90 x_copy: A copy of the forward data 91 92 Returns the passed in x_copy 93 """ 94 # get the min, max values of the data 95 min_val_cur, max_val_cur = torch.aminmax(x_copy) 96 97 # calculate new epoch range values 98 epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) 99 epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) 100 101 self.epoch_activation_min.copy_(epoch_min_val) 102 self.epoch_activation_max.copy_(epoch_max_val) 103 104 # calculate the average batch activation range 105 current_batch_range = max_val_cur - min_val_cur 106 new_range = ( 107 self.average_batch_activation_range * self.num_batches_tracked 108 + current_batch_range 109 ) / (self.num_batches_tracked + 1) 110 111 self.average_batch_activation_range = new_range 112 self.num_batches_tracked += 1 # new batch was processed 113 114 return x_copy 115 116 def _calculate_min_max_stats(self, x_copy): 117 r"""Calculates and stores the per_channel min, max stats with forward values. 118 Does calculation based on channel axis: self.ch_axis 119 120 Args 121 x_copy: A copy of the forward data 122 123 Returns the passed in x_copy 124 """ 125 # get the current min and max vals 126 min_val = self.min_val 127 max_val = self.max_val 128 x_dim = x_copy.size() 129 130 new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 131 new_axis_list[self.ch_axis] = 0 132 new_axis_list[0] = self.ch_axis 133 y = x_copy.permute(new_axis_list) 134 # Need to match dtype of min/max because the updates to buffers 135 # are done in place and types need to match for comparisons 136 y = y.to(self.min_val.dtype) 137 y = torch.flatten(y, start_dim=1) 138 if min_val.numel() == 0 or max_val.numel() == 0: 139 min_val, max_val = torch.aminmax(y, dim=1) 140 else: 141 min_val_cur, max_val_cur = torch.aminmax(y, dim=1) 142 min_val = torch.min(min_val_cur, min_val) 143 max_val = torch.max(max_val_cur, max_val) 144 145 self.min_val.resize_(min_val.shape) 146 self.max_val.resize_(max_val.shape) 147 self.min_val.copy_(min_val) 148 self.max_val.copy_(max_val) 149 150 return x_copy 151 152 def _calculate_percentile_stats(self, x_copy): 153 r"""Calculates and stores the per_channel percentile stats with forward values. 154 Does calculation based on channel axis: self.ch_axis 155 156 Args 157 x_copy: A copy of the forward data 158 159 Returns the passed in x_copy 160 """ 161 # get the dimension of the copy 162 x_dim = x_copy.size() 163 164 new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 165 new_axis_list[self.ch_axis] = 0 166 new_axis_list[0] = self.ch_axis 167 y = x_copy.permute(new_axis_list) 168 # Need to match dtype of min/max because the updates to buffers 169 # are done in place and types need to match for comparisons 170 y = y.to(self.min_val.dtype) 171 y = torch.flatten(y, start_dim=1) 172 y = y.to(dtype=self.min_val.dtype, device="cpu") 173 174 # find the percentile values along the axis 175 # we want both 100th percentile and comp_percentile 176 # we also want to find 0th quartile to see if we have constant channel 177 quantiles_list = [0, self.comp_percentile, 1.00] 178 quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype) 179 180 # find the quantiles 181 desired_quantiles = torch.quantile( 182 y, quantiles_to_find, dim=self.ch_axis, interpolation="lower" 183 ) 184 zero_quantile = desired_quantiles[0] 185 comp_quantile = desired_quantiles[1] 186 hundreth_quartile = desired_quantiles[2] 187 188 # if any of the channels have 0s, we ignore that channel for this calculation 189 any_non_zero_quantile_value: torch.Tensor = ( 190 comp_quantile != torch.tensor([0]) 191 ) | (hundreth_quartile != torch.tensor([0])) 192 any_non_zero_quantile_value = ( 193 any_non_zero_quantile_value.int() 194 ) # transform boolean values to int values 195 196 # we also check if we have a constant channel 197 any_constant_channels: torch.Tensor = ( 198 hundreth_quartile - zero_quantile 199 ) == torch.tensor([0]) 200 any_constant_channels = ( 201 any_constant_channels.int() 202 ) # transform boolean values to int values 203 204 # possibilities to get nan as an answer 205 # will ignore any of these three cases with 0s and just not deal with them for now 206 # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative 207 # case (2) 0 in denominator: is possible unless case 3, we just ignore 208 # case (3) 0 in both: not outlier, channel just kinda useless, ignore 209 210 # get the ratio and get rid of nan values 211 quantile_ratios = hundreth_quartile / comp_quantile 212 quantile_ratios = torch.nan_to_num(quantile_ratios) 213 # update averages, remembering to only update if didn't have zeros 214 ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios 215 216 # if num_batches and average_ratio are not initialized, we want to initialize them 217 if ( 218 self.percentile_batches_tracked.shape[0] == 0 219 or self.average_percentile_ratio.shape[0] == 0 220 ): 221 self.percentile_batches_tracked = torch.zeros_like( 222 any_non_zero_quantile_value 223 ) 224 self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero) 225 226 # also initialize the constant channel var if that is not initialized separately 227 if self.constant_channels.shape[0] == 0: 228 self.constant_channels = torch.zeros_like(any_constant_channels) 229 230 # get current num batches and average ratio 231 num_batches = self.percentile_batches_tracked 232 average_ratio = self.average_percentile_ratio 233 234 # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches 235 new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value 236 new_ratios: torch.Tensor = ( 237 (average_ratio * num_batches) + ratio_if_not_zero 238 ) / new_number_of_batches 239 new_ratios = torch.nan_to_num(new_ratios) 240 241 # update the number of non-constant channels 242 new_constant_count: torch.Tensor = ( 243 self.constant_channels + any_constant_channels 244 ) 245 246 # update the values locally 247 self.percentile_batches_tracked.copy_(new_number_of_batches) 248 self.average_percentile_ratio.copy_(new_ratios) 249 self.constant_channels.copy_(new_constant_count) 250 251 return x_copy 252 253 @torch.jit.export 254 def get_batch_to_epoch_ratio(self): 255 epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min 256 257 if epoch_activation_range == torch.tensor(float(0)): 258 raise ValueError("Range for Epoch is 0") 259 elif epoch_activation_range == torch.tensor(float("inf")): 260 raise ValueError( 261 "No data has been run through observer or infinity value present" 262 ) 263 else: 264 return self.average_batch_activation_range / epoch_activation_range 265 266 @torch.jit.export 267 def reset_batch_and_epoch_values(self): 268 # set all the values back to their original defaults for a new epoch 269 # keep device 270 device = self.max_val.device 271 self.num_batches_tracked = 0 272 self.average_batch_activation_range = torch.tensor(float(0), device=device) 273 self.epoch_activation_min = torch.tensor(float("inf"), device=device) 274 self.epoch_activation_max = torch.tensor(float("-inf"), device=device) 275 self.min_val = torch.tensor([], device=device) 276 self.max_val = torch.tensor([], device=device) 277 self.average_percentile_ratio = torch.tensor([], device=device) 278 self.percentile_batches_tracked = torch.tensor([], device=device) 279 self.constant_channels = torch.tensor([], device=device) 280 281 @torch.jit.export 282 def calculate_qparams(self): 283 raise Exception( # noqa: TRY002 284 "calculate_qparams should not be called for ModelReportObserver" 285 ) 286