xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_quantize_pt2e_qat.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2import copy
3import operator
4import unittest
5from typing import Any, Optional, Tuple, Type
6
7import torch
8from torch._export import capture_pre_autograd_graph
9from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
10from torch.ao.quantization import (
11    default_fake_quant,
12    FusedMovingAvgObsFakeQuantize,
13    MovingAverageMinMaxObserver,
14    MovingAveragePerChannelMinMaxObserver,
15    QConfigMapping,
16)
17from torch.ao.quantization.backend_config import get_qnnpack_backend_config
18from torch.ao.quantization.qconfig import (
19    default_per_channel_symmetric_qnnpack_qat_qconfig,
20    default_symmetric_qnnpack_qat_qconfig,
21)
22from torch.ao.quantization.quantize_fx import prepare_qat_fx
23from torch.ao.quantization.quantize_pt2e import (
24    _convert_to_reference_decomposed_fx,
25    convert_pt2e,
26    prepare_pt2e,
27    prepare_qat_pt2e,
28)
29from torch.ao.quantization.quantizer import (
30    DerivedQuantizationSpec,
31    QuantizationAnnotation,
32    QuantizationSpec,
33    Quantizer,
34)
35from torch.ao.quantization.quantizer.xnnpack_quantizer import (
36    get_symmetric_quantization_config,
37    XNNPACKQuantizer,
38)
39from torch.testing._internal.common_cuda import TEST_CUDA
40from torch.testing._internal.common_quantization import (
41    NodeSpec as ns,
42    QuantizationTestCase,
43    skip_if_no_torchvision,
44    skipIfNoQNNPACK,
45)
46from torch.testing._internal.common_quantized import override_quantized_engine
47
48
49class PT2EQATTestCase(QuantizationTestCase):
50    """
51    Base QuantizationTestCase for PT2E QAT with some helper methods.
52    """
53
54    class _BaseConvBnModel(torch.nn.Module):
55        def __init__(
56            self,
57            conv_class: Type[torch.nn.Module],
58            bn_class: Type[torch.nn.Module],
59            has_conv_bias: bool,
60            has_bn: bool,
61            has_relu: bool,
62            **conv_kwargs,
63        ):
64            super().__init__()
65            conv_kwargs.setdefault("in_channels", 3)
66            conv_kwargs.setdefault("out_channels", 3)
67            conv_kwargs.setdefault("kernel_size", 3)
68            conv_kwargs.setdefault("bias", has_conv_bias)
69            self.conv = conv_class(**conv_kwargs)
70            self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None
71            self.relu = torch.nn.ReLU() if has_relu else None
72
73        def forward(self, x):
74            x = self.conv(x)
75            if self.bn is not None:
76                x = self.bn(x)
77            if self.relu is not None:
78                x = self.relu(x)
79            return x
80
81    def _get_conv_bn_model(
82        self,
83        has_conv_bias: bool = True,
84        has_bn: bool = True,
85        has_relu: bool = False,
86        transpose: bool = False,
87        **conv_kwargs,
88    ):
89        """
90        Return an instance of a simple test model containing the
91        conv[-bn][-relu] pattern. By default, this returns a
92        conv-bn model with conv bias.
93        """
94        return self._BaseConvBnModel(
95            self.conv_transpose_class if transpose else self.conv_class,
96            self.bn_class,
97            has_conv_bias,
98            has_bn,
99            has_relu,
100            **conv_kwargs,
101        )
102
103    def _verify_symmetric_xnnpack_qat_numerics(
104        self,
105        model: torch.nn.Module,
106        example_inputs: Tuple[Any, ...],
107    ):
108        self._verify_symmetric_xnnpack_qat_numerics_helper(
109            model,
110            example_inputs,
111            is_per_channel=True,
112        )
113        self._verify_symmetric_xnnpack_qat_numerics_helper(
114            model,
115            example_inputs,
116            is_per_channel=False,
117        )
118
119    def _verify_symmetric_xnnpack_qat_numerics_helper(
120        self,
121        model: torch.nn.Module,
122        example_inputs: Tuple[Any, ...],
123        is_per_channel: bool,
124        verify_convert: bool = True,
125    ):
126        """
127        Helper method to verify that the QAT numerics for PT2E quantization match those of
128        FX graph mode quantization for symmetric qnnpack.
129        """
130        # resetting dynamo cache
131        torch._dynamo.reset()
132        MANUAL_SEED = 100
133
134        # PT2 export
135
136        model_pt2e = copy.deepcopy(model)
137        quantizer = XNNPACKQuantizer()
138        quantizer.set_global(
139            get_symmetric_quantization_config(
140                is_per_channel=is_per_channel, is_qat=True
141            )
142        )
143        model_pt2e = capture_pre_autograd_graph(
144            model_pt2e,
145            example_inputs,
146        )
147        model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
148        torch.manual_seed(MANUAL_SEED)
149        after_prepare_result_pt2e = model_pt2e(*example_inputs)
150
151        model_fx = copy.deepcopy(model)
152        if is_per_channel:
153            default_qconfig = default_per_channel_symmetric_qnnpack_qat_qconfig
154        else:
155            default_qconfig = default_symmetric_qnnpack_qat_qconfig
156        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
157        backend_config = get_qnnpack_backend_config()
158        model_fx = prepare_qat_fx(
159            model_fx, qconfig_mapping, example_inputs, backend_config=backend_config
160        )
161        torch.manual_seed(MANUAL_SEED)
162        after_prepare_result_fx = model_fx(*example_inputs)
163
164        # Verify that numerics match
165        self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
166
167        if verify_convert:
168            # We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
169            torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
170            model_pt2e = convert_pt2e(model_pt2e)
171            quant_result_pt2e = model_pt2e(*example_inputs)
172            model_fx.eval()
173            model_fx = _convert_to_reference_decomposed_fx(
174                model_fx,
175                backend_config=backend_config,
176            )
177            quant_result_fx = model_fx(*example_inputs)
178            self.assertEqual(quant_result_pt2e, quant_result_fx)
179
180    def _verify_symmetric_xnnpack_qat_graph(
181        self,
182        m: torch.fx.GraphModule,
183        example_inputs: Tuple[Any, ...],
184        has_relu: bool,
185        has_bias: bool = True,
186        is_cuda: bool = False,
187        expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
188        # TODO: set this to true by default
189        verify_convert: bool = False,
190    ):
191        self._verify_symmetric_xnnpack_qat_graph_helper(
192            m,
193            example_inputs,
194            is_per_channel=True,
195            has_relu=has_relu,
196            has_bias=has_bias,
197            is_cuda=is_cuda,
198            expected_conv_literal_args=expected_conv_literal_args,
199            verify_convert=verify_convert,
200        )
201        self._verify_symmetric_xnnpack_qat_graph_helper(
202            m,
203            example_inputs,
204            is_per_channel=False,
205            has_relu=has_relu,
206            has_bias=has_bias,
207            is_cuda=is_cuda,
208            expected_conv_literal_args=expected_conv_literal_args,
209            verify_convert=verify_convert,
210        )
211
212    def _verify_symmetric_xnnpack_qat_graph_helper(
213        self,
214        m: torch.fx.GraphModule,
215        example_inputs: Tuple[Any, ...],
216        is_per_channel: bool,
217        has_relu: bool,
218        has_bias: bool = True,
219        is_cuda: bool = False,
220        expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
221        verify_convert: bool = False,
222    ):
223        """
224        Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
225        with fake quantizes inserted into the correct places.
226        # TODO: also verify that metadata is copied over to the new nodes.
227        """
228        m = copy.deepcopy(m)
229        quantizer = XNNPACKQuantizer()
230        quantizer.set_global(
231            get_symmetric_quantization_config(is_per_channel, is_qat=True)
232        )
233        m = capture_pre_autograd_graph(
234            m,
235            example_inputs,
236        )
237        m = prepare_qat_pt2e(m, quantizer)
238        m(*example_inputs)
239
240        # Verify: getitem output activation fake quantize
241        output_node = list(m.graph.nodes)[-1]
242        output_fq_node = output_node.args[0][0]
243        self.assertTrue(output_fq_node.target.startswith("activation_post_process_"))
244        output_fq_mod = getattr(m, output_fq_node.target)
245        self.assertEqual(type(output_fq_mod), FusedMovingAvgObsFakeQuantize)
246        self.assertEqual(
247            type(output_fq_mod.activation_post_process), MovingAverageMinMaxObserver
248        )
249        self.assertEqual(output_fq_mod.dtype, torch.int8)
250        self.assertEqual(output_fq_mod.quant_min, -128)
251        self.assertEqual(output_fq_mod.quant_max, 127)
252
253        # Verify: getitem(bn, 0) or relu(getitem(bn, 0))
254        if has_relu:
255            relu_node = output_fq_node.args[0]
256            getitem_node = relu_node.args[0]
257            self.assertEqual(relu_node.target, torch.ops.aten.relu.default)
258        else:
259            relu_node = None
260            getitem_node = output_fq_node.args[0]
261
262        is_training_ir_flag = capture_pre_autograd_graph_using_training_ir()
263        if is_training_ir_flag:
264            # The relu node takes in the output of bn.
265            # See NOTE [training ir has no getitem for bn node].
266            bn_node = getitem_node
267            self.assertEqual(bn_node.target, torch.ops.aten.batch_norm.default)
268        else:
269            # TODO: This branch is going through a deprecated branch and should be deleted soon,
270            # after capture_pre_autograd_graph fully migrate to training IR
271            # T199018392
272            self.assertEqual(getitem_node.target, operator.getitem)
273            bn_node = getitem_node.args[0]
274
275            expected_bn_op = None
276            if is_cuda:
277                if torch.version.cuda is not None:
278                    expected_bn_op = torch.ops.aten.cudnn_batch_norm.default
279                elif torch.version.hip is not None:
280                    expected_bn_op = torch.ops.aten.miopen_batch_norm.default
281            else:
282                expected_bn_op = torch.ops.aten._native_batch_norm_legit.default
283            self.assertEqual(bn_node.target, expected_bn_op)
284
285        # Verify: conv / scale_factor.reshape [+ bias.reshape]
286        if has_bias:
287            add_bias_node = bn_node.args[0]
288            (div_scale_factor_node, bias_reshape_node) = add_bias_node.args
289            self.assertEqual(add_bias_node.target, torch.ops.aten.add.Tensor)
290            self.assertEqual(bias_reshape_node.target, torch.ops.aten.reshape.default)
291        else:
292            div_scale_factor_node = bn_node.args[0]
293        (conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
294        conv_op = conv_node.target
295        self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
296        self.assertTrue(_is_conv_node(conv_node))
297        self.assertEqual(
298            scale_factor_reshape_node.target, torch.ops.aten.reshape.default
299        )
300
301        # Verify: conv literal args
302        if expected_conv_literal_args is not None:
303            assert (
304                len(expected_conv_literal_args) == 6
305            ), "wrong num conv args, bad test setup"
306            for i in range(6):
307                if i + 3 < len(conv_node.args):
308                    self.assertEqual(
309                        conv_node.args[i + 3], expected_conv_literal_args[i]
310                    )
311
312        # Verify: conv input activation fake quantize
313        conv_input_fq_node = conv_node.args[0]
314        conv_input_node = conv_input_fq_node.args[0]
315        self.assertTrue(
316            conv_input_fq_node.target.startswith("activation_post_process_")
317        )
318        conv_input_fq_mod = getattr(m, conv_input_fq_node.target)
319        self.assertEqual(type(conv_input_fq_mod), FusedMovingAvgObsFakeQuantize)
320        self.assertEqual(
321            type(conv_input_fq_mod.activation_post_process), MovingAverageMinMaxObserver
322        )
323        self.assertEqual(conv_input_fq_mod.dtype, torch.int8)
324        self.assertEqual(conv_input_fq_mod.quant_min, -128)
325        self.assertEqual(conv_input_fq_mod.quant_max, 127)
326        self.assertTrue(conv_input_node.op, "placeholder")
327
328        # Verify: conv weight fake quantize
329        conv_weight_fq_node = conv_node.args[1]
330        self.assertTrue(
331            conv_weight_fq_node.target.startswith("activation_post_process_")
332        )
333        conv_weight_fq_mod = getattr(m, conv_weight_fq_node.target)
334        if is_per_channel:
335            expected_weight_observer_type = MovingAveragePerChannelMinMaxObserver
336        else:
337            expected_weight_observer_type = MovingAverageMinMaxObserver
338        self.assertEqual(type(conv_weight_fq_mod), FusedMovingAvgObsFakeQuantize)
339        self.assertEqual(
340            type(conv_weight_fq_mod.activation_post_process),
341            expected_weight_observer_type,
342        )
343        self.assertEqual(conv_weight_fq_mod.dtype, torch.int8)
344        self.assertEqual(conv_weight_fq_mod.quant_min, -127)
345        self.assertEqual(conv_weight_fq_mod.quant_max, 127)
346
347        # Verify: conv(fq(input), fq(weight * scale_factor.reshape), zero_bias)
348        zero_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
349        mul_weight_scale_factor_node = conv_weight_fq_node.args[0]
350        (
351            conv_weight_fq_node,
352            scale_factor_reshape_node,
353        ) = mul_weight_scale_factor_node.args
354        if has_bias:
355            self.assertEqual(zero_bias_node.target, torch.ops.aten.zeros_like.default)
356        else:
357            self.assertTrue(zero_bias_node is None)
358        self.assertEqual(mul_weight_scale_factor_node.target, torch.ops.aten.mul.Tensor)
359        self.assertEqual(
360            scale_factor_reshape_node.target, torch.ops.aten.reshape.default
361        )
362
363        # Verify: scale_factor = bn_weight / sqrt(bn_running_var + eps)
364        scale_factor_node = scale_factor_reshape_node.args[0]
365        (bn_weight_node, sqrt_node) = scale_factor_node.args
366        bn_running_var_add_node = sqrt_node.args[0]
367        (bn_running_var_node, eps) = bn_running_var_add_node.args
368        self.assertEqual(scale_factor_node.target, torch.ops.aten.div.Tensor)
369        if is_training_ir_flag:
370            self.assertTrue("bn.weight" in bn_weight_node.target)
371            self.assertTrue("bn.running_var" in bn_running_var_node.target)
372        else:
373            self.assertTrue("bn_weight" in bn_weight_node.target)
374            self.assertTrue("bn_running_var" in bn_running_var_node.target)
375        self.assertEqual(sqrt_node.target, torch.ops.aten.sqrt.default)
376        self.assertEqual(bn_running_var_add_node.target, torch.ops.aten.add.Tensor)
377        self.assertEqual(eps, 1e-5)
378
379        # Optionally check the converted graph
380        if verify_convert:
381            m = convert_pt2e(m)
382            m(*example_inputs)
383
384            if is_per_channel:
385                conv_weight_dq_op = (
386                    torch.ops.quantized_decomposed.dequantize_per_channel.default
387                )
388                node_occurrence = {
389                    ns.call_function(
390                        torch.ops.quantized_decomposed.quantize_per_tensor.default
391                    ): 2,
392                    ns.call_function(
393                        torch.ops.quantized_decomposed.dequantize_per_tensor.default
394                    ): 2,
395                    ns.call_function(
396                        torch.ops.quantized_decomposed.dequantize_per_channel.default
397                    ): 1,
398                }
399            else:
400                conv_weight_dq_op = (
401                    torch.ops.quantized_decomposed.dequantize_per_tensor.default
402                )
403                node_occurrence = {
404                    ns.call_function(
405                        torch.ops.quantized_decomposed.quantize_per_tensor.default
406                    ): 2,
407                    ns.call_function(
408                        torch.ops.quantized_decomposed.dequantize_per_tensor.default
409                    ): 3,
410                }
411            node_list = [
412                ns.call_function(
413                    torch.ops.quantized_decomposed.quantize_per_tensor.default
414                ),
415                ns.call_function(
416                    torch.ops.quantized_decomposed.dequantize_per_tensor.default
417                ),
418                ns.call_function(conv_weight_dq_op),
419                ns.call_function(conv_op),
420                ns.call_function(
421                    torch.ops.quantized_decomposed.quantize_per_tensor.default
422                ),
423                ns.call_function(
424                    torch.ops.quantized_decomposed.dequantize_per_tensor.default
425                ),
426            ]
427
428            self.checkGraphModuleNodes(
429                m,
430                expected_node_list=node_list,
431                expected_node_occurrence=node_occurrence,
432            )
433
434
435class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
436    """
437    Base TestCase to be used for all conv-bn[-relu] fusion patterns.
438    """
439
440    # TODO: how can we avoid adding every new test to dynamo/expected_test_failures?
441    # Otherwise it fails with the following error:
442    #   torch._dynamo.exc.InternalTorchDynamoError:
443    #   'QuantizationConfig' object has no attribute '__bool__'
444
445    def setUp(self):
446        # NB: Skip the test if this is a base class, this is to handle the test
447        # discovery logic in buck which finds and runs all tests here including
448        # the base class which we don't want to run
449        if self.id() and "_Base" in self.id():
450            self.skipTest("Skipping test running from base class")
451
452    def test_qat_conv_no_bias(self):
453        m1 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=True)
454        m2 = self._get_conv_bn_model(has_conv_bias=False, has_bn=False, has_relu=False)
455        self._verify_symmetric_xnnpack_qat_numerics(m1, self.example_inputs)
456        self._verify_symmetric_xnnpack_qat_numerics(m2, self.example_inputs)
457
458    def test_qat_conv_bn_fusion(self):
459        m = self._get_conv_bn_model()
460        self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False)
461        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
462
463    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
464    def test_qat_conv_bn_fusion_cuda(self):
465        m = self._get_conv_bn_model().cuda()
466        example_inputs = (self.example_inputs[0].cuda(),)
467        self._verify_symmetric_xnnpack_qat_graph(
468            m,
469            example_inputs,
470            has_relu=False,
471            is_cuda=True,
472        )
473        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
474
475    def test_qat_conv_bn_fusion_literal_args(self):
476        class M(torch.nn.Module):
477            def __init__(self, conv_class, bn_class):
478                super().__init__()
479                self.conv = conv_class(3, 3, 3, stride=2, padding=4)
480                self.bn = bn_class(3)
481
482            def forward(self, x):
483                x = self.conv(x)
484                x = self.bn(x)
485                return x
486
487        assert self.dim in [1, 2]
488        if self.dim == 1:
489            # stride, padding, dilation, transposed, output_padding, groups
490            conv_args = ((2,), (4,), (1,), False, (0,), 1)
491            example_inputs = (torch.randn(1, 3, 5),)
492        else:
493            # stride, padding, dilation, transposed, output_padding, groups
494            conv_args = ((2, 2), (4, 4), (1, 1), False, (0, 0), 1)
495            example_inputs = (torch.randn(1, 3, 5, 5),)
496
497        m = M(self.conv_class, self.bn_class)
498
499        self._verify_symmetric_xnnpack_qat_graph(
500            m,
501            example_inputs,
502            has_relu=False,
503            expected_conv_literal_args=conv_args,
504        )
505        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
506
507    def test_qat_conv_bn_fusion_no_conv_bias(self):
508        class M2(torch.nn.Module):
509            """
510            Mixed conv + BN with and without conv bias.
511            """
512
513            def __init__(self, conv_class, bn_class):
514                super().__init__()
515                self.conv1 = conv_class(3, 3, 3, bias=False)
516                self.bn1 = bn_class(3)
517                self.conv2 = conv_class(3, 3, 3, bias=True)
518                self.bn2 = bn_class(3)
519
520            def forward(self, x):
521                x = self.conv1(x)
522                x = self.bn1(x)
523                x = self.conv2(x)
524                x = self.bn2(x)
525                return x
526
527        m1 = self._get_conv_bn_model(has_conv_bias=False)
528        m2 = M2(self.conv_class, self.bn_class)
529
530        assert self.dim in [1, 2]
531        if self.dim == 1:
532            example_inputs = (torch.randn(3, 3, 5),)
533        else:
534            example_inputs = (torch.randn(3, 3, 5, 5),)
535
536        self._verify_symmetric_xnnpack_qat_graph(
537            m1,
538            example_inputs,
539            has_relu=False,
540            has_bias=False,
541        )
542        self._verify_symmetric_xnnpack_qat_numerics(m1, example_inputs)
543        self._verify_symmetric_xnnpack_qat_numerics(m2, example_inputs)
544
545    def test_qat_conv_bn_relu_fusion(self):
546        m = self._get_conv_bn_model(has_relu=True)
547        self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True)
548        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
549
550    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
551    def test_qat_conv_bn_relu_fusion_cuda(self):
552        m = self._get_conv_bn_model(has_relu=True).cuda()
553        example_inputs = (self.example_inputs[0].cuda(),)
554        self._verify_symmetric_xnnpack_qat_graph(
555            m,
556            example_inputs,
557            has_relu=True,
558            is_cuda=True,
559        )
560        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
561
562    def test_qat_conv_bn_relu_fusion_no_conv_bias(self):
563        m = self._get_conv_bn_model(has_conv_bias=False, has_relu=True)
564        self._verify_symmetric_xnnpack_qat_graph(
565            m,
566            self.example_inputs,
567            has_relu=True,
568            has_bias=False,
569        )
570        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
571
572    def test_qat_inplace_add_relu(self):
573        class M(torch.nn.Module):
574            def __init__(self, conv_class):
575                super().__init__()
576                self.conv = conv_class(1, 1, 1)
577                self.relu = torch.nn.ReLU(inplace=True)
578
579            def forward(self, x):
580                x0 = x
581                x = self.conv(x)
582                x += x0
583                x = self.relu(x)
584                return x
585
586        assert self.dim in [1, 2]
587        if self.dim == 1:
588            example_inputs = (torch.randn(1, 1, 3),)
589        else:
590            example_inputs = (torch.randn(1, 1, 3, 3),)
591
592        m = M(self.conv_class)
593        self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
594
595    def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self):
596        """
597        Test the case where the placeholder node for the [conv - bn - getitem] pattern
598        is also a getitem node:
599
600          some_op -> unrelated_getitem -> conv -> bn -> conv_bn_getitem
601
602        We want the metadata to be copied from the `conv_bn_getitem` node, not from
603        the `unrelated_getitem` node, which is not part of the conv-bn pattern but
604        is returned as part of the match anyway (as a placeholder).
605        """
606
607        class M(torch.nn.Module):
608            def __init__(self, conv_class, bn_class):
609                super().__init__()
610                self.bn1 = bn_class(3)
611                self.conv = conv_class(3, 3, 3)
612                self.bn2 = bn_class(3)
613
614            def forward(self, x):
615                x = self.bn1(x)
616                x = self.conv(x)
617                x = self.bn2(x)
618                return x
619
620        def _get_getitem_nodes(m: torch.fx.GraphModule):
621            """
622            Return a 2-tuple of (unrelated_getitem_node, conv_bn_getitem_node) from the graph.
623            """
624            unrelated_getitem_node, conv_bn_getitem_node = None, None
625            for node in m.graph.nodes:
626                if (
627                    node.target != operator.getitem
628                    or node.args[0].target
629                    != torch.ops.aten._native_batch_norm_legit.default
630                ):
631                    continue
632                if node.args[0].args[0].op == "placeholder":
633                    unrelated_getitem_node = node
634                else:
635                    conv_bn_getitem_node = node
636            assert (
637                unrelated_getitem_node is not None
638            ), "did not find unrelated getitem node, bad test setup"
639            assert (
640                conv_bn_getitem_node is not None
641            ), "did not find conv bn getitem node, bad test setup"
642            return (unrelated_getitem_node, conv_bn_getitem_node)
643
644        # Program capture
645        m = M(self.conv_class, self.bn_class)
646        m = capture_pre_autograd_graph(m, self.example_inputs)
647        m.graph.eliminate_dead_code()
648        m.recompile()
649        (_, original_conv_bn_getitem_node) = _get_getitem_nodes(m)
650
651        # Prepare QAT
652        quantizer = XNNPACKQuantizer()
653        quantizer.set_global(
654            get_symmetric_quantization_config(is_per_channel=False, is_qat=True)
655        )
656        m = prepare_qat_pt2e(m, quantizer)
657        (unrelated_getitem_node, conv_bn_getitem_node) = _get_getitem_nodes(m)
658
659        # Verify that the metadata was copied from `conv_bn_getitem`, not `unrelated_getitem`
660        original_conv_bn_getitem_meta = original_conv_bn_getitem_node.meta[
661            "quantization_annotation"
662        ]
663        conv_bn_getitem_meta = conv_bn_getitem_node.meta["quantization_annotation"]
664        self.assertEqual(conv_bn_getitem_meta, original_conv_bn_getitem_meta)
665        self.assertTrue("quantization_annotation" not in unrelated_getitem_node.meta)
666
667    def test_qat_update_shared_qspec(self):
668        """
669        Test the case where nodes used in SharedQuantizationSpec were replaced
670        during QAT subgraph rewriting.
671        """
672
673        class M(torch.nn.Module):
674            def __init__(self, conv_class, bn_class):
675                super().__init__()
676                self.conv = conv_class(3, 3, 3)
677                self.bn = bn_class(3)
678                self.hardtanh = torch.nn.Hardtanh()
679
680            def forward(self, x):
681                x = self.conv(x)
682                x = self.bn(x)
683                x = self.hardtanh(x)
684                return x
685
686        m = M(self.conv_class, self.bn_class)
687        self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs)
688
689    def test_qat_preserve_source_fn_stack(self):
690        """
691        Test whether `source_fn_stack` is preserved after QAT fusion.
692        """
693
694        class M(torch.nn.Module):
695            def __init__(self, conv_class, bn_class, backbone):
696                super().__init__()
697                self.conv = conv_class(5, 3, 3)
698                self.bn = bn_class(3)
699                self.relu = torch.nn.ReLU()
700                self.backbone = backbone
701
702            def forward(self, x):
703                x = self.conv(x)
704                x = self.bn(x)
705                x = self.relu(x)
706                x = self.backbone(x)
707                return x
708
709        assert self.dim in [1, 2]
710        if self.dim == 1:
711            example_inputs = (torch.randn(1, 5, 10),)
712        else:
713            example_inputs = (torch.randn(1, 5, 10, 10),)
714
715        # QAT prepare + convert
716        backbone = self._get_conv_bn_model(has_relu=True)
717        m = M(self.conv_class, self.bn_class, backbone)
718        quantizer = XNNPACKQuantizer()
719        quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
720        m = capture_pre_autograd_graph(m, example_inputs)
721        m = prepare_qat_pt2e(m, quantizer)
722        m(*example_inputs)
723        m = convert_pt2e(m)
724
725        # Extract the conv and relu nodes (bn was folded into conv)
726        first_conv, first_relu, second_conv, second_relu = None, None, None, None
727        for n in m.graph.nodes:
728            if n.target == torch.ops.aten.relu.default:
729                if first_relu is None:
730                    assert first_conv is None, "bad test setup"
731                    first_relu = n
732                    first_conv = n.args[0]
733                else:
734                    assert second_conv is None, "bad test setup"
735                    second_relu = n
736                    second_conv = n.args[0]
737
738        # Extract the conv weight and bias nodes
739        def get_conv_weight_and_bias(conv_node: torch.fx.Node):
740            weight_dq_node = conv_node.args[1]
741            qweight_node = weight_dq_node.args[0]
742            bias_node = conv_node.args[2]
743            assert isinstance(qweight_node, torch.fx.Node)
744            assert isinstance(bias_node, torch.fx.Node)
745            return (qweight_node, bias_node)
746
747        first_conv_qweight, first_conv_bias = get_conv_weight_and_bias(first_conv)
748        second_conv_qweight, second_conv_bias = get_conv_weight_and_bias(second_conv)
749
750        # Assert that each set of conv, conv weight, and conv bias are in the same partition
751        def get_source_fn(node: torch.fx.Node):
752            # E.g. [('l__self___backbone1_conv', <class 'torch.nn.modules.conv.Conv2d'>)]
753            return node.meta["source_fn_stack"][0][0]
754
755        # we don't preserve this is quantized weight currently since it's folded
756        # but user can attach "quantization_tag" to the node and it will be preserved
757        # self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_qweight))
758        # self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_qweight))
759
760        self.assertEqual(get_source_fn(first_conv), get_source_fn(first_conv_bias))
761        self.assertEqual(get_source_fn(second_conv), get_source_fn(second_conv_bias))
762
763        # Assert that different sets of convs and relus have different partitions
764        self.assertNotEqual(get_source_fn(first_conv), get_source_fn(first_relu))
765        self.assertNotEqual(get_source_fn(first_conv), get_source_fn(second_conv))
766        self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu))
767        self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu))
768
769        # Assert that "backbone" exists only in the second set of conv and relu's partition
770        self.assertTrue("backbone" not in get_source_fn(first_conv))
771        self.assertTrue("backbone" not in get_source_fn(first_relu))
772        self.assertTrue("backbone" in get_source_fn(second_conv))
773        self.assertTrue("backbone" in get_source_fn(second_relu))
774
775    def test_qat_conv_bn_bias_derived_qspec(self):
776        m = self._get_conv_bn_model()
777        example_inputs = self.example_inputs
778        m = capture_pre_autograd_graph(m, example_inputs)
779        quantizer = ConvBnDerivedBiasQuantizer()
780        m = prepare_qat_pt2e(m, quantizer)
781        m(*example_inputs)
782        m = convert_pt2e(m)
783        m(*example_inputs)
784
785        # Assert that both weight and bias are quantized
786        (conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
787        weight_dq = conv_node.args[1]
788        bias_dq = conv_node.args[2]
789        self.assertEqual(
790            weight_dq.target,
791            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
792        )
793        self.assertEqual(
794            bias_dq.target,
795            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
796        )
797        weight_getattr = weight_dq.args[0]
798        bias_getattr = bias_dq.args[0]
799        self.assertEqual(
800            weight_getattr.op,
801            "get_attr",
802        )
803        self.assertEqual(
804            bias_getattr.op,
805            "get_attr",
806        )
807
808        # Assert that bias scale = weight scale * input scale
809        input_dq = conv_node.args[0]
810        input_scale = input_dq.args[1]
811        bias_scale = bias_dq.args[1]
812        weight_scale = weight_dq.args[1]
813        self.assertEqual(bias_scale, input_scale * weight_scale)
814
815        # Assert that args for the bias' quantize and dequantize ops
816        # are copied correctly after subgraph rewriting
817        (bias_qmin, bias_qmax, bias_dtype) = bias_dq.args[3:]
818        self.assertEqual(bias_qmin, -(2**31))
819        self.assertEqual(bias_qmax, 2**31 - 1)
820        self.assertEqual(bias_dtype, torch.int32)
821
822    def test_qat_per_channel_weight_custom_dtype(self):
823        m = self._get_conv_bn_model()
824        example_inputs = self.example_inputs
825        m = capture_pre_autograd_graph(m, example_inputs)
826        quantizer = ConvBnInt32WeightQuantizer()
827        m = prepare_qat_pt2e(m, quantizer)
828        m(*example_inputs)
829        m = convert_pt2e(m)
830        m(*example_inputs)
831
832        # Assert that conv weight is quantized per channel
833        (conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
834        weight_dq = conv_node.args[1]
835        self.assertEqual(
836            weight_dq.target,
837            torch.ops.quantized_decomposed.dequantize_per_channel.default,
838        )
839        weight_getattr = weight_dq.args[0]
840        self.assertEqual(
841            weight_getattr.op,
842            "get_attr",
843        )
844
845        # Assert that args for the weight's dequantize ops
846        # are copied correctly after subgraph rewriting
847        (dq_axis, dq_qmin, dq_qmax, dq_dtype) = weight_dq.args[3:]
848        self.assertEqual(dq_axis, 0)
849        self.assertEqual(dq_qmin, 0)
850        self.assertEqual(dq_qmax, 2**31 - 1)
851        self.assertEqual(dq_dtype, torch.int32)
852
853    def _do_test_qat_conv_transpose_bn(self, has_relu: bool):
854        # Use different in/out channel sizes to test if conv weight is
855        # properly transposed in QAT pattern
856        m = self._get_conv_bn_model(
857            has_relu=has_relu,
858            transpose=True,
859            in_channels=3,
860            out_channels=5,
861            kernel_size=3,
862        )
863        self._verify_symmetric_xnnpack_qat_graph(
864            m,
865            self.example_inputs,
866            has_relu=has_relu,
867            verify_convert=True,
868        )
869
870    def test_qat_conv_transpose_bn(self):
871        self._do_test_qat_conv_transpose_bn(has_relu=False)
872
873    def test_qat_conv_transpose_bn_relu(self):
874        self._do_test_qat_conv_transpose_bn(has_relu=True)
875
876    def test_qat_conv_bn_per_channel_weight_bias(self):
877        m = self._get_conv_bn_model()
878        example_inputs = self.example_inputs
879        m = capture_pre_autograd_graph(m, example_inputs)
880        quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
881        m = prepare_qat_pt2e(m, quantizer)
882        m(*example_inputs)
883        m = convert_pt2e(m)
884        m(*example_inputs)
885
886        # Expected graph:
887        #      x -> q_tensor -> dq_tensor -> conv -> q_tensor -> dq_tensor -> output
888        #  weight -> q_channel -> dq_channel /
889        #    bias -> q_channel -> dq_channel /
890
891        (conv_node, _, _) = _get_conv_bn_getitem_nodes(m)
892        conv_op = conv_node.target
893        conv_weight_dq_op = (
894            torch.ops.quantized_decomposed.dequantize_per_channel.default
895        )
896        node_occurrence = {
897            ns.call_function(
898                torch.ops.quantized_decomposed.quantize_per_tensor.default
899            ): 2,
900            ns.call_function(
901                torch.ops.quantized_decomposed.dequantize_per_tensor.default
902            ): 2,
903            ns.call_function(
904                torch.ops.quantized_decomposed.dequantize_per_channel.default
905            ): 2,
906        }
907        node_list = [
908            ns.call_function(
909                torch.ops.quantized_decomposed.quantize_per_tensor.default
910            ),
911            ns.call_function(
912                torch.ops.quantized_decomposed.dequantize_per_tensor.default
913            ),
914            ns.call_function(conv_weight_dq_op),
915            ns.call_function(conv_weight_dq_op),
916            ns.call_function(conv_op),
917            ns.call_function(
918                torch.ops.quantized_decomposed.quantize_per_tensor.default
919            ),
920            ns.call_function(
921                torch.ops.quantized_decomposed.dequantize_per_tensor.default
922            ),
923        ]
924        self.checkGraphModuleNodes(
925            m,
926            expected_node_list=node_list,
927            expected_node_occurrence=node_occurrence,
928        )
929
930    def test_fold_bn_erases_bn_node(self):
931        """
932        Ensure the BN node is erased from the graph after folding
933        it into conv in `convert_pt2e` even in train mode.
934        """
935        m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
936        m = capture_pre_autograd_graph(m, self.example_inputs)
937        quantizer = XNNPACKQuantizer()
938        quantizer.set_global(
939            get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
940        )
941        m = prepare_qat_pt2e(m, quantizer)
942        m = convert_pt2e(m)
943        (conv_node, bn_node, _) = _get_conv_bn_getitem_nodes(m)
944        self.assertTrue(conv_node is not None)
945        self.assertTrue(bn_node is None)
946
947
948@skipIfNoQNNPACK
949class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
950    dim = 1
951    example_inputs = (torch.randn(1, 3, 5),)
952    conv_class = torch.nn.Conv1d
953    conv_transpose_class = torch.nn.ConvTranspose1d
954    bn_class = torch.nn.BatchNorm1d
955
956
957@skipIfNoQNNPACK
958class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
959    dim = 2
960    example_inputs = (torch.randn(1, 3, 5, 5),)
961    conv_class = torch.nn.Conv2d
962    conv_transpose_class = torch.nn.ConvTranspose2d
963    bn_class = torch.nn.BatchNorm2d
964
965
966def _is_conv_node(n: torch.fx.Node):
967    return n.op == "call_function" and n.target in [
968        torch.ops.aten.conv1d.default,
969        torch.ops.aten.conv2d.default,
970        torch.ops.aten.conv_transpose1d,
971        torch.ops.aten.conv_transpose1d.default,
972        torch.ops.aten.conv_transpose2d,
973        torch.ops.aten.conv_transpose2d.input,
974    ]
975
976
977def _get_conv_bn_getitem_nodes(model: torch.fx.GraphModule):
978    """
979    Return a 3-tuple of (conv, bn, getitem) nodes from the graph.
980    """
981    model.graph.eliminate_dead_code()
982    model.recompile()
983    conv_node = None
984    bn_node = None
985    getitem_node = None
986    for n in model.graph.nodes:
987        if _is_conv_node(n):
988            conv_node = n
989        if n.target in (
990            torch.ops.aten._native_batch_norm_legit.default,
991            torch.ops.aten.batch_norm.default,
992        ):
993            bn_node = n
994        if n.target == operator.getitem:
995            getitem_node = n
996    assert conv_node is not None, "bad test setup"
997    return (conv_node, bn_node, getitem_node)
998
999
1000class ConvBnInt32WeightQuantizer(Quantizer):
1001    """
1002    Dummy quantizer that annotates conv bn in such a way that the weights
1003    are quantized per channel to int32.
1004    """
1005
1006    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1007        conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
1008        act_qspec = QuantizationSpec(
1009            dtype=torch.uint8,
1010            quant_min=0,
1011            quant_max=255,
1012            qscheme=torch.per_tensor_affine,
1013            observer_or_fake_quant_ctr=default_fake_quant,
1014        )
1015        weight_qspec = QuantizationSpec(
1016            dtype=torch.int32,
1017            quant_min=0,
1018            quant_max=2**31 - 1,
1019            qscheme=torch.per_channel_affine,
1020            observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args(
1021                observer=MovingAveragePerChannelMinMaxObserver,
1022            ),
1023        )
1024        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
1025            input_qspec_map={
1026                conv_node.args[0]: act_qspec,
1027                conv_node.args[1]: weight_qspec,
1028            },
1029            _annotated=True,
1030        )
1031        if getitem_node is not None:
1032            # TODO: This branch is going through a deprecated branch and should be deleted soon,
1033            # after capture_pre_autograd_graph fully migrate to training IR
1034            # T199018392
1035            getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
1036                output_qspec=act_qspec,
1037                _annotated=True,
1038            )
1039        else:
1040            # See NOTE [training ir has no getitem for bn node].
1041            assert capture_pre_autograd_graph_using_training_ir()
1042            bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
1043                output_qspec=act_qspec,
1044                _annotated=True,
1045            )
1046        return model
1047
1048    def validate(self, model: torch.fx.GraphModule):
1049        pass
1050
1051
1052class ConvBnDerivedBiasQuantizer(Quantizer):
1053    """
1054    Dummy quantizer that annotates conv bn in such a way that the bias qparams are
1055    derived from the conv input activation and weight qparams.
1056    """
1057
1058    def __init__(self, is_per_channel: bool = False):
1059        super().__init__()
1060        self.is_per_channel = is_per_channel
1061
1062    def _derive_bias_qparams_from_act_and_weight_qparams(self, obs_or_fqs):
1063        act_scale, _ = obs_or_fqs[0].calculate_qparams()
1064        weight_scale, _ = obs_or_fqs[1].calculate_qparams()
1065        if self.is_per_channel:
1066            bias_scale = act_scale * weight_scale
1067            bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
1068        else:
1069            bias_scale = torch.tensor([act_scale * weight_scale], dtype=torch.float32)
1070            bias_zero_point = torch.tensor([0], dtype=torch.int32)
1071        return bias_scale, bias_zero_point
1072
1073    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1074        if self.is_per_channel:
1075            weight_qscheme = torch.per_channel_symmetric
1076            weight_fq = FusedMovingAvgObsFakeQuantize.with_args(
1077                observer=MovingAveragePerChannelMinMaxObserver,
1078            )
1079        else:
1080            weight_qscheme = torch.per_tensor_affine
1081            weight_fq = default_fake_quant
1082        conv_node, bn_node, getitem_node = _get_conv_bn_getitem_nodes(model)
1083        act_qspec = QuantizationSpec(
1084            dtype=torch.uint8,
1085            quant_min=0,
1086            quant_max=255,
1087            qscheme=torch.per_tensor_affine,
1088            observer_or_fake_quant_ctr=default_fake_quant,
1089        )
1090        weight_qspec = QuantizationSpec(
1091            dtype=torch.uint8,
1092            quant_min=0,
1093            quant_max=255,
1094            qscheme=weight_qscheme,
1095            observer_or_fake_quant_ctr=weight_fq,
1096        )
1097        bias_qspec = DerivedQuantizationSpec(
1098            derived_from=[
1099                (conv_node.args[0], conv_node),
1100                (conv_node.args[1], conv_node),
1101            ],
1102            derive_qparams_fn=self._derive_bias_qparams_from_act_and_weight_qparams,
1103            dtype=torch.int32,
1104            quant_min=-(2**31),
1105            quant_max=2**31 - 1,
1106            qscheme=weight_qscheme,
1107            ch_axis=0 if self.is_per_channel else None,
1108        )
1109        conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
1110            input_qspec_map={
1111                conv_node.args[0]: act_qspec,
1112                conv_node.args[1]: weight_qspec,
1113                conv_node.args[2]: bias_qspec,
1114            },
1115            _annotated=True,
1116        )
1117
1118        if getitem_node is not None:
1119            # TODO: This branch is going through a deprecated branch and should be deleted soon,
1120            # after capture_pre_autograd_graph fully migrate to training IR
1121            # T199018392
1122            getitem_node.meta["quantization_annotation"] = QuantizationAnnotation(
1123                output_qspec=act_qspec,
1124                _annotated=True,
1125            )
1126        else:
1127            # NOTE [training ir has no getitem for bn node].
1128            # getitem is None when we use the training IR. It outputs
1129            # aten.batch_norm.default, which do not need any getitem node.
1130            # In this case, we need to annotate on the batch norm node.
1131            # geteitem node should only be None if we are using training IR.
1132            assert capture_pre_autograd_graph_using_training_ir()
1133            bn_node.meta["quantization_annotation"] = QuantizationAnnotation(
1134                output_qspec=act_qspec,
1135                _annotated=True,
1136            )
1137        return model
1138
1139    def validate(self, model: torch.fx.GraphModule):
1140        pass
1141
1142
1143@skipIfNoQNNPACK
1144class TestQuantizePT2EQATModels(PT2EQATTestCase):
1145    @skip_if_no_torchvision
1146    @skipIfNoQNNPACK
1147    def test_qat_resnet18(self):
1148        import torchvision
1149
1150        with override_quantized_engine("qnnpack"):
1151            example_inputs = (torch.randn(1, 3, 224, 224),)
1152            m = torchvision.models.resnet18()
1153            self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
1154
1155    @skip_if_no_torchvision
1156    @skipIfNoQNNPACK
1157    def test_qat_mobilenet_v2(self):
1158        import torchvision
1159
1160        with override_quantized_engine("qnnpack"):
1161            example_inputs = (torch.randn(1, 3, 224, 224),)
1162            m = torchvision.models.mobilenet_v2()
1163            self._verify_symmetric_xnnpack_qat_numerics(m, example_inputs)
1164
1165
1166class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
1167    class TwoLinear(torch.nn.Module):
1168        def __init__(self) -> None:
1169            super().__init__()
1170            self.linear1 = torch.nn.Linear(16, 8, bias=False)
1171            self.linear2 = torch.nn.Linear(8, 8)
1172
1173        def forward(self, x):
1174            return self.linear2(self.linear1(x))
1175
1176    class QATPTQTestModule(torch.nn.Module):
1177        def __init__(self) -> None:
1178            super().__init__()
1179            self.conv = torch.nn.Conv2d(3, 16, 3)
1180            self.linears = TestQuantizeMixQATAndPTQ.TwoLinear()
1181            self.my_linear = torch.nn.Linear(8, 8)
1182
1183        def forward(self, x):
1184            conv_out = self.conv(x)
1185            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
1186            linear_out = self.linears(permute_out)
1187            my_linear_out = self.my_linear(linear_out)
1188            # Hardtanh doesnt get quantized via xnnpack quantizer in this test
1189            # because it relies on the propagation rules
1190            # Need to fix this
1191            return torch.nn.functional.hardtanh(my_linear_out)
1192
1193    def _prepare_qat_linears(self, model):
1194        for name, child in model.named_children():
1195            if isinstance(child, (torch.nn.Linear, TestQuantizeMixQATAndPTQ.TwoLinear)):
1196                if isinstance(child, torch.nn.Linear):
1197                    in_channels = child.weight.size(1)
1198                else:
1199                    in_channels = child.linear1.weight.size(1)
1200
1201                example_input = (torch.rand((1, in_channels)),)
1202                traced_child = capture_pre_autograd_graph(child, example_input)
1203                quantizer = XNNPACKQuantizer()
1204                quantization_config = get_symmetric_quantization_config(
1205                    is_per_channel=True, is_qat=True
1206                )
1207                quantizer.set_global(quantization_config)
1208                traced_child_prepared = prepare_qat_pt2e(traced_child, quantizer)
1209                setattr(model, name, traced_child_prepared)
1210            else:
1211                self._prepare_qat_linears(child)
1212
1213    def _convert_qat_linears(self, model):
1214        for name, child in model.named_children():
1215            if isinstance(child, torch.fx.GraphModule):
1216                torch.ao.quantization.move_exported_model_to_eval(child)
1217                converted_child = convert_pt2e(child)
1218                setattr(model, name, converted_child)
1219            else:
1220                self._convert_qat_linears(child)
1221
1222    def test_mixing_qat_ptq(self):
1223        example_inputs = (torch.randn(2, 3, 4, 4),)
1224        model = TestQuantizeMixQATAndPTQ.QATPTQTestModule()
1225
1226        self._prepare_qat_linears(model)
1227
1228        after_prepare_result_pt2e = model(*example_inputs)
1229        # must be fixed model.eval()
1230        self._convert_qat_linears(model)
1231        quant_result_pt2e = model(*example_inputs)
1232
1233        model_pt2e = capture_pre_autograd_graph(
1234            model,
1235            example_inputs,
1236        )
1237
1238        quantizer = XNNPACKQuantizer()
1239        quantizer.set_module_type(torch.nn.Linear, None)
1240        quantization_config = get_symmetric_quantization_config()
1241        quantizer.set_global(quantization_config)
1242        model_pt2e = prepare_pt2e(model_pt2e, quantizer)
1243        after_prepare_result_pt2e = model_pt2e(*example_inputs)
1244        model_pt2e = convert_pt2e(model_pt2e)
1245        quant_result_pt2e = model_pt2e(*example_inputs)
1246
1247        exported_model = torch.export.export(model_pt2e, example_inputs)
1248
1249        node_occurrence = {
1250            # conv2d: 1 for act, 1 for weight, 1 for output
1251            # 3 x linear: 1 for act, 1 for output
1252            ns.call_function(
1253                torch.ops.quantized_decomposed.quantize_per_tensor.default
1254            ): 8,
1255            ns.call_function(
1256                torch.ops.quantized_decomposed.dequantize_per_tensor.default
1257            ): 9,
1258            ns.call_function(
1259                torch.ops.quantized_decomposed.dequantize_per_channel.default
1260            ): 3,
1261            # There needs to be one for hardtanh
1262        }
1263        self.checkGraphModuleNodes(
1264            exported_model.graph_module, expected_node_occurrence=node_occurrence
1265        )
1266