xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/observer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3"""
4This module implements observers which are used to collect statistics about
5the values observed during calibration (PTQ) or training (QAT).
6"""
7
8import re
9import warnings
10from abc import ABCMeta, abstractmethod
11from collections import OrderedDict
12from functools import partial
13from typing import Any, Dict, List, Optional, Tuple
14
15import torch
16import torch.nn as nn
17from torch.ao.quantization.utils import (
18    calculate_qmin_qmax,
19    check_min_max_valid,
20    is_per_channel,
21    is_per_tensor,
22    validate_qmin_qmax,
23)
24
25
26__all__ = [
27    "default_affine_fixed_qparams_observer",
28    "default_debug_observer",
29    "default_dynamic_quant_observer",
30    "default_fixed_qparams_range_0to1_observer",
31    "default_fixed_qparams_range_neg1to1_observer",
32    "default_float_qparams_observer",
33    "default_float_qparams_observer_4bit",
34    "default_histogram_observer",
35    "default_observer",
36    "default_per_channel_weight_observer",
37    "default_placeholder_observer",
38    "default_reuse_input_observer",
39    "default_symmetric_fixed_qparams_observer",
40    "default_weight_observer",
41    "get_observer_state_dict",
42    "load_observer_state_dict",
43    "per_channel_weight_observer_range_neg_127_to_127",
44    "weight_observer_range_neg_127_to_127",
45    "FixedQParamsObserver",
46    "HistogramObserver",
47    "MinMaxObserver",
48    "MovingAverageMinMaxObserver",
49    "MovingAveragePerChannelMinMaxObserver",
50    "NoopObserver",
51    "ObserverBase",
52    "PerChannelMinMaxObserver",
53    "PlaceholderObserver",
54    "RecordingObserver",
55    "ReuseInputObserver",
56    "UniformQuantizationObserverBase",
57]
58
59
60class _PartialWrapper:
61    def __init__(self, p):
62        self.p = p
63        self.callable_args = {}
64
65    def __call__(self, *args, **keywords):
66        # call each arg in callable_args and add them partial, then run with keywords
67        # skip if arg_name in keywords so its possible to overwrite
68        for arg_name in self.callable_args:
69            if arg_name not in keywords:
70                keywords = {**keywords, arg_name: self.callable_args[arg_name]()}
71        return self.p(*args, **keywords)
72
73    def __repr__(self):
74        return self.p.__repr__() + self.callable_args.__repr__()
75
76    def with_args(self, **kwargs):
77        return _with_args(self, **kwargs)
78
79    def with_callable_args(self, **kwargs):
80        result = _PartialWrapper(p=self.p)
81        result.callable_args = {**self.callable_args, **kwargs}
82        return result
83
84
85def _with_args(cls_or_self, **kwargs):
86    r"""Wrapper that allows creation of class factories.
87
88    This can be useful when there is a need to create classes with the same
89    constructor arguments, but different instances. Can be used in conjunction with
90    _callable_args
91
92    Example::
93
94        >>> # xdoctest: +SKIP("Undefined vars")
95        >>> Foo.with_args = classmethod(_with_args)
96        >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
97        >>> foo_instance1 = foo_builder()
98        >>> foo_instance2 = foo_builder()
99        >>> id(foo_instance1) == id(foo_instance2)
100        False
101    """
102    r = _PartialWrapper(partial(cls_or_self, **kwargs))
103    return r
104
105
106def _with_callable_args(cls_or_self, **kwargs):
107    r"""Wrapper that allows creation of class factories args that need to be
108    called at construction time.
109
110    This can be useful when there is a need to create classes with the same
111    constructor arguments, but different instances and those arguments should only
112    be calculated at construction time. Can be used in conjunction with _with_args
113
114    Example::
115
116        >>> # xdoctest: +SKIP("Undefined vars")
117        >>> Foo.with_callable_args = classmethod(_with_callable_args)
118        >>> Foo.with_args = classmethod(_with_args)
119        >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan")
120        >>> foo_instance1 = foo_builder()
121        >>> # wait 50
122        >>> foo_instance2 = foo_builder()
123        >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time)
124        False
125    """
126    r = _PartialWrapper(partial(cls_or_self))
127    return r.with_callable_args(**kwargs)
128
129
130ABC: Any = ABCMeta("ABC", (object,), {})  # compatible with Python 2 *and* 3:
131
132
133class ObserverBase(ABC, nn.Module):
134    r"""Base observer Module.
135    Any observer implementation should derive from this class.
136
137    Concrete observers should follow the same API. In forward, they will update
138    the statistics of the observed Tensor. And they should provide a
139    `calculate_qparams` function that computes the quantization parameters given
140    the collected statistics.
141
142    Args:
143        dtype: dtype argument to the `quantize` node needed to implement the
144               reference model spec.
145        is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization
146        or static quantization
147    """
148
149    def __init__(self, dtype, is_dynamic=False):
150        super().__init__()
151        self.dtype = dtype
152        self.is_dynamic = is_dynamic
153
154    @abstractmethod
155    def forward(self, x):
156        pass
157
158    @abstractmethod
159    def calculate_qparams(self, **kwargs):
160        pass
161
162    with_args = classmethod(_with_args)
163    with_callable_args = classmethod(_with_callable_args)
164
165
166class UniformQuantizationObserverBase(ObserverBase):
167    r"""Common base for all observers using uniform quantization to calculate
168    scale and zero_point.
169
170    Args:
171        dtype: dtype argument to the `quantize` node needed to implement the
172               reference model spec.
173        qscheme: Quantization scheme to be used.
174        reduce_range: Reduces the range of the quantized data type by 1 bit.
175                      This is sometimes required to avoid instruction overflow.
176        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
177        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
178        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
179
180    .. warning::
181
182        :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
183               or `torch.int8` or `torch.uint8`
184
185    .. warning::
186
187        :attr:`qscheme` can only take one of the following options:
188
189        - ``torch.per_tensor_affine``
190        - ``torch.per_tensor_symmetric``
191        - ``torch.per_channel_affine``
192        - ``torch.per_channel_symmetric``
193    """
194
195    # Note: the version is shared by all observer types
196    #
197    # Version 1/None
198    #   self
199    #
200    # Version 2 (base class only, does not include child class buffers)
201    #   self
202    #   |--- eps : Tensor
203    #
204    # Version 3
205    #   for HistogramObserver only, changed the shape of uninitialized
206    #   min_val and max_val buffers from torch.Size([0]) to torch.Size([])
207    #   for PerChannelObservers, changed the name of the buffers from min_vals
208    #   to min_val and from max_vals to max_val.
209    _version = 3
210
211    eps: torch.Tensor
212
213    def __init__(
214        self,
215        dtype=torch.quint8,
216        qscheme=torch.per_tensor_affine,
217        reduce_range=False,
218        quant_min=None,
219        quant_max=None,
220        factory_kwargs=None,
221        eps=torch.finfo(torch.float32).eps,
222        is_dynamic=False,
223        **kwargs,
224    ) -> None:
225        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
226        super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs)
227        self.qscheme = qscheme
228        if reduce_range:
229            warnings.warn(
230                "Please use quant_min and quant_max to specify the range for observers. \
231                    reduce_range will be deprecated in a future release of PyTorch."
232            )
233        self.reduce_range = reduce_range
234        self.register_buffer("eps", torch.tensor([eps], **factory_kwargs))
235        assert self.qscheme in (
236            torch.per_tensor_affine,
237            torch.per_tensor_symmetric,
238            torch.per_channel_affine,
239            torch.per_channel_symmetric,
240            torch.per_channel_affine_float_qparams,
241        ), "Default Observer only works for per_tensor_affine, \
242                per_tensor_symmetric, per_channel_affine, \
243                per_channel_symmetric and per_channel_float_qparams quantization scheme"
244
245        _ALLOWED_DTYPES = (
246            torch.qint8,
247            torch.quint8,
248            torch.quint4x2,
249            torch.qint32,
250            torch.int8,
251            torch.uint8,
252            torch.int16,
253            torch.int32,
254            torch.float8_e5m2,
255            torch.float8_e4m3fn,
256        )
257
258        assert (
259            self.dtype in _ALLOWED_DTYPES
260        ), f"Default Observer only works for {_ALLOWED_DTYPES} data type"
261        self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
262        if self.has_customized_qrange:
263            validate_qmin_qmax(quant_min, quant_max)
264        self.quant_min, self.quant_max = calculate_qmin_qmax(
265            quant_min,
266            quant_max,
267            self.has_customized_qrange,
268            self.dtype,
269            self.reduce_range,
270        )
271
272    def _load_from_state_dict(
273        self,
274        state_dict,
275        prefix,
276        local_metadata,
277        strict,
278        missing_keys,
279        unexpected_keys,
280        error_msgs,
281    ):
282        version = local_metadata.get("version", None)
283
284        if version is None or version == 1:
285            # eps was moved to a buffer in version 2
286            eps = torch.tensor([torch.finfo(torch.float32).eps])
287            state_dict[prefix + "eps"] = eps
288
289        super()._load_from_state_dict(
290            state_dict,
291            prefix,
292            local_metadata,
293            strict,
294            missing_keys,
295            unexpected_keys,
296            error_msgs,
297        )
298
299    @torch.jit.export
300    def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
301        r"""Validates that the user-specified quantization range is properly initialized
302        and within the given bound supported by the observer dtype.
303
304        To accommodate lower-bit quantization with respect to the existing torch.qint8 and
305        torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
306        in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
307        values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
308        fake quantization. These estimates are compared against parameters learned through backpropagation.
309        The related literatures for scale and zero point via backpropagation are as follows:
310
311        Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
312        Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
313        """
314        # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
315        # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
316        assert (
317            quant_min <= 0 <= quant_max
318        ), "Used-specified quantization range must include 0."
319        assert (
320            quant_min < quant_max
321        ), "qmin must be strictly less than qmax for user-specified quantization range."
322
323    @torch.jit.export
324    def _calculate_qparams(
325        self, min_val: torch.Tensor, max_val: torch.Tensor
326    ) -> Tuple[torch.Tensor, torch.Tensor]:
327        r"""Calculates the quantization parameters, given min and max
328        value tensors. Works for both per tensor and per channel cases
329
330        Args:
331            min_val: Minimum values per channel
332            max_val: Maximum values per channel
333
334        Returns:
335            scales: Scales tensor of shape (#channels,)
336            zero_points: Zero points tensor of shape (#channels,)
337        """
338        # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme
339        # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
340        # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code
341        # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor.
342        # TODO(jakeszwe, jerryzh168)
343        if not check_min_max_valid(min_val, max_val):
344            return torch.tensor([1.0], device=min_val.device.type), torch.tensor(
345                [0], device=min_val.device.type
346            )
347
348        quant_min, quant_max = self.quant_min, self.quant_max
349        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
350        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
351
352        device = min_val_neg.device
353        scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
354        zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
355
356        if (
357            self.qscheme == torch.per_tensor_symmetric
358            or self.qscheme == torch.per_channel_symmetric
359        ):
360            max_val_pos = torch.max(-min_val_neg, max_val_pos)
361            scale = max_val_pos / (float(quant_max - quant_min) / 2)
362            scale = torch.max(scale, self.eps)
363            if self.dtype in [torch.quint8, torch.uint8]:
364                if self.has_customized_qrange:
365                    # When customized quantization range is used, down-rounded midpoint of the range is chosen.
366                    zero_point = zero_point.new_full(
367                        zero_point.size(), (quant_min + quant_max) // 2
368                    )
369                else:
370                    zero_point = zero_point.new_full(zero_point.size(), 128)
371        elif self.qscheme == torch.per_channel_affine_float_qparams:
372            scale = (max_val - min_val) / float(quant_max - quant_min)
373            scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
374            # We use the quantize function
375            # xq = Round(Xf * inv_scale + zero_point),
376            # setting zero_point to (-1 * min *inv_scale) we get
377            # Xq = Round((Xf - min) * inv_scale)
378            zero_point = -1 * min_val / scale
379        else:
380            scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
381            scale = torch.max(scale, self.eps)
382            zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
383            zero_point = torch.clamp(zero_point, quant_min, quant_max)
384
385        # For scalar values, cast them to Tensors of size 1 to keep the shape
386        # consistent with default values in FakeQuantize.
387        if len(scale.shape) == 0:
388            # TODO: switch to scale.item() after adding JIT support
389            scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
390        if len(zero_point.shape) == 0:
391            # TODO: switch to zero_point.item() after adding JIT support
392            zero_point = torch.tensor(
393                [int(zero_point)], dtype=zero_point.dtype, device=device
394            )
395            if self.qscheme == torch.per_channel_affine_float_qparams:
396                zero_point = torch.tensor(
397                    [float(zero_point)], dtype=zero_point.dtype, device=device
398                )
399
400        return scale, zero_point
401
402    @torch.jit.export
403    def reset_min_max_vals(self):
404        raise NotImplementedError("Cannot reset min/max values in the given observer.")
405
406
407# Originally, this class was called `_ObserverBase`.  Keeping the old name around
408# for backwards compatibility.
409# TODO(after v1.13): delete this
410_ObserverBase = UniformQuantizationObserverBase
411
412
413class MinMaxObserver(UniformQuantizationObserverBase):
414    r"""Observer module for computing the quantization parameters based on the
415    running min and max values.
416
417    This observer uses the tensor min/max statistics to compute the quantization
418    parameters. The module records the running minimum and maximum of incoming
419    tensors, and uses this statistic to compute the quantization parameters.
420
421    Args:
422        dtype: dtype argument to the `quantize` node needed to implement the
423               reference model spec.
424        qscheme: Quantization scheme to be used
425        reduce_range: Reduces the range of the quantized data type by 1 bit
426        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
427        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
428        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
429
430    Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
431    scale :math:`s` and zero point :math:`z` are computed as:
432
433    The running minimum/maximum :math:`x_\text{min/max}` is computed as:
434
435    .. math::
436
437        \begin{array}{ll}
438        x_\text{min} &= \begin{cases}
439            \min(X) & \text{if~}x_\text{min} = \text{None} \\
440            \min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
441        \end{cases}\\
442        x_\text{max} &= \begin{cases}
443            \max(X) & \text{if~}x_\text{max} = \text{None} \\
444            \max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
445        \end{cases}\\
446        \end{array}
447
448    where :math:`X` is the observed tensor.
449
450    The scale :math:`s` and zero point :math:`z` are then computed as:
451
452    .. math::
453
454        \begin{aligned}
455            \text{if Symmetric:}&\\
456            &s = 2 \max(|x_\text{min}|, x_\text{max}) /
457                \left( Q_\text{max} - Q_\text{min} \right) \\
458            &z = \begin{cases}
459                0 & \text{if dtype is qint8} \\
460                128 & \text{otherwise}
461            \end{cases}\\
462            \text{Otherwise:}&\\
463                &s = \left( x_\text{max} - x_\text{min}  \right ) /
464                    \left( Q_\text{max} - Q_\text{min} \right ) \\
465                &z = Q_\text{min} - \text{round}(x_\text{min} / s)
466        \end{aligned}
467
468    where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and
469    maximum of the quantized data type.
470
471    .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
472
473    .. note:: If the running minimum equals to the running maximum, the scale
474              and zero_point are set to 1.0 and 0.
475    """
476    min_val: torch.Tensor
477    max_val: torch.Tensor
478
479    def __init__(
480        self,
481        dtype=torch.quint8,
482        qscheme=torch.per_tensor_affine,
483        reduce_range=False,
484        quant_min=None,
485        quant_max=None,
486        factory_kwargs=None,
487        eps=torch.finfo(torch.float32).eps,
488        is_dynamic=False,
489        **kwargs,
490    ) -> None:
491        if not is_per_tensor(qscheme):
492            raise NotImplementedError(
493                "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \
494                    and torch.per_tensor_affine."
495            )
496        # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but
497        # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it
498        # supports dynamic quantization, we may need to better error checking here
499
500        # For x86 quantized kernels, we need to ensure that the vpmaddubsw
501        # instruction does not overflow. We allow for a reduce_range argument to
502        # observers that reduces the quantized range to (0,127) or (-64, 63).
503        # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
504        # This is not an optimal choice for non x86 backends as it loses a bit
505        # of precision for activations.
506        super().__init__(
507            dtype=dtype,
508            qscheme=qscheme,
509            reduce_range=reduce_range,
510            quant_min=quant_min,
511            quant_max=quant_max,
512            factory_kwargs=factory_kwargs,
513            eps=eps,
514            is_dynamic=is_dynamic,
515            **kwargs,
516        )
517        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
518        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
519        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
520        if (
521            self.qscheme == torch.per_tensor_symmetric
522            and self.reduce_range
523            and self.dtype == torch.quint8
524        ):
525            raise NotImplementedError(
526                "Cannot reduce range for symmetric \
527                                       quantization for quint8"
528            )
529
530    def forward(self, x_orig):
531        r"""Records the running minimum and maximum of ``x``."""
532        if x_orig.numel() == 0:
533            return x_orig
534        x = x_orig.detach()  # avoid keeping autograd tape
535        x = x.to(self.min_val.dtype)
536        min_val_cur, max_val_cur = torch.aminmax(x)
537        min_val = torch.min(min_val_cur, self.min_val)
538        max_val = torch.max(max_val_cur, self.max_val)
539        self.min_val.copy_(min_val)
540        self.max_val.copy_(max_val)
541        return x_orig
542
543    @torch.jit.export
544    def calculate_qparams(self):
545        r"""Calculates the quantization parameters."""
546        return self._calculate_qparams(self.min_val, self.max_val)
547
548    @torch.jit.export
549    def extra_repr(self):
550        return f"min_val={self.min_val}, max_val={self.max_val}"
551
552    @torch.jit.export
553    def reset_min_max_vals(self):
554        """Resets the min/max values."""
555        self.min_val.copy_(torch.tensor(float("inf")))
556        self.max_val.copy_(torch.tensor(float("-inf")))
557
558
559class MovingAverageMinMaxObserver(MinMaxObserver):
560    r"""Observer module for computing the quantization parameters based on the
561    moving average of the min and max values.
562
563    This observer computes the quantization parameters based on the moving
564    averages of minimums and maximums of the incoming tensors. The module
565    records the average minimum and maximum of incoming tensors, and uses this
566    statistic to compute the quantization parameters.
567
568    Args:
569        averaging_constant: Averaging constant for min/max.
570        dtype: dtype argument to the `quantize` node needed to implement the
571               reference model spec.
572        qscheme: Quantization scheme to be used
573        reduce_range: Reduces the range of the quantized data type by 1 bit
574        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
575        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
576        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
577
578    The moving average min/max is computed as follows
579
580    .. math::
581
582        \begin{array}{ll}
583                x_\text{min} = \begin{cases}
584                    \min(X) & \text{if~}x_\text{min} = \text{None} \\
585                    (1 - c) x_\text{min} + c \min(X) & \text{otherwise}
586                \end{cases}\\
587                x_\text{max} = \begin{cases}
588                    \max(X) & \text{if~}x_\text{max} = \text{None} \\
589                    (1 - c) x_\text{max} + c \max(X) & \text{otherwise}
590                \end{cases}\\
591        \end{array}
592
593    where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is
594    is the incoming tensor, and :math:`c` is the ``averaging_constant``.
595
596    The scale and zero point are then computed as in
597    :class:`~torch.ao.quantization.observer.MinMaxObserver`.
598
599    .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme.
600
601    .. note:: If the running minimum equals to the running maximum, the scale
602              and zero_point are set to 1.0 and 0.
603    """
604
605    def __init__(
606        self,
607        averaging_constant=0.01,
608        dtype=torch.quint8,
609        qscheme=torch.per_tensor_affine,
610        reduce_range=False,
611        quant_min=None,
612        quant_max=None,
613        eps=torch.finfo(torch.float32).eps,
614        is_dynamic=False,
615        **kwargs,
616    ) -> None:
617        if not is_per_tensor(qscheme):
618            raise NotImplementedError(
619                f"MovingAverageMinMaxObserver's qscheme only support \
620                torch.per_tensor_symmetric and torch.per_tensor_affine. \
621                but got: {qscheme}"
622            )
623        self.averaging_constant = averaging_constant
624        if is_dynamic and self.averaging_constant != 1:
625            raise NotImplementedError(
626                "MovingAverageMinMaxObserver doesn't support dynamic quantization for "
627                f"averaging constant of {self.averaging_constant}"
628            )
629        super().__init__(
630            dtype=dtype,
631            qscheme=qscheme,
632            reduce_range=reduce_range,
633            quant_min=quant_min,
634            quant_max=quant_max,
635            eps=eps,
636            is_dynamic=is_dynamic,
637            **kwargs,
638        )
639
640    def forward(self, x_orig):
641        if x_orig.numel() == 0:
642            return x_orig
643        x = x_orig.detach()  # avoid keeping autograd tape
644        x = x.to(self.min_val.dtype)
645        min_val = self.min_val
646        max_val = self.max_val
647        if min_val == float("inf") and max_val == float("-inf"):
648            min_val, max_val = torch.aminmax(x)
649        else:
650            min_val_cur, max_val_cur = torch.aminmax(x)
651            min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
652            max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
653        self.min_val.copy_(min_val)
654        self.max_val.copy_(max_val)
655        return x_orig
656
657
658class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
659    r"""Observer module for computing the quantization parameters based on the
660    running per channel min and max values.
661
662    This observer uses the tensor min/max statistics to compute the per channel
663    quantization parameters. The module records the running minimum and maximum
664    of incoming tensors, and uses this statistic to compute the quantization
665    parameters.
666
667    Args:
668        ch_axis: Channel axis
669        dtype: dtype argument to the `quantize` node needed to implement the
670               reference model spec.
671        qscheme: Quantization scheme to be used
672        reduce_range: Reduces the range of the quantized data type by 1 bit
673        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
674        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
675        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
676
677    The quantization parameters are computed the same way as in
678    :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
679    that the running min/max values are stored per channel.
680    Scales and zero points are thus computed per channel as well.
681
682    .. note:: If the running minimum equals to the running maximum, the scales
683              and zero_points are set to 1.0 and 0.
684    """
685    min_val: torch.Tensor
686    max_val: torch.Tensor
687
688    def __init__(
689        self,
690        ch_axis=0,
691        dtype=torch.quint8,
692        qscheme=torch.per_channel_affine,
693        reduce_range=False,
694        quant_min=None,
695        quant_max=None,
696        factory_kwargs=None,
697        eps=torch.finfo(torch.float32).eps,
698        is_dynamic=False,
699        **kwargs,
700    ) -> None:
701        if not is_per_channel(qscheme):
702            raise NotImplementedError(
703                "PerChannelMinMaxObserver's qscheme only support \
704                    torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams."
705            )
706        if is_dynamic:
707            raise NotImplementedError(
708                "PerChannelMinMaxObserver doesn't support dynamic quantization"
709            )
710        super().__init__(
711            dtype=dtype,
712            qscheme=qscheme,
713            reduce_range=reduce_range,
714            quant_min=quant_min,
715            quant_max=quant_max,
716            factory_kwargs=factory_kwargs,
717            eps=eps,
718            is_dynamic=is_dynamic,
719            **kwargs,
720        )
721        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
722        self.ch_axis = ch_axis
723        self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
724        self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
725        if (
726            self.qscheme == torch.per_channel_symmetric
727            and self.reduce_range
728            and self.dtype == torch.quint8
729        ):
730            raise NotImplementedError(
731                "Cannot reduce range for symmetric quantization for quint8"
732            )
733
734    def forward(self, x_orig):
735        return self._forward(x_orig)
736
737    def _forward(self, x_orig):
738        if x_orig.numel() == 0:
739            return x_orig
740        x = x_orig.detach()  # avoid keeping autograd tape
741        min_val = self.min_val
742        max_val = self.max_val
743        x_dim = x.size()
744
745        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
746        new_axis_list[self.ch_axis] = 0
747        new_axis_list[0] = self.ch_axis
748        y = x.permute(new_axis_list)
749        # Need to match dtype of min/max because the updates to buffers
750        # are done in place and types need to match for comparisons
751        y = y.to(self.min_val.dtype)
752        y = torch.flatten(y, start_dim=1)
753        if min_val.numel() == 0 or max_val.numel() == 0:
754            min_val, max_val = torch.aminmax(y, dim=1)
755        else:
756            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
757            min_val = torch.min(min_val_cur, min_val)
758            max_val = torch.max(max_val_cur, max_val)
759        self.min_val.resize_(min_val.shape)
760        self.max_val.resize_(max_val.shape)
761        self.min_val.copy_(min_val)
762        self.max_val.copy_(max_val)
763        return x_orig
764
765    @torch.jit.export
766    def calculate_qparams(self):
767        return self._calculate_qparams(self.min_val, self.max_val)
768
769    def extra_repr(self):
770        return f"min_val={self.min_val}, max_val={self.max_val}"
771
772    def _load_from_state_dict(
773        self,
774        state_dict: Dict[str, Any],
775        prefix: str,
776        local_metadata: Dict[str, torch.Tensor],
777        strict: bool,
778        missing_keys: List[str],
779        unexpected_keys: List[str],
780        error_msgs: List[str],
781    ):
782        version = local_metadata.get("version", None)
783        if version is not None and version < 3:
784            local_state = ["min_vals", "max_vals"]
785            expected_min_name = "min_vals"
786            expected_max_name = "max_vals"
787        else:
788            local_state = ["min_val", "max_val"]
789            expected_min_name = "min_val"
790            expected_max_name = "max_val"
791        for name in local_state:
792            key = prefix + name
793            if key in state_dict:
794                val = state_dict[key]
795                # Custom handling to allow loading min_val or max_val
796                # of size N into uninitialized buffers of size 0. The
797                # buffers are resized here, and the values are copied in
798                # the default state_dict loading code of the parent.
799                if name == expected_min_name:
800                    self.min_val.resize_(val.shape)
801                elif name == expected_max_name:
802                    self.max_val.resize_(val.shape)
803                else:
804                    warnings.warn(
805                        f"Observer load_from_state_dict got unexpected name {name}"
806                    )
807                # For torchscript module we need to update the attributes here since we do not
808                # call the `_load_from_state_dict` function defined module.py
809                if torch.jit.is_scripting():
810                    if name == expected_min_name:
811                        self.min_val.copy_(val)
812                    elif name == expected_max_name:
813                        self.max_val.copy_(val)
814                    else:
815                        warnings.warn(
816                            f"Observer load_from_state_dict got unexpected name {name}"
817                        )
818            elif strict:
819                missing_keys.append(key)
820
821        if not torch.jit.is_scripting():
822            super()._load_from_state_dict(
823                state_dict,
824                prefix,
825                local_metadata,
826                False,
827                missing_keys,
828                unexpected_keys,
829                error_msgs,
830            )
831
832    def _load_from_state_dict_script(
833        self,
834        state_dict: Dict[str, Any],
835        prefix: str,
836        local_metadata: Dict[str, torch.Tensor],
837        strict: bool,
838        missing_keys: List[str],
839        unexpected_keys: List[str],
840        error_msgs: List[str],
841    ):
842        self._load_from_state_dict(
843            state_dict,
844            prefix,
845            local_metadata,
846            strict,
847            missing_keys,
848            unexpected_keys,
849            error_msgs,
850        )
851
852    @torch.jit.export
853    def reset_min_max_vals(self):
854        """Resets the min/max values."""
855        # This used to be torch.ones but that does not work because
856        # JIT compiler can optimize it via common subexpression elimination
857        # in which case both min_val and max_val point to the same tensor.
858        self.min_val = torch.rand(
859            0,
860        )
861        self.max_val = torch.rand(
862            0,
863        )
864
865
866class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
867    r"""Observer module for computing the quantization parameters based on the
868    running per channel min and max values.
869
870    This observer uses the tensor min/max statistics to compute the per channel
871    quantization parameters. The module records the running minimum and maximum
872    of incoming tensors, and uses this statistic to compute the quantization
873    parameters.
874
875    Args:
876        averaging_constant: Averaging constant for min/max.
877        ch_axis: Channel axis
878        dtype: Quantized data type
879        qscheme: Quantization scheme to be used
880        reduce_range: Reduces the range of the quantized data type by 1 bit
881        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
882        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
883        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
884
885    The quantization parameters are computed the same way as in
886    :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the
887    difference that the running min/max values are stored per channel.
888    Scales and zero points are thus computed per channel as well.
889
890    .. note:: If the running minimum equals to the running maximum, the scales
891              and zero_points are set to 1.0 and 0.
892    """
893
894    def __init__(
895        self,
896        averaging_constant=0.01,
897        ch_axis=0,
898        dtype=torch.quint8,
899        qscheme=torch.per_channel_affine,
900        reduce_range=False,
901        quant_min=None,
902        quant_max=None,
903        eps=torch.finfo(torch.float32).eps,
904        is_dynamic=False,
905        **kwargs,
906    ) -> None:
907        if not is_per_channel(qscheme):
908            raise NotImplementedError(
909                "MovingAveragePerChannelMinMaxObserver's qscheme only support \
910                    torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams."
911            )
912        if is_dynamic:
913            raise NotImplementedError(
914                "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization"
915            )
916        super().__init__(
917            ch_axis=ch_axis,
918            dtype=dtype,
919            qscheme=qscheme,
920            reduce_range=reduce_range,
921            quant_min=quant_min,
922            quant_max=quant_max,
923            eps=eps,
924            is_dynamic=is_dynamic,
925            **kwargs,
926        )
927        self.averaging_constant = averaging_constant
928
929    def forward(self, x_orig):
930        if x_orig.numel() == 0:
931            return x_orig
932        x = x_orig.detach()  # avoid keeping autograd tape
933        x = x.to(self.min_val.dtype)
934        min_val = self.min_val
935        max_val = self.max_val
936        x_dim = x.size()
937
938        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
939        new_axis_list[self.ch_axis] = 0
940        new_axis_list[0] = self.ch_axis
941        y = x.permute(new_axis_list)
942        y = torch.flatten(y, start_dim=1)
943        if min_val.numel() == 0 or max_val.numel() == 0:
944            min_val, max_val = torch.aminmax(y, dim=1)
945        else:
946            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
947            min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
948            max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
949        self.min_val.resize_(min_val.shape)
950        self.max_val.resize_(max_val.shape)
951        self.min_val.copy_(min_val)
952        self.max_val.copy_(max_val)
953        return x_orig
954
955
956class HistogramObserver(UniformQuantizationObserverBase):
957    r"""
958    The module records the running histogram of tensor values along with
959    min/max values. ``calculate_qparams`` will calculate scale and zero_point.
960
961    Args:
962        bins: Number of bins to use for the histogram
963        dtype: dtype argument to the `quantize` node needed to implement the
964               reference model spec
965        qscheme: Quantization scheme to be used
966        reduce_range: Reduces the range of the quantized data type by 1 bit
967        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
968
969    The scale and zero point are computed as follows:
970
971    1. Create the histogram of the incoming inputs.
972        The histogram is computed continuously, and the ranges per bin change
973        with every new tensor observed.
974    2. Search the distribution in the histogram for optimal min/max values.
975        The search for the min/max values ensures the minimization of the
976        quantization error with respect to the floating point model.
977    3. Compute the scale and zero point the same way as in the
978        :class:`~torch.ao.quantization.MinMaxObserver`
979    """
980    histogram: torch.Tensor
981    min_val: torch.Tensor
982    max_val: torch.Tensor
983
984    def __init__(
985        self,
986        bins: int = 2048,
987        dtype: torch.dtype = torch.quint8,
988        qscheme=torch.per_tensor_affine,
989        reduce_range=False,
990        quant_min=None,
991        quant_max=None,
992        factory_kwargs=None,
993        eps=torch.finfo(torch.float32).eps,
994        is_dynamic=False,
995        **kwargs,
996    ) -> None:
997        if not is_per_tensor(qscheme):
998            raise NotImplementedError(
999                "HistogramObserver's qscheme only support torch.per_tensor_symmetric \
1000                    and torch.per_tensor_affine."
1001            )
1002        if is_dynamic:
1003            raise NotImplementedError(
1004                "HistogramObserver doesn't support dynamic quantization"
1005            )
1006        # bins: The number of bins used for histogram calculation.
1007        super().__init__(
1008            dtype=dtype,
1009            qscheme=qscheme,
1010            reduce_range=reduce_range,
1011            quant_min=quant_min,
1012            quant_max=quant_max,
1013            factory_kwargs=factory_kwargs,
1014            eps=eps,
1015            is_dynamic=is_dynamic,
1016            **kwargs,
1017        )
1018        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
1019        self.bins = bins
1020        self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
1021        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
1022        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
1023        self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
1024        self.upsample_rate = (
1025            16  # used to reduce quantization errors when upscaling histogram
1026        )
1027
1028    def _get_norm(
1029        self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor
1030    ) -> torch.Tensor:
1031        r"""
1032        Compute the norm of the values uniformaly distributed between
1033        delta_begin and delta_end.
1034        Currently only L2 norm is supported.
1035
1036        norm = density * (integral_{begin, end} x^2)
1037             = density * (end^3 - begin^3) / 3
1038        """
1039        norm = (
1040            delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin
1041        ) / 3
1042        return density * norm
1043
1044    def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int):
1045        r"""
1046        Compute the quantization error if we use start_bin to end_bin as the
1047        min and max to do the quantization.
1048        """
1049        bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
1050
1051        dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
1052        if dst_bin_width == 0.0:
1053            return 0.0
1054
1055        src_bin = torch.arange(self.bins, device=self.histogram.device)
1056        # distances from the beginning of first dst_bin to the beginning and
1057        # end of src_bin
1058        src_bin_begin = (src_bin - next_start_bin) * bin_width
1059        src_bin_end = src_bin_begin + bin_width
1060
1061        # which dst_bins the beginning and end of src_bin belong to?
1062        dst_bin_of_begin = torch.clamp(
1063            torch.div(src_bin_begin, dst_bin_width, rounding_mode="floor"),
1064            0,
1065            self.dst_nbins - 1,
1066        )
1067        dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width
1068
1069        dst_bin_of_end = torch.clamp(
1070            torch.div(src_bin_end, dst_bin_width, rounding_mode="floor"),
1071            0,
1072            self.dst_nbins - 1,
1073        )
1074        density = self.histogram / bin_width
1075
1076        norm = torch.zeros(self.bins, device=self.histogram.device)
1077
1078        delta_begin = src_bin_begin - dst_bin_of_begin_center
1079        delta_end = dst_bin_width / 2
1080        norm += self._get_norm(
1081            delta_begin,
1082            torch.ones(self.bins, device=self.histogram.device) * delta_end,
1083            density,
1084        )
1085
1086        norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
1087            torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
1088        )
1089
1090        dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2
1091
1092        delta_begin = -dst_bin_width / 2
1093        delta_end = src_bin_end - dst_bin_of_end_center
1094        norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)
1095
1096        return norm.sum().item()
1097
1098    def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]:
1099        r"""Non-linear parameter search.
1100
1101        An approximation for L2 error minimization for selecting min/max.
1102        By selecting new min/max, we filter out outliers in input distribution.
1103        This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
1104        caffe2/quantization/server/norm_minimization.cc
1105        """
1106        assert self.histogram.size()[0] == self.bins, "bins mismatch"
1107        bin_width = (self.max_val - self.min_val) / self.bins
1108
1109        # cumulative sum
1110        total = torch.sum(self.histogram).item()
1111        cSum = torch.cumsum(self.histogram, dim=0)
1112
1113        stepsize = 1e-5  # granularity
1114        alpha = 0.0  # lower bound
1115        beta = 1.0  # upper bound
1116        start_bin = 0
1117        end_bin = self.bins - 1
1118        norm_min = float("inf")
1119
1120        while alpha < beta:
1121            # Find the next step
1122            next_alpha = alpha + stepsize
1123            next_beta = beta - stepsize
1124
1125            # find the left and right bins between the quantile bounds
1126            l = start_bin
1127            r = end_bin
1128            while l < end_bin and cSum[l] < next_alpha * total:
1129                l = l + 1
1130            while r > start_bin and cSum[r] > next_beta * total:
1131                r = r - 1
1132
1133            # decide the next move
1134            next_start_bin = start_bin
1135            next_end_bin = end_bin
1136            if (l - start_bin) > (end_bin - r):
1137                # move the start bin
1138                next_start_bin = l
1139                alpha = next_alpha
1140            else:
1141                # move the end bin
1142                next_end_bin = r
1143                beta = next_beta
1144
1145            if next_start_bin == start_bin and next_end_bin == end_bin:
1146                continue
1147
1148            # calculate the quantization error using next_start_bin and next_end_bin
1149            norm = self._compute_quantization_error(next_start_bin, next_end_bin)
1150
1151            if norm > norm_min:
1152                break
1153            norm_min = norm
1154            start_bin = next_start_bin
1155            end_bin = next_end_bin
1156
1157        new_min = self.min_val + bin_width * start_bin
1158        new_max = self.min_val + bin_width * (end_bin + 1)
1159        return new_min, new_max
1160
1161    def _upscale_histogram(
1162        self,
1163        histogram: torch.Tensor,
1164        orig_min: torch.Tensor,
1165        orig_max: torch.Tensor,
1166        update_min: torch.Tensor,
1167        update_max: torch.Tensor,
1168    ):
1169        # this turns the histogram into a more fine-coarsed histogram to reduce
1170        # bin quantization errors
1171        histogram = histogram.repeat_interleave(self.upsample_rate) / self.upsample_rate
1172        bin_size = (orig_max - orig_min) / (self.bins * self.upsample_rate)
1173        mid_points_histogram = (
1174            torch.linspace(
1175                orig_min,
1176                orig_max,
1177                self.bins * self.upsample_rate + 1,
1178                device=orig_min.device,
1179            )[:-1].to(histogram.device)
1180            + 0.5 * bin_size
1181        )
1182        boundaries_new_histogram = torch.linspace(
1183            update_min, update_max, self.bins + 1, device=update_min.device
1184        ).to(histogram.device)
1185        # this maps the mid-poits of the histogram to the new histogram's space
1186        bucket_assignments = (
1187            torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True)
1188            - 1
1189        )
1190        # this then maps the histogram mid-points in the new space, weighted by the original histogram's values
1191        # this is just the old histogram in the new histogram's space
1192
1193        # In case due to numerical issues the values land higher/lower than the maximum/minimum
1194        bucket_assignments[bucket_assignments >= self.bins] = self.bins - 1
1195        bucket_assignments[bucket_assignments < 0] = 0
1196
1197        update_histogram = torch.bincount(
1198            bucket_assignments, weights=histogram, minlength=self.bins
1199        )
1200        return update_histogram
1201
1202    def _combine_histograms(
1203        self,
1204        orig_hist: torch.Tensor,
1205        orig_min: torch.Tensor,
1206        orig_max: torch.Tensor,
1207        update_hist: torch.Tensor,
1208        update_min: torch.Tensor,
1209        update_max: torch.Tensor,
1210    ) -> torch.Tensor:
1211        # If the new min and max are the same as the current min and max,
1212        # we can just add the new histogram to the original histogram
1213        if update_min == orig_min and update_max == orig_max:
1214            return orig_hist + update_hist
1215
1216        # If the orig hist only has one value (i.e., the min and max are the same)
1217        # we can just add it into new histogram
1218        if orig_min == orig_max:
1219            bin_value = torch.sum(update_hist)
1220            transformed_orig_hist = (
1221                torch.histc(orig_min, bins=self.bins, min=update_min, max=update_max)  # type: ignore[arg-type]
1222                * bin_value
1223            )
1224            return transformed_orig_hist + update_hist
1225
1226        # We assume the update_hist is already in the target range, we will map the orig_max to it
1227        assert update_min <= orig_min
1228        assert update_max >= orig_max
1229
1230        # Now we need to turn the old_histogram, into the range of the new histogram
1231        transformed_orig_hist = self._upscale_histogram(
1232            orig_hist,
1233            orig_min,
1234            orig_max,
1235            update_min,
1236            update_max,
1237        )
1238
1239        return update_hist + transformed_orig_hist
1240
1241    def reset_histogram(
1242        self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor
1243    ) -> None:
1244        self.min_val.resize_(min_val.shape)
1245        self.min_val.copy_(min_val)
1246        self.max_val.resize_(max_val.shape)
1247        self.max_val.copy_(max_val)
1248        assert (
1249            min_val.numel() == 1 and max_val.numel() == 1
1250        ), "histogram min/max values must be scalar."
1251        new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val)  # type: ignore[arg-type]
1252        self.histogram.detach_().resize_(new_histogram.shape)
1253        self.histogram.copy_(new_histogram)
1254
1255    def forward(self, x_orig: torch.Tensor) -> torch.Tensor:  # pyre-ignore[14]
1256        if x_orig.numel() == 0:
1257            return x_orig
1258        x = x_orig.detach()
1259        x_min, x_max = torch.aminmax(x)
1260        # want to ignore torch.inf since we don't actually
1261        # want to make our quantization range infinite
1262        # and in practice those values will be clamped
1263        if x_min == -torch.inf or x_max == torch.inf:
1264            warnings.warn("torch.inf detected in input tensor, ignoring input")
1265            x = x[x.abs() != torch.inf]
1266            if x.numel() == 0:
1267                return x_orig
1268            x_min, x_max = torch.aminmax(x)
1269
1270        current_min = self.min_val
1271        current_max = self.max_val
1272
1273        is_uninitialized = self.min_val == float("inf") or self.max_val == float("-inf")
1274        if is_uninitialized:
1275            self.reset_histogram(x, x_min, x_max)
1276        else:
1277            update_min, update_max = x_min, x_max
1278            new_min = torch.min(current_min, update_min)
1279            new_max = torch.max(current_max, update_max)
1280
1281            # TODO: For some reason, this is required for it to pass torchscript test
1282            # new_min and new_max should already have requires_grad set to False
1283            new_min, new_max = new_min.detach(), new_max.detach()
1284            update_histogram = torch.histc(
1285                x, self.bins, min=new_min, max=new_max  # type: ignore[arg-type]
1286            ).to(self.histogram.device)
1287            if new_min == current_min and new_max == current_max:
1288                combined_histogram = self.histogram + update_histogram
1289                self.histogram.detach_().resize_(combined_histogram.shape)
1290                self.histogram.copy_(combined_histogram)
1291            else:
1292                combined_histogram = self._combine_histograms(
1293                    self.histogram,
1294                    current_min,
1295                    current_max,
1296                    update_histogram,
1297                    new_min,
1298                    new_max,
1299                )
1300                self.histogram.detach_().resize_(combined_histogram.shape)
1301                self.histogram.copy_(combined_histogram)
1302                self.min_val.detach_().resize_(new_min.shape)
1303                self.min_val.copy_(new_min)
1304                self.max_val.detach_().resize_(new_max.shape)
1305                self.max_val.copy_(new_max)
1306
1307        return x_orig
1308
1309    @torch.jit.export
1310    def calculate_qparams(self):
1311        is_uninitialized = self.min_val == float("inf") and self.max_val == float(
1312            "-inf"
1313        )
1314        if is_uninitialized:
1315            warnings.warn(
1316                "must run observer before calling calculate_qparams.\
1317                                    Returning default scale and zero point "
1318            )
1319            return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor(
1320                [0], device=self.min_val.device.type
1321            )
1322        assert self.bins == len(self.histogram), (
1323            "The number of bins in histogram should be equal to the number of bins "
1324            "supplied while making this observer"
1325        )
1326
1327        new_min, new_max = self._non_linear_param_search()
1328
1329        return self._calculate_qparams(new_min, new_max)
1330
1331    def _save_to_state_dict(self, destination, prefix, keep_vars):
1332        super()._save_to_state_dict(destination, prefix, keep_vars)
1333        destination[prefix + "min_val"] = self.min_val
1334        destination[prefix + "max_val"] = self.max_val
1335
1336    def _load_from_state_dict(
1337        self,
1338        state_dict,
1339        prefix,
1340        local_metadata,
1341        strict,
1342        missing_keys,
1343        unexpected_keys,
1344        error_msgs,
1345    ):
1346        version = local_metadata.get("version", None)
1347
1348        if version is None or version < 3:
1349            # if min_val and max_val are not initialized, update their shape
1350            # to account for the differences between v2 and v3
1351            min_val_name, max_val_name = prefix + "min_val", prefix + "max_val"
1352            if min_val_name in state_dict:
1353                if state_dict[min_val_name].shape == torch.Size([0]):
1354                    state_dict[min_val_name] = torch.tensor(float("inf"))
1355            if max_val_name in state_dict:
1356                if state_dict[max_val_name].shape == torch.Size([0]):
1357                    state_dict[max_val_name] = torch.tensor(float("-inf"))
1358
1359        local_state = ["min_val", "max_val"]
1360        for name in local_state:
1361            key = prefix + name
1362            if key in state_dict:
1363                val = state_dict[key]
1364                setattr(self, name, val)
1365            elif strict:
1366                missing_keys.append(key)
1367        super()._load_from_state_dict(
1368            state_dict,
1369            prefix,
1370            local_metadata,
1371            strict,
1372            missing_keys,
1373            unexpected_keys,
1374            error_msgs,
1375        )
1376
1377    def extra_repr(self):
1378        return f"min_val={self.min_val}, max_val={self.max_val}"
1379
1380
1381class FixedQParamsObserver(ObserverBase):
1382    r"""
1383    Observer that simulates quantize and dequantize with fixed
1384    quantization parameters in training time. Only per tensor
1385    quantization is supported.
1386
1387    Args:
1388        `scale` (float): fixed scale for the observer
1389        `zero_point` (int): fixed zero point for the observer
1390        `dtype`, `qscheme`, `quant_min`, `quant_max`
1391    """
1392
1393    scale: torch.Tensor
1394    zero_point: torch.Tensor
1395
1396    def __init__(
1397        self,
1398        scale,
1399        zero_point,
1400        dtype=torch.quint8,
1401        qscheme=torch.per_tensor_affine,
1402        quant_min=0,
1403        quant_max=255,
1404        is_dynamic=False,
1405        **kwargs,
1406    ):
1407        if is_dynamic:
1408            raise NotImplementedError(
1409                "FixedQParamsObserver doesn't support dynamic quantization"
1410            )
1411        super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs)
1412        self.quant_min = quant_min
1413        self.quant_max = quant_max
1414        self.register_buffer("scale", torch.tensor([scale], dtype=torch.float))
1415        self.register_buffer("zero_point", torch.tensor([zero_point], dtype=torch.int))
1416        self.dtype = dtype
1417        self.qscheme = qscheme
1418
1419    def forward(self, X):
1420        return X
1421
1422    @torch.jit.export
1423    def calculate_qparams(self):
1424        return self.scale, self.zero_point
1425
1426
1427class PlaceholderObserver(ObserverBase):
1428    r"""
1429    Observer that doesn't do anything and just passes its configuration to the
1430    quantized module's ``.from_float()``.
1431
1432    Can be used for quantization to float16 which doesn't require determining
1433    ranges.
1434
1435    Args:
1436        dtype: dtype argument to the `quantize` node needed to implement the
1437               reference model spec.
1438        quant_min: minimum value in quantized domain (TODO: align behavior with other observers)
1439        quant_max: maximum value in quantized domain
1440        custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
1441                        (Can be used in Graph Mode Passes for special case ops).
1442        compute_dtype (deprecated): if set, marks the future quantize function to use
1443                       dynamic quantization instead of static quantization.
1444                       This field is deprecated, use `is_dynamic=True` instead.
1445        is_dynamic: if True, the `quantize` function in the reference model
1446                    representation taking stats from this observer instance will
1447                    use dynamic quantization.
1448    """
1449
1450    def __init__(
1451        self,
1452        dtype=torch.float32,
1453        custom_op_name="",
1454        compute_dtype=None,
1455        quant_min=None,
1456        quant_max=None,
1457        qscheme=None,
1458        eps=None,
1459        is_dynamic=False,
1460    ) -> None:
1461        super().__init__(dtype=dtype, is_dynamic=is_dynamic)
1462        if qscheme is None:
1463            qscheme = torch.per_tensor_affine
1464        if eps is None:
1465            eps = torch.finfo(torch.float32).eps
1466
1467        # dtype of input of the target operator, e.g. for dynamic quantization
1468        # ops, the dtype will be float32
1469        self.dtype = dtype
1470        self.qscheme = qscheme
1471        self.quant_min = quant_min
1472        self.quant_max = quant_max
1473        self.eps = eps
1474        self.custom_op = custom_op_name
1475        # used for configuration of computation type for dynamic quantization
1476        if compute_dtype:
1477            is_dynamic = True
1478            warnings.warn(
1479                "Please use `is_dynamic` instead of `compute_dtype`. \
1480                    `compute_dtype` will be deprecated in a future release \
1481                    of PyTorch."
1482            )
1483
1484    def forward(self, x):
1485        return x
1486
1487    @torch.jit.export
1488    def extra_repr(self):
1489        return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}"
1490
1491    @torch.jit.export
1492    def calculate_qparams(self):
1493        raise Exception(  # noqa: TRY002
1494            "calculate_qparams should not be called for PlaceholderObserver"
1495        )
1496
1497
1498class RecordingObserver(ObserverBase):
1499    r"""
1500    The module is mainly for debug and records the tensor values during runtime.
1501
1502    Args:
1503        dtype: Quantized data type
1504        qscheme: Quantization scheme to be used
1505        reduce_range: Reduces the range of the quantized data type by 1 bit
1506    """
1507    __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]}
1508
1509    def __init__(self, dtype=torch.quint8):
1510        super().__init__(dtype=dtype, is_dynamic=False)  # type: ignore[call-arg]
1511        self.tensor_val = []
1512
1513    def forward(self, x):
1514        self.tensor_val.append(x.clone())
1515        return x
1516
1517    @torch.jit.export
1518    def calculate_qparams(self):
1519        raise Exception(  # noqa: TRY002
1520            "calculate_qparams should not be called for RecordingObserver"
1521        )
1522
1523    @torch.jit.export
1524    def get_tensor_value(self):
1525        return self.tensor_val
1526
1527
1528class NoopObserver(ObserverBase):
1529    r"""
1530    Observer that doesn't do anything and just passes its configuration to the
1531    quantized module's ``.from_float()``.
1532
1533    Primarily used for quantization to float16 which doesn't require determining
1534    ranges.
1535
1536    Args:
1537        dtype: Quantized data type
1538        custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
1539                        (Can be used in Graph Mode Passes for special case ops).
1540    """
1541
1542    def __init__(self, dtype=torch.float16, custom_op_name="") -> None:
1543        super().__init__(dtype=dtype, is_dynamic=False)
1544        self.dtype = dtype
1545        self.custom_op = custom_op_name
1546
1547    def forward(self, x):
1548        return x
1549
1550    @torch.jit.export
1551    def calculate_qparams(self):
1552        raise Exception(  # noqa: TRY002
1553            "calculate_qparams should not be called for NoopObserver"
1554        )
1555
1556
1557class ReuseInputObserver(ObserverBase):
1558    r"""This observer is used when we want to reuse the observer from the operator
1559    that produces the input Tensor, typically used for operators like reshape, e.g.
1560    ```
1561    x0 = ...
1562    x1 = x0.reshape()
1563    ```
1564    if we configure x0 to be observed by some observer, let's say MinMaxObserver,
1565    and reshape is configured with ReuseInputObserver, we'll reuse the observer instance
1566    for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1.
1567
1568    Note: this is only enabled in FX Graph Mode Quantization
1569    """
1570
1571    def __init__(self) -> None:
1572        super().__init__(torch.quint8, is_dynamic=False)
1573
1574    def forward(self, x):
1575        return x
1576
1577    @torch.jit.export
1578    def calculate_qparams(self):
1579        raise Exception(  # noqa: TRY002
1580            "calculate_qparams should not be called for ReuseInputObserver"
1581        )
1582
1583
1584def _is_observer_script_module(mod, obs_type_name):
1585    """Returns true if given mod is an instance of Observer script module."""
1586    if isinstance(mod, torch.jit.RecursiveScriptModule):
1587        # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver'
1588        suffix = mod._c.qualified_name.split(".", 1)[1]
1589        name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
1590        return obs_type_name in name
1591    return False
1592
1593
1594def _is_activation_post_process(module):
1595    return isinstance(
1596        module,
1597        (torch.ao.quantization.ObserverBase, torch.ao.quantization.FakeQuantizeBase),
1598    ) or _is_observer_script_module(module, "quantization.observer")
1599
1600
1601def _is_per_channel_script_obs_instance(module):
1602    if isinstance(module, torch.jit.RecursiveScriptModule):
1603        return _is_observer_script_module(
1604            module, "quantization.observer.PerChannelMinMaxObserver"
1605        ) or _is_observer_script_module(
1606            module, "quantization.observer.MovingAveragePerChannelMinMaxObserver"
1607        )
1608    return False
1609
1610
1611def get_observer_state_dict(mod):
1612    r"""
1613    Returns the state dict corresponding to the observer stats.
1614    Traverse the model state_dict and extract out the stats.
1615    """
1616    od = OrderedDict()
1617    if isinstance(mod, torch.jit.RecursiveScriptModule):
1618        for k, v in mod.state_dict().items():
1619            if "observer" in k:
1620                od[k] = v
1621    else:
1622        # path for GraphModule and nn.Module (eager mode)
1623        for k, v in mod.state_dict().items():
1624            if "activation_post_process" in k:
1625                od[k] = v
1626    od._metadata = mod.state_dict()._metadata  # type: ignore[attr-defined]
1627    return od
1628
1629
1630def load_observer_state_dict(mod, obs_dict):
1631    r"""
1632    Given input model and a state_dict containing model observer stats,
1633    load the stats back into the model. The observer state_dict can be saved
1634    using torch.ao.quantization.get_observer_state_dict
1635    """
1636    missing_keys: List[str] = []
1637    unexpected_keys: List[str] = []
1638    for name, module in mod.named_modules():
1639        prefix = name + "."
1640        if _is_activation_post_process(module):
1641            if _is_per_channel_script_obs_instance(module):
1642                # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor.
1643                # However this is not called when the module is scripted and we end up calling the default one in module.py
1644                module._load_from_state_dict_script(
1645                    obs_dict, prefix, {}, True, missing_keys, unexpected_keys, []
1646                )
1647            else:
1648                module._load_from_state_dict(
1649                    obs_dict, prefix, {}, False, missing_keys, unexpected_keys, []
1650                )
1651    for k in missing_keys:
1652        if "observer" in k or "activation_post_process" in k:
1653            raise Exception(  # noqa: TRY002
1654                f"Missing keys for observer {k} in state_dict"
1655            )
1656    for k in unexpected_keys:
1657        if "observer" in k or "activation_post_process" in k:
1658            raise Exception(  # noqa: TRY002
1659                f"Unexpected keys for observer {k} in state_dict"
1660            )
1661
1662
1663# Restrict activations to be in the range (0,127)
1664default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127)
1665"""
1666Default observer for static quantization, usually used for debugging.
1667"""
1668
1669default_placeholder_observer = PlaceholderObserver
1670"""
1671Default placeholder observer, usually used for quantization to torch.float16.
1672"""
1673
1674default_debug_observer = RecordingObserver
1675"""
1676Default debug-only observer.
1677"""
1678
1679default_weight_observer = MinMaxObserver.with_args(
1680    dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
1681)
1682"""
1683Default weight observer.
1684"""
1685
1686weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args(
1687    dtype=torch.qint8,
1688    qscheme=torch.per_tensor_symmetric,
1689    quant_min=-127,
1690    quant_max=127,
1691    eps=2**-12,
1692)
1693"""
1694Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
1695"""
1696
1697default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127)
1698"""
1699Default histogram observer, usually used for PTQ.
1700"""
1701
1702default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
1703    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
1704)
1705"""
1706Default per-channel weight observer, usually used on backends where per-channel
1707weight quantization is supported, such as `fbgemm`.
1708"""
1709
1710per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args(
1711    dtype=torch.qint8,
1712    qscheme=torch.per_channel_symmetric,
1713    quant_min=-127,
1714    quant_max=127,
1715    eps=2**-12,
1716)
1717"""
1718Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
1719"""
1720
1721default_dynamic_quant_observer = PlaceholderObserver.with_args(
1722    dtype=torch.quint8,
1723    quant_min=0,
1724    quant_max=255,
1725    is_dynamic=True,
1726)
1727"""
1728Default observer for dynamic quantization.
1729"""
1730
1731default_float_qparams_observer = PerChannelMinMaxObserver.with_args(
1732    dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
1733)
1734"""
1735Default observer for a floating point zero-point.
1736"""
1737
1738default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args(
1739    dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
1740)
1741"""
1742Default observer for a floating point zero-point and 4 bit activations.
1743"""
1744
1745# TODO(future PR): remove these defaults and enforce activation functions
1746# to explicitly specify their output range
1747default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args(
1748    scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255
1749)
1750default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args(
1751    scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255
1752)
1753# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
1754default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer
1755default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer
1756
1757"""
1758Default observers for fixed qparams operations.
1759"""
1760
1761default_reuse_input_observer = ReuseInputObserver
1762"""
1763Default observer for operators like reshape that reuses the observer of input to
1764the operator
1765"""
1766