xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_model_report/model_report.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import OrderedDict
3from typing import Any, Callable, Dict, Set, Tuple
4
5import torch
6from torch.ao.quantization.fx._equalize import EqualizationQConfig
7from torch.ao.quantization.fx._model_report.detector import (
8    DETECTOR_IS_POST_OBS_KEY,
9    DETECTOR_OBS_ARGS_KEY,
10    DETECTOR_OBS_TO_INSERT_KEY,
11    DETECTOR_TARGET_NODE_KEY,
12    DetectorBase,
13    DetectorQConfigInfo,
14)
15from torch.ao.quantization.fx._model_report.model_report_visualizer import (
16    ModelReportVisualizer,
17)
18from torch.ao.quantization.fx.graph_module import GraphModule
19from torch.ao.quantization.observer import ObserverBase
20from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping
21
22
23class ModelReport:
24    r"""
25    The ModelReport class aims to provide users an easy way to diagnose issues that they run into
26    with their models. The class works with all traceable GraphModules to help diagnose issues,
27    though the requirements on the type of model more-so depends on the specific report the user
28    is trying to generate. With respect to the reports, the ModelReport class is initialized with
29    a set of Detector classes, each of which generate reports on quantization configuration
30    issues a use might have.
31
32    Currently supports generating reports on:
33    - Suggestions for per-channel vs. per-tensor quantization (nn.Module)
34    - Suggestions for dynamic vs static quantization for linear layers (Graph Modules)
35    - Suggestions for input-weight equalization for linear and conv layers (Graph Modules)
36    - Suggestions for outlier detection for all layers (Graph Modules)
37
38    The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver)
39    where needed for each detector to gather the information it needs, and then after callibration, the ModelReport
40    class compiles the report generated by each Detector class into a single report to return to the user. It also
41    has the capability to remove all the observers it inserted as well.
42
43    * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule
44
45    * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class
46        Make sure that these are all unique types of detectors [do not have more than 1 of the same class]
47
48    * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors.
49        This set is generated by calling the get_detector_name() of each detector
50
51    * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest
52        The purpose of this is to keep track of what observers were inserted for each detector, so that they
53        can be removed at the end if desired
54
55    * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not
56        This is to ensure we only insert observers once with the ModelReport instance
57
58    * :attr:`_removed_observers` A boolean to track if we have removed observers already
59        The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport
60        instance. This also allows the functionality where we can generate the report multiple times
61        as long as we haven't removed the observers yet.
62
63    Note:
64        This class was initially designed to work with the Fx Graph Mode workflow in mind. However,
65        full functionality is available as long as there is a traceable GraphModule that is being used.
66        One method to get a traceable GraphModule without going through the Fx workflow is to use
67        the QuantizationTracer class.
68
69    General Flow for Fx workflow:
70    1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model
71    2.) Prepare your model with prepare_fx
72    3.) Call model_report.prepare_detailed_calibration to add relevant observers
73    4.) Callibrate your model with data
74    5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
75    Optional
76        6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
77        7.) To help in parsing report information and debugging, view report info as a:
78            - Table
79            - Histogram
80            - Line plot
81    8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions
82
83    Example (with QuantizationTracer):
84        >>> # xdoctest: +SKIP
85        >>> # get the necessary qconfig
86        >>> config = PrepareCustomConfig()
87        >>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False)
88
89        >>> # initialize our model and get GraphModule
90        >>> model = SomeModel()
91        >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
92        >>> graph_module = GraphModule(model, tracer.trace(model))
93
94        >>> # get our set of detectors and ModelReport instance
95        >>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)])
96        >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set)
97
98        >>> # now we insert the observers and callibrate the model
99        >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration()
100        >>> for i in range(num_callibration_batches):
101        >>>     example_input = get_callibration_input()
102        >>>     tracer_model_with_observers(example_input)
103
104        >>> # finally we generate the reports and optionally remove the observers we inserted
105        >>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True)
106
107        >>> # Optional: we can generate the qconfig mapping based on the suggestions
108        >>> qconfigs = model_report.generate_qconfig_mapping()
109
110        >>> # Optional: we can generate the equalization mapping based on the suggestions
111        >>> qconfigs = model_report.generate_equalization_mapping()
112
113        >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
114        >>> model_report_visualizer = tracer_reporter.generate_visualizer()
115
116    """
117
118    def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]):
119        if len(desired_report_detectors) == 0:
120            raise ValueError("Should include at least 1 desired report")
121
122        # keep track of the model we wish to generate report for
123        self._model: GraphModule = model
124
125        # keep the reports private so they can't be modified
126        self._desired_report_detectors = desired_report_detectors
127        self._desired_detector_names = {
128            detector.get_detector_name() for detector in desired_report_detectors
129        }
130
131        # keep a mapping of desired reports to observers of interest
132        # this is to get the readings, and to remove them, can create a large set
133        # this set can then be used to traverse the graph and remove added observers
134        self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
135
136        # initialize each report to have empty set of observers of interest
137        for desired_report in self._desired_detector_names:
138            self._detector_name_to_observer_fqns[desired_report] = set()
139
140        # flags to ensure that we can only prepare and remove observers once
141        self._prepared_flag = False
142        self._removed_observers = False
143
144        # store the reports that we generated for visualization purposes
145        # initially empty since no reports generated
146        self._generated_reports: Dict[str, Dict] = {}
147
148    def get_desired_reports_names(self) -> Set[str]:
149        """Returns a copy of the desired reports for viewing"""
150        return self._desired_detector_names.copy()
151
152    def get_observers_of_interest(self) -> Dict[str, Set[str]]:
153        """Returns a copy of the observers of interest for viewing"""
154        return self._detector_name_to_observer_fqns.copy()
155
156    def prepare_detailed_calibration(self) -> GraphModule:
157        r"""
158        Takes in a graph model and inserts the following observers:
159        - ModelReportObserver
160
161        Each observer is inserted based on the desired_reports into the relevant locations
162
163        Right now, each report in self._desired_detector_names has independent insertions
164            However, if a module already has a Observer of the same type, the insertion will not occur
165            This is because all of the same type of Observer collect same information, so redundant
166
167        Returns the same GraphModule with the observers inserted
168        """
169
170        # if already prepared once, cannot prepare again
171        if self._prepared_flag:
172            raise ValueError(
173                "Already ran preparing detailed callibration. Run the report generation next after callibration."
174            )
175
176        # loop through each detector, find where placements should be, and keep track
177        insert_observers_fqns: Dict[str, Any] = {}
178
179        for detector in self._desired_report_detectors:
180            # determine observer points for each detector
181            obs_fqn_to_info = detector.determine_observer_insert_points(self._model)
182            # map each insert point to the observer to use
183            insert_observers_fqns.update(obs_fqn_to_info)
184            # update the set of observers this report cares about
185            self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(
186                obs_fqn_to_info.keys()
187            )
188
189        # now insert all the observers at their desired locations
190        for observer_fqn in insert_observers_fqns:
191            target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY]
192            insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY]
193            insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY]
194            observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY]
195            self._insert_observer_around_module(
196                observer_fqn, target_node, insert_obs, observer_args, insert_post
197            )
198
199        self._prepared_flag = True
200
201        return self._model
202
203    def _insert_observer_around_module(
204        self,
205        obs_fqn: str,
206        target_node: torch.fx.node.Node,
207        obs_to_insert: ObserverBase,
208        observer_args: Tuple,
209        insert_post: bool,
210    ):
211        r"""
212        Helper function that inserts the observer into both the graph structure and the module of the model
213
214        Args
215            node_fqn (str): The fully qualified name of the observer we want to insert
216            target_node (torch.fx.node.Node): The node in model we are inserting observers around
217            obs_to_insert (ObserverBase): The observer we are inserting around target_node
218            observer_args (Tuple): The arguments we want to pass into the observer
219            insert_post (bool): whether this is meant to be a post observer for this node
220        """
221        # if we are inserting post, then our target node is the next node
222        if insert_post:
223            target_node = target_node.next
224
225        with self._model.graph.inserting_before(target_node):
226            self._model.add_submodule(obs_fqn, obs_to_insert)
227            self._model.graph.create_node(
228                op="call_module", target=obs_fqn, args=observer_args
229            )
230
231        # recompile model after inserts are made
232        self._model.recompile()
233
234    def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node:
235        r"""
236        Takes in a node fqn and returns the node based on the fqn
237
238        Args
239            node_fqn (str): The fully qualified name of the node we want to find in model
240
241        Returns the Node object of the given node_fqn otherwise returns None
242        """
243        node_to_return = None
244        for node in self._model.graph.nodes:
245            # if the target matches the fqn, it's the node we are looking for
246            if node.target == node_fqn:
247                node_to_return = node
248                break
249
250        if node_to_return is None:
251            raise ValueError("The node_fqn is was not found within the module.")
252
253        # assert for MyPy
254        assert isinstance(node_to_return, torch.fx.node.Node)
255
256        return node_to_return
257
258    def generate_model_report(
259        self, remove_inserted_observers: bool
260    ) -> Dict[str, Tuple[str, Dict]]:
261        r"""
262        Generates all the requested reports.
263
264        Note:
265            You should have callibrated the model with relevant data before calling this
266
267        The reports generated are specified by the desired_reports specified in desired_reports
268
269        Can optionally remove all the observers inserted by the ModelReport instance
270
271        Args:
272            remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance
273
274        Returns a mapping of each desired report name to a tuple with:
275            The textual summary of that report information
276            A dictionary containing relevant statistics or information for that report
277
278        Note:
279            Throws exception if we try to generate report on model we already removed observers from
280            Throws exception if we try to generate report without preparing for callibration
281        """
282        # if we haven't prepped model for callibration, then we shouldn't generate report yet
283        if not self._prepared_flag:
284            raise Exception(  # noqa: TRY002
285                "Cannot generate report without preparing model for callibration"
286            )
287
288        # if we already removed the observers, we cannot generate report
289        if self._removed_observers:
290            raise Exception(  # noqa: TRY002
291                "Cannot generate report on model you already removed observers from"
292            )
293
294        # keep track of all the reports of interest and their outputs
295        reports_of_interest = {}
296
297        for detector in self._desired_report_detectors:
298            # generate the individual report for the detector
299            report_output = detector.generate_detector_report(self._model)
300            reports_of_interest[detector.get_detector_name()] = report_output
301
302        # if user wishes to remove inserted observers, go ahead and remove
303        if remove_inserted_observers:
304            self._removed_observers = True
305            # get the set of all Observers inserted by this instance of ModelReport
306            all_observers_of_interest: Set[str] = set()
307            for desired_report in self._detector_name_to_observer_fqns:
308                observers_of_interest = self._detector_name_to_observer_fqns[
309                    desired_report
310                ]
311                all_observers_of_interest.update(observers_of_interest)
312
313            # go through all_observers_of_interest and remove them from the graph and model
314            for observer_fqn in all_observers_of_interest:
315                # remove the observer from the model
316                self._model.delete_submodule(observer_fqn)
317
318                # remove the observer from the graph structure
319                node_obj = self._get_node_from_fqn(observer_fqn)
320
321                if node_obj:
322                    self._model.graph.erase_node(node_obj)
323                else:
324                    raise ValueError("Node no longer exists in GraphModule structure")
325
326            # remember to recompile the model
327            self._model.recompile()
328
329        # save the generated reports for visualization purposes
330        saved_reports: Dict[str, Dict] = {
331            report_name: report_tuple[1]
332            for report_name, report_tuple in reports_of_interest.items()
333        }
334
335        self._generated_reports = saved_reports
336
337        # return the reports of interest
338        return reports_of_interest
339
340    def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool:
341        r"""
342        Takes in two dictionaries and ensures that any common keys between the two have the same
343        values.
344
345        Args:
346            info_dict_a (Dict): First dictionary we wish to compare
347            info_dict_b (Dict): Second dictionary we wish to compare
348
349        Returns True if all shared keys have same values, false otherwise
350        """
351        # get the set of keys for both
352        dict_a_keys: Set = set(info_dict_a.keys())
353        dict_b_keys: Set = set(info_dict_b.keys())
354
355        # get the insersection keys and check if same value for both dicts
356        intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys)
357
358        for key in intersecting_keys:
359            dict_a_val = info_dict_a[key]
360            dict_b_val = info_dict_b[key]
361
362            # if it's a tensor we have to handle separately
363            if type(dict_a_val) == torch.Tensor:
364                # if dict_b_val not tensor, automatically false
365                if (
366                    type(dict_b_val) != torch.Tensor
367                    or sum(dict_a_val != dict_b_val) != 0
368                ):
369                    return False
370            else:
371                # for non-tensor vals
372                if dict_a_val != dict_b_val:
373                    return False
374
375        # if no non matching shared keys found, return true
376        return True
377
378    def _reformat_reports_for_visualizer(self) -> OrderedDict:
379        r"""
380        Takes the generated reports and reformats them into the format that is desired by the
381        ModelReportVisualizer
382
383        Returns an OrderedDict mapping module_fqns to their features
384        """
385        # we want to reorder and reformat the information so it is ordered in terms of order
386        # found in the model
387
388        # first create new dict with all modules as keys and features under respective module
389        module_fqns_to_features: Dict[str, Dict] = {}
390
391        for report_name in self._generated_reports:
392            # get mod -> feature dict and go through
393            module_info = self._generated_reports[report_name]
394
395            for module_fqn in module_info:
396                # check if already in our accumulation dict
397                if module_fqn in module_fqns_to_features:
398                    # we merge all the features together
399                    new_info: Dict = module_info[module_fqn]
400                    present_info: Dict = module_fqns_to_features[module_fqn]
401
402                    # merge them together into the new unioned dict
403                    # same features keys -> same info, so okay if override
404
405                    # do safety check to make sure shared keys have same info
406                    if self._is_same_info_for_same_key(new_info, present_info):
407                        module_fqns_to_features[module_fqn] = {
408                            **new_info,
409                            **present_info,
410                        }
411                    else:
412                        error_str = "You have the same key with different values across detectors. "
413                        error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors."
414                        raise ValueError(error_str)
415                else:
416                    # we just set it
417                    module_fqns_to_features[module_fqn] = module_info[module_fqn]
418
419        # our ordered dict so that modules can be ordered in order of how they appear in model
420        features_by_module: OrderedDict[str, Dict] = OrderedDict()
421
422        # we loop through modules in graph in order
423        for fqn, module in self._model.named_modules():
424            # find that fqn in fqns_to_features
425            if fqn in module_fqns_to_features:
426                # add it to our ordered dict
427                features_by_module[fqn] = module_fqns_to_features[fqn]
428
429        # return the ordered dict of info we created
430        return features_by_module
431
432    def generate_visualizer(self) -> ModelReportVisualizer:
433        r"""
434        Generates a ModelReportVisualizer instance using the reports generated
435        by the generate_model_report() method.
436
437        Returns the generated ModelReportVisualizer instance initialized
438
439        Note:
440            Throws exception if attempt to get visualizers without generating report
441        """
442        # check if user has generated reports at least once
443        if len(self._generated_reports) == 0:
444            raise Exception(  # noqa: TRY002
445                "Unable to generate visualizers without first generating reports"
446            )
447
448        # get the ordered dict mapping modules to their full set of collected features / stats
449        module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
450
451        # create and return ModelReportVisualizer instance
452        visualizer: ModelReportVisualizer = ModelReportVisualizer(
453            module_fqns_to_features
454        )
455
456        return visualizer
457
458    def _generate_qconfig_mapping_helper(
459        self,
460        detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo],
461        generation_function: Callable,
462    ) -> QConfigMapping:
463        r"""
464        This helper takes in the compiled detector qconfig info that
465        has been compiled together and merges it into a QConfigMapping
466        """
467        # keep track of the qconfigmapping
468        qconfig_mapping = QConfigMapping()
469
470        # loop through each module / fqn and attempt to create QConfigMapping
471        for fqn, module in self._model.named_modules():
472            # if we have a qconfig info for this module
473            if fqn in detector_qconfig_info_combined:
474                qconfig_info_compiled = detector_qconfig_info_combined[fqn]
475
476                # now generate the qconfig and add it to the mapping
477                generated_qconfig = generation_function(qconfig_info_compiled, module)
478
479                # add to our config
480                qconfig_mapping.set_module_name(fqn, generated_qconfig)
481
482        # return compiled mapping
483        return qconfig_mapping
484
485    def _update_detector_quantizaiton_qconfig_info(
486        self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo
487    ):
488        r"""
489        Takes in the old and new information and updates the combined information.
490
491        Args:
492            combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
493            new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
494                into it
495        """
496        combined_info.is_activation_dynamic = (
497            combined_info.is_activation_dynamic or new_info.is_activation_dynamic
498        )
499        combined_info.is_weight_per_channel = (
500            combined_info.is_weight_per_channel or new_info.is_weight_per_channel
501        )
502
503    def _update_detector_equalization_qconfig_info(
504        self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo
505    ):
506        r"""
507        Takes in the old and new information and updates the combined information.
508
509        Args:
510            combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
511            new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
512                into it
513        """
514        is_equalization_recommended = (
515            combined_info.is_equalization_recommended
516            or new_info.is_equalization_recommended
517        )
518        combined_info.is_equalization_recommended = is_equalization_recommended
519
520    def _generate_module_fqn_to_detector_info_mapping(
521        self, update_qconfig_info_function: Callable
522    ) -> Dict[str, DetectorQConfigInfo]:
523        r"""
524        Generates a QConfigMapping based on the suggestions of the
525        ModelReport API. The generated mapping encompasses all the
526        different types of feedback from the different detectors
527        all into one place.
528
529        These configs are based on the suggestions provided by the ModelReport API
530        and can only be generated once the reports have been generated.
531
532        Args:
533            update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo
534            and updates the one that is being compiled
535
536        Returns a Dict mapping module_fqns to DetectorQConfigInfo objects
537
538        Note:
539            Throws exception if we try to generate mapping on model we already removed observers from
540            Throws exception if we try to generate mapping without preparing for callibration
541        """
542        # if we haven't prepped model for callibration, then we shouldn't generate mapping yet
543        if not self._prepared_flag:
544            raise Exception(  # noqa: TRY002
545                "Cannot generate report without preparing model for callibration"
546            )
547
548        # if we already removed the observers, we cannot mapping
549        if self._removed_observers:
550            raise Exception(  # noqa: TRY002
551                "Cannot generate report on model you already removed observers from"
552            )
553
554        # keep track of qconfig info for each module across detectors
555        detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo] = {}
556
557        for detector in self._desired_report_detectors:
558            # get the info from the detector
559            detector_info: Dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(
560                self._model
561            )
562
563            # we go through the modules
564            for module_fqn in detector_info:
565                # see if we already have info on it
566                if module_fqn in detector_qconfig_info_combined:
567                    # we combine the current options with what is there
568                    current_options = detector_qconfig_info_combined[module_fqn]
569                    detector_options = detector_info[module_fqn]
570
571                    update_qconfig_info_function(current_options, detector_options)
572                else:
573                    # we just use this for now
574                    detector_qconfig_info_combined[module_fqn] = detector_info[
575                        module_fqn
576                    ]
577
578        return detector_qconfig_info_combined
579
580    def generate_qconfig_mapping(self) -> QConfigMapping:
581        r"""
582        Generates a QConfigMapping based on the suggestions of the
583        ModelReport API. The generated mapping encompasses all the
584        different types of feedback from the different detectors
585        all into one place.
586
587        These configs are based on the suggestions provided by the ModelReport API
588        and can only be generated once the reports have been generated.
589
590        Returns a QConfigMapping for the quantization configuration
591
592        Note:
593            Throws exception if we try to generate mapping on model we already removed observers from
594            Throws exception if we try to generate mapping without preparing for callibration
595        """
596        # get the mapping info
597        detector_qconfig_info_combined = (
598            self._generate_module_fqn_to_detector_info_mapping(
599                self._update_detector_quantizaiton_qconfig_info
600            )
601        )
602
603        # we will do a bit of processing and remove fqns that don't have input weight recommended
604
605        # now we generate the QConfig for each of the options
606        mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
607            detector_qconfig_info_combined, self._quantization_config_generator
608        )
609
610        # return the generated mapping
611        return mapping
612
613    def _quantization_config_generator(
614        self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module
615    ) -> QConfig:
616        r"""
617        Returns the quantization configuration generated by the DetectorQConfigInfo object
618        """
619        return detector_qconfig_info.generate_quantization_qconfig(module)
620
621    def _equalization_config_generator(
622        self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module
623    ) -> EqualizationQConfig:
624        r"""
625        We ignore the module argument here, and only focus on thedetector_qconfig_info
626
627        Returns the equalization configuration generated by the DetectorQConfigInfo object
628        """
629        return detector_qconfig_info.generate_equalization_qconfig()
630
631    def generate_equalization_mapping(self) -> QConfigMapping:
632        r"""
633        Generates a QConfigMapping based on the suggestions of the
634        ModelReport API for equalization. The generated mapping encompasses all the
635        different types of feedback from the input-weight equalization detector.
636
637        These configs are based on the suggestions provided by the ModelReport API
638        and can only be generated once the reports have been generated.
639
640        Returns a QConfigMapping for the equalization configuration
641        """
642        # get the mapping info
643        detector_qconfig_info_combined = (
644            self._generate_module_fqn_to_detector_info_mapping(
645                self._update_detector_equalization_qconfig_info
646            )
647        )
648
649        # now we generate the QConfig for each of the options
650        mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
651            detector_qconfig_info_combined, self._equalization_config_generator
652        )
653
654        # return the generated mapping
655        return mapping
656