xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/qconfig_mapping.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from collections import OrderedDict
5from typing import Any, Callable, Dict, List, Tuple, Union
6
7import torch
8
9from .fake_quantize import default_weight_fake_quant, FixedQParamsFakeQuantize
10from .observer import (
11    _PartialWrapper,
12    default_fixed_qparams_range_0to1_observer,
13    default_fixed_qparams_range_neg1to1_observer,
14    default_placeholder_observer,
15    default_weight_observer,
16)
17from .qconfig import (
18    default_quint8_weight_qconfig,
19    default_reuse_input_qconfig,
20    default_symmetric_qnnpack_qat_qconfig,
21    default_symmetric_qnnpack_qconfig,
22    get_default_qat_qconfig,
23    get_default_qconfig,
24    QConfig,
25    QConfigAny,
26)
27
28
29__all__ = [
30    "get_default_qconfig_mapping",
31    "get_default_qat_qconfig_mapping",
32    "QConfigMapping",
33]
34
35
36# TODO: replace all usages with these constants
37_GLOBAL_DICT_KEY = ""
38_OBJECT_TYPE_DICT_KEY = "object_type"
39_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex"
40_MODULE_NAME_DICT_KEY = "module_name"
41_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
42
43# TODO: derive this map from the BackendConfig
44_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = {
45    torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer,
46    torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer,
47    "hardsigmoid": default_fixed_qparams_range_0to1_observer,
48    "hardsigmoid_": default_fixed_qparams_range_0to1_observer,
49    torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer,
50    torch.sigmoid: default_fixed_qparams_range_0to1_observer,
51    "sigmoid": default_fixed_qparams_range_0to1_observer,
52    "sigmoid_": default_fixed_qparams_range_0to1_observer,
53    torch.nn.Softmax: default_fixed_qparams_range_0to1_observer,
54    torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer,
55    torch.tanh: default_fixed_qparams_range_neg1to1_observer,
56    "tanh": default_fixed_qparams_range_neg1to1_observer,
57    "tanh_": default_fixed_qparams_range_neg1to1_observer,
58}
59
60
61def _get_default_qconfig_mapping(
62    is_qat: bool, backend: str, version: int
63) -> QConfigMapping:
64    """
65    Return the default QConfigMapping for the given quantization type and backend.
66    """
67    if is_qat:
68        qconfig = get_default_qat_qconfig(backend, version)
69    else:
70        qconfig = get_default_qconfig(backend, version)
71    default_weight = default_weight_fake_quant if is_qat else default_weight_observer
72
73    # default_per_channel_weight_observer is not currently compatible with fbgemm backend
74    # so we have to modify the weight observer to default_weight_observer or another
75    # per tensor supported observer.
76    # see https://github.com/pytorch/pytorch/issues/47535
77    if backend in ("fbgemm", "x86"):
78        qconfig_transpose = QConfig(
79            activation=qconfig.activation, weight=default_weight
80        )
81    else:
82        qconfig_transpose = qconfig
83
84    # currently layernorm only supports float weights
85    # we have to add this because otherwise there will be a extra quantize-dequantize pair
86    qconfig_layernorm = QConfig(
87        activation=qconfig.activation, weight=default_placeholder_observer
88    )
89
90    qconfig_mapping = (
91        QConfigMapping()
92        .set_global(qconfig)
93        .set_object_type("reshape", default_reuse_input_qconfig)
94        .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose)
95        .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose)
96        .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose)
97        .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose)
98        .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose)
99        .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose)
100        .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm)
101        .set_object_type(torch.nn.LayerNorm, qconfig_layernorm)
102        .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig)
103    )
104    # Use special observers for ops with fixed qparams
105    fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
106    for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
107        if observer in fixed_qparams_observer_to_qconfig:
108            fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
109        else:
110            if is_qat:
111                activation = FixedQParamsFakeQuantize.with_args(observer=observer)
112            else:
113                activation = observer
114            fixed_qparams_qconfig = QConfig(
115                activation=activation, weight=default_weight
116            )
117            fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
118        qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)
119
120    # TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
121    #      Need to be able to support fusion of ops with different qconfigs
122
123    return qconfig_mapping
124
125
126def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping:
127    """
128    Return the default QConfigMapping for post training quantization.
129
130    Args:
131      * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
132         one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
133      * ``version`` (int) : the version for the default qconfig mapping
134    """
135    # TODO: add assert for backend choices
136    return _get_default_qconfig_mapping(False, backend, version)
137
138
139def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping:
140    """
141    Return the default QConfigMapping for quantization aware training.
142
143    Args:
144      * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
145         one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
146      * ``version`` (int) : the version for the default qconfig mapping
147    """
148    return _get_default_qconfig_mapping(True, backend, version)
149
150
151def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping:
152    """
153    Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig`
154    as the default QConfig.
155    """
156    default_qconfig = default_symmetric_qnnpack_qconfig
157    return _get_default_qconfig_mapping_with_default_qconfig(
158        False, "qnnpack", default_qconfig
159    )
160
161
162def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping:
163    """
164    Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig`
165    as the default QConfig.
166    """
167    default_qconfig = default_symmetric_qnnpack_qat_qconfig
168    return _get_default_qconfig_mapping_with_default_qconfig(
169        True, "qnnpack", default_qconfig
170    )
171
172
173def _get_default_qconfig_mapping_with_default_qconfig(
174    is_qat: bool,
175    backend: str,
176    default_qconfig: QConfig,
177) -> QConfigMapping:
178    """
179    Return a QConfigMapping that uses the provided qconfig as the default QConfig.
180    """
181    if is_qat:
182        qconfig_mapping = get_default_qat_qconfig_mapping(backend)
183    else:
184        qconfig_mapping = get_default_qconfig_mapping(backend)
185    qconfig_mapping.set_global(default_qconfig)
186    for pattern in qconfig_mapping.object_type_qconfigs.keys():
187        if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
188            qconfig_mapping.set_object_type(pattern, default_qconfig)
189    return qconfig_mapping
190
191
192_QCONFIG_STYLE_ORDER: List[str] = [
193    "global_qconfig",
194    "object_type_qconfigs",
195    "module_name_regex_qconfigs",
196    "module_name_qconfigs",
197    "module_name_object_type_order_qconfigs",
198]
199
200
201class QConfigMapping:
202    """
203    Mapping from model ops to :class:`torch.ao.quantization.QConfig` s.
204
205    The user can specify QConfigs using the following methods (in increasing match priority):
206
207        ``set_global`` : sets the global (default) QConfig
208
209        ``set_object_type`` : sets the QConfig for a given module type, function, or method name
210
211        ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string
212
213        ``set_module_name`` : sets the QConfig for modules matching the given module name
214
215        ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination
216        of the given module name, object type, and the index at which the module appears
217
218    Example usage::
219
220        qconfig_mapping = QConfigMapping()
221            .set_global(global_qconfig)
222            .set_object_type(torch.nn.Linear, qconfig1)
223            .set_object_type(torch.nn.ReLU, qconfig1)
224            .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
225            .set_module_name_regex("foo.*", qconfig2)
226            .set_module_name("module1", qconfig1)
227            .set_module_name("module2", qconfig2)
228            .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3)
229
230    """
231
232    def __init__(self) -> None:
233        # In increasing match priority:
234        self.global_qconfig: QConfigAny = None
235        self.object_type_qconfigs: OrderedDict[
236            Union[Callable, str], QConfigAny
237        ] = OrderedDict()
238        self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
239        self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
240        self.module_name_object_type_order_qconfigs: OrderedDict[
241            Tuple[str, Callable, int], QConfigAny
242        ] = OrderedDict()
243
244    def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping:
245        """
246        Set the global (default) QConfig.
247        """
248        self.global_qconfig = global_qconfig
249        return self
250
251    def set_object_type(
252        self, object_type: Union[Callable, str], qconfig: QConfigAny
253    ) -> QConfigMapping:
254        """
255        Set the QConfig for a given module type, function, or method name.
256        If the QConfig for an existing object type was already set, the new QConfig will override the old one.
257        """
258        self.object_type_qconfigs[object_type] = qconfig
259        return self
260
261    def set_module_name_regex(
262        self, module_name_regex: str, qconfig: QConfigAny
263    ) -> QConfigMapping:
264        """
265        Set the QConfig for modules matching the given regex string.
266
267        Regexes will be matched in the order in which they are registered through this method.
268        Thus, the caller should register more specific patterns first, e.g.::
269
270            qconfig_mapping = QConfigMapping()
271                .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
272                .set_module_name_regex("foo.*bar.*", qconfig2)
273                .set_module_name_regex("foo.*", qconfig3)
274
275        In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2,
276        and "foo.baz.relu" would match qconfig3.
277
278        If the QConfig for an existing module name regex was already set, the new QConfig will override the
279        old one while preserving the order in which the regexes were originally registered.
280        """
281        self.module_name_regex_qconfigs[module_name_regex] = qconfig
282        return self
283
284    def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping:
285        """
286        Set the QConfig for modules matching the given module name.
287        If the QConfig for an existing module name was already set, the new QConfig will override the old one.
288        """
289        self.module_name_qconfigs[module_name] = qconfig
290        return self
291
292    def set_module_name_object_type_order(
293        self, module_name: str, object_type: Callable, index: int, qconfig: QConfigAny
294    ) -> QConfigMapping:
295        """
296        Set the QConfig for modules matching a combination of the given module name, object type,
297        and the index at which the module appears.
298
299        If the QConfig for an existing (module name, object type, index)  was already set, the new QConfig
300        will override the old one.
301        """
302        self.module_name_object_type_order_qconfigs[
303            (module_name, object_type, index)
304        ] = qconfig
305        return self
306
307    def __repr__(self) -> str:
308        output = self.__class__.__name__ + " ("
309        for style_name in _QCONFIG_STYLE_ORDER:
310            output += f"\n {style_name}"
311            qconfigs = getattr(self, style_name)
312            if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0:
313                for key, qconfig in qconfigs.items():
314                    output += f"\n  {key}: {qconfig}"
315            else:
316                output += f"\n  {qconfigs}"
317        return output + "\n)"
318
319    # TODO: remove this
320    def to_dict(self) -> Dict[str, Any]:
321        """
322        Convert this ``QConfigMapping`` to a dictionary with the following keys:
323
324            "" (for global QConfig)
325
326            "object_type"
327
328            "module_name_regex"
329
330            "module_name"
331
332            "module_name_object_type_order"
333
334        The values of this dictionary are lists of tuples.
335        """
336        return {
337            _GLOBAL_DICT_KEY: self.global_qconfig,
338            _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()),
339            _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()),
340            _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()),
341            _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
342                (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items()
343            ],
344        }
345
346    # TODO: remove this
347    @classmethod
348    def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
349        """
350        Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):
351
352            "" (for global QConfig)
353
354            "object_type"
355
356            "module_name_regex"
357
358            "module_name"
359
360            "module_name_object_type_order"
361
362        The values of this dictionary are expected to be lists of tuples.
363        """
364        conf = cls()
365        if _GLOBAL_DICT_KEY in qconfig_dict:
366            conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY])
367        for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []):
368            conf.set_object_type(object_type, qconfig)
369        for module_name_regex, qconfig in qconfig_dict.get(
370            _MODULE_NAME_REGEX_DICT_KEY, []
371        ):
372            conf.set_module_name_regex(module_name_regex, qconfig)
373        for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []):
374            conf.set_module_name(module_name, qconfig)
375        for module_name, object_type, index, qconfig in qconfig_dict.get(
376            _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []
377        ):
378            conf.set_module_name_object_type_order(
379                module_name, object_type, index, qconfig
380            )
381        return conf
382