xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fake_quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3"""Implements modules  used to perform fake quantization."""
4
5import re
6from abc import ABC, abstractmethod
7from typing import Any, Tuple
8
9import torch
10from torch.ao.quantization.observer import (
11    _with_args,
12    default_fixed_qparams_range_0to1_observer,
13    default_fixed_qparams_range_neg1to1_observer,
14    FixedQParamsObserver,
15    HistogramObserver,
16    MovingAverageMinMaxObserver,
17    MovingAveragePerChannelMinMaxObserver,
18)
19from torch.nn import Module
20
21
22__all__ = [
23    "FakeQuantizeBase",
24    "FakeQuantize",
25    "FixedQParamsFakeQuantize",
26    "FusedMovingAvgObsFakeQuantize",
27    "disable_fake_quant",
28    "disable_observer",
29    "enable_fake_quant",
30    "enable_observer",
31    "default_fake_quant",
32    "default_weight_fake_quant",
33    "default_dynamic_fake_quant",
34    "default_fixed_qparams_range_neg1to1_fake_quant",
35    "default_fixed_qparams_range_0to1_fake_quant",
36    "default_symmetric_fixed_qparams_fake_quant",
37    "default_affine_fixed_qparams_fake_quant",
38    "default_per_channel_weight_fake_quant",
39    "default_embedding_fake_quant",
40    "default_embedding_fake_quant_4bit",
41    "default_histogram_fake_quant",
42    "default_fused_act_fake_quant",
43    "default_fused_wt_fake_quant",
44    "default_fused_per_channel_wt_fake_quant",
45    "fused_wt_fake_quant_range_neg_127_to_127",
46    "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
47]
48
49
50def _is_per_channel(qscheme: "torch.qscheme") -> bool:
51    return qscheme in [
52        torch.per_channel_symmetric,
53        torch.per_channel_affine,
54        torch.per_channel_affine_float_qparams,
55    ]
56
57
58def _is_per_tensor(qscheme: "torch.qscheme") -> bool:
59    return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
60
61
62def _is_symmetric_quant(qscheme: "torch.qscheme") -> bool:
63    return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
64
65
66def _is_float_qparams(qscheme: "torch.qscheme") -> bool:
67    return qscheme in [
68        torch.per_channel_affine_float_qparams,
69    ]
70
71
72class FakeQuantizeBase(ABC, Module):
73    r"""Base fake quantize module.
74
75    Base fake quantize module
76    Any fake quantize implementation should derive from this class.
77
78    Concrete fake quantize module should follow the same API. In forward, they will update
79    the statistics of the observed Tensor and fake quantize the input. They should also provide a
80    `calculate_qparams` function that computes the quantization parameters given
81    the collected statistics.
82
83    """
84
85    fake_quant_enabled: torch.Tensor
86    observer_enabled: torch.Tensor
87
88    def __init__(self) -> None:
89        """Set fake_quant_enabled and observer_enabled."""
90        super().__init__()
91        # fake_quant_enabled and observer_enabled are buffers to support their
92        # replication in DDP. Data type is uint8 because NCCL does not support
93        # bool tensors.
94        self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8))
95        self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.uint8))
96
97    @abstractmethod
98    def forward(self, x):
99        pass
100
101    @abstractmethod
102    def calculate_qparams(self, **kwargs):
103        pass
104
105    @torch.jit.export
106    def enable_fake_quant(self, enabled: bool = True) -> None:
107        self.fake_quant_enabled[0] = 1 if enabled else 0
108
109    @torch.jit.export
110    def disable_fake_quant(self):
111        self.enable_fake_quant(False)
112
113    @torch.jit.export
114    def enable_observer(self, enabled: bool = True) -> None:
115        self.observer_enabled[0] = 1 if enabled else 0
116
117    @torch.jit.export
118    def disable_observer(self):
119        self.enable_observer(False)
120
121    @classmethod
122    def with_args(cls, **kwargs):
123        fake_quant_constructor = _with_args(cls, **kwargs)
124        # need to assign the correct module to fake_quantize
125        # constructors to satisfy public v private requirements
126        fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
127        return fake_quant_constructor
128
129
130class FakeQuantize(FakeQuantizeBase):
131    r"""Simulate the quantize and dequantize operations in training time.
132
133    The output of this module is given by::
134
135        x_out = (
136          clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point
137        ) * scale
138
139    * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
140      operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
141
142    * :attr:`scale` defines the scale factor used for quantization.
143
144    * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
145
146    * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
147      statistics can still be updated.
148
149    * :attr:`observer_enabled` controls statistics collection on tensors
150
151    * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
152        allowable values are torch.qint8 and torch.quint8.
153
154    Args:
155
156        observer (module): Module for observing statistics on input tensors and calculating scale
157          and zero-point.
158        observer_kwargs (optional): Arguments for the observer module
159
160    Attributes:
161        activation_post_process (Module): User provided module that collects statistics on the input tensor and
162          provides a method to calculate scale and zero-point.
163
164    """
165
166    scale: torch.Tensor
167    zero_point: torch.Tensor
168
169    def __init__(
170        self,
171        observer=MovingAverageMinMaxObserver,
172        quant_min=None,
173        quant_max=None,
174        is_dynamic=False,
175        **observer_kwargs,
176    ):
177        super().__init__()
178        # Populate quant_min/quant_max to observer_kwargs if valid
179        if quant_min is not None and quant_max is not None:
180            assert (
181                quant_min <= quant_max
182            ), "quant_min must be less than or equal to quant_max"
183            dtype = observer_kwargs.get("dtype", torch.quint8)
184            if hasattr(observer, "p"):
185                # In case observer is _PartialWrapper, dtype can be stored in
186                # observer.p.keywords["dtype"]
187                dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
188                    "dtype", dtype
189                )
190            assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
191            assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
192            observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
193        observer_kwargs["is_dynamic"] = is_dynamic
194        self.activation_post_process = observer(**observer_kwargs)
195        # TODO: keeping self.quant_min/max for BC; remove after a couple releases
196        # Users should use self.activation_post_process.quant_min
197        self.quant_min = self.activation_post_process.quant_min
198        self.quant_max = self.activation_post_process.quant_max
199        self.is_dynamic = self.activation_post_process.is_dynamic
200        if _is_float_qparams(self.activation_post_process.qscheme):
201            zero_point_dtype = torch.float
202        else:
203            zero_point_dtype = torch.int
204        self.register_buffer("scale", torch.tensor([1.0], dtype=torch.float))
205        self.register_buffer("zero_point", torch.tensor([0], dtype=zero_point_dtype))
206        self.dtype = self.activation_post_process.dtype
207        self.qscheme = self.activation_post_process.qscheme
208        self.ch_axis = (
209            self.activation_post_process.ch_axis
210            if hasattr(self.activation_post_process, "ch_axis")
211            else -1
212        )
213        assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), (
214            "Only per channel and per tensor quantization are supported in fake quantize"
215            + " got qscheme: "
216            + str(self.qscheme)
217        )
218        self.is_per_channel = _is_per_channel(self.qscheme)
219
220    @torch.jit.export
221    def calculate_qparams(self):
222        return self.activation_post_process.calculate_qparams()
223
224    def forward(self, X):
225        if self.observer_enabled[0] == 1:
226            self.activation_post_process(X.detach())
227            _scale, _zero_point = self.calculate_qparams()
228            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
229                self.zero_point.device
230            )
231            if self.scale.shape != _scale.shape:
232                self.scale.resize_(_scale.shape)
233                self.zero_point.resize_(_zero_point.shape)
234            self.scale.copy_(_scale)
235            self.zero_point.copy_(_zero_point)
236
237        if self.fake_quant_enabled[0] == 1:
238            if self.is_per_channel:
239                X = torch.fake_quantize_per_channel_affine(
240                    X,
241                    self.scale,
242                    self.zero_point,
243                    self.ch_axis,
244                    self.activation_post_process.quant_min,
245                    self.activation_post_process.quant_max,
246                )
247            else:
248                X = torch.fake_quantize_per_tensor_affine(
249                    X,
250                    self.scale,
251                    self.zero_point,
252                    self.activation_post_process.quant_min,
253                    self.activation_post_process.quant_max,
254                )
255        return X
256
257    @torch.jit.export
258    def extra_repr(self):
259        return (
260            f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
261            f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
262            f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
263            f"scale={self.scale}, zero_point={self.zero_point}"
264        )
265
266    def _save_to_state_dict(self, destination, prefix, keep_vars):
267        # We cannot currently register scalar values as buffers, so need to manually
268        # specify serialization here.
269        super()._save_to_state_dict(destination, prefix, keep_vars)
270        destination[prefix + "scale"] = self.scale
271        destination[prefix + "zero_point"] = self.zero_point
272
273    def _load_from_state_dict(
274        self,
275        state_dict,
276        prefix,
277        local_metadata,
278        strict,
279        missing_keys,
280        unexpected_keys,
281        error_msgs,
282    ):
283        # Removing this function throws an error that the size of the loaded tensor does not match the original size
284        # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
285        local_state = ["scale", "zero_point"]
286        for name in local_state:
287            key = prefix + name
288            if key in state_dict:
289                val = state_dict[key]
290                # Custom handling to allow loading scale and zero_point
291                # of size N into uninitialized buffers of size 0. The
292                # buffers are resized here, and the values are copied in
293                # the default state_dict loading code of the parent.
294                if name == "scale":
295                    self.scale.resize_(val.shape)
296                else:
297                    assert name == "zero_point"
298                    self.zero_point.resize_(val.shape)
299                # For torchscript module we need to update the attributes here since we do not
300                # call the `_load_from_state_dict` function defined module.py
301                if torch.jit.is_scripting():
302                    if name == "scale":
303                        self.scale.copy_(val)
304                    else:
305                        assert name == "zero_point"
306                        self.zero_point.copy_(val)
307            elif strict:
308                missing_keys.append(key)
309        super()._load_from_state_dict(
310            state_dict,
311            prefix,
312            local_metadata,
313            strict,
314            missing_keys,
315            unexpected_keys,
316            error_msgs,
317        )
318
319
320class FixedQParamsFakeQuantize(FakeQuantize):
321    """Simulate quantize and dequantize in training time.
322
323    Simulate quantize and dequantize with fixed quantization
324    parameters in training time. Only per tensor quantization
325    is supported.
326    """
327
328    # TODO: rename observer to observer_ctr
329    def __init__(self, observer):
330        super().__init__(observer=observer)
331        assert (
332            type(self.activation_post_process) == FixedQParamsObserver
333        ), f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
334        self._observer_ctr = observer
335        self.scale = self.activation_post_process.scale
336        self.zero_point = self.activation_post_process.zero_point
337        assert _is_per_tensor(self.qscheme), (
338            "Only per tensor quantization is supported"
339            + " FixedQParamsFakeQuantize module, got qscheme:"
340            + str(self.qscheme)
341        )
342
343    @torch.jit.export
344    def calculate_qparams(self):
345        return self.scale, self.zero_point
346
347    @torch.jit.export
348    def extra_repr(self):
349        """Define a string representation of the object's attributes."""
350        return (
351            f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
352            f"scale={self.scale}, zero_point={self.zero_point}, "
353            f"dtype={self.dtype}, quant_min={self.activation_post_process.quant_min}, "
354            f"quant_max={self.activation_post_process.quant_max}, qscheme={self.qscheme}"
355        )
356
357
358class FusedMovingAvgObsFakeQuantize(FakeQuantize):
359    r"""Define a fused module to observe the tensor.
360
361    Fused module that is used to observe the input tensor (compute min/max), compute
362    scale/zero_point and fake_quantize the tensor.
363    This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
364    to compute the min/max values in order to compute the scale/zero_point.
365    The qscheme input in the observer is used to differentiate between symmetric/affine
366    quantization scheme.
367
368    The output of this module is given by
369    x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
370
371    Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
372    base class.
373
374    """
375
376    def __init__(
377        self,
378        observer: Any = MovingAverageMinMaxObserver,
379        quant_min: int = 0,
380        quant_max: int = 255,
381        **observer_kwargs: Any,
382    ) -> None:
383        super().__init__(observer, quant_min, quant_max, **observer_kwargs)
384        assert isinstance(
385            self.activation_post_process,
386            (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver),
387        ), "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
388        self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
389        self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
390        self.is_symmetric_quant = _is_symmetric_quant(
391            self.activation_post_process.qscheme
392        )
393
394    @torch.jit.export
395    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
396        return self.activation_post_process.calculate_qparams()
397
398    @torch.jit.export
399    def extra_repr(self) -> str:
400        return (
401            f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
402            f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}, "
403            f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
404            f"qscheme={self.qscheme}, reduce_range={self.activation_post_process.reduce_range}"
405        )
406
407    def forward(self, X: torch.Tensor) -> torch.Tensor:
408        return torch.fused_moving_avg_obs_fake_quant(
409            X,
410            self.observer_enabled,
411            self.fake_quant_enabled,
412            self.activation_post_process.min_val,
413            self.activation_post_process.max_val,
414            self.scale,
415            self.zero_point,
416            self.activation_post_process.averaging_constant,
417            self.activation_post_process.quant_min,
418            self.activation_post_process.quant_max,
419            self.ch_axis,
420            self.is_per_channel,
421            self.is_symmetric_quant,
422        )
423
424
425default_fake_quant = FakeQuantize.with_args(
426    observer=MovingAverageMinMaxObserver,
427    quant_min=0,
428    quant_max=255,
429    dtype=torch.quint8,
430    qscheme=torch.per_tensor_affine,
431    reduce_range=True,
432)
433"""
434Default fake_quant for activations.
435"""
436
437default_weight_fake_quant = FakeQuantize.with_args(
438    observer=MovingAverageMinMaxObserver,
439    quant_min=-128,
440    quant_max=127,
441    dtype=torch.qint8,
442    qscheme=torch.per_tensor_symmetric,
443    reduce_range=False,
444)
445"""
446Default fake_quant for weights.
447Observer is memoryless since averaging_constant is 1.
448"""
449
450default_dynamic_fake_quant = FakeQuantize.with_args(
451    observer=MovingAverageMinMaxObserver,
452    quant_min=0,
453    quant_max=255,
454    is_dynamic=True,
455    dtype=torch.quint8,
456    averaging_constant=1,
457)
458"""
459Default dynamic fake_quant for activations.
460"""
461
462default_fixed_qparams_range_neg1to1_fake_quant = FixedQParamsFakeQuantize.with_args(
463    observer=default_fixed_qparams_range_neg1to1_observer
464)
465default_fixed_qparams_range_0to1_fake_quant = FixedQParamsFakeQuantize.with_args(
466    observer=default_fixed_qparams_range_0to1_observer
467)
468# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
469default_symmetric_fixed_qparams_fake_quant = (
470    default_fixed_qparams_range_neg1to1_fake_quant
471)
472default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
473
474default_per_channel_weight_fake_quant = FakeQuantize.with_args(
475    observer=MovingAveragePerChannelMinMaxObserver,
476    quant_min=-128,
477    quant_max=127,
478    dtype=torch.qint8,
479    qscheme=torch.per_channel_symmetric,
480    reduce_range=False,
481    ch_axis=0,
482)
483"""
484Default fake_quant for per-channel weights.
485Observer is memoryless since averaging_constant is 1.
486"""
487default_embedding_fake_quant = FakeQuantize.with_args(
488    observer=MovingAveragePerChannelMinMaxObserver,
489    qscheme=torch.per_channel_affine_float_qparams,
490    dtype=torch.quint8,
491    quant_min=0,
492    quant_max=255,
493    ch_axis=0,
494    averaging_constant=1,
495)
496"""
497Default fake_quant for embeddings.
498Observer is memoryless since averaging_constant is 1.
499"""
500
501default_embedding_fake_quant_4bit = FakeQuantize.with_args(
502    observer=MovingAveragePerChannelMinMaxObserver,
503    qscheme=torch.per_channel_affine_float_qparams,
504    ch_axis=0,
505    dtype=torch.quint4x2,
506    averaging_constant=1,
507)
508
509default_histogram_fake_quant = FakeQuantize.with_args(
510    observer=HistogramObserver,
511    quant_min=0,
512    quant_max=255,
513    dtype=torch.quint8,
514    qscheme=torch.per_tensor_affine,
515    reduce_range=True,
516)
517"""
518Fake_quant for activations using a histogram..
519"""
520
521
522default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
523    observer=MovingAverageMinMaxObserver,
524    quant_min=0,
525    quant_max=255,
526    dtype=torch.quint8,
527)
528
529"""
530Fused version of `default_fake_quant`, with improved performance.
531"""
532
533
534default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
535    observer=MovingAverageMinMaxObserver,
536    quant_min=-128,
537    quant_max=127,
538    dtype=torch.qint8,
539    qscheme=torch.per_tensor_symmetric,
540)
541"""
542Fused version of `default_weight_fake_quant`, with improved performance.
543"""
544
545default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
546    observer=MovingAveragePerChannelMinMaxObserver,
547    quant_min=-128,
548    quant_max=127,
549    dtype=torch.qint8,
550    qscheme=torch.per_channel_symmetric,
551)
552"""
553Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
554"""
555
556fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(
557    observer=MovingAverageMinMaxObserver,
558    quant_min=-127,
559    quant_max=127,
560    dtype=torch.qint8,
561    qscheme=torch.per_tensor_symmetric,
562    eps=2**-12,
563)
564"""
565Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
566"""
567
568fused_per_channel_wt_fake_quant_range_neg_127_to_127 = (
569    FusedMovingAvgObsFakeQuantize.with_args(
570        observer=MovingAveragePerChannelMinMaxObserver,
571        quant_min=-127,
572        quant_max=127,
573        dtype=torch.qint8,
574        qscheme=torch.per_channel_symmetric,
575        eps=2**-12,
576    )
577)
578
579"""
580Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
581"""
582
583
584def _is_fake_quant_script_module(mod):
585    """Return true if given mod is an instance of FakeQuantize script module."""
586    if isinstance(mod, torch.jit.RecursiveScriptModule):
587        # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
588        suffix = mod._c.qualified_name.split(".", 1)[1]
589        name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
590        return (
591            name == "torch.ao.quantization.fake_quantize.FakeQuantize"
592            or name
593            == "torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize"
594        )
595    return False
596
597
598def disable_fake_quant(mod):
599    """Disable fake quantization for the module.
600
601    Disable fake quantization for this module, if applicable. Example usage::
602
603      # model is any PyTorch model
604      model.apply(torch.ao.quantization.disable_fake_quant)
605
606    """
607    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
608        mod.disable_fake_quant()
609
610
611def enable_fake_quant(mod):
612    """Enable fake quantization for the module.
613
614    Enable fake quantization for this module, if applicable. Example usage::
615
616      # model is any PyTorch model
617      model.apply(torch.ao.quantization.enable_fake_quant)
618
619    """
620    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
621        mod.enable_fake_quant()
622
623
624def disable_observer(mod):
625    """Disable observation for this module.
626
627    Disable observation for this module, if applicable. Example usage::
628
629      # model is any PyTorch model
630      model.apply(torch.ao.quantization.disable_observer)
631
632    """
633    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
634        mod.disable_observer()
635
636
637def enable_observer(mod):
638    """Enable observation for this module.
639
640    Enable observation for this module, if applicable. Example usage::
641
642      # model is any PyTorch model
643      model.apply(torch.ao.quantization.enable_observer)
644
645    """
646    if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
647        mod.enable_observer()
648