1# Owner(s): ["oncall: quantization"] 2 3import torch 4import torch.ao.nn.intrinsic as nni 5import torch.ao.nn.qat as nnqat 6import torch.ao.nn.quantized.reference as nnqr 7from torch.testing._internal.common_quantization import QuantizationTestCase 8 9from torch.ao.quantization.backend_config import ( 10 BackendConfig, 11 BackendPatternConfig, 12 DTypeConfig, 13 DTypeWithConstraints, 14 ObservationType, 15) 16from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2 17from torch.ao.quantization.fx.quantize_handler import _default_root_node_getter 18 19 20class TestBackendConfig(QuantizationTestCase): 21 22 # ============= 23 # DTypeConfig 24 # ============= 25 26 dtype_config1 = DTypeConfig( 27 input_dtype=torch.quint8, 28 output_dtype=torch.quint8, 29 weight_dtype=torch.qint8, 30 bias_dtype=torch.float 31 ) 32 33 dtype_config2 = DTypeConfig( 34 input_dtype=torch.float16, 35 output_dtype=torch.float, 36 is_dynamic=True 37 ) 38 39 activation_dtype_with_constraints = DTypeWithConstraints( 40 dtype=torch.quint8, 41 quant_min_lower_bound=0, 42 quant_max_upper_bound=127, 43 scale_min_lower_bound=2 ** -12, 44 ) 45 46 weight_dtype_with_constraints = DTypeWithConstraints( 47 dtype=torch.qint8, 48 quant_min_lower_bound=-128, 49 quant_max_upper_bound=127, 50 scale_min_lower_bound=2 ** -12, 51 ) 52 53 dtype_config3 = DTypeConfig( 54 input_dtype=activation_dtype_with_constraints, 55 output_dtype=activation_dtype_with_constraints, 56 weight_dtype=weight_dtype_with_constraints, 57 ) 58 59 dtype_config_dict1_legacy = { 60 "input_dtype": torch.quint8, 61 "output_dtype": torch.quint8, 62 "weight_dtype": torch.qint8, 63 "bias_dtype": torch.float, 64 } 65 66 dtype_config_dict2_legacy = { 67 "input_dtype": torch.float16, 68 "output_dtype": torch.float, 69 "is_dynamic": True, 70 } 71 72 dtype_config_dict1 = { 73 "input_dtype": DTypeWithConstraints(dtype=torch.quint8), 74 "output_dtype": DTypeWithConstraints(torch.quint8), 75 "weight_dtype": DTypeWithConstraints(torch.qint8), 76 "bias_dtype": torch.float, 77 } 78 79 dtype_config_dict2 = { 80 "input_dtype": DTypeWithConstraints(dtype=torch.float16), 81 "output_dtype": DTypeWithConstraints(dtype=torch.float), 82 "is_dynamic": True, 83 } 84 85 dtype_config_dict3 = { 86 "input_dtype": activation_dtype_with_constraints, 87 "output_dtype": activation_dtype_with_constraints, 88 "weight_dtype": weight_dtype_with_constraints, 89 } 90 91 def test_dtype_config_from_dict(self): 92 self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1_legacy), self.dtype_config1) 93 self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2_legacy), self.dtype_config2) 94 self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict1), self.dtype_config1) 95 self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict2), self.dtype_config2) 96 self.assertEqual(DTypeConfig.from_dict(self.dtype_config_dict3), self.dtype_config3) 97 98 def test_dtype_config_to_dict(self): 99 self.assertEqual(self.dtype_config1.to_dict(), self.dtype_config_dict1) 100 self.assertEqual(self.dtype_config2.to_dict(), self.dtype_config_dict2) 101 self.assertEqual(self.dtype_config3.to_dict(), self.dtype_config_dict3) 102 103 # ====================== 104 # BackendPatternConfig 105 # ====================== 106 107 _fuser_method = _sequential_wrapper2(nni.LinearReLU) 108 109 _num_tensor_args_to_observation_type = { 110 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 111 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, 112 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 113 } 114 _input_type_to_index = { 115 "bias": 0, 116 "input": 1, 117 "weight": 2, 118 } 119 120 def _extra_inputs_getter(self, p): 121 return (torch.rand(3, 3),) 122 123 def _get_backend_op_config1(self): 124 return BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) \ 125 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 126 .add_dtype_config(self.dtype_config1) \ 127 .add_dtype_config(self.dtype_config2) \ 128 .set_root_module(torch.nn.Linear) \ 129 .set_qat_module(nnqat.Linear) \ 130 .set_reference_quantized_module(nnqr.Linear) \ 131 .set_fused_module(nni.LinearReLU) \ 132 .set_fuser_method(self._fuser_method) 133 134 def _get_backend_op_config2(self): 135 return BackendPatternConfig(torch.add) \ 136 .add_dtype_config(self.dtype_config2) \ 137 ._set_root_node_getter(_default_root_node_getter) \ 138 ._set_extra_inputs_getter(self._extra_inputs_getter) \ 139 ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \ 140 ._set_input_type_to_index(self._input_type_to_index) 141 142 def _get_backend_pattern_config_dict1(self): 143 return { 144 "pattern": (torch.nn.Linear, torch.nn.ReLU), 145 "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 146 "dtype_configs": [self.dtype_config_dict1, self.dtype_config_dict2], 147 "root_module": torch.nn.Linear, 148 "qat_module": nnqat.Linear, 149 "reference_quantized_module_for_root": nnqr.Linear, 150 "fused_module": nni.LinearReLU, 151 "fuser_method": self._fuser_method, 152 } 153 154 def _get_backend_pattern_config_dict2(self): 155 return { 156 "pattern": torch.add, 157 "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 158 "dtype_configs": [self.dtype_config_dict2], 159 "root_node_getter": _default_root_node_getter, 160 "extra_inputs_getter": self._extra_inputs_getter, 161 "num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type, 162 "input_type_to_index": self._input_type_to_index, 163 } 164 165 def test_backend_op_config_set_observation_type(self): 166 conf = BackendPatternConfig(torch.nn.Linear) 167 self.assertEqual(conf.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) 168 conf.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 169 self.assertEqual(conf.observation_type, ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 170 171 def test_backend_op_config_add_dtype_config(self): 172 conf = BackendPatternConfig(torch.nn.Linear) 173 self.assertEqual(len(conf.dtype_configs), 0) 174 conf.add_dtype_config(self.dtype_config1) 175 conf.add_dtype_config(self.dtype_config2) 176 self.assertEqual(len(conf.dtype_configs), 2) 177 self.assertEqual(conf.dtype_configs[0], self.dtype_config1) 178 self.assertEqual(conf.dtype_configs[1], self.dtype_config2) 179 180 def test_backend_op_config_set_root_module(self): 181 conf = BackendPatternConfig(nni.LinearReLU) 182 self.assertTrue(conf.root_module is None) 183 conf.set_root_module(torch.nn.Linear) 184 self.assertEqual(conf.root_module, torch.nn.Linear) 185 186 def test_backend_op_config_set_qat_module(self): 187 conf = BackendPatternConfig(torch.nn.Linear) 188 self.assertTrue(conf.qat_module is None) 189 conf.set_qat_module(nnqat.Linear) 190 self.assertEqual(conf.qat_module, nnqat.Linear) 191 192 def test_backend_op_config_set_reference_quantized_module(self): 193 conf = BackendPatternConfig(torch.nn.Linear) 194 self.assertTrue(conf.reference_quantized_module is None) 195 conf.set_reference_quantized_module(nnqr.Linear) 196 self.assertEqual(conf.reference_quantized_module, nnqr.Linear) 197 198 def test_backend_op_config_set_fused_module(self): 199 conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) 200 self.assertTrue(conf.fused_module is None) 201 conf.set_fused_module(nni.LinearReLU) 202 self.assertEqual(conf.fused_module, nni.LinearReLU) 203 204 def test_backend_op_config_set_fuser_method(self): 205 conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) 206 self.assertTrue(conf.fuser_method is None) 207 conf.set_fuser_method(self._fuser_method) 208 self.assertEqual(conf.fuser_method, self._fuser_method) 209 210 def test_backend_op_config_set_root_node_getter(self): 211 conf = BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) 212 self.assertTrue(conf._root_node_getter is None) 213 conf._set_root_node_getter(_default_root_node_getter) 214 self.assertEqual(conf._root_node_getter, _default_root_node_getter) 215 216 def test_backend_op_config_set_extra_inputs_getter(self): 217 conf = BackendPatternConfig(torch.nn.Linear) 218 self.assertTrue(conf._extra_inputs_getter is None) 219 conf._set_extra_inputs_getter(self._extra_inputs_getter) 220 self.assertEqual(conf._extra_inputs_getter, self._extra_inputs_getter) 221 222 def test_backend_op_config_set_num_tensor_args_to_observation_type(self): 223 conf = BackendPatternConfig(torch.add) 224 self.assertEqual(len(conf._num_tensor_args_to_observation_type), 0) 225 conf._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) 226 self.assertEqual(conf._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) 227 228 def test_backend_op_config_set_input_type_to_index(self): 229 conf = BackendPatternConfig(torch.addmm) 230 self.assertEqual(len(conf._input_type_to_index), 0) 231 conf._set_input_type_to_index(self._input_type_to_index) 232 self.assertEqual(conf._input_type_to_index, self._input_type_to_index) 233 234 def test_backend_op_config_from_dict(self): 235 conf_dict1 = self._get_backend_pattern_config_dict1() 236 conf1 = BackendPatternConfig.from_dict(conf_dict1) 237 self.assertEqual(conf1.pattern, (torch.nn.Linear, torch.nn.ReLU)) 238 self.assertEqual(conf1.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) 239 self.assertEqual(conf1.root_module, torch.nn.Linear) 240 self.assertEqual(conf1.qat_module, nnqat.Linear) 241 self.assertEqual(conf1.reference_quantized_module, nnqr.Linear) 242 self.assertEqual(conf1.fused_module, nni.LinearReLU) 243 self.assertEqual(conf1.fuser_method, self._fuser_method) 244 self.assertTrue(conf1._root_node_getter is None) 245 self.assertTrue(conf1._extra_inputs_getter is None) 246 self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0) 247 self.assertEqual(len(conf1._input_type_to_index), 0) 248 # Test temporary/internal keys 249 conf_dict2 = self._get_backend_pattern_config_dict2() 250 conf2 = BackendPatternConfig.from_dict(conf_dict2) 251 self.assertEqual(conf2.pattern, torch.add) 252 self.assertEqual(conf2.observation_type, ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) 253 self.assertTrue(conf2.root_module is None) 254 self.assertTrue(conf2.qat_module is None) 255 self.assertTrue(conf2.reference_quantized_module is None) 256 self.assertTrue(conf2.fused_module is None) 257 self.assertTrue(conf2.fuser_method is None) 258 self.assertEqual(conf2._root_node_getter, _default_root_node_getter) 259 self.assertEqual(conf2._extra_inputs_getter, self._extra_inputs_getter) 260 self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) 261 self.assertEqual(conf2._input_type_to_index, self._input_type_to_index) 262 263 def test_backend_op_config_to_dict(self): 264 conf1 = self._get_backend_op_config1() 265 conf2 = self._get_backend_op_config2() 266 conf_dict1 = self._get_backend_pattern_config_dict1() 267 conf_dict2 = self._get_backend_pattern_config_dict2() 268 self.assertEqual(conf1.to_dict(), conf_dict1) 269 self.assertEqual(conf2.to_dict(), conf_dict2) 270 271 # =============== 272 # BackendConfig 273 # =============== 274 275 def test_backend_config_set_name(self): 276 conf = BackendConfig("name1") 277 self.assertEqual(conf.name, "name1") 278 conf.set_name("name2") 279 self.assertEqual(conf.name, "name2") 280 281 def test_backend_config_set_backend_pattern_config(self): 282 conf = BackendConfig("name1") 283 self.assertEqual(len(conf.configs), 0) 284 backend_op_config1 = self._get_backend_op_config1() 285 backend_op_config2 = self._get_backend_op_config2() 286 conf.set_backend_pattern_config(backend_op_config1) 287 self.assertEqual(conf._pattern_complex_format_to_config, { 288 (torch.nn.ReLU, torch.nn.Linear): backend_op_config1, 289 }) 290 conf.set_backend_pattern_config(backend_op_config2) 291 self.assertEqual(conf._pattern_complex_format_to_config, { 292 (torch.nn.ReLU, torch.nn.Linear): backend_op_config1, 293 torch.add: backend_op_config2 294 }) 295 296 def test_backend_config_from_dict(self): 297 op1 = self._get_backend_op_config1() 298 op2 = self._get_backend_op_config2() 299 op_dict1 = self._get_backend_pattern_config_dict1() 300 op_dict2 = self._get_backend_pattern_config_dict2() 301 conf_dict = { 302 "name": "name1", 303 "configs": [op_dict1, op_dict2], 304 } 305 conf = BackendConfig.from_dict(conf_dict) 306 self.assertEqual(conf.name, "name1") 307 self.assertEqual(len(conf.configs), 2) 308 key1 = (torch.nn.ReLU, torch.nn.Linear) 309 key2 = torch.add 310 self.assertTrue(key1 in conf._pattern_complex_format_to_config) 311 self.assertTrue(key2 in conf._pattern_complex_format_to_config) 312 self.assertEqual(conf._pattern_complex_format_to_config[key1].to_dict(), op_dict1) 313 self.assertEqual(conf._pattern_complex_format_to_config[key2].to_dict(), op_dict2) 314 315 def test_backend_config_to_dict(self): 316 op1 = self._get_backend_op_config1() 317 op2 = self._get_backend_op_config2() 318 op_dict1 = self._get_backend_pattern_config_dict1() 319 op_dict2 = self._get_backend_pattern_config_dict2() 320 conf = BackendConfig("name1").set_backend_pattern_config(op1).set_backend_pattern_config(op2) 321 conf_dict = { 322 "name": "name1", 323 "configs": [op_dict1, op_dict2], 324 } 325 self.assertEqual(conf.to_dict(), conf_dict) 326 327if __name__ == '__main__': 328 raise RuntimeError("This _test file is not meant to be run directly, use:\n\n" 329 "\tpython _test/_test_quantization.py TESTNAME\n\n" 330 "instead.") 331