xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/qconfig_multi_mapping.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import copy
5from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
6
7import torch
8from torch.ao.quantization import QConfigMapping
9from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
10
11
12if TYPE_CHECKING:
13    from torch.ao.quantization.qconfig import QConfigAny
14
15__all__ = ["QConfigMultiMapping"]
16
17_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
18    "global_qconfig": "set_global",
19    "object_type_qconfigs": "set_object_type",
20    "module_name_regex_qconfigs": "set_module_name_regex",
21    "module_name_qconfigs": "set_module_name",
22    "module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
23}
24
25
26def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
27    to_remove = []
28    for index, cur_qconfig in enumerate(qconfig_list):
29        if cur_qconfig is None:
30            to_remove.append(index)
31            break
32        for checked_qconfig in qconfig_list[:index]:
33            if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
34                to_remove.append(index)
35                break
36    for index in to_remove[::-1]:
37        qconfig_list.pop(index)
38
39
40class QConfigMultiMapping:
41    """
42    This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
43    so that multiple QConfigs can be specified for each QConfig matching style.
44
45    The user can specify QConfigs using the following methods (in increasing match priority):
46
47        ``set_global`` : sets the global (default) QConfigs
48
49        ``set_object_type`` : sets the QConfigs for a given module type, function, or method name
50
51        ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
52
53        ``set_module_name`` : sets the QConfigs for modules matching the given module name
54
55        ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
56        of the given module name, object type, and the index at which the module appears
57
58    Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
59    single QConfig.
60
61    Example usage::
62
63        qconfig_mapping = QConfigMultiMapping()
64            .set_global([qconfig1, qconfig2])
65            .set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
66            .set_object_type(torch.nn.ReLU, [qconfig1])
67            .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
68            .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
69            .set_module_name("module1", [None])
70            .set_module_name("module2", [qconfig2])
71            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
72
73    """
74
75    def __init__(self) -> None:
76        # initialize this with 1 QConfigMapping to avoid corner cases
77        self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
78
79    def _handle_list_size_mismatch(
80        self, qconfig_list: List[QConfigAny], style: str
81    ) -> None:
82        # this method handles cases where the size of qconfig_list does not match
83        # the size of qconfig_mappings_list.
84        # Issue: Consider a user inserting global_qconfig A and B first, then inserting
85        # qconfig C as an object_type_qconfig for conv ops. If we internally store
86        # 1 QConfigMapping with A and C and another with just B, then the
87        # second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
88
89        # we avoid this by maintaining the invariant that if any QConfigMapping
90        # has a qconfig style+key with a qconfig in it, all QConfigMappings must
91        # have either a qconfig or None for that same style+key. In the above
92        # example, a None qconfig would prevent the unwanted match in the
93        # second QConfigMapping
94
95        if len(qconfig_list) > len(self.qconfig_mappings_list):
96            # Case: we have more qconfigs (in qconfig_list) than QConfigMappings
97
98            # Add new QConfigMappings (initialized so we maintain the `invariant`)
99
100            new_qconfig_mapping = QConfigMapping()
101            # searches other QConfigMappings for qconfig style+keys
102            # that need to be inserted as `None` into the new QConfigMapping
103            for qconfig_mapping in self.qconfig_mappings_list:
104                # global_qconfig has None by default
105                for check_style in _QCONFIG_STYLE_ORDER[1:]:
106                    qconfigs_dict = getattr(qconfig_mapping, check_style)
107                    target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
108                    for key in qconfigs_dict:
109                        target_qconfigs_dict[key] = None
110                break
111
112            # insert copies of this new QConfigMapping until all entires
113            # in qconfig_list can fit among the QConfigMappings
114            while len(qconfig_list) > len(self.qconfig_mappings_list):
115                self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
116        else:
117            # Case: we have fewer qconfigs in qconfig_list than QConfigMappings
118
119            # pad qconfig_list with `None` until length is same
120            while len(qconfig_list) < len(self.qconfig_mappings_list):
121                qconfig_list.append(None)
122
123    # this function applies the insertion method across each QConfigMapping
124    def _insert_qconfig_list(
125        self,
126        style: str,
127        args: List[Union[str, int, Callable]],
128        qconfig_list: List[QConfigAny],
129    ) -> None:
130        # we remove duplicates and None to make the ordering of qconfigs
131        # deterministic upon insertion.
132        _remove_duplicates_and_none(qconfig_list)
133
134        self._handle_list_size_mismatch(qconfig_list, style)
135        method_name = _QCONFIG_STYLE_TO_METHOD[style]
136        for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
137            # uses QConfigMapping set method to insert qconfig
138            set_method = getattr(qconfig_mapping, method_name)
139            set_method(*args, qconfig)
140
141    def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
142        """
143        Set global QConfigs
144        see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
145        """
146        self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
147        return self
148
149    def set_object_type(
150        self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
151    ) -> QConfigMultiMapping:
152        """
153        Set object type QConfigs
154        see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
155        """
156        self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
157        return self
158
159    def set_module_name_regex(
160        self, module_name_regex: str, qconfig_list: List[QConfigAny]
161    ) -> QConfigMultiMapping:
162        """
163        Set module_name_regex QConfigs
164        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
165        """
166        self._insert_qconfig_list(
167            "module_name_regex_qconfigs", [module_name_regex], qconfig_list
168        )
169        return self
170
171    def set_module_name(
172        self, module_name: str, qconfig_list: List[QConfigAny]
173    ) -> QConfigMultiMapping:
174        """
175        Set module_name QConfigs
176        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
177        """
178        self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
179        return self
180
181    def set_module_name_object_type_order(
182        self,
183        module_name: str,
184        object_type: Callable,
185        index: int,
186        qconfig_list: List[QConfigAny],
187    ) -> QConfigMultiMapping:
188        """
189        Set module_name QConfigs
190        see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
191        """
192        self._insert_qconfig_list(
193            "module_name_object_type_order_qconfigs",
194            [module_name, object_type, index],
195            qconfig_list,
196        )
197        return self
198
199    def __repr__(self):
200        return (
201            self.__class__.__name__
202            + " ["
203            + "".join(
204                f"\n{qconfig_mapping.__repr__()},"
205                for qconfig_mapping in self.qconfig_mappings_list
206            )
207            + "\n]"
208        )
209
210    @classmethod
211    def from_list_qconfig_mapping(
212        cls, qconfig_mapping_list: List[QConfigMapping]
213    ) -> QConfigMultiMapping:
214        """
215        Creates a QConfigMultiMapping from a list of QConfigMappings
216        """
217        new_qconfig_multi_mapping = cls()
218
219        new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
220            qconfig_mapping_list
221        )
222
223        # we need to avoid the issue described in _handle_list_size_mismatch,
224        # so we reinsert all the qconfigs using the QConfigMultiMapping
225        # set methods
226
227        # go through all qconfig styles
228        # note: global can be ignored since it is None by default
229        for style in _QCONFIG_STYLE_ORDER[1:]:
230            # gather all key+qconfigs for current style
231            # into qconfig_dict_list
232            qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
233            for qconfig_mapping in qconfig_mapping_list:
234                qconfig_dict = getattr(qconfig_mapping, style)
235                for key, qconfig in qconfig_dict.items():
236                    if key not in qconfig_dict_list:
237                        qconfig_dict_list[key] = []
238                    qconfig_dict_list[key].append(qconfig)
239
240            # reinsert all gathered key+qconfigs
241            set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
242            set_method = getattr(new_qconfig_multi_mapping, set_method_name)
243            for key, qconfig_list in qconfig_dict_list.items():
244                if isinstance(key, tuple):
245                    set_method(*key, qconfig_list)
246                else:
247                    set_method(key, qconfig_list)
248
249        return new_qconfig_multi_mapping
250