1r""" 2**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not 3rely on it for anything!** 4""" 5import operator 6import sys 7from typing import Optional 8 9import torch 10from torch.fx import Graph, GraphModule, Node 11from torch.fx.graph import map_arg 12from torch.fx.proxy import Proxy 13from torch.nn.utils import fuse_conv_bn_weights 14 15 16# can be a 17# module type, a builtin function, or a string to match target 18 19 20def _minmax_scale_zeropoint( 21 min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps 22): 23 min_val = min(0.0, min_val) 24 max_val = max(0.0, max_val) 25 if max_val == min_val: 26 return 1.0, 0 27 else: 28 scale = (max_val - min_val) / float(qmax - qmin) 29 scale = max(scale, eps) 30 zero_point = qmin - round(min_val / scale) 31 zero_point = max(qmin, zero_point) 32 zero_point = min(qmax, zero_point) 33 zero_point = int(zero_point) 34 return scale, zero_point 35 36 37class MinMaxObserver: 38 def __init__(self, quantizer, node): 39 self.min, self.max = float("inf"), float("-inf") 40 self.all_tensors = True 41 42 def observe(self, node, env): 43 v = env[node.name] 44 if not isinstance(v, torch.Tensor): 45 self.all_tensors = False 46 return 47 self.max = max(self.max, float(v.max())) 48 self.min = min(self.min, float(v.min())) 49 50 def scale_zeropoint(self): 51 return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) 52 53 54class NoObserver: 55 def __init__(self, quantizer, node): 56 pass 57 58 def observe(self, node, env): 59 pass 60 61 62_DEFAULT_QUANTIZATION_PATTERNS = {} 63 64 65def register_pattern(pattern): 66 def insert(fn): 67 _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn 68 return fn 69 70 return insert 71 72 73@register_pattern(operator.add) 74class Add(MinMaxObserver): 75 def quantize(self, quantizer, node, load_arg): 76 if not self.all_tensors: 77 return NotImplemented 78 scale, zeropoint = self.scale_zeropoint() 79 return quantizer.quantized_graph.create_node( 80 "call_function", 81 torch.ops.quantized.add, 82 load_arg(node.args), 83 {"scale": scale, "zero_point": zeropoint}, 84 ) 85 86 87class Relu(NoObserver): 88 def quantize(self, quantizer, node, load_arg): 89 return torch.relu( 90 load_arg(node.args[0]) 91 ) # torch.relu works directly on quantized tensors? 92 93 94# these ops have quantized equivalents that do not need any extra information 95@register_pattern(torch.nn.ReLU) 96@register_pattern(torch.nn.AvgPool2d) 97@register_pattern(torch.nn.MaxPool2d) 98@register_pattern(torch.nn.AdaptiveAvgPool2d) 99class CopyNode(NoObserver): 100 def quantize(self, quantizer, node, load_arg): 101 return quantizer.quantized_graph.node_copy(node, load_arg) 102 103 104class IdentityModule(torch.nn.Module): 105 def forward(self, x): 106 return x 107 108 109# handle conv, maybe followed by bn, maybe followed by relu 110@register_pattern(torch.nn.modules.conv.Conv2d) 111@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) 112@register_pattern( 113 (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d) 114) 115@register_pattern( 116 ( 117 torch.nn.ReLU, 118 (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d), 119 ) 120) 121class ConvNormRelu(MinMaxObserver): 122 def __init__(self, quantizer, node): 123 super().__init__(quantizer, node) 124 self.relu_node, self.bn_node = None, None 125 if isinstance(quantizer.modules[node.target], torch.nn.ReLU): 126 self.relu_node = node 127 node = node.args[0] 128 if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): 129 self.bn_node = node 130 self.bn = quantizer.modules[self.bn_node.target] 131 node = node.args[0] 132 assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) 133 self.conv_node = node 134 self.conv = quantizer.modules[self.conv_node.target] 135 136 def quantize(self, quantizer, node, load_arg): 137 mod = self.conv 138 weight, bias = mod.weight, mod.bias 139 140 if self.bn_node is not None: 141 weight, bias = fuse_conv_bn_weights( 142 weight, 143 bias, 144 self.bn.running_mean, 145 self.bn.running_var, 146 self.bn.eps, 147 self.bn.weight, 148 self.bn.bias, 149 ) 150 151 min_val, max_val = float(weight.min()), float(weight.max()) 152 153 act_scale, act_zp = self.scale_zeropoint() 154 155 weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) 156 qweight = torch.quantize_per_tensor( 157 weight, weight_scale, weight_zp, torch.qint8 158 ) 159 160 ctor = ( 161 torch.ao.nn.intrinsic.quantized.ConvReLU2d 162 if self.relu_node is not None 163 else torch.ao.nn.quantized.Conv2d 164 ) 165 166 qconv = ctor( 167 mod.in_channels, 168 mod.out_channels, 169 mod.kernel_size, 170 mod.stride, 171 mod.padding, 172 mod.dilation, 173 mod.groups, 174 mod.bias is not None, 175 mod.padding_mode, 176 ) 177 178 qconv.set_weight_bias(qweight, bias) 179 qconv.scale = float(act_scale) 180 qconv.zero_point = int(act_zp) 181 parent_name, name = _parent_name(self.conv_node.target) 182 setattr(quantizer.modules[parent_name], name, qconv) 183 if self.bn_node is not None: 184 parent_bn, bn_name = _parent_name(self.bn_node.target) 185 # we can't just delete this because submodules's forwards (which are not longer use) 186 # try to call it, so replace with something that does nothing. 187 setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) 188 189 return quantizer.quantized_graph.create_node( 190 "call_module", 191 self.conv_node.target, 192 (load_arg(self.conv_node.args[0]),), 193 {}, 194 ) 195 196 197# turn foo.bar -> ['foo', 'bar'] 198def _parent_name(target): 199 r = target.rsplit(".", 1) 200 if len(r) == 1: 201 return "", r[0] 202 else: 203 return r[0], r[1] 204 205 206class DefaultQuant(MinMaxObserver): 207 def quantize(self, input): 208 assert self.all_tensors 209 scale, zeropoint = self.scale_zeropoint() 210 return torch.quantize_per_tensor( 211 Proxy(input), scale, zeropoint, torch.quint8 212 ).node 213 214 215def matches(modules, node, pattern, max_uses=sys.maxsize): 216 if isinstance(pattern, tuple): 217 self_match, *arg_matches = pattern 218 else: 219 self_match = pattern 220 arg_matches = None 221 222 if len(node.users) > max_uses: 223 return False 224 225 if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): 226 if node.op != "call_module": 227 return False 228 if not isinstance(modules[node.target], self_match): 229 return False 230 elif callable(self_match): 231 if node.op != "call_function" or node.target is not self_match: 232 return False 233 elif node.target != self_match: 234 return False 235 236 if not arg_matches: 237 return True 238 239 if len(arg_matches) != len(node.args): 240 return False 241 242 return all( 243 matches(modules, node, arg_match, max_uses=1) 244 for node, arg_match in zip(node.args, arg_matches) 245 ) 246 247 248class Quantizer: 249 def __init__( 250 self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant 251 ): 252 self.root = mod 253 self.graph = mod.graph 254 self.quant_ctor = quant_ctor 255 256 # cached information for observe 257 self.state_dict = self.root.state_dict() 258 self.modules = dict(self.root.named_modules()) 259 260 # match the patterns that will get quantized 261 self.matches = self._find_matches(patterns) 262 # find _inputs_ to matched nodes that are not quantized, these 263 # have to be quantized, which requires measuring stats, 264 # initialize an quant_ctor object for each 265 self.quants = self._find_quants(quant_ctor) 266 267 def observe(self, args): 268 # most of this function is just an interpreter for the graph 269 # it would be possible to put this in some abstraction, but 270 # it is pretty nice to just be able to see exactly what is happening here 271 # and hack on it. 272 # maybe we should just provide an example interpreter that people copy/paste 273 # then edit. 274 args_iter = iter(args) 275 env = {} 276 277 def load_arg(a): 278 return map_arg(a, lambda node: env[node.name]) 279 280 output_node: Optional[Node] = None 281 for node in self.graph.nodes: 282 if node.op == "placeholder": 283 result = next(args_iter) 284 elif node.op == "get_attr": 285 result = self.state_dict[node.target] 286 elif node.op == "call_function": 287 result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) 288 elif node.op == "call_method": 289 self_obj, *args = load_arg(node.args) 290 kwargs = load_arg(node.kwargs) 291 result = getattr(self_obj, node.target)(*args, **kwargs) 292 elif node.op == "call_module": 293 result = self.modules[node.target]( 294 *load_arg(node.args), **load_arg(node.kwargs) 295 ) 296 elif node.op == "output": 297 return load_arg(node.args[0]) 298 299 env[node.name] = result 300 root_node, obj = self.matches.get(node.name, (None, None)) 301 if root_node is node: 302 obj.observe(node, env) 303 if node.name in self.quants: 304 self.quants[node.name].observe(node, env) 305 306 raise RuntimeError("Graph had no output node!") 307 308 def quantize(self): 309 self.quantized_graph = Graph() 310 311 env = {} 312 quant_env = {} 313 314 def load_arg(n, quantized): 315 if not quantized: 316 if n.name not in env and n.name in quant_env: 317 env[n.name] = Proxy(quant_env[n.name]).dequantize().node 318 return env[n.name] 319 else: 320 if n.name not in quant_env and n.name in env: 321 quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) 322 return quant_env[n.name] 323 324 def copy_recursive(node): 325 def load_or_emit(n): 326 if n.name in env or e.name in quant_env: # noqa: F821 327 return load_arg(n, quantized=False) 328 else: 329 return copy_recursive(n) 330 331 r = env[node.name] = self.quantized_graph.node_copy( 332 node, lambda n: load_arg(n, quantized=False) 333 ) 334 return r 335 336 for node in self.graph.nodes: 337 root_node, obj = self.matches.get(node.name, (None, None)) 338 if root_node is None: 339 # not quantized just copy it 340 env[node.name] = self.quantized_graph.node_copy( 341 node, lambda n: load_arg(n, quantized=False) 342 ) 343 344 elif root_node is node: 345 r = obj.quantize( 346 self, 347 node, 348 lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)), 349 ) 350 if r is NotImplemented: 351 # quantizer choose to to quantize the node take the entire match, and just copy it over 352 env[node.name] = copy_recursive(node) 353 else: 354 quant_env[node.name] = r 355 356 return GraphModule(self.root, self.quantized_graph) 357 358 def _find_matches(self, patterns): 359 modules = dict(self.root.named_modules()) 360 match_map = {} # node name -> (root_node, match_value?) 361 362 def apply_match(pattern, node, match): 363 if isinstance(pattern, tuple): 364 s, *args = pattern 365 apply_match(s, node, match) 366 for subpattern, arg in zip(args, node.args): 367 apply_match(subpattern, arg, match) 368 else: 369 match_map[node.name] = match 370 371 for node in reversed(self.graph.nodes): 372 if node.name not in match_map: 373 for pattern, value in patterns.items(): 374 if matches(modules, node, pattern): 375 apply_match(pattern, node, (node, value(self, node))) 376 377 return match_map 378 379 def _find_quants(self, quant_ctor): 380 quants = {} 381 382 def visit_arg(n): 383 # note: we have to measure quantization information 384 # even for nodes where we might not use it because it is already 385 # quantized. This is because each match has the option to 386 # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) 387 if n.name not in quants: 388 quants[n.name] = quant_ctor(self, n) 389 390 for node in self.graph.nodes: 391 if node.name in self.matches: 392 map_arg(node.args, visit_arg) 393 map_arg(node.kwargs, visit_arg) 394 return quants 395