xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4from torch.testing._internal.common_utils import TestCase
5from torch.ao.quantization.utils import get_fqn_to_example_inputs
6from torch.ao.nn.quantized.modules.utils import _quantize_weight
7from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
8
9
10class TestUtils(TestCase):
11    def _test_get_fqn_to_example_inputs(self, M, example_inputs, expected_fqn_to_dim):
12        m = M().eval()
13        fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
14        for fqn, expected_dims in expected_fqn_to_dim.items():
15            assert fqn in expected_fqn_to_dim
16            example_inputs = fqn_to_example_inputs[fqn]
17            for example_input, expected_dim in zip(example_inputs, expected_dims):
18                assert example_input.dim() == expected_dim
19
20    def test_get_fqn_to_example_inputs_simple(self):
21        class Sub(torch.nn.Module):
22            def __init__(self) -> None:
23                super().__init__()
24                self.linear1 = torch.nn.Linear(5, 5)
25                self.linear2 = torch.nn.Linear(5, 5)
26
27            def forward(self, x):
28                x = self.linear1(x)
29                x = self.linear2(x)
30                return x
31
32        class M(torch.nn.Module):
33            def __init__(self) -> None:
34                super().__init__()
35                self.linear1 = torch.nn.Linear(5, 5)
36                self.linear2 = torch.nn.Linear(5, 5)
37                self.sub = Sub()
38
39            def forward(self, x):
40                x = self.linear1(x)
41                x = self.linear2(x)
42                x = self.sub(x)
43                return x
44
45        expected_fqn_to_dim = {
46            "": (2,),
47            "linear1": (2,),
48            "linear2": (2,),
49            "sub": (2,),
50            "sub.linear1": (2,),
51            "sub.linear2": (2,)
52        }
53        example_inputs = (torch.rand(1, 5),)
54        self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
55
56    def test_get_fqn_to_example_inputs_default_kwargs(self):
57        """ Test that we can get example inputs for functions with default keyword arguments
58        """
59        class Sub(torch.nn.Module):
60            def __init__(self) -> None:
61                super().__init__()
62                self.linear1 = torch.nn.Linear(5, 5)
63                self.linear2 = torch.nn.Linear(5, 5)
64
65            def forward(self, x, key1=torch.rand(1), key2=torch.rand(1)):
66                x = self.linear1(x)
67                x = self.linear2(x)
68                return x
69
70        class M(torch.nn.Module):
71            def __init__(self) -> None:
72                super().__init__()
73                self.linear1 = torch.nn.Linear(5, 5)
74                self.linear2 = torch.nn.Linear(5, 5)
75                self.sub = Sub()
76
77            def forward(self, x):
78                x = self.linear1(x)
79                x = self.linear2(x)
80                # only override `key2`, `key1` will use default
81                x = self.sub(x, key2=torch.rand(1, 2))
82                return x
83
84        expected_fqn_to_dim = {
85            "": (2,),
86            "linear1": (2,),
87            "linear2": (2,),
88            # second arg is `key1`, which is using default argument
89            # third arg is `key2`, override by callsite
90            "sub": (2, 1, 2),
91            "sub.linear1": (2,),
92            "sub.linear2": (2,)
93        }
94        example_inputs = (torch.rand(1, 5),)
95        self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
96
97    def test_get_fqn_to_example_inputs_complex_args(self):
98        """ Test that we can record complex example inputs such as lists and dicts
99        """
100        class Sub(torch.nn.Module):
101            def __init__(self) -> None:
102                super().__init__()
103                self.linear1 = torch.nn.Linear(5, 5)
104                self.linear2 = torch.nn.Linear(5, 5)
105
106            def forward(self, x, list_arg, dict_arg):
107                x = self.linear1(x)
108                x = self.linear2(x)
109                return x
110
111        class M(torch.nn.Module):
112            def __init__(self) -> None:
113                super().__init__()
114                self.linear1 = torch.nn.Linear(5, 5)
115                self.linear2 = torch.nn.Linear(5, 5)
116                self.sub = Sub()
117
118            def forward(self, x):
119                x = self.linear1(x)
120                x = self.linear2(x)
121                x = self.sub(x, [x], {"3": x})
122                return x
123
124        example_inputs = (torch.rand(1, 5),)
125        m = M().eval()
126        fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
127        assert "sub" in fqn_to_example_inputs
128        assert isinstance(fqn_to_example_inputs["sub"][1], list)
129        assert isinstance(fqn_to_example_inputs["sub"][2], dict) and \
130            "3" in fqn_to_example_inputs["sub"][2]
131
132    def test_quantize_weight_clamping_per_tensor(self):
133        """ Test quant_{min, max} from per tensor observer is honored by `_quantize_weight` method
134        """
135        fp_min, fp_max = -1000.0, 1000.0
136        q8_min, q8_max = -10, 10
137
138        float_tensor = torch.tensor([fp_min, fp_max])
139
140        observer = MovingAverageMinMaxObserver(
141            averaging_constant=1.0,
142            dtype=torch.qint8,
143            quant_min=q8_min,
144            quant_max=q8_max,
145            qscheme=torch.per_tensor_symmetric,
146        )
147
148        observer(float_tensor)
149        assert observer.min_val == fp_min
150        assert observer.max_val == fp_max
151
152        quantized_tensor = _quantize_weight(float_tensor, observer)
153        assert quantized_tensor.int_repr().max().item() == q8_max
154        assert quantized_tensor.int_repr().min().item() == q8_min
155
156        # Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
157        float_tensor *= 1.2
158
159        quantized_tensor = _quantize_weight(float_tensor, observer)
160        assert quantized_tensor.int_repr().max().item() == q8_max
161        assert quantized_tensor.int_repr().min().item() == q8_min
162
163    def test_quantize_weight_clamping_per_channel(self):
164        """ Test quant_{min, max} from per channel observer is honored by `_quantize_weight` method
165        """
166        fp_min, fp_max = -1000.0, 1000.0
167        q8_min, q8_max = -10, 10
168
169        float_tensor = torch.tensor([[fp_min, fp_max]])
170
171        observer = MovingAveragePerChannelMinMaxObserver(
172            averaging_constant=1.0,
173            dtype=torch.qint8,
174            quant_min=q8_min,
175            quant_max=q8_max,
176            qscheme=torch.per_channel_symmetric,
177            ch_axis=0,
178        )
179
180        observer(float_tensor)
181        assert observer.min_val == fp_min
182        assert observer.max_val == fp_max
183
184        quantized_tensor = _quantize_weight(float_tensor, observer)
185        assert quantized_tensor.int_repr().max().item() == q8_max
186        assert quantized_tensor.int_repr().min().item() == q8_min
187
188        # Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
189        float_tensor *= 1.2
190
191        quantized_tensor = _quantize_weight(float_tensor, observer)
192        assert quantized_tensor.int_repr().max().item() == q8_max
193        assert quantized_tensor.int_repr().min().item() == q8_min
194
195    def test_uint1_7_dtype(self):
196
197        def up_size(size):
198            return (*size[:-1], size[-1] * 2)
199
200        class UInt4Tensor(torch.Tensor):
201            @staticmethod
202            def __new__(cls, elem, **kwargs):
203                assert elem.dtype is torch.uint8
204                assert not kwargs.get("requires_grad", False)
205                kwargs["requires_grad"] = False
206                return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs)
207
208            def __init__(self, elem):
209                self.elem = elem
210
211            @classmethod
212            def __torch_dispatch__(cls, func, types, args, kwargs=None):
213                pass
214
215        # make sure it runs
216        x = UInt4Tensor(torch.tensor([
217            [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
218            [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
219            [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
220        ], dtype=torch.uint8))
221        assert x.dtype == torch.uint4
222