xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_model_report/detector.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from abc import ABC, abstractmethod
3from typing import Any, Callable, Dict, List, Set, Tuple
4
5import torch
6import torch.ao.nn.qat as nnqat
7import torch.nn as nn
8from torch.ao.quantization.fake_quantize import FakeQuantize
9from torch.ao.quantization.fx._equalize import (
10    default_equalization_qconfig,
11    EqualizationQConfig,
12)
13from torch.ao.quantization.fx._model_report.model_report_observer import (
14    ModelReportObserver,
15)
16from torch.ao.quantization.fx.graph_module import GraphModule
17from torch.ao.quantization.observer import (
18    _is_activation_post_process,
19    default_dynamic_quant_observer,
20    default_observer,
21    default_per_channel_weight_observer,
22    default_weight_observer,
23    ObserverBase,
24)
25from torch.ao.quantization.qconfig import (
26    _assert_valid_qconfig,
27    default_qconfig,
28    QConfig,
29)
30
31
32# Names for observer insert keys
33DETECTOR_TARGET_NODE_KEY = "target_node"
34DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
35DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
36DETECTOR_OBS_ARGS_KEY = "observer_args"
37
38
39# Mapping related code
40class DetectorQConfigInfo:
41    r"""
42    This class contains the QConfig information for a single module.
43    The list of variables / values this contains can grow depending on the
44    extensibility of the qconfig mapping feature set but this currently includes:
45    - if activation observer is dynamic
46    - if weight observer is per channel
47
48
49    Args:
50        module_fqn (str): The fully qualified name (fqn) of the module that this
51            information contains info relevant to qconfig for
52    """
53
54    def __init__(self, module_fqn: str):
55        super().__init__()
56        self.module_fqn = module_fqn
57
58        # populate this section with all the variables we might find important
59        # change from none if your detector is actually using this
60        self.is_activation_dynamic = False
61        self.is_weight_per_channel = False
62
63        # equalization related options
64        self.is_equalization_recommended = False
65
66    def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
67        r"""
68        Args:
69            module (torch.nn.Module) The module we are generating
70            the qconfig for
71
72        Returns the generated quantization QConfig according to what a valid configuration is
73        """
74        # Apply suggestions to new qconfig
75        module_qconfig = default_qconfig
76
77        # keep track of dynamic and per_channel recommendations
78        recommendations_list = []
79        # append as if a list of combinations
80        recommendations_list.append(
81            (self.is_activation_dynamic, self.is_weight_per_channel)
82        )
83        recommendations_list.append(
84            (self.is_activation_dynamic, False)
85        )  # only trying dynamic rec
86        recommendations_list.append(
87            (False, self.is_weight_per_channel)
88        )  # only trying dynamic
89
90        # now we try each of the combinations
91        for rec in recommendations_list:
92            # rec[0] -> dynamic recommended
93            # rec[1] -> per channel recommended
94            activation = default_dynamic_quant_observer if rec[0] else default_observer
95            weight = (
96                default_per_channel_weight_observer
97                if rec[1]
98                else default_weight_observer
99            )
100            test_config = QConfig(activation, weight)
101            try:
102                _assert_valid_qconfig(test_config, module)
103                module_qconfig = test_config
104                break
105            except AssertionError:
106                # if not a valid configuration, we move on to the next one in priority
107                continue
108
109        # return the QConfig chosen
110        return module_qconfig
111
112    def generate_equalization_qconfig(self) -> EqualizationQConfig:
113        r"""
114        This returns the equalization configuration for a module.
115
116        For now, it just returns the default, but as more equalization options become
117        possible, this method can get more fleshed out with more nuanced granularity.
118
119
120        Returns the generated equalization QConfig according to what a valid configuration is
121        """
122        # in this case, we just return default equalization config
123        # we know this is valid because only valid modules would even
124        # have this option
125        return default_equalization_qconfig
126
127
128# Adding base class for detectors
129class DetectorBase(ABC):
130    r"""Base Detector Module
131    Any detector class should derive from this class.
132
133    Concrete detectors should follow the same general API, which includes:
134    - A method to calculate and return observer insertion points
135        - Should return both the fqns and the Observer class to insert
136    - A method to return a report based on the detector
137        - Should return a str-based report and dict info in Tuple[str,Dict] format
138    """
139
140    def __init__(self) -> None:
141        super().__init__()
142        self.detector_config_info = None
143
144    @abstractmethod
145    def determine_observer_insert_points(self, model) -> Dict:
146        r"""
147        Args
148            model (nn.Module or subclass): model to find observer insertion points
149
150        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
151            This dict maps string keys to detector specific information
152        """
153
154    @abstractmethod
155    def get_detector_name(self) -> str:
156        r"""Returns the name of the current detector"""
157
158    @abstractmethod
159    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
160        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
161        Args
162            model (nn.Module or subclass): model to find observer insertion points
163
164        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
165            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
166        """
167
168    def _get_targeting_node(
169        self, prepared_fx_model: GraphModule, target_fqn: str
170    ) -> torch.fx.node.Node:
171        r"""
172        Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.
173
174        If it's not found, it means it is most likely inside a fused layer
175            We just go one layer up in terms of the fqn we are searching for until we find parent node
176            If we get to empty string, then we know that it doesn't exist
177
178        The reason for the recursion is that if the model that we are looking for got fused,
179        we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
180        which would have fqn as x.linear so they will not match.
181        To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
182        or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
183        even in cases with fusion
184
185        Args:
186            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
187            target_fqn (str): The fqn of the layer we are trying to target
188
189        Returns the node object we are trying to add observers around
190        """
191        for node in prepared_fx_model.graph.nodes:
192            # if the node's target is our target, return it
193            if node.target == target_fqn:
194                return node
195
196        # getting here means node not found
197        # if no "." we are already at base and failed
198        parent_fqn_sep_index = target_fqn.rfind(".")
199        if parent_fqn_sep_index == -1:
200            raise ValueError("passed in target_fqn not found in graph's targets.")
201        else:
202            # recursively call it with parent fqn
203            return self._get_targeting_node(
204                prepared_fx_model, target_fqn[:parent_fqn_sep_index]
205            )
206
207    @abstractmethod
208    def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]:
209        r"""
210        Args
211            model (nn.Module or subclass): model to find observer insertion points
212
213        Returns a Tuple of two elements:
214            Str: string report of the suggested improvements
215            Dict: contains useful data collected by the observer pertinent to this report
216        """
217
218
219class PerChannelDetector(DetectorBase):
220    r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
221    Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
222
223    per_channel quantization can lead to major benefits in the form of accuracy.
224    Therefore, if the backend used by the user supports it, it is recommended to use
225
226    Args:
227        backend (str, optional): the backend the user wishes to use in production
228            Default value is current torch.backends.quantized.engine
229    """
230
231    # Keys for return dictionary
232    BACKEND_KEY = "backend"
233    PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
234    PER_CHAN_USED_KEY = "per_channel_quantization_used"
235
236    # Default map for representing supported per channel quantization modules for different backends
237    DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
238        "fbgemm": {
239            nn.Linear,
240            nn.Conv1d,
241            nn.Conv2d,
242            nn.Conv3d,
243            nnqat.Linear,
244            nnqat.Conv1d,
245            nnqat.Conv2d,
246            nnqat.Conv3d,
247        },
248        "qnnpack": {
249            nn.Linear,
250            nn.Conv1d,
251            nn.Conv2d,
252            nn.Conv3d,
253            nnqat.Linear,
254            nnqat.Conv1d,
255            nnqat.Conv2d,
256            nnqat.Conv3d,
257        },
258        "onednn": {
259            nn.Linear,
260            nn.Conv1d,
261            nn.Conv2d,
262            nn.Conv3d,
263            nnqat.Linear,
264            nnqat.Conv1d,
265            nnqat.Conv2d,
266            nnqat.Conv3d,
267        },
268        "x86": {
269            nn.Linear,
270            nn.Conv1d,
271            nn.Conv2d,
272            nn.Conv3d,
273            nnqat.Linear,
274            nnqat.Conv1d,
275            nnqat.Conv2d,
276            nnqat.Conv3d,
277        },
278    }
279
280    def __init__(self, backend: str = torch.backends.quantized.engine):
281        super().__init__()
282
283        # store the backend information
284        self.backend_chosen = backend
285        self.supported_modules = set()
286        if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
287            self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[
288                self.backend_chosen
289            ]
290        else:
291            raise ValueError(
292                f"Not configured to work with {self.backend_chosen}. Try a different default backend"
293            )
294
295    def get_detector_name(self) -> str:
296        r"""returns the string name of this detector"""
297        return "per_channel_detector"
298
299    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
300        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
301        Args
302            model (nn.Module or subclass): model to find observer insertion points
303
304        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
305            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
306        """
307        # run the helper function to populate the dictionary
308        per_channel_info = self._detect_per_channel_helper(model)
309
310        # we actually have a qconfig info object we are populating
311        module_fqn_to_detector_qconfig_info = {}
312
313        for module_fqn in per_channel_info:
314            # create a detector info instance
315            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
316
317            # see if per channel quantization is supported
318            per_chan_supported: bool = per_channel_info[module_fqn][
319                self.PER_CHAN_SUPPORTED_KEY
320            ]
321            detector_qconfig_info.is_weight_per_channel = per_chan_supported
322            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
323
324        return module_fqn_to_detector_qconfig_info
325
326    def determine_observer_insert_points(self, model: nn.Module) -> Dict:
327        r"""
328        There is no observers inserted for the PerChannelDetector.
329
330        Returns an empty dictionary since no observers are added or needed
331        """
332        return {}
333
334    def _detect_per_channel_helper(self, model: nn.Module):
335        r"""
336        determines if per_channel quantization is supported in modules and submodules.
337
338        Returns a dictionary in the higher level _detect_per_channel function.
339        Each entry maps the fully-qualified-name to information on whether per_channel quantization.
340
341        Args:
342            model: The current module that is being checked to see if it is per_channel quantizable
343
344        Returns dictionary mapping fqns to if per_channel quantization is possible
345        """
346        # create dict we will return
347        per_channel_info: Dict = {}
348
349        # get the fully qualified name and check if in list of modules to include and list of modules to ignore
350        for fqn, module in model.named_modules():
351            is_in_include_list = any(
352                isinstance(module, x) for x in self.supported_modules
353            )
354
355            # check if the module per_channel is supported
356            # based on backend
357            per_channel_supported = False
358
359            if is_in_include_list:
360                per_channel_supported = True
361
362                # assert statement for MyPy
363                q_config_file = module.qconfig
364                assert isinstance(q_config_file, QConfig)
365
366                # this object should either be fake quant or observer
367                q_or_s_obj = module.qconfig.weight.p.func()
368                assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
369
370                per_channel_used = False  # will be true if found in qconfig
371
372                if hasattr(
373                    q_or_s_obj, "ch_axis"
374                ):  # then we know that per_channel quantization used
375                    # all fake quants have channel axis so need to check is_per_channel
376                    if isinstance(q_or_s_obj, FakeQuantize):
377                        if (
378                            hasattr(q_or_s_obj, "is_per_channel")
379                            and q_or_s_obj.is_per_channel
380                        ):
381                            per_channel_used = True
382                    elif isinstance(q_or_s_obj, ObserverBase):
383                        # should be an observer otherwise
384                        per_channel_used = True
385                    else:
386                        raise ValueError("Should be either observer or fake quant")
387
388                per_channel_info[fqn] = {
389                    self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
390                    self.PER_CHAN_USED_KEY: per_channel_used,
391                    self.BACKEND_KEY: self.backend_chosen,
392                }
393
394        return per_channel_info
395
396    def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]:
397        r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
398        Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
399
400        Looks at q_config format and backend to determine if per_channel can be utilized.
401        Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support
402
403        Args:
404            model: The prepared and calibrated model we want to check if using per_channel
405
406        Returns a tuple with two elements:
407            String report of potential actions to improve model (if per_channel quantization is available in backend)
408            Dictionary mapping per_channel quantizable elements to:
409                whether per_channel quantization is supported by the backend
410                if it is being utilized in the current model
411        """
412
413        # run the helper function to populate the dictionary
414        per_channel_info = self._detect_per_channel_helper(model)
415
416        # String to let the user know of further optimizations
417        further_optims_str = (
418            f"Further Optimizations for backend {self.backend_chosen}: \n"
419        )
420
421        optimizations_possible = False
422        for fqn in per_channel_info:
423            fqn_dict = per_channel_info[fqn]
424            if (
425                fqn_dict[self.PER_CHAN_SUPPORTED_KEY]
426                and not fqn_dict[self.PER_CHAN_USED_KEY]
427            ):
428                optimizations_possible = True
429                further_optims_str += (
430                    f"Module {fqn} can be configured to use per_channel quantization.\n"
431                )
432
433        if optimizations_possible:
434            further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
435        else:
436            further_optims_str += "No further per_channel optimizations possible."
437
438        # return the string and the dictionary form of same information
439        return (further_optims_str, per_channel_info)
440
441
442class DynamicStaticDetector(DetectorBase):
443    r"""
444    Determines whether dynamic or static quantization is more appropriate for a given module.
445
446    Takes advantage of the ModelReportObserver that records range information.
447    Stationary distribution of data are strictly above tolerance level for the comparison statistic:
448
449        S = average_batch_activation_range/epoch_activation_range
450
451    Nonstationary distributions are below or at the tolerance level for this metric.
452
453    If the distribution of data right after the module is non-stationary, recommend dynamic quantization
454        Otherwise recommend static quantization
455
456    Args:
457        tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
458    """
459    # names for the pre and post observers that are inserted
460    DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
461    DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
462
463    # naming conventions for stationary vs non-stationary data
464    STATIONARY_STR = "stationary"
465    NON_STATIONARY_STR = "non-stationary"
466
467    # naming for activation
468    INPUT_ACTIVATION_PREFIX = "input_activation_"
469    OUTPUT_ACTIVATION_PREFIX = "output_activation_"
470
471    # naming conventions for the keys of the return module info
472    TOLERANCE_KEY = "dynamic_static_tolerance"
473    DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
474    PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
475    POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
476    PRE_OBS_DATA_DIST_KEY = (
477        INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
478    )
479    POST_OBS_DATA_DIST_KEY = (
480        OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
481    )
482    IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
483
484    # modules that are supported both dynamic and static for this report function
485    DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
486
487    # modules that will be supported soon for both
488    DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
489
490    def __init__(self, tolerance=0.5):
491        super().__init__()
492
493        # set tolerance level and initialize a set to keep track of useful fqn locations
494        self.tolerance = tolerance
495        self.useful_observer_fqns: Set[str] = set()
496
497    def determine_observer_insert_points(
498        self, prepared_fx_model: GraphModule
499    ) -> Dict[str, Dict[str, Any]]:
500        r"""
501        Determines where observers need to be inserted for the Dynamic vs Static detector.
502        For this detector, we want to place observers on either side of linear layers in the model.
503
504        Currently inserts observers for:
505            linear layers
506
507        Args:
508            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
509
510        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
511            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
512            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
513            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
514            key "observer_args" -> The arguments that are meant to be passed into the observer
515        """
516
517        # observer for this detector is ModelReportObserver
518        obs_ctr = ModelReportObserver
519
520        # return dict
521        obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
522
523        for fqn, module in prepared_fx_model.named_modules():
524            # make sure module is supported
525            if self._is_supported(module, insert=True):
526                # if it's a supported type, we want to get node and add observer insert locations
527                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
528
529                # add entry for pre-observer
530                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
531
532                obs_fqn_to_info[pre_obs_fqn] = {
533                    DETECTOR_TARGET_NODE_KEY: targeted_node,
534                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
535                    DETECTOR_IS_POST_OBS_KEY: False,
536                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
537                }
538
539                # add entry for post-observer
540                post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME
541
542                obs_fqn_to_info[post_obs_fqn] = {
543                    DETECTOR_TARGET_NODE_KEY: targeted_node,
544                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
545                    DETECTOR_IS_POST_OBS_KEY: True,
546                    DETECTOR_OBS_ARGS_KEY: (targeted_node,),
547                }
548
549        return obs_fqn_to_info
550
551    def get_detector_name(self) -> str:
552        r"""returns the string name of this detector"""
553        return "dynamic_vs_static_detector"
554
555    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
556        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
557        Args
558            model (nn.Module or subclass): model to find observer insertion points
559
560        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
561            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
562        """
563        # run the helper function to populate the dictionary
564        dynamic_static_info = self._generate_dict_info(model)
565
566        # we actually have a qconfig info object we are populating
567        module_fqn_to_detector_qconfig_info = {}
568
569        for module_fqn in dynamic_static_info:
570            # create a detector info instance
571            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
572
573            # see if per channel quantization is supported
574            dynamic_static_recommended: bool = dynamic_static_info[module_fqn][
575                self.DEFAULT_DYNAMIC_REC_KEY
576            ]
577            detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
578            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
579
580        return module_fqn_to_detector_qconfig_info
581
582    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
583        r"""Returns whether the given module is supported for observers
584
585        Args
586            module: The module to check and ensure is supported
587            insert: True if this is check for observer insertion, false if for report gen
588
589        Returns True if the module is supported by observer, False otherwise
590        """
591        # check to see if module is of a supported type
592        is_supported_type = any(
593            isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED
594        )
595
596        # check if it will be supported
597        future_supported_type = any(
598            isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED
599        )
600
601        # supported
602        supported = is_supported_type or future_supported_type
603
604        # this is check for observer insertion
605        if insert:
606            return supported
607        else:
608            # this is for report gen and we also need to check if it contains observers
609            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(
610                module, self.DEFAULT_POST_OBSERVER_NAME
611            )
612            return supported and has_obs
613
614    def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]:
615        r"""
616        Helper function for generate_detector_report that does the generation of the dictionary.
617        This process is done as specified in generate_detector_report documentation
618
619        Args:
620            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
621
622        Returns a Dictionary mapping modules with ModelReportObservers around them to:
623                whether dynamic quantization is recommended
624                their S metric of input to module
625                whether input to module is stationary or non-stationary
626                their S metric of output of module
627                whether output of module is stationary or non-stationary
628                the tolerance level to decided whether input/output is stationary or non-stationary
629                whether it is currently supported or planned for the future
630        """
631        # store modules dynamic vs static information
632        module_dynamic_static_info = {}
633
634        # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
635        #   This information primary includes whether the data distributions around a supported module is stationary or not
636        #   Based on this, it is recorded whether dynamic or static quantization is recommended
637
638        # loop through all submodules included nested ones
639        for fqn, module in model.named_modules():
640            # if module is Linear has the ModelReportObserver attached to it
641            if self._is_supported(module):
642                # get pre and post observers for the module
643                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
644                post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)
645
646                # get the statistics for each module
647                pre_stat = pre_obs.get_batch_to_epoch_ratio()
648                post_stat = post_obs.get_batch_to_epoch_ratio()
649
650                # record module, pre and post stat, and whether to do dynamic or static based off it
651                # true if post observer data distribution is non-stationary, false if it's stationary
652                dynamic_recommended = post_stat <= self.tolerance
653
654                # specify the classifications for whether data distributions considered stationary or non-stationary
655                pre_obs_dist_classif = (
656                    self.STATIONARY_STR
657                    if pre_stat > self.tolerance
658                    else self.NON_STATIONARY_STR
659                )
660                post_obs_dist_classif = (
661                    self.STATIONARY_STR
662                    if post_stat > self.tolerance
663                    else self.NON_STATIONARY_STR
664                )
665
666                # check if current support or future support
667                is_supported_type = any(
668                    isinstance(module, x)
669                    for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED
670                )
671
672                # store the set of important information for this module
673                module_info = {
674                    self.TOLERANCE_KEY: self.tolerance,
675                    self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
676                    self.PRE_OBS_COMP_STAT_KEY: pre_stat,
677                    self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
678                    self.POST_OBS_COMP_STAT_KEY: post_stat,
679                    self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
680                    self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
681                }
682
683                module_dynamic_static_info[fqn] = module_info
684
685        return module_dynamic_static_info
686
687    def generate_detector_report(
688        self, model: GraphModule
689    ) -> Tuple[str, Dict[str, Any]]:
690        r"""
691        Determines whether dynamic or static quantization is more appropriate for a given module.
692
693        Takes advantage of the ModelReportObserver that records range information.
694        Stationary distribution of data are strictly above tolerance level for the comparison statistic:
695
696            S = average_batch_activation_range/epoch_activation_range
697
698        Nonstationary distributions are below or at the tolerance level for this metric.
699
700        If the distribution of data right after the module is non-stationary, recommend dynamic quantization
701            Otherwise recommend static quantization
702
703        This will then generate suggestions for dynamic vs static quantization focused around Linear.
704
705        Args:
706            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
707
708        Returns a tuple with two elements:
709            String report of of whether dynamic or static quantization is recommended for certain modules
710            Dictionary mapping modules with ModelReportObservers around them to:
711                whether dynamic quantization is recommended
712                their S metric of input to module
713                whether input to module is stationary or non-stationary
714                their S metric of output of module
715                whether output of module is stationary or non-stationary
716                the tolerance level to decided whether input/output is stationary or non-stationary
717                whether it is currently supported or planned for the future
718        """
719
720        # get the dictionary of the information to format the string report
721        module_dynamic_static_info = self._generate_dict_info(model)
722
723        dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
724
725        modules_added: bool = False  # check to make sure at least 1 module added.
726
727        dynamic_benefit = (
728            " You will get more accurate results if you use dynamic quantization"
729        )
730        static_benefit = (
731            " You can increase model efficiency if you use static quantization"
732        )
733        future_support_str = (
734            ". This layer is not yet supported for dynamic quantization"
735        )
736        # This for loop goes through the information collected in module_dynamic_static_info and:
737        #   Populates the string based report with the information from module_dynamic_static_info
738        #   Compiles the complete report by appending relevant formatted strings
739
740        for module_fqn in module_dynamic_static_info.keys():
741            # there is at least 1 module for suggestion
742            modules_added = True
743            module_info = module_dynamic_static_info[module_fqn]
744            suggestion_string_template = (
745                "For module {} it is suggested to use {} quantization because {}.\n"
746            )
747
748            # decide what string formatting values will be
749            quantization_type = ""
750            quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
751
752            benefit_str = ""
753
754            # strings for if dynamic quantized per tensor is needed
755            recommend_per_tensor = (
756                ". We recommend to add a {} before this module if it is static."
757            )
758            rec_lay_to_add = "dynamic quantize per tensor layer"
759            dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
760            dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution"
761
762            # start composing explanation
763            if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
764                quantization_type = "dynamic"
765                # check if currently supported or future supported
766                benefit_str = dynamic_benefit
767                if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
768                    benefit_str += future_support_str
769            else:
770                quantization_type = "static"
771                benefit_str = static_benefit
772
773            # now set the quantization explanation string
774            quantization_reasoning = (
775                quantization_reasoning.format(
776                    module_fqn,
777                    module_info[self.PRE_OBS_DATA_DIST_KEY],
778                    module_info[self.POST_OBS_DATA_DIST_KEY],
779                )
780                + benefit_str
781            )
782
783            # if we have a non-stationary input -> linear -> stationary we suggested static
784            # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
785            if (
786                module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
787                and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
788            ):
789                quantization_reasoning = (
790                    quantization_reasoning
791                    + dynamic_per_tensor_string
792                    + dynamic_per_tensor_reasoning_string
793                )
794
795            # format the overall suggestion string with the specific inputs
796            module_suggestion_string = suggestion_string_template.format(
797                module_fqn, quantization_type, quantization_reasoning
798            )
799
800            # append to overall suggestion
801            dynamic_vs_static_string += module_suggestion_string
802
803        if not modules_added:
804            dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"
805
806        # return the string as well as the dictionary of information
807        return (dynamic_vs_static_string, module_dynamic_static_info)
808
809
810class InputWeightEqualizationDetector(DetectorBase):
811    r"""
812    Determines whether input-weight equalization can help improve quantization for certain modules.
813
814    Specifically, this list of modules includes:
815        linear
816        conv
817
818    Determines whether input-weight equalization is recommended based on the comp stat:
819        s_c = sqrt(w_c/W)/sqrt(i_c/I)
820        where:
821            w_c is range of weight for channel c, W is range of weight over all channels
822            i_c is range of input for channel c, I is range of input over all channels
823
824        if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization
825
826    Args:
827        ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
828            Should be between 0 and 1 (both non-inclusive)
829        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
830            Default: 1
831
832    * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
833        Should be between 0 and 1
834
835    * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization
836
837    * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization
838
839    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
840    """
841
842    SUPPORTED_MODULES: Set[Callable] = {
843        nn.Linear,
844        nn.Conv1d,
845        nn.Conv2d,
846        nn.Conv3d,
847        nnqat.Linear,
848        nnqat.Conv1d,
849        nnqat.Conv2d,
850        nnqat.Conv3d,
851    }
852
853    # names for the pre and post observers that are inserted
854    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
855
856    # weight / activation prefix for each of the below info
857    WEIGHT_PREFIX = "weight_"
858    ACTIVATION_PREFIX = "input_activation_"
859
860    # string names for keys of info dictionaries
861    PER_CHANNEL_MAX_KEY = "per_channel_max"
862    PER_CHANNEL_MIN_KEY = "per_channel_min"
863    GLOBAL_MAX_KEY = "global_max"
864    GLOBAL_MIN_KEY = "global_min"
865
866    # keys for return dict of recommendations
867    RECOMMENDED_KEY = "input_weight_equalization_recommended"
868    COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
869    THRESHOLD_KEY = "input_weight_threshold"
870    CHANNEL_KEY = "input_weight_channel_axis"
871
872    # default weight and info strings
873    WEIGHT_STR = "weight"
874    INPUT_STR = "input"
875
876    # default for what ratio we recommend input weight
877    DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4
878
879    def __init__(self, ratio_threshold: float, ch_axis: int = 1):
880        # ensure passed in inputs are valid
881        if ratio_threshold <= 0 or ratio_threshold >= 1:
882            raise ValueError("Make sure threshold is > 0 and < 1")
883
884        # initialize attributes based on args
885        self.ratio_threshold: float = ratio_threshold
886        self.ch_axis: int = ch_axis
887
888    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
889        r"""Returns whether the given module is supported for observers
890
891        Args
892            module: The module to check and ensure is supported
893            insert: True if this is check for observer insertion, false if for report gen
894
895        Returns True if the module is supported by observer, False otherwise
896        """
897        # check to see if module is of a supported type
898        is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES)
899
900        # this is check for observer insertion
901        if insert:
902            return is_supported_type
903        else:
904            # this is for report gen and we also need to check if it contains observers
905            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
906            return is_supported_type and has_obs
907
908    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
909        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
910        Args
911            model (nn.Module or subclass): model to find observer insertion points
912
913        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
914            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
915        """
916        # run the helper function to populate the dictionary
917        # find the range of inputs
918        input_values: Dict[str, Dict] = self._extract_input_info(model)
919
920        # find the range of weights
921        weight_values: Dict[str, Dict] = self._extract_weight_info(model)
922
923        # calculate per_channel comparison statistic s_c
924        comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(
925            input_values, weight_values
926        )
927
928        # generate the return dictionary
929        input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(
930            input_values, weight_values, comp_stats
931        )
932
933        # we actually have a qconfig info object we are populating
934        module_fqn_to_detector_qconfig_info = {}
935
936        for module_fqn in input_weight_equalization_info:
937            # create a detector info instance
938            detector_qconfig_info = DetectorQConfigInfo(module_fqn)
939
940            # see if per channel quantization is supported
941            input_weight_recommended: bool = input_weight_equalization_info[module_fqn][
942                self.RECOMMENDED_KEY
943            ]
944            detector_qconfig_info.is_equalization_recommended = input_weight_recommended
945            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
946
947        return module_fqn_to_detector_qconfig_info
948
949    def determine_observer_insert_points(
950        self, prepared_fx_model: GraphModule
951    ) -> Dict[str, Dict[str, Any]]:
952        r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
953        For this detector, we want to place observers in front of supported layers.
954
955        Currently inserts observers for:
956            linear layers
957            conv layers
958
959        Args:
960            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
961
962        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
963            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
964            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
965            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
966            key "observer_args" -> The arguments that are meant to be passed into the observer
967        """
968
969        # observer for this detector is ModelReportObserver
970        obs_ctr = ModelReportObserver
971
972        # return dict
973        obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
974
975        for fqn, module in prepared_fx_model.named_modules():
976            # check to see if module is of a supported type
977            if self._is_supported(module, insert=True):
978                # if it's a supported type, we want to get node and add observer insert locations
979                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
980
981                # add entry for pre-observer
982                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
983
984                obs_fqn_to_info[pre_obs_fqn] = {
985                    DETECTOR_TARGET_NODE_KEY: targeted_node,
986                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
987                    DETECTOR_IS_POST_OBS_KEY: False,
988                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
989                }
990
991        return obs_fqn_to_info
992
993    def get_detector_name(self) -> str:
994        r"""Returns the name of this detector"""
995        return "input_weight_equalization_detector"
996
997    def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]:
998        r"""
999        Takes in a calibrated GraphModule and then finds the relevant observers.
1000        It then extracts the input information for each observer returns it
1001
1002        Args
1003            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1004
1005        Returns a dict mapping relevant module fqns (str) to a dict with keys:
1006            "input_activation_per_channel_max" : maps to the per_channel max values
1007            "input_activation_per_channel_min" : maps to the per_channel min values
1008            "input_activation_global_max" : maps to the global max recorded
1009            "input_activation_global_min" : maps to the global min recorded
1010        """
1011
1012        # return dictionary mapping observer fqns to desired info
1013        input_info: Dict[str, Dict] = {}
1014
1015        for fqn, module in model.named_modules():
1016            # if module is supported and it has a pre-observer
1017            if self._is_supported(module):
1018                # get pre observer for the module
1019                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
1020
1021                input_info[fqn] = {
1022                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
1023                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
1024                    self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
1025                    self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
1026                }
1027
1028        return input_info
1029
1030    def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]:
1031        r"""
1032        Takes in a calibrated GraphModule and then finds the relevant observers.
1033        It then extracts the weight information for each layer an observer is attached to.
1034
1035        Args
1036            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1037
1038        Returns a dict mapping module fqns (str) to a dict with keys:
1039            "per_channel_max" : maps to the per_channel max values
1040            "per_channel_min" : maps to the per_channel min values
1041            "global_max" : maps to the global max recorded
1042            "global_min" : maps to the global min recorded
1043        """
1044        # return dictionary mapping observer fqns to desired info
1045        weight_info: Dict[str, Dict] = {}
1046
1047        for fqn, module in model.named_modules():
1048            # if module is supported and it has a pre-observer
1049            if self._is_supported(module):
1050                # we don't need actual observer, just the module weights
1051                # calculate min and max vals
1052                device = module.weight.device
1053                min_val: torch.Tensor = torch.tensor([float("inf")], device=device)
1054                max_val: torch.Tensor = torch.tensor([float("-inf")], device=device)
1055                x_copy = module.weight
1056                x_dim = x_copy.size()
1057
1058                new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
1059                new_axis_list[self.ch_axis] = 0
1060                new_axis_list[0] = self.ch_axis
1061                y = x_copy.permute(new_axis_list)
1062
1063                # Need to match dtype of min/max because the updates to buffers
1064                # are done in place and types need to match for comparisons
1065                y = y.to(min_val.dtype)
1066                y = torch.flatten(y, start_dim=1)
1067                if min_val.numel() == 0 or max_val.numel() == 0:
1068                    min_val, max_val = torch.aminmax(y, dim=1)
1069                else:
1070                    min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
1071                    min_val = torch.min(min_val_cur, min_val)
1072                    max_val = torch.max(max_val_cur, max_val)
1073
1074                weight_info[fqn] = {
1075                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
1076                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
1077                    self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
1078                    self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
1079                }
1080
1081        return weight_info
1082
1083    def _calculate_range_ratio(
1084        self, info_dict: Dict, info_str: str, module_fqn: str
1085    ) -> torch.Tensor:
1086        r"""
1087        Takes in an info dict and calculates the s_c matrix.
1088
1089        Args:
1090            info_dict (dict): A dictionary of either input or weight range info
1091            info_str (str): A str describing whether currently looking at weight or input info
1092                Either "weight" or "input"
1093            module_fqn (str): The fqn of the module we are looking at
1094
1095        Returns a tensor of values, where each value is the s_c stat for a different channel
1096        """
1097        # calculate the ratios of the info
1098        # get the prefix str
1099        prefix_str = (
1100            self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
1101        )
1102
1103        per_channel_range = (
1104            info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY]
1105            - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
1106        )
1107        global_range = (
1108            info_dict[prefix_str + self.GLOBAL_MAX_KEY]
1109            - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
1110        )
1111
1112        if global_range == 0:
1113            range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
1114            raise ValueError(
1115                f"The range of the {info_str} data for module {module_fqn} is 0, "
1116                f"which means you have a constant value channel. {range_zero_explanation}"
1117            )
1118
1119        ratio = per_channel_range / global_range
1120
1121        return ratio
1122
1123    def _generate_comparison_values(
1124        self, input_info: Dict, weight_info: Dict
1125    ) -> Dict[str, torch.Tensor]:
1126        r"""
1127        Takes in the information on the min and max values of the inputs and weights and:
1128            Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)
1129
1130        Args:
1131            input_info (dict): A dict mapping each observer to input range information
1132            weight_info (dict): A dict mapping each observer to weight range information
1133
1134        Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
1135            Each value is a different s_c value for a different channel
1136        """
1137        # create return dictionary for each observer
1138        module_fqn_to_channel: Dict[str, torch.Tensor] = {}
1139
1140        # for each module (both passed in dicts should have same keys)
1141        for module_fqn in input_info:
1142            # raise error if not in weight info
1143            if module_fqn not in weight_info:
1144                raise KeyError(
1145                    f"Unable to find weight range stats for module {module_fqn}"
1146                )
1147
1148            # calculate the ratios of the weight info and input info
1149            weight_ratio = self._calculate_range_ratio(
1150                weight_info[module_fqn], self.WEIGHT_STR, module_fqn
1151            )
1152            input_ratio = self._calculate_range_ratio(
1153                input_info[module_fqn], self.INPUT_STR, module_fqn
1154            )
1155
1156            # if mismatched size, because of grouping, we want to replicate weight enough times
1157            weight_channels = len(weight_ratio)
1158            input_channels = len(input_ratio)
1159            if weight_channels != input_channels:
1160                # we try to replicate
1161                assert (
1162                    input_channels % weight_channels == 0
1163                ), "input channels should be divisible by weight channels."
1164                # get replication factor
1165                rep_factor: int = input_channels // weight_channels
1166
1167                # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
1168                weight_ratio = weight_ratio.repeat(rep_factor)
1169
1170            # calculate the s metric per channel
1171            s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
1172            module_fqn_to_channel[module_fqn] = s
1173
1174        # return compiled observer ratios
1175        return module_fqn_to_channel
1176
1177    def _generate_dict_info(
1178        self, input_info: Dict, weight_info: Dict, comp_stats: Dict
1179    ) -> Dict[str, Dict]:
1180        r"""
1181        Helper function for generate_detector_report that does the generation of the dictionary.
1182        This process is done as specified in generate_detector_report documentation
1183
1184        Args:
1185            input_info (dict): A dict mapping each module to input range information
1186            weight_info (dict): A dict mapping each module to weight range information
1187            comp_stats (dict): A dict mapping each module to its corresponding comp stat
1188
1189        Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
1190            whether input weight equalization is recommended
1191            their s_c metric compared to the threshold
1192            the threshold used to make the recommendation
1193            the channel used for recording data
1194            the input channel range info
1195            the weight channel range info
1196        """
1197        # store modules input weight equalization info
1198        input_weight_equalization_info: Dict[str, Dict] = {}
1199
1200        # for each module we add separate set of suggestions
1201        for module_fqn in input_info:
1202            # get relevant info for this module
1203            mod_input_info: Dict = input_info[module_fqn]
1204            mod_weight_info: Dict = weight_info[module_fqn]
1205            mod_comp_stat: Dict = comp_stats[module_fqn]
1206
1207            # decide if each channel should have input weight equalization or not
1208            channel_rec_vals: list = []
1209
1210            for val in mod_comp_stat:
1211                float_rep: float = val.item()
1212
1213                # decide if recommending input weight equalization
1214                recommended: bool = (
1215                    float_rep >= self.ratio_threshold
1216                    and float_rep <= 1 / self.ratio_threshold
1217                )
1218                channel_rec_vals.append(recommended)
1219
1220            # build the return dict input
1221            # also unpack input and weight dicts into it
1222            input_weight_equalization_info[module_fqn] = {
1223                self.RECOMMENDED_KEY: channel_rec_vals,
1224                self.COMP_METRIC_KEY: mod_comp_stat,
1225                self.THRESHOLD_KEY: self.ratio_threshold,
1226                self.CHANNEL_KEY: self.ch_axis,
1227                **mod_input_info,
1228                **mod_weight_info,
1229            }
1230
1231        # return our compiled info for each module
1232        return input_weight_equalization_info
1233
1234    def generate_detector_report(
1235        self, model: GraphModule
1236    ) -> Tuple[str, Dict[str, Any]]:
1237        r"""
1238        Determines whether input weight equalization is appropriate for a given module.
1239
1240        Takes advantage of the ModelReport Observer which records per channel information of input range
1241        It then uses the passed in weight info inconjunction to compute the desired ratio
1242        Finally, it gives suggestions based on this information for each module of interest
1243
1244        Args:
1245            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1246
1247        Returns a tuple with two elements:
1248            String report of of whether input weight equalization is recommended for certain modules
1249            Dictionary mapping modules of interest to:
1250                whether input weight equalization is recommended
1251                their s_c metric compared to the threshold
1252                the threshold used to make the recommendation
1253                the channel used for recording data
1254                the input channel range info
1255                the weight channel range info
1256        """
1257
1258        # find the range of inputs
1259        input_values: Dict[str, Dict] = self._extract_input_info(model)
1260
1261        # find the range of weights
1262        weight_values: Dict[str, Dict] = self._extract_weight_info(model)
1263
1264        # calculate per_channel comparison statistic s_c
1265        comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(
1266            input_values, weight_values
1267        )
1268
1269        # generate the return dictionary
1270        input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(
1271            input_values, weight_values, comp_stats
1272        )
1273
1274        # now we can generate report based on this information
1275        input_weight_string = "Input-Weight Equalization suggestions: \n"
1276
1277        # some strings to be formatted depending on module we are adding
1278        module_suggestion_str = "For Module {} looked at with axis {}: \n"
1279        channel_suggestion_str = (
1280            "\tWe suggest {} input weight equalization because {}\n"
1281        )
1282        use_str = "to use"
1283        no_use_str = "to not use"
1284        input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
1285        input_weight_non_benefit_reasoning = (
1286            "{}/{} channels benefitting from input-weight equalization being applied."
1287        )
1288        input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"
1289
1290        # added module check
1291        added_module: bool = False
1292
1293        # compile the suggestion string
1294        for module_fqn in input_weight_equalization_info:
1295            # we added at least 1 module
1296            added_module = True
1297            # add the module level description
1298            input_weight_string += module_suggestion_str.format(
1299                module_fqn, self.ch_axis
1300            )
1301
1302            mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn]
1303
1304            # gather info on how many channels would benefit from input weight and
1305            recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
1306            num_recs = sum(recommendation_per_channel)
1307
1308            if (
1309                num_recs / len(recommendation_per_channel)
1310                >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO
1311            ):
1312                input_benefit_formatted = input_weight_benefit_str.format(
1313                    num_recs, len(recommendation_per_channel)
1314                )
1315                channel_str = channel_suggestion_str.format(
1316                    use_str, input_benefit_formatted
1317                )
1318                input_weight_string += channel_str
1319            else:
1320                non_benefit_reason_formatted = (
1321                    input_weight_non_benefit_reasoning.format(
1322                        num_recs, len(recommendation_per_channel)
1323                    )
1324                )
1325                non_benefit_str = input_weight_non_benefit_str.format(
1326                    non_benefit_reason_formatted
1327                )
1328                channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
1329                input_weight_string += channel_str
1330
1331        # if no modules looked at, amend return string
1332        if not added_module:
1333            input_weight_string += (
1334                "No applicable layers for suggestions. Only linear and conv valid.\n"
1335            )
1336
1337        # return a tuple with the string explanation and the compiled dict info
1338        return (input_weight_string, input_weight_equalization_info)
1339
1340
1341class OutlierDetector(DetectorBase):
1342    r"""
1343    Determines whether there are significant outliers in activation data around a certain layer.
1344
1345    This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
1346        If the data is stationary, and there are significant outliers, then we want to flag them
1347        We want to do this on a per channel basis for detecting outliers
1348
1349    Determines whether activation data is flagged as outlier based on if data is stationary and:
1350        p_r = avg(100th percentile / "reference_percentile"th percentile)
1351        where:
1352            p_r is average percentile ratio across all batches in the epoch
1353            reference_percentile is a percentile values between 0 and 100 exclusive
1354
1355        if p_r is above some threshold, then we consider the activations to have significant outliers
1356
1357    Args:
1358        ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
1359            Should be >= 1
1360            Default: 3.5
1361        reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
1362            Should be between 0 and 1
1363            Default: 0.975
1364        fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
1365            If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
1366            regardless of whether we detected outliers or not in channel to take a closer look at channel results
1367            Should be between 0 and 1
1368            Default: 0.95
1369        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
1370            Default: 1
1371
1372    * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
1373        The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
1374        If it is significantly greater, then we consider it an outlier
1375        This threshold was calculated based on the ratio of the percentiles in a normal distribution
1376        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
1377
1378    * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
1379        Should be between 0 and 1
1380        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
1381
1382    * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
1383        Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
1384        Should be between 0 and 1
1385
1386    * :attr:`ch_axis`: The channel axis being observed to determine outliers
1387
1388    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
1389    """
1390
1391    # names for the pre observers that are inserted
1392    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
1393
1394    # pre activation prefix
1395    INPUT_ACTIVATION_PREFIX = "input_activation_"
1396
1397    # names for dict keys
1398    OUTLIER_KEY = "outliers_detected"
1399    NUM_BATCHES_KEY = "outlier_detection_batches_used"
1400    IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
1401    COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
1402    RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
1403    REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
1404    CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
1405    MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
1406    CONSTANT_COUNTS_KEY = "constant_batch_counts"
1407
1408    def __init__(
1409        self,
1410        ratio_threshold: float = 3.5,
1411        reference_percentile: float = 0.975,
1412        fraction_batches_used_threshold: float = 0.95,
1413        ch_axis: int = 1,
1414    ):
1415        # initialize the variables of interest
1416        self.ratio_threshold = ratio_threshold
1417
1418        # make sure passed in percentile is valid
1419        assert reference_percentile >= 0 and reference_percentile <= 1
1420        assert (
1421            fraction_batches_used_threshold >= 0
1422            and fraction_batches_used_threshold <= 1
1423        )
1424        self.reference_percentile = reference_percentile
1425        self.fraction_batches_used_threshold = fraction_batches_used_threshold
1426        self.ch_axis = ch_axis
1427
1428    def get_detector_name(self) -> str:
1429        r"""Returns the name of this detector"""
1430        return "outlier_detector"
1431
1432    def _supports_insertion(self, module: nn.Module) -> bool:
1433        r"""Returns whether the given module is supported for observers insertion
1434
1435        Any module that doesn't have children and isn't an observer itself is supported
1436
1437        Args
1438            module: The module to check and ensure is supported
1439
1440        Returns True if the module is supported by observer, False otherwise
1441        """
1442        # case for insertion of module
1443        # check if the module has any children and isn't observer
1444        num_children = len(list(module.children()))
1445        return num_children == 0 and not _is_activation_post_process(module)
1446
1447    def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
1448        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
1449        Args
1450            model (nn.Module or subclass): model to find observer insertion points
1451
1452        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
1453            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
1454        """
1455        # currently doesn't do anything for outlier detector
1456        return {}
1457
1458    def _supports_report_gen(self, module: nn.Module) -> bool:
1459        r"""Returns whether the given module is supported for report generation
1460
1461        Any module that has a model report pre-observer is supported
1462
1463        Args
1464            module: The module to check and ensure is supported
1465
1466        Returns True if the module is supported by observer, False otherwise
1467        """
1468        return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
1469
1470    def determine_observer_insert_points(
1471        self, prepared_fx_model: GraphModule
1472    ) -> Dict[str, Dict[str, Any]]:
1473        r"""Determines where observers need to be inserted for the Outlier Detector.
1474
1475        For this detector, we want to place observers in front of supported layers.
1476
1477        Currently inserts observers for:
1478            all layers that do not have children (leaf level layers)
1479
1480        Args:
1481            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
1482
1483        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
1484            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
1485            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
1486            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
1487            key "observer_args" -> The arguments that are meant to be passed into the observer
1488        """
1489        # observer for this detector is ModelReportObserver
1490        obs_ctr = ModelReportObserver
1491
1492        # return dict
1493        obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
1494
1495        for fqn, module in prepared_fx_model.named_modules():
1496            # check to see if module is of a supported type
1497            if self._supports_insertion(module):
1498                # if it's a supported type, we want to get node and add observer insert locations
1499                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
1500
1501                # add entry for pre-observer
1502                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
1503
1504                obs_fqn_to_info[pre_obs_fqn] = {
1505                    DETECTOR_TARGET_NODE_KEY: targeted_node,
1506                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(
1507                        ch_axis=self.ch_axis, comp_percentile=self.reference_percentile
1508                    ),
1509                    DETECTOR_IS_POST_OBS_KEY: False,
1510                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
1511                }
1512
1513        return obs_fqn_to_info
1514
1515    def _calculate_outlier_info(
1516        self,
1517        percentile_ratios: torch.Tensor,
1518        counted_batches: torch.Tensor,
1519        total_batches: int,
1520    ) -> Dict[str, List[bool]]:
1521        r"""
1522        Gives info on whether the percentile ratios calculated would be considered outliers
1523        Also gives information on whether the collected data is statistically significant to make this claim
1524
1525        Args:
1526            percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
1527            counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
1528            total_batches (int): The total number of batches that passed through observer in this epoch
1529
1530        Returns a dictionary mapping:
1531            "outliers_detected" : list of bools per channel that are true if it is considered an outlier
1532            "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
1533                where o_r = counted_batches / total_batches
1534        """
1535        outlier_dict: Dict[str, List[bool]] = {
1536            self.OUTLIER_KEY: [],
1537            self.IS_SUFFICIENT_BATCHES_KEY: [],
1538        }
1539
1540        # get both as flattened lists for easy mapping
1541        ratios_list: List = percentile_ratios.tolist()
1542        num_batches_list: List = counted_batches.tolist()
1543
1544        # calculate whether channels were statistically significant
1545        significant_size = [
1546            batch_size / total_batches >= self.fraction_batches_used_threshold
1547            for batch_size in num_batches_list
1548        ]
1549        outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
1550
1551        # calculate for each channel whether it's an outlier or not based on ratio
1552        outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
1553        outlier_dict[self.OUTLIER_KEY] = outlier_detected
1554
1555        # return the dictionary with the two lists
1556        return outlier_dict
1557
1558    def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]:
1559        r"""
1560        Helper function for generate_detector_report that does the generation of the dictionary.
1561        This process is done as specified in generate_detector_report documentation
1562
1563        Args:
1564            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1565
1566        Returns a dict mapping relevant module fqns to:
1567            whether there were outliers found in activation before
1568            the number of batches used for each channel
1569            whether fraction of applicable batches used is above fraction_batches_used_threshold
1570            their p_r metric compared to the threshold
1571            the threshold used to make the recommendation
1572            the reference_percentile used to make the recommendation
1573            the channel axis used to determine individual channels
1574            the constant batch counts per channel
1575            the per channel max values
1576        """
1577        # return dictionary mapping observer fqns to desired info
1578        info_dict: Dict[str, Dict] = {}
1579
1580        for fqn, module in model.named_modules():
1581            # if module is supported and it has a pre-observer
1582            if self._supports_report_gen(module):
1583                # get pre observer for the module
1584                pre_obs: ModelReportObserver = getattr(
1585                    module, self.DEFAULT_PRE_OBSERVER_NAME
1586                )
1587
1588                # get the number of batches and calculated ratio thresholds
1589                num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
1590                average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
1591                channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
1592                total_batches: int = pre_obs.num_batches_tracked
1593
1594                # also get the max values
1595                max_vals: torch.Tensor = pre_obs.max_val
1596
1597                # we have to specifically modify how we are recording negative ratio for pre-relu layers
1598                for index, ratio_val in enumerate(average_ratios):
1599                    # check if we have a negative ratio
1600                    # a ratio might be negative if we have a situation where the 100th percentile is
1601                    # > 0 while the nth percentile is < 0, in which case this would not be detected
1602                    # as an outlier. Since we care more about magnitude, we make it positive.
1603                    if ratio_val.item() < 0:
1604                        # first make it positive
1605                        average_ratios[index] = -ratio_val
1606
1607                    if ratio_val.item() < 1:
1608                        # if it's less than 1 we have the flip it as well
1609                        average_ratios[index] = 1 / ratio_val
1610
1611                outlier_calcs = self._calculate_outlier_info(
1612                    average_ratios, num_batches, total_batches
1613                )
1614
1615                # calculate whether ratios were outliers
1616                info_dict[fqn] = {
1617                    self.CHANNEL_AXIS_KEY: self.ch_axis,
1618                    self.REF_PERCENTILE_KEY: self.reference_percentile,
1619                    self.RATIO_THRES_KEY: self.ratio_threshold,
1620                    self.COMP_METRIC_KEY: average_ratios,
1621                    self.NUM_BATCHES_KEY: num_batches,
1622                    self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
1623                    self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[
1624                        self.IS_SUFFICIENT_BATCHES_KEY
1625                    ],
1626                    self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
1627                    self.MAX_VALS_KEY: max_vals,
1628                }
1629
1630        return info_dict
1631
1632    def generate_detector_report(
1633        self, model: GraphModule
1634    ) -> Tuple[str, Dict[str, Any]]:
1635        r"""
1636        Determines whether input weight equalization is appropriate for a given module.
1637
1638        Takes advantage of the ModelReport Observer which records the relevant percentile information
1639
1640        Args:
1641            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1642
1643        Returns a tuple with two elements:
1644            String report of of whether there are outliers in the activations around certain modules
1645            Dictionary mapping modules of interest to:
1646                whether there were outliers found in activation before
1647                the number of batches used for each channel
1648                whether fraction of applicable batches used is above fraction_batches_used_threshold
1649                their p_r metric compared to the threshold
1650                the threshold used to make the recommendation
1651                the reference_percentile used to make the recommendation
1652                the channel axis used to determine individual channels
1653                the constant batch counts per channel
1654                the per channel max values
1655        """
1656        # generate the information dictionary of outlier information
1657        info_dict = self._generate_info_dict(model)
1658
1659        # now we can generate report based on this information
1660        outlier_string = "Outlier detection report: \n"
1661
1662        # added module check
1663        added_module: bool = False
1664
1665        # some strings to be formatted depending on module we are adding
1666        module_suggestion_str = "For Module {} looked at with axis {}: \n"
1667        channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
1668        channel_max_value_str = "a max value across all batches of {}"
1669        note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
1670        note_distribution = "stationary distributions"
1671        note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"
1672
1673        # suggestion for constant batch check since that can make it no outliers
1674        constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
1675        constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."
1676
1677        # compile the suggestion string
1678        for module_fqn in info_dict:
1679            # get module specific info
1680            mod_info: Dict[str, Any] = info_dict[module_fqn]
1681            # check to see if we already added high level model desc
1682            added_model_desc = False
1683            # look at each individual channel and add a suggestion
1684            for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
1685                if outlier_detected:
1686                    # we found at least 1 outlier
1687                    if not added_model_desc:
1688                        # add the module level description
1689                        outlier_string += module_suggestion_str.format(
1690                            module_fqn, self.ch_axis
1691                        )
1692                        added_model_desc = True
1693
1694                    # we mark that we found at least one outlier
1695                    added_module = True
1696                    max_value_found_str = channel_max_value_str.format(
1697                        mod_info[self.MAX_VALS_KEY][index]
1698                    )
1699                    channel_str = channel_suggestion_str.format(
1700                        index, max_value_found_str
1701                    )
1702                    outlier_string += channel_str
1703
1704                # also check if we found constant batch
1705                if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
1706                    # make sure we add a module level highlight.
1707                    if not added_model_desc:
1708                        # add the module level description
1709                        outlier_string += module_suggestion_str.format(
1710                            module_fqn, self.ch_axis
1711                        )
1712                        added_model_desc = True
1713
1714                    constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][
1715                        index
1716                    ]
1717                    formatted_str = constant_str.format(
1718                        index, constant_values_for_channel, constant_suggestion
1719                    )
1720                    outlier_string += formatted_str
1721                    # we also added at least one thing to description
1722                    added_module = True
1723
1724        # if found outlier, give suggestion, else give default response
1725        if added_module:
1726            # compose the note string
1727            note_composed = note_string.format(note_distribution, note_rec)
1728            outlier_string += note_composed
1729        else:
1730            outlier_string += "There were no outliers found in the activations.\n"
1731
1732        return (outlier_string, info_dict)
1733