xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/qat_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import itertools
4import operator
5from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
6
7import torch
8import torch.nn.functional as F
9from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
10from torch.ao.quantization.pt2e.export_utils import _WrapperModule
11from torch.ao.quantization.quantizer import (
12    DerivedQuantizationSpec,
13    EdgeOrNode,
14    QuantizationSpecBase,
15    SharedQuantizationSpec,
16)
17from torch.fx import Graph, GraphModule, Node
18from torch.fx.subgraph_rewriter import replace_pattern_with_filters, ReplacedPatterns
19
20from .utils import (
21    _conv1d_bn_example_inputs,
22    _conv2d_bn_example_inputs,
23    _get_aten_graph_module_for_pattern,
24    _is_bn_node,
25    _is_conv_or_conv_transpose_node,
26    _is_conv_transpose_fn,
27    fold_bn_weights_into_conv_node,
28)
29
30
31if TYPE_CHECKING:
32    from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
33
34__all__ = []  # type: ignore[var-annotated]
35
36
37# Example inputs for quantized and folded conv-bn1d patterns used in convert
38_quantized_conv1d_bn_example_inputs = (
39    torch.randn(1, 1, 3),  # x
40    torch.randn(1, 1, 1),  # conv_weight
41    torch.randn(1),  # bn_weight
42    torch.randn(1),  # bn_bias
43    torch.randn(1),  # bn_running_mean
44    torch.randn(1),  # bn_running_var
45)
46
47# Example inputs for quantized and folded conv-bn2d patterns used in convert
48_quantized_conv2d_bn_example_inputs = (
49    torch.randn(1, 1, 3, 3),  # x
50    torch.randn(1, 1, 1, 1),  # conv_weight
51    torch.randn(1),  # bn_weight
52    torch.randn(1),  # bn_bias
53    torch.randn(1),  # bn_running_mean
54    torch.randn(1),  # bn_running_var
55)
56
57
58def _get_quantized_conv_bn_example_inputs_kwargs(
59    is_per_channel: bool,
60    has_bias: bool,
61    bias_is_quantized: bool,
62    is_cuda: bool,
63) -> Dict[str, Any]:
64    """
65    Optional example inputs for quantized and folded conv-bn patterns
66    used in convert, expressed as kwargs.
67    """
68    kwargs = {}
69    # Per tensor quantization uses literals to represent scale and zero
70    # point, so there is no need to include them here as kwargs
71    if is_per_channel:
72        kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
73        kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
74        if has_bias and bias_is_quantized:
75            kwargs["bias_scale"] = torch.tensor([1], dtype=torch.float)
76            kwargs["bias_zero_point"] = torch.tensor([0], dtype=torch.int)
77    if has_bias:
78        kwargs["conv_bias"] = torch.randn(1)
79    if is_cuda:
80        for k, v in kwargs.items():
81            if isinstance(v, torch.Tensor):
82                kwargs[k] = v.cuda()
83    return kwargs
84
85
86def _get_conv_bn_pattern(conv_fn: Callable) -> Callable:
87    def _conv_bn_pattern(
88        x: torch.Tensor,
89        conv_weight: torch.Tensor,
90        conv_bias: torch.Tensor,
91        bn_weight: torch.Tensor,
92        bn_bias: torch.Tensor,
93        bn_running_mean: torch.Tensor,
94        bn_running_var: torch.Tensor,
95    ) -> torch.Tensor:
96        x = conv_fn(x, conv_weight, conv_bias)
97        x = F.batch_norm(
98            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
99        )
100        return x
101
102    return _WrapperModule(_conv_bn_pattern)
103
104
105# TODO: merge this with the `no_conv_bias` case
106def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
107    def _qat_conv_bn_pattern(
108        x: torch.Tensor,
109        conv_weight: torch.Tensor,
110        conv_bias: torch.Tensor,
111        bn_weight: torch.Tensor,
112        bn_bias: torch.Tensor,
113        bn_running_mean: torch.Tensor,
114        bn_running_var: torch.Tensor,
115    ) -> torch.Tensor:
116        """
117        Approximated method to fuse conv and bn. It requires only one forward pass.
118        conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std.
119        This is based on `nniqat.ConvBn2d._forward_approximate`.
120        """
121        # TODO: allow setting eps
122        bn_eps = 1e-5
123        running_std = torch.sqrt(bn_running_var + bn_eps)
124        scale_factor = bn_weight / running_std
125        weight_shape = [1] * len(conv_weight.shape)
126        weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
127        weight_shape[weight_in_channel_axis] = -1
128        bias_shape = [1] * len(conv_weight.shape)
129        bias_shape[1] = -1
130        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
131        zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
132        x = conv_fn(x, scaled_weight, zero_bias)
133        x = x / scale_factor.reshape(bias_shape)
134        x = x + conv_bias.reshape(bias_shape)
135        x = F.batch_norm(
136            x,
137            bn_running_mean,
138            bn_running_var,
139            bn_weight,
140            bn_bias,
141            training=True,
142            eps=bn_eps,
143        )
144        return x
145
146    return _WrapperModule(_qat_conv_bn_pattern)
147
148
149def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
150    def _qat_conv_bn_pattern_no_conv_bias(
151        x: torch.Tensor,
152        conv_weight: torch.Tensor,
153        # Not used, only for matching convenience
154        conv_bias: torch.Tensor,
155        bn_weight: torch.Tensor,
156        bn_bias: torch.Tensor,
157        bn_running_mean: torch.Tensor,
158        bn_running_var: torch.Tensor,
159    ) -> torch.Tensor:
160        """
161        Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias.
162        """
163        # TODO: allow setting eps
164        bn_eps = 1e-5
165        running_std = torch.sqrt(bn_running_var + bn_eps)
166        scale_factor = bn_weight / running_std
167        weight_shape = [1] * len(conv_weight.shape)
168        weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
169        weight_shape[weight_in_channel_axis] = -1
170        bias_shape = [1] * len(conv_weight.shape)
171        bias_shape[1] = -1
172        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
173        x = conv_fn(x, scaled_weight, None)
174        x = x / scale_factor.reshape(bias_shape)
175        x = F.batch_norm(
176            x,
177            bn_running_mean,
178            bn_running_var,
179            bn_weight,
180            bn_bias,
181            training=True,
182            eps=bn_eps,
183        )
184        return x
185
186    return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias)
187
188
189def _append_qdq(x, is_per_channel, is_bias, kwargs):
190    """
191    Helper function to append q-dq ops after `x`, using dummy values for the qparams
192    and qmin/qmax. We use dummy values here because we match with `ignore_literals=True`
193    and will manually replace these values after subgraph rewriting.
194
195    Return the dq node.
196    """
197    # Dummy args to be passed into q-dq ops
198    per_channel_axis = 0
199    scale_key = "bias_scale" if is_bias else "weight_scale"
200    zp_key = "bias_zero_point" if is_bias else "weight_zero_point"
201    scale = kwargs[scale_key] if is_per_channel else 1.0
202    zp = kwargs[zp_key] if is_per_channel else 0
203    qmin = -127
204    qmax = 127
205    dtype = torch.int8
206
207    qd = torch.ops.quantized_decomposed
208    if is_per_channel:
209        x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
210        x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype)
211    else:
212        x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
213        x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype)
214    return x
215
216
217def _get_quantized_qat_conv_bn_pattern(
218    is_per_channel: bool,
219    has_bias: bool,
220    bias_is_quantized: bool,
221    conv_fn: Callable,
222    bn_is_training: bool,
223) -> Callable:
224    """
225    Return the quantized version of QAT conv + BN pattern.
226    This is based on `nniqat.ConvBn2d._forward_approximate`,
227    used in QAT convert. We first match this pattern and replace
228    it with the normal [conv - bn] pattern, then fold the BN
229    weights into conv.
230    """
231    # TODO: allow setting eps
232    bn_eps = 1e-5
233
234    def _quantized_qat_conv_bn_pattern(
235        x: torch.Tensor,
236        conv_weight: torch.Tensor,
237        bn_weight: torch.Tensor,
238        bn_bias: torch.Tensor,
239        bn_running_mean: torch.Tensor,
240        bn_running_var: torch.Tensor,
241        **kwargs,
242    ) -> torch.Tensor:
243        running_std = torch.sqrt(bn_running_var + bn_eps)
244        scale_factor = bn_weight / running_std
245        weight_shape = [1] * len(conv_weight.shape)
246        weight_shape[0] = -1
247        bias_shape = [1] * len(conv_weight.shape)
248        bias_shape[1] = -1
249        scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
250        scaled_weight = _append_qdq(
251            scaled_weight,
252            is_per_channel,
253            is_bias=False,
254            kwargs=kwargs,
255        )
256        if has_bias:
257            zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
258            if bias_is_quantized:
259                zero_bias = _append_qdq(
260                    zero_bias,
261                    is_per_channel,
262                    is_bias=True,
263                    kwargs=kwargs,
264                )
265            x = conv_fn(x, scaled_weight, zero_bias)
266        else:
267            x = conv_fn(x, scaled_weight, None)
268        x = x / scale_factor.reshape(bias_shape)
269        if has_bias:
270            x = x + kwargs["conv_bias"].reshape(bias_shape)
271        x = F.batch_norm(
272            x,
273            bn_running_mean,
274            bn_running_var,
275            bn_weight,
276            bn_bias,
277            training=bn_is_training,
278            eps=bn_eps,
279        )
280        return x
281
282    return _WrapperModule(_quantized_qat_conv_bn_pattern)
283
284
285def _get_folded_quantized_qat_conv_bn_pattern(
286    is_per_channel: bool,
287    has_bias: bool,
288    bias_is_quantized: bool,
289    conv_fn: Callable,
290    bn_is_training: bool,
291) -> Callable:
292    """
293    Quantized QAT conv - bn pattern with bn weights being folded into conv.
294    """
295    # TODO: allow setting eps
296    bn_eps = 1e-5
297
298    def _folded_quantized_qat_conv_bn_pattern(
299        x: torch.Tensor,
300        conv_weight: torch.Tensor,
301        bn_weight: torch.Tensor,
302        bn_bias: torch.Tensor,
303        bn_running_mean: torch.Tensor,
304        bn_running_var: torch.Tensor,
305        **kwargs,
306    ) -> torch.Tensor:
307        conv_weight = _append_qdq(
308            conv_weight,
309            is_per_channel,
310            is_bias=False,
311            kwargs=kwargs,
312        )
313        if has_bias:
314            bias = kwargs["conv_bias"]
315            if bias_is_quantized:
316                bias = _append_qdq(
317                    bias,
318                    is_per_channel,
319                    is_bias=True,
320                    kwargs=kwargs,
321                )
322        else:
323            bias = None
324        x = conv_fn(x, conv_weight, bias)
325        x = F.batch_norm(
326            x,
327            bn_running_mean,
328            bn_running_var,
329            bn_weight,
330            bn_bias,
331            training=bn_is_training,
332            eps=bn_eps,
333        )
334        return x
335
336    return _WrapperModule(_folded_quantized_qat_conv_bn_pattern)
337
338
339def _has_conv_bias_filter(
340    match: "InternalMatch",
341    original_graph: Graph,
342    pattern_graph: Graph,
343) -> bool:
344    """
345    Match filter for the subgraph rewriter that returns True if the conv node in
346    the original graph has bias.
347    """
348    for n in match.nodes_map.values():
349        if _is_conv_or_conv_transpose_node(n):
350            return len(n.args) > 2 and n.args[2] is not None
351    raise ValueError("Could not find conv node in matched conv + bn pattern")
352
353
354def _no_conv_bias_filter(
355    match: "InternalMatch",
356    original_graph: Graph,
357    pattern_graph: Graph,
358) -> bool:
359    """
360    Match filter for the subgraph rewriter that returns True if the conv node in
361    the original graph does NOT have bias.
362    """
363    return not _has_conv_bias_filter(match, original_graph, pattern_graph)
364
365
366def _is_quantize(n: Node) -> bool:
367    return n.target in [
368        torch.ops.quantized_decomposed.quantize_per_tensor.default,
369        torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
370        torch.ops.quantized_decomposed.quantize_per_channel.default,
371    ]
372
373
374def _is_dequantize(n: Node) -> bool:
375    return n.target in [
376        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
377        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
378        torch.ops.quantized_decomposed.dequantize_per_channel.default,
379    ]
380
381
382def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Node]]:
383    """
384    Helper function to extract the nodes in the conv-bn fusion pattern after
385    subgraph rewriting, in the form of a map:
386
387        {name: (original_node, replacement_node)}
388
389    The following names must exist in the map:
390
391        "conv", "conv_weight", "conv_input", "bn", "getitem"
392
393    The following names may exist in the map:
394
395        "conv_weight_q", "conv_weight_dq", "conv_bias",
396        "conv_bias_q", "conv_bias_dq"
397    """
398
399    def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Optional[Node]]:
400        """
401        Return a 3-tuple of (conv_node, bn_node, getitem_node).
402        This asserts that the match contains exactly one of each node.
403        """
404        conv_node, bn_node, getitem_node = None, None, None
405        for n in nodes:
406            if n.op != "call_function":
407                continue
408            if _is_conv_or_conv_transpose_node(n):
409                assert conv_node is None
410                conv_node = n
411            if _is_bn_node(n):
412                assert bn_node is None
413                bn_node = n
414            if n.target == operator.getitem:
415                assert getitem_node is None
416                getitem_node = n
417        assert conv_node is not None
418        assert bn_node is not None
419        # getitem_node might be None in new training IR
420        return (conv_node, bn_node, getitem_node)
421
422    def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]:
423        """
424        Return a 3-tuple of (orig_node, q_node, dq_node).
425        """
426        assert _is_dequantize(n)
427        q_node = n.args[0]
428        assert isinstance(q_node, Node)
429        assert _is_quantize(q_node)
430        orig_node = q_node.args[0]
431        assert isinstance(orig_node, Node)
432        return (orig_node, q_node, n)
433
434    original_nodes = list(_filter_nodes_map(r.nodes_map).values())
435    o_conv, o_bn, o_getitem = _get_nodes(original_nodes)
436    r_conv, r_bn, r_getitem = _get_nodes(r.replacements)
437
438    # Create the mapping from original node to replacement node
439    if o_getitem is None:
440        # getitem is None is new training IR
441        assert r_getitem is None
442        mapping = {
443            "conv": (o_conv, r_conv),
444            "bn": (o_bn, r_bn),
445        }
446    else:
447        # TODO: This branch is going through a deprecated branch and should be deleted soon,
448        # after capture_pre_autograd_graph fully migrate to training IR
449        # T199018392
450        assert r_getitem is not None
451        assert o_getitem is not None
452        mapping = {
453            "conv": (o_conv, r_conv),
454            "bn": (o_bn, r_bn),
455            "getitem": (o_getitem, r_getitem),
456        }
457
458    # Extract conv input and weight
459    # Note: here we extract the original nodes indirectly through the pattern nodes
460    # because the args of the original nodes are no longer available after replacement
461    (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys()))
462    (p_conv_input, p_conv_weight, *_) = p_conv.args
463    (r_conv_input, r_conv_weight, *_) = r_conv.args
464    assert isinstance(p_conv_input, Node)
465    assert isinstance(p_conv_weight, Node)
466    assert isinstance(r_conv_input, Node)
467    assert isinstance(r_conv_weight, Node)
468    o_conv_input = r.nodes_map[p_conv_input]
469    o_conv_weight = r.nodes_map[p_conv_weight]
470
471    # If conv weight is quantized, extract the q - dq nodes
472    if _is_dequantize(p_conv_weight):
473        p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes(
474            p_conv_weight
475        )
476        r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes(
477            r_conv_weight
478        )
479        o_conv_weight = r.nodes_map[p_conv_weight]
480        o_conv_weight_q = r.nodes_map[p_conv_weight_q]
481        o_conv_weight_dq = r.nodes_map[p_conv_weight_dq]
482        mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q)
483        mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq)
484    mapping["conv_input"] = (o_conv_input, r_conv_input)
485    mapping["conv_weight"] = (o_conv_weight, r_conv_weight)
486
487    # Extract conv bias
488    if len(p_conv.args) > 2 and len(r_conv.args) > 2:
489        p_conv_bias = p_conv.args[2]
490        r_conv_bias = r_conv.args[2]
491        assert isinstance(p_conv_bias, Node)
492        assert isinstance(r_conv_bias, Node)
493        o_conv_bias = r.nodes_map[p_conv_bias]
494
495        # If conv bias is quantized, extract the q - dq nodes
496        if _is_dequantize(p_conv_bias):
497            p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias)
498            r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias)
499            o_conv_bias = r.nodes_map[p_conv_bias]
500            o_conv_bias_q = r.nodes_map[p_conv_bias_q]
501            o_conv_bias_dq = r.nodes_map[p_conv_bias_dq]
502            mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q)
503            mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq)
504        mapping["conv_bias"] = (o_conv_bias, r_conv_bias)
505    return mapping
506
507
508def _filter_nodes_map(nodes_map: Dict[Node, Node]) -> Dict[Node, Node]:
509    """
510    Return a filtered `nodes_map` returned from the subgraph rewriter.
511    The filtered `nodes_map` will contain only nodes that are actually
512    matched in the pattern, excluding None or placeholder nodes.
513    """
514    new_nodes_map: Dict[Node, Node] = {}
515    for pattern_node, graph_node in nodes_map.items():
516        # bias can be None
517        if graph_node is None:
518            continue
519        # skip pattern placeholder nodes
520        if pattern_node.op == "placeholder":
521            continue
522        new_nodes_map[pattern_node] = graph_node
523    return new_nodes_map
524
525
526# TODO: this is error prone, use the replace_literals_with_placeholders hack instead
527def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
528    """
529    Copy over literal args in conv, such as stride and padding, from the matched node
530    in the original graph to its replacement in the new graph.
531
532    This is needed due to the following limitation in the subgraph rewriter when used
533    with dynamo export: literal (non-tensor) args are not supported in the match and
534    replacement patterns. This is because dynamo export automatically inlines these
535    literal args, making them dead placeholder nodes. In the future, we should check
536    if dynamo export can optionally disable this inlining, or if subgraph rewriter
537    can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419.
538
539    Note: Unlike other tensor args like conv weights and biases, literal args are
540    preserved in the original nodes after replacement, so we can access them here.
541    """
542    assert _is_conv_or_conv_transpose_node(original_node)
543    assert _is_conv_or_conv_transpose_node(new_node)
544    # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
545    new_args = list(new_node.args)
546    if len(new_args) < 3:
547        # bias is optional, when it is not present, it means it is None
548        new_args.append(None)
549    new_node.args = tuple(new_args[:3]) + original_node.args[3:]
550
551
552def _update_conv_input_qspec_map_after_replacement(
553    original_node: Node, replacement_node: Node
554):
555    """
556    Update the `input_qspec_map` in the annotation after subgraph rewriting.
557
558    The original annotation referred to the nodes in the original graph,
559    so the keys in the `input_qspec_map` will need to be updated to reflect
560    the corresponding nodes in the replacement graph.
561    """
562    assert _is_conv_or_conv_transpose_node(original_node)
563    assert _is_conv_or_conv_transpose_node(replacement_node)
564    if "quantization_annotation" not in original_node.meta:
565        return
566    original_input_qspec_map = original_node.meta[
567        "quantization_annotation"
568    ].input_qspec_map
569    input_qspec_map = {}
570    # get the list of configs, it should be ordered as input, weight, bias
571    # note: this is really hacky, we need a better solution, hopefully
572    # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820
573    all_configs = list(original_input_qspec_map.items())
574    # input activation
575    input_qspec_map[replacement_node.args[0]] = all_configs[0][1]
576    # weight
577    input_qspec_map[replacement_node.args[1]] = all_configs[1][1]
578    # bias
579    if len(replacement_node.args) > 2 and len(all_configs) > 2:
580        input_qspec_map[replacement_node.args[2]] = all_configs[2][1]
581    replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map
582
583
584def _update_special_qspecs_after_replacement(
585    node: Node,
586    original_to_replacement_node: Dict[Node, Node],
587):
588    """
589    Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s
590    used in `node`'s quantization annotation after subgraph rewriting.
591
592    The original annotation referred to the nodes in the original graph,
593    so the nodes used in these special quantization specs will need to
594    be updated to the corresponding nodes in the replacement graph.
595    """
596
597    def _get_new_edge_or_node(edge_or_node: EdgeOrNode):
598        if isinstance(edge_or_node, Node):
599            _node = edge_or_node
600            return original_to_replacement_node.get(_node, _node)
601        elif (
602            isinstance(edge_or_node, tuple)
603            and len(edge_or_node) == 2
604            and all(isinstance(x, Node) for x in edge_or_node)
605        ):
606            src, dest = edge_or_node
607            return (
608                original_to_replacement_node.get(src, src),
609                original_to_replacement_node.get(dest, dest),
610            )
611        else:
612            raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node))
613
614    def _get_new_qspec(qspec: QuantizationSpecBase):
615        if isinstance(qspec, SharedQuantizationSpec):
616            new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node)
617            return SharedQuantizationSpec(new_edge_or_node)
618        elif isinstance(qspec, DerivedQuantizationSpec):
619            new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from]
620            return dataclasses.replace(qspec, derived_from=new_derived_from)
621        else:
622            return qspec
623
624    if "quantization_annotation" not in node.meta:
625        return
626    annotation = node.meta["quantization_annotation"]
627    for input_node, qspec in annotation.input_qspec_map.items():
628        annotation.input_qspec_map[input_node] = _get_new_qspec(qspec)
629    annotation.output_qspec = _get_new_qspec(annotation.output_qspec)
630
631
632def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
633    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
634    if not has_bn:
635        return m
636    is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
637    for is_cuda in is_cuda_options:
638        m = _fuse_conv_bn_qat_helper(
639            m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
640        )
641        m = _fuse_conv_bn_qat_helper(
642            m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
643        )
644        m = _fuse_conv_bn_qat_helper(
645            m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda
646        )
647        m = _fuse_conv_bn_qat_helper(
648            m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda
649        )
650    return m
651
652
653def _fuse_conv_bn_qat_helper(
654    m: GraphModule,
655    conv_fn: Callable,
656    example_inputs: Tuple[Any, ...],
657    is_cuda: bool,
658) -> GraphModule:
659    """
660    Given a graph of decomposed aten ops, replace the (conv + bn) pattern with
661    the fused QAT subgraph equivalent. The input graph should already be annotated.
662    The annotations in the original nodes will be preserved in the corresponding
663    nodes in the new subgraph.
664
665    Note: This also handles the (conv + bn + relu) pattern.
666    """
667    m.graph.eliminate_dead_code()
668    m.recompile()
669    conv_bn_pattern = _get_conv_bn_pattern(conv_fn)
670    match_pattern = _get_aten_graph_module_for_pattern(
671        conv_bn_pattern, example_inputs, is_cuda
672    )
673
674    # Step (1): Replace patterns with conv bias
675    #
676    # Here we do replacement separately for cases with and without conv bias, since
677    # the replacement patterns for these two cases are substantially different.
678    # TODO: use the public replace_pattern API once it also returns replacement nodes
679
680    qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn)
681    replacement_pattern_with_conv_bias = _get_aten_graph_module_for_pattern(
682        qat_conv_bn_pattern,
683        example_inputs,
684        is_cuda,
685    )
686    replacements_with_conv_bias = replace_pattern_with_filters(
687        m,
688        match_pattern,
689        replacement_pattern_with_conv_bias,
690        match_filters=[_has_conv_bias_filter],
691        ignore_literals=True,
692    )
693    m.recompile()
694
695    # Step (2): Replace patterns without conv bias
696
697    qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn)
698    replacement_pattern_no_conv_bias = _get_aten_graph_module_for_pattern(
699        qat_conv_bn_pattern_no_conv_bias,
700        example_inputs,
701        is_cuda,
702    )
703    replacements_no_conv_bias = replace_pattern_with_filters(
704        m,
705        match_pattern,
706        replacement_pattern_no_conv_bias,
707        match_filters=[_no_conv_bias_filter],
708        ignore_literals=True,
709    )
710    m.recompile()
711
712    # Step (3): Post processing
713    #
714    # Due to limited functionality in the subgraph rewriter, here we manually
715    # update the replacement graph as follows:
716    #
717    #   (a) Copy over metadata from original subgraph. This ensures the stack traces
718    #       and annotations are preserved in the new subgraph
719    #
720    #   (b) Copy over literal args for conv from the original subgraph
721    #       TODO: do this for literal args for batchnorm as well
722    #
723    #   (c) Update all references of the old nodes in the original subgraph to refer
724    #       to the corresponding nodes in the new subgraph in the annotations
725    #
726    # In the future, we should try to push as much of this functionality into the
727    # subgraph rewriter as possible, so we don't have to manually copy anything over.
728    # For more detail, see https://github.com/pytorch/pytorch/issues/100419.
729
730    all_original_to_replacement_nodes = {}
731    for r in replacements_with_conv_bias + replacements_no_conv_bias:
732        for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
733            # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
734            replacement_node.meta = original_node.meta
735            if _is_conv_or_conv_transpose_node(original_node):
736                # Step (3b): Copy over conv literal args
737                _copy_over_literal_conv_args(original_node, replacement_node)
738                # Step (3c): Update old references in the conv node's input_qspec_map
739                _update_conv_input_qspec_map_after_replacement(
740                    original_node, replacement_node
741                )
742            all_original_to_replacement_nodes[original_node] = replacement_node
743
744    # Step (3c): Update old references in the special qspecs for all nodes in the graph
745    for n in m.graph.nodes:
746        _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes)
747
748    return m
749
750
751def _duplicate_dequantize_node(m: GraphModule):
752    """
753    Helper function to duplicate all dequantize nodes in the graph if the
754    node has more than one user. For example:
755
756    Before:
757      quantize -> dequantize -> a
758                          \\--> b
759                          \\--> c
760
761    After:
762      quantize -> dequantize_1 -> a
763            \\--> dequantize_2 -> b
764            \\--> dequantize_3 -> c
765
766    This is useful for subgraph rewriting. E.g. if we wish to match the
767    pattern [dequantize - a] above, subgraph matching would fail because
768    the dequantize node has users outside the matched portion of the graph.
769    Instead, we match [dequantize_1 - a], which is safe.
770    """
771    dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
772    for n in m.graph.nodes:
773        if n.op != "call_function" or n.target != dq_op or len(n.users) == 1:
774            continue
775        for user in list(n.users):
776            with m.graph.inserting_before(n):
777                new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs)
778            user.replace_input_with(n, new_node)
779        m.graph.erase_node(n)
780    m.recompile()
781
782
783def _remove_extra_dequantize(m: GraphModule):
784    """
785    Removes duplicate dequant nodes in the graph, for an operator that has
786    multiple dequant nodes as a user, replace them with a single dequant node
787    that can be shared across all the uses. This should be seen as the "reverse"
788    of `_duplicate_dequantize_node`.
789    """
790    dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor
791    for n in m.graph.nodes:
792        dq_users = [
793            user
794            for user in n.users
795            if user.op == "call_function" and user.target == dq_op
796        ]
797        if len(dq_users) > 1:
798            with m.graph.inserting_after(dq_users[0]):
799                new_node = m.graph.create_node(
800                    "call_function", dq_op, dq_users[0].args, {}
801                )
802            for dq_user in dq_users:
803                dq_user.replace_all_uses_with(new_node)
804                m.graph.erase_node(dq_user)
805    m.recompile()
806
807
808def _copy_over_q_dq_args(original_node: Node, replacement_node: Node):
809    """
810    Given a pair of quantize or dequantize nodes, copy over all literal args
811    from the original node to the replacement node.
812    """
813    # For quantize_per_tensor, scale and zp are literals and need to be copied
814    # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped
815    assert original_node.target == replacement_node.target
816    if original_node.target in (
817        torch.ops.quantized_decomposed.quantize_per_tensor.default,
818        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
819    ):
820        # Args: input, [scale, zp, qmin, qmax, dtype]
821        start_copy_arg_index = 1
822    elif original_node.target in (
823        torch.ops.quantized_decomposed.quantize_per_channel.default,
824        torch.ops.quantized_decomposed.dequantize_per_channel.default,
825    ):
826        # Args: input, scale, zp, [axis, qmin, qmax, dtype]
827        start_copy_arg_index = 3
828    else:
829        raise ValueError(
830            f"Expected quantize/dequantize nodes, got '{original_node.target}'"
831        )
832    replacement_node.args = (
833        replacement_node.args[:start_copy_arg_index]
834        + original_node.args[start_copy_arg_index:]
835    )
836
837
838def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
839    has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
840    if not has_bn:
841        return m
842    is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
843    for is_cuda in is_cuda_options:
844        m = _fold_conv_bn_qat_helper(
845            m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
846        )
847        m = _fold_conv_bn_qat_helper(
848            m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
849        )
850        m = _fold_conv_bn_qat_helper(
851            m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda
852        )
853        m = _fold_conv_bn_qat_helper(
854            m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda
855        )
856
857    # remove in place add from batchnorm tracking traning stats
858    for node in m.graph.nodes:
859        if (
860            node.target == torch.ops.aten.add_.Tensor
861            and node.args[0].op == "get_attr"
862            and node.args[1] == 1
863            and torch.nn.modules.batchnorm.BatchNorm2d
864            in [val[1] for val in node.meta["source_fn_stack"]]
865        ):
866            m.graph.erase_node(node)
867
868    m.graph.eliminate_dead_code()
869    m.recompile()
870
871    return m
872
873
874def _fold_conv_bn_qat_helper(
875    m: GraphModule,
876    conv_fn: Callable,
877    example_inputs: Tuple[Any, ...],
878    is_cuda: bool,
879) -> GraphModule:
880    """
881    Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv.
882    """
883    m.graph.eliminate_dead_code()
884    m.recompile()
885    _duplicate_dequantize_node(m)
886
887    # Step (1): Replace QAT pattern with simple [conv - bn] pattern
888    replacements = []
889    replacement_options = itertools.product(
890        [True, False],  # is_per_channel
891        [True, False],  # has_bias
892        [True, False],  # bias_is_quantized
893        [True, False],  # bn_is_training
894    )
895    for (
896        is_per_channel,
897        has_bias,
898        bias_is_quantized,
899        bn_is_training,
900    ) in replacement_options:
901        # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily
902        # filter out one of the values for this flag to avoid having duplicate patterns
903        if not has_bias and bias_is_quantized:
904            continue
905        kwargs = _get_quantized_conv_bn_example_inputs_kwargs(
906            is_per_channel, has_bias, bias_is_quantized, is_cuda
907        )
908        match_pattern = _get_quantized_qat_conv_bn_pattern(
909            is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
910        )
911        match_pattern = _get_aten_graph_module_for_pattern(
912            match_pattern, example_inputs, is_cuda, **kwargs
913        )
914        replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern(
915            is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training
916        )
917        replacement_pattern = _get_aten_graph_module_for_pattern(
918            replacement_pattern, example_inputs, is_cuda, **kwargs
919        )
920        replacements.extend(
921            replace_pattern_with_filters(
922                m,
923                match_pattern,
924                replacement_pattern,
925                ignore_literals=True,
926            )
927        )
928    m.recompile()
929    _remove_extra_dequantize(m)
930
931    for r in replacements:
932        node_map = _get_conv_bn_pattern_nodes(r)
933
934        # Step (2): Copy over metadata from original subgraph
935        for original_node, replacement_node in node_map.values():
936            replacement_node.meta = original_node.meta
937
938        # Step (3): Copy over args for weight (and optionally bias) q - dq nodes
939        _copy_over_q_dq_args(*node_map["conv_weight_q"])
940        _copy_over_q_dq_args(*node_map["conv_weight_dq"])
941        if "conv_bias_q" in node_map:
942            assert "conv_bias_dq" in node_map
943            _copy_over_q_dq_args(*node_map["conv_bias_q"])
944            _copy_over_q_dq_args(*node_map["conv_bias_dq"])
945
946        # Step (4): Fold BN weights into conv
947        conv_bias = None
948        (_, conv_node) = node_map["conv"]
949        (_, bn_node) = node_map["bn"]
950        (_, conv_weight) = node_map["conv_weight"]
951        if "conv_bias" in node_map:
952            (_, conv_bias) = node_map["conv_bias"]
953        fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m)
954
955        # Copy over literal args for conv
956        for original_node in _filter_nodes_map(r.nodes_map).values():
957            if _is_conv_or_conv_transpose_node(original_node):
958                _copy_over_literal_conv_args(original_node, conv_node)
959
960    m.graph.eliminate_dead_code()
961    m.recompile()
962    return m
963