xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantized/reference/modules/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import typing
3
4import torch
5
6
7__all__ = [
8    "ReferenceQuantizedModule",
9]
10
11
12class ReferenceQuantizedModule(torch.nn.Module):
13    def _init_weight_qparams(self, weight_qparams, device):
14        if weight_qparams is None:
15            weight_qparams = {
16                "qscheme": torch.per_tensor_affine,
17                "dtype": torch.quint8,
18                "scale": 1.0,
19                "zero_point": 0,
20            }
21        self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
22        self.weight_dtype = weight_qparams["dtype"]
23        assert self.weight_qscheme in [
24            None,
25            torch.per_tensor_affine,
26            torch.per_channel_affine,
27            torch.per_channel_affine_float_qparams,
28        ], f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}"
29        if self.weight_dtype in [
30            torch.quint8,
31            torch.qint8,
32            torch.quint4x2,
33            torch.qint32,
34        ]:
35            zero_point_dtype = (
36                weight_qparams["zero_point"].dtype
37                if isinstance(weight_qparams["zero_point"], torch.Tensor)
38                else torch.int
39            )
40            w_scale = weight_qparams["scale"]
41            w_scale_tensor = (
42                w_scale.clone().detach()
43                if isinstance(w_scale, torch.Tensor)
44                else torch.tensor(w_scale, dtype=torch.float, device=device)
45            )
46            self.register_buffer("weight_scale", w_scale_tensor)
47            w_zp = weight_qparams["zero_point"]
48            w_zp_tensor = (
49                w_zp.clone().detach()
50                if isinstance(w_zp, torch.Tensor)
51                else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
52            )
53            self.register_buffer("weight_zero_point", w_zp_tensor)
54            if self.weight_qscheme in [
55                torch.per_channel_affine,
56                torch.per_channel_affine_float_qparams,
57            ]:
58                w_axis = weight_qparams["axis"]
59                w_axis_tensor = (
60                    w_axis.clone().detach()
61                    if isinstance(w_axis, torch.Tensor)
62                    else torch.tensor(w_axis, dtype=torch.int, device=device)
63                )
64                self.register_buffer("weight_axis", w_axis_tensor)
65            else:
66                # added for TorchScriptability, not used
67                self.register_buffer(
68                    "weight_axis", torch.tensor(0, dtype=torch.int, device=device)
69                )
70        else:
71            # added for TorchScriptability, and for torch.float
72            self.register_buffer(
73                "weight_scale", torch.tensor(1.0, dtype=torch.float, device=device)
74            )
75            self.register_buffer(
76                "weight_zero_point", torch.tensor(0, dtype=torch.int, device=device)
77            )
78            self.register_buffer(
79                "weight_axis", torch.tensor(0, dtype=torch.int, device=device)
80            )
81        self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
82        # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
83        # for capturing `.item` operations
84        self.weight_axis_int: int = self.weight_axis.item()  # type: ignore[operator, assignment]
85        self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
86            "quant_min", None
87        )
88        self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
89            "quant_max", None
90        )
91
92    def get_weight(self):
93        """
94        Fake quantize (quantize and dequantize) the weight with
95        the quantization parameters for weight, this is used to
96        simulate the numerics for the quantized weight in a quantized
97        model
98        """
99        # suppress mypy warning
100        assert isinstance(self.weight_scale, torch.Tensor)
101        assert isinstance(self.weight_zero_point, torch.Tensor)
102        if self.is_decomposed:
103            return _quantize_and_dequantize_weight_decomposed(
104                self.weight,  # type: ignore[arg-type]
105                self.weight_qscheme,
106                self.weight_dtype,
107                self.weight_scale,
108                self.weight_zero_point,
109                self.weight_axis_int,
110                self.weight_quant_min,
111                self.weight_quant_max,
112            )
113        else:
114            return _quantize_and_dequantize_weight(
115                self.weight,  # type: ignore[arg-type]
116                self.weight_qscheme,
117                self.weight_dtype,
118                self.weight_scale,
119                self.weight_zero_point,
120                self.weight_axis_int,
121            )
122
123    def get_quantized_weight(self):
124        # suppress mypy warning
125        assert isinstance(self.weight_scale, torch.Tensor)
126        assert isinstance(self.weight_zero_point, torch.Tensor)
127        # assert isinstance(self.weight_axis, torch.Tensor)
128        if self.is_decomposed:
129            return _quantize_weight_decomposed(
130                self.weight,  # type: ignore[arg-type]
131                self.weight_qscheme,
132                self.weight_dtype,
133                self.weight_scale,
134                self.weight_zero_point,
135                self.weight_axis_int,
136                self.weight_quant_min,
137                self.weight_quant_max,
138            )
139        else:
140            return _quantize_weight(
141                self.weight,  # type: ignore[arg-type]
142                self.weight_qscheme,
143                self.weight_dtype,
144                self.weight_scale,
145                self.weight_zero_point,
146                self.weight_axis_int,
147            )
148
149    def _save_to_state_dict(self, destination, prefix, keep_vars):
150        super()._save_to_state_dict(destination, prefix, keep_vars)
151        _save_weight_qparams(
152            destination,
153            prefix,
154            self.weight_qscheme,
155            self.weight_dtype,
156            self.weight_scale,
157            self.weight_zero_point,
158            self.weight_axis,
159        )
160
161    def _load_from_state_dict(
162        self,
163        state_dict,
164        prefix,
165        local_metadata,
166        strict,
167        missing_keys,
168        unexpected_keys,
169        error_msgs,
170    ):
171        for key in _get_weight_qparam_keys(state_dict, prefix):
172            setattr(self, key, state_dict[prefix + key])
173            state_dict.pop(prefix + key)
174
175        super()._load_from_state_dict(
176            state_dict,
177            prefix,
178            local_metadata,
179            False,
180            missing_keys,
181            unexpected_keys,
182            error_msgs,
183        )
184
185
186def _quantize_weight_decomposed(
187    weight: torch.Tensor,
188    weight_qscheme: torch.qscheme,
189    weight_dtype: torch.dtype,
190    weight_scale: torch.Tensor,
191    weight_zero_point: torch.Tensor,
192    weight_axis: int,
193    weight_quant_min: typing.Optional[int],
194    weight_quant_max: typing.Optional[int],
195) -> torch.Tensor:
196    _DTYPE_TO_QVALUE_BOUNDS = {
197        torch.uint8: (0, 255),
198        torch.int8: (-128, 127),
199        torch.int32: (-(2**31), 2**31 - 1),
200    }
201    # TODO: add an util function for converting qdtype to dtype
202    _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
203        torch.quint8: torch.uint8,
204        torch.qint8: torch.int8,
205        torch.qint32: torch.int32,
206    }
207    if weight_qscheme == torch.per_tensor_affine:
208        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
209            weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
210            if weight_quant_min is None or weight_quant_max is None:
211                weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[
212                    weight_dtype_
213                ]
214            weight = torch.ops.quantized_decomposed.quantize_per_tensor(
215                weight,
216                weight_scale,
217                weight_zero_point,
218                weight_quant_min,
219                weight_quant_max,
220                weight_dtype_,
221            )
222            return weight
223    elif weight_qscheme in [
224        torch.per_channel_affine,
225        torch.per_channel_affine_float_qparams,
226    ]:
227        # TODO: torch.quint4x2 is not supported
228        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
229            weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
230            if weight_quant_min is None or weight_quant_max is None:
231                weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[
232                    weight_dtype_
233                ]
234            weight = torch.ops.quantized_decomposed.quantize_per_channel(
235                weight,
236                weight_scale,
237                weight_zero_point,
238                weight_axis,
239                weight_quant_min,
240                weight_quant_max,
241                weight_dtype_,
242            )  # type: ignore[arg-type]
243            return weight
244    raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
245
246
247def _dequantize_weight_decomposed(
248    weight: torch.Tensor,
249    weight_qscheme: torch.qscheme,
250    weight_dtype: torch.dtype,
251    weight_scale: torch.Tensor,
252    weight_zero_point: torch.Tensor,
253    weight_axis: int,
254    weight_quant_min: typing.Optional[int],
255    weight_quant_max: typing.Optional[int],
256) -> torch.Tensor:
257    # TODO: get the quant_min and quant_max from activation_post_process
258    _DTYPE_TO_QVALUE_BOUNDS = {
259        torch.uint8: (0, 255),
260        torch.int8: (-128, 127),
261        torch.int32: (-(2**31), 2**31 - 1),
262    }
263    # TODO: add an util function for converting qdtype to dtype
264    _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
265        torch.quint8: torch.uint8,
266        torch.qint8: torch.int8,
267        torch.qint32: torch.int32,
268    }
269    weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
270    if weight_quant_min is None or weight_quant_max is None:
271        weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
272    if weight_qscheme == torch.per_tensor_affine:
273        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
274            weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
275                weight,
276                weight_scale,
277                weight_zero_point,
278                weight_quant_min,
279                weight_quant_max,
280                weight_dtype_,
281            )
282            return weight
283    elif weight_qscheme in [
284        torch.per_channel_affine,
285        torch.per_channel_affine_float_qparams,
286    ]:
287        # TODO: torch.quint4x2 is not supported
288        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
289            weight = torch.ops.quantized_decomposed.dequantize_per_channel(
290                weight,
291                weight_scale,
292                weight_zero_point,
293                weight_axis,
294                weight_quant_min,
295                weight_quant_max,
296                weight_dtype_,
297            )  # type: ignore[arg-type]
298            return weight
299    raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
300
301
302def _quantize_weight(
303    weight: torch.Tensor,
304    weight_qscheme: torch.qscheme,
305    weight_dtype: torch.dtype,
306    weight_scale: torch.Tensor,
307    weight_zero_point: torch.Tensor,
308    weight_axis_int: int,
309) -> torch.Tensor:
310    if weight_dtype == torch.float16:
311        weight = weight.to(weight_dtype)
312        return weight
313
314    if weight_qscheme == torch.per_tensor_affine:
315        if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
316            weight = torch.quantize_per_tensor(
317                weight, weight_scale, weight_zero_point, weight_dtype
318            )
319            return weight
320    elif weight_qscheme in [
321        torch.per_channel_affine,
322        torch.per_channel_affine_float_qparams,
323    ]:
324        if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
325            weight = torch.quantize_per_channel(
326                weight, weight_scale, weight_zero_point, weight_axis_int, weight_dtype
327            )  # type: ignore[arg-type]
328            return weight
329    raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
330
331
332def _quantize_and_dequantize_weight_decomposed(
333    weight: torch.Tensor,
334    weight_qscheme: torch.qscheme,
335    weight_dtype: torch.dtype,
336    weight_scale: torch.Tensor,
337    weight_zero_point: torch.Tensor,
338    weight_axis_int: int,
339    weight_quant_min: typing.Optional[int],
340    weight_quant_max: typing.Optional[int],
341) -> torch.Tensor:
342    """Quantize and then dequantize the weight based on
343    the quantization parameters
344    """
345    if weight_qscheme in [
346        torch.per_tensor_affine,
347        torch.per_channel_affine,
348        torch.per_channel_affine_float_qparams,
349    ]:
350        weight_quant = _quantize_weight_decomposed(
351            weight,
352            weight_qscheme,
353            weight_dtype,
354            weight_scale,
355            weight_zero_point,
356            weight_axis_int,
357            weight_quant_min,
358            weight_quant_max,
359        )
360        weight_dequant = _dequantize_weight_decomposed(
361            weight_quant,
362            weight_qscheme,
363            weight_dtype,
364            weight_scale,
365            weight_zero_point,
366            weight_axis_int,
367            weight_quant_min,
368            weight_quant_max,
369        )
370    else:
371        weight_dequant = weight
372    return weight_dequant
373
374
375def _quantize_and_dequantize_weight(
376    weight: torch.Tensor,
377    weight_qscheme: torch.qscheme,
378    weight_dtype: torch.dtype,
379    weight_scale: torch.Tensor,
380    weight_zero_point: torch.Tensor,
381    weight_axis_int: int,
382) -> torch.Tensor:
383    """Quantize and then dequantize the weight based on
384    the quantization parameters
385    """
386    if weight_qscheme in [
387        torch.per_tensor_affine,
388        torch.per_channel_affine,
389        torch.per_channel_affine_float_qparams,
390    ]:
391        weight_quant = _quantize_weight(
392            weight,
393            weight_qscheme,
394            weight_dtype,
395            weight_scale,
396            weight_zero_point,
397            weight_axis_int,
398        )
399        weight_dequant = weight_quant.dequantize()
400    else:
401        weight_dequant = weight
402    return weight_dequant
403
404
405def _save_weight_qparams(
406    destination,
407    prefix,
408    weight_qscheme,
409    weight_dtype,
410    weight_scale,
411    weight_zero_point,
412    weight_axis,
413):
414    destination[prefix + "weight_qscheme"] = weight_qscheme
415    destination[prefix + "weight_dtype"] = weight_dtype
416    if weight_qscheme is not None:
417        destination[prefix + "weight_scale"] = weight_scale
418        destination[prefix + "weight_zero_point"] = weight_zero_point
419        if weight_qscheme == torch.per_channel_affine:
420            destination[prefix + "weight_axis"] = weight_axis
421
422
423def _get_weight_qparam_keys(state_dict: typing.Dict[str, typing.Any], prefix: str):
424    keys = ["weight_qscheme", "weight_dtype"]
425    weight_qscheme = state_dict[prefix + "weight_qscheme"]
426    if weight_qscheme is not None:
427        keys.append("weight_scale")
428        keys.append("weight_zero_point")
429        if weight_qscheme == torch.quantize_per_channel:
430            keys.append("weight_axis")
431    return keys
432