1# mypy: allow-untyped-defs 2from typing import Any, Callable, Dict, List, Optional, Set, Union 3 4import torch 5import torch.ao.nn.quantized as nnq 6import torch.ao.nn.quantized.dynamic as nnqd 7import torch.nn as nn 8from torch.ao.quantization import prepare 9from torch.ao.quantization.quantization_mappings import ( 10 get_default_compare_output_module_list, 11) 12 13 14NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { 15 nnqd.Linear, 16 nnq.Linear, 17 nnqd.LSTM, 18 nn.LSTM, 19} 20 21 22def _find_match( 23 str_list: Union[Dict[str, Any], List[str]], 24 key_str: str, 25 postfix: str, 26) -> Optional[str]: 27 split_str = key_str.split(".") 28 if split_str[-1] == postfix: 29 match_string = "".join(key_str.split(".")[0:-1]) 30 for s2 in str_list: 31 pattern1 = "".join(s2.split(".")[0:-1]) 32 pattern2 = "".join(s2.split(".")[0:-2]) 33 if match_string == pattern1: 34 return s2 35 if match_string == pattern2: 36 return s2 37 38 # For matching "fc.weight" and "fc._packed_params._packed_params" 39 if postfix == "_packed_params": 40 match_string = "".join(key_str.split(".")[0:-2]) 41 if len(match_string) == 0: 42 return None 43 for s2 in str_list: 44 pattern1 = "".join(s2.split(".")[0:-1]) 45 pattern2 = "".join(s2.split(".")[0:-2]) 46 if match_string == pattern1: 47 return s2 48 if match_string == pattern2: 49 return s2 50 return None 51 else: 52 return None 53 54 55def compare_weights( 56 float_dict: Dict[str, Any], quantized_dict: Dict[str, Any] 57) -> Dict[str, Dict[str, torch.Tensor]]: 58 r"""Compare the weights of the float module with its corresponding quantized 59 module. Return a dict with key corresponding to module names and each entry being 60 a dictionary with two keys 'float' and 'quantized', containing the float and 61 quantized weights. This dict can be used to compare and compute the quantization 62 error of the weights of float and quantized models. 63 64 Example usage:: 65 66 wt_compare_dict = compare_weights( 67 float_model.state_dict(), qmodel.state_dict()) 68 for key in wt_compare_dict: 69 print( 70 key, 71 compute_error( 72 wt_compare_dict[key]['float'], 73 wt_compare_dict[key]['quantized'].dequantize() 74 ) 75 ) 76 77 Args: 78 float_dict: state dict of the float model 79 quantized_dict: state dict of the quantized model 80 81 Return: 82 weight_dict: dict with key corresponding to module names and each entry being 83 a dictionary with two keys 'float' and 'quantized', containing the float and 84 quantized weights 85 """ 86 torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights") 87 weight_dict: Dict[str, Dict] = {} 88 for key in quantized_dict: 89 match_key = _find_match(float_dict, key, "weight") 90 if match_key is not None: 91 weight_dict[key] = {} 92 weight_dict[key]["float"] = float_dict[match_key] 93 weight_dict[key]["quantized"] = quantized_dict[key] 94 continue 95 96 # For matching "fc.weight" and "fc._packed_params._packed_params" 97 match_key = _find_match(float_dict, key, "_packed_params") 98 if match_key is not None: 99 weight_dict[key] = {} 100 weight_dict[key]["float"] = float_dict[match_key] 101 weight_dict[key]["quantized"] = quantized_dict[key][0] 102 103 # For LSTM 104 split_str = key.split(".") 105 if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": 106 layer = split_str[-2] 107 module_name = ".".join(split_str[:-3]) 108 float_weight_ih_key = module_name + ".weight_ih_l" + layer 109 float_weight_hh_key = module_name + ".weight_hh_l" + layer 110 if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: 111 weight_dict[key] = {} 112 weight_dict[key]["float"] = float_dict[float_weight_ih_key] 113 weight_dict[key]["quantized"] = ( 114 quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] 115 ) 116 weight_dict[key]["float"] = float_dict[float_weight_hh_key] 117 weight_dict[key]["quantized"] = ( 118 quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] 119 ) 120 121 return weight_dict 122 123 124def _get_logger_dict_helper( 125 mod: nn.Module, 126 target_dict: Dict[str, Any], 127 prefix: str = "", 128) -> None: 129 r"""This is the helper function for get_logger_dict 130 131 Args: 132 mod: module we want to save all logger stats 133 prefix: prefix for the current module 134 target_dict: the dictionary used to save all logger stats 135 """ 136 137 def get_prefix(prefix): 138 return prefix if prefix == "" else prefix + "." 139 140 for name, child in mod.named_children(): 141 if isinstance(child, Logger): 142 target_dict[get_prefix(prefix) + "stats"] = child.stats 143 break 144 145 for name, child in mod.named_children(): 146 module_prefix = get_prefix(prefix) + name if prefix else name 147 _get_logger_dict_helper(child, target_dict, module_prefix) 148 149 150def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]: 151 r"""Traverse the modules and save all logger stats into target dict. 152 This is mainly used for quantization accuracy debug. 153 154 Type of loggers supported: 155 ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module, 156 OutputLogger: used to log the outputs of the modules 157 158 Args: 159 mod: module we want to save all logger stats 160 prefix: prefix for the current module 161 162 Return: 163 target_dict: the dictionary used to save all logger stats 164 165 """ 166 torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict") 167 168 target_dict: Dict[str, Dict] = {} 169 _get_logger_dict_helper(mod, target_dict, prefix) 170 return target_dict 171 172 173class Logger(nn.Module): 174 r"""Base class for stats logging""" 175 176 def __init__(self): 177 super().__init__() 178 self.stats = {} 179 # We only insert observer if the op is quantized with static quantization, 180 # which is identified by activation_observer.dtype == quint8. This is needed 181 # when attaching Logger as observer for FX mode 182 self.dtype = torch.quint8 183 184 def forward(self, x): 185 # fmt: off 186 """ 187 """ # blank docblock to make autodoc happy 188 # fmt: on 189 190 191class ShadowLogger(Logger): 192 r"""Class used in Shadow module to record the outputs of the original and 193 shadow modules. 194 """ 195 196 def __init__(self): 197 super().__init__() 198 self.stats["float"] = [] 199 self.stats["quantized"] = [] 200 201 def forward(self, x, y): 202 # fmt: off 203 """ 204 """ # blank docblock to make autodoc happy 205 # fmt: on 206 if len(x) > 1: 207 x = x[0] 208 if len(y) > 1: 209 y = y[0] 210 self.stats["quantized"].append(x.detach()) 211 self.stats["float"].append(y.detach()) 212 213 214class OutputLogger(Logger): 215 r"""Class used to log the outputs of the module""" 216 217 def __init__(self): 218 super().__init__() 219 self.stats["tensor_val"] = [] 220 221 def forward(self, x): 222 # fmt: off 223 """ 224 """ # blank docblock to make autodoc happy 225 # fmt: on 226 self.stats["tensor_val"].append(x) 227 return x 228 229 230def _convert_tuple_to_list(t: Any) -> Any: 231 return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t 232 233 234def _dequantize_tensor_list(t: Any) -> Any: 235 return ( 236 [_dequantize_tensor_list(x) for x in t] 237 if type(t) is list 238 else t.dequantize() 239 if t.is_quantized 240 else t 241 ) 242 243 244class Shadow(nn.Module): 245 r"""Shadow module attaches the float module to its matching quantized module 246 as the shadow. Then it uses Logger module to process the outputs of both 247 modules. 248 249 Args: 250 q_module: module quantized from float_module that we want to shadow 251 float_module: float module used to shadow q_module 252 logger_cls: type of logger used to process the outputs of q_module and 253 float_module. ShadowLogger or custom loggers can be used. 254 """ 255 256 def __init__(self, q_module, float_module, logger_cls): 257 super().__init__() 258 self.orig_module = q_module 259 self.shadow_module = float_module 260 self.dequant = nnq.DeQuantize() 261 self.logger = logger_cls() 262 263 def forward(self, *x) -> torch.Tensor: 264 # fmt: off 265 """ 266 """ # blank docblock to make autodoc happy 267 # fmt: on 268 xl = _convert_tuple_to_list(x) 269 output = self.orig_module(*xl) 270 xl_float = _dequantize_tensor_list(xl) 271 shadow_output = self.shadow_module(*xl_float) 272 self.logger(output, shadow_output) 273 return output 274 275 def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 276 # fmt: off 277 """ 278 """ # blank docblock to make autodoc happy 279 # fmt: on 280 output = self.orig_module.add(x, y) 281 x = x.dequantize() 282 y = y.dequantize() 283 shadow_output = self.shadow_module.add(x, y) 284 self.logger(output, shadow_output) 285 return output 286 287 def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: 288 # fmt: off 289 """ 290 """ # blank docblock to make autodoc happy 291 # fmt: on 292 output = self.orig_module.add_scalar(x, y) 293 x = x.dequantize() 294 shadow_output = self.shadow_module.add_scalar(x, y) 295 self.logger(output, shadow_output) 296 return output 297 298 def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 299 # fmt: off 300 """ 301 """ # blank docblock to make autodoc happy 302 # fmt: on 303 output = self.orig_module.mul(x, y) 304 x = x.dequantize() 305 y = y.dequantize() 306 shadow_output = self.shadow_module.mul(x, y) 307 self.logger(output, shadow_output) 308 return output 309 310 def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: 311 # fmt: off 312 """ 313 """ # blank docblock to make autodoc happy 314 # fmt: on 315 output = self.orig_module.mul_scalar(x, y) 316 x = x.dequantize() 317 shadow_output = self.shadow_module.mul_scalar(x, y) 318 self.logger(output, shadow_output) 319 return output 320 321 def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor: 322 # fmt: off 323 """ 324 """ # blank docblock to make autodoc happy 325 # fmt: on 326 output = self.orig_module.cat(x, dim) 327 x = [y.dequantize() for y in x] 328 shadow_output = self.shadow_module.cat(x, dim) 329 self.logger(output, shadow_output) 330 return output 331 332 def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 333 # fmt: off 334 """ 335 """ # blank docblock to make autodoc happy 336 # fmt: on 337 output = self.orig_module.add_relu(x, y) 338 x = x.dequantize() 339 y = y.dequantize() 340 shadow_output = self.shadow_module.add_relu(x, y) 341 self.logger(output, shadow_output) 342 return output 343 344 345def prepare_model_with_stubs( 346 float_module: nn.Module, 347 q_module: nn.Module, 348 module_swap_list: Set[type], 349 logger_cls: Callable, 350) -> None: 351 r"""Prepare the model by attaching the float module to its matching quantized 352 module as the shadow if the float module type is in module_swap_list. 353 354 Example usage:: 355 356 prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) 357 q_model(data) 358 ob_dict = get_logger_dict(q_model) 359 360 Args: 361 float_module: float module used to generate the q_module 362 q_module: module quantized from float_module 363 module_swap_list: list of float module types to attach the shadow 364 logger_cls: type of logger to be used in shadow module to process the outputs of 365 quantized module and its float shadow module 366 """ 367 torch._C._log_api_usage_once( 368 "quantization_api._numeric_suite.prepare_model_with_stubs" 369 ) 370 371 float_module_children = {} 372 for name, mod in float_module.named_children(): 373 float_module_children[name] = mod 374 375 reassign = {} 376 for name, mod in q_module.named_children(): 377 if name not in float_module_children: 378 continue 379 380 float_mod = float_module_children[name] 381 382 if type(float_mod) not in module_swap_list: 383 prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls) 384 385 # Insert shadow module only if the module is not of the same type as 386 # the floating point module 387 if type(float_mod) in module_swap_list and not _is_identical_module_type( 388 mod, float_mod 389 ): 390 reassign[name] = Shadow(mod, float_mod, logger_cls) 391 392 for key, value in reassign.items(): 393 q_module._modules[key] = value 394 395 396def _is_identical_module_type(mod1, mod2): 397 # Compare if two modules have the same dtype 398 mod1_module_types = [type(mod) for mod in mod1.modules()] 399 mod2_module_types = [type(mod) for mod in mod2.modules()] 400 return mod1_module_types == mod2_module_types 401 402 403def compare_model_stub( 404 float_model: nn.Module, 405 q_model: nn.Module, 406 module_swap_list: Set[type], 407 *data, 408 logger_cls=ShadowLogger, 409) -> Dict[str, Dict]: 410 r"""Compare quantized module in a model with its floating point counterpart, 411 feeding both of them the same input. Return a dict with key corresponding to 412 module names and each entry being a dictionary with two keys 'float' and 413 'quantized', containing the output tensors of quantized and its matching 414 float shadow module. This dict can be used to compare and compute the module 415 level quantization error. 416 417 This function first call prepare_model_with_stubs() to swap the quantized 418 module that we want to compare with the Shadow module, which takes quantized 419 module, corresponding float module and logger as input, and creates a forward 420 path inside to make the float module to shadow quantized module sharing the 421 same input. The logger can be customizable, default logger is ShadowLogger 422 and it will save the outputs of the quantized module and float module that 423 can be used to compute the module level quantization error. 424 425 Example usage:: 426 427 module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock] 428 ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data) 429 for key in ob_dict: 430 print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize())) 431 432 Args: 433 float_model: float model used to generate the q_model 434 q_model: model quantized from float_model 435 module_swap_list: list of float module types at which shadow modules will 436 be attached. 437 data: input data used to run the prepared q_model 438 logger_cls: type of logger to be used in shadow module to process the outputs of 439 quantized module and its float shadow module 440 """ 441 torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub") 442 prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls) 443 q_model(*data) 444 ob_dict = get_logger_dict(q_model) 445 return ob_dict 446 447 448def get_matching_activations( 449 float_module: nn.Module, 450 q_module: nn.Module, 451) -> Dict[str, Dict[str, torch.Tensor]]: 452 r"""Find the matching activation between float and quantized modules. 453 454 Args: 455 float_module: float module used to generate the q_module 456 q_module: module quantized from float_module 457 458 Return: 459 act_dict: dict with key corresponding to quantized module names and each 460 entry being a dictionary with two keys 'float' and 'quantized', containing 461 the matching float and quantized activations 462 """ 463 torch._C._log_api_usage_once( 464 "quantization_api._numeric_suite.get_matching_activations" 465 ) 466 float_dict = get_logger_dict(float_module) 467 quantized_dict = get_logger_dict(q_module) 468 act_dict: Dict[str, Dict] = {} 469 for key in quantized_dict: 470 if len(quantized_dict[key]["tensor_val"]) == 0: 471 continue 472 match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") 473 if match_key is not None: 474 act_dict[key] = {} 475 act_dict[key]["float"] = float_dict[match_key]["tensor_val"] 476 act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"] 477 return act_dict 478 479 480def prepare_model_outputs( 481 float_module: nn.Module, 482 q_module: nn.Module, 483 logger_cls=OutputLogger, 484 allow_list=None, 485) -> None: 486 r"""Prepare the model by attaching the logger to both float module 487 and quantized module if they are in the allow_list. 488 489 Args: 490 float_module: float module used to generate the q_module 491 q_module: module quantized from float_module 492 logger_cls: type of logger to be attached to float_module and q_module 493 allow_list: list of module types to attach logger 494 """ 495 torch._C._log_api_usage_once( 496 "quantization_api._numeric_suite.prepare_model_outputs" 497 ) 498 if allow_list is None: 499 allow_list = get_default_compare_output_module_list() 500 501 qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None) 502 float_module.qconfig = qconfig_debug # type: ignore[assignment] 503 prepare( 504 float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={} 505 ) 506 q_module.qconfig = qconfig_debug # type: ignore[assignment] 507 prepare( 508 q_module, 509 inplace=True, 510 allow_list=allow_list, 511 observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, 512 prepare_custom_config_dict={}, 513 ) 514 515 516def compare_model_outputs( 517 float_model: nn.Module, 518 q_model: nn.Module, 519 *data, 520 logger_cls=OutputLogger, 521 allow_list=None, 522) -> Dict[str, Dict[str, torch.Tensor]]: 523 r"""Compare output activations between float and quantized models at 524 corresponding locations for the same input. Return a dict with key corresponding 525 to quantized module names and each entry being a dictionary with two keys 526 'float' and 'quantized', containing the activations of quantized model and 527 float model at matching locations. This dict can be used to compare and 528 compute the propagation quantization error. 529 530 Example usage:: 531 532 act_compare_dict = compare_model_outputs(float_model, qmodel, data) 533 for key in act_compare_dict: 534 print( 535 key, 536 compute_error( 537 act_compare_dict[key]['float'], 538 act_compare_dict[key]['quantized'].dequantize() 539 ) 540 ) 541 542 Args: 543 float_model: float model used to generate the q_model 544 q_model: model quantized from float_model 545 data: input data used to run the prepared float_model and q_model 546 logger_cls: type of logger to be attached to float_module and q_module 547 allow_list: list of module types to attach logger 548 549 Return: 550 act_compare_dict: dict with key corresponding to quantized module names 551 and each entry being a dictionary with two keys 'float' and 'quantized', 552 containing the matching float and quantized activations 553 """ 554 torch._C._log_api_usage_once( 555 "quantization_api._numeric_suite.compare_model_outputs" 556 ) 557 if allow_list is None: 558 allow_list = get_default_compare_output_module_list() 559 prepare_model_outputs(float_model, q_model, logger_cls, allow_list) 560 float_model(*data) 561 q_model(*data) 562 act_compare_dict = get_matching_activations(float_model, q_model) 563 return act_compare_dict 564