xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/experimental/qconfig.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2from torch.ao.quantization import MinMaxObserver
3from torch.ao.quantization.experimental.fake_quantize import APoTFakeQuantize
4from torch.ao.quantization.fake_quantize import FakeQuantize
5from torch.ao.quantization.qconfig import QConfig
6
7
8"""
9Default symmetric fake_quant for activations.
10"""
11default_symmetric_fake_quant = FakeQuantize.with_args(
12    observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
13)
14
15"""
16Default symmetric fake_quant for weights.
17"""
18default_weight_symmetric_fake_quant = FakeQuantize.with_args(
19    observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
20)
21
22# uniform activation and weight, b=8 k=2
23uniform_qconfig_8bit = QConfig(
24    activation=default_symmetric_fake_quant,
25    weight=default_weight_symmetric_fake_quant.with_args,
26)
27
28# uniform activation, APoT weight, b=8 k=2
29apot_weight_qconfig_8bit = QConfig(
30    activation=default_symmetric_fake_quant.with_args,
31    weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8),
32)
33
34# APoT activation and uniform weight, b=8 k=2
35apot_qconfig_8bit = QConfig(
36    activation=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.quint8),
37    weight=APoTFakeQuantize.with_args(b=8, k=2, dtype=torch.qint8),
38)
39
40# uniform activation and weight, b=4 k=2
41uniform_qconfig_4bit = QConfig(
42    activation=default_symmetric_fake_quant.with_args(quant_min=0, quant_max=15),
43    weight=default_weight_symmetric_fake_quant.with_args(quant_min=0, quant_max=15),
44)
45
46# uniform activation, APoT weight, b=4 k=2
47apot_weight_qconfig_4bit = QConfig(
48    activation=default_symmetric_fake_quant.with_args(quant_min=0, quant_max=15),
49    weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8),
50)
51
52# APoT activation and uniform weight, b=4 k=2
53apot_qconfig_4bit = QConfig(
54    activation=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.quint8),
55    weight=APoTFakeQuantize.with_args(b=4, k=2, dtype=torch.qint8),
56)
57