1# Owner(s): ["oncall: quantization"] 2 3import copy 4import math 5import operator 6import unittest 7 8import torch 9import torch.nn as nn 10import torch.nn.functional as F 11from torch.ao.quantization import ( 12 default_dynamic_qconfig, 13 QConfigMapping, 14 get_default_qconfig_mapping, 15) 16import torch.ao.nn.quantized as nnq 17toq = torch.ops.quantized 18from torch.ao.quantization.quantize_fx import ( 19 convert_fx, 20 convert_to_reference_fx, 21 prepare_fx, 22 prepare_qat_fx, 23) 24from torch.testing._internal.common_quantization import ( 25 ConvBnModel, 26 ConvBnReLUModel, 27 ConvModel, 28 QuantizationTestCase, 29 skipIfNoFBGEMM, 30 skipIfNoQNNPACK, 31 withQNNPACKBackend, 32 SingleLayerLinearDynamicModel, 33 SingleLayerLinearModel, 34 LSTMwithHiddenDynamicModel, 35 SparseNNModel, 36 skip_if_no_torchvision, 37 TwoLayerLinearModel 38) 39from torch.testing._internal.common_utils import skipIfTorchDynamo 40from torch.ao.quantization.quantization_mappings import ( 41 get_default_static_quant_module_mappings, 42 get_default_dynamic_quant_module_mappings, 43 get_default_float_to_quantized_operator_mappings, 44) 45from torch.testing._internal.common_cuda import TEST_CUDA 46from torch.testing._internal.common_quantization import NodeSpec as ns 47from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns 48import torch.ao.quantization.fx.quantize_handler as qh 49from torch.ao.ns.fx.pattern_utils import ( 50 get_type_a_related_to_b, 51) 52from torch.ao.ns.fx.graph_matcher import ( 53 get_matching_subgraph_pairs, 54 GraphMatchingException, 55) 56from torch.ao.ns.fx.utils import ( 57 compute_sqnr, 58 compute_normalized_l2_error, 59 compute_cosine_similarity, 60) 61from torch.ao.ns.fx.mappings import ( 62 get_node_type_to_io_type_map, 63 get_unmatchable_types_map, 64 get_base_name_to_sets_of_related_ops, 65 get_base_name_for_op, 66 add_op_to_sets_of_related_ops, 67) 68from torch.ao.ns.fx.weight_utils import ( 69 get_op_to_type_to_weight_extraction_fn, 70) 71from torch.ao.ns._numeric_suite_fx import ( 72 extract_weights, 73 _extract_weights_impl, 74 add_loggers, 75 _add_loggers_impl, 76 OutputLogger, 77 add_shadow_loggers, 78 _add_shadow_loggers_impl, 79 extract_logger_info, 80 extract_shadow_logger_info, 81 extend_logger_results_with_comparison, 82 prepare_n_shadows_model, 83 convert_n_shadows_model, 84 extract_results_n_shadows_model, 85 OutputComparisonLogger, 86 print_comparisons_n_shadows_model, 87 loggers_set_enabled, 88 loggers_set_save_activations, 89 _prepare_n_shadows_add_loggers_model, 90 _n_shadows_compare_weights, 91) 92from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping 93from torch.ao.quantization.backend_config import get_native_backend_config 94from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers 95 96 97# Note: these models are not for use outside of this file. While it's good 98# to reuse code, we also need to be able to iterate on tests 99# quickly when debugging. If a test model has a large number of callsites 100# across various different files, speed of debugging on individual test cases 101# decreases. 102class LinearReluFunctional(nn.Module): 103 def __init__(self) -> None: 104 super().__init__() 105 self.w1 = nn.Parameter(torch.empty(4, 4)) 106 self.b1 = nn.Parameter(torch.zeros(4)) 107 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 108 109 def forward(self, x): 110 x = F.linear(x, self.w1, self.b1) 111 x = F.relu(x) 112 return x 113 114 115class LinearFunctional(nn.Module): 116 def __init__(self) -> None: 117 super().__init__() 118 self.w1 = nn.Parameter(torch.empty(4, 4)) 119 self.b1 = nn.Parameter(torch.zeros(4)) 120 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 121 122 def forward(self, x): 123 x = F.linear(x, self.w1, self.b1) 124 return x 125 126 127class LinearReluLinearFunctional(nn.Module): 128 def __init__(self) -> None: 129 super().__init__() 130 self.w = nn.Parameter(torch.Tensor(4, 4)) 131 self.b = nn.Parameter(torch.zeros(4)) 132 torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) 133 134 def forward(self, x): 135 x = F.linear(x, self.w, self.b) 136 x = F.relu(x) 137 x = F.linear(x, self.w, self.b) 138 return x 139 140 141class AddMulFunctional(nn.Module): 142 def forward(self, x, y): 143 x = x + 1.0 144 x = x * 1.0 145 x = 1.0 + x 146 x = 1.0 * x 147 x = x + y 148 x = x * y 149 return x 150 151 152class AllConvAndLinearFusionModules(torch.nn.Module): 153 def __init__(self) -> None: 154 super().__init__() 155 # conv1d 156 self.conv1d_0 = nn.Conv1d(1, 1, 1) 157 # conv1d - relu 158 self.conv1d_1 = nn.Conv1d(1, 1, 1) 159 self.relu_0 = nn.ReLU() 160 # conv1d - bn (qat only) 161 self.conv1d_2 = nn.Conv1d(1, 1, 1) 162 self.bn1d_0 = nn.BatchNorm1d(1) 163 # conv1d - bn - relu (qat only) 164 self.conv1d_3 = nn.Conv1d(1, 1, 1) 165 self.bn1d_1 = nn.BatchNorm1d(1) 166 self.relu_4 = nn.ReLU() 167 # conv2d 168 self.conv2d_0 = nn.Conv2d(1, 1, 1) 169 # conv2d - relu 170 self.conv2d_1 = nn.Conv2d(1, 1, 1) 171 self.relu_1 = nn.ReLU() 172 # conv2d - bn (qat only) 173 self.conv2d_2 = nn.Conv2d(1, 1, 1) 174 self.bn2d_0 = nn.BatchNorm2d(1) 175 # conv2d - bn - relu (qat only) 176 self.conv2d_3 = nn.Conv2d(1, 1, 1) 177 self.bn2d_1 = nn.BatchNorm2d(1) 178 self.relu_5 = nn.ReLU() 179 # conv3d 180 self.conv3d_0 = nn.Conv3d(1, 1, 1) 181 # conv3d - relu 182 self.conv3d_1 = nn.Conv3d(1, 1, 1) 183 self.relu_2 = nn.ReLU() 184 # conv3d - bn (qat only) 185 self.conv3d_2 = nn.Conv3d(1, 1, 1) 186 self.bn3d_0 = nn.BatchNorm3d(1) 187 # conv3d - bn - relu (qat only) 188 self.conv3d_3 = nn.Conv3d(1, 1, 1) 189 self.bn3d_1 = nn.BatchNorm3d(1) 190 self.relu_6 = nn.ReLU() 191 # linear 192 self.linear_0 = nn.Linear(1, 1) 193 # linear - relu 194 self.linear_1 = nn.Linear(1, 1) 195 self.relu_3 = nn.ReLU() 196 197 def forward(self, x): 198 # conv1d 199 x = self.conv1d_0(x) 200 x = self.conv1d_1(x) 201 x = self.relu_0(x) 202 x = self.conv1d_2(x) 203 x = self.bn1d_0(x) 204 x = self.conv1d_3(x) 205 x = self.bn1d_1(x) 206 x = self.relu_4(x) 207 # conv2d 208 x = x.reshape(1, 1, 1, 1) 209 x = self.conv2d_0(x) 210 x = self.conv2d_1(x) 211 x = self.relu_1(x) 212 x = self.conv2d_2(x) 213 x = self.bn2d_0(x) 214 x = self.conv2d_3(x) 215 x = self.bn2d_1(x) 216 x = self.relu_5(x) 217 # conv3d 218 x = x.reshape(1, 1, 1, 1, 1) 219 x = self.conv3d_0(x) 220 x = self.conv3d_1(x) 221 x = self.relu_2(x) 222 x = self.conv3d_2(x) 223 x = self.bn3d_0(x) 224 x = self.conv3d_3(x) 225 x = self.bn3d_1(x) 226 x = self.relu_6(x) 227 # linear 228 x = x.reshape(1, 1) 229 x = self.linear_0(x) 230 x = self.linear_1(x) 231 x = self.relu_3(x) 232 return x 233 234 235class AllConvFunctional(torch.nn.Module): 236 def __init__(self, weight1d, weight2d, weight3d, bias1d, bias2d, bias3d): 237 super().__init__() 238 self.weight1d = torch.nn.Parameter(weight1d) 239 self.weight2d = torch.nn.Parameter(weight2d) 240 self.weight3d = torch.nn.Parameter(weight3d) 241 self.bias1d = torch.nn.Parameter(bias1d) 242 self.bias2d = torch.nn.Parameter(bias2d) 243 self.bias3d = torch.nn.Parameter(bias3d) 244 self.stride1d = 1 245 self.padding1d = 0 246 self.dilation1d = 1 247 self.stride2d = (1, 1) 248 self.padding2d = (0, 0) 249 self.dilation2d = (1, 1) 250 self.groups = 1 251 self.stride3d = (1, 1, 1) 252 self.padding3d = (0, 0, 0) 253 self.dilation3d = (1, 1, 1) 254 255 def forward(self, x): 256 x = F.conv1d( 257 x, self.weight1d, self.bias1d, self.stride1d, self.padding1d, 258 self.dilation1d, self.groups) 259 x = F.conv1d( 260 x, self.weight1d, self.bias1d, self.stride1d, self.padding1d, 261 self.dilation1d, self.groups) 262 x = F.relu(x) 263 x = F.conv2d( 264 x, self.weight2d, self.bias2d, self.stride2d, self.padding2d, 265 self.dilation2d, self.groups) 266 x = F.conv2d( 267 x, self.weight2d, self.bias2d, self.stride2d, self.padding2d, 268 self.dilation2d, self.groups) 269 x = F.relu(x) 270 x = F.conv3d( 271 x, self.weight3d, self.bias3d, self.stride3d, self.padding3d, 272 self.dilation3d, self.groups) 273 x = F.conv3d( 274 x, self.weight3d, self.bias3d, self.stride3d, self.padding3d, 275 self.dilation3d, self.groups) 276 x = F.relu(x) 277 return x 278 279@torch.fx.wrap 280def _wrapped_hardswish(x): 281 return F.hardswish(x) 282 283@torch.fx.wrap 284def _wrapped_hardswish_fp16(x): 285 x = x.dequantize() 286 x = F.hardswish(x) 287 x = x.to(torch.float16) 288 return x 289 290@torch.fx.wrap 291def _wrapped_sigmoid(x): 292 return F.sigmoid(x) 293 294@torch.fx.wrap 295def _wrapped_linear(x, w, b): 296 return F.linear(x, w, b) 297 298def get_all_quant_patterns(): 299 """ we are in the process to migrate the frontend of fx graph mode quant 300 to use backend_config_dict, so some of the patterns are moved to backend_config_dict 301 this function will include these patterns so that we can still have all the patterns 302 """ 303 # TODO: we can remove this call, and get all patterns from backend_config_dict in 304 # the future when the frontend refactor is done in fx graph mode quantization 305 all_quant_patterns = get_default_quant_patterns() 306 # some of the patterns are moved to (native) backend_config_dict so we need to 307 # add them back here 308 for pattern, quantize_handler in _get_pattern_to_quantize_handlers(get_native_backend_config()).items(): 309 all_quant_patterns[pattern] = quantize_handler 310 return all_quant_patterns 311 312class TestFXGraphMatcher(QuantizationTestCase): 313 314 @skipIfNoFBGEMM 315 def test_simple_mod(self): 316 m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() 317 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) 318 mp_copy = copy.deepcopy(mp) 319 mq = convert_fx(mp_copy) 320 results = get_matching_subgraph_pairs(mp, mq) 321 322 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 323 conv_name_0 = 'base_op_' + get_base_name_for_op( 324 base_name_to_sets_of_related_ops, nn.Conv2d) + '_0' 325 326 expected_types = { 327 conv_name_0: ((nn.Conv2d, torch.ao.quantization.MinMaxObserver), (nnq.Conv2d, nnq.Conv2d)), 328 } 329 self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) 330 331 @skipIfNoFBGEMM 332 def test_simple_fun(self): 333 class M(nn.Module): 334 def __init__(self) -> None: 335 super().__init__() 336 self.w = nn.Parameter(torch.empty(1, 4)) 337 self.b = nn.Parameter(torch.zeros(1)) 338 torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) 339 340 def forward(self, x): 341 return F.linear(x, self.w, self.b) 342 343 m = M().eval() 344 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) 345 mp_copy = copy.deepcopy(mp) 346 mq = convert_fx(mp_copy) 347 results = get_matching_subgraph_pairs(mp, mq) 348 349 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 350 linear_name_0 = 'base_op_' + get_base_name_for_op( 351 base_name_to_sets_of_related_ops, F.linear) + '_0' 352 353 expected_types = { 354 linear_name_0: 355 ((F.linear, torch.ao.quantization.MinMaxObserver), (toq.linear, toq.linear)) 356 } 357 self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) 358 359 @skipIfNoFBGEMM 360 def test_simple_fusion(self): 361 m = LinearReluFunctional().eval() 362 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(4, 4),)) 363 mp_copy = copy.deepcopy(mp) 364 mq = convert_fx(mp_copy) 365 results = get_matching_subgraph_pairs(mp, mq) 366 367 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 368 linear_name_0 = 'base_op_' + get_base_name_for_op( 369 base_name_to_sets_of_related_ops, F.linear) + '_0' 370 371 expected_types = { 372 linear_name_0: 373 ((F.linear, torch.ao.quantization.MinMaxObserver), (toq.linear_relu, toq.linear_relu)), 374 } 375 self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) 376 377 @skipIfNoFBGEMM 378 def test_simple_mod_multi(self): 379 m = nn.Sequential( 380 nn.Sequential( 381 nn.Conv2d(1, 1, 1), 382 ), 383 nn.Conv2d(1, 1, 1), 384 ).eval() 385 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=(torch.randn(1, 1, 1, 1),)) 386 mp_copy = copy.deepcopy(mp) 387 mq = convert_fx(mp_copy) 388 # assume success if no exceptions 389 results = get_matching_subgraph_pairs(mp, mq) 390 391 @skipIfNoFBGEMM 392 def test_simple_tensor_ops(self): 393 class M(nn.Module): 394 def forward(self, x, y): 395 z = x + y 396 return z 397 398 m = M().eval() 399 example_inputs = (torch.randn(1), torch.randn(1)) 400 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 401 mp_copy = copy.deepcopy(mp) 402 mq = convert_fx(mp_copy) 403 # assume success if no exceptions 404 results = get_matching_subgraph_pairs(mp, mq) 405 406 @skipIfNoFBGEMM 407 def test_matching_failure_node_count(self): 408 # verify that matching graphs with matching node types but 409 # different counts of matchable nodes fails 410 m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() 411 m2 = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() 412 example_inputs = (torch.randn(1, 1, 1, 1),) 413 mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 414 mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 415 with self.assertRaises(GraphMatchingException) as ex: 416 results = get_matching_subgraph_pairs(mp1, mp2) 417 418 @skipIfNoFBGEMM 419 def test_matching_failure_node_type(self): 420 # verify that matching graphs with non-matching node types fails 421 m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() 422 m2 = nn.Sequential(nn.Linear(1, 1)).eval() 423 example_inputs = (torch.randn(1, 1, 1, 1),) 424 mp1 = prepare_fx(m1, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 425 example_inputs = (torch.randn(1, 1),) 426 mp2 = prepare_fx(m2, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 427 with self.assertRaises(GraphMatchingException) as ex: 428 results = get_matching_subgraph_pairs(mp1, mp2) 429 430 @skipIfNoFBGEMM 431 def test_nodes_before_cat(self): 432 # verify that nodes before cat get matched 433 class M(nn.Module): 434 def forward(self, x0): 435 x1 = torch.add(x0, 1.0) 436 y1 = torch.add(x0, 1.0) 437 x2 = torch.cat([x1, y1]) 438 return x2 439 440 m = M().eval() 441 example_inputs = (torch.randn(1),) 442 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 443 mp_copy = copy.deepcopy(mp) 444 mq = convert_fx(mp_copy) 445 results = get_matching_subgraph_pairs(mp, mq) 446 447 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 448 cat_name_0 = 'base_op_' + get_base_name_for_op( 449 base_name_to_sets_of_related_ops, torch.cat) + '_0' 450 add_name_0 = 'base_op_' + get_base_name_for_op( 451 base_name_to_sets_of_related_ops, torch.add) + '_0' 452 add_name_1 = 'base_op_' + get_base_name_for_op( 453 base_name_to_sets_of_related_ops, torch.add) + '_1' 454 455 expected_types = { 456 cat_name_0: ((torch.cat, torch.cat), (torch.cat, torch.cat)), 457 add_name_0: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), 458 add_name_1: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), 459 } 460 self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) 461 462 @skipIfNoFBGEMM 463 def test_dict_return_type(self): 464 # verify that we can traverse up nodes which return dictionaries 465 class M(nn.Module): 466 def forward(self, x0): 467 x1 = torch.add(x0, 1.0) 468 y1 = torch.add(x0, 1.0) 469 z1 = torch.add(x0, 1.0) 470 a1 = {'x1': x1, 'y1': (y1,), 'z1': [{'key': (z1,)}]} 471 return a1 472 473 m = M().eval() 474 example_inputs = (torch.randn(1),) 475 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 476 mp_copy = copy.deepcopy(mp) 477 mq = convert_fx(mp_copy) 478 results = get_matching_subgraph_pairs(mp, mq) 479 480 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 481 add_name_0 = 'base_op_' + get_base_name_for_op( 482 base_name_to_sets_of_related_ops, torch.add) + '_0' 483 add_name_1 = 'base_op_' + get_base_name_for_op( 484 base_name_to_sets_of_related_ops, torch.add) + '_1' 485 add_name_2 = 'base_op_' + get_base_name_for_op( 486 base_name_to_sets_of_related_ops, torch.add) + '_2' 487 488 expected_types = { 489 add_name_0: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), 490 add_name_1: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), 491 add_name_2: ((torch.add, torch.ao.quantization.MinMaxObserver), (toq.add, toq.add)), 492 } 493 self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) 494 495 @skipIfNoFBGEMM 496 def test_nodes_with_equal_types_get_matched(self): 497 class M(nn.Module): 498 def __init__(self) -> None: 499 super().__init__() 500 self.conv1 = nn.Conv2d(1, 1, 1) 501 self.conv2 = nn.Conv2d(1, 1, 1) 502 503 def forward(self, x): 504 x = self.conv1(x) 505 x = self.conv2(x) 506 x = torch.mul(x, x) 507 x = torch.sigmoid(x) 508 x = F.relu(x) 509 return x 510 511 m = M().eval() 512 # prevent conv2 from getting quantized, so we can test 513 # modules with equal types 514 qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping().set_module_name("conv2", None) 515 example_inputs = (torch.randn(1, 1, 1, 1),) 516 mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) 517 mp_copy = copy.deepcopy(mp) 518 mq = convert_fx(mp_copy) 519 results = get_matching_subgraph_pairs(mp, mq) 520 521 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 522 conv_name_0 = 'base_op_' + get_base_name_for_op( 523 base_name_to_sets_of_related_ops, nn.Conv2d) + '_0' 524 conv_name_1 = 'base_op_' + get_base_name_for_op( 525 base_name_to_sets_of_related_ops, nn.Conv2d) + '_1' 526 mul_name_0 = 'base_op_' + get_base_name_for_op( 527 base_name_to_sets_of_related_ops, torch.mul) + '_0' 528 relu_name_0 = 'base_op_' + get_base_name_for_op( 529 base_name_to_sets_of_related_ops, torch.relu) + '_0' 530 sigmoid_name_0 = 'base_op_' + get_base_name_for_op( 531 base_name_to_sets_of_related_ops, torch.sigmoid) + '_0' 532 533 # all of these should be matched 534 expected_types = { 535 conv_name_1: 536 ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nnq.Conv2d, nnq.Conv2d)), 537 conv_name_0: 538 ((nn.Conv2d, torch.ao.quantization.HistogramObserver), (nn.Conv2d, nn.Conv2d)), 539 mul_name_0: ((torch.mul, torch.ao.quantization.HistogramObserver), (toq.mul, toq.mul)), 540 relu_name_0: ((F.relu, torch.ao.quantization.FixedQParamsObserver), (F.relu, F.relu)), 541 sigmoid_name_0: 542 ((torch.sigmoid, torch.ao.quantization.FixedQParamsObserver), (torch.sigmoid, torch.sigmoid)), 543 } 544 self.assert_types_for_matched_subgraph_pairs(results, expected_types, mp, mq) 545 546 def test_methods(self): 547 """ 548 Verify that graph matching works on methods 549 """ 550 class M(nn.Module): 551 def forward(self, x): 552 x = x.sigmoid() 553 return x 554 555 m1 = M().eval() 556 m2 = M().eval() 557 qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() 558 example_inputs = (torch.randn(1),) 559 m1p = prepare_fx(m1, qconfig_mapping, example_inputs=example_inputs) 560 m2p = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) 561 results = get_matching_subgraph_pairs(m1p, m2p) 562 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 563 sigmoid_name_0 = 'base_op_' + get_base_name_for_op( 564 base_name_to_sets_of_related_ops, torch.sigmoid) + '_0' 565 expected_types = { 566 sigmoid_name_0: 567 (('sigmoid', torch.ao.quantization.FixedQParamsObserver), ('sigmoid', torch.ao.quantization.FixedQParamsObserver)), 568 } 569 self.assert_types_for_matched_subgraph_pairs( 570 results, expected_types, m1p, m2p) 571 572 def test_op_relationship_mapping(self): 573 """ 574 Tests that the mapping of op relationships is complete. 575 """ 576 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 577 type_a_related_to_b = \ 578 get_type_a_related_to_b(base_name_to_sets_of_related_ops) 579 580 # 1. check static quant module mappings 581 static_quant_mod_mappings = get_default_static_quant_module_mappings() 582 for fp32_type, int8_type in static_quant_mod_mappings.items(): 583 # skip quants and dequants, for the purposes of Numerical Suite 584 types_to_skip = ( 585 torch.ao.quantization.QuantStub, 586 torch.ao.quantization.DeQuantStub, 587 nnq.FloatFunctional, 588 # the ConvTranspose3d swap is not implemented in FX Graph 589 # mode quantization yet 590 nn.ConvTranspose3d, 591 # the GroupNorm swap is not implemented in FX Graph 592 # mode quantization yet 593 nn.GroupNorm, 594 # nnq.ReLU6 is no longer swapped, because nn.ReLU6 can 595 # take quantized inputs 596 nn.ReLU6, 597 ) 598 if fp32_type in types_to_skip: 599 continue 600 601 # verify relatedness 602 in_type_a_related_to_b = \ 603 (fp32_type, int8_type) in type_a_related_to_b 604 self.assertTrue( 605 in_type_a_related_to_b, 606 f"{fp32_type} and {int8_type} need a relationship mapping") 607 608 # 2. check static quant op mappings 609 static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings() 610 for fp32_type, int8_type in static_quant_fun_mappings.items(): 611 # verify relatedness 612 in_type_a_related_to_b = \ 613 (fp32_type, int8_type) in type_a_related_to_b 614 self.assertTrue( 615 in_type_a_related_to_b, 616 f"{fp32_type} and {int8_type} need a relationship mapping") 617 618 # 3. check dynamic quant mappings 619 dynamic_quant_mappings = get_default_dynamic_quant_module_mappings() 620 for fp32_type, int8_type in dynamic_quant_mappings.items(): 621 # TODO(future PR): enable correct weight extraction for these 622 # and remove from this list. 623 types_to_skip = ( 624 nn.GRUCell, 625 nn.GRU, 626 nn.LSTMCell, 627 nn.RNNCell, 628 ) 629 if fp32_type in types_to_skip: 630 continue 631 # verify relatedness 632 in_type_a_related_to_b = \ 633 (fp32_type, int8_type) in type_a_related_to_b 634 self.assertTrue( 635 in_type_a_related_to_b, 636 f"{fp32_type} and {int8_type} need a relationship mapping") 637 638 # 4. go through the ops mapped to each QuantizeHandler type, and verify 639 # correctness. 640 def _op_in_base_sets_of_related_ops(op): 641 for ops in base_name_to_sets_of_related_ops.values(): 642 if op in ops: 643 return True 644 return False 645 646 unmatchable_types_map = get_unmatchable_types_map() 647 FUNS_UNMATCHABLE = unmatchable_types_map['funs_unmatchable'] 648 MODS_UNMATCHABLE = unmatchable_types_map['mods_unmatchable'] 649 METHS_UNMATCHABLE = unmatchable_types_map['meths_unmatchable'] 650 651 def _op_is_unmatchable(op): 652 return ( 653 op in FUNS_UNMATCHABLE or 654 op in MODS_UNMATCHABLE or 655 op in METHS_UNMATCHABLE 656 ) 657 658 default_quant_patterns = get_all_quant_patterns() 659 for pattern, qhandler_cls in default_quant_patterns.items(): 660 base_op = None 661 if isinstance(pattern, tuple): 662 base_op = pattern[-1] 663 elif isinstance(pattern, str): 664 base_op = pattern 665 else: 666 base_op = pattern 667 668 qhandler_cls_all_ops_quantizeable = [ 669 qh.CatQuantizeHandler, 670 qh.ConvReluQuantizeHandler, 671 qh.LinearReLUQuantizeHandler, 672 qh.BatchNormQuantizeHandler, 673 qh.EmbeddingQuantizeHandler, 674 qh.RNNDynamicQuantizeHandler, 675 ] 676 677 qhandler_cls_quant_op_same_signature = [ 678 qh.FixedQParamsOpQuantizeHandler, 679 qh.CopyNodeQuantizeHandler, 680 qh.GeneralTensorShapeOpQuantizeHandler, 681 ] 682 683 if qhandler_cls == qh.BinaryOpQuantizeHandler: 684 # these ops do not have quantized equivalents 685 ops_to_skip = [ 686 torch.bmm, 687 torch.div, 688 torch.sub, 689 operator.truediv, 690 operator.sub 691 ] 692 if base_op in ops_to_skip: 693 continue 694 self.assertTrue( 695 _op_in_base_sets_of_related_ops(base_op), 696 f"{base_op} not in sets of related ops") 697 elif qhandler_cls == qh.RNNDynamicQuantizeHandler: 698 # TODO(future PR): add support for all classes in 699 # RNNDynamicQuantizeHandler 700 pass 701 elif qhandler_cls == qh.DefaultNodeQuantizeHandler: 702 self.assertTrue( 703 _op_in_base_sets_of_related_ops(base_op), 704 f"{base_op} not in sets of related ops") 705 elif qhandler_cls in qhandler_cls_quant_op_same_signature: 706 # these ops use the same op signature for fp32 and quantized 707 # tensors 708 self.assertTrue( 709 _op_in_base_sets_of_related_ops(base_op) or 710 _op_is_unmatchable(base_op), 711 f"{base_op} not in sets of related ops or unmatchable") 712 elif qhandler_cls in qhandler_cls_all_ops_quantizeable: 713 self.assertTrue( 714 _op_in_base_sets_of_related_ops(base_op), 715 f"{base_op} not in sets of related ops") 716 else: 717 # torch.sum does not have quantized equivalents 718 if base_op in [ 719 torch.sum, 720 nn.GRUCell, 721 nn.GRU, 722 nn.LSTMCell, 723 nn.RNNCell, 724 ]: 725 continue 726 if isinstance(base_op, tuple): 727 # skip fusion patterns 728 continue 729 # didn't match explicit quantize handler class, we can check if the 730 # operator is in the related op set directly 731 if not (_op_in_base_sets_of_related_ops(base_op) or _op_is_unmatchable(base_op)): 732 raise AssertionError( 733 f"handling for {qhandler_cls} for op {base_op} not implemented") 734 735 @skipIfNoFBGEMM 736 def test_user_defined_function(self): 737 """ 738 Verify that graph matching works on user defined functions 739 """ 740 class M1(nn.Module): 741 def forward(self, x): 742 x = F.hardswish(x) 743 return x 744 745 class M2(nn.Module): 746 def forward(self, x): 747 x = _wrapped_hardswish(x) 748 return x 749 750 qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() 751 example_inputs = (torch.randn(1, 1, 1, 1),) 752 m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs) 753 m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs) 754 755 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 756 add_op_to_sets_of_related_ops( 757 base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish) 758 759 results = get_matching_subgraph_pairs( 760 m1, m2, 761 base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops) 762 763 hardswish_name_0 = 'base_op_' + get_base_name_for_op( 764 base_name_to_sets_of_related_ops, F.hardswish) + '_0' 765 766 expected_types = { 767 hardswish_name_0: 768 ((F.hardswish, torch.ao.quantization.HistogramObserver), (_wrapped_hardswish, _wrapped_hardswish)), 769 } 770 self.assert_types_for_matched_subgraph_pairs( 771 results, expected_types, m1, m2) 772 773 @skipIfNoFBGEMM 774 def test_results_order(self): 775 m = nn.Sequential( 776 nn.Conv2d(1, 1, 1), 777 nn.Linear(1, 1), 778 ).eval() 779 example_inputs = (torch.randn(1, 1, 1, 1),) 780 mp = prepare_fx(m, {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 781 mp_copy = copy.deepcopy(mp) 782 mq = convert_fx(mp_copy) 783 results = get_matching_subgraph_pairs(mp, mq) 784 self.assertTrue(len(results) == 2) 785 results_iter = iter(results.items()) 786 _, (subgraph_a_0, subgraph_b_0) = next(results_iter) 787 self.assertTrue(subgraph_a_0.start_node.name == '_0' and 788 subgraph_b_0.start_node.name == '_0') 789 _, (subgraph_a_1, subgraph_b_1) = next(results_iter) 790 self.assertTrue(subgraph_a_1.start_node.name == '_1' and 791 subgraph_b_1.start_node.name == '_1') 792 793 794class TestFXGraphMatcherModels(QuantizationTestCase): 795 796 @skipIfTorchDynamo("too slow") 797 @skipIfNoFBGEMM 798 @skip_if_no_torchvision 799 def test_mobilenet_v2(self): 800 # verify that mobilenetv2 graph is able to be matched 801 import torchvision 802 m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).eval().float() 803 example_inputs = (torch.randn(1, 3, 224, 224),) 804 mp = prepare_fx(copy.deepcopy(m), {'': torch.ao.quantization.default_qconfig}, example_inputs=example_inputs) 805 # assume success if no exceptions 806 results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) 807 mp_copy = copy.deepcopy(mp) 808 mq = convert_fx(mp_copy) 809 # assume success if no exceptions 810 results_mp_mq = get_matching_subgraph_pairs(mp, mq) 811 812 @skipIfNoFBGEMM 813 @skip_if_no_torchvision 814 def test_mobilenet_v2_qat(self): 815 # verify that mobilenetv2 graph is able to be matched 816 import torchvision 817 m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float() 818 example_inputs = (torch.randn(1, 3, 224, 224),) 819 mp = prepare_qat_fx( 820 copy.deepcopy(m), 821 {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, 822 example_inputs=example_inputs) 823 # assume success if no exceptions 824 results_m_mp = get_matching_subgraph_pairs(torch.fx.symbolic_trace(m), mp) 825 mp_copy = copy.deepcopy(mp) 826 mq = convert_fx(mp_copy) 827 # assume success if no exceptions 828 results_mp_mq = get_matching_subgraph_pairs(mp, mq) 829 830 831class FXNumericSuiteQuantizationTestCase(QuantizationTestCase): 832 def _test_extract_weights( 833 self, m, example_inputs, results_len=0, qconfig_dict=None, prepare_fn=prepare_fx 834 ): 835 m = torch.fx.symbolic_trace(m) 836 if qconfig_dict is None: 837 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 838 mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) 839 mp_copy = copy.deepcopy(mp) 840 mq = convert_fx(mp_copy) 841 842 # test both the public API as well as the internal GraphModule API 843 for extract_weights_fun in (extract_weights, _extract_weights_impl): 844 # test both m vs mp and mp vs mq 845 for m1, m2 in ((m, mp), (mp, mq)): 846 results = extract_weights_fun('a', m1, 'b', m2) 847 self.assertTrue( 848 len(results) == results_len, 849 f"expected len {results_len}, got len {len(results)}") 850 self.assert_ns_compare_dict_valid(results) 851 extend_logger_results_with_comparison( 852 results, 'a', 'b', compute_sqnr, 'sqnr') 853 extend_logger_results_with_comparison( 854 results, 'a', 'b', compute_normalized_l2_error, 'l2_error') 855 extend_logger_results_with_comparison( 856 results, 'a', 'b', compute_cosine_similarity, 857 'cosine_similarity') 858 859 def _test_match_activations( 860 self, m, data, prepared_expected_node_occurrence=None, results_len=0, 861 should_log_inputs=False, 862 qconfig_dict=None, 863 skip_scripting=False, 864 prepare_fn=prepare_fx, 865 ): 866 if qconfig_dict is None: 867 qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() 868 if prepare_fn == prepare_fx: 869 m.eval() 870 else: 871 m.train() 872 mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) 873 mp(*data) 874 mp_copy = copy.deepcopy(mp) 875 mq = convert_fx(mp_copy) 876 877 m_ns, mp_ns2 = add_loggers( 878 'a', m, 'b', copy.deepcopy(mp), OutputLogger, 879 should_log_inputs=should_log_inputs) 880 mp_ns, mq_ns = add_loggers( 881 'a', mp, 'b', mq, OutputLogger, 882 should_log_inputs=should_log_inputs) 883 884 if prepared_expected_node_occurrence: 885 self.checkGraphModuleNodes( 886 m_ns, expected_node_occurrence=prepared_expected_node_occurrence) 887 self.checkGraphModuleNodes( 888 mp_ns2, expected_node_occurrence=prepared_expected_node_occurrence) 889 self.checkGraphModuleNodes( 890 mp_ns, expected_node_occurrence=prepared_expected_node_occurrence) 891 self.checkGraphModuleNodes( 892 mq_ns, expected_node_occurrence=prepared_expected_node_occurrence) 893 894 if not skip_scripting: 895 m_ns = torch.jit.script(m_ns) 896 mp_ns = torch.jit.script(mp_ns) 897 mq_ns = torch.jit.script(mq_ns) 898 899 # calibrate 900 m_ns(*data) 901 mp_ns2(*data) 902 mp_ns(*data) 903 mq_ns(*data) 904 905 # check activation result correctness 906 results = [] 907 for m1, m2 in ((m_ns, mp_ns2), (mp_ns, mq_ns)): 908 act_compare_dict = extract_logger_info( 909 m1, m2, OutputLogger, 'b') 910 self.assertTrue( 911 len(act_compare_dict) == results_len, 912 f"expected len {results_len}, got len {len(act_compare_dict)}") 913 self.assert_ns_compare_dict_valid(act_compare_dict) 914 extend_logger_results_with_comparison( 915 act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') 916 extend_logger_results_with_comparison( 917 act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error') 918 extend_logger_results_with_comparison( 919 act_compare_dict, 'a', 'b', compute_cosine_similarity, 920 'cosine_similarity') 921 results.append(act_compare_dict) 922 return results 923 924 def _test_match_shadow_activations( 925 self, m, data, prepared_expected_node_occurrence=None, results_len=None, 926 should_log_inputs=False, qconfig_dict=None, skip_scripting=False, 927 prepare_fn=prepare_fx, compare_fp32_vs_fp32_prepared=True, 928 ): 929 if qconfig_dict is None: 930 qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() 931 if prepare_fn == prepare_fx: 932 m.eval() 933 else: 934 m.train() 935 print("qconfig_dict:", qconfig_dict) 936 mp = prepare_fn(copy.deepcopy(m), qconfig_dict, example_inputs=data) 937 print("prepared:", mp) 938 mp(*data) 939 mp_copy = copy.deepcopy(mp) 940 mq = convert_fx(mp_copy) 941 print("quantized:", mq) 942 943 if compare_fp32_vs_fp32_prepared: 944 m_shadows_mp = add_shadow_loggers( 945 'a', copy.deepcopy(m), 'b', copy.deepcopy(mp), 946 OutputLogger, should_log_inputs=should_log_inputs) 947 mp_shadows_mq = add_shadow_loggers( 948 'a', mp, 'b', mq, OutputLogger, 949 should_log_inputs=should_log_inputs) 950 951 if prepared_expected_node_occurrence: 952 if compare_fp32_vs_fp32_prepared: 953 self.checkGraphModuleNodes( 954 m_shadows_mp, expected_node_occurrence=prepared_expected_node_occurrence) 955 self.checkGraphModuleNodes( 956 mp_shadows_mq, expected_node_occurrence=prepared_expected_node_occurrence) 957 958 if not skip_scripting: 959 if compare_fp32_vs_fp32_prepared: 960 m_shadows_mp = torch.jit.script(m_shadows_mp) 961 mp_shadows_mq = torch.jit.script(mp_shadows_mq) 962 963 # calibrate 964 if compare_fp32_vs_fp32_prepared: 965 m_shadows_mp(*data) 966 mp_shadows_mq(*data) 967 968 # check activation result correctness 969 results = [] 970 models = (m_shadows_mp, mp_shadows_mq) if \ 971 compare_fp32_vs_fp32_prepared else (mp_shadows_mq,) 972 for model in models: 973 act_compare_dict = extract_shadow_logger_info( 974 model, OutputLogger, 'b') 975 if results_len is not None: 976 self.assertTrue( 977 len(act_compare_dict) == results_len, 978 f"expected len {results_len}, got len {len(act_compare_dict)}") 979 self.assert_ns_compare_dict_valid(act_compare_dict) 980 extend_logger_results_with_comparison( 981 act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') 982 extend_logger_results_with_comparison( 983 act_compare_dict, 'a', 'b', compute_normalized_l2_error, 'l2_error') 984 extend_logger_results_with_comparison( 985 act_compare_dict, 'a', 'b', compute_cosine_similarity, 986 'cosine_similarity') 987 results.append(act_compare_dict) 988 return results 989 990 991class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase): 992 993 @skipIfNoFBGEMM 994 def test_extract_weights_mod_ptq(self): 995 m = AllConvAndLinearFusionModules().eval() 996 example_inputs = (torch.randn(1, 1, 1, 1),) 997 self._test_extract_weights(m, example_inputs, results_len=14) 998 999 @skipIfNoFBGEMM 1000 def test_extract_weights_mod_qat(self): 1001 m = AllConvAndLinearFusionModules().train() 1002 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1003 example_inputs = (torch.randn(1, 1, 1, 1),) 1004 self._test_extract_weights( 1005 m, example_inputs, results_len=14, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) 1006 1007 @skipIfNoFBGEMM 1008 def test_extract_weights_linear_fun_ptq(self): 1009 m = LinearReluLinearFunctional().eval() 1010 example_inputs = (torch.randn(1, 4),) 1011 self._test_extract_weights(m, example_inputs, results_len=2) 1012 1013 @skipIfNoFBGEMM 1014 def test_extract_weights_linear_fun_qat(self): 1015 m = LinearReluLinearFunctional().train() 1016 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1017 example_inputs = (torch.randn(1, 4),) 1018 self._test_extract_weights( 1019 m, example_inputs, results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) 1020 1021 @skipIfNoFBGEMM 1022 def test_extract_weights_conv_fun_ptq(self): 1023 w1d = torch.randn(1, 1, 1) 1024 w2d = torch.randn(1, 1, 1, 1) 1025 w3d = torch.randn(1, 1, 1, 1, 1) 1026 b1d = torch.randn(1) 1027 b2d = torch.randn(1) 1028 b3d = torch.randn(1) 1029 m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).eval() 1030 example_inputs = (torch.randn(1, 1, 1, 1),) 1031 self._test_extract_weights(m, example_inputs, results_len=6) 1032 1033 @skipIfNoFBGEMM 1034 def test_extract_weights_conv_fun_qat(self): 1035 w1d = torch.randn(1, 1, 1) 1036 w2d = torch.randn(1, 1, 1, 1) 1037 w3d = torch.randn(1, 1, 1, 1, 1) 1038 b1d = torch.randn(1) 1039 b2d = torch.randn(1) 1040 b3d = torch.randn(1) 1041 m = AllConvFunctional(w1d, w2d, w3d, b1d, b2d, b3d).train() 1042 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1043 example_inputs = (torch.randn(1, 1, 1, 1),) 1044 self._test_extract_weights( 1045 m, example_inputs, results_len=6, qconfig_dict=qconfig_dict, prepare_fn=prepare_qat_fx) 1046 1047 @skipIfNoFBGEMM 1048 def test_extract_weights_dynamic(self): 1049 # TODO(future PR): add Linear-ReLU, after #55393 is fixed. 1050 m = nn.Sequential(nn.Linear(1, 1)).eval() 1051 qconfig_dict = { 1052 'object_type': [ 1053 (nn.Linear, default_dynamic_qconfig), 1054 ], 1055 } 1056 example_inputs = (torch.randn(1, 1),) 1057 self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) 1058 1059 @skipIfNoFBGEMM 1060 def test_extract_weights_fqn(self): 1061 m = nn.Sequential( 1062 nn.Sequential(nn.Conv2d(1, 1, 1)), 1063 nn.Conv2d(1, 1, 1), 1064 ).eval() 1065 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1066 example_inputs = (torch.randn(1, 1, 1, 1),) 1067 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1068 mq = convert_fx(copy.deepcopy(mp)) 1069 results = extract_weights('a', mp, 'b', mq) 1070 fqn_a_0 = results['_0_0']['weight']['a'][0]['fqn'] 1071 fqn_b_0 = results['_0_0']['weight']['b'][0]['fqn'] 1072 self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0) 1073 fqn_a_1 = results['_1']['weight']['a'][0]['fqn'] 1074 fqn_b_1 = results['_1']['weight']['b'][0]['fqn'] 1075 self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1) 1076 1077 def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx): 1078 m = nn.Sequential( 1079 torch.ao.quantization.QuantStub(), 1080 nn.Conv2d(1, 1, 1), 1081 nn.Conv2d(1, 1, 1), 1082 ).eval() 1083 qconfig_dict = None 1084 if prepare_fn == prepare_qat_fx: 1085 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1086 expected_occurrence = { 1087 ns.call_module(OutputLogger): 2, 1088 } 1089 self._test_match_activations( 1090 m, (torch.randn(2, 1, 2, 2),), 1091 prepared_expected_node_occurrence=expected_occurrence, 1092 results_len=2, qconfig_dict=qconfig_dict, prepare_fn=prepare_fn) 1093 1094 @skipIfNoFBGEMM 1095 def test_match_activations_mod_ptq(self): 1096 self._test_match_activations_mod_impl(prepare_fn=prepare_fx) 1097 1098 @skipIfNoFBGEMM 1099 def test_match_activations_mod_qat(self): 1100 self._test_match_activations_mod_impl(prepare_fn=prepare_qat_fx) 1101 1102 def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx): 1103 m = LinearReluLinearFunctional().eval() 1104 qconfig_dict = None 1105 if prepare_fn == prepare_qat_fx: 1106 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1107 expected_occurrence = { 1108 ns.call_module(OutputLogger): 2, 1109 } 1110 self._test_match_activations( 1111 m, (torch.randn(4, 4),), 1112 prepared_expected_node_occurrence=expected_occurrence, 1113 results_len=2, prepare_fn=prepare_fn, qconfig_dict=qconfig_dict) 1114 1115 @skipIfNoFBGEMM 1116 def test_match_activations_fun_ptq(self): 1117 self._test_match_activations_fun_impl(prepare_fn=prepare_fx) 1118 1119 @skipIfNoFBGEMM 1120 def test_match_activations_fun_qat(self): 1121 self._test_match_activations_fun_impl(prepare_fn=prepare_qat_fx) 1122 1123 @skipIfNoFBGEMM 1124 def test_match_activations_meth_ptq(self): 1125 """ 1126 Verify that add_loggers works on methods 1127 """ 1128 class M(nn.Module): 1129 def forward(self, x): 1130 x = x.sigmoid() 1131 return x 1132 1133 m = M().eval() 1134 res = self._test_match_activations( 1135 m, (torch.randn(4, 4),), 1136 results_len=1) 1137 1138 @skipIfNoFBGEMM 1139 def test_match_activations_fqn(self): 1140 m = nn.Sequential( 1141 nn.Sequential(nn.Conv2d(1, 1, 1)), 1142 nn.Conv2d(1, 1, 1), 1143 ).eval() 1144 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1145 example_inputs = (torch.randn(1, 1, 1, 1),) 1146 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1147 mq = convert_fx(copy.deepcopy(mp)) 1148 mp_ns, mq_ns = add_loggers('a', mp, 'b', mq, OutputLogger) 1149 datum = torch.randn(1, 1, 1, 1) 1150 mp_ns(datum) 1151 mq_ns(datum) 1152 1153 results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'b') 1154 fqn_a_0 = results['_0_0']['node_output']['a'][0]['fqn'] 1155 fqn_b_0 = results['_0_0']['node_output']['b'][0]['fqn'] 1156 self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0) 1157 fqn_a_1 = results['_1']['node_output']['a'][0]['fqn'] 1158 fqn_b_1 = results['_1']['node_output']['b'][0]['fqn'] 1159 self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1) 1160 1161 def _test_add_shadow_loggers_mod_impl(self, prepare_fn=prepare_fx): 1162 m = nn.Sequential( 1163 nn.Conv2d(1, 1, 1), 1164 nn.Conv2d(1, 1, 1), 1165 ).eval() 1166 qconfig_dict = None 1167 if prepare_fn == prepare_qat_fx: 1168 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1169 res = self._test_match_shadow_activations( 1170 m, (torch.randn(1, 1, 4, 4),), results_len=2, 1171 prepare_fn=prepare_fn, qconfig_dict=qconfig_dict) 1172 1173 @skipIfNoFBGEMM 1174 def test_add_shadow_loggers_mod_ptq(self): 1175 self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_fx) 1176 1177 @skipIfNoFBGEMM 1178 def test_add_shadow_loggers_mod_qat(self): 1179 self._test_add_shadow_loggers_mod_impl(prepare_fn=prepare_qat_fx) 1180 1181 def _test_add_shadow_loggers_fun_impl(self, prepare_fn=prepare_fx): 1182 m = LinearReluLinearFunctional() 1183 qconfig_dict = None 1184 if prepare_fn == prepare_qat_fx: 1185 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1186 res = self._test_match_shadow_activations( 1187 m, (torch.randn(4, 4),), results_len=2, prepare_fn=prepare_fn, 1188 qconfig_dict=qconfig_dict) 1189 1190 @skipIfNoFBGEMM 1191 def test_add_shadow_loggers_fun_ptq(self): 1192 self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_fx) 1193 1194 @skipIfNoFBGEMM 1195 def test_add_shadow_loggers_fun_qat(self): 1196 self._test_add_shadow_loggers_fun_impl(prepare_fn=prepare_qat_fx) 1197 1198 @skipIfNoFBGEMM 1199 def test_add_shadow_loggers_meth_ptq(self): 1200 """ 1201 Verify that add_loggers works on methods 1202 """ 1203 class M(nn.Module): 1204 def forward(self, x): 1205 x = x.sigmoid() 1206 return x 1207 1208 m = M().eval() 1209 res = self._test_match_shadow_activations( 1210 m, (torch.randn(4, 4),), 1211 # For now, sigmoid is not supported for shadowing because the dtype 1212 # inference for it is not implemented yet. So, this is just testing 1213 # that shadowing models with method calls does not crash. 1214 results_len=0) 1215 1216 @skipIfNoFBGEMM 1217 def test_shadow_activations_fqn(self): 1218 m = nn.Sequential( 1219 nn.Sequential(nn.Conv2d(1, 1, 1)), 1220 nn.Conv2d(1, 1, 1), 1221 ).eval() 1222 qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() 1223 example_inputs = (torch.randn(1, 1, 1, 1),) 1224 mp = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) 1225 mq = convert_fx(copy.deepcopy(mp)) 1226 mp_shadows_mq = add_shadow_loggers('a', mp, 'b', mq, OutputLogger) 1227 datum = torch.randn(1, 1, 1, 1) 1228 mp_shadows_mq(datum) 1229 1230 results = extract_shadow_logger_info(mp_shadows_mq, OutputLogger, 'b') 1231 fqn_a_0 = results['_0_0']['node_output']['a'][0]['fqn'] 1232 fqn_b_0 = results['_0_0']['node_output']['b'][0]['fqn'] 1233 self.assertTrue(fqn_a_0 == '0.0' and fqn_a_0 == fqn_b_0) 1234 fqn_a_1 = results['_1']['node_output']['a'][0]['fqn'] 1235 fqn_b_1 = results['_1']['node_output']['b'][0]['fqn'] 1236 self.assertTrue(fqn_a_1 == '1' and fqn_a_1 == fqn_b_1) 1237 1238 @skipIfNoFBGEMM 1239 def test_logging_inputs(self): 1240 """ 1241 Verifies that logging inputs works correctly 1242 """ 1243 class M(nn.Module): 1244 def __init__(self) -> None: 1245 super().__init__() 1246 self.conv = nn.Conv2d(1, 1, 1) 1247 1248 def forward(self, x): 1249 x = self.conv(x) 1250 x = torch.cat([x, x], dim=0) 1251 return x 1252 1253 m = M().eval() 1254 self._test_match_shadow_activations( 1255 m, (torch.randn(1, 1, 4, 4),), 1256 results_len=1, 1257 should_log_inputs=True) 1258 1259 @skipIfNoFBGEMM 1260 def test_ops_with_same_fp32_and_int8_signature(self): 1261 """ 1262 Verifies that we can match pairs of ops which have the same aten 1263 signature for fp32 and int8 tensors. 1264 """ 1265 class M(nn.Module): 1266 def __init__(self) -> None: 1267 super().__init__() 1268 self.max_pool_2d = nn.MaxPool2d(2) 1269 1270 def forward(self, x): 1271 x = self.max_pool_2d(x) 1272 x = F.relu(x) 1273 return x 1274 1275 m = M().eval() 1276 self._test_match_activations( 1277 m, (torch.randn(1, 1, 2, 2),), 1278 results_len=2) 1279 1280 @skipIfNoFBGEMM 1281 def test_add_mul_inputs_activations(self): 1282 m = AddMulFunctional().eval() 1283 res = self._test_match_activations( 1284 m, (torch.randn(2, 2), torch.randn(2, 2)), 1285 results_len=6, should_log_inputs=True) 1286 1287 @skipIfNoFBGEMM 1288 def test_linear_fp16_weights(self): 1289 qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} 1290 m = LinearReluFunctional().eval() 1291 example_inputs = (torch.randn(1, 4),) 1292 self._test_extract_weights(m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) 1293 1294 @skipIfNoFBGEMM 1295 def test_linear_fp16_activations(self): 1296 for should_log_inputs in (True, False): 1297 qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} 1298 m = LinearReluFunctional().eval() 1299 num_loggers = 2 if should_log_inputs else 1 1300 expected_occurrence = { 1301 ns.call_module(OutputLogger): num_loggers, 1302 } 1303 res = self._test_match_activations( 1304 m, (torch.randn(4, 4),), 1305 prepared_expected_node_occurrence=expected_occurrence, 1306 results_len=1, 1307 qconfig_dict=qconfig_dict, 1308 should_log_inputs=should_log_inputs) 1309 1310 @skipIfNoFBGEMM 1311 def test_linear_fp16_shadow_activations(self): 1312 for should_log_inputs in (True, False): 1313 qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} 1314 m = LinearReluFunctional().eval() 1315 num_loggers = 4 if should_log_inputs else 2 1316 expected_occurrence = { 1317 ns.call_module(OutputLogger): num_loggers, 1318 } 1319 res2 = self._test_match_shadow_activations( 1320 m, (torch.randn(4, 4),), 1321 prepared_expected_node_occurrence=expected_occurrence, 1322 results_len=1, 1323 qconfig_dict=qconfig_dict, 1324 should_log_inputs=should_log_inputs) 1325 1326 @skipIfNoFBGEMM 1327 def test_linear_fp16_vs_linear_fp16_shadow_activations(self): 1328 m = LinearFunctional().eval() 1329 qconfig_dict = {'': torch.ao.quantization.float16_static_qconfig} 1330 example_inputs = (torch.randn(1, 4),) 1331 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1332 mq1 = convert_fx(copy.deepcopy(mp)) 1333 mq2 = convert_fx(copy.deepcopy(mp)) 1334 mq1_shadows_mq2 = _add_shadow_loggers_impl( 1335 'a', mq1, 'b', mq2, OutputLogger, should_log_inputs=False) 1336 mq1_shadows_mq2(torch.randn(4, 4)) 1337 act_compare_dict = extract_shadow_logger_info( 1338 mq1_shadows_mq2, OutputLogger, 'b') 1339 self.assertTrue(len(act_compare_dict) == 1) 1340 self.assert_ns_compare_dict_valid(act_compare_dict) 1341 1342 1343 @skipIfNoFBGEMM 1344 def test_op_with_either_fp32_or_int8_input(self): 1345 """ 1346 Verify that shadowing works with ops which accept either fp32 or 1347 int8 inputs. 1348 """ 1349 class M(nn.Module): 1350 def __init__(self) -> None: 1351 super().__init__() 1352 self.relu = nn.ReLU() 1353 1354 def forward(self, x): 1355 x = self.relu(x) 1356 x = F.relu(x) 1357 return x 1358 1359 m = M() 1360 res = self._test_match_shadow_activations( 1361 m, (torch.randn(4, 4),), 1362 # Note: shadowing relu by itself is currently not supported, 1363 # this test is just testing that it does not crash 1364 results_len=0) 1365 1366 def _test_int8_shadows_int8_impl(self, m): 1367 """ 1368 Verify that shadowing works where both modules are int8 1369 """ 1370 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1371 example_inputs = (torch.randn(4, 1, 4, 4),) 1372 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1373 mp(*example_inputs) 1374 mq1 = convert_fx(copy.deepcopy(mp)) 1375 mq2 = convert_fx(mp) 1376 mq1_shadows_mq2 = add_shadow_loggers('a', mq1, 'b', mq2, OutputLogger) 1377 mq1_shadows_mq2(torch.randn(4, 1, 4, 4)) 1378 act_compare_dict = extract_shadow_logger_info( 1379 mq1_shadows_mq2, OutputLogger, 'b') 1380 self.assertTrue(len(act_compare_dict) == 1) 1381 self.assert_ns_compare_dict_valid(act_compare_dict) 1382 1383 @skipIfNoFBGEMM 1384 def test_int8_shadows_int8_mod(self): 1385 m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() 1386 self._test_int8_shadows_int8_impl(m) 1387 1388 @skipIfNoFBGEMM 1389 def test_int8_shadows_int8_fun(self): 1390 m = LinearFunctional().eval() 1391 self._test_int8_shadows_int8_impl(m) 1392 1393 @skipIfNoFBGEMM 1394 def test_user_module_scriptable(self): 1395 # Logging of the output of this class is not supported, because it is 1396 # neither a tensor or an RNN return type. 1397 class M1(nn.Module): 1398 def forward(self, x): 1399 x1 = x * 2 1400 x2 = x * 4 1401 return (x1, x2) 1402 1403 class M2(nn.Module): 1404 def __init__(self) -> None: 1405 super().__init__() 1406 self.m1 = M1() 1407 1408 def forward(self, x): 1409 x1, x2 = self.m1(x) 1410 return x1, x2 1411 1412 m = M2().eval() 1413 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1414 prepare_custom_config_dict = { 1415 'non_traceable_module_class': [M1], 1416 } 1417 example_inputs = (torch.randn(1),) 1418 mp1 = prepare_fx( 1419 m, 1420 qconfig_dict, 1421 example_inputs=example_inputs, 1422 prepare_custom_config=prepare_custom_config_dict) 1423 mp2 = copy.deepcopy(mp1) 1424 unmatchable_types_map = get_unmatchable_types_map() 1425 unmatchable_types_map['mods_unmatchable'].add(M1) 1426 mp1_ns, mp2_ns = _add_loggers_impl( 1427 'a', mp1, 'b', mp2, OutputLogger, should_log_inputs=False, 1428 unmatchable_types_map=unmatchable_types_map) 1429 1430 # Scripting a model with loggers should succeed. If it fails because of 1431 # incorrect dtypes, we can blocklist the associated types from being instrumented. 1432 mp1_ns_scripted = torch.jit.script(mp1_ns) 1433 mp2_ns_scripted = torch.jit.script(mp2_ns) 1434 1435 @skipIfNoFBGEMM 1436 def test_user_module(self): 1437 """ 1438 For user defined modules, 1439 1. weight extraction should not crash 1440 2. unshadowed activations should only have loggers for known types 1441 3. shadowed activations should only have loggers for known types with 1442 known dtypes 1443 """ 1444 class UserModule(nn.Module): 1445 def forward(self, x): 1446 return x 1447 1448 class M(nn.Module): 1449 def __init__(self) -> None: 1450 super().__init__() 1451 self.linear = nn.Linear(1, 1) 1452 self.user_module = UserModule() 1453 1454 def forward(self, x): 1455 x = self.linear(x) 1456 x = self.user_module(x) 1457 return x 1458 1459 m = M().eval() 1460 1461 # quantize without tracing through UserModule 1462 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1463 prepare_custom_config_dict = {'non_traceable_module_name': ['user_module']} 1464 example_inputs = (torch.randn(1, 1, 1),) 1465 mp = prepare_fx( 1466 m, 1467 qconfig_dict, 1468 example_inputs=example_inputs, 1469 prepare_custom_config=prepare_custom_config_dict) 1470 mp(*example_inputs) 1471 mq = convert_fx(copy.deepcopy(mp)) 1472 1473 # weight extraction should not crash 1474 weights = _extract_weights_impl('fp32_prepared', mp, 'int8', mq) 1475 1476 # unshadowed activations should have loggers 1477 1478 # add loggers, without retracing 1479 # note: converting again because we cannot copy a quantized linear 1480 mp_ns, mq_ns = _add_loggers_impl( 1481 'fp32_prepared', copy.deepcopy(mp), 'int8', 1482 convert_fx(copy.deepcopy(mp)), OutputLogger, 1483 should_log_inputs=True) 1484 # both fp32 and int8 models should have 2 loggers each, 2 for I/O 1485 # of linear, and 0 for I/O of user_module 1486 unshadowed_expected_occurrence = { 1487 ns.call_module(OutputLogger): 2, 1488 } 1489 self.checkGraphModuleNodes( 1490 mp_ns, expected_node_occurrence=unshadowed_expected_occurrence) 1491 self.checkGraphModuleNodes( 1492 mq_ns, expected_node_occurrence=unshadowed_expected_occurrence) 1493 1494 # shadowed activations should only have loggers for nodes where 1495 # the types are known and we can do a dtype cast 1496 1497 # add shadow loggers, without retracing 1498 mp_shadows_mq_ns = _add_shadow_loggers_impl( 1499 'fp32_prepared', mp, 'int8', mq, OutputLogger, 1500 should_log_inputs=True) 1501 # 4 loggers for I/O of linear, 0 loggers for I/O of user_module 1502 shadowed_expected_occurrence = { 1503 ns.call_module(OutputLogger): 4, 1504 } 1505 self.checkGraphModuleNodes( 1506 mp_shadows_mq_ns, expected_node_occurrence=shadowed_expected_occurrence) 1507 1508 def test_op_io_dtype_coverage(self): 1509 """ 1510 Tests that all the ops quantization cares about have input and output 1511 dtypes defined. 1512 """ 1513 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 1514 type_a_related_to_b = \ 1515 get_type_a_related_to_b(base_name_to_sets_of_related_ops) 1516 1517 # TODO(future PR): clean this up 1518 node_type_to_io_type_map = get_node_type_to_io_type_map() 1519 FUNS_IO_TYPE_FP32 = node_type_to_io_type_map['funs_io_type_fp32'] 1520 FUNS_IO_TYPE_INT8 = node_type_to_io_type_map['funs_io_type_int8'] 1521 FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['funs_io_type_fp32_or_int8'] 1522 MODS_IO_TYPE_FP32 = node_type_to_io_type_map['mods_io_type_fp32'] 1523 MODS_IO_TYPE_INT8 = node_type_to_io_type_map['mods_io_type_int8'] 1524 MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['mods_io_type_fp32_or_int8'] 1525 METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map['meths_io_type_fp32_or_int8'] 1526 1527 unmatchable_types_map = get_unmatchable_types_map() 1528 FUNS_UNMATCHABLE = unmatchable_types_map['funs_unmatchable'] 1529 MODS_UNMATCHABLE = unmatchable_types_map['mods_unmatchable'] 1530 METHS_UNMATCHABLE = unmatchable_types_map['meths_unmatchable'] 1531 1532 # 1. check static quant module mappings 1533 static_quant_mod_mappings = get_default_static_quant_module_mappings() 1534 for fp32_type, int8_type in static_quant_mod_mappings.items(): 1535 types_to_skip = ( 1536 torch.ao.quantization.QuantStub, 1537 torch.ao.quantization.DeQuantStub, 1538 nnq.FloatFunctional, 1539 # TODO(future PR): look into whether shadowing embeddings 1540 # makes sense 1541 nn.Embedding, 1542 nn.EmbeddingBag, 1543 # the ConvTranspose3d swap is not implemented in FX Graph 1544 # mode quantization yet 1545 nn.ConvTranspose3d, 1546 # the GroupNorm swap is not implemented in FX Graph 1547 # mode quantization yet 1548 nn.GroupNorm, 1549 # nnq.ReLU6 is no longer swapped, because nn.ReLU6 can 1550 # take quantized inputs 1551 nn.ReLU6, 1552 ) 1553 if fp32_type in types_to_skip: 1554 continue 1555 self.assertTrue( 1556 fp32_type in MODS_IO_TYPE_FP32, 1557 f"missing IO type handling for f{fp32_type}") 1558 self.assertTrue( 1559 int8_type in MODS_IO_TYPE_INT8, 1560 f"missing IO type handling for f{int8_type}") 1561 1562 # 2. check static quant op mappings 1563 static_quant_fun_mappings = get_default_float_to_quantized_operator_mappings() 1564 for fp32_type, int8_type in static_quant_fun_mappings.items(): 1565 self.assertTrue( 1566 fp32_type in FUNS_IO_TYPE_FP32, 1567 f"missing IO type handling for f{fp32_type}") 1568 self.assertTrue( 1569 int8_type in FUNS_IO_TYPE_INT8, 1570 f"missing IO type handling for f{int8_type}") 1571 1572 # 3. check dynamic quant mappings 1573 dynamic_quant_mappings = get_default_dynamic_quant_module_mappings() 1574 for fp32_type1, fp32_type2 in dynamic_quant_mappings.items(): 1575 # TODO(future PR): verify correct I/O for these and remove from 1576 # this list. 1577 types_to_skip = ( 1578 nn.GRUCell, 1579 nn.GRU, 1580 nn.LSTMCell, 1581 nn.RNNCell, 1582 # TODO(future PR): look into whether shadowing embeddings 1583 # makes sense 1584 nn.Embedding, 1585 nn.EmbeddingBag, 1586 ) 1587 if fp32_type1 in types_to_skip: 1588 continue 1589 self.assertTrue( 1590 fp32_type1 in MODS_IO_TYPE_FP32, 1591 f"missing IO type handling for f{fp32_type1}") 1592 self.assertTrue( 1593 fp32_type2 in MODS_IO_TYPE_FP32, 1594 f"missing IO type handling for f{fp32_type2}") 1595 1596 # 4. go through the ops mapped to each QuantizeHandler type, and verify 1597 # correctness. 1598 default_quant_patterns = get_all_quant_patterns() 1599 for pattern, qhandler_cls in default_quant_patterns.items(): 1600 base_op = None 1601 if isinstance(pattern, tuple): 1602 base_op = pattern[-1] 1603 elif isinstance(pattern, str): 1604 base_op = pattern 1605 else: 1606 base_op = pattern 1607 1608 if ( 1609 qhandler_cls in ( 1610 qh.BinaryOpQuantizeHandler, 1611 qh.RNNDynamicQuantizeHandler, 1612 ) 1613 ): 1614 # TODO(future PR): implement shadowing for binary ops 1615 # TODO(future PR): implement shadowing for RNN ops 1616 continue 1617 elif qhandler_cls == qh.CatQuantizeHandler: 1618 self.assertTrue( 1619 base_op in FUNS_IO_TYPE_FP32_OR_INT8, 1620 f"missing IO type handling for {base_op}") 1621 elif ( 1622 qhandler_cls in ( 1623 qh.ConvReluQuantizeHandler, 1624 qh.LinearReLUQuantizeHandler, 1625 qh.BatchNormQuantizeHandler, 1626 qh.DefaultNodeQuantizeHandler, 1627 ) 1628 ): 1629 self.assertTrue( 1630 (base_op in FUNS_IO_TYPE_FP32) or (base_op in MODS_IO_TYPE_FP32), 1631 f"missing IO type handling for {base_op}") 1632 elif ( 1633 qhandler_cls in ( 1634 qh.FixedQParamsOpQuantizeHandler, 1635 qh.CopyNodeQuantizeHandler, 1636 qh.GeneralTensorShapeOpQuantizeHandler, 1637 ) 1638 ): 1639 if ( 1640 base_op in FUNS_UNMATCHABLE or 1641 base_op in MODS_UNMATCHABLE or 1642 base_op in METHS_UNMATCHABLE 1643 ): 1644 continue 1645 1646 self.assertTrue( 1647 (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or 1648 (base_op in MODS_IO_TYPE_FP32_OR_INT8) or 1649 (base_op in METHS_IO_TYPE_FP32_OR_INT8) or 1650 # Softmax has a different signature for the quantized 1651 # version, so it does not fit into the cases above. 1652 (base_op is torch.nn.Softmax), 1653 f"missing IO type handling for {base_op}") 1654 elif qhandler_cls == qh.EmbeddingQuantizeHandler: 1655 # embedding shadowing is not implemented, for now 1656 continue 1657 else: 1658 if ( 1659 base_op in FUNS_UNMATCHABLE or 1660 base_op in MODS_UNMATCHABLE or 1661 base_op in METHS_UNMATCHABLE 1662 ): 1663 continue 1664 if qhandler_cls(None, {}).is_general_tensor_value_op(): 1665 self.assertTrue( 1666 (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or 1667 (base_op in MODS_IO_TYPE_FP32_OR_INT8) or 1668 (base_op in METHS_IO_TYPE_FP32_OR_INT8), 1669 f"missing IO type handling for {base_op} using {qhandler_cls}") 1670 else: 1671 self.assertTrue( 1672 (base_op in FUNS_IO_TYPE_FP32_OR_INT8) or 1673 (base_op in MODS_IO_TYPE_FP32_OR_INT8) or 1674 (base_op in METHS_IO_TYPE_FP32_OR_INT8) or 1675 (base_op in FUNS_IO_TYPE_FP32) or 1676 (base_op in MODS_IO_TYPE_FP32) or 1677 f"missing IO type handling for {base_op} using {qhandler_cls}") 1678 1679 @skipIfNoFBGEMM 1680 def test_user_defined_function(self): 1681 """ 1682 Verify that NS APIs work on user defined functions 1683 """ 1684 class M1(nn.Module): 1685 def __init__(self) -> None: 1686 super().__init__() 1687 self.w1 = nn.Parameter(torch.empty(1, 1)) 1688 self.b1 = nn.Parameter(torch.zeros(1)) 1689 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 1690 1691 def forward(self, x): 1692 x = F.hardswish(x) 1693 x = x.sigmoid() 1694 x = F.linear(x, self.w1, self.b1) 1695 return x 1696 1697 class M2(nn.Module): 1698 def __init__(self) -> None: 1699 super().__init__() 1700 self.w1 = nn.Parameter(torch.empty(1, 1)) 1701 self.b1 = nn.Parameter(torch.zeros(1)) 1702 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 1703 1704 def forward(self, x): 1705 x = _wrapped_hardswish(x) 1706 x = _wrapped_sigmoid(x) 1707 x = _wrapped_linear(x, self.w1, self.b1) 1708 return x 1709 1710 qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping() 1711 example_inputs = (torch.randn(1, 1),) 1712 m1 = prepare_fx(M1().eval(), qconfig_mapping, example_inputs=example_inputs) 1713 m2 = prepare_fx(M2().eval(), qconfig_mapping, example_inputs=example_inputs) 1714 data = torch.randn(1, 1) 1715 1716 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 1717 add_op_to_sets_of_related_ops( 1718 base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish) 1719 add_op_to_sets_of_related_ops( 1720 base_name_to_sets_of_related_ops, _wrapped_sigmoid, F.sigmoid) 1721 add_op_to_sets_of_related_ops( 1722 base_name_to_sets_of_related_ops, _wrapped_linear, F.linear) 1723 1724 op_to_type_to_weight_extraction_fn = \ 1725 get_op_to_type_to_weight_extraction_fn() 1726 op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \ 1727 torch.ao.ns.fx.weight_utils.get_linear_fun_weight 1728 1729 # test compare weights 1730 results = extract_weights( 1731 'a', m1, 'b', m2, 1732 base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, 1733 op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn) 1734 self.assertTrue(len(results) == 1) 1735 self.assertTrue(len(results['_wrapped_linear']['weight']) == 2) 1736 1737 # test unshadowed activations 1738 1739 m1_ns, m2_ns = _add_loggers_impl( 1740 'a', copy.deepcopy(m1), 'b', copy.deepcopy(m2), OutputLogger, 1741 should_log_inputs=False, 1742 base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops) 1743 1744 # calibrate 1745 m1_ns(data) 1746 m2_ns(data) 1747 1748 # check activation result correctness 1749 act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b') 1750 self.assertTrue(len(act_compare_dict) == 3) 1751 self.assert_ns_compare_dict_valid(act_compare_dict) 1752 1753 # test shadowed activations 1754 1755 node_type_to_io_type_map = get_node_type_to_io_type_map() 1756 node_type_to_io_type_map['funs_io_type_fp32'].add(_wrapped_hardswish) 1757 node_type_to_io_type_map['funs_io_type_fp32'].add(_wrapped_sigmoid) 1758 1759 m2_shadows_m1_ns = _add_shadow_loggers_impl( 1760 'a', m2, 'b', m1, OutputLogger, 1761 should_log_inputs=False, 1762 base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, 1763 node_type_to_io_type_map=node_type_to_io_type_map) 1764 1765 # calibrate 1766 m2_shadows_m1_ns(data) 1767 1768 # check activation result correctness 1769 act_compare_dict = extract_shadow_logger_info( 1770 m2_shadows_m1_ns, OutputLogger, 'b') 1771 self.assertTrue(len(act_compare_dict) == 2) 1772 self.assert_ns_compare_dict_valid(act_compare_dict) 1773 1774 @skipIfNoFBGEMM 1775 def test_layer_names(self): 1776 m = nn.Sequential( 1777 nn.Conv2d(1, 1, 1), 1778 nn.Conv2d(1, 1, 1), 1779 nn.Sigmoid(), 1780 ).eval() 1781 qconfig_mapping = torch.ao.quantization.get_default_qconfig_mapping("fbgemm") 1782 example_inputs = (torch.randn(1, 1, 1, 1),) 1783 mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) 1784 mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1785 1786 # extract weights 1787 results = extract_weights('fp32', mp, 'int8', mq) 1788 mq_node_names = [node.name for node in mq.graph.nodes] 1789 for layer_name in results.keys(): 1790 self.assertTrue(layer_name in mq_node_names) 1791 1792 # match activations 1793 mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1794 mp_ns, mq_ns = add_loggers( 1795 'fp32', copy.deepcopy(mp), 'int8', mq, OutputLogger) 1796 data = torch.randn(1, 1, 1, 1) 1797 mp_ns(data) 1798 mq_ns(data) 1799 results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'int8') 1800 mq_node_names = [node.name for node in mq_ns.graph.nodes] 1801 for layer_name in results.keys(): 1802 self.assertTrue(layer_name in mq_node_names) 1803 1804 # match shadow activations 1805 mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1806 mp_shadows_mq = add_shadow_loggers( 1807 'fp32', mp, 'int8', mq, OutputLogger) 1808 mp_shadows_mq(data) 1809 results = extract_shadow_logger_info( 1810 mp_shadows_mq, OutputLogger, 'int8') 1811 mq_node_names = [node.name for node in mp_shadows_mq.graph.nodes] 1812 for layer_name in results.keys(): 1813 self.assertTrue(layer_name in mq_node_names) 1814 1815 @skipIfNoFBGEMM 1816 def test_extend_logger_results_with_comparison(self): 1817 m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)).eval() 1818 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1819 example_inputs = (torch.randn(1, 1, 1, 1),) 1820 mp = torch.ao.quantization.quantize_fx.prepare_fx( 1821 m, qconfig_dict, example_inputs=example_inputs) 1822 mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1823 1824 # extract weights 1825 results = extract_weights('fp32', mp, 'int8', mq) 1826 extend_logger_results_with_comparison( 1827 results, 'fp32', 'int8', compute_sqnr, 'sqnr_int8_vs_fp32') 1828 extend_logger_results_with_comparison( 1829 results, 'fp32', 'int8', compute_normalized_l2_error, 'l2_error_int8_vs_fp32') 1830 extend_logger_results_with_comparison( 1831 results, 'fp32', 'int8', compute_cosine_similarity, 1832 'cosine_similarity_int8_vs_fp32') 1833 1834 for layer_results in results.values(): 1835 assert 'sqnr_int8_vs_fp32' in \ 1836 layer_results['weight']['int8'][0].keys() 1837 assert 'l2_error_int8_vs_fp32' in \ 1838 layer_results['weight']['int8'][0].keys() 1839 assert 'cosine_similarity_int8_vs_fp32' in \ 1840 layer_results['weight']['int8'][0].keys() 1841 1842 @skipIfNoFBGEMM 1843 def test_int8_shadows_fp32_simple(self): 1844 m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), nn.ReLU()).eval() 1845 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1846 example_inputs = (torch.randn(1, 1, 1, 1),) 1847 mp = torch.ao.quantization.quantize_fx.prepare_fx( 1848 m, qconfig_dict, example_inputs=example_inputs) 1849 mp(torch.randn(1, 1, 1, 1)) 1850 mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1851 mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1852 mp_shadows_mq = add_shadow_loggers( 1853 'int8', mq, 'fp32', mp, OutputLogger) 1854 1855 # verify that scale and zp were extracted correctly 1856 1857 # for the first op, the scale+zp live as attributes on the module 1858 scale_0 = mp_shadows_mq._0_input_scale_0 1859 scale_0_ref = getattr(mq_ref, '0_input_scale_0') 1860 self.assertEqual(scale_0, scale_0_ref) 1861 zp_0 = mp_shadows_mq._0_input_zero_point_0 1862 zp_0_ref = getattr(mq_ref, '0_input_zero_point_0') 1863 self.assertEqual(zp_0, zp_0_ref) 1864 1865 # for the second op, the scale and zp of input to second op 1866 # must equal to scale and zp of output of first op 1867 scale_1 = mp_shadows_mq._1_input_scale_0 1868 scale_1_ref = getattr(mq_ref, '0').scale 1869 self.assertEqual(scale_1, scale_1_ref) 1870 zp_1 = mp_shadows_mq._1_input_zero_point_0 1871 zp_1_ref = getattr(mq_ref, '0').zero_point 1872 self.assertEqual(zp_1, zp_1_ref) 1873 1874 # verify running data works 1875 mp_shadows_mq(torch.randn(1, 1, 1, 1)) 1876 act_compare_dict = extract_shadow_logger_info( 1877 mp_shadows_mq, OutputLogger, 'fp32') 1878 self.assertTrue(len(act_compare_dict) == 2) 1879 self.assert_ns_compare_dict_valid(act_compare_dict) 1880 1881 @skipIfNoFBGEMM 1882 def test_int8_shadows_fp32_coverage(self): 1883 class M(torch.nn.Module): 1884 def __init__(self) -> None: 1885 super().__init__() 1886 self.adaptive_avg_pool = nn.AdaptiveAvgPool2d(1) 1887 self.conv = nn.Conv2d(1, 1, 1) 1888 1889 def forward(self, x): 1890 x = self.adaptive_avg_pool(x) 1891 # input qparams of conv will be input qparams of adaptive_avg_pool 1892 x = self.conv(x) 1893 x = torch.mul(x, x) 1894 x = self.conv(x) 1895 x = torch.add(x, x) 1896 x = F.relu(x) 1897 x = self.conv(x) 1898 return x 1899 1900 m = M().eval() 1901 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 1902 example_inputs = (torch.randn(1, 1, 1, 1),) 1903 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1904 mp(*example_inputs) 1905 mq = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1906 mq_ref = torch.ao.quantization.quantize_fx.convert_fx(copy.deepcopy(mp)) 1907 mp_shadows_mq = add_shadow_loggers( 1908 'int8', mq, 'fp32', mp, OutputLogger) 1909 mp_shadows_mq(torch.randn(1, 1, 1, 1)) 1910 act_compare_dict = extract_shadow_logger_info( 1911 mp_shadows_mq, OutputLogger, 'fp32') 1912 self.assertTrue(len(act_compare_dict) == 3) 1913 self.assert_ns_compare_dict_valid(act_compare_dict) 1914 1915 @skipIfNoFBGEMM 1916 def test_loggers_preserve_qat_numerics(self): 1917 m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) 1918 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1919 example_inputs = (torch.randn(1, 1, 1, 1),) 1920 mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) 1921 mp(*example_inputs) 1922 mc = convert_fx(copy.deepcopy(mp)) 1923 mp.apply(torch.ao.quantization.disable_observer) 1924 1925 ref_fp32 = mp(*example_inputs) 1926 ref_int8 = mc(*example_inputs) 1927 1928 mp_ns, mc_ns = add_loggers('fp32', mp, 'int8', mc, OutputLogger) 1929 ref_fp32_ns = mp_ns(*example_inputs) 1930 ref_int8_ns = mc_ns(*example_inputs) 1931 self.assertEqual(ref_fp32, ref_fp32_ns) 1932 self.assertEqual(ref_int8, ref_int8_ns) 1933 1934 @skipIfNoFBGEMM 1935 def test_shadow_loggers_preserve_qat_numerics(self): 1936 m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)) 1937 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 1938 example_inputs = (torch.randn(1, 1, 1, 1),) 1939 mp = prepare_qat_fx(m, qconfig_dict, example_inputs=example_inputs) 1940 mp(*example_inputs) 1941 mc = convert_fx(copy.deepcopy(mp)) 1942 mp.apply(torch.ao.quantization.disable_observer) 1943 1944 ref_fp32 = mp(*example_inputs) 1945 ref_int8 = mc(*example_inputs) 1946 1947 mc_shadows_mp = add_shadow_loggers('int8', mc, 'fp32', mp, OutputLogger) 1948 ref_shadow = mc_shadows_mp(*example_inputs) 1949 self.assertEqual(ref_fp32, ref_shadow) 1950 1951 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1952 def test_extract_weights_cuda(self): 1953 # Note: this is not using quantization because quantized kernels do not 1954 # work on cuda yet. 1955 m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() 1956 m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() 1957 results = extract_weights('a', m1, 'b', m2) 1958 extend_logger_results_with_comparison( 1959 results, 'a', 'b', compute_sqnr, 'sqnr') 1960 self.assert_ns_compare_dict_valid(results) 1961 1962 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1963 def test_add_loggers_cuda(self): 1964 # Note: this is not using quantization because quantized kernels do not 1965 # work on cuda yet. 1966 m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() 1967 m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() 1968 m1_ns, m2_ns = add_loggers('a', m1, 'b', m2, OutputLogger) 1969 datum = torch.randn(1, 1, 1, 1) 1970 datum = datum.cuda() 1971 1972 m1_ns(datum) 1973 m2_ns(datum) 1974 1975 act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b') 1976 extend_logger_results_with_comparison( 1977 act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') 1978 1979 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1980 def test_add_shadow_loggers_cuda(self): 1981 # Note: this is not using quantization because quantized kernels do not 1982 # work on cuda yet. 1983 m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() 1984 m2 = nn.Sequential(nn.Conv2d(1, 1, 1)).cuda() 1985 m1_shadows_m2 = add_shadow_loggers('a', m1, 'b', m2, OutputLogger) 1986 datum = torch.randn(1, 1, 1, 1) 1987 datum = datum.cuda() 1988 1989 m1_shadows_m2(datum) 1990 1991 act_compare_dict = extract_shadow_logger_info(m1_shadows_m2, OutputLogger, 'b') 1992 extend_logger_results_with_comparison( 1993 act_compare_dict, 'a', 'b', compute_sqnr, 'sqnr') 1994 1995 def test_fp16_shadows_fp32(self): 1996 m = LinearReluFunctional().eval() 1997 example_inputs = (torch.randn(1, 4),) 1998 qconfig_dict = {"": torch.ao.quantization.float16_static_qconfig} 1999 mp = prepare_fx(copy.deepcopy(m), qconfig_dict, example_inputs=example_inputs) 2000 mq = convert_to_reference_fx(mp) 2001 mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger) 2002 2003 def test_mul_add_cat_stack_skips_shadowing(self): 2004 class M(nn.Module): 2005 def forward(self, x): 2006 x = x * x 2007 x = torch.mul(x, x) 2008 x = x + x 2009 x = torch.add(x, x) 2010 x = torch.cat([x]) 2011 x = torch.stack([x]) 2012 return x 2013 2014 m = M().eval() 2015 self._test_match_shadow_activations( 2016 m, (torch.randn(1, 1, 4, 4),), 2017 results_len=0) 2018 2019 def test_op_with_only_kwargs_skips_shadowing(self): 2020 class M(nn.Module): 2021 def forward(self, x): 2022 x = torch.cat(tensors=[x]) 2023 x = torch.stack(tensors=[x]) 2024 return x 2025 2026 m = M().eval() 2027 self._test_match_shadow_activations( 2028 m, (torch.randn(1, 1, 4, 4),), 2029 results_len=0) 2030 2031 def test_unsupported_op_copy_skips_shadowing(self): 2032 """ 2033 Copying a `call_function` node is not implemented, test that this 2034 does not crash shadowing but instead skips the node. 2035 """ 2036 class M(nn.Module): 2037 def forward(self, x): 2038 # the second argument leads to attempting to copy a 2039 # call_function node 2040 x = F.layer_norm(x, x.shape[1:]) 2041 return x 2042 2043 m = M().eval() 2044 self._test_match_shadow_activations( 2045 m, (torch.randn(1, 1, 4, 4),), 2046 results_len=0) 2047 2048 def test_linear_kwargs_shadow(self): 2049 2050 class M(nn.Module): 2051 def __init__(self) -> None: 2052 super().__init__() 2053 self.w1 = nn.Parameter(torch.empty(4, 4)) 2054 self.b1 = nn.Parameter(torch.zeros(4)) 2055 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 2056 2057 def forward(self, x): 2058 x = F.linear(input=x, weight=self.w1, bias=self.b1) 2059 return x 2060 2061 # note: FX graph mode quantization does not have good support 2062 # for kwargs-only right now, so we pass in two unquantized 2063 # models 2064 m = M().eval() 2065 mt = torch.fx.symbolic_trace(m) 2066 mt_copy = copy.deepcopy(mt) 2067 2068 mt_shadows_mt_copy = add_shadow_loggers( 2069 'a', mt, 'b', mt_copy, OutputLogger) 2070 2071 mt_shadows_mt_copy(torch.randn(4, 4)) 2072 act_compare_dict = extract_shadow_logger_info( 2073 mt_shadows_mt_copy, OutputLogger, 'b') 2074 self.assertTrue(len(act_compare_dict) == 1) 2075 2076@skipIfNoQNNPACK 2077class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase): 2078 """ 2079 Tests the "n shadows" workflow. 2080 """ 2081 2082 def _test_impl(self, m, example_input, qconfig_mappings): 2083 backend_config = get_native_backend_config() 2084 2085 # test that input is valid 2086 _ = m(*example_input) 2087 2088 msp = prepare_n_shadows_model( 2089 m, example_input, qconfig_mappings, backend_config) 2090 # print('msp', msp) 2091 2092 for _ in range(2): 2093 msp(*example_input) 2094 2095 msq = convert_n_shadows_model(msp) 2096 2097 loggers_set_enabled(msq, True) 2098 msq(*example_input) 2099 2100 results = extract_results_n_shadows_model(msq) 2101 print_comparisons_n_shadows_model(results) 2102 return msq 2103 2104 @withQNNPACKBackend 2105 def test_linear_mod(self): 2106 class M(nn.Module): 2107 def __init__(self) -> None: 2108 super().__init__() 2109 self.fc1 = nn.Linear(2, 2) 2110 2111 def forward(self, x): 2112 x = self.fc1(x) 2113 return x 2114 2115 m = M().eval() 2116 example_input = (torch.randn(2, 2),) 2117 2118 qconfig_mappings = \ 2119 QConfigMultiMapping().set_global([torch.ao.quantization.default_qconfig]) 2120 self._test_impl(m, example_input, qconfig_mappings) 2121 2122 @withQNNPACKBackend 2123 def test_linear_relu_mod(self): 2124 class M(nn.Module): 2125 def __init__(self) -> None: 2126 super().__init__() 2127 self.fc1 = nn.Linear(2, 2) 2128 self.fc2 = nn.Linear(2, 2) 2129 self.relu = nn.ReLU() 2130 2131 def forward(self, x): 2132 x = self.fc1(x) 2133 x = self.fc2(x) 2134 x = self.relu(x) 2135 return x 2136 2137 m = M().eval() 2138 example_input = (torch.randn(2, 2),) 2139 2140 qconfig_mappings = ( 2141 QConfigMultiMapping().set_global([ 2142 torch.ao.quantization.default_qconfig, 2143 torch.ao.quantization.default_dynamic_qconfig 2144 ]) 2145 ) 2146 self._test_impl(m, example_input, qconfig_mappings) 2147 2148 @withQNNPACKBackend 2149 def test_conv_bn_relu_mod(self): 2150 class M(nn.Module): 2151 def __init__(self) -> None: 2152 super().__init__() 2153 self.conv = nn.Conv2d(1, 1, 1) 2154 self.bn = nn.BatchNorm2d(1) 2155 self.relu = nn.ReLU() 2156 2157 def forward(self, x): 2158 x = self.conv(x) 2159 x = self.bn(x) 2160 x = self.relu(x) 2161 return x 2162 2163 m = M().eval() 2164 example_input = (torch.randn(32, 1, 16, 16),) 2165 2166 qconfig_mappings = QConfigMultiMapping() \ 2167 .set_global([ 2168 torch.ao.quantization.default_qconfig, 2169 torch.ao.quantization.default_per_channel_qconfig 2170 ]) 2171 self._test_impl(m, example_input, qconfig_mappings) 2172 2173 @withQNNPACKBackend 2174 def test_functions(self): 2175 class M(nn.Module): 2176 def __init__(self) -> None: 2177 super().__init__() 2178 self.w1 = nn.Parameter(torch.randn(2, 2)) 2179 self.b1 = nn.Parameter(torch.zeros(2)) 2180 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 2181 2182 def forward(self, x): 2183 x = F.sigmoid(x) 2184 x = F.linear(x, self.w1, self.b1) 2185 x = F.linear(x, self.w1[:], self.b1) 2186 x = F.relu(x) 2187 x = x + x 2188 x = torch.cat([x]) 2189 x = torch.cat((x,)) 2190 x = torch.cat(tensors=[x]) 2191 # TODO(future PR): enable layernorm 2192 # blocked on FX graph mode quant not inserting observer for 2193 # second arg, if the second arg is a module input 2194 # x = F.layer_norm(x, x.shape) 2195 # x = F.layer_norm(x, x.shape[1:]) 2196 # x = x.reshape(1, -1) * 2 2197 # x = F.layer_norm(x.reshape(1, -1), x.shape[1:]) 2198 x = torch.matmul(x, x.reshape(2, 2)) 2199 x = torch.matmul(x.reshape(2, 2), x.reshape(2, 2)) 2200 # TODO(future PR): enable below after FX graph mode quantization handles 2201 # it, currently this is not supported 2202 # x = F.linear(input=x, weight=self.w1, bias=self.b1) 2203 return x 2204 2205 m = M().eval() 2206 example_input = (torch.randn(2, 2),) 2207 2208 qconfig_mappings = QConfigMultiMapping() \ 2209 .set_global([torch.ao.quantization.default_qconfig]) 2210 self._test_impl(m, example_input, qconfig_mappings) 2211 2212 @withQNNPACKBackend 2213 def test_partial_qconfig_mapping(self): 2214 class M(nn.Module): 2215 def __init__(self) -> None: 2216 super().__init__() 2217 self.fc = nn.Linear(2, 2) 2218 self.w1 = nn.Parameter(torch.randn(2, 2)) 2219 self.b1 = nn.Parameter(torch.randn(2)) 2220 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 2221 2222 def forward(self, x): 2223 x = self.fc(x) 2224 x = F.linear(x, self.w1, self.b1) 2225 x = F.relu(x) 2226 x = x + x 2227 return x 2228 2229 m = M().eval() 2230 example_input = (torch.randn(2, 2),) 2231 qconfig = torch.ao.quantization.default_qconfig 2232 2233 qconfig_mappings = QConfigMultiMapping() \ 2234 .set_object_type(F.linear, [qconfig]) \ 2235 .set_object_type(F.relu, [qconfig]) 2236 self._test_impl(m, example_input, qconfig_mappings) 2237 2238 @withQNNPACKBackend 2239 def test_logger_enabled_and_save_activations_flags(self): 2240 m = nn.Sequential(nn.Linear(1, 1)).eval() 2241 example_input = (torch.randn(1, 1),) 2242 2243 qconfig_mappings = QConfigMultiMapping() \ 2244 .set_global([torch.ao.quantization.default_qconfig]) 2245 backend_config = get_native_backend_config() 2246 2247 msp = prepare_n_shadows_model( 2248 m, example_input, qconfig_mappings, backend_config) 2249 2250 for _ in range(2): 2251 msp(*example_input) 2252 2253 def _check_logger_count(model, exp_count_stats, exp_count_comparisons): 2254 for name, mod in model.named_modules(): 2255 if isinstance(mod, OutputLogger): 2256 self.assertTrue( 2257 len(mod.stats) == exp_count_stats, 2258 f'stats: expected {len(mod.stats)} to equal {exp_count_stats}') 2259 if isinstance(mod, OutputComparisonLogger): 2260 self.assertTrue( 2261 len(mod.comparisons) == exp_count_comparisons, 2262 f'comparisons: expected {len(mod.comparisons)} to equal {exp_count_comparisons}') 2263 2264 # check behavior with save_activations enabled 2265 msq = convert_n_shadows_model(copy.deepcopy(msp)) 2266 loggers_set_enabled(msq, True) 2267 loggers_set_save_activations(msq, True) 2268 # after prepare calibration but before convert calibration, loggers 2269 # should not have anything saved 2270 _check_logger_count(msq, 0, 0) 2271 msq(*example_input) 2272 # loggers should save each item after calibration 2273 _check_logger_count(msq, 1, 1) 2274 2275 # check behavior with save_activations disabled 2276 msq = convert_n_shadows_model(copy.deepcopy(msp)) 2277 loggers_set_enabled(msq, True) 2278 loggers_set_save_activations(msq, False) 2279 # after prepare calibration but before convert calibration, loggers 2280 # should not have anything saved 2281 _check_logger_count(msq, 0, 0) 2282 msq(*example_input) 2283 # stats should be empty, but comparisons should be there 2284 _check_logger_count(msq, 0, 1) 2285 2286 @skipIfTorchDynamo("too slow") 2287 @skip_if_no_torchvision 2288 @withQNNPACKBackend 2289 def test_mobilenet_v2(self): 2290 import torchvision 2291 m = torchvision.models.quantization.mobilenet_v2( 2292 pretrained=False, quantize=False).eval() 2293 example_input = (torch.randn(1, 3, 224, 224),) 2294 2295 qconfig_mappings = QConfigMultiMapping() \ 2296 .set_global([torch.ao.quantization.default_qconfig, torch.ao.quantization.default_dynamic_qconfig]) 2297 2298 self._test_impl(m, example_input, qconfig_mappings) 2299 2300 @withQNNPACKBackend 2301 def test_qconfig_multi_mapping_deduplication(self): 2302 # check that insertion deduplicates qconfigs 2303 qconfig_multi_mapping = QConfigMultiMapping().set_global( 2304 [torch.ao.quantization.default_qconfig, torch.ao.quantization.default_qconfig] 2305 ) 2306 self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 1) 2307 2308 @withQNNPACKBackend 2309 def test_qconfig_multi_mapping_insert_padding(self): 2310 # test that inserting a higher priority qconfig style with fewer elements than a lower priority qconfig will 2311 # result in adding None to the extra QConfigMappings at that same style+key 2312 qconfig_multi_mapping = ( 2313 QConfigMultiMapping() 2314 .set_global( 2315 [ 2316 torch.ao.quantization.default_qconfig, 2317 torch.ao.quantization.default_dynamic_qconfig, 2318 ] 2319 ) 2320 .set_object_type(torch.nn.Linear, [torch.ao.quantization.default_qconfig]) 2321 .set_module_name_regex("fc", [torch.ao.quantization.default_qconfig]) 2322 .set_module_name("fc2", [torch.ao.quantization.default_qconfig]) 2323 .set_module_name_object_type_order( 2324 "", nn.Linear, 0, [torch.ao.quantization.default_qconfig] 2325 ) 2326 ) 2327 2328 self.assertEqual( 2329 qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[ 2330 torch.nn.Linear 2331 ], 2332 None, 2333 ) 2334 self.assertEqual( 2335 qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[ 2336 "fc" 2337 ], 2338 None, 2339 ) 2340 self.assertEqual( 2341 qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], 2342 None, 2343 ) 2344 self.assertEqual( 2345 qconfig_multi_mapping.qconfig_mappings_list[ 2346 1 2347 ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)], 2348 None, 2349 ) 2350 2351 @withQNNPACKBackend 2352 def test_qconfig_multi_mapping_retroactive_padding(self): 2353 # test that inserting a lower priority qconfig style with more elements thhan lower priority qconfig styles 2354 # will result in the new QConfigMapping having None at all previously existing styles+keys 2355 qconfig_multi_mapping = ( 2356 QConfigMultiMapping() 2357 .set_object_type(torch.nn.Linear, [torch.ao.quantization.default_qconfig]) 2358 .set_module_name_regex("fc", [torch.ao.quantization.default_qconfig]) 2359 .set_module_name("fc2", [torch.ao.quantization.default_qconfig]) 2360 .set_module_name_object_type_order( 2361 "", nn.Linear, 0, [torch.ao.quantization.default_qconfig] 2362 ) 2363 .set_global( 2364 [ 2365 torch.ao.quantization.default_qconfig, 2366 torch.ao.quantization.default_dynamic_qconfig, 2367 ] 2368 ) 2369 ) 2370 2371 self.assertEqual( 2372 qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[ 2373 torch.nn.Linear 2374 ], 2375 None, 2376 ) 2377 self.assertEqual( 2378 qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[ 2379 "fc" 2380 ], 2381 None, 2382 ) 2383 self.assertEqual( 2384 qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], 2385 None, 2386 ) 2387 self.assertEqual( 2388 qconfig_multi_mapping.qconfig_mappings_list[ 2389 1 2390 ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)], 2391 None, 2392 ) 2393 2394 @withQNNPACKBackend 2395 def test_qconfig_multi_mapping_end_to_end(self): 2396 # test that the prepare/convert_n_shadows_model works as expected 2397 # with qconfig_multi_mapping and avoids unwanted matches 2398 2399 m = TwoLayerLinearModel().eval() 2400 example_input = m.get_example_inputs() 2401 2402 qconfig_multi_mapping = ( 2403 QConfigMultiMapping() 2404 .set_global( 2405 [ 2406 torch.ao.quantization.default_qconfig, 2407 torch.ao.quantization.default_dynamic_qconfig, 2408 ] 2409 ) 2410 .set_module_name("fc2", [None, torch.ao.quantization.default_qconfig]) 2411 ) 2412 self.assertEqual( 2413 qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], 2414 None, 2415 ) 2416 msq = self._test_impl(m, example_input, qconfig_multi_mapping) 2417 2418 self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) 2419 self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) 2420 self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0) 2421 self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2) 2422 2423 @withQNNPACKBackend 2424 def test_qconfig_multi_mapping_from_list(self): 2425 # test QConfigMultiMapping.from_list_qconfig_mapping works as expected 2426 2427 m = TwoLayerLinearModel().eval() 2428 example_input = m.get_example_inputs() 2429 2430 qconfig_mappings_list = [ 2431 QConfigMapping().set_global(torch.ao.quantization.default_qconfig), 2432 QConfigMapping() 2433 .set_global(torch.ao.quantization.default_dynamic_qconfig) 2434 .set_module_name("fc2", torch.ao.quantization.default_qconfig), 2435 ] 2436 2437 qconfig_multi_mapping = QConfigMultiMapping().from_list_qconfig_mapping( 2438 qconfig_mappings_list 2439 ) 2440 self.assertEqual( 2441 qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], 2442 None, 2443 ) 2444 2445 msq = self._test_impl(m, example_input, qconfig_multi_mapping) 2446 2447 self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) 2448 self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) 2449 self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0) 2450 self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2) 2451 2452 @withQNNPACKBackend 2453 def test_qconfig_multi_mapping_ordering(self): 2454 # test that the module ordering ignores None 2455 2456 m = TwoLayerLinearModel().eval() 2457 example_input = m.get_example_inputs() 2458 qconfig_multi_mapping = ( 2459 QConfigMultiMapping() 2460 .set_global( 2461 [ 2462 torch.ao.quantization.default_qconfig, 2463 torch.ao.quantization.default_dynamic_qconfig, 2464 ] 2465 ) 2466 .set_module_name( 2467 "fc2", 2468 [ 2469 None, 2470 torch.ao.quantization.default_dynamic_qconfig, 2471 torch.ao.quantization.default_qat_qconfig_v2, 2472 ], 2473 ) 2474 ) 2475 self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 2) 2476 msq = self._test_impl(m, example_input, qconfig_multi_mapping) 2477 2478 self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) 2479 self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) 2480 self.checkDynamicQuantizedLinear(msq.shadow_wrapper_1_1.mod_0, torch.qint8) 2481 self.checkQuantizedLinear(msq.shadow_wrapper_1_2.mod_0) 2482 2483 @withQNNPACKBackend 2484 def test_qconfig_multi_mapping_repr(self): 2485 qconfig_multi_mapping = ( 2486 QConfigMultiMapping() 2487 .set_global( 2488 [ 2489 torch.ao.quantization.default_qconfig, 2490 torch.ao.quantization.default_dynamic_qconfig, 2491 ] 2492 ) 2493 .set_module_name( 2494 "fc2", 2495 [ 2496 None, 2497 torch.ao.quantization.default_dynamic_qconfig, 2498 torch.ao.quantization.default_qat_qconfig_v2, 2499 ], 2500 ) 2501 ) 2502 self.assertTrue(isinstance(qconfig_multi_mapping.__repr__(), str)) 2503 2504 @withQNNPACKBackend 2505 def test_custom_functions_and_tracer(self): 2506 class M(nn.Module): 2507 def __init__(self) -> None: 2508 super().__init__() 2509 self.fc1 = nn.Linear(2, 2) 2510 self.fc2 = nn.Linear(2, 2) 2511 2512 def forward(self, x): 2513 x = self.fc1(x) 2514 x = self.fc2(x) 2515 return x 2516 2517 m = M().eval() 2518 example_inputs = (torch.randn(2, 2),) 2519 2520 qconfig_mappings = QConfigMultiMapping().set_global( 2521 [torch.ao.quantization.default_qat_qconfig] 2522 ) 2523 2524 custom_tracer = torch.ao.quantization.quantize_fx.QuantizationTracer( 2525 ["fc2"], [] 2526 ) 2527 2528 custom_prepare_fn = torch.ao.quantization.quantize_fx.prepare_qat_fx 2529 2530 def custom_convert_fn(module, to_print): 2531 print(to_print) 2532 mod = torch.ao.quantization.quantize_fx.convert_fx(module) 2533 return mod 2534 2535 backend_config = get_native_backend_config() 2536 2537 # test that input is valid 2538 _ = m(*example_inputs) 2539 2540 kwargs = {"to_print": "working"} 2541 2542 msp = prepare_n_shadows_model( 2543 m, 2544 example_inputs, 2545 qconfig_mappings, 2546 backend_config, 2547 custom_prepare_fn=custom_prepare_fn, 2548 custom_prepare_kwargs=None, 2549 custom_tracer=custom_tracer, 2550 ) 2551 2552 for _ in range(2): 2553 msp(*example_inputs) 2554 2555 msq = convert_n_shadows_model( 2556 msp, custom_convert_fn=custom_convert_fn, custom_convert_kwargs=kwargs 2557 ) 2558 print(msq) 2559 loggers_set_enabled(msq, True) 2560 msq(*example_inputs) 2561 2562 results = extract_results_n_shadows_model(msq) 2563 print_comparisons_n_shadows_model(results) 2564 2565 def _test_extract_weights_impl(self, m, example_input, qconfig_mapping): 2566 backend_config = get_native_backend_config() 2567 results = _n_shadows_compare_weights( 2568 m, example_input, qconfig_mapping, backend_config) 2569 print_comparisons_n_shadows_model(results) 2570 2571 @withQNNPACKBackend 2572 def test_extract_weights_linear(self): 2573 class M(nn.Module): 2574 def __init__(self) -> None: 2575 super().__init__() 2576 self.w1 = nn.Parameter(torch.randn(2, 2)) 2577 self.b1 = nn.Parameter(torch.randn(2)) 2578 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 2579 self.w2 = nn.Parameter(torch.randn(2, 2)) 2580 self.b2 = nn.Parameter(torch.randn(2)) 2581 torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) 2582 self.w3 = nn.Parameter(torch.randn(2, 2)) 2583 self.b3 = nn.Parameter(torch.randn(2)) 2584 torch.nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5)) 2585 self.w4 = nn.Parameter(torch.randn(2, 2)) 2586 self.b4 = nn.Parameter(torch.randn(2)) 2587 torch.nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5)) 2588 2589 def forward(self, x): 2590 x = F.linear(x, self.w1, self.b1) 2591 x = F.linear(x, self.w2, self.b2) 2592 x = F.relu(x) 2593 x = F.linear(x, self.w3, self.b3) 2594 x = F.linear(x, self.w4, self.b4) 2595 return x 2596 2597 per_tensor_qconfig = torch.ao.quantization.default_qconfig 2598 2599 m = M().eval() 2600 example_input = (torch.randn(2, 2),) 2601 qconfig_mapping = get_default_qconfig_mapping() 2602 # test unquantized 2603 qconfig_mapping.set_module_name_object_type_order( 2604 '', F.linear, 2, None) 2605 # test per-tensor 2606 qconfig_mapping.set_module_name_object_type_order( 2607 '', F.linear, 3, per_tensor_qconfig) 2608 self._test_extract_weights_impl(m, example_input, qconfig_mapping) 2609 2610 2611 def _test_add_loggers_impl(self, m, example_input, qconfig_mapping): 2612 backend_config = get_native_backend_config() 2613 m_copy = copy.deepcopy(m) 2614 2615 # test that input is valid 2616 _ = m(*example_input) 2617 2618 msp = _prepare_n_shadows_add_loggers_model( 2619 m, example_input, qconfig_mapping, backend_config) 2620 # print('msp', msp) 2621 2622 msp(*example_input) 2623 2624 msq = convert_n_shadows_model(msp) 2625 # print('msq', msq) 2626 2627 loggers_set_enabled(msq, True) 2628 output_fp32 = msq(*example_input) 2629 2630 results = extract_results_n_shadows_model(msq) 2631 # print(results) 2632 # print_comparisons_n_shadows_model(results) 2633 2634 # get the last quantized output from results 2635 inner_results = results['model']['node_output'] 2636 last_subgraph = list(inner_results.keys())[-1] 2637 output_shadow = inner_results[last_subgraph][0]['values'][-1] 2638 2639 # verify that both fp32 and quantized output matches reference 2640 output_fp32_ref = m_copy(*example_input) 2641 mp_ref = prepare_fx(m_copy, qconfig_mapping, example_input) 2642 for _ in range(2): 2643 mp_ref(*example_input) 2644 mq_ref = convert_fx(mp_ref) 2645 output_shadow_ref = mq_ref(*example_input) 2646 self.assertTrue( 2647 torch.allclose(output_fp32, output_fp32_ref), 2648 f"fp32 comparison: {output_fp32} not close to {output_fp32_ref}") 2649 2650 # print('shadow', output_shadow.shape, output_shadow) 2651 # print('shadow_ref', output_shadow_ref.shape, output_shadow_ref) 2652 2653 self.assertTrue( 2654 torch.allclose(output_shadow, output_shadow_ref), 2655 f"shadow comparison: {output_shadow} not close to {output_shadow_ref}") 2656 2657 return msq 2658 2659 @withQNNPACKBackend 2660 def test_add_loggers_linear_mod_quant_quant(self): 2661 m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 2662 example_input = (torch.randn(2, 2),) 2663 qconfig_mapping = get_default_qconfig_mapping() 2664 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2665 2666 @withQNNPACKBackend 2667 def test_add_loggers_linear_mod_fp32_quant(self): 2668 m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 2669 example_input = (torch.randn(2, 2),) 2670 qconfig_mapping = get_default_qconfig_mapping() 2671 qconfig_mapping.set_module_name('0', None) 2672 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2673 2674 @withQNNPACKBackend 2675 def test_add_loggers_linear_mod_quant_fp32(self): 2676 m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 2677 example_input = (torch.randn(2, 2),) 2678 qconfig_mapping = get_default_qconfig_mapping() 2679 qconfig_mapping.set_module_name('1', None) 2680 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2681 2682 @withQNNPACKBackend 2683 def test_add_loggers_linear_mod_fp32_fp32(self): 2684 m = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 2685 example_input = (torch.randn(2, 2),) 2686 qconfig_mapping = get_default_qconfig_mapping() 2687 qconfig_mapping.set_module_name('0', None) 2688 qconfig_mapping.set_module_name('1', None) 2689 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2690 2691 @withQNNPACKBackend 2692 def test_add_loggers_conv_bn_relu_fusion_quant(self): 2693 m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1), nn.ReLU()) 2694 m.eval() 2695 example_input = (torch.randn(16, 1, 4, 4),) 2696 qconfig_mapping = get_default_qconfig_mapping() 2697 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2698 2699 @withQNNPACKBackend 2700 def test_add_loggers_conv_bn_relu_fusion_fp32(self): 2701 m = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1), nn.ReLU()) 2702 m.eval() 2703 example_input = (torch.randn(16, 1, 4, 4),) 2704 qconfig_mapping = get_default_qconfig_mapping() 2705 qconfig_mapping.set_module_name('0', None) 2706 qconfig_mapping.set_module_name('1', None) 2707 qconfig_mapping.set_module_name('2', None) 2708 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2709 2710 @withQNNPACKBackend 2711 def test_add_loggers_functions(self): 2712 class M(nn.Module): 2713 def __init__(self) -> None: 2714 super().__init__() 2715 self.w1 = nn.Parameter(torch.randn(2, 2)) 2716 self.b1 = nn.Parameter(torch.randn(2)) 2717 torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) 2718 2719 def forward(self, x): 2720 x = F.linear(x, self.w1, self.b1) 2721 x = F.relu(x) 2722 x = x + x 2723 x = x + 1 2724 # TODO(future PR): support first arg being a scalar 2725 # x = 1 + x 2726 x = torch.cat([x, x]) 2727 x = torch.cat([x, x]) 2728 x = torch.cat(tensors=[x, x]) 2729 # function not matchable by quantization 2730 x = torch.nn.functional.rrelu(x) 2731 x = F.linear(x, self.w1, self.b1) 2732 return x 2733 2734 m = M().eval() 2735 example_input = (torch.randn(16, 2),) 2736 for qconfig_mapping in ( 2737 get_default_qconfig_mapping(), 2738 QConfigMapping(), 2739 ): 2740 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2741 2742 @skipIfTorchDynamo("too slow") 2743 @skip_if_no_torchvision 2744 @withQNNPACKBackend 2745 def test_add_loggers_mobilenet_v2(self): 2746 import torchvision 2747 m = torchvision.models.quantization.mobilenet_v2( 2748 pretrained=False, quantize=False).eval() 2749 example_input = (torch.randn(8, 3, 224, 224),) 2750 qconfig_mapping = get_default_qconfig_mapping() 2751 self._test_add_loggers_impl(m, example_input, qconfig_mapping) 2752 2753 2754class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase): 2755 """ 2756 Tests numeric suite core APIs on non-toy models. 2757 """ 2758 2759 @skipIfNoFBGEMM 2760 def test_compare_weights_conv(self): 2761 test_cases = ( 2762 (ConvModel(),), 2763 (ConvBnModel(),), 2764 (ConvBnReLUModel(),), 2765 ) 2766 for m, in test_cases: 2767 m.eval() 2768 example_inputs = (torch.randn(1, 3, 5, 5),) 2769 self._test_extract_weights(m, example_inputs, results_len=1) 2770 2771 @skipIfNoFBGEMM 2772 def test_compare_weights_linear(self): 2773 test_cases = ( 2774 (SingleLayerLinearModel(), None), 2775 ( 2776 SingleLayerLinearDynamicModel(), 2777 {"object_type": [(nn.Linear, default_dynamic_qconfig)]}, 2778 ), 2779 ) 2780 for m, qconfig_dict in test_cases: 2781 m.eval() 2782 example_inputs = (torch.randn(1, 3, 5, 5),) 2783 res = self._test_extract_weights( 2784 m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) 2785 2786 @skipIfNoFBGEMM 2787 def test_compare_weights_lstm_dynamic(self): 2788 qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} 2789 lstm_input = torch.rand((1, 1, 2)) 2790 lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) 2791 example_inputs = (lstm_input, lstm_hidden) 2792 m = LSTMwithHiddenDynamicModel().eval() 2793 res = self._test_extract_weights( 2794 m, example_inputs, results_len=1, qconfig_dict=qconfig_dict) 2795 2796 @skipIfNoFBGEMM 2797 def test_compare_activations_conv(self): 2798 test_cases = ( 2799 (ConvModel(),), 2800 (ConvBnModel(),), 2801 (ConvBnReLUModel(),), 2802 ) 2803 for m, in test_cases: 2804 m.eval() 2805 res = self._test_match_activations( 2806 m, (torch.randn(1, 3, 4, 4),), results_len=1) 2807 2808 @skipIfNoFBGEMM 2809 def test_compare_activations_linear(self): 2810 test_cases = ( 2811 (SingleLayerLinearModel(), None), 2812 ( 2813 SingleLayerLinearDynamicModel(), 2814 {"object_type": [(nn.Linear, default_dynamic_qconfig)]}, 2815 ), 2816 ) 2817 for m, qconfig_dict in test_cases: 2818 m.eval() 2819 res = self._test_match_activations( 2820 m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict) 2821 2822 @skipIfNoFBGEMM 2823 def test_compare_activations_lstm_dynamic(self): 2824 qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} 2825 m = LSTMwithHiddenDynamicModel().eval() 2826 lstm_input = torch.rand((1, 1, 2)) 2827 lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) 2828 # TODO(future PR): enable scripting (quant prepared LSTM not scriptable) 2829 res = self._test_match_activations( 2830 m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict, 2831 skip_scripting=True) 2832 2833 @skipIfNoFBGEMM 2834 def test_compare_shadow_activations_conv(self): 2835 test_cases = ( 2836 (ConvModel(),), 2837 (ConvBnModel(),), 2838 (ConvBnReLUModel(),), 2839 ) 2840 for m, in test_cases: 2841 m.eval() 2842 res = self._test_match_shadow_activations( 2843 m, (torch.randn(1, 3, 4, 4),), results_len=1) 2844 2845 @skipIfNoFBGEMM 2846 def test_compare_shadow_activations_linear(self): 2847 test_cases = ( 2848 (SingleLayerLinearModel(), None), 2849 ( 2850 SingleLayerLinearDynamicModel(), 2851 {"object_type": [(nn.Linear, default_dynamic_qconfig)]}, 2852 ), 2853 ) 2854 for m, qconfig_dict in test_cases: 2855 m.eval() 2856 res = self._test_match_shadow_activations( 2857 m, (torch.randn(5, 5),), results_len=1, qconfig_dict=qconfig_dict) 2858 2859 @skipIfNoFBGEMM 2860 def test_compare_shadow_activations_lstm_dynamic(self): 2861 qconfig_dict = {"object_type": [(nn.LSTM, default_dynamic_qconfig)]} 2862 m = LSTMwithHiddenDynamicModel().eval() 2863 lstm_input = torch.rand((1, 1, 2)) 2864 lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) 2865 # TODO(future PR): enable scripting (quant prepared LSTM not scriptable) 2866 res = self._test_match_shadow_activations( 2867 m, (lstm_input, lstm_hidden), results_len=1, qconfig_dict=qconfig_dict, 2868 skip_scripting=True) 2869 2870 @skipIfNoFBGEMM 2871 def test_sparsenn_compare_activations(self): 2872 for should_log_inputs in (True, False): 2873 sparse_nn = SparseNNModel().eval() 2874 idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) 2875 offsets = torch.LongTensor([0, 4]) 2876 x = torch.randn(2, 4) 2877 self._test_match_activations( 2878 sparse_nn, (idx, offsets, x), 2879 results_len=5, 2880 should_log_inputs=should_log_inputs) 2881 2882 @skipIfNoFBGEMM 2883 def test_sparsenn_shadow(self): 2884 for should_log_inputs in (True, False): 2885 sparse_nn = SparseNNModel().eval() 2886 idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) 2887 offsets = torch.LongTensor([0, 4]) 2888 x = torch.randn(2, 4) 2889 self._test_match_shadow_activations( 2890 sparse_nn, (idx, offsets, x), 2891 results_len=3, 2892 should_log_inputs=should_log_inputs) 2893 2894 @skipIfTorchDynamo("too slow") 2895 @skip_if_no_torchvision 2896 @skipIfNoFBGEMM 2897 def test_resnet18(self): 2898 import torchvision 2899 m = torchvision.models.quantization.resnet18(pretrained=False, quantize=False).eval() 2900 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 2901 self._test_match_shadow_activations( 2902 m, (torch.randn(1, 3, 224, 224),), 2903 qconfig_dict=qconfig_dict, 2904 should_log_inputs=False) 2905 2906 @skipIfTorchDynamo("too slow") 2907 @skip_if_no_torchvision 2908 @skipIfNoFBGEMM 2909 def test_mobilenet_v2(self): 2910 import torchvision 2911 m = torchvision.models.quantization.mobilenet_v2(pretrained=False, quantize=False).eval() 2912 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 2913 self._test_match_shadow_activations( 2914 m, (torch.randn(1, 3, 224, 224),), 2915 qconfig_dict=qconfig_dict, 2916 should_log_inputs=False) 2917