1# mypy: allow-untyped-defs 2import torch 3import torch.ao.nn.quantized as nnq 4import torch.ao.ns._numeric_suite as ns 5import torch.ao.quantization 6import torch.nn as nn 7 8 9__all__ = [ 10 "get_module", 11 "parent_child_names", 12 "get_param", 13 "MeanShadowLogger", 14 "bias_correction", 15] 16 17_supported_modules = {nn.Linear, nn.Conv2d} 18_supported_modules_quantized = {nnq.Linear, nnq.Conv2d} 19 20 21def get_module(model, name): 22 """Given name of submodule, this function grabs the submodule from given model.""" 23 return dict(model.named_modules())[name] 24 25 26def parent_child_names(name): 27 """Split full name of submodule into parent submodule's full name and submodule's name.""" 28 split_name = name.rsplit(".", 1) 29 if len(split_name) == 1: 30 return "", split_name[0] 31 else: 32 return split_name[0], split_name[1] 33 34 35def get_param(module, attr): 36 """Get the parameter given a module and attribute. 37 38 Sometimes the weights/bias attribute gives you the raw tensor, but sometimes 39 gives a function that will give you the raw tensor, this function takes care of that logic 40 """ 41 param = getattr(module, attr, None) 42 if callable(param): 43 return param() 44 else: 45 return param 46 47 48class MeanShadowLogger(ns.Logger): 49 """Mean Logger for a Shadow module. 50 51 A logger for a Shadow module whose purpose is to record the rolling mean 52 of the data passed to the floating point and quantized models 53 """ 54 55 def __init__(self): 56 """Set up initial values for float and quantized stats, count, float sum, and quant sum.""" 57 super().__init__() 58 self.stats["float"] = None 59 self.stats["quantized"] = None 60 self.count = 0 61 self.float_sum = None 62 self.quant_sum = None 63 64 def forward(self, x, y): 65 """Compute the average of quantized and floating-point data from modules. 66 67 The inputs x,y are output data from the quantized and floating-point modules. 68 x is for the quantized module, y is for the floating point module 69 """ 70 if x.is_quantized: 71 x = x.dequantize() 72 73 self.count += 1 74 if self.stats["quantized"] is None: 75 self.stats["quantized"] = x 76 self.quant_sum = x 77 else: 78 self.quant_sum += x 79 self.stats["quantized"] = self.quant_sum / self.count 80 81 if self.stats["float"] is None: 82 self.stats["float"] = y 83 self.float_sum = y 84 else: 85 self.float_sum += y 86 self.stats["float"] = self.float_sum / self.count 87 88 def clear(self): 89 self.stats["float"] = None 90 self.stats["quantized"] = None 91 self.count = 0 92 self.float_sum = None 93 self.quant_sum = None 94 95 96def bias_correction( 97 float_model, 98 quantized_model, 99 img_data, 100 target_modules=_supported_modules_quantized, 101 neval_batches=None, 102): 103 """Perform bias correction on a module. 104 105 Using numeric suite shadow module, the expected output of the floating point and quantized modules 106 is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused 107 by quantization 108 Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2) 109 110 Args: 111 float_model: a trained model that serves as a reference to what bias correction should aim for 112 quantized_model: quantized form of float_model that bias correction is to applied to 113 img_data: calibration data to estimate the expected output (used to find quantization error) 114 target_modules: specifies what submodules in quantized_model need bias correction (can be extended to 115 unquantized submodules) 116 neval_batches: a cap to the number of batches you want to be used for estimating the expected output 117 """ 118 ns.prepare_model_with_stubs( 119 float_model, quantized_model, _supported_modules, MeanShadowLogger 120 ) 121 122 uncorrected_modules = {} 123 for name, submodule in quantized_model.named_modules(): 124 if type(submodule) in target_modules: 125 uncorrected_modules[name] = submodule 126 127 for uncorrected_module in uncorrected_modules: 128 quantized_submodule = get_module(quantized_model, uncorrected_module) 129 bias = get_param(quantized_submodule, "bias") 130 if bias is not None: 131 count = 0 132 for data in img_data: 133 quantized_model(data[0]) 134 count += 1 135 if count == neval_batches: 136 break 137 ob_dict = ns.get_logger_dict(quantized_model) 138 parent_name, _ = parent_child_names(uncorrected_module) 139 140 float_data = ob_dict[parent_name + ".stats"]["float"] 141 quant_data = ob_dict[parent_name + ".stats"]["quantized"] 142 143 # math for expected_error 144 quantization_error = quant_data - float_data 145 dims = list(range(quantization_error.dim())) 146 # Note: we don't want to take the mean over the output channel dimension 147 dims.remove(1) 148 expected_error = torch.mean(quantization_error, dims) 149 150 updated_bias = bias.data - expected_error 151 152 bias.data = updated_bias 153 154 # Resets the data contained in the loggers 155 for name, submodule in quantized_model.named_modules(): 156 if isinstance(submodule, MeanShadowLogger): 157 submodule.clear() 158