1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import copy 5from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union 6 7import torch 8from torch.ao.quantization import QConfigMapping 9from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER 10 11 12if TYPE_CHECKING: 13 from torch.ao.quantization.qconfig import QConfigAny 14 15__all__ = ["QConfigMultiMapping"] 16 17_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = { 18 "global_qconfig": "set_global", 19 "object_type_qconfigs": "set_object_type", 20 "module_name_regex_qconfigs": "set_module_name_regex", 21 "module_name_qconfigs": "set_module_name", 22 "module_name_object_type_order_qconfigs": "set_module_name_object_type_order", 23} 24 25 26def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None: 27 to_remove = [] 28 for index, cur_qconfig in enumerate(qconfig_list): 29 if cur_qconfig is None: 30 to_remove.append(index) 31 break 32 for checked_qconfig in qconfig_list[:index]: 33 if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig): 34 to_remove.append(index) 35 break 36 for index in to_remove[::-1]: 37 qconfig_list.pop(index) 38 39 40class QConfigMultiMapping: 41 """ 42 This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s 43 so that multiple QConfigs can be specified for each QConfig matching style. 44 45 The user can specify QConfigs using the following methods (in increasing match priority): 46 47 ``set_global`` : sets the global (default) QConfigs 48 49 ``set_object_type`` : sets the QConfigs for a given module type, function, or method name 50 51 ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string 52 53 ``set_module_name`` : sets the QConfigs for modules matching the given module name 54 55 ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination 56 of the given module name, object type, and the index at which the module appears 57 58 Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a 59 single QConfig. 60 61 Example usage:: 62 63 qconfig_mapping = QConfigMultiMapping() 64 .set_global([qconfig1, qconfig2]) 65 .set_object_type(torch.nn.Linear, [qconfig2, qconfig3]) 66 .set_object_type(torch.nn.ReLU, [qconfig1]) 67 .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2]) 68 .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3]) 69 .set_module_name("module1", [None]) 70 .set_module_name("module2", [qconfig2]) 71 .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3]) 72 73 """ 74 75 def __init__(self) -> None: 76 # initialize this with 1 QConfigMapping to avoid corner cases 77 self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()] 78 79 def _handle_list_size_mismatch( 80 self, qconfig_list: List[QConfigAny], style: str 81 ) -> None: 82 # this method handles cases where the size of qconfig_list does not match 83 # the size of qconfig_mappings_list. 84 # Issue: Consider a user inserting global_qconfig A and B first, then inserting 85 # qconfig C as an object_type_qconfig for conv ops. If we internally store 86 # 1 QConfigMapping with A and C and another with just B, then the 87 # second QConfigMapping will match B to conv ops (which is not wanted), since B is global. 88 89 # we avoid this by maintaining the invariant that if any QConfigMapping 90 # has a qconfig style+key with a qconfig in it, all QConfigMappings must 91 # have either a qconfig or None for that same style+key. In the above 92 # example, a None qconfig would prevent the unwanted match in the 93 # second QConfigMapping 94 95 if len(qconfig_list) > len(self.qconfig_mappings_list): 96 # Case: we have more qconfigs (in qconfig_list) than QConfigMappings 97 98 # Add new QConfigMappings (initialized so we maintain the `invariant`) 99 100 new_qconfig_mapping = QConfigMapping() 101 # searches other QConfigMappings for qconfig style+keys 102 # that need to be inserted as `None` into the new QConfigMapping 103 for qconfig_mapping in self.qconfig_mappings_list: 104 # global_qconfig has None by default 105 for check_style in _QCONFIG_STYLE_ORDER[1:]: 106 qconfigs_dict = getattr(qconfig_mapping, check_style) 107 target_qconfigs_dict = getattr(new_qconfig_mapping, check_style) 108 for key in qconfigs_dict: 109 target_qconfigs_dict[key] = None 110 break 111 112 # insert copies of this new QConfigMapping until all entires 113 # in qconfig_list can fit among the QConfigMappings 114 while len(qconfig_list) > len(self.qconfig_mappings_list): 115 self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) 116 else: 117 # Case: we have fewer qconfigs in qconfig_list than QConfigMappings 118 119 # pad qconfig_list with `None` until length is same 120 while len(qconfig_list) < len(self.qconfig_mappings_list): 121 qconfig_list.append(None) 122 123 # this function applies the insertion method across each QConfigMapping 124 def _insert_qconfig_list( 125 self, 126 style: str, 127 args: List[Union[str, int, Callable]], 128 qconfig_list: List[QConfigAny], 129 ) -> None: 130 # we remove duplicates and None to make the ordering of qconfigs 131 # deterministic upon insertion. 132 _remove_duplicates_and_none(qconfig_list) 133 134 self._handle_list_size_mismatch(qconfig_list, style) 135 method_name = _QCONFIG_STYLE_TO_METHOD[style] 136 for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list): 137 # uses QConfigMapping set method to insert qconfig 138 set_method = getattr(qconfig_mapping, method_name) 139 set_method(*args, qconfig) 140 141 def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping: 142 """ 143 Set global QConfigs 144 see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info 145 """ 146 self._insert_qconfig_list("global_qconfig", [], global_qconfig_list) 147 return self 148 149 def set_object_type( 150 self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny] 151 ) -> QConfigMultiMapping: 152 """ 153 Set object type QConfigs 154 see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info 155 """ 156 self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list) 157 return self 158 159 def set_module_name_regex( 160 self, module_name_regex: str, qconfig_list: List[QConfigAny] 161 ) -> QConfigMultiMapping: 162 """ 163 Set module_name_regex QConfigs 164 see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info 165 """ 166 self._insert_qconfig_list( 167 "module_name_regex_qconfigs", [module_name_regex], qconfig_list 168 ) 169 return self 170 171 def set_module_name( 172 self, module_name: str, qconfig_list: List[QConfigAny] 173 ) -> QConfigMultiMapping: 174 """ 175 Set module_name QConfigs 176 see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info 177 """ 178 self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list) 179 return self 180 181 def set_module_name_object_type_order( 182 self, 183 module_name: str, 184 object_type: Callable, 185 index: int, 186 qconfig_list: List[QConfigAny], 187 ) -> QConfigMultiMapping: 188 """ 189 Set module_name QConfigs 190 see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info 191 """ 192 self._insert_qconfig_list( 193 "module_name_object_type_order_qconfigs", 194 [module_name, object_type, index], 195 qconfig_list, 196 ) 197 return self 198 199 def __repr__(self): 200 return ( 201 self.__class__.__name__ 202 + " [" 203 + "".join( 204 f"\n{qconfig_mapping.__repr__()}," 205 for qconfig_mapping in self.qconfig_mappings_list 206 ) 207 + "\n]" 208 ) 209 210 @classmethod 211 def from_list_qconfig_mapping( 212 cls, qconfig_mapping_list: List[QConfigMapping] 213 ) -> QConfigMultiMapping: 214 """ 215 Creates a QConfigMultiMapping from a list of QConfigMappings 216 """ 217 new_qconfig_multi_mapping = cls() 218 219 new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy( 220 qconfig_mapping_list 221 ) 222 223 # we need to avoid the issue described in _handle_list_size_mismatch, 224 # so we reinsert all the qconfigs using the QConfigMultiMapping 225 # set methods 226 227 # go through all qconfig styles 228 # note: global can be ignored since it is None by default 229 for style in _QCONFIG_STYLE_ORDER[1:]: 230 # gather all key+qconfigs for current style 231 # into qconfig_dict_list 232 qconfig_dict_list: Dict[Any, List[QConfigAny]] = {} 233 for qconfig_mapping in qconfig_mapping_list: 234 qconfig_dict = getattr(qconfig_mapping, style) 235 for key, qconfig in qconfig_dict.items(): 236 if key not in qconfig_dict_list: 237 qconfig_dict_list[key] = [] 238 qconfig_dict_list[key].append(qconfig) 239 240 # reinsert all gathered key+qconfigs 241 set_method_name = _QCONFIG_STYLE_TO_METHOD[style] 242 set_method = getattr(new_qconfig_multi_mapping, set_method_name) 243 for key, qconfig_list in qconfig_dict_list.items(): 244 if isinstance(key, tuple): 245 set_method(*key, qconfig_list) 246 else: 247 set_method(key, qconfig_list) 248 249 return new_qconfig_multi_mapping 250