xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/_learnable_fake_quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List
3
4import torch
5from torch.nn.parameter import Parameter
6
7
8__all__: List[str] = []
9
10
11class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
12    r"""Generalized extension of the FakeQuantize module in fake_quantize.py.
13
14    This is an extension of the FakeQuantize module in fake_quantize.py, which
15    supports more generalized lower-bit quantization and supports learning of the scale
16    and zero point parameters through backpropagation.
17
18    In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
19    module also includes the following attributes to support quantization parameter learning.
20
21    * :attr:`channel_len` defines the length of the channel when initializing scale and zero point
22      for the per channel case.
23
24    * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
25      normalized by the constant, which is proportional to the square root of the number of
26      elements in the tensor. The related literature justifying the use of this particular constant
27      can be found here: https://openreview.net/pdf?id=rkgO66VKDS.
28
29    * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.
30
31    * :attr:`static_enabled` defines the flag for using observer's static estimation for
32      scale and zero point.
33
34    * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
35    """
36
37    def __init__(
38        self,
39        observer,
40        quant_min=0,
41        quant_max=255,
42        scale=1.0,
43        zero_point=0.0,
44        channel_len=-1,
45        use_grad_scaling=False,
46        **observer_kwargs,
47    ):
48        super().__init__()
49        assert quant_min < quant_max, "quant_min must be strictly less than quant_max."
50        self.quant_min = quant_min
51        self.quant_max = quant_max
52        # also pass quant_min and quant_max to observer
53        observer_kwargs["quant_min"] = quant_min
54        observer_kwargs["quant_max"] = quant_max
55        self.use_grad_scaling = use_grad_scaling
56        if channel_len == -1:
57            self.scale = Parameter(torch.tensor([scale]))
58            self.zero_point = Parameter(torch.tensor([zero_point]))
59        else:
60            assert (
61                isinstance(channel_len, int) and channel_len > 0
62            ), "Channel size must be a positive integer."
63            self.scale = Parameter(torch.tensor([scale] * channel_len))
64            self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
65
66        self.activation_post_process = observer(**observer_kwargs)
67        assert (
68            torch.iinfo(self.activation_post_process.dtype).min <= quant_min
69        ), "quant_min out of bound"
70        assert (
71            quant_max <= torch.iinfo(self.activation_post_process.dtype).max
72        ), "quant_max out of bound"
73        self.dtype = self.activation_post_process.dtype
74        self.qscheme = self.activation_post_process.qscheme
75        self.ch_axis = (
76            self.activation_post_process.ch_axis
77            if hasattr(self.activation_post_process, "ch_axis")
78            else -1
79        )
80        self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8))
81        self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8))
82        self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8))
83
84        bitrange = torch.tensor(quant_max - quant_min + 1).double()
85        self.bitwidth = int(torch.log2(bitrange).item())
86        self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps]))
87
88    @torch.jit.export
89    def enable_param_learning(self):
90        r"""Enable parameter learning over static observer estimates.
91
92        Enables learning of quantization parameters and
93        disables static observer estimates. Forward path returns fake quantized X.
94        """
95        self.toggle_qparam_learning(enabled=True).toggle_fake_quant(
96            enabled=True
97        ).toggle_observer_update(enabled=False)
98        return self
99
100    @torch.jit.export
101    def enable_static_estimate(self):
102        """Enable static estimates of quantization parameters.
103
104        Enables static observer estimates and disables learning of
105        quantization parameters. Forward path returns fake quantized X.
106        """
107        self.toggle_qparam_learning(enabled=False).toggle_fake_quant(
108            enabled=True
109        ).toggle_observer_update(enabled=True)
110
111    @torch.jit.export
112    def enable_static_observation(self):
113        """Enable accumulation of data without updating quantization parameters.
114
115        Enables static observer accumulating data from input but doesn't
116        update the quantization parameters. Forward path returns the original X.
117        """
118        self.toggle_qparam_learning(enabled=False).toggle_fake_quant(
119            enabled=False
120        ).toggle_observer_update(enabled=True)
121
122    @torch.jit.export
123    def toggle_observer_update(self, enabled=True):
124        self.static_enabled[0] = int(enabled)  # type: ignore[operator]
125        return self
126
127    @torch.jit.export
128    def enable_observer(self, enabled=True):
129        self.toggle_observer_update(enabled)
130
131    @torch.jit.export
132    def toggle_qparam_learning(self, enabled=True):
133        self.learning_enabled[0] = int(enabled)  # type: ignore[operator]
134        self.scale.requires_grad = enabled
135        self.zero_point.requires_grad = enabled
136        return self
137
138    @torch.jit.export
139    def toggle_fake_quant(self, enabled=True):
140        self.fake_quant_enabled[0] = int(enabled)
141        return self
142
143    @torch.jit.export
144    def observe_quant_params(self):
145        print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}")
146        print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}")
147
148    @torch.jit.export
149    def calculate_qparams(self):
150        self.scale.data.clamp_(min=self.eps.item())  # type: ignore[operator]
151        scale = self.scale.detach()
152        zero_point = (
153            self.zero_point.detach()
154            .round()
155            .clamp(self.quant_min, self.quant_max)
156            .long()
157        )
158        return scale, zero_point
159
160    def forward(self, X):
161        if self.static_enabled[0] == 1:  # type: ignore[index]
162            self.activation_post_process(X.detach())
163            _scale, _zero_point = self.activation_post_process.calculate_qparams()
164            _scale = _scale.to(self.scale.device)
165            _zero_point = _zero_point.to(self.zero_point.device)
166            self.scale.data.copy_(_scale)
167            self.zero_point.data.copy_(_zero_point)
168        else:
169            self.scale.data.clamp_(min=self.eps.item())  # type: ignore[operator]
170
171        if self.fake_quant_enabled[0] == 1:
172            if self.qscheme in (
173                torch.per_channel_symmetric,
174                torch.per_tensor_symmetric,
175            ):
176                self.zero_point.data.zero_()
177
178            if self.use_grad_scaling:
179                grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
180            else:
181                grad_factor = 1.0
182            if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine):
183                X = torch._fake_quantize_learnable_per_channel_affine(
184                    X,
185                    self.scale,
186                    self.zero_point,
187                    self.ch_axis,
188                    self.quant_min,
189                    self.quant_max,
190                    grad_factor,
191                )
192            else:
193                X = torch._fake_quantize_learnable_per_tensor_affine(
194                    X,
195                    self.scale,
196                    self.zero_point,
197                    self.quant_min,
198                    self.quant_max,
199                    grad_factor,
200                )
201
202        return X
203