1# mypy: allow-untyped-defs 2from abc import ABC, abstractmethod 3from typing import Any, Callable, Dict, List, Set, Tuple 4 5import torch 6import torch.ao.nn.qat as nnqat 7import torch.nn as nn 8from torch.ao.quantization.fake_quantize import FakeQuantize 9from torch.ao.quantization.fx._equalize import ( 10 default_equalization_qconfig, 11 EqualizationQConfig, 12) 13from torch.ao.quantization.fx._model_report.model_report_observer import ( 14 ModelReportObserver, 15) 16from torch.ao.quantization.fx.graph_module import GraphModule 17from torch.ao.quantization.observer import ( 18 _is_activation_post_process, 19 default_dynamic_quant_observer, 20 default_observer, 21 default_per_channel_weight_observer, 22 default_weight_observer, 23 ObserverBase, 24) 25from torch.ao.quantization.qconfig import ( 26 _assert_valid_qconfig, 27 default_qconfig, 28 QConfig, 29) 30 31 32# Names for observer insert keys 33DETECTOR_TARGET_NODE_KEY = "target_node" 34DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert" 35DETECTOR_IS_POST_OBS_KEY = "is_post_observer" 36DETECTOR_OBS_ARGS_KEY = "observer_args" 37 38 39# Mapping related code 40class DetectorQConfigInfo: 41 r""" 42 This class contains the QConfig information for a single module. 43 The list of variables / values this contains can grow depending on the 44 extensibility of the qconfig mapping feature set but this currently includes: 45 - if activation observer is dynamic 46 - if weight observer is per channel 47 48 49 Args: 50 module_fqn (str): The fully qualified name (fqn) of the module that this 51 information contains info relevant to qconfig for 52 """ 53 54 def __init__(self, module_fqn: str): 55 super().__init__() 56 self.module_fqn = module_fqn 57 58 # populate this section with all the variables we might find important 59 # change from none if your detector is actually using this 60 self.is_activation_dynamic = False 61 self.is_weight_per_channel = False 62 63 # equalization related options 64 self.is_equalization_recommended = False 65 66 def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig: 67 r""" 68 Args: 69 module (torch.nn.Module) The module we are generating 70 the qconfig for 71 72 Returns the generated quantization QConfig according to what a valid configuration is 73 """ 74 # Apply suggestions to new qconfig 75 module_qconfig = default_qconfig 76 77 # keep track of dynamic and per_channel recommendations 78 recommendations_list = [] 79 # append as if a list of combinations 80 recommendations_list.append( 81 (self.is_activation_dynamic, self.is_weight_per_channel) 82 ) 83 recommendations_list.append( 84 (self.is_activation_dynamic, False) 85 ) # only trying dynamic rec 86 recommendations_list.append( 87 (False, self.is_weight_per_channel) 88 ) # only trying dynamic 89 90 # now we try each of the combinations 91 for rec in recommendations_list: 92 # rec[0] -> dynamic recommended 93 # rec[1] -> per channel recommended 94 activation = default_dynamic_quant_observer if rec[0] else default_observer 95 weight = ( 96 default_per_channel_weight_observer 97 if rec[1] 98 else default_weight_observer 99 ) 100 test_config = QConfig(activation, weight) 101 try: 102 _assert_valid_qconfig(test_config, module) 103 module_qconfig = test_config 104 break 105 except AssertionError: 106 # if not a valid configuration, we move on to the next one in priority 107 continue 108 109 # return the QConfig chosen 110 return module_qconfig 111 112 def generate_equalization_qconfig(self) -> EqualizationQConfig: 113 r""" 114 This returns the equalization configuration for a module. 115 116 For now, it just returns the default, but as more equalization options become 117 possible, this method can get more fleshed out with more nuanced granularity. 118 119 120 Returns the generated equalization QConfig according to what a valid configuration is 121 """ 122 # in this case, we just return default equalization config 123 # we know this is valid because only valid modules would even 124 # have this option 125 return default_equalization_qconfig 126 127 128# Adding base class for detectors 129class DetectorBase(ABC): 130 r"""Base Detector Module 131 Any detector class should derive from this class. 132 133 Concrete detectors should follow the same general API, which includes: 134 - A method to calculate and return observer insertion points 135 - Should return both the fqns and the Observer class to insert 136 - A method to return a report based on the detector 137 - Should return a str-based report and dict info in Tuple[str,Dict] format 138 """ 139 140 def __init__(self) -> None: 141 super().__init__() 142 self.detector_config_info = None 143 144 @abstractmethod 145 def determine_observer_insert_points(self, model) -> Dict: 146 r""" 147 Args 148 model (nn.Module or subclass): model to find observer insertion points 149 150 Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict. 151 This dict maps string keys to detector specific information 152 """ 153 154 @abstractmethod 155 def get_detector_name(self) -> str: 156 r"""Returns the name of the current detector""" 157 158 @abstractmethod 159 def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: 160 r"""Returns the DetectorQConfigInfo for each module_fqn relevant 161 Args 162 model (nn.Module or subclass): model to find observer insertion points 163 164 Returns a Dict mapping from unique observer fqns (where we want to insert them) to: 165 A DetectorQConfigInfo with the information to generate a QConfig for a specific module 166 """ 167 168 def _get_targeting_node( 169 self, prepared_fx_model: GraphModule, target_fqn: str 170 ) -> torch.fx.node.Node: 171 r""" 172 Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn. 173 174 If it's not found, it means it is most likely inside a fused layer 175 We just go one layer up in terms of the fqn we are searching for until we find parent node 176 If we get to empty string, then we know that it doesn't exist 177 178 The reason for the recursion is that if the model that we are looking for got fused, 179 we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module, 180 which would have fqn as x.linear so they will not match. 181 To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear, 182 or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module 183 even in cases with fusion 184 185 Args: 186 prepared_fx_model (GraphModule): The prepared Fx GraphModule 187 target_fqn (str): The fqn of the layer we are trying to target 188 189 Returns the node object we are trying to add observers around 190 """ 191 for node in prepared_fx_model.graph.nodes: 192 # if the node's target is our target, return it 193 if node.target == target_fqn: 194 return node 195 196 # getting here means node not found 197 # if no "." we are already at base and failed 198 parent_fqn_sep_index = target_fqn.rfind(".") 199 if parent_fqn_sep_index == -1: 200 raise ValueError("passed in target_fqn not found in graph's targets.") 201 else: 202 # recursively call it with parent fqn 203 return self._get_targeting_node( 204 prepared_fx_model, target_fqn[:parent_fqn_sep_index] 205 ) 206 207 @abstractmethod 208 def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]: 209 r""" 210 Args 211 model (nn.Module or subclass): model to find observer insertion points 212 213 Returns a Tuple of two elements: 214 Str: string report of the suggested improvements 215 Dict: contains useful data collected by the observer pertinent to this report 216 """ 217 218 219class PerChannelDetector(DetectorBase): 220 r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization. 221 Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. 222 223 per_channel quantization can lead to major benefits in the form of accuracy. 224 Therefore, if the backend used by the user supports it, it is recommended to use 225 226 Args: 227 backend (str, optional): the backend the user wishes to use in production 228 Default value is current torch.backends.quantized.engine 229 """ 230 231 # Keys for return dictionary 232 BACKEND_KEY = "backend" 233 PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported" 234 PER_CHAN_USED_KEY = "per_channel_quantization_used" 235 236 # Default map for representing supported per channel quantization modules for different backends 237 DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = { 238 "fbgemm": { 239 nn.Linear, 240 nn.Conv1d, 241 nn.Conv2d, 242 nn.Conv3d, 243 nnqat.Linear, 244 nnqat.Conv1d, 245 nnqat.Conv2d, 246 nnqat.Conv3d, 247 }, 248 "qnnpack": { 249 nn.Linear, 250 nn.Conv1d, 251 nn.Conv2d, 252 nn.Conv3d, 253 nnqat.Linear, 254 nnqat.Conv1d, 255 nnqat.Conv2d, 256 nnqat.Conv3d, 257 }, 258 "onednn": { 259 nn.Linear, 260 nn.Conv1d, 261 nn.Conv2d, 262 nn.Conv3d, 263 nnqat.Linear, 264 nnqat.Conv1d, 265 nnqat.Conv2d, 266 nnqat.Conv3d, 267 }, 268 "x86": { 269 nn.Linear, 270 nn.Conv1d, 271 nn.Conv2d, 272 nn.Conv3d, 273 nnqat.Linear, 274 nnqat.Conv1d, 275 nnqat.Conv2d, 276 nnqat.Conv3d, 277 }, 278 } 279 280 def __init__(self, backend: str = torch.backends.quantized.engine): 281 super().__init__() 282 283 # store the backend information 284 self.backend_chosen = backend 285 self.supported_modules = set() 286 if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: 287 self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[ 288 self.backend_chosen 289 ] 290 else: 291 raise ValueError( 292 f"Not configured to work with {self.backend_chosen}. Try a different default backend" 293 ) 294 295 def get_detector_name(self) -> str: 296 r"""returns the string name of this detector""" 297 return "per_channel_detector" 298 299 def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: 300 r"""Returns the DetectorQConfigInfo for each module_fqn relevant 301 Args 302 model (nn.Module or subclass): model to find observer insertion points 303 304 Returns a Dict mapping from unique observer fqns (where we want to insert them) to: 305 A DetectorQConfigInfo with the information to generate a QConfig for a specific module 306 """ 307 # run the helper function to populate the dictionary 308 per_channel_info = self._detect_per_channel_helper(model) 309 310 # we actually have a qconfig info object we are populating 311 module_fqn_to_detector_qconfig_info = {} 312 313 for module_fqn in per_channel_info: 314 # create a detector info instance 315 detector_qconfig_info = DetectorQConfigInfo(module_fqn) 316 317 # see if per channel quantization is supported 318 per_chan_supported: bool = per_channel_info[module_fqn][ 319 self.PER_CHAN_SUPPORTED_KEY 320 ] 321 detector_qconfig_info.is_weight_per_channel = per_chan_supported 322 module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info 323 324 return module_fqn_to_detector_qconfig_info 325 326 def determine_observer_insert_points(self, model: nn.Module) -> Dict: 327 r""" 328 There is no observers inserted for the PerChannelDetector. 329 330 Returns an empty dictionary since no observers are added or needed 331 """ 332 return {} 333 334 def _detect_per_channel_helper(self, model: nn.Module): 335 r""" 336 determines if per_channel quantization is supported in modules and submodules. 337 338 Returns a dictionary in the higher level _detect_per_channel function. 339 Each entry maps the fully-qualified-name to information on whether per_channel quantization. 340 341 Args: 342 model: The current module that is being checked to see if it is per_channel quantizable 343 344 Returns dictionary mapping fqns to if per_channel quantization is possible 345 """ 346 # create dict we will return 347 per_channel_info: Dict = {} 348 349 # get the fully qualified name and check if in list of modules to include and list of modules to ignore 350 for fqn, module in model.named_modules(): 351 is_in_include_list = any( 352 isinstance(module, x) for x in self.supported_modules 353 ) 354 355 # check if the module per_channel is supported 356 # based on backend 357 per_channel_supported = False 358 359 if is_in_include_list: 360 per_channel_supported = True 361 362 # assert statement for MyPy 363 q_config_file = module.qconfig 364 assert isinstance(q_config_file, QConfig) 365 366 # this object should either be fake quant or observer 367 q_or_s_obj = module.qconfig.weight.p.func() 368 assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)) 369 370 per_channel_used = False # will be true if found in qconfig 371 372 if hasattr( 373 q_or_s_obj, "ch_axis" 374 ): # then we know that per_channel quantization used 375 # all fake quants have channel axis so need to check is_per_channel 376 if isinstance(q_or_s_obj, FakeQuantize): 377 if ( 378 hasattr(q_or_s_obj, "is_per_channel") 379 and q_or_s_obj.is_per_channel 380 ): 381 per_channel_used = True 382 elif isinstance(q_or_s_obj, ObserverBase): 383 # should be an observer otherwise 384 per_channel_used = True 385 else: 386 raise ValueError("Should be either observer or fake quant") 387 388 per_channel_info[fqn] = { 389 self.PER_CHAN_SUPPORTED_KEY: per_channel_supported, 390 self.PER_CHAN_USED_KEY: per_channel_used, 391 self.BACKEND_KEY: self.backend_chosen, 392 } 393 394 return per_channel_info 395 396 def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]: 397 r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization. 398 Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. 399 400 Looks at q_config format and backend to determine if per_channel can be utilized. 401 Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support 402 403 Args: 404 model: The prepared and calibrated model we want to check if using per_channel 405 406 Returns a tuple with two elements: 407 String report of potential actions to improve model (if per_channel quantization is available in backend) 408 Dictionary mapping per_channel quantizable elements to: 409 whether per_channel quantization is supported by the backend 410 if it is being utilized in the current model 411 """ 412 413 # run the helper function to populate the dictionary 414 per_channel_info = self._detect_per_channel_helper(model) 415 416 # String to let the user know of further optimizations 417 further_optims_str = ( 418 f"Further Optimizations for backend {self.backend_chosen}: \n" 419 ) 420 421 optimizations_possible = False 422 for fqn in per_channel_info: 423 fqn_dict = per_channel_info[fqn] 424 if ( 425 fqn_dict[self.PER_CHAN_SUPPORTED_KEY] 426 and not fqn_dict[self.PER_CHAN_USED_KEY] 427 ): 428 optimizations_possible = True 429 further_optims_str += ( 430 f"Module {fqn} can be configured to use per_channel quantization.\n" 431 ) 432 433 if optimizations_possible: 434 further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer." 435 else: 436 further_optims_str += "No further per_channel optimizations possible." 437 438 # return the string and the dictionary form of same information 439 return (further_optims_str, per_channel_info) 440 441 442class DynamicStaticDetector(DetectorBase): 443 r""" 444 Determines whether dynamic or static quantization is more appropriate for a given module. 445 446 Takes advantage of the ModelReportObserver that records range information. 447 Stationary distribution of data are strictly above tolerance level for the comparison statistic: 448 449 S = average_batch_activation_range/epoch_activation_range 450 451 Nonstationary distributions are below or at the tolerance level for this metric. 452 453 If the distribution of data right after the module is non-stationary, recommend dynamic quantization 454 Otherwise recommend static quantization 455 456 Args: 457 tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5 458 """ 459 # names for the pre and post observers that are inserted 460 DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer" 461 DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" 462 463 # naming conventions for stationary vs non-stationary data 464 STATIONARY_STR = "stationary" 465 NON_STATIONARY_STR = "non-stationary" 466 467 # naming for activation 468 INPUT_ACTIVATION_PREFIX = "input_activation_" 469 OUTPUT_ACTIVATION_PREFIX = "output_activation_" 470 471 # naming conventions for the keys of the return module info 472 TOLERANCE_KEY = "dynamic_static_tolerance" 473 DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended" 474 PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" 475 POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" 476 PRE_OBS_DATA_DIST_KEY = ( 477 INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" 478 ) 479 POST_OBS_DATA_DIST_KEY = ( 480 OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" 481 ) 482 IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" 483 484 # modules that are supported both dynamic and static for this report function 485 DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear} 486 487 # modules that will be supported soon for both 488 DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d} 489 490 def __init__(self, tolerance=0.5): 491 super().__init__() 492 493 # set tolerance level and initialize a set to keep track of useful fqn locations 494 self.tolerance = tolerance 495 self.useful_observer_fqns: Set[str] = set() 496 497 def determine_observer_insert_points( 498 self, prepared_fx_model: GraphModule 499 ) -> Dict[str, Dict[str, Any]]: 500 r""" 501 Determines where observers need to be inserted for the Dynamic vs Static detector. 502 For this detector, we want to place observers on either side of linear layers in the model. 503 504 Currently inserts observers for: 505 linear layers 506 507 Args: 508 prepared_fx_model (GraphModule): The prepared Fx GraphModule 509 510 Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: 511 key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) 512 key "observer_to_insert" -> the observer we wish to insert (ObserverBase) 513 key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer 514 key "observer_args" -> The arguments that are meant to be passed into the observer 515 """ 516 517 # observer for this detector is ModelReportObserver 518 obs_ctr = ModelReportObserver 519 520 # return dict 521 obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} 522 523 for fqn, module in prepared_fx_model.named_modules(): 524 # make sure module is supported 525 if self._is_supported(module, insert=True): 526 # if it's a supported type, we want to get node and add observer insert locations 527 targeted_node = self._get_targeting_node(prepared_fx_model, fqn) 528 529 # add entry for pre-observer 530 pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME 531 532 obs_fqn_to_info[pre_obs_fqn] = { 533 DETECTOR_TARGET_NODE_KEY: targeted_node, 534 DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), 535 DETECTOR_IS_POST_OBS_KEY: False, 536 DETECTOR_OBS_ARGS_KEY: targeted_node.args, 537 } 538 539 # add entry for post-observer 540 post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME 541 542 obs_fqn_to_info[post_obs_fqn] = { 543 DETECTOR_TARGET_NODE_KEY: targeted_node, 544 DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), 545 DETECTOR_IS_POST_OBS_KEY: True, 546 DETECTOR_OBS_ARGS_KEY: (targeted_node,), 547 } 548 549 return obs_fqn_to_info 550 551 def get_detector_name(self) -> str: 552 r"""returns the string name of this detector""" 553 return "dynamic_vs_static_detector" 554 555 def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: 556 r"""Returns the DetectorQConfigInfo for each module_fqn relevant 557 Args 558 model (nn.Module or subclass): model to find observer insertion points 559 560 Returns a Dict mapping from unique observer fqns (where we want to insert them) to: 561 A DetectorQConfigInfo with the information to generate a QConfig for a specific module 562 """ 563 # run the helper function to populate the dictionary 564 dynamic_static_info = self._generate_dict_info(model) 565 566 # we actually have a qconfig info object we are populating 567 module_fqn_to_detector_qconfig_info = {} 568 569 for module_fqn in dynamic_static_info: 570 # create a detector info instance 571 detector_qconfig_info = DetectorQConfigInfo(module_fqn) 572 573 # see if per channel quantization is supported 574 dynamic_static_recommended: bool = dynamic_static_info[module_fqn][ 575 self.DEFAULT_DYNAMIC_REC_KEY 576 ] 577 detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended 578 module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info 579 580 return module_fqn_to_detector_qconfig_info 581 582 def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: 583 r"""Returns whether the given module is supported for observers 584 585 Args 586 module: The module to check and ensure is supported 587 insert: True if this is check for observer insertion, false if for report gen 588 589 Returns True if the module is supported by observer, False otherwise 590 """ 591 # check to see if module is of a supported type 592 is_supported_type = any( 593 isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED 594 ) 595 596 # check if it will be supported 597 future_supported_type = any( 598 isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED 599 ) 600 601 # supported 602 supported = is_supported_type or future_supported_type 603 604 # this is check for observer insertion 605 if insert: 606 return supported 607 else: 608 # this is for report gen and we also need to check if it contains observers 609 has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr( 610 module, self.DEFAULT_POST_OBSERVER_NAME 611 ) 612 return supported and has_obs 613 614 def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]: 615 r""" 616 Helper function for generate_detector_report that does the generation of the dictionary. 617 This process is done as specified in generate_detector_report documentation 618 619 Args: 620 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 621 622 Returns a Dictionary mapping modules with ModelReportObservers around them to: 623 whether dynamic quantization is recommended 624 their S metric of input to module 625 whether input to module is stationary or non-stationary 626 their S metric of output of module 627 whether output of module is stationary or non-stationary 628 the tolerance level to decided whether input/output is stationary or non-stationary 629 whether it is currently supported or planned for the future 630 """ 631 # store modules dynamic vs static information 632 module_dynamic_static_info = {} 633 634 # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info 635 # This information primary includes whether the data distributions around a supported module is stationary or not 636 # Based on this, it is recorded whether dynamic or static quantization is recommended 637 638 # loop through all submodules included nested ones 639 for fqn, module in model.named_modules(): 640 # if module is Linear has the ModelReportObserver attached to it 641 if self._is_supported(module): 642 # get pre and post observers for the module 643 pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) 644 post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME) 645 646 # get the statistics for each module 647 pre_stat = pre_obs.get_batch_to_epoch_ratio() 648 post_stat = post_obs.get_batch_to_epoch_ratio() 649 650 # record module, pre and post stat, and whether to do dynamic or static based off it 651 # true if post observer data distribution is non-stationary, false if it's stationary 652 dynamic_recommended = post_stat <= self.tolerance 653 654 # specify the classifications for whether data distributions considered stationary or non-stationary 655 pre_obs_dist_classif = ( 656 self.STATIONARY_STR 657 if pre_stat > self.tolerance 658 else self.NON_STATIONARY_STR 659 ) 660 post_obs_dist_classif = ( 661 self.STATIONARY_STR 662 if post_stat > self.tolerance 663 else self.NON_STATIONARY_STR 664 ) 665 666 # check if current support or future support 667 is_supported_type = any( 668 isinstance(module, x) 669 for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED 670 ) 671 672 # store the set of important information for this module 673 module_info = { 674 self.TOLERANCE_KEY: self.tolerance, 675 self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended, 676 self.PRE_OBS_COMP_STAT_KEY: pre_stat, 677 self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, 678 self.POST_OBS_COMP_STAT_KEY: post_stat, 679 self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, 680 self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type, 681 } 682 683 module_dynamic_static_info[fqn] = module_info 684 685 return module_dynamic_static_info 686 687 def generate_detector_report( 688 self, model: GraphModule 689 ) -> Tuple[str, Dict[str, Any]]: 690 r""" 691 Determines whether dynamic or static quantization is more appropriate for a given module. 692 693 Takes advantage of the ModelReportObserver that records range information. 694 Stationary distribution of data are strictly above tolerance level for the comparison statistic: 695 696 S = average_batch_activation_range/epoch_activation_range 697 698 Nonstationary distributions are below or at the tolerance level for this metric. 699 700 If the distribution of data right after the module is non-stationary, recommend dynamic quantization 701 Otherwise recommend static quantization 702 703 This will then generate suggestions for dynamic vs static quantization focused around Linear. 704 705 Args: 706 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 707 708 Returns a tuple with two elements: 709 String report of of whether dynamic or static quantization is recommended for certain modules 710 Dictionary mapping modules with ModelReportObservers around them to: 711 whether dynamic quantization is recommended 712 their S metric of input to module 713 whether input to module is stationary or non-stationary 714 their S metric of output of module 715 whether output of module is stationary or non-stationary 716 the tolerance level to decided whether input/output is stationary or non-stationary 717 whether it is currently supported or planned for the future 718 """ 719 720 # get the dictionary of the information to format the string report 721 module_dynamic_static_info = self._generate_dict_info(model) 722 723 dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n" 724 725 modules_added: bool = False # check to make sure at least 1 module added. 726 727 dynamic_benefit = ( 728 " You will get more accurate results if you use dynamic quantization" 729 ) 730 static_benefit = ( 731 " You can increase model efficiency if you use static quantization" 732 ) 733 future_support_str = ( 734 ". This layer is not yet supported for dynamic quantization" 735 ) 736 # This for loop goes through the information collected in module_dynamic_static_info and: 737 # Populates the string based report with the information from module_dynamic_static_info 738 # Compiles the complete report by appending relevant formatted strings 739 740 for module_fqn in module_dynamic_static_info.keys(): 741 # there is at least 1 module for suggestion 742 modules_added = True 743 module_info = module_dynamic_static_info[module_fqn] 744 suggestion_string_template = ( 745 "For module {} it is suggested to use {} quantization because {}.\n" 746 ) 747 748 # decide what string formatting values will be 749 quantization_type = "" 750 quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}." 751 752 benefit_str = "" 753 754 # strings for if dynamic quantized per tensor is needed 755 recommend_per_tensor = ( 756 ". We recommend to add a {} before this module if it is static." 757 ) 758 rec_lay_to_add = "dynamic quantize per tensor layer" 759 dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add) 760 dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution" 761 762 # start composing explanation 763 if module_info[self.DEFAULT_DYNAMIC_REC_KEY]: 764 quantization_type = "dynamic" 765 # check if currently supported or future supported 766 benefit_str = dynamic_benefit 767 if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]: 768 benefit_str += future_support_str 769 else: 770 quantization_type = "static" 771 benefit_str = static_benefit 772 773 # now set the quantization explanation string 774 quantization_reasoning = ( 775 quantization_reasoning.format( 776 module_fqn, 777 module_info[self.PRE_OBS_DATA_DIST_KEY], 778 module_info[self.POST_OBS_DATA_DIST_KEY], 779 ) 780 + benefit_str 781 ) 782 783 # if we have a non-stationary input -> linear -> stationary we suggested static 784 # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made 785 if ( 786 module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR 787 and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR 788 ): 789 quantization_reasoning = ( 790 quantization_reasoning 791 + dynamic_per_tensor_string 792 + dynamic_per_tensor_reasoning_string 793 ) 794 795 # format the overall suggestion string with the specific inputs 796 module_suggestion_string = suggestion_string_template.format( 797 module_fqn, quantization_type, quantization_reasoning 798 ) 799 800 # append to overall suggestion 801 dynamic_vs_static_string += module_suggestion_string 802 803 if not modules_added: 804 dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n" 805 806 # return the string as well as the dictionary of information 807 return (dynamic_vs_static_string, module_dynamic_static_info) 808 809 810class InputWeightEqualizationDetector(DetectorBase): 811 r""" 812 Determines whether input-weight equalization can help improve quantization for certain modules. 813 814 Specifically, this list of modules includes: 815 linear 816 conv 817 818 Determines whether input-weight equalization is recommended based on the comp stat: 819 s_c = sqrt(w_c/W)/sqrt(i_c/I) 820 where: 821 w_c is range of weight for channel c, W is range of weight over all channels 822 i_c is range of input for channel c, I is range of input over all channels 823 824 if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization 825 826 Args: 827 ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested 828 Should be between 0 and 1 (both non-inclusive) 829 ch_axis (int, optional): The channel axis being observed to determine input weight equalization 830 Default: 1 831 832 * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested 833 Should be between 0 and 1 834 835 * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization 836 837 * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization 838 839 * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector 840 """ 841 842 SUPPORTED_MODULES: Set[Callable] = { 843 nn.Linear, 844 nn.Conv1d, 845 nn.Conv2d, 846 nn.Conv3d, 847 nnqat.Linear, 848 nnqat.Conv1d, 849 nnqat.Conv2d, 850 nnqat.Conv3d, 851 } 852 853 # names for the pre and post observers that are inserted 854 DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" 855 856 # weight / activation prefix for each of the below info 857 WEIGHT_PREFIX = "weight_" 858 ACTIVATION_PREFIX = "input_activation_" 859 860 # string names for keys of info dictionaries 861 PER_CHANNEL_MAX_KEY = "per_channel_max" 862 PER_CHANNEL_MIN_KEY = "per_channel_min" 863 GLOBAL_MAX_KEY = "global_max" 864 GLOBAL_MIN_KEY = "global_min" 865 866 # keys for return dict of recommendations 867 RECOMMENDED_KEY = "input_weight_equalization_recommended" 868 COMP_METRIC_KEY = "input_weight_channel_comparison_metrics" 869 THRESHOLD_KEY = "input_weight_threshold" 870 CHANNEL_KEY = "input_weight_channel_axis" 871 872 # default weight and info strings 873 WEIGHT_STR = "weight" 874 INPUT_STR = "input" 875 876 # default for what ratio we recommend input weight 877 DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4 878 879 def __init__(self, ratio_threshold: float, ch_axis: int = 1): 880 # ensure passed in inputs are valid 881 if ratio_threshold <= 0 or ratio_threshold >= 1: 882 raise ValueError("Make sure threshold is > 0 and < 1") 883 884 # initialize attributes based on args 885 self.ratio_threshold: float = ratio_threshold 886 self.ch_axis: int = ch_axis 887 888 def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: 889 r"""Returns whether the given module is supported for observers 890 891 Args 892 module: The module to check and ensure is supported 893 insert: True if this is check for observer insertion, false if for report gen 894 895 Returns True if the module is supported by observer, False otherwise 896 """ 897 # check to see if module is of a supported type 898 is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES) 899 900 # this is check for observer insertion 901 if insert: 902 return is_supported_type 903 else: 904 # this is for report gen and we also need to check if it contains observers 905 has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) 906 return is_supported_type and has_obs 907 908 def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: 909 r"""Returns the DetectorQConfigInfo for each module_fqn relevant 910 Args 911 model (nn.Module or subclass): model to find observer insertion points 912 913 Returns a Dict mapping from unique observer fqns (where we want to insert them) to: 914 A DetectorQConfigInfo with the information to generate a QConfig for a specific module 915 """ 916 # run the helper function to populate the dictionary 917 # find the range of inputs 918 input_values: Dict[str, Dict] = self._extract_input_info(model) 919 920 # find the range of weights 921 weight_values: Dict[str, Dict] = self._extract_weight_info(model) 922 923 # calculate per_channel comparison statistic s_c 924 comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values( 925 input_values, weight_values 926 ) 927 928 # generate the return dictionary 929 input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info( 930 input_values, weight_values, comp_stats 931 ) 932 933 # we actually have a qconfig info object we are populating 934 module_fqn_to_detector_qconfig_info = {} 935 936 for module_fqn in input_weight_equalization_info: 937 # create a detector info instance 938 detector_qconfig_info = DetectorQConfigInfo(module_fqn) 939 940 # see if per channel quantization is supported 941 input_weight_recommended: bool = input_weight_equalization_info[module_fqn][ 942 self.RECOMMENDED_KEY 943 ] 944 detector_qconfig_info.is_equalization_recommended = input_weight_recommended 945 module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info 946 947 return module_fqn_to_detector_qconfig_info 948 949 def determine_observer_insert_points( 950 self, prepared_fx_model: GraphModule 951 ) -> Dict[str, Dict[str, Any]]: 952 r"""Determines where observers need to be inserted for the Input Weight Equalization Detector. 953 For this detector, we want to place observers in front of supported layers. 954 955 Currently inserts observers for: 956 linear layers 957 conv layers 958 959 Args: 960 prepared_fx_model (GraphModule): The prepared Fx GraphModule 961 962 Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: 963 key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) 964 key "observer_to_insert" -> the observer we wish to insert (ObserverBase) 965 key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer 966 key "observer_args" -> The arguments that are meant to be passed into the observer 967 """ 968 969 # observer for this detector is ModelReportObserver 970 obs_ctr = ModelReportObserver 971 972 # return dict 973 obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} 974 975 for fqn, module in prepared_fx_model.named_modules(): 976 # check to see if module is of a supported type 977 if self._is_supported(module, insert=True): 978 # if it's a supported type, we want to get node and add observer insert locations 979 targeted_node = self._get_targeting_node(prepared_fx_model, fqn) 980 981 # add entry for pre-observer 982 pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME 983 984 obs_fqn_to_info[pre_obs_fqn] = { 985 DETECTOR_TARGET_NODE_KEY: targeted_node, 986 DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis), 987 DETECTOR_IS_POST_OBS_KEY: False, 988 DETECTOR_OBS_ARGS_KEY: targeted_node.args, 989 } 990 991 return obs_fqn_to_info 992 993 def get_detector_name(self) -> str: 994 r"""Returns the name of this detector""" 995 return "input_weight_equalization_detector" 996 997 def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]: 998 r""" 999 Takes in a calibrated GraphModule and then finds the relevant observers. 1000 It then extracts the input information for each observer returns it 1001 1002 Args 1003 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 1004 1005 Returns a dict mapping relevant module fqns (str) to a dict with keys: 1006 "input_activation_per_channel_max" : maps to the per_channel max values 1007 "input_activation_per_channel_min" : maps to the per_channel min values 1008 "input_activation_global_max" : maps to the global max recorded 1009 "input_activation_global_min" : maps to the global min recorded 1010 """ 1011 1012 # return dictionary mapping observer fqns to desired info 1013 input_info: Dict[str, Dict] = {} 1014 1015 for fqn, module in model.named_modules(): 1016 # if module is supported and it has a pre-observer 1017 if self._is_supported(module): 1018 # get pre observer for the module 1019 pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) 1020 1021 input_info[fqn] = { 1022 self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val, 1023 self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val, 1024 self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val), 1025 self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val), 1026 } 1027 1028 return input_info 1029 1030 def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]: 1031 r""" 1032 Takes in a calibrated GraphModule and then finds the relevant observers. 1033 It then extracts the weight information for each layer an observer is attached to. 1034 1035 Args 1036 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 1037 1038 Returns a dict mapping module fqns (str) to a dict with keys: 1039 "per_channel_max" : maps to the per_channel max values 1040 "per_channel_min" : maps to the per_channel min values 1041 "global_max" : maps to the global max recorded 1042 "global_min" : maps to the global min recorded 1043 """ 1044 # return dictionary mapping observer fqns to desired info 1045 weight_info: Dict[str, Dict] = {} 1046 1047 for fqn, module in model.named_modules(): 1048 # if module is supported and it has a pre-observer 1049 if self._is_supported(module): 1050 # we don't need actual observer, just the module weights 1051 # calculate min and max vals 1052 device = module.weight.device 1053 min_val: torch.Tensor = torch.tensor([float("inf")], device=device) 1054 max_val: torch.Tensor = torch.tensor([float("-inf")], device=device) 1055 x_copy = module.weight 1056 x_dim = x_copy.size() 1057 1058 new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 1059 new_axis_list[self.ch_axis] = 0 1060 new_axis_list[0] = self.ch_axis 1061 y = x_copy.permute(new_axis_list) 1062 1063 # Need to match dtype of min/max because the updates to buffers 1064 # are done in place and types need to match for comparisons 1065 y = y.to(min_val.dtype) 1066 y = torch.flatten(y, start_dim=1) 1067 if min_val.numel() == 0 or max_val.numel() == 0: 1068 min_val, max_val = torch.aminmax(y, dim=1) 1069 else: 1070 min_val_cur, max_val_cur = torch.aminmax(y, dim=1) 1071 min_val = torch.min(min_val_cur, min_val) 1072 max_val = torch.max(max_val_cur, max_val) 1073 1074 weight_info[fqn] = { 1075 self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val, 1076 self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val, 1077 self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val), 1078 self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val), 1079 } 1080 1081 return weight_info 1082 1083 def _calculate_range_ratio( 1084 self, info_dict: Dict, info_str: str, module_fqn: str 1085 ) -> torch.Tensor: 1086 r""" 1087 Takes in an info dict and calculates the s_c matrix. 1088 1089 Args: 1090 info_dict (dict): A dictionary of either input or weight range info 1091 info_str (str): A str describing whether currently looking at weight or input info 1092 Either "weight" or "input" 1093 module_fqn (str): The fqn of the module we are looking at 1094 1095 Returns a tensor of values, where each value is the s_c stat for a different channel 1096 """ 1097 # calculate the ratios of the info 1098 # get the prefix str 1099 prefix_str = ( 1100 self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX 1101 ) 1102 1103 per_channel_range = ( 1104 info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] 1105 - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY] 1106 ) 1107 global_range = ( 1108 info_dict[prefix_str + self.GLOBAL_MAX_KEY] 1109 - info_dict[prefix_str + self.GLOBAL_MIN_KEY] 1110 ) 1111 1112 if global_range == 0: 1113 range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information." 1114 raise ValueError( 1115 f"The range of the {info_str} data for module {module_fqn} is 0, " 1116 f"which means you have a constant value channel. {range_zero_explanation}" 1117 ) 1118 1119 ratio = per_channel_range / global_range 1120 1121 return ratio 1122 1123 def _generate_comparison_values( 1124 self, input_info: Dict, weight_info: Dict 1125 ) -> Dict[str, torch.Tensor]: 1126 r""" 1127 Takes in the information on the min and max values of the inputs and weights and: 1128 Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I) 1129 1130 Args: 1131 input_info (dict): A dict mapping each observer to input range information 1132 weight_info (dict): A dict mapping each observer to weight range information 1133 1134 Returns a dict mapping relevant observer fqns (str) to a 1-D tensor. 1135 Each value is a different s_c value for a different channel 1136 """ 1137 # create return dictionary for each observer 1138 module_fqn_to_channel: Dict[str, torch.Tensor] = {} 1139 1140 # for each module (both passed in dicts should have same keys) 1141 for module_fqn in input_info: 1142 # raise error if not in weight info 1143 if module_fqn not in weight_info: 1144 raise KeyError( 1145 f"Unable to find weight range stats for module {module_fqn}" 1146 ) 1147 1148 # calculate the ratios of the weight info and input info 1149 weight_ratio = self._calculate_range_ratio( 1150 weight_info[module_fqn], self.WEIGHT_STR, module_fqn 1151 ) 1152 input_ratio = self._calculate_range_ratio( 1153 input_info[module_fqn], self.INPUT_STR, module_fqn 1154 ) 1155 1156 # if mismatched size, because of grouping, we want to replicate weight enough times 1157 weight_channels = len(weight_ratio) 1158 input_channels = len(input_ratio) 1159 if weight_channels != input_channels: 1160 # we try to replicate 1161 assert ( 1162 input_channels % weight_channels == 0 1163 ), "input channels should be divisible by weight channels." 1164 # get replication factor 1165 rep_factor: int = input_channels // weight_channels 1166 1167 # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n 1168 weight_ratio = weight_ratio.repeat(rep_factor) 1169 1170 # calculate the s metric per channel 1171 s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio) 1172 module_fqn_to_channel[module_fqn] = s 1173 1174 # return compiled observer ratios 1175 return module_fqn_to_channel 1176 1177 def _generate_dict_info( 1178 self, input_info: Dict, weight_info: Dict, comp_stats: Dict 1179 ) -> Dict[str, Dict]: 1180 r""" 1181 Helper function for generate_detector_report that does the generation of the dictionary. 1182 This process is done as specified in generate_detector_report documentation 1183 1184 Args: 1185 input_info (dict): A dict mapping each module to input range information 1186 weight_info (dict): A dict mapping each module to weight range information 1187 comp_stats (dict): A dict mapping each module to its corresponding comp stat 1188 1189 Returns a dictionary mapping each module with relevant ModelReportObservers around them to: 1190 whether input weight equalization is recommended 1191 their s_c metric compared to the threshold 1192 the threshold used to make the recommendation 1193 the channel used for recording data 1194 the input channel range info 1195 the weight channel range info 1196 """ 1197 # store modules input weight equalization info 1198 input_weight_equalization_info: Dict[str, Dict] = {} 1199 1200 # for each module we add separate set of suggestions 1201 for module_fqn in input_info: 1202 # get relevant info for this module 1203 mod_input_info: Dict = input_info[module_fqn] 1204 mod_weight_info: Dict = weight_info[module_fqn] 1205 mod_comp_stat: Dict = comp_stats[module_fqn] 1206 1207 # decide if each channel should have input weight equalization or not 1208 channel_rec_vals: list = [] 1209 1210 for val in mod_comp_stat: 1211 float_rep: float = val.item() 1212 1213 # decide if recommending input weight equalization 1214 recommended: bool = ( 1215 float_rep >= self.ratio_threshold 1216 and float_rep <= 1 / self.ratio_threshold 1217 ) 1218 channel_rec_vals.append(recommended) 1219 1220 # build the return dict input 1221 # also unpack input and weight dicts into it 1222 input_weight_equalization_info[module_fqn] = { 1223 self.RECOMMENDED_KEY: channel_rec_vals, 1224 self.COMP_METRIC_KEY: mod_comp_stat, 1225 self.THRESHOLD_KEY: self.ratio_threshold, 1226 self.CHANNEL_KEY: self.ch_axis, 1227 **mod_input_info, 1228 **mod_weight_info, 1229 } 1230 1231 # return our compiled info for each module 1232 return input_weight_equalization_info 1233 1234 def generate_detector_report( 1235 self, model: GraphModule 1236 ) -> Tuple[str, Dict[str, Any]]: 1237 r""" 1238 Determines whether input weight equalization is appropriate for a given module. 1239 1240 Takes advantage of the ModelReport Observer which records per channel information of input range 1241 It then uses the passed in weight info inconjunction to compute the desired ratio 1242 Finally, it gives suggestions based on this information for each module of interest 1243 1244 Args: 1245 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 1246 1247 Returns a tuple with two elements: 1248 String report of of whether input weight equalization is recommended for certain modules 1249 Dictionary mapping modules of interest to: 1250 whether input weight equalization is recommended 1251 their s_c metric compared to the threshold 1252 the threshold used to make the recommendation 1253 the channel used for recording data 1254 the input channel range info 1255 the weight channel range info 1256 """ 1257 1258 # find the range of inputs 1259 input_values: Dict[str, Dict] = self._extract_input_info(model) 1260 1261 # find the range of weights 1262 weight_values: Dict[str, Dict] = self._extract_weight_info(model) 1263 1264 # calculate per_channel comparison statistic s_c 1265 comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values( 1266 input_values, weight_values 1267 ) 1268 1269 # generate the return dictionary 1270 input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info( 1271 input_values, weight_values, comp_stats 1272 ) 1273 1274 # now we can generate report based on this information 1275 input_weight_string = "Input-Weight Equalization suggestions: \n" 1276 1277 # some strings to be formatted depending on module we are adding 1278 module_suggestion_str = "For Module {} looked at with axis {}: \n" 1279 channel_suggestion_str = ( 1280 "\tWe suggest {} input weight equalization because {}\n" 1281 ) 1282 use_str = "to use" 1283 no_use_str = "to not use" 1284 input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error." 1285 input_weight_non_benefit_reasoning = ( 1286 "{}/{} channels benefitting from input-weight equalization being applied." 1287 ) 1288 input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}" 1289 1290 # added module check 1291 added_module: bool = False 1292 1293 # compile the suggestion string 1294 for module_fqn in input_weight_equalization_info: 1295 # we added at least 1 module 1296 added_module = True 1297 # add the module level description 1298 input_weight_string += module_suggestion_str.format( 1299 module_fqn, self.ch_axis 1300 ) 1301 1302 mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn] 1303 1304 # gather info on how many channels would benefit from input weight and 1305 recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY] 1306 num_recs = sum(recommendation_per_channel) 1307 1308 if ( 1309 num_recs / len(recommendation_per_channel) 1310 >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO 1311 ): 1312 input_benefit_formatted = input_weight_benefit_str.format( 1313 num_recs, len(recommendation_per_channel) 1314 ) 1315 channel_str = channel_suggestion_str.format( 1316 use_str, input_benefit_formatted 1317 ) 1318 input_weight_string += channel_str 1319 else: 1320 non_benefit_reason_formatted = ( 1321 input_weight_non_benefit_reasoning.format( 1322 num_recs, len(recommendation_per_channel) 1323 ) 1324 ) 1325 non_benefit_str = input_weight_non_benefit_str.format( 1326 non_benefit_reason_formatted 1327 ) 1328 channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str) 1329 input_weight_string += channel_str 1330 1331 # if no modules looked at, amend return string 1332 if not added_module: 1333 input_weight_string += ( 1334 "No applicable layers for suggestions. Only linear and conv valid.\n" 1335 ) 1336 1337 # return a tuple with the string explanation and the compiled dict info 1338 return (input_weight_string, input_weight_equalization_info) 1339 1340 1341class OutlierDetector(DetectorBase): 1342 r""" 1343 Determines whether there are significant outliers in activation data around a certain layer. 1344 1345 This is ideally used in conjunction with information on stationary vs. non-stationary distribution: 1346 If the data is stationary, and there are significant outliers, then we want to flag them 1347 We want to do this on a per channel basis for detecting outliers 1348 1349 Determines whether activation data is flagged as outlier based on if data is stationary and: 1350 p_r = avg(100th percentile / "reference_percentile"th percentile) 1351 where: 1352 p_r is average percentile ratio across all batches in the epoch 1353 reference_percentile is a percentile values between 0 and 100 exclusive 1354 1355 if p_r is above some threshold, then we consider the activations to have significant outliers 1356 1357 Args: 1358 ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations 1359 Should be >= 1 1360 Default: 3.5 1361 reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile 1362 Should be between 0 and 1 1363 Default: 0.975 1364 fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier 1365 If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user 1366 regardless of whether we detected outliers or not in channel to take a closer look at channel results 1367 Should be between 0 and 1 1368 Default: 0.95 1369 ch_axis (int, optional): The channel axis being observed to determine input weight equalization 1370 Default: 1 1371 1372 * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations 1373 The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold 1374 If it is significantly greater, then we consider it an outlier 1375 This threshold was calculated based on the ratio of the percentiles in a normal distribution 1376 The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing 1377 1378 * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile 1379 Should be between 0 and 1 1380 The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing 1381 1382 * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this 1383 Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used 1384 Should be between 0 and 1 1385 1386 * :attr:`ch_axis`: The channel axis being observed to determine outliers 1387 1388 * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector 1389 """ 1390 1391 # names for the pre observers that are inserted 1392 DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" 1393 1394 # pre activation prefix 1395 INPUT_ACTIVATION_PREFIX = "input_activation_" 1396 1397 # names for dict keys 1398 OUTLIER_KEY = "outliers_detected" 1399 NUM_BATCHES_KEY = "outlier_detection_batches_used" 1400 IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches" 1401 COMP_METRIC_KEY = "outlier_detection_percentile_ratios" 1402 RATIO_THRES_KEY = "outlier_detection_ratio_threshold" 1403 REF_PERCENTILE_KEY = "outlier_detection_reference_percentile" 1404 CHANNEL_AXIS_KEY = "outlier_detection_channel_axis" 1405 MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max" 1406 CONSTANT_COUNTS_KEY = "constant_batch_counts" 1407 1408 def __init__( 1409 self, 1410 ratio_threshold: float = 3.5, 1411 reference_percentile: float = 0.975, 1412 fraction_batches_used_threshold: float = 0.95, 1413 ch_axis: int = 1, 1414 ): 1415 # initialize the variables of interest 1416 self.ratio_threshold = ratio_threshold 1417 1418 # make sure passed in percentile is valid 1419 assert reference_percentile >= 0 and reference_percentile <= 1 1420 assert ( 1421 fraction_batches_used_threshold >= 0 1422 and fraction_batches_used_threshold <= 1 1423 ) 1424 self.reference_percentile = reference_percentile 1425 self.fraction_batches_used_threshold = fraction_batches_used_threshold 1426 self.ch_axis = ch_axis 1427 1428 def get_detector_name(self) -> str: 1429 r"""Returns the name of this detector""" 1430 return "outlier_detector" 1431 1432 def _supports_insertion(self, module: nn.Module) -> bool: 1433 r"""Returns whether the given module is supported for observers insertion 1434 1435 Any module that doesn't have children and isn't an observer itself is supported 1436 1437 Args 1438 module: The module to check and ensure is supported 1439 1440 Returns True if the module is supported by observer, False otherwise 1441 """ 1442 # case for insertion of module 1443 # check if the module has any children and isn't observer 1444 num_children = len(list(module.children())) 1445 return num_children == 0 and not _is_activation_post_process(module) 1446 1447 def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: 1448 r"""Returns the DetectorQConfigInfo for each module_fqn relevant 1449 Args 1450 model (nn.Module or subclass): model to find observer insertion points 1451 1452 Returns a Dict mapping from unique observer fqns (where we want to insert them) to: 1453 A DetectorQConfigInfo with the information to generate a QConfig for a specific module 1454 """ 1455 # currently doesn't do anything for outlier detector 1456 return {} 1457 1458 def _supports_report_gen(self, module: nn.Module) -> bool: 1459 r"""Returns whether the given module is supported for report generation 1460 1461 Any module that has a model report pre-observer is supported 1462 1463 Args 1464 module: The module to check and ensure is supported 1465 1466 Returns True if the module is supported by observer, False otherwise 1467 """ 1468 return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) 1469 1470 def determine_observer_insert_points( 1471 self, prepared_fx_model: GraphModule 1472 ) -> Dict[str, Dict[str, Any]]: 1473 r"""Determines where observers need to be inserted for the Outlier Detector. 1474 1475 For this detector, we want to place observers in front of supported layers. 1476 1477 Currently inserts observers for: 1478 all layers that do not have children (leaf level layers) 1479 1480 Args: 1481 prepared_fx_model (GraphModule): The prepared Fx GraphModule 1482 1483 Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: 1484 key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) 1485 key "observer_to_insert" -> the observer we wish to insert (ObserverBase) 1486 key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer 1487 key "observer_args" -> The arguments that are meant to be passed into the observer 1488 """ 1489 # observer for this detector is ModelReportObserver 1490 obs_ctr = ModelReportObserver 1491 1492 # return dict 1493 obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} 1494 1495 for fqn, module in prepared_fx_model.named_modules(): 1496 # check to see if module is of a supported type 1497 if self._supports_insertion(module): 1498 # if it's a supported type, we want to get node and add observer insert locations 1499 targeted_node = self._get_targeting_node(prepared_fx_model, fqn) 1500 1501 # add entry for pre-observer 1502 pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME 1503 1504 obs_fqn_to_info[pre_obs_fqn] = { 1505 DETECTOR_TARGET_NODE_KEY: targeted_node, 1506 DETECTOR_OBS_TO_INSERT_KEY: obs_ctr( 1507 ch_axis=self.ch_axis, comp_percentile=self.reference_percentile 1508 ), 1509 DETECTOR_IS_POST_OBS_KEY: False, 1510 DETECTOR_OBS_ARGS_KEY: targeted_node.args, 1511 } 1512 1513 return obs_fqn_to_info 1514 1515 def _calculate_outlier_info( 1516 self, 1517 percentile_ratios: torch.Tensor, 1518 counted_batches: torch.Tensor, 1519 total_batches: int, 1520 ) -> Dict[str, List[bool]]: 1521 r""" 1522 Gives info on whether the percentile ratios calculated would be considered outliers 1523 Also gives information on whether the collected data is statistically significant to make this claim 1524 1525 Args: 1526 percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer 1527 counted_batches (torch.Tensor): The number of batches used for average calculation per tensor 1528 total_batches (int): The total number of batches that passed through observer in this epoch 1529 1530 Returns a dictionary mapping: 1531 "outliers_detected" : list of bools per channel that are true if it is considered an outlier 1532 "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold: 1533 where o_r = counted_batches / total_batches 1534 """ 1535 outlier_dict: Dict[str, List[bool]] = { 1536 self.OUTLIER_KEY: [], 1537 self.IS_SUFFICIENT_BATCHES_KEY: [], 1538 } 1539 1540 # get both as flattened lists for easy mapping 1541 ratios_list: List = percentile_ratios.tolist() 1542 num_batches_list: List = counted_batches.tolist() 1543 1544 # calculate whether channels were statistically significant 1545 significant_size = [ 1546 batch_size / total_batches >= self.fraction_batches_used_threshold 1547 for batch_size in num_batches_list 1548 ] 1549 outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size 1550 1551 # calculate for each channel whether it's an outlier or not based on ratio 1552 outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list] 1553 outlier_dict[self.OUTLIER_KEY] = outlier_detected 1554 1555 # return the dictionary with the two lists 1556 return outlier_dict 1557 1558 def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]: 1559 r""" 1560 Helper function for generate_detector_report that does the generation of the dictionary. 1561 This process is done as specified in generate_detector_report documentation 1562 1563 Args: 1564 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 1565 1566 Returns a dict mapping relevant module fqns to: 1567 whether there were outliers found in activation before 1568 the number of batches used for each channel 1569 whether fraction of applicable batches used is above fraction_batches_used_threshold 1570 their p_r metric compared to the threshold 1571 the threshold used to make the recommendation 1572 the reference_percentile used to make the recommendation 1573 the channel axis used to determine individual channels 1574 the constant batch counts per channel 1575 the per channel max values 1576 """ 1577 # return dictionary mapping observer fqns to desired info 1578 info_dict: Dict[str, Dict] = {} 1579 1580 for fqn, module in model.named_modules(): 1581 # if module is supported and it has a pre-observer 1582 if self._supports_report_gen(module): 1583 # get pre observer for the module 1584 pre_obs: ModelReportObserver = getattr( 1585 module, self.DEFAULT_PRE_OBSERVER_NAME 1586 ) 1587 1588 # get the number of batches and calculated ratio thresholds 1589 num_batches: torch.Tensor = pre_obs.percentile_batches_tracked 1590 average_ratios: torch.Tensor = pre_obs.average_percentile_ratio 1591 channel_batch_cnts: torch.Tensor = pre_obs.constant_channels 1592 total_batches: int = pre_obs.num_batches_tracked 1593 1594 # also get the max values 1595 max_vals: torch.Tensor = pre_obs.max_val 1596 1597 # we have to specifically modify how we are recording negative ratio for pre-relu layers 1598 for index, ratio_val in enumerate(average_ratios): 1599 # check if we have a negative ratio 1600 # a ratio might be negative if we have a situation where the 100th percentile is 1601 # > 0 while the nth percentile is < 0, in which case this would not be detected 1602 # as an outlier. Since we care more about magnitude, we make it positive. 1603 if ratio_val.item() < 0: 1604 # first make it positive 1605 average_ratios[index] = -ratio_val 1606 1607 if ratio_val.item() < 1: 1608 # if it's less than 1 we have the flip it as well 1609 average_ratios[index] = 1 / ratio_val 1610 1611 outlier_calcs = self._calculate_outlier_info( 1612 average_ratios, num_batches, total_batches 1613 ) 1614 1615 # calculate whether ratios were outliers 1616 info_dict[fqn] = { 1617 self.CHANNEL_AXIS_KEY: self.ch_axis, 1618 self.REF_PERCENTILE_KEY: self.reference_percentile, 1619 self.RATIO_THRES_KEY: self.ratio_threshold, 1620 self.COMP_METRIC_KEY: average_ratios, 1621 self.NUM_BATCHES_KEY: num_batches, 1622 self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY], 1623 self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[ 1624 self.IS_SUFFICIENT_BATCHES_KEY 1625 ], 1626 self.CONSTANT_COUNTS_KEY: channel_batch_cnts, 1627 self.MAX_VALS_KEY: max_vals, 1628 } 1629 1630 return info_dict 1631 1632 def generate_detector_report( 1633 self, model: GraphModule 1634 ) -> Tuple[str, Dict[str, Any]]: 1635 r""" 1636 Determines whether input weight equalization is appropriate for a given module. 1637 1638 Takes advantage of the ModelReport Observer which records the relevant percentile information 1639 1640 Args: 1641 model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers 1642 1643 Returns a tuple with two elements: 1644 String report of of whether there are outliers in the activations around certain modules 1645 Dictionary mapping modules of interest to: 1646 whether there were outliers found in activation before 1647 the number of batches used for each channel 1648 whether fraction of applicable batches used is above fraction_batches_used_threshold 1649 their p_r metric compared to the threshold 1650 the threshold used to make the recommendation 1651 the reference_percentile used to make the recommendation 1652 the channel axis used to determine individual channels 1653 the constant batch counts per channel 1654 the per channel max values 1655 """ 1656 # generate the information dictionary of outlier information 1657 info_dict = self._generate_info_dict(model) 1658 1659 # now we can generate report based on this information 1660 outlier_string = "Outlier detection report: \n" 1661 1662 # added module check 1663 added_module: bool = False 1664 1665 # some strings to be formatted depending on module we are adding 1666 module_suggestion_str = "For Module {} looked at with axis {}: \n" 1667 channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n" 1668 channel_max_value_str = "a max value across all batches of {}" 1669 note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results." 1670 note_distribution = "stationary distributions" 1671 note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary" 1672 1673 # suggestion for constant batch check since that can make it no outliers 1674 constant_str = "\tFor channel {}, we found {} constant value batches. {}\n" 1675 constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why." 1676 1677 # compile the suggestion string 1678 for module_fqn in info_dict: 1679 # get module specific info 1680 mod_info: Dict[str, Any] = info_dict[module_fqn] 1681 # check to see if we already added high level model desc 1682 added_model_desc = False 1683 # look at each individual channel and add a suggestion 1684 for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]): 1685 if outlier_detected: 1686 # we found at least 1 outlier 1687 if not added_model_desc: 1688 # add the module level description 1689 outlier_string += module_suggestion_str.format( 1690 module_fqn, self.ch_axis 1691 ) 1692 added_model_desc = True 1693 1694 # we mark that we found at least one outlier 1695 added_module = True 1696 max_value_found_str = channel_max_value_str.format( 1697 mod_info[self.MAX_VALS_KEY][index] 1698 ) 1699 channel_str = channel_suggestion_str.format( 1700 index, max_value_found_str 1701 ) 1702 outlier_string += channel_str 1703 1704 # also check if we found constant batch 1705 if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0: 1706 # make sure we add a module level highlight. 1707 if not added_model_desc: 1708 # add the module level description 1709 outlier_string += module_suggestion_str.format( 1710 module_fqn, self.ch_axis 1711 ) 1712 added_model_desc = True 1713 1714 constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][ 1715 index 1716 ] 1717 formatted_str = constant_str.format( 1718 index, constant_values_for_channel, constant_suggestion 1719 ) 1720 outlier_string += formatted_str 1721 # we also added at least one thing to description 1722 added_module = True 1723 1724 # if found outlier, give suggestion, else give default response 1725 if added_module: 1726 # compose the note string 1727 note_composed = note_string.format(note_distribution, note_rec) 1728 outlier_string += note_composed 1729 else: 1730 outlier_string += "There were no outliers found in the activations.\n" 1731 1732 return (outlier_string, info_dict) 1733