1# Owner(s): ["oncall: quantization"] 2import copy 3import unittest 4from typing import List 5 6import torch 7import torch._export 8from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 9from torch.ao.quantization.quantizer import QuantizationAnnotation, Quantizer 10from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 11 get_symmetric_quantization_config, 12) 13from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR 14from torch.fx import Node 15from torch.testing._internal.common_quantization import QuantizationTestCase 16from torch.testing._internal.common_utils import IS_WINDOWS 17 18 19class TestHelperModules: 20 class Conv2dWithObsSharingOps(torch.nn.Module): 21 def __init__(self) -> None: 22 super().__init__() 23 self.conv = torch.nn.Conv2d(3, 3, 3) 24 self.hardtanh = torch.nn.Hardtanh() 25 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 26 self.linear = torch.nn.Linear(3, 3) 27 28 def forward(self, x): 29 x = self.conv(x) 30 x = self.adaptive_avg_pool2d(x) 31 x = self.hardtanh(x) 32 x = x.view(-1, 3) 33 x = self.linear(x) 34 return x 35 36 37def _tag_partitions( 38 backend_name: str, op_name: str, annotated_partitions: List[List[Node]] 39): 40 for index, partition_nodes in enumerate(annotated_partitions): 41 tag_name = backend_name + "_" + op_name + "_" + str(index) 42 for node in partition_nodes: 43 assert "quantization_tag" not in node.meta, f"{node} is already tagged" 44 node.meta["quantization_tag"] = tag_name 45 46 47_QUANT_OPS = { 48 torch.ops.quantized_decomposed.quantize_per_tensor.default, 49 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 50 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 51 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 52 torch.ops.quantized_decomposed.quantize_per_channel.default, 53 torch.ops.quantized_decomposed.dequantize_per_channel.default, 54 torch.ops.quantized_decomposed.choose_qparams.tensor, 55} 56 57 58# TODO: rename to TestPortMetadataPass to align with the util name? 59@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") 60class TestMetaDataPorting(QuantizationTestCase): 61 def _test_quant_tag_preservation_through_decomp( 62 self, model, example_inputs, from_node_to_tags 63 ): 64 ep = torch.export.export(model, example_inputs) 65 found_tags = True 66 not_found_nodes = "" 67 for from_node, tag in from_node_to_tags.items(): 68 for n in ep.graph_module.graph.nodes: 69 from_node_meta = n.meta.get("from_node", None) 70 if from_node_meta is None: 71 continue 72 if not isinstance(from_node_meta, list): 73 raise ValueError( 74 f"from_node metadata is of type {type(from_node_meta)}, but expected list" 75 ) 76 for meta in from_node_meta: 77 node_target = meta[1] 78 if node_target == from_node: 79 node_tag = n.meta.get("quantization_tag", None) 80 if node_tag is None or tag != node_tag: 81 not_found_nodes += str(n.target) + ", " 82 found_tags = False 83 break 84 if not found_tags: 85 break 86 self.assertTrue( 87 found_tags, 88 f"Decomposition did not preserve quantization tag for {not_found_nodes}", 89 ) 90 91 def _test_metadata_porting( 92 self, 93 model, 94 example_inputs, 95 quantizer, 96 node_tags=None, 97 ) -> torch.fx.GraphModule: 98 m_eager = model.eval() 99 100 # program capture 101 m = copy.deepcopy(m_eager) 102 m = torch._export.capture_pre_autograd_graph( 103 m, 104 example_inputs, 105 ) 106 107 m = prepare_pt2e(m, quantizer) 108 # Calibrate 109 m(*example_inputs) 110 m = convert_pt2e(m) 111 112 pt2_quant_output = m(*example_inputs) 113 recorded_node_tags = {} 114 for n in m.graph.nodes: 115 if "quantization_tag" not in n.meta: 116 continue 117 if n.op == "call_function" and n.target in _QUANT_OPS: 118 key = n.target 119 elif n.op == "get_attr": 120 key = "get_attr" 121 else: 122 continue 123 124 if key not in recorded_node_tags: 125 recorded_node_tags[key] = set() 126 127 if ( 128 n.op == "call_function" 129 and n.meta["quantization_tag"] in recorded_node_tags[key] 130 ): 131 raise ValueError( 132 f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that " 133 "is associated with another node of the same type" 134 ) 135 recorded_node_tags[key].add(n.meta["quantization_tag"]) 136 137 self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys())) 138 for k, v in recorded_node_tags.items(): 139 self.assertEqual(v, node_tags[k]) 140 return m 141 142 def test_simple_metadata_porting(self): 143 """ 144 Model under test 145 conv2d -> avgpool -> hardtanh -> linear 146 Check quantization tags on conv2d, avgpool and linear are correctly set 147 """ 148 149 class BackendAQuantizer(Quantizer): 150 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 151 backend_string = "BackendA" 152 quantization_config = get_symmetric_quantization_config( 153 is_per_channel=True 154 ) 155 annotated_partitions = OP_TO_ANNOTATOR["linear"]( 156 gm, quantization_config 157 ) 158 _tag_partitions(backend_string, "linear", annotated_partitions) 159 annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) 160 _tag_partitions(backend_string, "conv2d", annotated_partitions) 161 annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( 162 gm, quantization_config 163 ) 164 _tag_partitions( 165 backend_string, "adaptive_avg_pool2d", annotated_partitions 166 ) 167 168 def validate(self, model: torch.fx.GraphModule) -> None: 169 pass 170 171 example_inputs = (torch.randn(1, 3, 5, 5),) 172 get_attr_tags = { 173 "BackendA_conv2d_0", 174 "BackendA_linear_0", 175 } 176 quantize_per_tensor_tags = { 177 "BackendA_conv2d_0", 178 "BackendA_adaptive_avg_pool2d_0", 179 "BackendA_linear_0", 180 } 181 dequantize_per_tensor_tags = { 182 "BackendA_adaptive_avg_pool2d_0", 183 "BackendA_conv2d_0", 184 "BackendA_linear_0", 185 } 186 dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} 187 node_tags = { 188 "get_attr": get_attr_tags, 189 torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, 190 torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, 191 torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, 192 } 193 m = self._test_metadata_porting( 194 TestHelperModules.Conv2dWithObsSharingOps(), 195 example_inputs, 196 BackendAQuantizer(), 197 node_tags, 198 ) 199 200 from_node_to_tags = { 201 torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0", 202 torch.ops.aten.linear.default: "BackendA_linear_0", 203 } 204 self._test_quant_tag_preservation_through_decomp( 205 m, example_inputs, from_node_to_tags 206 ) 207 208 def test_metadata_porting_with_no_quant_inbetween(self): 209 """ 210 Model under test 211 conv2d -> avgpool -> hardtanh -> linear 212 Dont quantize avgpool 213 Check quantization tags on conv2d and linear are correctly set 214 """ 215 216 class BackendAQuantizer(Quantizer): 217 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 218 backend_string = "BackendA" 219 quantization_config = get_symmetric_quantization_config( 220 is_per_channel=True 221 ) 222 annotated_partitions = OP_TO_ANNOTATOR["linear"]( 223 gm, quantization_config 224 ) 225 _tag_partitions(backend_string, "linear", annotated_partitions) 226 annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) 227 _tag_partitions(backend_string, "conv2d", annotated_partitions) 228 229 def validate(self, model: torch.fx.GraphModule) -> None: 230 pass 231 232 example_inputs = (torch.randn(1, 3, 5, 5),) 233 get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} 234 quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} 235 dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} 236 dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} 237 node_tags = { 238 "get_attr": get_attr_tags, 239 torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, 240 torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, 241 torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, 242 } 243 self._test_metadata_porting( 244 TestHelperModules.Conv2dWithObsSharingOps(), 245 example_inputs, 246 BackendAQuantizer(), 247 node_tags, 248 ) 249 250 @unittest.skip("Temporarily disabled") 251 def test_metadata_porting_for_dq(self): 252 """ 253 Model under test 254 conv2d -> avgpool -> hardtanh -> linear 255 Quantize all except linear. 256 Quantize linear with dynamic quantization 257 Check quantization tags on conv2d, avgpool and linear are correctly set 258 """ 259 260 class BackendAQuantizer(Quantizer): 261 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 262 backend_string = "BackendA" 263 # static quantiazation 264 quantization_config = get_symmetric_quantization_config( 265 is_per_channel=True 266 ) 267 annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config) 268 _tag_partitions(backend_string, "conv2d", annotated_partitions) 269 annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"]( 270 gm, quantization_config 271 ) 272 _tag_partitions( 273 backend_string, "adaptive_avg_pool2d", annotated_partitions 274 ) 275 276 # dynamic quantization 277 quantization_config_dynamic = get_symmetric_quantization_config( 278 is_per_channel=True, is_dynamic=True 279 ) 280 annotated_partitions = OP_TO_ANNOTATOR["linear"]( 281 gm, quantization_config_dynamic 282 ) 283 _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) 284 285 def validate(self, model: torch.fx.GraphModule) -> None: 286 pass 287 288 example_inputs = (torch.randn(1, 3, 5, 5),) 289 # TODO: add get_attr_tags when the test is re-enabled 290 get_attr_tags = {} 291 quantize_per_tensor_tags = { 292 "BackendA_conv2d_0", 293 "BackendA_adaptive_avg_pool2d_0", 294 } 295 quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} 296 choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} 297 dequantize_per_tensor_tags = { 298 "BackendA_adaptive_avg_pool2d_0", 299 "BackendA_conv2d_0", 300 } 301 dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} 302 dequantize_per_channel_tags = { 303 "BackendA_conv2d_0", 304 "BackendA_linear_dynamic_0", 305 } 306 node_tags = { 307 "get_attr": get_attr_tags, 308 torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, 309 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, 310 torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, 311 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, 312 torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, 313 torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags, 314 } 315 self._test_metadata_porting( 316 TestHelperModules.Conv2dWithObsSharingOps(), 317 example_inputs, 318 BackendAQuantizer(), 319 node_tags, 320 ) 321 322 def test_metadata_porting_for_two_dq(self): 323 """ 324 Model under test 325 conv2d -> avgpool -> hardtanh -> linear 326 Quantize linear and conv with dynamic quantization 327 Check quantization tags on conv2d, avgpool and linear are correctly set 328 """ 329 330 class BackendAQuantizer(Quantizer): 331 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 332 backend_string = "BackendA" 333 334 # dynamic quantization 335 quantization_config_dynamic = get_symmetric_quantization_config( 336 is_per_channel=True, is_dynamic=True 337 ) 338 annotated_partitions = OP_TO_ANNOTATOR["conv"]( 339 gm, quantization_config_dynamic 340 ) 341 _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions) 342 annotated_partitions = OP_TO_ANNOTATOR["linear"]( 343 gm, quantization_config_dynamic 344 ) 345 _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) 346 347 def validate(self, model: torch.fx.GraphModule) -> None: 348 pass 349 350 example_inputs = (torch.randn(1, 3, 5, 5),) 351 get_attr_tags = { 352 "BackendA_conv2d_dynamic_0", 353 "BackendA_linear_dynamic_0", 354 } 355 choose_qparams_tensor_tags = { 356 "BackendA_conv2d_dynamic_0", 357 "BackendA_linear_dynamic_0", 358 } 359 quantize_per_tensor_tensor_tags = { 360 "BackendA_conv2d_dynamic_0", 361 "BackendA_linear_dynamic_0", 362 } 363 dequantize_per_tensor_tensor_tags = { 364 "BackendA_conv2d_dynamic_0", 365 "BackendA_linear_dynamic_0", 366 } 367 dequantize_per_channel_tags = { 368 "BackendA_conv2d_dynamic_0", 369 "BackendA_linear_dynamic_0", 370 } 371 node_tags = { 372 "get_attr": get_attr_tags, 373 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, 374 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, 375 torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, 376 torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, 377 } 378 self._test_metadata_porting( 379 TestHelperModules.Conv2dWithObsSharingOps(), 380 example_inputs, 381 BackendAQuantizer(), 382 node_tags, 383 ) 384 385 def test_metadata_porting_for_dq_no_static_q(self): 386 """ 387 Model under test 388 conv2d -> avgpool -> hardtanh -> linear 389 Dont quantize anything except linear. 390 Quantize linear with dynamic quantization 391 Check quantization tags on conv2d, avgpool and linear are correctly set 392 """ 393 394 class BackendAQuantizer(Quantizer): 395 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 396 backend_string = "BackendA" 397 # dynamic quantization 398 quantization_config_dynamic = get_symmetric_quantization_config( 399 is_per_channel=True, is_dynamic=True 400 ) 401 annotated_partitions = OP_TO_ANNOTATOR["linear"]( 402 gm, quantization_config_dynamic 403 ) 404 _tag_partitions(backend_string, "linear_dynamic", annotated_partitions) 405 406 def validate(self, model: torch.fx.GraphModule) -> None: 407 pass 408 409 example_inputs = (torch.randn(1, 3, 5, 5),) 410 get_attr_tags = {"BackendA_linear_dynamic_0"} 411 choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"} 412 quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} 413 dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} 414 dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"} 415 node_tags = { 416 "get_attr": get_attr_tags, 417 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, 418 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, 419 torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, 420 torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags, 421 } 422 self._test_metadata_porting( 423 TestHelperModules.Conv2dWithObsSharingOps(), 424 example_inputs, 425 BackendAQuantizer(), 426 node_tags, 427 ) 428 429 def test_no_metadata_porting(self): 430 class BackendAQuantizer(Quantizer): 431 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 432 backend_string = "BackendA" 433 quantization_config = get_symmetric_quantization_config( 434 is_per_channel=True 435 ) 436 OP_TO_ANNOTATOR["linear"](gm, quantization_config) 437 OP_TO_ANNOTATOR["conv"](gm, quantization_config) 438 OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) 439 440 def validate(self, model: torch.fx.GraphModule) -> None: 441 pass 442 443 example_inputs = (torch.randn(1, 3, 5, 5),) 444 node_tags = {} 445 m = self._test_metadata_porting( 446 TestHelperModules.Conv2dWithObsSharingOps(), 447 example_inputs, 448 BackendAQuantizer(), 449 node_tags, 450 ) 451 452 from_node_to_tags = {} 453 self._test_quant_tag_preservation_through_decomp( 454 m, example_inputs, from_node_to_tags 455 ) 456 457 def test_no_metadata_porting_through_unknown_ops(self): 458 """ 459 Model under test 460 matmul -> add -> relu 461 matmul has get_attr as first input, but the quantization_tag should not be 462 propagated to add even if it's part of a chain that ends at get_attr 463 """ 464 465 class MatmulWithConstInput(torch.nn.Module): 466 def __init__(self) -> None: 467 super().__init__() 468 self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16))) 469 470 def forward(self, x, y): 471 x = torch.matmul(self.w, x) 472 z = x + y 473 return torch.nn.functional.relu(z) 474 475 class BackendAQuantizer(Quantizer): 476 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 477 backend_string = "BackendA" 478 qconfig = get_symmetric_quantization_config() 479 for n in gm.graph.nodes: 480 if n.op != "call_function": 481 continue 482 483 n.meta["quantization_annotation"] = QuantizationAnnotation( 484 input_qspec_map={n.args[0]: qconfig.input_activation}, 485 output_qspec=qconfig.output_activation, 486 ) 487 488 tag = str(n.target) 489 n.meta["quantization_tag"] = tag 490 for arg in n.args: 491 if arg.op == "get_attr": 492 arg.meta["quantization_tag"] = tag 493 494 def validate(self, model: torch.fx.GraphModule) -> None: 495 pass 496 497 example_inputs = (torch.randn(16, 24), torch.randn(8, 24)) 498 get_attr_tags = {"aten.matmul.default"} 499 quantize_per_tensor_tensor_tags = { 500 "aten.matmul.default", 501 "aten.add.Tensor", 502 "aten.relu.default", 503 } 504 dequantize_per_tensor_tensor_tags = { 505 "aten.matmul.default", 506 "aten.add.Tensor", 507 "aten.relu.default", 508 } 509 node_tags = { 510 "get_attr": get_attr_tags, 511 torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags, 512 torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags, 513 } 514 m = self._test_metadata_porting( 515 MatmulWithConstInput(), 516 example_inputs, 517 BackendAQuantizer(), 518 node_tags, 519 ) 520