xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_xnnpack_quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2import copy
3import operator
4
5import torch
6import torch._dynamo as torchdynamo
7from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
8from torch.ao.ns.fx.utils import compute_sqnr
9from torch.ao.quantization import (
10    default_dynamic_fake_quant,
11    default_dynamic_qconfig,
12    observer,
13    QConfig,
14    QConfigMapping,
15)
16from torch.ao.quantization.backend_config import get_qnnpack_backend_config
17from torch.ao.quantization.qconfig import (
18    default_per_channel_symmetric_qnnpack_qconfig,
19    default_symmetric_qnnpack_qconfig,
20    per_channel_weight_observer_range_neg_127_to_127,
21    weight_observer_range_neg_127_to_127,
22)
23from torch.ao.quantization.quantize_fx import (
24    _convert_to_reference_decomposed_fx,
25    convert_to_reference_fx,
26    prepare_fx,
27)
28from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
29from torch.ao.quantization.quantizer.xnnpack_quantizer import (
30    get_symmetric_quantization_config,
31    XNNPACKQuantizer,
32)
33from torch.export import export_for_training
34from torch.testing._internal.common_quantization import (
35    NodeSpec as ns,
36    PT2EQuantizationTestCase,
37    skip_if_no_torchvision,
38    skipIfNoQNNPACK,
39    TestHelperModules,
40)
41from torch.testing._internal.common_quantized import override_quantized_engine
42
43
44@skipIfNoQNNPACK
45class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
46    def test_conv1d(self):
47        quantizer = XNNPACKQuantizer()
48        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
49        quantizer.set_global(quantization_config)
50        example_inputs = (torch.randn(1, 3, 5),)
51        node_occurrence = {
52            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
53            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
54            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
55            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
56            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
57        }
58        node_list = [
59            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
60            torch.ops.aten.conv1d.default,
61            torch.ops.quantized_decomposed.quantize_per_tensor.default,
62        ]
63        self._test_quantizer(
64            TestHelperModules.ConvWithBNRelu(dim=1, relu=False, bn=False),
65            example_inputs,
66            quantizer,
67            node_occurrence,
68            node_list,
69        )
70
71    def test_conv2d(self):
72        quantizer = XNNPACKQuantizer()
73        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
74        quantizer.set_global(quantization_config)
75        example_inputs = (torch.randn(1, 3, 5, 5),)
76        node_occurrence = {
77            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
78            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
79            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
80            # quantize_per_channel for weights are const propagated
81            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
82            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
83        }
84        node_list = [
85            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
86            torch.ops.aten.conv2d.default,
87            torch.ops.quantized_decomposed.quantize_per_tensor.default,
88        ]
89        self._test_quantizer(
90            TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
91            example_inputs,
92            quantizer,
93            node_occurrence,
94            node_list,
95        )
96
97    def test_conv1d_with_conv2d(self):
98        quantizer = XNNPACKQuantizer()
99        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
100        quantizer.set_global(quantization_config)
101        node_occurrence = {
102            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
103            torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
104            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
105            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
106            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
107        }
108        node_list = [
109            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
110            torch.ops.aten.conv2d.default,
111            torch.ops.quantized_decomposed.quantize_per_tensor.default,
112            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
113            torch.ops.aten.conv1d.default,
114            torch.ops.quantized_decomposed.quantize_per_tensor.default,
115        ]
116        m = TestHelperModules.Conv2dThenConv1d()
117        self._test_quantizer(
118            m,
119            m.example_inputs(),
120            quantizer,
121            node_occurrence,
122            node_list,
123        )
124
125    def test_linear(self):
126        quantizer = XNNPACKQuantizer()
127        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
128        quantizer.set_global(quantization_config)
129        m_eager = TestHelperModules.TwoLinearModule().eval()
130
131        # Test with 2d inputs
132        example_inputs_2d = (torch.randn(9, 8),)
133        example_inputs_3d = (torch.randn(9, 10, 8),)
134        example_inputs_4d = (torch.randn(9, 10, 11, 8),)
135        node_occurrence = {
136            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
137            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
138            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
139            # quantize_per_channel for weights are const propagated
140            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
141            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
142        }
143        qconfig = default_per_channel_symmetric_qnnpack_qconfig
144        qconfig_mapping = QConfigMapping().set_global(qconfig)
145        for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]:
146            self._test_quantizer(
147                m_eager,
148                example_inputs,
149                quantizer,
150                node_occurrence,
151                [],
152                True,
153                qconfig_mapping,
154            )
155
156    def test_linear_relu(self):
157        quantizer = XNNPACKQuantizer()
158        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
159        quantizer.set_global(quantization_config)
160        m_eager = TestHelperModules.LinearReluModel().eval()
161
162        # Test with 2d inputs
163        example_inputs_2d = (torch.randn(1, 5),)
164        example_inputs_3d = (torch.randn(1, 2, 5),)
165        example_inputs_4d = (torch.randn(1, 2, 3, 5),)
166
167        node_occurrence = {
168            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
169            # There should not be extra quantize_per_tensor or dequantize_per_tensors for relu
170            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
171            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
172            # quantize_per_channel for weights are const propagated
173            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
174            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
175        }
176        qconfig = default_per_channel_symmetric_qnnpack_qconfig
177        qconfig_mapping = QConfigMapping().set_global(qconfig)
178        for example_inputs in [example_inputs_2d, example_inputs_3d, example_inputs_4d]:
179            self._test_quantizer(
180                m_eager,
181                example_inputs,
182                quantizer,
183                node_occurrence,
184                [],  # node_list
185                False,  # executorch_backend_config() does not fuse linear-relu
186                qconfig_mapping,
187            )
188
189    def test_conv_linear_no_permute(self):
190        quantizer = XNNPACKQuantizer()
191        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
192        quantizer.set_global(quantization_config)
193        node_occurrence = {
194            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
195            torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
196            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
197            # quantize_per_channel for weights are const propagated
198            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
199            torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
200        }
201        qconfig = default_per_channel_symmetric_qnnpack_qconfig
202        qconfig_mapping = QConfigMapping().set_global(qconfig)
203        # Test with 2d inputs
204        example_inputs = (torch.randn(2, 3, 4, 4),)
205        self._test_quantizer(
206            TestHelperModules.Conv2dWithTwoLinear(),
207            example_inputs,
208            quantizer,
209            node_occurrence,
210            [],
211            True,
212            qconfig_mapping,
213        )
214
215    def test_conv_linear(self):
216        quantizer = XNNPACKQuantizer()
217        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
218        quantizer.set_global(quantization_config)
219
220        # Test with 2d inputs
221        example_inputs = (torch.randn(2, 3, 4, 4),)
222        node_occurrence = {
223            torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
224            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
225            # quantize_per_channel for weights are const propagated
226            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
227            torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
228        }
229        qconfig = default_per_channel_symmetric_qnnpack_qconfig
230        qconfig_mapping = QConfigMapping().set_global(qconfig)
231        self._test_quantizer(
232            TestHelperModules.Conv2dWithTwoLinearPermute(),
233            example_inputs,
234            quantizer,
235            node_occurrence,
236            [],
237            True,
238            qconfig_mapping,
239        )
240
241    def test_linear_with_dynamic_shape(self):
242        quantizer = XNNPACKQuantizer()
243        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
244        quantizer.set_global(quantization_config)
245        m_eager = TestHelperModules.TwoLinearModule().eval()
246
247        # Test with 2d inputs
248        example_inputs_3d = (torch.randn(9, 10, 8),)
249        node_occurrence = {
250            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
251            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
252            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
253            # quantize_per_channel for weights are const propagated
254            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
255            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
256        }
257        qconfig = default_per_channel_symmetric_qnnpack_qconfig
258        qconfig_mapping = QConfigMapping().set_global(qconfig)
259        self._test_quantizer(
260            m_eager,
261            example_inputs_3d,
262            quantizer,
263            node_occurrence,
264            [],
265            True,
266            qconfig_mapping,
267            export_with_dynamic_shape=True,
268        )
269
270    def test_obs_sharing_ops(self):
271        quantizer = XNNPACKQuantizer()
272        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
273        quantizer.set_global(quantization_config)
274        m = TestHelperModules.Conv2dWithObsSharingOps().eval()
275        example_inputs = (torch.randn(1, 3, 5, 5),)
276        node_occurrence = {
277            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
278            torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
279            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
280            # quantize_per_channel for weights are const propagated
281            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
282            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
283        }
284        node_list = [
285            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
286            torch.ops.aten.conv2d.default,
287            torch.ops.quantized_decomposed.quantize_per_tensor.default,
288            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
289            torch.ops.aten.adaptive_avg_pool2d.default,
290            torch.ops.quantized_decomposed.quantize_per_tensor.default,
291            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
292            torch.ops.aten.hardtanh.default,
293            torch.ops.quantized_decomposed.quantize_per_tensor.default,
294            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
295            torch.ops.aten.mean.default,
296            torch.ops.quantized_decomposed.quantize_per_tensor.default,
297            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
298        ]
299        self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
300
301    def test_set_module_name(self):
302        class Sub(torch.nn.Module):
303            def __init__(self) -> None:
304                super().__init__()
305                self.linear = torch.nn.Linear(5, 5)
306
307            def forward(self, x):
308                return self.linear(x)
309
310        class M(torch.nn.Module):
311            def __init__(self) -> None:
312                super().__init__()
313                self.linear = torch.nn.Linear(5, 5)
314                self.sub = Sub()
315
316            def forward(self, x):
317                x = self.linear(x)
318                x = self.sub(x)
319                return x
320
321        m = M().eval()
322        example_inputs = (torch.randn(3, 5),)
323        quantizer = XNNPACKQuantizer()
324        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
325        quantizer.set_module_name("sub", quantization_config)
326        node_occurrence = {
327            torch.ops.aten.linear.default: 2,
328            # input and output for the second linear
329            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
330            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
331        }
332        node_list = [
333            # first linear is not quantized
334            torch.ops.aten.linear.default,
335            # second linear is quantized
336            torch.ops.quantized_decomposed.quantize_per_tensor.default,
337            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
338            torch.ops.aten.linear.default,
339            torch.ops.quantized_decomposed.quantize_per_tensor.default,
340            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
341        ]
342        self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
343
344    def test_set_module_name_with_underscores(self) -> None:
345        """Test that if a module name has an underscore, we can still quantize it"""
346
347        class M(torch.nn.Module):
348            def __init__(self) -> None:
349                super().__init__()
350                # This module name has underscores, which can be part of a mangled
351                # name.
352                self.foo_bar = torch.nn.Linear(2, 2)
353                self.baz = torch.nn.Linear(2, 2)
354
355            def forward(self, x):
356                return self.baz(self.foo_bar(x))
357
358        quantizer = XNNPACKQuantizer()
359        # Set global to no quantization and then per-channel for a specific submodule.
360        quantizer.set_module_name(
361            "foo_bar", get_symmetric_quantization_config(is_per_channel=True)
362        )
363        example_inputs = (torch.randn(2, 2),)
364        m = M().eval()
365        m = export_for_training(m, example_inputs).module()
366        m = prepare_pt2e(m, quantizer)
367        # Use a linear count instead of names because the names might change, but
368        # the order should be the same.
369        count = 0
370        for n in m.graph.nodes:
371            if n.op == "call_function" and n.target == torch.ops.aten.linear.default:
372                # Get the weight observer to see the per-channel vs per-tensor.
373                weight_observer_node = n.args[1]
374                if count == 0:
375                    # The weight tensor should be per-tensor and not per-channel
376                    # for foo_bar.
377                    self.assertEqual(weight_observer_node.op, "call_module")
378                    observer_instance = getattr(m, weight_observer_node.target)
379                    self.assertEqual(
380                        observer_instance.qscheme, torch.per_channel_symmetric
381                    )
382                else:
383                    # For baz it should have no observer at all.
384                    self.assertNotEqual(weight_observer_node.op, "call_module")
385                count += 1
386
387    def test_set_module_type(self):
388        class Sub(torch.nn.Module):
389            def __init__(self) -> None:
390                super().__init__()
391                self.linear = torch.nn.Linear(5, 5)
392
393            def forward(self, x):
394                return self.linear(x)
395
396        class M(torch.nn.Module):
397            def __init__(self) -> None:
398                super().__init__()
399                self.linear = torch.nn.Linear(5, 5)
400                self.sub = Sub()
401
402            def forward(self, x):
403                x = self.linear(x)
404                x = self.sub(x)
405                return x
406
407        m = M().eval()
408        example_inputs = (torch.randn(3, 5),)
409        quantizer = XNNPACKQuantizer()
410        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
411        quantizer.set_module_type(Sub, quantization_config)
412        node_occurrence = {
413            torch.ops.aten.linear.default: 2,
414            # input and output for the second linear
415            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
416            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
417        }
418        node_list = [
419            # first linear is not quantized
420            torch.ops.aten.linear.default,
421            # second linear is quantized
422            torch.ops.quantized_decomposed.quantize_per_tensor.default,
423            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
424            torch.ops.aten.linear.default,
425            torch.ops.quantized_decomposed.quantize_per_tensor.default,
426            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
427        ]
428        self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
429
430    def test_set_module_type_case_2(self):
431        class M(torch.nn.Module):
432            def __init__(self) -> None:
433                super().__init__()
434                self.conv = torch.nn.Conv2d(
435                    in_channels=3,
436                    out_channels=3,
437                    kernel_size=3,
438                    stride=1,
439                    padding=1,
440                    bias=True,
441                )
442                self.conv2 = torch.nn.Conv2d(
443                    in_channels=3,
444                    out_channels=3,
445                    kernel_size=3,
446                    stride=1,
447                    padding=1,
448                    bias=True,
449                )
450                self.conv3 = torch.nn.Conv2d(
451                    in_channels=3,
452                    out_channels=3,
453                    kernel_size=3,
454                    stride=1,
455                    padding=1,
456                    bias=True,
457                )
458                self.relu = torch.nn.ReLU()
459                self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
460                self.fc = torch.nn.Linear(3, 16)
461
462            def forward(self, x):
463                x1 = self.conv(x)
464                x2 = self.relu(self.conv2(x1) + self.conv3(x1))
465                x3 = self.avgpool(x2)
466                x4 = torch.flatten(x3, 1)
467                x5 = self.fc(x4)
468                return x5
469
470        m = M().eval()
471        example_inputs = (torch.randn(1, 3, 16, 16),)
472        quantizer = XNNPACKQuantizer()
473        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
474        # We only want to annotate Linear type
475        quantizer.set_module_type(torch.nn.Linear, quantization_config)
476        node_occurrence = {
477            torch.ops.aten.conv2d.default: 3,
478            torch.ops.aten.linear.default: 1,
479            # input and output for the linear
480            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
481            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
482        }
483        node_list = [
484            # only the linear is quantized
485            torch.ops.quantized_decomposed.quantize_per_tensor.default,
486            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
487            torch.ops.aten.linear.default,
488            torch.ops.quantized_decomposed.quantize_per_tensor.default,
489            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
490        ]
491        self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
492
493    def test_propagate_annotation(self):
494        quantizer = XNNPACKQuantizer()
495        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
496        quantizer.set_global(quantization_config)
497        m = TestHelperModules.Conv2dPropAnnotaton().eval()
498        example_inputs = (torch.randn(1, 3, 5, 5),)
499
500        # program capture
501        m = export_for_training(
502            m,
503            example_inputs,
504        ).module()
505
506        m = prepare_pt2e(m, quantizer)
507        m(*example_inputs)
508        act_post_processes_pairs = []
509        for n in m.graph.nodes:
510            if n.target in [
511                torch.ops.aten.view.default,
512                torch.ops.aten.hardtanh.default,
513            ]:
514                input_act = getattr(m, n.args[0].target)
515                output_act = getattr(m, next(iter(n.users)).target)
516                self.assertIs(input_act, output_act)
517
518        m = convert_pt2e(m)
519        node_occurrence = {
520            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
521            ns.call_function(
522                torch.ops.quantized_decomposed.quantize_per_tensor.default
523            ): 5,
524            ns.call_function(
525                torch.ops.quantized_decomposed.dequantize_per_tensor.default
526            ): 5,
527            # note: quantize op for weights are const propagated
528            ns.call_function(
529                torch.ops.quantized_decomposed.quantize_per_channel.default
530            ): 0,
531            ns.call_function(
532                torch.ops.quantized_decomposed.dequantize_per_channel.default
533            ): 2,
534        }
535        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
536
537    def test_dynamic_linear(self):
538        quantizer = XNNPACKQuantizer()
539        quantization_config = get_symmetric_quantization_config(
540            is_per_channel=True, is_dynamic=True
541        )
542        quantizer.set_global(quantization_config)
543        m_eager = TestHelperModules.TwoLinearModule().eval()
544
545        node_occurrence = {
546            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
547            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
548            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
549            # note: quantize op for weights are const propagated
550            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
551            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
552        }
553        act_affine_quant_obs = observer.PlaceholderObserver.with_args(
554            dtype=torch.qint8,
555            qscheme=torch.per_tensor_affine,
556            quant_min=-128,
557            quant_max=127,
558            eps=2**-12,
559            is_dynamic=True,
560        )
561        qconfig = QConfig(
562            activation=act_affine_quant_obs,
563            weight=per_channel_weight_observer_range_neg_127_to_127,
564        )
565        qconfig_mapping = QConfigMapping().set_global(qconfig)
566        # Test with 2d inputs
567        example_inputs_2d = (torch.randn(9, 8),)
568        example_inputs_4d = (torch.randn(9, 10, 11, 8),)
569        for example_inputs in [example_inputs_2d, example_inputs_4d]:
570            self._test_quantizer(
571                m_eager,
572                example_inputs,
573                quantizer,
574                node_occurrence,
575                [],
576                True,
577                qconfig_mapping,
578            )
579
580    def test_dynamic_linear_int4_weight(self):
581        quantizer = XNNPACKQuantizer()
582        quantization_config = get_symmetric_quantization_config(
583            is_per_channel=True,
584            is_dynamic=True,
585            weight_qmin=0,
586            weight_qmax=15,
587        )
588        quantizer.set_global(quantization_config)
589        m_eager = TestHelperModules.TwoLinearModule().eval()
590
591        node_occurrence = {
592            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
593            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
594            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
595            # note: quantize op for weights are const propagated
596            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
597            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
598        }
599        act_affine_quant_obs = observer.PlaceholderObserver.with_args(
600            dtype=torch.qint8,
601            qscheme=torch.per_tensor_affine,
602            quant_min=-128,
603            quant_max=127,
604            eps=2**-12,
605            is_dynamic=True,
606        )
607        qconfig = QConfig(
608            activation=act_affine_quant_obs,
609            weight=per_channel_weight_observer_range_neg_127_to_127.with_args(
610                quant_min=0, quant_max=15
611            ),
612        )
613        qconfig_mapping = QConfigMapping().set_global(qconfig)
614        # Test with 2d inputs
615        example_inputs_2d = (torch.randn(9, 8),)
616        example_inputs_4d = (torch.randn(9, 10, 11, 8),)
617        for example_inputs in [example_inputs_2d, example_inputs_4d]:
618            self._test_quantizer(
619                m_eager,
620                example_inputs,
621                quantizer,
622                node_occurrence,
623                [],
624                True,
625                qconfig_mapping,
626            )
627
628    def test_qat_dynamic_linear(self):
629        quantizer = XNNPACKQuantizer()
630        quantization_config = get_symmetric_quantization_config(
631            is_per_channel=True,
632            is_dynamic=True,
633            is_qat=True,
634        )
635        quantizer.set_global(quantization_config)
636        m_eager = TestHelperModules.TwoLinearModule().eval()
637
638        node_occurrence = {
639            torch.ops.quantized_decomposed.choose_qparams.tensor: 2,
640            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
641            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
642            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
643            # note: quantize op for weights are const propagated
644            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
645            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
646        }
647        act_affine_quant_obs = default_dynamic_fake_quant
648        qconfig = QConfig(
649            activation=act_affine_quant_obs,
650            weight=per_channel_weight_observer_range_neg_127_to_127,
651        )
652        qconfig_mapping = QConfigMapping().set_global(qconfig)
653        # Test with 2d inputs
654        example_inputs_2d = (torch.randn(9, 8),)
655        example_inputs_4d = (torch.randn(9, 10, 11, 8),)
656        for example_inputs in [example_inputs_2d, example_inputs_4d]:
657            self._test_quantizer(
658                m_eager,
659                example_inputs,
660                quantizer,
661                node_occurrence,
662                [],
663                True,
664                qconfig_mapping,
665                is_qat=True,
666            )
667
668    def test_dynamic_linear_with_conv(self):
669        quantizer = XNNPACKQuantizer()
670        quantization_config = get_symmetric_quantization_config(
671            is_per_channel=False, is_dynamic=True
672        )
673        quantizer.set_global(quantization_config)
674        m_eager = TestHelperModules.ConvLinearWPermute().eval()
675
676        node_occurrence = {
677            # input and output are using quantize_per_tensor and weight is using quantize_per_channel
678            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
679            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
680            # note: quantize op for weights are const propagated
681            torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
682            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
683        }
684
685        capture_pre_autograd_graph_node_occurrence = None
686        if capture_pre_autograd_graph_using_training_ir():
687            capture_pre_autograd_graph_node_occurrence = {
688                # input and output are using quantize_per_tensor and weight is using quantize_per_channel
689                # In training IR, the decomposition is different.
690                # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes
691                # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes.
692                torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
693                torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
694                # note: quantize op for weights are const propagated
695                torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
696                torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
697            }
698        act_affine_quant_obs = observer.PlaceholderObserver.with_args(
699            dtype=torch.qint8,
700            qscheme=torch.per_tensor_affine,
701            quant_min=-128,
702            quant_max=127,
703            eps=2**-12,
704            is_dynamic=True,
705        )
706        qconfig = QConfig(
707            activation=act_affine_quant_obs,
708            weight=weight_observer_range_neg_127_to_127,
709        )
710        # Test with 2d inputs
711        example_inputs = (torch.randn(2, 3, 4, 4),)
712        qconfig_mapping = QConfigMapping().set_global(qconfig)
713        self._test_quantizer(
714            m_eager,
715            example_inputs,
716            quantizer,
717            node_occurrence,
718            [],
719            True,
720            qconfig_mapping,
721            capture_pre_autograd_graph_node_occurrence=capture_pre_autograd_graph_node_occurrence,
722        )
723
724    def test_gru(self):
725        """this is a test for annotating fp32 GRU so that it produces
726        q -> dq -> fp32_gru -> q -> dq, this is currently enough for our use cases,
727        but we may change the annotation to be more precise in the future
728        """
729
730        class RNNDynamicModel(torch.nn.Module):
731            def __init__(self, mod_type):
732                super().__init__()
733                self.qconfig = default_dynamic_qconfig
734                if mod_type == "GRU":
735                    self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
736                if mod_type == "LSTM":
737                    self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
738
739            def forward(self, input_tensor, hidden_tensor):
740                input_tensor = 1 * input_tensor
741                hidden_tensor = 1 * hidden_tensor
742                output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor)
743                return 1 * output_tensor, 1 * hidden_out
744
745        with override_quantized_engine("qnnpack"):
746            model_fx = RNNDynamicModel("GRU")
747            module_types = [torch.nn.GRU]
748            niter = 10
749            example_inputs = (
750                # input_tensor
751                torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float)
752                .unsqueeze(0)
753                .repeat(niter, 1, 1),
754                # hidden_tensor
755                # (D * num_layers, N, H_out)
756                torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1),
757            )
758            model_graph = copy.deepcopy(model_fx)
759
760            qconfig_mapping = QConfigMapping().set_object_type(
761                operator.mul, default_symmetric_qnnpack_qconfig
762            )
763            model_fx = prepare_fx(
764                model_fx,
765                qconfig_mapping,
766                example_inputs,
767                backend_config=get_qnnpack_backend_config(),
768            )
769            model_fx(*example_inputs)
770            model_fx = _convert_to_reference_decomposed_fx(model_fx)
771
772            with torchdynamo.config.patch(allow_rnn=True):
773                model_graph = export_for_training(
774                    model_graph,
775                    example_inputs,
776                ).module()
777            quantizer = XNNPACKQuantizer()
778            quantization_config = get_symmetric_quantization_config(
779                is_per_channel=False, is_dynamic=False
780            )
781            quantizer.set_global(quantization_config)
782            model_graph = prepare_pt2e(model_graph, quantizer)
783            model_graph(*example_inputs)
784            model_graph = convert_pt2e(model_graph)
785            self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
786
787    def test_linear_gru(self):
788        """this test is to make sure GRU annotation does not interfere with linear annotation"""
789
790        class RNNDynamicModel(torch.nn.Module):
791            def __init__(self, mod_type):
792                super().__init__()
793                self.qconfig = default_dynamic_qconfig
794                self.linear = torch.nn.Linear(2, 2)
795                if mod_type == "GRU":
796                    self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
797                if mod_type == "LSTM":
798                    self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
799
800            def forward(self, input_tensor, hidden_tensor):
801                input_tensor = self.linear(input_tensor)
802                input_tensor = 1 * input_tensor
803                hidden_tensor = 1 * hidden_tensor
804                output_tensor, hidden_out = self.mod(input_tensor, hidden_tensor)
805                return 1 * output_tensor, 1 * hidden_out
806
807        with override_quantized_engine("qnnpack"):
808            model_fx = RNNDynamicModel("GRU")
809            module_types = [torch.nn.GRU]
810            niter = 10
811            example_inputs = (
812                # input_tensor
813                torch.tensor([[100, -155], [-155, 100], [100, -155]], dtype=torch.float)
814                .unsqueeze(0)
815                .repeat(niter, 1, 1),
816                # hidden_tensor
817                # (D * num_layers, N, H_out)
818                torch.tensor([[[100, -155]]], dtype=torch.float).repeat(1, 3, 1),
819            )
820            model_graph = copy.deepcopy(model_fx)
821
822            qconfig_mapping = (
823                QConfigMapping()
824                .set_object_type(operator.mul, default_symmetric_qnnpack_qconfig)
825                .set_object_type(torch.nn.Linear, default_symmetric_qnnpack_qconfig)
826            )
827            model_fx = prepare_fx(
828                model_fx,
829                qconfig_mapping,
830                example_inputs,
831                backend_config=get_qnnpack_backend_config(),
832            )
833            model_fx(*example_inputs)
834            model_fx = _convert_to_reference_decomposed_fx(model_fx)
835
836            with torchdynamo.config.patch(allow_rnn=True):
837                model_graph = export_for_training(
838                    model_graph,
839                    example_inputs,
840                ).module()
841            quantizer = XNNPACKQuantizer()
842            quantization_config = get_symmetric_quantization_config(
843                is_per_channel=False, is_dynamic=False
844            )
845            quantizer.set_global(quantization_config)
846            model_graph = prepare_pt2e(model_graph, quantizer)
847            model_graph(*example_inputs)
848            model_graph = convert_pt2e(model_graph)
849            self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))
850
851    def test_add_and_inplace_add(self):
852        quantizer = XNNPACKQuantizer()
853        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
854        quantizer.set_global(quantization_config)
855        example_inputs = (
856            torch.randn(1, 3, 5, 5),
857            torch.randn(1, 3, 5, 5),
858        )
859        node_occurrence = {
860            # two input and one output for first add, and output for second add
861            torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
862            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
863        }
864        node_list = [
865            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
866            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
867            torch.ops.aten.add.Tensor,
868            torch.ops.quantized_decomposed.quantize_per_tensor.default,
869            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
870            # TODO torch.ops.aten.add.Tensor,
871            torch.ops.quantized_decomposed.quantize_per_tensor.default,
872        ]
873        self._test_quantizer(
874            TestHelperModules.AddInplaceAdd(),
875            example_inputs,
876            quantizer,
877            node_occurrence,
878            node_list,
879        )
880
881    def test_mul_and_inplace_mul(self):
882        quantizer = XNNPACKQuantizer()
883        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
884        quantizer.set_global(quantization_config)
885        example_inputs = (
886            torch.randn(1, 3, 5, 5),
887            torch.randn(1, 3, 5, 5),
888        )
889        node_occurrence = {
890            # two input and one output for first add, and output for second add
891            torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
892            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5,
893        }
894        node_list = [
895            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
896            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
897            torch.ops.aten.mul.Tensor,
898            torch.ops.quantized_decomposed.quantize_per_tensor.default,
899            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
900            # TODO torch.ops.aten.mul.Tensor,
901            torch.ops.quantized_decomposed.quantize_per_tensor.default,
902        ]
903        self._test_quantizer(
904            TestHelperModules.MulInplaceMul(),
905            example_inputs,
906            quantizer,
907            node_occurrence,
908            node_list,
909        )
910
911    def test_add_mul_scalar(self):
912        quantizer = XNNPACKQuantizer()
913        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
914        quantizer.set_global(quantization_config)
915        example_inputs = (torch.randn(1, 3, 5, 5),)
916        node_occurrence = {
917            # two input and one output for first add, and output for second add
918            torch.ops.quantized_decomposed.quantize_per_tensor.default: 5,
919            # TODO torch.ops.quantized_decomposed.dequantize_per_tensor.default: 9,
920        }
921        node_list = [
922            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
923            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
924            torch.ops.aten.add.Tensor,
925            torch.ops.quantized_decomposed.quantize_per_tensor.default,
926            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
927            torch.ops.aten.mul.Tensor,
928            torch.ops.quantized_decomposed.quantize_per_tensor.default,
929            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
930            # TODO torch.ops.aten.add.Tensor,
931            torch.ops.quantized_decomposed.quantize_per_tensor.default,
932            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
933            # TODO torch.ops.aten.mul.Tensor,
934            torch.ops.quantized_decomposed.quantize_per_tensor.default,
935        ]
936        self._test_quantizer(
937            TestHelperModules.AddMulScalar(),
938            example_inputs,
939            quantizer,
940            node_occurrence,
941            node_list,
942        )
943
944    def test_mul_float32_max(self):
945        class M(torch.nn.Module):
946            def forward(self, x):
947                return x * 3.4028235e38
948
949        quantizer = XNNPACKQuantizer()
950        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
951        quantizer.set_global(quantization_config)
952        example_inputs = (torch.randn(1, 3, 5, 5),)
953        # not quantized
954        node_occurrence = {
955            torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
956            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
957        }
958        node_list = [
959            torch.ops.aten.mul.Tensor,
960        ]
961        self._test_quantizer(
962            M(),
963            example_inputs,
964            quantizer,
965            node_occurrence,
966            node_list,
967        )
968
969    def test_add_mul_long(self):
970        class M(torch.nn.Module):
971            def __init__(self) -> None:
972                super().__init__()
973                self.t = torch.tensor([100])
974
975            def forward(self, x):
976                x = x + self.t
977                x = x * self.t
978                return x
979
980        quantizer = XNNPACKQuantizer()
981        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
982        quantizer.set_global(quantization_config)
983        example_inputs = (torch.randn(1, 3, 5, 5),)
984        # not quantized
985        node_occurrence = {
986            torch.ops.quantized_decomposed.quantize_per_tensor.default: 0,
987            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0,
988        }
989        node_list = [
990            torch.ops.aten.add.Tensor,
991            torch.ops.aten.mul.Tensor,
992        ]
993        self._test_quantizer(
994            M(),
995            example_inputs,
996            quantizer,
997            node_occurrence,
998            node_list,
999        )
1000
1001
1002# TODO: express this using self._test_quantizer, add test for inception_v4
1003class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
1004    @skip_if_no_torchvision
1005    @skipIfNoQNNPACK
1006    def test_resnet18(self):
1007        import torchvision
1008
1009        with override_quantized_engine("qnnpack"):
1010            example_inputs = (torch.randn(1, 3, 224, 224),)
1011            m = torchvision.models.resnet18().eval()
1012            m_copy = copy.deepcopy(m)
1013            # program capture
1014            m = export_for_training(
1015                m,
1016                example_inputs,
1017            ).module()
1018
1019            quantizer = XNNPACKQuantizer()
1020            quantization_config = get_symmetric_quantization_config(is_per_channel=True)
1021            quantizer.set_global(quantization_config)
1022            m = prepare_pt2e(m, quantizer)
1023            # checking that we inserted observers correctly for maxpool operator (input and
1024            # output share observer instance)
1025            self.assertEqual(
1026                id(m.activation_post_process_3), id(m.activation_post_process_2)
1027            )
1028            after_prepare_result = m(*example_inputs)
1029            m = convert_pt2e(m)
1030
1031            after_quant_result = m(*example_inputs)
1032
1033            # comparing with existing fx graph mode quantization reference flow
1034            qconfig = default_per_channel_symmetric_qnnpack_qconfig
1035            qconfig_mapping = QConfigMapping().set_global(qconfig)
1036            backend_config = get_qnnpack_backend_config()
1037            m_fx = prepare_fx(
1038                m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
1039            )
1040            after_prepare_result_fx = m_fx(*example_inputs)
1041            m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
1042
1043            after_quant_result_fx = m_fx(*example_inputs)
1044
1045            # the result matches exactly after prepare
1046            # Note: this currently will always be true since we are inserting observers
1047            # the check becomes useful when we add qat examples
1048            # but we can still manully inspect the printed observers to make sure
1049            # it matches
1050            self.assertEqual(after_prepare_result, after_prepare_result_fx)
1051            self.assertEqual(
1052                compute_sqnr(after_prepare_result, after_prepare_result_fx),
1053                torch.tensor(float("inf")),
1054            )
1055            # there are slight differences after convert due to different implementations
1056            # of quant/dequant
1057            self.assertTrue(
1058                torch.max(after_quant_result - after_quant_result_fx) < 1e-1
1059            )
1060            self.assertTrue(
1061                compute_sqnr(after_quant_result, after_quant_result_fx) > 35
1062            )
1063