1# mypy: allow-untyped-defs 2import dataclasses 3import itertools 4import operator 5from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING 6 7import torch 8import torch.nn.functional as F 9from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 10from torch.ao.quantization.pt2e.export_utils import _WrapperModule 11from torch.ao.quantization.quantizer import ( 12 DerivedQuantizationSpec, 13 EdgeOrNode, 14 QuantizationSpecBase, 15 SharedQuantizationSpec, 16) 17from torch.fx import Graph, GraphModule, Node 18from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns 19 20from .utils import ( 21 _conv1d_bn_example_inputs, 22 _conv2d_bn_example_inputs, 23 _get_aten_graph_module_for_pattern, 24 _is_bn_node, 25 _is_conv_or_conv_transpose_node, 26 _is_conv_transpose_fn, 27 fold_bn_weights_into_conv_node, 28) 29 30 31if TYPE_CHECKING: 32 from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch 33 34__all__ = [] # type: ignore[var-annotated] 35 36 37# Example inputs for quantized and folded conv-bn1d patterns used in convert 38_quantized_conv1d_bn_example_inputs = ( 39 torch.randn(1, 1, 3), # x 40 torch.randn(1, 1, 1), # conv_weight 41 torch.randn(1), # bn_weight 42 torch.randn(1), # bn_bias 43 torch.randn(1), # bn_running_mean 44 torch.randn(1), # bn_running_var 45) 46 47# Example inputs for quantized and folded conv-bn2d patterns used in convert 48_quantized_conv2d_bn_example_inputs = ( 49 torch.randn(1, 1, 3, 3), # x 50 torch.randn(1, 1, 1, 1), # conv_weight 51 torch.randn(1), # bn_weight 52 torch.randn(1), # bn_bias 53 torch.randn(1), # bn_running_mean 54 torch.randn(1), # bn_running_var 55) 56 57 58def _get_quantized_conv_bn_example_inputs_kwargs( 59 is_per_channel: bool, 60 has_bias: bool, 61 bias_is_quantized: bool, 62 is_cuda: bool, 63) -> Dict[str, Any]: 64 """ 65 Optional example inputs for quantized and folded conv-bn patterns 66 used in convert, expressed as kwargs. 67 """ 68 kwargs = {} 69 # Per tensor quantization uses literals to represent scale and zero 70 # point, so there is no need to include them here as kwargs 71 if is_per_channel: 72 kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float) 73 kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int) 74 if has_bias and bias_is_quantized: 75 kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float) 76 kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int) 77 if has_bias: 78 kwargs["conv_bias"] = torch.randn(1) 79 if is_cuda: 80 for k, v in kwargs.items(): 81 if isinstance(v, torch.Tensor): 82 kwargs[k] = v.cuda() 83 return kwargs 84 85 86def _get_conv_bn_pattern(conv_fn: Callable) -> Callable: 87 def _conv_bn_pattern( 88 x: torch.Tensor, 89 conv_weight: torch.Tensor, 90 conv_bias: torch.Tensor, 91 bn_weight: torch.Tensor, 92 bn_bias: torch.Tensor, 93 bn_running_mean: torch.Tensor, 94 bn_running_var: torch.Tensor, 95 ) -> torch.Tensor: 96 x = conv_fn(x, conv_weight, conv_bias) 97 x = F.batch_norm( 98 x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True 99 ) 100 return x 101 102 return _WrapperModule(_conv_bn_pattern) 103 104 105# TODO: merge this with the `no_conv_bias` case 106def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable: 107 def _qat_conv_bn_pattern( 108 x: torch.Tensor, 109 conv_weight: torch.Tensor, 110 conv_bias: torch.Tensor, 111 bn_weight: torch.Tensor, 112 bn_bias: torch.Tensor, 113 bn_running_mean: torch.Tensor, 114 bn_running_var: torch.Tensor, 115 ) -> torch.Tensor: 116 """ 117 Approximated method to fuse conv and bn. It requires only one forward pass. 118 conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std. 119 This is based on `nniqat.ConvBn2d._forward_approximate`. 120 """ 121 # TODO: allow setting eps 122 bn_eps = 1e-5 123 running_std = torch.sqrt(bn_running_var + bn_eps) 124 scale_factor = bn_weight / running_std 125 weight_shape = [1] * len(conv_weight.shape) 126 weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 127 weight_shape[weight_in_channel_axis] = -1 128 bias_shape = [1] * len(conv_weight.shape) 129 bias_shape[1] = -1 130 scaled_weight = conv_weight * scale_factor.reshape(weight_shape) 131 zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype) 132 x = conv_fn(x, scaled_weight, zero_bias) 133 x = x / scale_factor.reshape(bias_shape) 134 x = x + conv_bias.reshape(bias_shape) 135 x = F.batch_norm( 136 x, 137 bn_running_mean, 138 bn_running_var, 139 bn_weight, 140 bn_bias, 141 training=True, 142 eps=bn_eps, 143 ) 144 return x 145 146 return _WrapperModule(_qat_conv_bn_pattern) 147 148 149def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable: 150 def _qat_conv_bn_pattern_no_conv_bias( 151 x: torch.Tensor, 152 conv_weight: torch.Tensor, 153 # Not used, only for matching convenience 154 conv_bias: torch.Tensor, 155 bn_weight: torch.Tensor, 156 bn_bias: torch.Tensor, 157 bn_running_mean: torch.Tensor, 158 bn_running_var: torch.Tensor, 159 ) -> torch.Tensor: 160 """ 161 Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias. 162 """ 163 # TODO: allow setting eps 164 bn_eps = 1e-5 165 running_std = torch.sqrt(bn_running_var + bn_eps) 166 scale_factor = bn_weight / running_std 167 weight_shape = [1] * len(conv_weight.shape) 168 weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0 169 weight_shape[weight_in_channel_axis] = -1 170 bias_shape = [1] * len(conv_weight.shape) 171 bias_shape[1] = -1 172 scaled_weight = conv_weight * scale_factor.reshape(weight_shape) 173 x = conv_fn(x, scaled_weight, None) 174 x = x / scale_factor.reshape(bias_shape) 175 x = F.batch_norm( 176 x, 177 bn_running_mean, 178 bn_running_var, 179 bn_weight, 180 bn_bias, 181 training=True, 182 eps=bn_eps, 183 ) 184 return x 185 186 return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias) 187 188 189def _append_qdq(x, is_per_channel, is_bias, kwargs): 190 """ 191 Helper function to append q-dq ops after `x`, using dummy values for the qparams 192 and qmin/qmax. We use dummy values here because we match with `ignore_literals=True` 193 and will manually replace these values after subgraph rewriting. 194 195 Return the dq node. 196 """ 197 # Dummy args to be passed into q-dq ops 198 per_channel_axis = 0 199 scale_key = "bias_scale" if is_bias else "weight_scale" 200 zp_key = "bias_zero_point" if is_bias else "weight_zero_point" 201 scale = kwargs[scale_key] if is_per_channel else 1.0 202 zp = kwargs[zp_key] if is_per_channel else 0 203 qmin = -127 204 qmax = 127 205 dtype = torch.int8 206 207 qd = torch.ops.quantized_decomposed 208 if is_per_channel: 209 x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) 210 x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) 211 else: 212 x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype) 213 x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype) 214 return x 215 216 217def _get_quantized_qat_conv_bn_pattern( 218 is_per_channel: bool, 219 has_bias: bool, 220 bias_is_quantized: bool, 221 conv_fn: Callable, 222 bn_is_training: bool, 223) -> Callable: 224 """ 225 Return the quantized version of QAT conv + BN pattern. 226 This is based on `nniqat.ConvBn2d._forward_approximate`, 227 used in QAT convert. We first match this pattern and replace 228 it with the normal [conv - bn] pattern, then fold the BN 229 weights into conv. 230 """ 231 # TODO: allow setting eps 232 bn_eps = 1e-5 233 234 def _quantized_qat_conv_bn_pattern( 235 x: torch.Tensor, 236 conv_weight: torch.Tensor, 237 bn_weight: torch.Tensor, 238 bn_bias: torch.Tensor, 239 bn_running_mean: torch.Tensor, 240 bn_running_var: torch.Tensor, 241 **kwargs, 242 ) -> torch.Tensor: 243 running_std = torch.sqrt(bn_running_var + bn_eps) 244 scale_factor = bn_weight / running_std 245 weight_shape = [1] * len(conv_weight.shape) 246 weight_shape[0] = -1 247 bias_shape = [1] * len(conv_weight.shape) 248 bias_shape[1] = -1 249 scaled_weight = conv_weight * scale_factor.reshape(weight_shape) 250 scaled_weight = _append_qdq( 251 scaled_weight, 252 is_per_channel, 253 is_bias=False, 254 kwargs=kwargs, 255 ) 256 if has_bias: 257 zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype) 258 if bias_is_quantized: 259 zero_bias = _append_qdq( 260 zero_bias, 261 is_per_channel, 262 is_bias=True, 263 kwargs=kwargs, 264 ) 265 x = conv_fn(x, scaled_weight, zero_bias) 266 else: 267 x = conv_fn(x, scaled_weight, None) 268 x = x / scale_factor.reshape(bias_shape) 269 if has_bias: 270 x = x + kwargs["conv_bias"].reshape(bias_shape) 271 x = F.batch_norm( 272 x, 273 bn_running_mean, 274 bn_running_var, 275 bn_weight, 276 bn_bias, 277 training=bn_is_training, 278 eps=bn_eps, 279 ) 280 return x 281 282 return _WrapperModule(_quantized_qat_conv_bn_pattern) 283 284 285def _get_folded_quantized_qat_conv_bn_pattern( 286 is_per_channel: bool, 287 has_bias: bool, 288 bias_is_quantized: bool, 289 conv_fn: Callable, 290 bn_is_training: bool, 291) -> Callable: 292 """ 293 Quantized QAT conv - bn pattern with bn weights being folded into conv. 294 """ 295 # TODO: allow setting eps 296 bn_eps = 1e-5 297 298 def _folded_quantized_qat_conv_bn_pattern( 299 x: torch.Tensor, 300 conv_weight: torch.Tensor, 301 bn_weight: torch.Tensor, 302 bn_bias: torch.Tensor, 303 bn_running_mean: torch.Tensor, 304 bn_running_var: torch.Tensor, 305 **kwargs, 306 ) -> torch.Tensor: 307 conv_weight = _append_qdq( 308 conv_weight, 309 is_per_channel, 310 is_bias=False, 311 kwargs=kwargs, 312 ) 313 if has_bias: 314 bias = kwargs["conv_bias"] 315 if bias_is_quantized: 316 bias = _append_qdq( 317 bias, 318 is_per_channel, 319 is_bias=True, 320 kwargs=kwargs, 321 ) 322 else: 323 bias = None 324 x = conv_fn(x, conv_weight, bias) 325 x = F.batch_norm( 326 x, 327 bn_running_mean, 328 bn_running_var, 329 bn_weight, 330 bn_bias, 331 training=bn_is_training, 332 eps=bn_eps, 333 ) 334 return x 335 336 return _WrapperModule(_folded_quantized_qat_conv_bn_pattern) 337 338 339def _has_conv_bias_filter( 340 match: "InternalMatch", 341 original_graph: Graph, 342 pattern_graph: Graph, 343) -> bool: 344 """ 345 Match filter for the subgraph rewriter that returns True if the conv node in 346 the original graph has bias. 347 """ 348 for n in match.nodes_map.values(): 349 if _is_conv_or_conv_transpose_node(n): 350 return len(n.args) > 2 and n.args[2] is not None 351 raise ValueError("Could not find conv node in matched conv + bn pattern") 352 353 354def _no_conv_bias_filter( 355 match: "InternalMatch", 356 original_graph: Graph, 357 pattern_graph: Graph, 358) -> bool: 359 """ 360 Match filter for the subgraph rewriter that returns True if the conv node in 361 the original graph does NOT have bias. 362 """ 363 return not _has_conv_bias_filter(match, original_graph, pattern_graph) 364 365 366def _is_quantize(n: Node) -> bool: 367 return n.target in [ 368 torch.ops.quantized_decomposed.quantize_per_tensor.default, 369 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 370 torch.ops.quantized_decomposed.quantize_per_channel.default, 371 ] 372 373 374def _is_dequantize(n: Node) -> bool: 375 return n.target in [ 376 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 377 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 378 torch.ops.quantized_decomposed.dequantize_per_channel.default, 379 ] 380 381 382def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Node]]: 383 """ 384 Helper function to extract the nodes in the conv-bn fusion pattern after 385 subgraph rewriting, in the form of a map: 386 387 {name: (original_node, replacement_node)} 388 389 The following names must exist in the map: 390 391 "conv", "conv_weight", "conv_input", "bn", "getitem" 392 393 The following names may exist in the map: 394 395 "conv_weight_q", "conv_weight_dq", "conv_bias", 396 "conv_bias_q", "conv_bias_dq" 397 """ 398 399 def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Optional[Node]]: 400 """ 401 Return a 3-tuple of (conv_node, bn_node, getitem_node). 402 This asserts that the match contains exactly one of each node. 403 """ 404 conv_node, bn_node, getitem_node = None, None, None 405 for n in nodes: 406 if n.op != "call_function": 407 continue 408 if _is_conv_or_conv_transpose_node(n): 409 assert conv_node is None 410 conv_node = n 411 if _is_bn_node(n): 412 assert bn_node is None 413 bn_node = n 414 if n.target == operator.getitem: 415 assert getitem_node is None 416 getitem_node = n 417 assert conv_node is not None 418 assert bn_node is not None 419 # getitem_node might be None in new training IR 420 return (conv_node, bn_node, getitem_node) 421 422 def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]: 423 """ 424 Return a 3-tuple of (orig_node, q_node, dq_node). 425 """ 426 assert _is_dequantize(n) 427 q_node = n.args[0] 428 assert isinstance(q_node, Node) 429 assert _is_quantize(q_node) 430 orig_node = q_node.args[0] 431 assert isinstance(orig_node, Node) 432 return (orig_node, q_node, n) 433 434 original_nodes = list(_filter_nodes_map(r.nodes_map).values()) 435 o_conv, o_bn, o_getitem = _get_nodes(original_nodes) 436 r_conv, r_bn, r_getitem = _get_nodes(r.replacements) 437 438 # Create the mapping from original node to replacement node 439 if o_getitem is None: 440 # getitem is None is new training IR 441 assert r_getitem is None 442 mapping = { 443 "conv": (o_conv, r_conv), 444 "bn": (o_bn, r_bn), 445 } 446 else: 447 # TODO: This branch is going through a deprecated branch and should be deleted soon, 448 # after capture_pre_autograd_graph fully migrate to training IR 449 # T199018392 450 assert r_getitem is not None 451 assert o_getitem is not None 452 mapping = { 453 "conv": (o_conv, r_conv), 454 "bn": (o_bn, r_bn), 455 "getitem": (o_getitem, r_getitem), 456 } 457 458 # Extract conv input and weight 459 # Note: here we extract the original nodes indirectly through the pattern nodes 460 # because the args of the original nodes are no longer available after replacement 461 (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) 462 (p_conv_input, p_conv_weight, *_) = p_conv.args 463 (r_conv_input, r_conv_weight, *_) = r_conv.args 464 assert isinstance(p_conv_input, Node) 465 assert isinstance(p_conv_weight, Node) 466 assert isinstance(r_conv_input, Node) 467 assert isinstance(r_conv_weight, Node) 468 o_conv_input = r.nodes_map[p_conv_input] 469 o_conv_weight = r.nodes_map[p_conv_weight] 470 471 # If conv weight is quantized, extract the q - dq nodes 472 if _is_dequantize(p_conv_weight): 473 p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes( 474 p_conv_weight 475 ) 476 r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes( 477 r_conv_weight 478 ) 479 o_conv_weight = r.nodes_map[p_conv_weight] 480 o_conv_weight_q = r.nodes_map[p_conv_weight_q] 481 o_conv_weight_dq = r.nodes_map[p_conv_weight_dq] 482 mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q) 483 mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq) 484 mapping["conv_input"] = (o_conv_input, r_conv_input) 485 mapping["conv_weight"] = (o_conv_weight, r_conv_weight) 486 487 # Extract conv bias 488 if len(p_conv.args) > 2 and len(r_conv.args) > 2: 489 p_conv_bias = p_conv.args[2] 490 r_conv_bias = r_conv.args[2] 491 assert isinstance(p_conv_bias, Node) 492 assert isinstance(r_conv_bias, Node) 493 o_conv_bias = r.nodes_map[p_conv_bias] 494 495 # If conv bias is quantized, extract the q - dq nodes 496 if _is_dequantize(p_conv_bias): 497 p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias) 498 r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias) 499 o_conv_bias = r.nodes_map[p_conv_bias] 500 o_conv_bias_q = r.nodes_map[p_conv_bias_q] 501 o_conv_bias_dq = r.nodes_map[p_conv_bias_dq] 502 mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q) 503 mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq) 504 mapping["conv_bias"] = (o_conv_bias, r_conv_bias) 505 return mapping 506 507 508def _filter_nodes_map(nodes_map: Dict[Node, Node]) -> Dict[Node, Node]: 509 """ 510 Return a filtered `nodes_map` returned from the subgraph rewriter. 511 The filtered `nodes_map` will contain only nodes that are actually 512 matched in the pattern, excluding None or placeholder nodes. 513 """ 514 new_nodes_map: Dict[Node, Node] = {} 515 for pattern_node, graph_node in nodes_map.items(): 516 # bias can be None 517 if graph_node is None: 518 continue 519 # skip pattern placeholder nodes 520 if pattern_node.op == "placeholder": 521 continue 522 new_nodes_map[pattern_node] = graph_node 523 return new_nodes_map 524 525 526# TODO: this is error prone, use the replace_literals_with_placeholders hack instead 527def _copy_over_literal_conv_args(original_node: Node, new_node: Node): 528 """ 529 Copy over literal args in conv, such as stride and padding, from the matched node 530 in the original graph to its replacement in the new graph. 531 532 This is needed due to the following limitation in the subgraph rewriter when used 533 with dynamo export: literal (non-tensor) args are not supported in the match and 534 replacement patterns. This is because dynamo export automatically inlines these 535 literal args, making them dead placeholder nodes. In the future, we should check 536 if dynamo export can optionally disable this inlining, or if subgraph rewriter 537 can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419. 538 539 Note: Unlike other tensor args like conv weights and biases, literal args are 540 preserved in the original nodes after replacement, so we can access them here. 541 """ 542 assert _is_conv_or_conv_transpose_node(original_node) 543 assert _is_conv_or_conv_transpose_node(new_node) 544 # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups] 545 new_args = list(new_node.args) 546 if len(new_args) < 3: 547 # bias is optional, when it is not present, it means it is None 548 new_args.append(None) 549 new_node.args = tuple(new_args[:3]) + original_node.args[3:] 550 551 552def _update_conv_input_qspec_map_after_replacement( 553 original_node: Node, replacement_node: Node 554): 555 """ 556 Update the `input_qspec_map` in the annotation after subgraph rewriting. 557 558 The original annotation referred to the nodes in the original graph, 559 so the keys in the `input_qspec_map` will need to be updated to reflect 560 the corresponding nodes in the replacement graph. 561 """ 562 assert _is_conv_or_conv_transpose_node(original_node) 563 assert _is_conv_or_conv_transpose_node(replacement_node) 564 if "quantization_annotation" not in original_node.meta: 565 return 566 original_input_qspec_map = original_node.meta[ 567 "quantization_annotation" 568 ].input_qspec_map 569 input_qspec_map = {} 570 # get the list of configs, it should be ordered as input, weight, bias 571 # note: this is really hacky, we need a better solution, hopefully 572 # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820 573 all_configs = list(original_input_qspec_map.items()) 574 # input activation 575 input_qspec_map[replacement_node.args[0]] = all_configs[0][1] 576 # weight 577 input_qspec_map[replacement_node.args[1]] = all_configs[1][1] 578 # bias 579 if len(replacement_node.args) > 2 and len(all_configs) > 2: 580 input_qspec_map[replacement_node.args[2]] = all_configs[2][1] 581 replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map 582 583 584def _update_special_qspecs_after_replacement( 585 node: Node, 586 original_to_replacement_node: Dict[Node, Node], 587): 588 """ 589 Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s 590 used in `node`'s quantization annotation after subgraph rewriting. 591 592 The original annotation referred to the nodes in the original graph, 593 so the nodes used in these special quantization specs will need to 594 be updated to the corresponding nodes in the replacement graph. 595 """ 596 597 def _get_new_edge_or_node(edge_or_node: EdgeOrNode): 598 if isinstance(edge_or_node, Node): 599 _node = edge_or_node 600 return original_to_replacement_node.get(_node, _node) 601 elif ( 602 isinstance(edge_or_node, tuple) 603 and len(edge_or_node) == 2 604 and all(isinstance(x, Node) for x in edge_or_node) 605 ): 606 src, dest = edge_or_node 607 return ( 608 original_to_replacement_node.get(src, src), 609 original_to_replacement_node.get(dest, dest), 610 ) 611 else: 612 raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node)) 613 614 def _get_new_qspec(qspec: QuantizationSpecBase): 615 if isinstance(qspec, SharedQuantizationSpec): 616 new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node) 617 return SharedQuantizationSpec(new_edge_or_node) 618 elif isinstance(qspec, DerivedQuantizationSpec): 619 new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from] 620 return dataclasses.replace(qspec, derived_from=new_derived_from) 621 else: 622 return qspec 623 624 if "quantization_annotation" not in node.meta: 625 return 626 annotation = node.meta["quantization_annotation"] 627 for input_node, qspec in annotation.input_qspec_map.items(): 628 annotation.input_qspec_map[input_node] = _get_new_qspec(qspec) 629 annotation.output_qspec = _get_new_qspec(annotation.output_qspec) 630 631 632def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: 633 has_bn = any(_is_bn_node(n) for n in m.graph.nodes) 634 if not has_bn: 635 return m 636 is_cuda_options = [True, False] if torch.cuda.is_available() else [False] 637 for is_cuda in is_cuda_options: 638 m = _fuse_conv_bn_qat_helper( 639 m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda 640 ) 641 m = _fuse_conv_bn_qat_helper( 642 m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda 643 ) 644 m = _fuse_conv_bn_qat_helper( 645 m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda 646 ) 647 m = _fuse_conv_bn_qat_helper( 648 m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda 649 ) 650 return m 651 652 653def _fuse_conv_bn_qat_helper( 654 m: GraphModule, 655 conv_fn: Callable, 656 example_inputs: Tuple[Any, ...], 657 is_cuda: bool, 658) -> GraphModule: 659 """ 660 Given a graph of decomposed aten ops, replace the (conv + bn) pattern with 661 the fused QAT subgraph equivalent. The input graph should already be annotated. 662 The annotations in the original nodes will be preserved in the corresponding 663 nodes in the new subgraph. 664 665 Note: This also handles the (conv + bn + relu) pattern. 666 """ 667 m.graph.eliminate_dead_code() 668 m.recompile() 669 conv_bn_pattern = _get_conv_bn_pattern(conv_fn) 670 match_pattern = _get_aten_graph_module_for_pattern( 671 conv_bn_pattern, example_inputs, is_cuda 672 ) 673 674 # Step (1): Replace patterns with conv bias 675 # 676 # Here we do replacement separately for cases with and without conv bias, since 677 # the replacement patterns for these two cases are substantially different. 678 # TODO: use the public replace_pattern API once it also returns replacement nodes 679 680 qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) 681 replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern( 682 qat_conv_bn_pattern, 683 example_inputs, 684 is_cuda, 685 ) 686 replacements_with_conv_bias = replace_pattern_with_filters( 687 m, 688 match_pattern, 689 replacement_pattern_with_conv_bias, 690 match_filters=[_has_conv_bias_filter], 691 ignore_literals=True, 692 ) 693 m.recompile() 694 695 # Step (2): Replace patterns without conv bias 696 697 qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) 698 replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern( 699 qat_conv_bn_pattern_no_conv_bias, 700 example_inputs, 701 is_cuda, 702 ) 703 replacements_no_conv_bias = replace_pattern_with_filters( 704 m, 705 match_pattern, 706 replacement_pattern_no_conv_bias, 707 match_filters=[_no_conv_bias_filter], 708 ignore_literals=True, 709 ) 710 m.recompile() 711 712 # Step (3): Post processing 713 # 714 # Due to limited functionality in the subgraph rewriter, here we manually 715 # update the replacement graph as follows: 716 # 717 # (a) Copy over metadata from original subgraph. This ensures the stack traces 718 # and annotations are preserved in the new subgraph 719 # 720 # (b) Copy over literal args for conv from the original subgraph 721 # TODO: do this for literal args for batchnorm as well 722 # 723 # (c) Update all references of the old nodes in the original subgraph to refer 724 # to the corresponding nodes in the new subgraph in the annotations 725 # 726 # In the future, we should try to push as much of this functionality into the 727 # subgraph rewriter as possible, so we don't have to manually copy anything over. 728 # For more detail, see https://github.com/pytorch/pytorch/issues/100419. 729 730 all_original_to_replacement_nodes = {} 731 for r in replacements_with_conv_bias + replacements_no_conv_bias: 732 for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values(): 733 # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] 734 replacement_node.meta = original_node.meta 735 if _is_conv_or_conv_transpose_node(original_node): 736 # Step (3b): Copy over conv literal args 737 _copy_over_literal_conv_args(original_node, replacement_node) 738 # Step (3c): Update old references in the conv node's input_qspec_map 739 _update_conv_input_qspec_map_after_replacement( 740 original_node, replacement_node 741 ) 742 all_original_to_replacement_nodes[original_node] = replacement_node 743 744 # Step (3c): Update old references in the special qspecs for all nodes in the graph 745 for n in m.graph.nodes: 746 _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes) 747 748 return m 749 750 751def _duplicate_dequantize_node(m: GraphModule): 752 """ 753 Helper function to duplicate all dequantize nodes in the graph if the 754 node has more than one user. For example: 755 756 Before: 757 quantize -> dequantize -> a 758 \\--> b 759 \\--> c 760 761 After: 762 quantize -> dequantize_1 -> a 763 \\--> dequantize_2 -> b 764 \\--> dequantize_3 -> c 765 766 This is useful for subgraph rewriting. E.g. if we wish to match the 767 pattern [dequantize - a] above, subgraph matching would fail because 768 the dequantize node has users outside the matched portion of the graph. 769 Instead, we match [dequantize_1 - a], which is safe. 770 """ 771 dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor 772 for n in m.graph.nodes: 773 if n.op != "call_function" or n.target != dq_op or len(n.users) == 1: 774 continue 775 for user in list(n.users): 776 with m.graph.inserting_before(n): 777 new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs) 778 user.replace_input_with(n, new_node) 779 m.graph.erase_node(n) 780 m.recompile() 781 782 783def _remove_extra_dequantize(m: GraphModule): 784 """ 785 Removes duplicate dequant nodes in the graph, for an operator that has 786 multiple dequant nodes as a user, replace them with a single dequant node 787 that can be shared across all the uses. This should be seen as the "reverse" 788 of `_duplicate_dequantize_node`. 789 """ 790 dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor 791 for n in m.graph.nodes: 792 dq_users = [ 793 user 794 for user in n.users 795 if user.op == "call_function" and user.target == dq_op 796 ] 797 if len(dq_users) > 1: 798 with m.graph.inserting_after(dq_users[0]): 799 new_node = m.graph.create_node( 800 "call_function", dq_op, dq_users[0].args, {} 801 ) 802 for dq_user in dq_users: 803 dq_user.replace_all_uses_with(new_node) 804 m.graph.erase_node(dq_user) 805 m.recompile() 806 807 808def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): 809 """ 810 Given a pair of quantize or dequantize nodes, copy over all literal args 811 from the original node to the replacement node. 812 """ 813 # For quantize_per_tensor, scale and zp are literals and need to be copied 814 # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped 815 assert original_node.target == replacement_node.target 816 if original_node.target in ( 817 torch.ops.quantized_decomposed.quantize_per_tensor.default, 818 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 819 ): 820 # Args: input, [scale, zp, qmin, qmax, dtype] 821 start_copy_arg_index = 1 822 elif original_node.target in ( 823 torch.ops.quantized_decomposed.quantize_per_channel.default, 824 torch.ops.quantized_decomposed.dequantize_per_channel.default, 825 ): 826 # Args: input, scale, zp, [axis, qmin, qmax, dtype] 827 start_copy_arg_index = 3 828 else: 829 raise ValueError( 830 f"Expected quantize/dequantize nodes, got '{original_node.target}'" 831 ) 832 replacement_node.args = ( 833 replacement_node.args[:start_copy_arg_index] 834 + original_node.args[start_copy_arg_index:] 835 ) 836 837 838def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: 839 has_bn = any(_is_bn_node(n) for n in m.graph.nodes) 840 if not has_bn: 841 return m 842 is_cuda_options = [True, False] if torch.cuda.is_available() else [False] 843 for is_cuda in is_cuda_options: 844 m = _fold_conv_bn_qat_helper( 845 m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda 846 ) 847 m = _fold_conv_bn_qat_helper( 848 m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda 849 ) 850 m = _fold_conv_bn_qat_helper( 851 m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda 852 ) 853 m = _fold_conv_bn_qat_helper( 854 m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda 855 ) 856 857 # remove in place add from batchnorm tracking traning stats 858 for node in m.graph.nodes: 859 if ( 860 node.target == torch.ops.aten.add_.Tensor 861 and node.args[0].op == "get_attr" 862 and node.args[1] == 1 863 and torch.nn.modules.batchnorm.BatchNorm2d 864 in [val[1] for val in node.meta["source_fn_stack"]] 865 ): 866 m.graph.erase_node(node) 867 868 m.graph.eliminate_dead_code() 869 m.recompile() 870 871 return m 872 873 874def _fold_conv_bn_qat_helper( 875 m: GraphModule, 876 conv_fn: Callable, 877 example_inputs: Tuple[Any, ...], 878 is_cuda: bool, 879) -> GraphModule: 880 """ 881 Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. 882 """ 883 m.graph.eliminate_dead_code() 884 m.recompile() 885 _duplicate_dequantize_node(m) 886 887 # Step (1): Replace QAT pattern with simple [conv - bn] pattern 888 replacements = [] 889 replacement_options = itertools.product( 890 [True, False], # is_per_channel 891 [True, False], # has_bias 892 [True, False], # bias_is_quantized 893 [True, False], # bn_is_training 894 ) 895 for ( 896 is_per_channel, 897 has_bias, 898 bias_is_quantized, 899 bn_is_training, 900 ) in replacement_options: 901 # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily 902 # filter out one of the values for this flag to avoid having duplicate patterns 903 if not has_bias and bias_is_quantized: 904 continue 905 kwargs = _get_quantized_conv_bn_example_inputs_kwargs( 906 is_per_channel, has_bias, bias_is_quantized, is_cuda 907 ) 908 match_pattern = _get_quantized_qat_conv_bn_pattern( 909 is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training 910 ) 911 match_pattern = _get_aten_graph_module_for_pattern( 912 match_pattern, example_inputs, is_cuda, **kwargs 913 ) 914 replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( 915 is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training 916 ) 917 replacement_pattern = _get_aten_graph_module_for_pattern( 918 replacement_pattern, example_inputs, is_cuda, **kwargs 919 ) 920 replacements.extend( 921 replace_pattern_with_filters( 922 m, 923 match_pattern, 924 replacement_pattern, 925 ignore_literals=True, 926 ) 927 ) 928 m.recompile() 929 _remove_extra_dequantize(m) 930 931 for r in replacements: 932 node_map = _get_conv_bn_pattern_nodes(r) 933 934 # Step (2): Copy over metadata from original subgraph 935 for original_node, replacement_node in node_map.values(): 936 replacement_node.meta = original_node.meta 937 938 # Step (3): Copy over args for weight (and optionally bias) q - dq nodes 939 _copy_over_q_dq_args(*node_map["conv_weight_q"]) 940 _copy_over_q_dq_args(*node_map["conv_weight_dq"]) 941 if "conv_bias_q" in node_map: 942 assert "conv_bias_dq" in node_map 943 _copy_over_q_dq_args(*node_map["conv_bias_q"]) 944 _copy_over_q_dq_args(*node_map["conv_bias_dq"]) 945 946 # Step (4): Fold BN weights into conv 947 conv_bias = None 948 (_, conv_node) = node_map["conv"] 949 (_, bn_node) = node_map["bn"] 950 (_, conv_weight) = node_map["conv_weight"] 951 if "conv_bias" in node_map: 952 (_, conv_bias) = node_map["conv_bias"] 953 fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) 954 955 # Copy over literal args for conv 956 for original_node in _filter_nodes_map(r.nodes_map).values(): 957 if _is_conv_or_conv_transpose_node(original_node): 958 _copy_over_literal_conv_args(original_node, conv_node) 959 960 m.graph.eliminate_dead_code() 961 m.recompile() 962 return m 963