1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from collections import OrderedDict 5from typing import Any, Callable, Dict, List, Tuple, Union 6 7import torch 8 9from .fake_quantize import default_weight_fake_quant, FixedQParamsFakeQuantize 10from .observer import ( 11 _PartialWrapper, 12 default_fixed_qparams_range_0to1_observer, 13 default_fixed_qparams_range_neg1to1_observer, 14 default_placeholder_observer, 15 default_weight_observer, 16) 17from .qconfig import ( 18 default_quint8_weight_qconfig, 19 default_reuse_input_qconfig, 20 default_symmetric_qnnpack_qat_qconfig, 21 default_symmetric_qnnpack_qconfig, 22 get_default_qat_qconfig, 23 get_default_qconfig, 24 QConfig, 25 QConfigAny, 26) 27 28 29__all__ = [ 30 "get_default_qconfig_mapping", 31 "get_default_qat_qconfig_mapping", 32 "QConfigMapping", 33] 34 35 36# TODO: replace all usages with these constants 37_GLOBAL_DICT_KEY = "" 38_OBJECT_TYPE_DICT_KEY = "object_type" 39_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" 40_MODULE_NAME_DICT_KEY = "module_name" 41_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" 42 43# TODO: derive this map from the BackendConfig 44_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { 45 torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, 46 torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, 47 "hardsigmoid": default_fixed_qparams_range_0to1_observer, 48 "hardsigmoid_": default_fixed_qparams_range_0to1_observer, 49 torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer, 50 torch.sigmoid: default_fixed_qparams_range_0to1_observer, 51 "sigmoid": default_fixed_qparams_range_0to1_observer, 52 "sigmoid_": default_fixed_qparams_range_0to1_observer, 53 torch.nn.Softmax: default_fixed_qparams_range_0to1_observer, 54 torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer, 55 torch.tanh: default_fixed_qparams_range_neg1to1_observer, 56 "tanh": default_fixed_qparams_range_neg1to1_observer, 57 "tanh_": default_fixed_qparams_range_neg1to1_observer, 58} 59 60 61def _get_default_qconfig_mapping( 62 is_qat: bool, backend: str, version: int 63) -> QConfigMapping: 64 """ 65 Return the default QConfigMapping for the given quantization type and backend. 66 """ 67 if is_qat: 68 qconfig = get_default_qat_qconfig(backend, version) 69 else: 70 qconfig = get_default_qconfig(backend, version) 71 default_weight = default_weight_fake_quant if is_qat else default_weight_observer 72 73 # default_per_channel_weight_observer is not currently compatible with fbgemm backend 74 # so we have to modify the weight observer to default_weight_observer or another 75 # per tensor supported observer. 76 # see https://github.com/pytorch/pytorch/issues/47535 77 if backend in ("fbgemm", "x86"): 78 qconfig_transpose = QConfig( 79 activation=qconfig.activation, weight=default_weight 80 ) 81 else: 82 qconfig_transpose = qconfig 83 84 # currently layernorm only supports float weights 85 # we have to add this because otherwise there will be a extra quantize-dequantize pair 86 qconfig_layernorm = QConfig( 87 activation=qconfig.activation, weight=default_placeholder_observer 88 ) 89 90 qconfig_mapping = ( 91 QConfigMapping() 92 .set_global(qconfig) 93 .set_object_type("reshape", default_reuse_input_qconfig) 94 .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) 95 .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) 96 .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) 97 .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) 98 .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) 99 .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) 100 .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) 101 .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) 102 .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) 103 ) 104 # Use special observers for ops with fixed qparams 105 fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {} 106 for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): 107 if observer in fixed_qparams_observer_to_qconfig: 108 fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer] 109 else: 110 if is_qat: 111 activation = FixedQParamsFakeQuantize.with_args(observer=observer) 112 else: 113 activation = observer 114 fixed_qparams_qconfig = QConfig( 115 activation=activation, weight=default_weight 116 ) 117 fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig 118 qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig) 119 120 # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. 121 # Need to be able to support fusion of ops with different qconfigs 122 123 return qconfig_mapping 124 125 126def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping: 127 """ 128 Return the default QConfigMapping for post training quantization. 129 130 Args: 131 * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be 132 one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] 133 * ``version`` (int) : the version for the default qconfig mapping 134 """ 135 # TODO: add assert for backend choices 136 return _get_default_qconfig_mapping(False, backend, version) 137 138 139def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping: 140 """ 141 Return the default QConfigMapping for quantization aware training. 142 143 Args: 144 * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be 145 one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] 146 * ``version`` (int) : the version for the default qconfig mapping 147 """ 148 return _get_default_qconfig_mapping(True, backend, version) 149 150 151def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping: 152 """ 153 Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig` 154 as the default QConfig. 155 """ 156 default_qconfig = default_symmetric_qnnpack_qconfig 157 return _get_default_qconfig_mapping_with_default_qconfig( 158 False, "qnnpack", default_qconfig 159 ) 160 161 162def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping: 163 """ 164 Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig` 165 as the default QConfig. 166 """ 167 default_qconfig = default_symmetric_qnnpack_qat_qconfig 168 return _get_default_qconfig_mapping_with_default_qconfig( 169 True, "qnnpack", default_qconfig 170 ) 171 172 173def _get_default_qconfig_mapping_with_default_qconfig( 174 is_qat: bool, 175 backend: str, 176 default_qconfig: QConfig, 177) -> QConfigMapping: 178 """ 179 Return a QConfigMapping that uses the provided qconfig as the default QConfig. 180 """ 181 if is_qat: 182 qconfig_mapping = get_default_qat_qconfig_mapping(backend) 183 else: 184 qconfig_mapping = get_default_qconfig_mapping(backend) 185 qconfig_mapping.set_global(default_qconfig) 186 for pattern in qconfig_mapping.object_type_qconfigs.keys(): 187 if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: 188 qconfig_mapping.set_object_type(pattern, default_qconfig) 189 return qconfig_mapping 190 191 192_QCONFIG_STYLE_ORDER: List[str] = [ 193 "global_qconfig", 194 "object_type_qconfigs", 195 "module_name_regex_qconfigs", 196 "module_name_qconfigs", 197 "module_name_object_type_order_qconfigs", 198] 199 200 201class QConfigMapping: 202 """ 203 Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. 204 205 The user can specify QConfigs using the following methods (in increasing match priority): 206 207 ``set_global`` : sets the global (default) QConfig 208 209 ``set_object_type`` : sets the QConfig for a given module type, function, or method name 210 211 ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string 212 213 ``set_module_name`` : sets the QConfig for modules matching the given module name 214 215 ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination 216 of the given module name, object type, and the index at which the module appears 217 218 Example usage:: 219 220 qconfig_mapping = QConfigMapping() 221 .set_global(global_qconfig) 222 .set_object_type(torch.nn.Linear, qconfig1) 223 .set_object_type(torch.nn.ReLU, qconfig1) 224 .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) 225 .set_module_name_regex("foo.*", qconfig2) 226 .set_module_name("module1", qconfig1) 227 .set_module_name("module2", qconfig2) 228 .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) 229 230 """ 231 232 def __init__(self) -> None: 233 # In increasing match priority: 234 self.global_qconfig: QConfigAny = None 235 self.object_type_qconfigs: OrderedDict[ 236 Union[Callable, str], QConfigAny 237 ] = OrderedDict() 238 self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() 239 self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() 240 self.module_name_object_type_order_qconfigs: OrderedDict[ 241 Tuple[str, Callable, int], QConfigAny 242 ] = OrderedDict() 243 244 def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping: 245 """ 246 Set the global (default) QConfig. 247 """ 248 self.global_qconfig = global_qconfig 249 return self 250 251 def set_object_type( 252 self, object_type: Union[Callable, str], qconfig: QConfigAny 253 ) -> QConfigMapping: 254 """ 255 Set the QConfig for a given module type, function, or method name. 256 If the QConfig for an existing object type was already set, the new QConfig will override the old one. 257 """ 258 self.object_type_qconfigs[object_type] = qconfig 259 return self 260 261 def set_module_name_regex( 262 self, module_name_regex: str, qconfig: QConfigAny 263 ) -> QConfigMapping: 264 """ 265 Set the QConfig for modules matching the given regex string. 266 267 Regexes will be matched in the order in which they are registered through this method. 268 Thus, the caller should register more specific patterns first, e.g.:: 269 270 qconfig_mapping = QConfigMapping() 271 .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) 272 .set_module_name_regex("foo.*bar.*", qconfig2) 273 .set_module_name_regex("foo.*", qconfig3) 274 275 In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2, 276 and "foo.baz.relu" would match qconfig3. 277 278 If the QConfig for an existing module name regex was already set, the new QConfig will override the 279 old one while preserving the order in which the regexes were originally registered. 280 """ 281 self.module_name_regex_qconfigs[module_name_regex] = qconfig 282 return self 283 284 def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping: 285 """ 286 Set the QConfig for modules matching the given module name. 287 If the QConfig for an existing module name was already set, the new QConfig will override the old one. 288 """ 289 self.module_name_qconfigs[module_name] = qconfig 290 return self 291 292 def set_module_name_object_type_order( 293 self, module_name: str, object_type: Callable, index: int, qconfig: QConfigAny 294 ) -> QConfigMapping: 295 """ 296 Set the QConfig for modules matching a combination of the given module name, object type, 297 and the index at which the module appears. 298 299 If the QConfig for an existing (module name, object type, index) was already set, the new QConfig 300 will override the old one. 301 """ 302 self.module_name_object_type_order_qconfigs[ 303 (module_name, object_type, index) 304 ] = qconfig 305 return self 306 307 def __repr__(self) -> str: 308 output = self.__class__.__name__ + " (" 309 for style_name in _QCONFIG_STYLE_ORDER: 310 output += f"\n {style_name}" 311 qconfigs = getattr(self, style_name) 312 if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0: 313 for key, qconfig in qconfigs.items(): 314 output += f"\n {key}: {qconfig}" 315 else: 316 output += f"\n {qconfigs}" 317 return output + "\n)" 318 319 # TODO: remove this 320 def to_dict(self) -> Dict[str, Any]: 321 """ 322 Convert this ``QConfigMapping`` to a dictionary with the following keys: 323 324 "" (for global QConfig) 325 326 "object_type" 327 328 "module_name_regex" 329 330 "module_name" 331 332 "module_name_object_type_order" 333 334 The values of this dictionary are lists of tuples. 335 """ 336 return { 337 _GLOBAL_DICT_KEY: self.global_qconfig, 338 _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), 339 _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), 340 _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), 341 _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ 342 (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() 343 ], 344 } 345 346 # TODO: remove this 347 @classmethod 348 def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping: 349 """ 350 Create a ``QConfigMapping`` from a dictionary with the following keys (all optional): 351 352 "" (for global QConfig) 353 354 "object_type" 355 356 "module_name_regex" 357 358 "module_name" 359 360 "module_name_object_type_order" 361 362 The values of this dictionary are expected to be lists of tuples. 363 """ 364 conf = cls() 365 if _GLOBAL_DICT_KEY in qconfig_dict: 366 conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) 367 for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): 368 conf.set_object_type(object_type, qconfig) 369 for module_name_regex, qconfig in qconfig_dict.get( 370 _MODULE_NAME_REGEX_DICT_KEY, [] 371 ): 372 conf.set_module_name_regex(module_name_regex, qconfig) 373 for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): 374 conf.set_module_name(module_name, qconfig) 375 for module_name, object_type, index, qconfig in qconfig_dict.get( 376 _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, [] 377 ): 378 conf.set_module_name_object_type_order( 379 module_name, object_type, index, qconfig 380 ) 381 return conf 382