1# Owner(s): ["oncall: quantization"] 2 3import io 4from typing import Dict 5 6import torch 7import torch._C 8from torch.ao.quantization import default_dynamic_qconfig, per_channel_dynamic_qconfig 9from torch.ao.quantization.quantize_jit import ( 10 _prepare_ondevice_dynamic_jit, 11 _quantize_ondevice_dynamic_jit, 12 convert_dynamic_jit, 13 prepare_dynamic_jit, 14) 15from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule 16from torch.testing import FileCheck 17from torch.testing._internal.common_quantization import ( 18 get_script_module, 19 LinearAddModel, 20) 21from torch.testing._internal.common_utils import TestCase 22from torch.utils import bundled_inputs as bundled_inputs 23 24 25class myMod(torch.nn.Module): 26 def __init__(self, weight): 27 super().__init__() 28 self.fc1 = torch.nn.Linear(5, 5).float() 29 self.fc1.weight = weight 30 self.fc2 = torch.nn.Linear(5, 5).float() 31 32 def forward(self, x): 33 return self.fc2(self.fc1(x)) 34 35 36class MyConvLinearModule(torch.nn.Module): 37 def __init__(self) -> None: 38 super().__init__() 39 self.conv = torch.nn.Conv2d(3, 5, 3) 40 weight = torch.nn.Parameter(torch.ones(5, 5)) 41 self.weight1 = torch.nn.Parameter(torch.ones(5, 5)) 42 self.mymod = myMod(weight) 43 44 def forward(self, x): 45 conv_output = self.conv(x) 46 y = self.mymod(conv_output) 47 z = torch.nn.functional.linear(y, self.weight1) 48 return z 49 50 def get_example_inputs(self): 51 return (torch.rand(1, 3, 12, 7),) 52 53 54class OnDevicePTQUtils: 55 observer_module_name = ["MinMaxObserver", "PerChannelMinMaxObserver"] 56 57 @staticmethod 58 def insert_observers(model, qconfig_dict): 59 inputs = model.get_example_inputs() 60 scripted_model = get_script_module(model, False, inputs) 61 scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict) 62 return scripted_model 63 64 @staticmethod 65 def ptq_dynamic_quantize(model, qconfig_dict): 66 inputs = model.get_example_inputs() 67 m = get_script_module(model, False, inputs) 68 m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, "forward", True) 69 return m 70 71 @staticmethod 72 def find_observer_modules(m): 73 observer_modules = [] 74 for child_module in m.children(): 75 if child_module.original_name in OnDevicePTQUtils.observer_module_name: 76 observer_modules.append(child_module) 77 return observer_modules 78 79 @staticmethod 80 def is_value_type_observer(value): 81 type_name = value.type() 82 for observer_type in OnDevicePTQUtils.observer_module_name: 83 if observer_type in type_name.str(): 84 return True 85 return False 86 87 @staticmethod 88 def is_calculate_qparam(node): 89 if node.kind() == "prim::CallMethod": 90 if node.s("name") == "calculate_qparams": 91 return True 92 return False 93 94 @staticmethod 95 def get_linear_packed_param_fp_weight(node): 96 weight = node.inputsAt(0).node() 97 if ( 98 weight.kind() != "aten::quantize_per_tensor" 99 and weight.kind() != "aten::quantize_per_channel" 100 ): 101 raise ValueError("Quantized weight must be produced.") 102 fp_weight = weight.inputsAt(0).node() 103 assert ( 104 fp_weight.kind() == "prim::GetAttr" 105 ), "Weight must be an attribute of the module." 106 fp_weight_name = fp_weight.s("name") 107 return fp_weight_name 108 109 @staticmethod 110 def is_per_channel_quantized_packed_param(node): 111 assert ( 112 node.kind() == "quantized::linear_prepack" 113 ), "Node must corresponds to linear_prepack." 114 weight = node.inputsAt(0).node() 115 assert ( 116 weight.kind() != "aten::quantize_per_tensor" 117 or weight.kind() != "aten::quantize_per_channel" 118 ) 119 return weight.kind() != "aten::quantize_per_tensor" 120 121 122class TestOnDeviceDynamicPTQInsertObservers(TestCase): 123 def _check_num_and_type_of_observers(self, model, num_observers): 124 qconfig_dict = {"": default_dynamic_qconfig} 125 scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) 126 observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) 127 self.assertTrue(len(observer_modules) == num_observers) 128 for observer in observer_modules: 129 self.assertTrue(observer.original_name == "MinMaxObserver") 130 131 qconfig_dict = {"": per_channel_dynamic_qconfig} 132 scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) 133 observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) 134 self.assertTrue(len(observer_modules) == num_observers) 135 for observer in observer_modules: 136 self.assertTrue(observer.original_name == "PerChannelMinMaxObserver") 137 138 def _check_observer_method(self, model, num_observers): 139 qconfig_dict = {"": default_dynamic_qconfig} 140 inputs = model.get_example_inputs() 141 orig_scripted_model = get_script_module(model, False, inputs) 142 torch._C._jit_pass_inline(orig_scripted_model.graph) 143 orig_forward_graph = orig_scripted_model.graph.str() 144 scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) 145 quant_forward_graph = scripted_model.graph.str() 146 # exact graph matching is difficult so just resorting to # of lines 147 # instead of implementing graph matching 148 self.assertEqual( 149 len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines()) 150 ) 151 observe_method = scripted_model.observe_forward.graph 152 FileCheck().check_count( 153 'prim::CallMethod[name="forward"](%_observer', num_observers, exactly=True 154 ).run(observe_method) 155 reset_observers_method = scripted_model.reset_observers_forward.graph 156 FileCheck().check_count( 157 'prim::CallMethod[name="reset_min_max_vals"](%_observer', 158 num_observers, 159 exactly=True, 160 ).run(reset_observers_method) 161 162 def _observer_is_weight_only(self, node): 163 if (node.kind() == "prim::CallMethod") and node.s("name") == "forward": 164 if OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0)): 165 return node.inputsAt(1).node().kind() == "prim::GetAttr" 166 return False 167 168 def test_num_observers(self): 169 model = LinearAddModel() 170 self._check_num_and_type_of_observers(model, 2) 171 model = MyConvLinearModule() 172 self._check_num_and_type_of_observers(model, 3) 173 174 def test_observe_method(self): 175 model = MyConvLinearModule() 176 self._check_observer_method(model, 3) 177 178 def test_weight_only_observers(self): 179 model = MyConvLinearModule() 180 qconfig_dict = {"": default_dynamic_qconfig} 181 inputs = model.get_example_inputs() 182 scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) 183 observe_forward_graph = scripted_model.observe_forward.graph 184 num_weight_only_observers = 0 185 for node in observe_forward_graph.nodes(): 186 if self._observer_is_weight_only(node): 187 num_weight_only_observers += 1 188 self.assertEqual(num_weight_only_observers, 3) 189 190 191class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase): 192 def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0): 193 quantize_forward_graph = model.quantize_forward.graph 194 quantize_per_tensor = quantize_per_channel = 0 195 for n in quantize_forward_graph.nodes(): 196 if "aten::quantize_per_tensor" in n.kind(): 197 quantize_per_tensor += 1 198 if "aten::quantize_per_channel" in n.kind(): 199 quantize_per_channel += 1 200 self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes) 201 202 def _validate_calculate_qparams(self, model, num_nodes): 203 quantize_forward_graph = model.quantize_forward.graph 204 num_calculate_qparams = 0 205 for n in quantize_forward_graph.nodes(): 206 if OnDevicePTQUtils.is_calculate_qparam(n): 207 num_calculate_qparams += 1 208 self.assertEqual(num_calculate_qparams, num_nodes) 209 210 def _validate_no_observer_forward(self, model): 211 quantize_forward_graph = model.quantize_forward.graph 212 for n in quantize_forward_graph.nodes(): 213 if (n.kind() == "prim::CallMethod") and n.s("name") == "forward": 214 if OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0)): 215 return False 216 return True 217 218 def _check_quant_dequant_and_calc_qparams(self, model, num_nodes): 219 qconfig_dict = {"": default_dynamic_qconfig} 220 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 221 self._validate_quant_dequant_nodes(m, num_nodes) 222 self._validate_calculate_qparams(m, num_nodes) 223 self._validate_no_observer_forward(m) 224 225 qconfig_dict = {"": per_channel_dynamic_qconfig} 226 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 227 self._validate_quant_dequant_nodes(m, num_nodes, num_nodes) 228 self._validate_calculate_qparams(m, num_nodes) 229 self._validate_no_observer_forward(m) 230 231 def _check_quantize_forward_runs(self, model): 232 inputs = model.get_example_inputs() 233 qconfig_dict = {"": default_dynamic_qconfig} 234 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 235 m.observe_forward(*inputs) 236 m.quantize_forward(*inputs) 237 238 qconfig_dict = {"": per_channel_dynamic_qconfig} 239 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 240 # First must run observe forward to record the stats to produce 241 # correct scales and zero points 242 m.observe_forward(*inputs) 243 m.quantize_forward(*inputs) 244 245 def test_num_quant_dequant_nodes(self): 246 model = LinearAddModel() 247 self._check_quant_dequant_and_calc_qparams(model, 2) 248 model = MyConvLinearModule() 249 self._check_quant_dequant_and_calc_qparams(model, 3) 250 251 def test_quantize_forward_runs(self): 252 model = LinearAddModel() 253 self._check_quantize_forward_runs(model) 254 model = MyConvLinearModule() 255 self._check_quantize_forward_runs(model) 256 257 258class TestOnDeviceDynamicPTQFinalize(TestCase): 259 def _validate_packed_params(self, model, num_nodes, per_channel=0): 260 quantize_forward_graph = model.quantize_forward.graph 261 quantize_per_tensor = quantize_per_channel = 0 262 linear_prepack = 0 263 linear_prepack_uses = 0 264 for n in quantize_forward_graph.nodes(): 265 if n.kind() == "prim::SetAttr": 266 maybe_packed_param_value = n.inputsAt(1) 267 maybe_packed_param = maybe_packed_param_value.node() 268 if maybe_packed_param.kind() == "quantized::linear_prepack": 269 linear_prepack += 1 270 linear_prepack_uses += len(maybe_packed_param_value.uses()) 271 if OnDevicePTQUtils.is_per_channel_quantized_packed_param( 272 maybe_packed_param 273 ): 274 quantize_per_channel += 1 275 else: 276 quantize_per_tensor += 1 277 self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes) 278 self.assertEqual(quantize_per_channel, per_channel) 279 self.assertEqual(linear_prepack, num_nodes) 280 self.assertEqual(linear_prepack_uses, num_nodes) 281 282 def _validate_no_linear_unpack(self, model): 283 quantize_forward_graph = model.quantize_forward.graph 284 for n in quantize_forward_graph.nodes(): 285 if n.kind() == "quantized::linear_unpack": 286 return False 287 return True 288 289 def _validate_setattr_fp_weights(self, model, num_nodes): 290 quantize_forward_graph = model.quantize_forward.graph 291 fp_weights_setattr = 0 292 fp_weight_names = [] 293 for n in quantize_forward_graph.nodes(): 294 if n.kind() == "prim::SetAttr": 295 maybe_packed_param = n.inputsAt(1).node() 296 if maybe_packed_param.kind() == "quantized::linear_prepack": 297 weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight( 298 maybe_packed_param 299 ) 300 fp_weight_names.append(weight_name) 301 302 for n in quantize_forward_graph.nodes(): 303 # This is basically detecting 304 # %x = prim::Constant 305 # = prim::SetAttr(<weight_name>)(module_value, x) 306 # Thus making sure that the original fp weights are 307 # reset 308 if n.kind() == "prim::SetAttr": 309 weight_name = n.s("name") 310 if weight_name in fp_weight_names: 311 maybe_constant = n.inputsAt(1).node() 312 if maybe_constant.kind() == "prim::Constant": 313 fp_weights_setattr += 1 314 self.assertEqual(fp_weights_setattr, num_nodes) 315 316 def _validate_quantized_forward(self, model, num_nodes): 317 quantized_forward_graph = model.quantized_forward.graph 318 quantize_per_tensor = quantize_per_channel = 0 319 quantized_linear_dynamic = 0 320 linear_packed_params = 0 321 num_setattr = 0 322 for n in quantized_forward_graph.nodes(): 323 if "aten::quantize_per_tensor" in n.kind(): 324 quantize_per_tensor += 1 325 if "aten::quantize_per_channel" in n.kind(): 326 quantize_per_channel += 1 327 if "quantized::linear_dynamic" in n.kind(): 328 quantized_linear_dynamic += 1 329 if n.kind() == "prim::GetAttr": 330 output = n.outputsAt(0) 331 output_type = output.type() 332 if "LinearPackedParamsBase" in output_type.str(): 333 linear_packed_params += 1 334 if n.kind() == "prim::SetAttr": 335 num_setattr += 1 336 self.assertEqual(quantize_per_tensor, 0) 337 self.assertEqual(quantize_per_channel, 0) 338 self.assertEqual(quantized_linear_dynamic, num_nodes) 339 self.assertEqual(linear_packed_params, num_nodes) 340 # self.assertEqual(num_setattr, 0) 341 342 def _check_quantize_forward(self, model, num_nodes): 343 qconfig_dict = {"": default_dynamic_qconfig} 344 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 345 self._validate_packed_params(m, num_nodes) 346 self._validate_no_linear_unpack(m) 347 self._validate_setattr_fp_weights(m, num_nodes) 348 349 qconfig_dict = {"": per_channel_dynamic_qconfig} 350 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 351 self._validate_packed_params(m, num_nodes, num_nodes) 352 self._validate_no_linear_unpack(m) 353 self._validate_setattr_fp_weights(m, num_nodes) 354 355 def _check_quantized_forward(self, model, num_nodes): 356 qconfig_dict = {"": default_dynamic_qconfig} 357 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 358 self._validate_quantized_forward(m, num_nodes) 359 360 qconfig_dict = {"": per_channel_dynamic_qconfig} 361 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 362 self._validate_quantized_forward(m, num_nodes) 363 364 def _check_against_ref_dynamic_ptq(self, model): 365 model.eval() 366 inputs = model.get_example_inputs() 367 ref_m = torch.jit.script(model) 368 torch._C._jit_pass_inline(ref_m.graph) 369 qconfig_dict = {"": default_dynamic_qconfig} 370 ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) 371 ref_m = convert_dynamic_jit(ref_m) 372 ref_output = ref_m(*inputs) 373 374 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 375 m.observe_forward(*inputs) 376 m.quantize_forward(*inputs) 377 output = m.quantized_forward(*inputs) 378 self.assertTrue(torch.allclose(ref_output, output)) 379 thrown = False 380 try: 381 m(*inputs) 382 except Exception as e: 383 thrown = True 384 self.assertTrue(thrown) 385 386 # test with per channel quant 387 ref_m = torch.jit.script(model) 388 torch._C._jit_pass_inline(ref_m.graph) 389 qconfig_dict = {"": per_channel_dynamic_qconfig} 390 ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) 391 ref_m = convert_dynamic_jit(ref_m) 392 ref_output = ref_m(*inputs) 393 394 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 395 m.observe_forward(*inputs) 396 m.quantize_forward(*inputs) 397 output = m.quantized_forward(*inputs) 398 self.assertTrue(torch.allclose(ref_output, output)) 399 thrown = False 400 try: 401 m(*inputs) 402 except Exception as e: 403 thrown = True 404 self.assertTrue(thrown) 405 406 def _check_serdes_and_device_side_api_helper( 407 self, model, check_device_side_api=False 408 ): 409 model.eval() 410 inputs = model.get_example_inputs() 411 ref_m = torch.jit.script(model) 412 torch._C._jit_pass_inline(ref_m.graph) 413 qconfig_dict = {"": default_dynamic_qconfig} 414 ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) 415 ref_m = convert_dynamic_jit(ref_m) 416 buffer = io.BytesIO() 417 torch.jit.save(ref_m, buffer) 418 buffer.seek(0) 419 ref_m = torch.jit.load(buffer) 420 ref_output = ref_m(*inputs) 421 422 if not check_device_side_api: 423 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 424 buffer = io.BytesIO() 425 torch.jit.save(m, buffer) 426 buffer.seek(0) 427 m = torch.jit.load(buffer) 428 m.reset_observers_forward() 429 m.observe_forward(*inputs) 430 m.quantize_forward(*inputs) 431 output = m.quantized_forward(*inputs) 432 self.assertTrue(torch.allclose(ref_output, output)) 433 else: 434 # check for lite interpreter 435 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 436 (first_input,) = inputs 437 rand_input = bundled_inputs.bundle_randn( 438 first_input.size(), dtype=first_input.dtype 439 ) 440 m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input,)]) 441 buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) 442 buffer.seek(0) 443 m = _load_for_lite_interpreter(buffer) # Error here 444 torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward") 445 self.assertFalse(m.find_method("quantized_forward")) 446 self.assertFalse(m.find_method("quantize_forward")) 447 self.assertFalse(m.find_method("observe_forward")) 448 self.assertFalse(m.find_method("reset_observers_forward")) 449 output = m(*inputs) 450 self.assertTrue(torch.allclose(ref_output, output)) 451 452 # Now serialize to flabuffer and load from fb and check 453 dict: Dict[str, str] = {} 454 bytes = torch._C._save_mobile_module_to_bytes(m._c, dict) 455 m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes)) 456 fb_output = m(*inputs) 457 self.assertTrue(torch.allclose(ref_output, fb_output)) 458 459 model.eval() 460 inputs = model.get_example_inputs() 461 ref_m = torch.jit.script(model) 462 torch._C._jit_pass_inline(ref_m.graph) 463 qconfig_dict = {"": per_channel_dynamic_qconfig} 464 ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) 465 ref_m = convert_dynamic_jit(ref_m) 466 buffer = io.BytesIO() 467 torch.jit.save(ref_m, buffer) 468 buffer.seek(0) 469 ref_m = torch.jit.load(buffer) 470 ref_output = ref_m(*inputs) 471 472 if not check_device_side_api: 473 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 474 buffer = io.BytesIO() 475 torch.jit.save(m, buffer) 476 buffer.seek(0) 477 m = torch.jit.load(buffer) 478 m.reset_observers_forward() 479 m.observe_forward(*inputs) 480 m.quantize_forward(*inputs) 481 output = m.quantized_forward(*inputs) 482 self.assertTrue(torch.allclose(ref_output, output)) 483 else: 484 # check for lite interpreter 485 m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) 486 (first_input,) = inputs 487 rand_input = bundled_inputs.bundle_randn( 488 first_input.size(), dtype=first_input.dtype 489 ) 490 m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input,)]) 491 buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) 492 buffer.seek(0) 493 m = _load_for_lite_interpreter(buffer) # Error here 494 torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward") 495 self.assertFalse(m.find_method("quantized_forward")) 496 self.assertFalse(m.find_method("quantize_forward")) 497 self.assertFalse(m.find_method("observe_forward")) 498 self.assertFalse(m.find_method("reset_observers_forward")) 499 output = m(*inputs) 500 self.assertTrue(torch.allclose(ref_output, output)) 501 502 def _check_serialization_deserialization(self, model): 503 self._check_serdes_and_device_side_api_helper(model, False) 504 505 def _check_device_side_api(self, model): 506 self._check_serdes_and_device_side_api_helper(model, True) 507 508 def test_quantize_forward(self): 509 model = LinearAddModel() 510 self._check_quantize_forward(model, 2) 511 model = MyConvLinearModule() 512 self._check_quantize_forward(model, 3) 513 514 def test_quantized_forward(self): 515 model = LinearAddModel() 516 self._check_quantized_forward(model, 2) 517 model = MyConvLinearModule() 518 self._check_quantized_forward(model, 3) 519 520 def test_against_offdevice_dynamic_ptq(self): 521 model = LinearAddModel() 522 self._check_against_ref_dynamic_ptq(model) 523 model = MyConvLinearModule() 524 self._check_against_ref_dynamic_ptq(model) 525 526 def test_serialization_deserialization(self): 527 model = MyConvLinearModule() 528 self._check_serialization_deserialization(model) 529 530 def test_device_side_api(self): 531 model = MyConvLinearModule() 532 self._check_device_side_api(model) 533