xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/modules/activation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from warnings import warn
3
4import torch
5
6
7__all__ = [
8    "ReLU6",
9    "Hardswish",
10    "ELU",
11    "LeakyReLU",
12    "Sigmoid",
13    "Softmax",
14    "MultiheadAttention",
15    "PReLU",
16]
17
18
19class ReLU6(torch.nn.ReLU):
20    r"""Applies the element-wise function:
21
22    :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
23    zero_point, and :math:`q(6)` is the quantized representation of number 6.
24
25    Args:
26        inplace: can optionally do the operation in-place. Default: ``False``
27
28    Shape:
29        - Input: :math:`(N, *)` where `*` means, any number of additional
30          dimensions
31        - Output: :math:`(N, *)`, same shape as the input
32
33    .. image:: ../scripts/activation_images/ReLU6.png
34
35    Examples::
36
37        >>> m = nn.quantized.ReLU6()
38        >>> input = torch.randn(2)
39        >>> # xdoctest: +SKIP
40        >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
41        >>> output = m(input)
42    """
43
44    def __init__(self, inplace=False):
45        super().__init__(inplace)
46        self.inplace = inplace
47
48    def forward(self, input):
49        return torch.ops.quantized.relu6(input, self.inplace)
50
51    def _get_name(self):
52        return "QuantizedReLU6"
53
54    @staticmethod
55    def from_float(mod, use_precomputed_fake_quant=False):
56        return ReLU6(mod.inplace)
57
58
59class Hardswish(torch.nn.Hardswish):
60    r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
61
62    Args:
63        scale: quantization scale of the output tensor
64        zero_point: quantization zero point of the output tensor
65    """
66
67    def __init__(self, scale, zero_point, device=None, dtype=None):
68        factory_kwargs = {"device": device, "dtype": dtype}
69        super().__init__()
70        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
71        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
72
73    def forward(self, input):
74        return torch.ops.quantized.hardswish(input, self.scale, self.zero_point)
75
76    def _get_name(self):
77        return "QuantizedHardswish"
78
79    @staticmethod
80    def from_float(mod, use_precomputed_fake_quant=False):
81        scale, zero_point = mod.activation_post_process.calculate_qparams()
82        return Hardswish(float(scale), int(zero_point))
83
84    @classmethod
85    def from_reference(cls, mod, scale, zero_point):
86        return cls(float(scale), int(zero_point))
87
88
89class ELU(torch.nn.ELU):
90    r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
91
92    Args:
93        scale: quantization scale of the output tensor
94        zero_point: quantization zero point of the output tensor
95        alpha: the alpha constant
96    """
97
98    def __init__(self, scale, zero_point, alpha=1.0):
99        super().__init__(alpha)
100        self.scale = scale
101        self.zero_point = zero_point
102
103    def forward(self, input):
104        return torch.ao.nn.quantized.functional.elu(
105            input, self.scale, self.zero_point, self.alpha
106        )
107
108    def _get_name(self):
109        return "QuantizedELU"
110
111    @staticmethod
112    def from_float(mod, use_precomputed_fake_quant=False):
113        scale, zero_point = mod.activation_post_process.calculate_qparams()
114        return ELU(float(scale), int(zero_point), mod.alpha)
115
116    @classmethod
117    def from_reference(cls, mod, scale, zero_point):
118        return cls(float(scale), int(zero_point), mod.alpha)
119
120
121class LeakyReLU(torch.nn.LeakyReLU):
122    r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
123
124    Args:
125        scale: quantization scale of the output tensor
126        zero_point: quantization zero point of the output tensor
127        negative_slope: Controls the angle of the negative slope. Default: 1e-2
128    """
129
130    def __init__(
131        self,
132        scale: float,
133        zero_point: int,
134        negative_slope: float = 1e-2,
135        inplace: bool = False,
136        device=None,
137        dtype=None,
138    ) -> None:
139        factory_kwargs = {"device": device, "dtype": dtype}
140        super().__init__(negative_slope, inplace)
141        self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
142        self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
143
144    def forward(self, input):
145        return torch.ops.quantized.leaky_relu(
146            input, self.negative_slope, self.inplace, self.scale, self.zero_point
147        )
148
149    def _get_name(self):
150        return "QuantizedLeakyReLU"
151
152    @classmethod
153    def from_float(cls, mod, use_precomputed_fake_quant=False):
154        scale, zero_point = mod.activation_post_process.calculate_qparams()
155        return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
156
157    @classmethod
158    def from_reference(cls, mod, scale, zero_point):
159        return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
160
161
162class Sigmoid(torch.nn.Sigmoid):
163    r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
164
165    Args:
166        scale: quantization scale of the output tensor
167        zero_point: quantization zero point of the output tensor
168    """
169
170    def __init__(self, output_scale: float, output_zero_point: int):
171        super().__init__()
172        self.output_scale = output_scale
173        self.output_zero_point = output_zero_point
174
175    def forward(self, input):
176        return torch.ops.quantized.sigmoid(
177            input, self.output_scale, self.output_zero_point
178        )
179
180    @classmethod
181    def from_float(cls, mod, use_precomputed_fake_quant=False):
182        (
183            output_scale,
184            output_zero_point,
185        ) = mod.activation_post_process.calculate_qparams()
186        return cls(float(output_scale), int(output_zero_point))
187
188
189class Softmax(torch.nn.Softmax):
190    r"""This is the quantized version of :class:`~torch.nn.Softmax`.
191
192    Args:
193        dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
194        scale: quantization scale of the output tensor
195        zero_point: quantization zero point of the output tensor
196    """
197
198    def __init__(self, dim=None, scale=1.0, zero_point=0):
199        super().__init__()
200        self.dim = dim
201        self.scale = scale
202        self.zero_point = zero_point
203
204    def forward(self, input):
205        dim = self.dim
206        if dim is None:
207            stacklevel = 3
208            # Note: adding the mypy ignore on _get_softmax_dim seems less bad
209            # than making `_get_softmax_dim` an official API.
210            dim = torch.nn.functional._get_softmax_dim(  # type: ignore[attr-defined]
211                "softmax", input.dim(), stacklevel
212            )
213        return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point)
214
215    def _get_name(self):
216        return "QuantizedSoftmax"
217
218    @staticmethod
219    def from_float(mod, use_precomputed_fake_quant=False):
220        scale, zero_point = mod.activation_post_process.calculate_qparams()
221        return Softmax(mod.dim, float(scale), int(zero_point))
222
223    @classmethod
224    def from_reference(cls, mod, scale, zero_point):
225        return cls(mod.dim, float(scale), int(zero_point))
226
227
228class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
229    _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
230
231    def _get_name(self):
232        return "QuantizedMultiheadAttention"
233
234    @classmethod
235    def from_float(cls, other):
236        # The whole flow is float -> observed -> quantized
237        # This class does observed -> quantized only
238        raise NotImplementedError(
239            "It looks like you are trying to convert a "
240            "non-observed MHA module. Please, see "
241            "the examples on quantizable MHAs."
242        )
243
244    @classmethod
245    def from_observed(cls, other):
246        converted = torch.ao.quantization.convert(
247            other,
248            mapping=None,
249            inplace=False,
250            remove_qconfig=True,
251            convert_custom_config_dict=None,
252        )
253        converted.__class__ = cls
254        # Remove the parameters for the bias_k and bias_v to quantize them
255        # TODO: This is a potential source of accuracy drop.
256        #       quantized cat takes the scale and zp of the first
257        #       element, which might lose the precision in the bias_k
258        #       and the bias_v (which are cat'ed with k/v being first).
259        if converted.bias_k is not None:
260            bias_k = converted._parameters.pop("bias_k")
261            sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False)
262            bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
263            setattr(converted, "bias_k", bias_k)  # noqa: B010
264
265        if converted.bias_v is not None:
266            bias_v = converted._parameters.pop("bias_v")
267            sc, zp = torch._choose_qparams_per_tensor(
268                bias_k, reduce_range=False  # type: ignore[possibly-undefined]
269            )
270            bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
271            setattr(converted, "bias_v", bias_v)  # noqa: B010
272
273        del converted.in_proj_weight
274        del converted.in_proj_bias
275
276        return converted
277
278
279class PReLU(torch.nn.Module):
280    r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
281
282    Args:
283        scale: quantization scale of the output tensor
284        zero_point: quantization zero point of the output tensor
285        num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
286    """
287
288    def __init__(
289        self, output_scale: float, output_zero_point: int, num_parameters: int = 1
290    ) -> None:
291        super().__init__()
292        self.num_parameters = num_parameters
293        self.scale = output_scale
294        self.zero_point = output_zero_point
295        w = torch.randn(num_parameters, dtype=torch.float)
296        qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
297        self.set_weight(qw)
298
299    def set_weight(self, w: torch.Tensor) -> None:
300        self.weight = w
301
302    def forward(self, input: torch.Tensor) -> torch.Tensor:
303        return torch.ops.quantized.prelu(
304            input, self.weight, self.scale, self.zero_point
305        )
306
307    def _get_name(self):
308        return "QuantizedPReLU"
309
310    @classmethod
311    def from_float(cls, mod, use_precomputed_fake_quant=False):
312        scale, zero_point = mod.activation_post_process.calculate_qparams()
313        qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
314        float_wt = mod.weight.float()
315        observer = mod.qconfig.weight()
316        observer(float_wt)
317        if observer.dtype != torch.quint8:
318            warn(
319                f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
320            )
321        wt_scale, wt_zp = observer.calculate_qparams()
322        qweight = torch.quantize_per_tensor(
323            float_wt, float(wt_scale), int(wt_zp), torch.quint8
324        )
325        qprelu.set_weight(qweight)
326        return qprelu
327
328    @classmethod
329    def from_reference(cls, mod, scale, zero_point):
330        qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
331        float_wt = mod.weight.float()
332        observer = mod.qconfig.weight()
333        observer(float_wt)
334        if observer.dtype != torch.quint8:
335            warn(
336                f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
337            )
338        wt_scale, wt_zp = observer.calculate_qparams()
339        qweight = torch.quantize_per_tensor(
340            float_wt, float(wt_scale), int(wt_zp), torch.quint8
341        )
342        qprelu.set_weight(qweight)
343        return qprelu
344