xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_backend_config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4import torch.ao.nn.intrinsic as nni
5import torch.ao.nn.qat as nnqat
6import torch.ao.nn.quantized.reference as nnqr
7from torch.testing._internal.common_quantization import QuantizationTestCase
8
9from torch.ao.quantization.backend_config import (
10    BackendConfig,
11    BackendPatternConfig,
12    DTypeConfig,
13    DTypeWithConstraints,
14    ObservationType,
15)
16from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2
17from torch.ao.quantization.fx.quantize_handler import _default_root_node_getter
18
19
20class TestBackendConfig(QuantizationTestCase):
21
22    # =============
23    #  DTypeConfig
24    # =============
25
26    dtype_config1 = DTypeConfig(
27        input_dtype=torch.quint8,
28        output_dtype=torch.quint8,
29        weight_dtype=torch.qint8,
30        bias_dtype=torch.float
31    )
32
33    dtype_config2 = DTypeConfig(
34        input_dtype=torch.float16,
35        output_dtype=torch.float,
36        is_dynamic=True
37    )
38
39    activation_dtype_with_constraints = DTypeWithConstraints(
40        dtype=torch.quint8,
41        quant_min_lower_bound=0,
42        quant_max_upper_bound=127,
43        scale_min_lower_bound=2 ** -12,
44    )
45
46    weight_dtype_with_constraints = DTypeWithConstraints(
47        dtype=torch.qint8,
48        quant_min_lower_bound=-128,
49        quant_max_upper_bound=127,
50        scale_min_lower_bound=2 ** -12,
51    )
52
53    dtype_config3 = DTypeConfig(
54        input_dtype=activation_dtype_with_constraints,
55        output_dtype=activation_dtype_with_constraints,
56        weight_dtype=weight_dtype_with_constraints,
57    )
58
59    dtype_config_dict1_legacy = {
60        "input_dtype": torch.quint8,
61        "output_dtype": torch.quint8,
62        "weight_dtype": torch.qint8,
63        "bias_dtype": torch.float,
64    }
65
66    dtype_config_dict2_legacy = {
67        "input_dtype": torch.float16,
68        "output_dtype": torch.float,
69        "is_dynamic": True,
70    }
71
72    dtype_config_dict1 = {
73        "input_dtype": DTypeWithConstraints(dtype=torch.quint8),
74        "output_dtype": DTypeWithConstraints(torch.quint8),
75        "weight_dtype": DTypeWithConstraints(torch.qint8),
76        "bias_dtype": torch.float,
77    }
78
79    dtype_config_dict2 = {
80        "input_dtype": DTypeWithConstraints(dtype=torch.float16),
81        "output_dtype": DTypeWithConstraints(dtype=torch.float),
82        "is_dynamic": True,
83    }
84
85    dtype_config_dict3 = {
86        "input_dtype": activation_dtype_with_constraints,
87        "output_dtype": activation_dtype_with_constraints,
88        "weight_dtype": weight_dtype_with_constraints,
89    }
90
91    def test_dtype_config_from_dict(self):
92        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1_legacy), self.dtype_config1)
93        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2_legacy), self.dtype_config2)
94        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1), self.dtype_config1)
95        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2), self.dtype_config2)
96        self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict3), self.dtype_config3)
97
98    def test_dtype_config_to_dict(self):
99        self.assertEqual(self.dtype_config1.to_dict(), self.dtype_config_dict1)
100        self.assertEqual(self.dtype_config2.to_dict(), self.dtype_config_dict2)
101        self.assertEqual(self.dtype_config3.to_dict(), self.dtype_config_dict3)
102
103    # ======================
104    #  BackendPatternConfig
105    # ======================
106
107    _fuser_method = _sequential_wrapper2(nni.LinearReLU)
108
109    _num_tensor_args_to_observation_type = {
110        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
111        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
112        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
113    }
114    _input_type_to_index = {
115        "bias": 0,
116        "input": 1,
117        "weight": 2,
118    }
119
120    def _extra_inputs_getter(self, p):
121        return (torch.rand(3, 3),)
122
123    def _get_backend_op_config1(self):
124        return BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) \
125            .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
126            .add_dtype_config(self.dtype_config1) \
127            .add_dtype_config(self.dtype_config2) \
128            .set_root_module(torch.nn.Linear) \
129            .set_qat_module(nnqat.Linear) \
130            .set_reference_quantized_module(nnqr.Linear) \
131            .set_fused_module(nni.LinearReLU) \
132            .set_fuser_method(self._fuser_method)
133
134    def _get_backend_op_config2(self):
135        return BackendPatternConfig(torch.add) \
136            .add_dtype_config(self.dtype_config2) \
137            ._set_root_node_getter(_default_root_node_getter) \
138            ._set_extra_inputs_getter(self._extra_inputs_getter) \
139            ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \
140            ._set_input_type_to_index(self._input_type_to_index)
141
142    def _get_backend_pattern_config_dict1(self):
143        return {
144            "pattern": (torch.nn.Linear, torch.nn.ReLU),
145            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
146            "dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2],
147            "root_module": torch.nn.Linear,
148            "qat_module": nnqat.Linear,
149            "reference_quantized_module_for_root": nnqr.Linear,
150            "fused_module": nni.LinearReLU,
151            "fuser_method": self._fuser_method,
152        }
153
154    def _get_backend_pattern_config_dict2(self):
155        return {
156            "pattern": torch.add,
157            "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
158            "dtype_configs": [self.dtype_config_dict2],
159            "root_node_getter": _default_root_node_getter,
160            "extra_inputs_getter": self._extra_inputs_getter,
161            "num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type,
162            "input_type_to_index": self._input_type_to_index,
163        }
164
165    def test_backend_op_config_set_observation_type(self):
166        conf = BackendPatternConfig(torch.nn.Linear)
167        self.assertEqual(conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
168        conf.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
169        self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
170
171    def test_backend_op_config_add_dtype_config(self):
172        conf = BackendPatternConfig(torch.nn.Linear)
173        self.assertEqual(len(conf.dtype_configs), 0)
174        conf.add_dtype_config(self.dtype_config1)
175        conf.add_dtype_config(self.dtype_config2)
176        self.assertEqual(len(conf.dtype_configs), 2)
177        self.assertEqual(conf.dtype_configs[0], self.dtype_config1)
178        self.assertEqual(conf.dtype_configs[1], self.dtype_config2)
179
180    def test_backend_op_config_set_root_module(self):
181        conf = BackendPatternConfig(nni.LinearReLU)
182        self.assertTrue(conf.root_module is None)
183        conf.set_root_module(torch.nn.Linear)
184        self.assertEqual(conf.root_module, torch.nn.Linear)
185
186    def test_backend_op_config_set_qat_module(self):
187        conf = BackendPatternConfig(torch.nn.Linear)
188        self.assertTrue(conf.qat_module is None)
189        conf.set_qat_module(nnqat.Linear)
190        self.assertEqual(conf.qat_module, nnqat.Linear)
191
192    def test_backend_op_config_set_reference_quantized_module(self):
193        conf = BackendPatternConfig(torch.nn.Linear)
194        self.assertTrue(conf.reference_quantized_module is None)
195        conf.set_reference_quantized_module(nnqr.Linear)
196        self.assertEqual(conf.reference_quantized_module, nnqr.Linear)
197
198    def test_backend_op_config_set_fused_module(self):
199        conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
200        self.assertTrue(conf.fused_module is None)
201        conf.set_fused_module(nni.LinearReLU)
202        self.assertEqual(conf.fused_module, nni.LinearReLU)
203
204    def test_backend_op_config_set_fuser_method(self):
205        conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
206        self.assertTrue(conf.fuser_method is None)
207        conf.set_fuser_method(self._fuser_method)
208        self.assertEqual(conf.fuser_method, self._fuser_method)
209
210    def test_backend_op_config_set_root_node_getter(self):
211        conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
212        self.assertTrue(conf._root_node_getter is None)
213        conf._set_root_node_getter(_default_root_node_getter)
214        self.assertEqual(conf._root_node_getter, _default_root_node_getter)
215
216    def test_backend_op_config_set_extra_inputs_getter(self):
217        conf = BackendPatternConfig(torch.nn.Linear)
218        self.assertTrue(conf._extra_inputs_getter is None)
219        conf._set_extra_inputs_getter(self._extra_inputs_getter)
220        self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter)
221
222    def test_backend_op_config_set_num_tensor_args_to_observation_type(self):
223        conf = BackendPatternConfig(torch.add)
224        self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0)
225        conf._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type)
226        self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
227
228    def test_backend_op_config_set_input_type_to_index(self):
229        conf = BackendPatternConfig(torch.addmm)
230        self.assertEqual(len(conf._input_type_to_index), 0)
231        conf._set_input_type_to_index(self._input_type_to_index)
232        self.assertEqual(conf._input_type_to_index, self._input_type_to_index)
233
234    def test_backend_op_config_from_dict(self):
235        conf_dict1 = self._get_backend_pattern_config_dict1()
236        conf1 = BackendPatternConfig.from_dict(conf_dict1)
237        self.assertEqual(conf1.pattern, (torch.nn.Linear, torch.nn.ReLU))
238        self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
239        self.assertEqual(conf1.root_module, torch.nn.Linear)
240        self.assertEqual(conf1.qat_module, nnqat.Linear)
241        self.assertEqual(conf1.reference_quantized_module, nnqr.Linear)
242        self.assertEqual(conf1.fused_module, nni.LinearReLU)
243        self.assertEqual(conf1.fuser_method, self._fuser_method)
244        self.assertTrue(conf1._root_node_getter is None)
245        self.assertTrue(conf1._extra_inputs_getter is None)
246        self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0)
247        self.assertEqual(len(conf1._input_type_to_index), 0)
248        # Test temporary/internal keys
249        conf_dict2 = self._get_backend_pattern_config_dict2()
250        conf2 = BackendPatternConfig.from_dict(conf_dict2)
251        self.assertEqual(conf2.pattern, torch.add)
252        self.assertEqual(conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)
253        self.assertTrue(conf2.root_module is None)
254        self.assertTrue(conf2.qat_module is None)
255        self.assertTrue(conf2.reference_quantized_module is None)
256        self.assertTrue(conf2.fused_module is None)
257        self.assertTrue(conf2.fuser_method is None)
258        self.assertEqual(conf2._root_node_getter, _default_root_node_getter)
259        self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter)
260        self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type)
261        self.assertEqual(conf2._input_type_to_index, self._input_type_to_index)
262
263    def test_backend_op_config_to_dict(self):
264        conf1 = self._get_backend_op_config1()
265        conf2 = self._get_backend_op_config2()
266        conf_dict1 = self._get_backend_pattern_config_dict1()
267        conf_dict2 = self._get_backend_pattern_config_dict2()
268        self.assertEqual(conf1.to_dict(), conf_dict1)
269        self.assertEqual(conf2.to_dict(), conf_dict2)
270
271    # ===============
272    #  BackendConfig
273    # ===============
274
275    def test_backend_config_set_name(self):
276        conf = BackendConfig("name1")
277        self.assertEqual(conf.name, "name1")
278        conf.set_name("name2")
279        self.assertEqual(conf.name, "name2")
280
281    def test_backend_config_set_backend_pattern_config(self):
282        conf = BackendConfig("name1")
283        self.assertEqual(len(conf.configs), 0)
284        backend_op_config1 = self._get_backend_op_config1()
285        backend_op_config2 = self._get_backend_op_config2()
286        conf.set_backend_pattern_config(backend_op_config1)
287        self.assertEqual(conf._pattern_complex_format_to_config, {
288            (torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
289        })
290        conf.set_backend_pattern_config(backend_op_config2)
291        self.assertEqual(conf._pattern_complex_format_to_config, {
292            (torch.nn.ReLU, torch.nn.Linear): backend_op_config1,
293            torch.add: backend_op_config2
294        })
295
296    def test_backend_config_from_dict(self):
297        op1 = self._get_backend_op_config1()
298        op2 = self._get_backend_op_config2()
299        op_dict1 = self._get_backend_pattern_config_dict1()
300        op_dict2 = self._get_backend_pattern_config_dict2()
301        conf_dict = {
302            "name": "name1",
303            "configs": [op_dict1, op_dict2],
304        }
305        conf = BackendConfig.from_dict(conf_dict)
306        self.assertEqual(conf.name, "name1")
307        self.assertEqual(len(conf.configs), 2)
308        key1 = (torch.nn.ReLU, torch.nn.Linear)
309        key2 = torch.add
310        self.assertTrue(key1 in conf._pattern_complex_format_to_config)
311        self.assertTrue(key2 in conf._pattern_complex_format_to_config)
312        self.assertEqual(conf._pattern_complex_format_to_config[key1].to_dict(), op_dict1)
313        self.assertEqual(conf._pattern_complex_format_to_config[key2].to_dict(), op_dict2)
314
315    def test_backend_config_to_dict(self):
316        op1 = self._get_backend_op_config1()
317        op2 = self._get_backend_op_config2()
318        op_dict1 = self._get_backend_pattern_config_dict1()
319        op_dict2 = self._get_backend_pattern_config_dict2()
320        conf = BackendConfig("name1").set_backend_pattern_config(op1).set_backend_pattern_config(op2)
321        conf_dict = {
322            "name": "name1",
323            "configs": [op_dict1, op_dict2],
324        }
325        self.assertEqual(conf.to_dict(), conf_dict)
326
327if __name__ == '__main__':
328    raise RuntimeError("This _test file is not meant to be run directly, use:\n\n"
329                       "\tpython _test/_test_quantization.py TESTNAME\n\n"
330                       "instead.")
331