xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_quantize_pt2e.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2from typing import List, Tuple
3
4import torch
5from torch import Tensor
6from torch._export import capture_pre_autograd_graph
7from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
8from torch.ao.quantization import observer, ObserverOrFakeQuantize, QConfigMapping
9from torch.ao.quantization.qconfig import (
10    default_per_channel_symmetric_qnnpack_qconfig,
11    float_qparams_weight_only_qconfig,
12    per_channel_weight_observer_range_neg_127_to_127,
13    QConfig,
14    weight_observer_range_neg_127_to_127,
15)
16from torch.ao.quantization.quantize_pt2e import (
17    convert_pt2e,
18    prepare_pt2e,
19    prepare_qat_pt2e,
20)
21from torch.ao.quantization.quantizer import (
22    DerivedQuantizationSpec,
23    FixedQParamsQuantizationSpec,
24    QuantizationAnnotation,
25    QuantizationSpec,
26    Quantizer,
27    SharedQuantizationSpec,
28)
29from torch.ao.quantization.quantizer.composable_quantizer import (  # noqa: F811
30    ComposableQuantizer,
31)
32from torch.ao.quantization.quantizer.embedding_quantizer import (  # noqa: F811
33    EmbeddingQuantizer,
34)
35from torch.ao.quantization.quantizer.xnnpack_quantizer import (
36    get_symmetric_quantization_config,
37    XNNPACKQuantizer,
38)
39from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
40    OP_TO_ANNOTATOR,
41    QuantizationConfig,
42)
43from torch.fx import Node
44from torch.testing._internal.common_quantization import (
45    NodeSpec as ns,
46    PT2EQuantizationTestCase,
47    skipIfNoQNNPACK,
48    TestHelperModules,
49)
50from torch.testing._internal.common_utils import (
51    instantiate_parametrized_tests,
52    parametrize,
53    TemporaryFileName,
54    TEST_CUDA,
55    TEST_WITH_ROCM,
56)
57
58
59@skipIfNoQNNPACK
60class TestQuantizePT2E(PT2EQuantizationTestCase):
61    def test_simple_quantizer(self):
62        # TODO: use OP_TO_ANNOTATOR
63        class BackendAQuantizer(Quantizer):
64            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
65                for node in model.graph.nodes:
66                    if (
67                        node.op == "call_function"
68                        and node.target == torch.ops.aten.conv2d.default
69                    ):
70                        input_act = node.args[0]
71                        assert isinstance(input_act, Node)
72                        weight = node.args[1]
73                        assert isinstance(weight, Node)
74                        bias = node.args[2]
75                        assert isinstance(bias, Node)
76                        act_qspec = QuantizationSpec(
77                            dtype=torch.uint8,
78                            quant_min=0,
79                            quant_max=255,
80                            qscheme=torch.per_tensor_affine,
81                            is_dynamic=False,
82                            observer_or_fake_quant_ctr=observer.default_observer,
83                        )
84                        weight_qspec = QuantizationSpec(
85                            dtype=torch.int8,
86                            quant_min=-128,
87                            quant_max=127,
88                            qscheme=torch.per_tensor_affine,
89                            is_dynamic=False,
90                            observer_or_fake_quant_ctr=observer.default_weight_observer,
91                        )
92                        bias_qspec = QuantizationSpec(
93                            dtype=torch.float32,
94                            is_dynamic=False,
95                            observer_or_fake_quant_ctr=observer.PlaceholderObserver,
96                        )
97                        node.meta["quantization_annotation"] = QuantizationAnnotation(
98                            input_qspec_map={
99                                input_act: act_qspec,
100                                weight: weight_qspec,
101                                bias: bias_qspec,
102                            },
103                            output_qspec=act_qspec,
104                            _annotated=True,
105                        )
106
107            def validate(self, model: torch.fx.GraphModule) -> None:
108                pass
109
110        example_inputs = (torch.randn(1, 3, 5, 5),)
111        node_occurrence = {
112            # two for input of the first conv, one for output for the first conv
113            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
114            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
115        }
116        node_list = [
117            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
118            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
119            torch.ops.aten.conv2d.default,
120            torch.ops.quantized_decomposed.quantize_per_tensor.default,
121        ]
122        self._test_quantizer(
123            TestHelperModules.ConvWithBNRelu(relu=False, bn=False),
124            example_inputs,
125            BackendAQuantizer(),
126            node_occurrence,
127            node_list,
128        )
129
130    def test_wo_annotate_conv_output_quantizer(self):
131        # TODO: use OP_TO_ANNOTATOR
132        class BackendAQuantizer(Quantizer):
133            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
134                act_qspec = QuantizationSpec(
135                    dtype=torch.uint8,
136                    quant_min=0,
137                    quant_max=255,
138                    qscheme=torch.per_tensor_affine,
139                    is_dynamic=False,
140                    observer_or_fake_quant_ctr=observer.default_observer,
141                )
142                weight_qspec = QuantizationSpec(
143                    dtype=torch.int8,
144                    quant_min=-128,
145                    quant_max=127,
146                    qscheme=torch.per_tensor_affine,
147                    is_dynamic=False,
148                    observer_or_fake_quant_ctr=observer.default_weight_observer,
149                )
150                bias_qspec = QuantizationSpec(
151                    dtype=torch.float32,
152                    is_dynamic=False,
153                    observer_or_fake_quant_ctr=observer.PlaceholderObserver,
154                )
155                for node in model.graph.nodes:
156                    if (
157                        node.op == "call_function"
158                        and node.target == torch.ops.aten.conv2d.default
159                    ):
160                        input_act = node.args[0]
161                        assert isinstance(input_act, Node)
162                        weight = node.args[1]
163                        assert isinstance(weight, Node)
164                        bias = node.args[2]
165                        assert isinstance(bias, Node)
166                        node.meta["quantization_annotation"] = QuantizationAnnotation(
167                            input_qspec_map={
168                                input_act: act_qspec,
169                                weight: weight_qspec,
170                                bias: bias_qspec,
171                            },
172                            _annotated=True,
173                        )
174
175            def validate(self, model: torch.fx.GraphModule) -> None:
176                pass
177
178        m = torch.nn.Conv2d(2, 2, 1)
179        x = torch.rand(1, 2, 14, 14)
180        example_inputs = (x,)
181        m = self._quantize(m, BackendAQuantizer(), example_inputs)
182        # Ensure the conv has no observer inserted at output
183        node_occurrence = {
184            # two for input of conv
185            ns.call_function(
186                torch.ops.quantized_decomposed.quantize_per_tensor.default
187            ): 1,
188            ns.call_function(
189                torch.ops.quantized_decomposed.dequantize_per_tensor.default
190            ): 2,
191        }
192        node_list = [
193            ns.call_function(
194                torch.ops.quantized_decomposed.dequantize_per_tensor.default
195            ),
196            ns.call_function(
197                torch.ops.quantized_decomposed.dequantize_per_tensor.default
198            ),
199            ns.call_function(torch.ops.aten.conv2d.default),
200        ]
201        self.checkGraphModuleNodes(
202            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
203        )
204
205    def test_max_pool2d_quantizer(self):
206        # TODO: use OP_TO_ANNOTATOR
207        class BackendAQuantizer(Quantizer):
208            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
209                act_qspec = QuantizationSpec(
210                    dtype=torch.uint8,
211                    quant_min=0,
212                    quant_max=255,
213                    qscheme=torch.per_tensor_affine,
214                    is_dynamic=False,
215                    observer_or_fake_quant_ctr=observer.default_observer,
216                )
217                weight_qspec = QuantizationSpec(
218                    dtype=torch.int8,
219                    quant_min=-128,
220                    quant_max=127,
221                    qscheme=torch.per_tensor_affine,
222                    is_dynamic=False,
223                    observer_or_fake_quant_ctr=observer.default_weight_observer,
224                )
225                bias_qspec = QuantizationSpec(
226                    dtype=torch.float32,
227                    is_dynamic=False,
228                    observer_or_fake_quant_ctr=observer.PlaceholderObserver,
229                )
230                for node in model.graph.nodes:
231                    if (
232                        node.op == "call_function"
233                        and node.target == torch.ops.aten.conv2d.default
234                    ):
235                        input_act = node.args[0]
236                        assert isinstance(input_act, Node)
237                        weight = node.args[1]
238                        assert isinstance(weight, Node)
239                        bias = node.args[2]
240                        assert isinstance(bias, Node)
241                        node.meta["quantization_annotation"] = QuantizationAnnotation(
242                            input_qspec_map={
243                                input_act: act_qspec,
244                                weight: weight_qspec,
245                                bias: bias_qspec,
246                            },
247                            _annotated=True,
248                        )
249                    if (
250                        node.op == "call_function"
251                        and node.target == torch.ops.aten.max_pool2d.default
252                    ):
253                        maxpool_node = node
254                        input_act = maxpool_node.args[0]
255                        assert isinstance(input_act, Node)
256                        maxpool_node.meta[
257                            "quantization_annotation"
258                        ] = QuantizationAnnotation(
259                            input_qspec_map={
260                                input_act: act_qspec,
261                            },
262                            output_qspec=SharedQuantizationSpec(
263                                (input_act, maxpool_node)
264                            ),
265                            _annotated=True,
266                        )
267
268            def validate(self, model: torch.fx.GraphModule) -> None:
269                pass
270
271        m = TestHelperModules.ConvMaxPool2d()
272        x = torch.rand(1, 2, 14, 14)
273        example_inputs = (x,)
274        m = self._quantize(m, BackendAQuantizer(), example_inputs)
275        node_occurrence = {
276            # two for input of conv
277            # one for input of maxpool
278            # one for output of maxpool
279            ns.call_function(
280                torch.ops.quantized_decomposed.quantize_per_tensor.default
281            ): 3,
282            ns.call_function(
283                torch.ops.quantized_decomposed.dequantize_per_tensor.default
284            ): 4,
285        }
286        node_list = [
287            ns.call_function(
288                torch.ops.quantized_decomposed.dequantize_per_tensor.default
289            ),
290            ns.call_function(
291                torch.ops.quantized_decomposed.dequantize_per_tensor.default
292            ),
293            ns.call_function(torch.ops.aten.conv2d.default),
294            ns.call_function(
295                torch.ops.quantized_decomposed.quantize_per_tensor.default
296            ),
297            ns.call_function(
298                torch.ops.quantized_decomposed.dequantize_per_tensor.default
299            ),
300            ns.call_function(torch.ops.aten.max_pool2d.default),
301        ]
302        self.checkGraphModuleNodes(
303            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
304        )
305
306    def test_derived_qspec(self):
307        # TODO: use OP_TO_ANNOTATOR
308        class BackendAQuantizer(Quantizer):
309            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
310                for node in model.graph.nodes:
311                    if (
312                        node.op == "call_function"
313                        and node.target == torch.ops.aten.conv2d.default
314                    ):
315                        input_act = node.args[0]
316                        assert isinstance(input_act, Node)
317                        weight = node.args[1]
318                        assert isinstance(weight, Node)
319                        bias = node.args[2]
320                        assert isinstance(bias, Node)
321                        act_qspec = QuantizationSpec(
322                            dtype=torch.uint8,
323                            quant_min=0,
324                            quant_max=255,
325                            qscheme=torch.per_tensor_affine,
326                            is_dynamic=False,
327                            observer_or_fake_quant_ctr=observer.default_observer,
328                        )
329                        weight_qspec = QuantizationSpec(
330                            dtype=torch.int8,
331                            quant_min=-128,
332                            quant_max=127,
333                            qscheme=torch.per_tensor_affine,
334                            is_dynamic=False,
335                            observer_or_fake_quant_ctr=observer.default_weight_observer,
336                        )
337
338                        def derive_qparams_fn(
339                            obs_or_fqs: List[ObserverOrFakeQuantize],
340                        ) -> Tuple[Tensor, Tensor]:
341                            assert (
342                                len(obs_or_fqs) == 2
343                            ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
344                            act_obs_or_fq = obs_or_fqs[0]
345                            weight_obs_or_fq = obs_or_fqs[1]
346                            act_scale, act_zp = act_obs_or_fq.calculate_qparams()
347                            (
348                                weight_scale,
349                                weight_zp,
350                            ) = weight_obs_or_fq.calculate_qparams()
351                            return torch.tensor([act_scale * weight_scale]).to(
352                                torch.float32
353                            ), torch.tensor([0]).to(torch.int32)
354
355                        bias_qspec = DerivedQuantizationSpec(
356                            derived_from=[(input_act, node), (weight, node)],
357                            derive_qparams_fn=derive_qparams_fn,
358                            dtype=torch.int32,
359                            quant_min=-(2**31),
360                            quant_max=2**31 - 1,
361                            qscheme=torch.per_tensor_symmetric,
362                        )
363                        node.meta["quantization_annotation"] = QuantizationAnnotation(
364                            input_qspec_map={
365                                input_act: act_qspec,
366                                weight: weight_qspec,
367                                bias: bias_qspec,
368                            },
369                            output_qspec=act_qspec,
370                            _annotated=True,
371                        )
372
373            def validate(self, model: torch.fx.GraphModule) -> None:
374                pass
375
376        m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
377        example_inputs = (torch.randn(1, 3, 5, 5),)
378
379        m = self._quantize(m, BackendAQuantizer(), example_inputs)
380        node_occurrence = {
381            # input, weight, bias, output for the conv
382            # note: quantize op for weight and bias are const propagated
383            ns.call_function(
384                torch.ops.quantized_decomposed.quantize_per_tensor.default
385            ): 2,
386            ns.call_function(
387                torch.ops.quantized_decomposed.dequantize_per_tensor.default
388            ): 4,
389        }
390        node_list = [
391            ns.call_function(
392                torch.ops.quantized_decomposed.dequantize_per_tensor.default
393            ),
394            ns.call_function(
395                torch.ops.quantized_decomposed.dequantize_per_tensor.default
396            ),
397            ns.call_function(
398                torch.ops.quantized_decomposed.dequantize_per_tensor.default
399            ),
400            ns.call_function(torch.ops.aten.conv2d.default),
401            ns.call_function(
402                torch.ops.quantized_decomposed.quantize_per_tensor.default
403            ),
404        ]
405        self.checkGraphModuleNodes(
406            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
407        )
408
409    def test_derived_qspec_per_channel(self):
410        class BackendAQuantizer(Quantizer):
411            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
412                for node in model.graph.nodes:
413                    if (
414                        node.op == "call_function"
415                        and node.target == torch.ops.aten.conv2d.default
416                    ):
417                        input_act = node.args[0]
418                        assert isinstance(input_act, Node)
419                        weight = node.args[1]
420                        assert isinstance(weight, Node)
421                        bias = node.args[2]
422                        assert isinstance(bias, Node)
423                        act_qspec = QuantizationSpec(
424                            dtype=torch.uint8,
425                            quant_min=0,
426                            quant_max=255,
427                            qscheme=torch.per_tensor_affine,
428                            is_dynamic=False,
429                            observer_or_fake_quant_ctr=observer.default_observer,
430                        )
431                        weight_qspec = QuantizationSpec(
432                            dtype=torch.int8,
433                            quant_min=-128,
434                            quant_max=127,
435                            qscheme=torch.per_channel_affine,
436                            is_dynamic=False,
437                            ch_axis=0,
438                            observer_or_fake_quant_ctr=observer.default_per_channel_weight_observer,
439                        )
440
441                        def derive_qparams_fn(
442                            obs_or_fqs: List[ObserverOrFakeQuantize],
443                        ) -> Tuple[Tensor, Tensor]:
444                            assert (
445                                len(obs_or_fqs) == 1
446                            ), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
447                            weight_obs_or_fq = obs_or_fqs[0]
448                            (
449                                weight_scale,
450                                weight_zp,
451                            ) = weight_obs_or_fq.calculate_qparams()
452                            return weight_scale, torch.zeros_like(weight_scale)
453
454                        bias_qspec = DerivedQuantizationSpec(
455                            derived_from=[(weight, node)],
456                            derive_qparams_fn=derive_qparams_fn,
457                            dtype=torch.int32,
458                            quant_min=-(2**31),
459                            quant_max=2**31 - 1,
460                            qscheme=torch.per_channel_symmetric,
461                            ch_axis=0,
462                        )
463                        node.meta["quantization_annotation"] = QuantizationAnnotation(
464                            input_qspec_map={
465                                input_act: act_qspec,
466                                weight: weight_qspec,
467                                bias: bias_qspec,
468                            },
469                            output_qspec=act_qspec,
470                            _annotated=True,
471                        )
472
473            def validate(self, model: torch.fx.GraphModule) -> None:
474                pass
475
476        m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval()
477        example_inputs = (torch.randn(1, 3, 5, 5),)
478
479        m = self._quantize(m, BackendAQuantizer(), example_inputs)
480
481        node_occurrence = {
482            # input, output for the conv
483            ns.call_function(
484                torch.ops.quantized_decomposed.quantize_per_tensor.default
485            ): 2,
486            ns.call_function(
487                torch.ops.quantized_decomposed.dequantize_per_tensor.default
488            ): 2,
489            # weight and bias for conv
490            # note: quantize op for weight and bias are const propagated
491            ns.call_function(
492                torch.ops.quantized_decomposed.quantize_per_channel.default
493            ): 0,
494            ns.call_function(
495                torch.ops.quantized_decomposed.dequantize_per_channel.default
496            ): 2,
497        }
498        node_list = [
499            ns.call_function(
500                torch.ops.quantized_decomposed.dequantize_per_channel.default
501            ),
502            ns.call_function(
503                torch.ops.quantized_decomposed.dequantize_per_channel.default
504            ),
505            ns.call_function(torch.ops.aten.conv2d.default),
506            ns.call_function(
507                torch.ops.quantized_decomposed.quantize_per_tensor.default
508            ),
509        ]
510        self.checkGraphModuleNodes(
511            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
512        )
513
514    def test_fixed_qparams_qspec_ptq(self):
515        self._test_fixed_qparams_qspec(is_qat=False)
516
517    # TODO: refactor and move this to test_quantize_pt2_qat.py
518    def test_fixed_qparams_qspec_qat(self):
519        self._test_fixed_qparams_qspec(is_qat=True)
520
521    def _test_fixed_qparams_qspec(self, is_qat: bool):
522        class M(torch.nn.Module):
523            def forward(self, x):
524                return torch.sigmoid(x)
525
526        class BackendAQuantizer(Quantizer):
527            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
528                for node in model.graph.nodes:
529                    if (
530                        node.op == "call_function"
531                        and node.target == torch.ops.aten.sigmoid.default
532                    ):
533                        input_act = node.args[0]
534                        assert isinstance(input_act, Node)
535                        act_qspec = FixedQParamsQuantizationSpec(
536                            dtype=torch.uint8,
537                            quant_min=0,
538                            quant_max=255,
539                            qscheme=torch.per_tensor_affine,
540                            scale=1.0 / 256.0,
541                            zero_point=0,
542                        )
543                        node.meta["quantization_annotation"] = QuantizationAnnotation(
544                            input_qspec_map={
545                                input_act: act_qspec,
546                            },
547                            output_qspec=act_qspec,
548                            _annotated=True,
549                        )
550
551            def validate(self, model: torch.fx.GraphModule) -> None:
552                pass
553
554        m = M().eval()
555        example_inputs = (torch.randn(1, 3, 5, 5),)
556
557        m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat)
558        fixed_scale = 1.0 / 256.0
559        fixed_zero_point = 0
560        for n in m.graph.nodes:
561            if n.op == "call_function":
562                if (
563                    n.target
564                    == torch.ops.quantized_decomposed.quantize_per_tensor.default
565                ):
566                    scale_0 = n.args[1]
567                    zero_point_0 = n.args[2]
568                if (
569                    n.target
570                    == torch.ops.quantized_decomposed.dequantize_per_tensor.default
571                ):
572                    scale_1 = n.args[1]
573                    zero_point_1 = n.args[2]
574        self.assertEqual(scale_0, fixed_scale)
575        self.assertEqual(zero_point_0, fixed_zero_point)
576        self.assertEqual(scale_1, fixed_scale)
577        self.assertEqual(zero_point_1, fixed_zero_point)
578        node_occurrence = {
579            # two for input of the first conv, one for output for the first conv
580            ns.call_function(
581                torch.ops.quantized_decomposed.quantize_per_tensor.default
582            ): 2,
583            ns.call_function(
584                torch.ops.quantized_decomposed.dequantize_per_tensor.default
585            ): 2,
586        }
587        node_list = [
588            ns.call_function(
589                torch.ops.quantized_decomposed.dequantize_per_tensor.default
590            ),
591            ns.call_function(torch.ops.aten.sigmoid.default),
592            ns.call_function(
593                torch.ops.quantized_decomposed.quantize_per_tensor.default
594            ),
595        ]
596        self.checkGraphModuleNodes(
597            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
598        )
599
600    def test_fixed_qparams_qspec_observer_dedup(self):
601        class BackendAQuantizer(Quantizer):
602            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
603                for node in model.graph.nodes:
604                    if (
605                        node.op == "call_function"
606                        and node.target == torch.ops.aten.sigmoid.default
607                    ):
608                        input_act = node.args[0]
609                        assert isinstance(input_act, Node)
610                        act_qspec = FixedQParamsQuantizationSpec(
611                            dtype=torch.uint8,
612                            quant_min=0,
613                            quant_max=255,
614                            qscheme=torch.per_tensor_affine,
615                            scale=1.0 / 256.0,
616                            zero_point=0,
617                        )
618                        node.meta["quantization_annotation"] = QuantizationAnnotation(
619                            input_qspec_map={
620                                input_act: act_qspec,
621                            },
622                            output_qspec=act_qspec,
623                            _annotated=True,
624                        )
625                    elif (
626                        node.op == "call_function"
627                        and node.target == torch.ops.aten.add.Tensor
628                    ):
629                        input_act0 = node.args[0]
630                        assert isinstance(input_act, Node)
631                        input_act1 = node.args[1]
632                        assert isinstance(input_act, Node)
633                        act_qspec = QuantizationSpec(
634                            observer_or_fake_quant_ctr=observer.default_observer,
635                            dtype=torch.uint8,
636                            quant_min=0,
637                            quant_max=255,
638                            qscheme=torch.per_tensor_affine,
639                        )
640                        node.meta["quantization_annotation"] = QuantizationAnnotation(
641                            input_qspec_map={
642                                input_act0: act_qspec,
643                                input_act1: act_qspec,
644                            },
645                            output_qspec=act_qspec,
646                            _annotated=True,
647                        )
648
649            def validate(self, model: torch.fx.GraphModule) -> None:
650                pass
651
652        class M(torch.nn.Module):
653            def forward(self, x, y):
654                return torch.sigmoid(x) + y
655
656            def example_inputs(self):
657                return (
658                    torch.randn(1, 3, 5, 5),
659                    torch.randn(1, 3, 5, 5),
660                )
661
662        m = M().eval()
663        example_inputs = m.example_inputs()
664        m = self._quantize(m, BackendAQuantizer(), example_inputs, is_qat=False)
665
666        node_occurrence = {
667            # two for input of the first conv, one for output for the first conv
668            ns.call_function(
669                torch.ops.quantized_decomposed.quantize_per_tensor.default
670            ): 4,
671            ns.call_function(
672                torch.ops.quantized_decomposed.dequantize_per_tensor.default
673            ): 4,
674        }
675        node_list = [
676            ns.call_function(
677                torch.ops.quantized_decomposed.dequantize_per_tensor.default
678            ),
679            ns.call_function(torch.ops.aten.sigmoid.default),
680            ns.call_function(
681                torch.ops.quantized_decomposed.quantize_per_tensor.default
682            ),
683            ns.call_function(
684                torch.ops.quantized_decomposed.dequantize_per_tensor.default
685            ),
686            ns.call_function(torch.ops.aten.add.Tensor),
687            ns.call_function(
688                torch.ops.quantized_decomposed.quantize_per_tensor.default
689            ),
690        ]
691        self.checkGraphModuleNodes(
692            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
693        )
694
695    def test_shared_qspec(self):
696        class BackendAQuantizer(Quantizer):
697            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
698                for node in model.graph.nodes:
699                    if (
700                        node.op == "call_function"
701                        and node.target == torch.ops.aten.conv2d.default
702                    ):
703                        input_act = node.args[0]
704                        assert isinstance(input_act, Node)
705                        weight = node.args[1]
706                        assert isinstance(weight, Node)
707                        bias = node.args[2]
708                        assert isinstance(bias, Node)
709                        act_qspec = QuantizationSpec(
710                            dtype=torch.uint8,
711                            quant_min=0,
712                            quant_max=255,
713                            qscheme=torch.per_tensor_affine,
714                            is_dynamic=False,
715                            observer_or_fake_quant_ctr=observer.default_observer,
716                        )
717                        weight_qspec = QuantizationSpec(
718                            dtype=torch.int8,
719                            quant_min=-128,
720                            quant_max=127,
721                            qscheme=torch.per_tensor_affine,
722                            is_dynamic=False,
723                            observer_or_fake_quant_ctr=observer.default_weight_observer,
724                        )
725                        bias_qspec = QuantizationSpec(
726                            dtype=torch.float32,
727                            is_dynamic=False,
728                            observer_or_fake_quant_ctr=observer.PlaceholderObserver,
729                        )
730                        node.meta["quantization_annotation"] = QuantizationAnnotation(
731                            input_qspec_map={
732                                input_act: act_qspec,
733                                weight: weight_qspec,
734                                bias: bias_qspec,
735                            },
736                            output_qspec=act_qspec,
737                            _annotated=True,
738                        )
739                    elif node.target is torch.ops.aten.cat.default:
740                        cat_node = node
741                        input_nodes = cat_node.args[0]
742                        first_input_node = input_nodes[0]
743                        input_qspec_map = {}
744                        act_qspec = QuantizationSpec(
745                            dtype=torch.uint8,
746                            quant_min=0,
747                            quant_max=255,
748                            qscheme=torch.per_tensor_affine,
749                            is_dynamic=False,
750                            observer_or_fake_quant_ctr=observer.default_observer,
751                        )
752                        input_qspec_map[first_input_node] = act_qspec
753                        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
754                            (first_input_node, cat_node)
755                        )
756                        for input_node in input_nodes[1:]:
757                            input_qspec_map[
758                                input_node
759                            ] = share_qparams_with_input_act0_qspec
760
761                        cat_node.meta[
762                            "quantization_annotation"
763                        ] = QuantizationAnnotation(
764                            input_qspec_map=input_qspec_map,
765                            output_qspec=share_qparams_with_input_act0_qspec,
766                            _annotated=True,
767                        )
768
769            def validate(self, model: torch.fx.GraphModule) -> None:
770                pass
771
772        m = TestHelperModules.Conv2dWithCat().eval()
773        example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
774
775        # program capture
776        m = capture_pre_autograd_graph(
777            m,
778            example_inputs,
779        )
780        m = prepare_pt2e(m, BackendAQuantizer())
781        # make sure the two observers for input are shared
782        conv_output_obs = []
783        for n in m.graph.nodes:
784            if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default:
785                conv_output_obs.append(getattr(m, next(iter(n.users)).target))
786            if n.op == "call_function" and n.target == torch.ops.aten.cat.default:
787                inputs = n.args[0]
788                input0 = inputs[0]
789                input1 = inputs[1]
790                assert input0.op == "call_module"
791                assert input1.op == "call_module"
792                obs_ins0 = getattr(m, input0.target)
793                obs_ins1 = getattr(m, input1.target)
794                assert obs_ins0 == obs_ins1
795        assert (
796            len(conv_output_obs) == 2
797        ), "expecting two observer that follows conv2d ops"
798        # checking that the output observers for the two convs are shared as well
799        assert conv_output_obs[0] == conv_output_obs[1]
800
801        m(*example_inputs)
802        m = convert_pt2e(m)
803
804        node_occurrence = {
805            # two for input of the first conv, one for output for the first conv
806            ns.call_function(
807                torch.ops.quantized_decomposed.quantize_per_tensor.default
808            ): 5,
809            ns.call_function(
810                torch.ops.quantized_decomposed.dequantize_per_tensor.default
811            ): 7,
812        }
813        node_list = [
814            ns.call_function(
815                torch.ops.quantized_decomposed.dequantize_per_tensor.default
816            ),
817            ns.call_function(
818                torch.ops.quantized_decomposed.dequantize_per_tensor.default
819            ),
820            ns.call_function(torch.ops.aten.cat.default),
821            ns.call_function(
822                torch.ops.quantized_decomposed.quantize_per_tensor.default
823            ),
824        ]
825        self.checkGraphModuleNodes(
826            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
827        )
828
829    def _test_transitive_sharing_with_cat_helper(self, quantizer):
830        m = TestHelperModules.Conv2dWithTwoCat().eval()
831        example_inputs = (
832            torch.randn(1, 3, 5, 5),
833            torch.randn(1, 3, 5, 5),
834            torch.randn(1, 6, 3, 3),
835            torch.randn(1, 6, 3, 3),
836        )
837
838        # program capture
839        m = capture_pre_autograd_graph(
840            m,
841            example_inputs,
842        )
843        m = prepare_pt2e(m, quantizer)
844        m(*example_inputs)
845        # make sure the two input observers and output are shared
846        conv_output_obs = []
847        for n in m.graph.nodes:
848            if n.op == "call_function" and n.target == torch.ops.aten.conv2d.default:
849                conv_output_obs.append(getattr(m, next(iter(n.users)).target))
850            if n.op == "call_function" and n.target == torch.ops.aten.cat.default:
851                inputs = n.args[0]
852                input0 = inputs[0]
853                input1 = inputs[1]
854                assert input0.op == "call_module"
855                assert input1.op == "call_module"
856                obs_ins0 = getattr(m, input0.target)
857                obs_ins1 = getattr(m, input1.target)
858                assert obs_ins0 == obs_ins1
859
860                output_obs = next(iter(n.users))
861                assert output_obs.op == "call_module"
862                obs_ins2 = getattr(m, output_obs.target)
863                assert obs_ins0 == obs_ins2, "input observer does not match output"
864
865        assert (
866            len(conv_output_obs) == 2
867        ), "expecting two observer that follows conv2d ops"
868        # checking that the output observers for the two convs are shared as well
869        assert conv_output_obs[0] == conv_output_obs[1]
870
871        m(*example_inputs)
872        m = convert_pt2e(m)
873
874        node_occurrence = {
875            # two for input of the first conv, one for output for the first conv
876            ns.call_function(
877                torch.ops.quantized_decomposed.quantize_per_tensor.default
878            ): 7,
879            ns.call_function(
880                torch.ops.quantized_decomposed.dequantize_per_tensor.default
881            ): 9,
882        }
883        node_list = [
884            ns.call_function(
885                torch.ops.quantized_decomposed.dequantize_per_tensor.default
886            ),
887            ns.call_function(
888                torch.ops.quantized_decomposed.dequantize_per_tensor.default
889            ),
890            ns.call_function(torch.ops.aten.cat.default),
891            ns.call_function(
892                torch.ops.quantized_decomposed.quantize_per_tensor.default
893            ),
894            ns.call_function(
895                torch.ops.quantized_decomposed.dequantize_per_tensor.default
896            ),
897            ns.call_function(torch.ops.aten.cat.default),
898            ns.call_function(
899                torch.ops.quantized_decomposed.quantize_per_tensor.default
900            ),
901        ]
902        self.checkGraphModuleNodes(
903            m, expected_node_list=node_list, expected_node_occurrence=node_occurrence
904        )
905
906    def test_shared_qspec_transitivity(self):
907        """This tests the transitivity of SharedQuantizationSpec, that is
908        if A is shared with B, B is shared with C, then C should be shared with A as well
909
910        x1 -> conv1 -> cat1 -----> cat2
911        x2 -> conv2 -/            /
912                       x3 -> add /
913                       x4  /
914
915        both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
916        so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
917        sharing group after transitive sharing
918        """
919
920        # TODO: refactor this to a common util
921        class BackendAQuantizer(Quantizer):
922            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
923                for node in model.graph.nodes:
924                    if (
925                        node.op == "call_function"
926                        and node.target == torch.ops.aten.conv2d.default
927                    ):
928                        input_act = node.args[0]
929                        assert isinstance(input_act, Node)
930                        weight = node.args[1]
931                        assert isinstance(weight, Node)
932                        bias = node.args[2]
933                        assert isinstance(bias, Node)
934                        act_qspec = QuantizationSpec(
935                            dtype=torch.uint8,
936                            quant_min=0,
937                            quant_max=255,
938                            qscheme=torch.per_tensor_affine,
939                            is_dynamic=False,
940                            observer_or_fake_quant_ctr=observer.default_observer,
941                        )
942                        weight_qspec = QuantizationSpec(
943                            dtype=torch.int8,
944                            quant_min=-128,
945                            quant_max=127,
946                            qscheme=torch.per_tensor_affine,
947                            is_dynamic=False,
948                            observer_or_fake_quant_ctr=observer.default_weight_observer,
949                        )
950                        bias_qspec = QuantizationSpec(
951                            dtype=torch.float32,
952                            is_dynamic=False,
953                            observer_or_fake_quant_ctr=observer.PlaceholderObserver,
954                        )
955                        node.meta["quantization_annotation"] = QuantizationAnnotation(
956                            input_qspec_map={
957                                input_act: act_qspec,
958                                weight: weight_qspec,
959                                bias: bias_qspec,
960                            },
961                            output_qspec=act_qspec,
962                            _annotated=True,
963                        )
964                    elif node.target is torch.ops.aten.cat.default:
965                        cat_node = node
966                        input_nodes = cat_node.args[0]
967                        first_input_node = input_nodes[0]
968                        input_qspec_map = {}
969                        act_qspec = QuantizationSpec(
970                            dtype=torch.uint8,
971                            quant_min=0,
972                            quant_max=255,
973                            qscheme=torch.per_tensor_affine,
974                            is_dynamic=False,
975                            observer_or_fake_quant_ctr=observer.default_observer,
976                        )
977                        input_qspec_map[first_input_node] = act_qspec
978                        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
979                            (first_input_node, cat_node)
980                        )
981                        for input_node in input_nodes[1:]:
982                            input_qspec_map[
983                                input_node
984                            ] = share_qparams_with_input_act0_qspec
985
986                        cat_node.meta[
987                            "quantization_annotation"
988                        ] = QuantizationAnnotation(
989                            input_qspec_map=input_qspec_map,
990                            output_qspec=share_qparams_with_input_act0_qspec,
991                            _annotated=True,
992                        )
993
994            def validate(self, model: torch.fx.GraphModule) -> None:
995                pass
996
997        self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
998
999    def test_shared_qspec_transitivity_case_2(self):
1000        """This tests the transitivity of SharedQuantizationSpec, that is
1001        if A is shared with B, B is shared with C, then C should be shared with A as well
1002
1003        x1 -> conv1 -> cat1 -----> cat2
1004        x2 -> conv2 -/            /
1005                       x3 -> add /
1006                       x4  /
1007
1008        both cat has shared input and output, and because of cat and (cat1 -> cat2) is the same Tensor
1009        so there is an implicit sharing here, all tensors connect to cat1 and cat2 are in the same
1010        sharing group after transitive sharing
1011
1012        the difference is that for this one, all edges and nodes are shared with the second input edge of cat
1013        instead of the first input edge of cat as in previous example
1014        """
1015
1016        # TODO: refactor this to a common util
1017        class BackendAQuantizer(Quantizer):
1018            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1019                for node in model.graph.nodes:
1020                    if (
1021                        node.op == "call_function"
1022                        and node.target == torch.ops.aten.conv2d.default
1023                    ):
1024                        input_act = node.args[0]
1025                        assert isinstance(input_act, Node)
1026                        weight = node.args[1]
1027                        assert isinstance(weight, Node)
1028                        bias = node.args[2]
1029                        assert isinstance(bias, Node)
1030                        act_qspec = QuantizationSpec(
1031                            dtype=torch.uint8,
1032                            quant_min=0,
1033                            quant_max=255,
1034                            qscheme=torch.per_tensor_affine,
1035                            is_dynamic=False,
1036                            observer_or_fake_quant_ctr=observer.default_observer,
1037                        )
1038                        weight_qspec = QuantizationSpec(
1039                            dtype=torch.int8,
1040                            quant_min=-128,
1041                            quant_max=127,
1042                            qscheme=torch.per_tensor_affine,
1043                            is_dynamic=False,
1044                            observer_or_fake_quant_ctr=observer.default_weight_observer,
1045                        )
1046                        bias_qspec = QuantizationSpec(
1047                            dtype=torch.float32,
1048                            is_dynamic=False,
1049                            observer_or_fake_quant_ctr=observer.PlaceholderObserver,
1050                        )
1051                        node.meta["quantization_annotation"] = QuantizationAnnotation(
1052                            input_qspec_map={
1053                                input_act: act_qspec,
1054                                weight: weight_qspec,
1055                                bias: bias_qspec,
1056                            },
1057                            output_qspec=act_qspec,
1058                            _annotated=True,
1059                        )
1060                    elif node.target is torch.ops.aten.cat.default:
1061                        cat_node = node
1062                        input_nodes = cat_node.args[0]
1063                        first_input_node = input_nodes[0]
1064                        second_input_node = input_nodes[1]
1065                        input_qspec_map = {}
1066                        act_qspec = QuantizationSpec(
1067                            dtype=torch.uint8,
1068                            quant_min=0,
1069                            quant_max=255,
1070                            qscheme=torch.per_tensor_affine,
1071                            is_dynamic=False,
1072                            observer_or_fake_quant_ctr=observer.default_observer,
1073                        )
1074                        input_qspec_map[second_input_node] = act_qspec
1075                        share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
1076                            (second_input_node, cat_node)
1077                        )
1078                        input_qspec_map[
1079                            first_input_node
1080                        ] = share_qparams_with_input_act1_qspec
1081
1082                        cat_node.meta[
1083                            "quantization_annotation"
1084                        ] = QuantizationAnnotation(
1085                            input_qspec_map=input_qspec_map,
1086                            output_qspec=share_qparams_with_input_act1_qspec,
1087                            _annotated=True,
1088                        )
1089
1090            def validate(self, model: torch.fx.GraphModule) -> None:
1091                pass
1092
1093        self._test_transitive_sharing_with_cat_helper(BackendAQuantizer())
1094
1095    def test_allow_implicit_sharing(self):
1096        """This tests the allow_transitive_sharing flag of QuantizationAnnotation, that is
1097        if a node is configured with allow_implicit_sharing=False, we will not have implicit sharing
1098        for node and (node, consumer) even they refer to the same Tensor
1099
1100        x1 -> add1 -----> add3
1101        x2 -/              /
1102               x3 -> add2 /
1103               x4 -/
1104
1105        all add has shared input and output, and second input is using shared quantization spec pointing
1106        to first input, but we set allow_implicit_sharing to False for all add nodes so input and output of add1,
1107        add2 and add3 will each belong to one sharing group, so we'll have:
1108
1109        x1 -> obs1 -> add1 -> obs1 -> obs3--> add3 -> obs3
1110        x2 -> obs1 -/                         /
1111               x3 -> obs2 -> add2 -> obs2 -> obs3
1112               x4 -> obs2 -/
1113        """
1114
1115        # TODO: refactor this to a common util
1116        class BackendAQuantizer(Quantizer):
1117            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1118                for node in model.graph.nodes:
1119                    if node.target is torch.ops.aten.add.Tensor:
1120                        add_node = node
1121                        first_input_node = add_node.args[0]
1122                        second_input_node = add_node.args[1]
1123                        input_qspec_map = {}
1124                        act_qspec = QuantizationSpec(
1125                            dtype=torch.uint8,
1126                            quant_min=0,
1127                            quant_max=255,
1128                            qscheme=torch.per_tensor_affine,
1129                            is_dynamic=False,
1130                            observer_or_fake_quant_ctr=observer.default_observer,
1131                        )
1132                        input_qspec_map[second_input_node] = act_qspec
1133                        share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
1134                            (second_input_node, add_node)
1135                        )
1136                        input_qspec_map[
1137                            first_input_node
1138                        ] = share_qparams_with_input_act1_qspec
1139
1140                        add_node.meta[
1141                            "quantization_annotation"
1142                        ] = QuantizationAnnotation(
1143                            input_qspec_map=input_qspec_map,
1144                            output_qspec=share_qparams_with_input_act1_qspec,
1145                            allow_implicit_sharing=False,
1146                            _annotated=True,
1147                        )
1148
1149            def validate(self, model: torch.fx.GraphModule) -> None:
1150                pass
1151
1152        m = TestHelperModules.ThreeAdd().eval()
1153        example_inputs = (
1154            torch.randn(1, 3, 5, 5),
1155            torch.randn(1, 3, 5, 5),
1156            torch.randn(1, 3, 5, 5),
1157            torch.randn(1, 3, 5, 5),
1158        )
1159
1160        # program capture
1161        m = capture_pre_autograd_graph(
1162            m,
1163            example_inputs,
1164        )
1165        quantizer = BackendAQuantizer()
1166        m = prepare_pt2e(m, quantizer)
1167        m(*example_inputs)
1168        observers = []
1169        for n in m.graph.nodes:
1170            if n.target == torch.ops.aten.add.Tensor:
1171                input_obs1 = getattr(m, n.args[0].target)
1172                input_obs2 = getattr(m, n.args[1].target)
1173                output_obs = getattr(m, next(iter(n.users)).target)
1174                self.assertIs(input_obs1, input_obs2)
1175                self.assertIs(input_obs1, output_obs)
1176                observers.append(input_obs1)
1177        assert len(observers) == 3
1178        self.assertIsNot(observers[0], observers[1])
1179        self.assertIsNot(observers[0], observers[2])
1180        self.assertIsNot(observers[1], observers[2])
1181
1182    @parametrize("dtype", (torch.float32, torch.bfloat16))
1183    @parametrize("quant_dtype", (torch.int16, torch.float8_e5m2, torch.float8_e4m3fn))
1184    def test_quantization_dtype(self, dtype, quant_dtype):
1185        class DtypeActQuantizer(Quantizer):
1186            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1187                info_fun = torch.iinfo if quant_dtype == torch.int16 else torch.finfo
1188                activate_qspec = QuantizationSpec(
1189                    dtype=quant_dtype,
1190                    quant_min=int(info_fun(quant_dtype).min),
1191                    quant_max=int(info_fun(quant_dtype).max),
1192                    qscheme=torch.per_tensor_affine,
1193                    is_dynamic=False,
1194                    observer_or_fake_quant_ctr=observer.default_observer,
1195                )
1196                int8_qspec = QuantizationSpec(
1197                    dtype=torch.int8,
1198                    quant_min=-128,
1199                    quant_max=127,
1200                    qscheme=torch.per_tensor_symmetric,
1201                    is_dynamic=False,
1202                    observer_or_fake_quant_ctr=observer.default_weight_observer,
1203                )
1204                quantization_config = QuantizationConfig(
1205                    input_activation=activate_qspec,
1206                    weight=int8_qspec,
1207                    bias=None,
1208                    output_activation=activate_qspec,
1209                )
1210                OP_TO_ANNOTATOR["conv"](model, quantization_config)
1211
1212            def validate(self, model: torch.fx.GraphModule) -> None:
1213                pass
1214
1215        class M(torch.nn.Module):
1216            def __init__(self, dtype):
1217                super().__init__()
1218                self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype)
1219
1220            def forward(self, x):
1221                return self.conv(x)
1222
1223        quantizer = DtypeActQuantizer()
1224        node_occurrence = {
1225            # one for input of the first conv, one for output for the first conv
1226            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
1227            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
1228        }
1229        node_list = [
1230            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1231            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1232            torch.ops.aten.conv2d.default,
1233            torch.ops.quantized_decomposed.quantize_per_tensor.default,
1234        ]
1235        example_inputs = (torch.randn(1, 3, 3, 3, dtype=dtype),)
1236        m = self._test_quantizer(
1237            M(dtype).eval(),
1238            example_inputs,
1239            quantizer,
1240            node_occurrence,
1241            node_list,
1242        )
1243
1244        def verify_quant_dequant_iotypes(m):
1245            for node in m.graph.nodes:
1246                if (
1247                    node.op == "call_function"
1248                    and node.target.__name__ == "dequantize_per_tensor.default"
1249                ):
1250                    # Check dequantize node
1251                    dequant_node = node
1252                    dequant_in_dtype = dequant_node.args[5]
1253                    dequant_out_dtype = torch.float32
1254                    if "out_dtype" in dequant_node.kwargs:
1255                        dequant_out_dtype = dequant_node.kwargs["out_dtype"]
1256
1257                    # Check preceding quantize node
1258                    # Depending on fold_quantize flag, quantize node may be absent
1259                    quant_node = node.args[0]
1260                    if (
1261                        quant_node.op == "call_function"
1262                        and quant_node.target.__name__ == "quantize_per_tensor.default"
1263                    ):
1264                        quant_in_dtype = torch.float32
1265                        if "val" in quant_node.args[0].meta:
1266                            quant_in_dtype = quant_node.args[0].meta["val"].dtype
1267                        quant_out_dtype = quant_node.args[5]
1268                        assert (
1269                            quant_in_dtype == dequant_out_dtype
1270                            and quant_out_dtype == dequant_in_dtype
1271                        ), "quant dequant io dtype check failed!"
1272
1273        verify_quant_dequant_iotypes(m)
1274
1275    def test_input_edge_sanity_check(self):
1276        class M(torch.nn.Module):
1277            def forward(self, x):
1278                return x + 6
1279
1280        class BackendAQuantizer(Quantizer):
1281            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1282                for node in model.graph.nodes:
1283                    if (
1284                        node.op == "call_function"
1285                        and node.target == torch.ops.aten.add.Tensor
1286                    ):
1287                        input_act1 = node.args[0]
1288                        # this is a constant, so not valid for annotation
1289                        input_act2 = node.args[1]
1290                        act_qspec = QuantizationSpec(
1291                            dtype=torch.uint8,
1292                            quant_min=0,
1293                            quant_max=255,
1294                            qscheme=torch.per_tensor_affine,
1295                            is_dynamic=False,
1296                            observer_or_fake_quant_ctr=observer.default_observer,
1297                        )
1298                        node.meta["quantization_annotation"] = QuantizationAnnotation(
1299                            input_qspec_map={
1300                                input_act1: act_qspec,
1301                                # this is supposed to error out
1302                                input_act2: act_qspec,
1303                            },
1304                            output_qspec=act_qspec,
1305                            _annotated=True,
1306                        )
1307
1308            def validate(self, model: torch.fx.GraphModule) -> None:
1309                pass
1310
1311        m = M().eval()
1312        example_inputs = torch.randn(1, 2, 3, 3)
1313        m = capture_pre_autograd_graph(m, (example_inputs,))
1314        with self.assertRaises(Exception):
1315            m = prepare_pt2e(m, BackendAQuantizer())
1316
1317    def test_fold_quantize(self):
1318        """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded)"""
1319        m = self._get_pt2e_quantized_linear()
1320        node_occurrence = {
1321            # quantize op for weight node is folded
1322            ns.call_function(
1323                torch.ops.quantized_decomposed.quantize_per_tensor.default
1324            ): 2,
1325            ns.call_function(
1326                torch.ops.quantized_decomposed.dequantize_per_tensor.default
1327            ): 3,
1328        }
1329        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1330
1331    def test_fold_quantize_per_channel(self):
1332        """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded)"""
1333        m = self._get_pt2e_quantized_linear(is_per_channel=True)
1334        node_occurrence = {
1335            # quantize op for weight node is folded
1336            ns.call_function(
1337                torch.ops.quantized_decomposed.quantize_per_tensor.default
1338            ): 2,
1339            ns.call_function(
1340                torch.ops.quantized_decomposed.dequantize_per_channel.default
1341            ): 1,
1342            ns.call_function(
1343                torch.ops.quantized_decomposed.dequantize_per_tensor.default
1344            ): 2,
1345        }
1346        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1347
1348    def test_dont_fold_other_constant(self):
1349        """Make sure the constant propagation does not apply to things unrelated to
1350        quantization
1351        """
1352
1353        class M(torch.nn.Module):
1354            def __init__(self) -> None:
1355                super().__init__()
1356                self.linear = torch.nn.Linear(2, 2)
1357                self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2))
1358
1359            def forward(self, x):
1360                t = self.dont_fold_me.t()
1361                return self.linear(x) + t
1362
1363        quantizer = XNNPACKQuantizer()
1364        operator_config = get_symmetric_quantization_config(is_per_channel=False)
1365        # only quantize linear, so add is not quantized and the constant Tensor
1366        # should not be folded
1367        quantizer.set_module_type(torch.nn.Linear, operator_config)
1368        example_inputs = (torch.randn(2, 2),)
1369        m = M().eval()
1370        m = self._quantize(m, quantizer, example_inputs)
1371        node_occurrence = {
1372            # quantize op for weight node is folded
1373            ns.call_function(
1374                torch.ops.quantized_decomposed.quantize_per_tensor.default
1375            ): 2,
1376            ns.call_function(
1377                torch.ops.quantized_decomposed.dequantize_per_tensor.default
1378            ): 3,
1379            # transpose op not folded
1380            ns.call_function(torch.ops.aten.t.default): 1,
1381        }
1382        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1383
1384    def test_fold_all_ops_before_quantize(self):
1385        """Test folding all ops that's before quantized operator:
1386        Before:
1387            get_attr(weight) -> transpose -> quantize -> dequantize
1388        After:
1389            get_attr(folded_weight) -> dequantize
1390        """
1391
1392        class M(torch.nn.Module):
1393            def __init__(self) -> None:
1394                super().__init__()
1395                self.weight = torch.randn(2, 2)
1396
1397            def forward(self, x):
1398                t = self.weight.t()
1399                return torch.nn.functional.linear(x, t)
1400
1401        quantizer = XNNPACKQuantizer()
1402        operator_config = get_symmetric_quantization_config(is_per_channel=False)
1403        quantizer.set_global(operator_config)
1404        example_inputs = (torch.randn(2, 2),)
1405        m = M().eval()
1406        m = self._quantize(m, quantizer, example_inputs)
1407        node_occurrence = {
1408            # quantize op for weight node is folded
1409            ns.call_function(
1410                torch.ops.quantized_decomposed.quantize_per_tensor.default
1411            ): 2,
1412            ns.call_function(
1413                torch.ops.quantized_decomposed.dequantize_per_tensor.default
1414            ): 3,
1415        }
1416        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1417
1418    def test_constant_prop_preserve_metadata(self):
1419        """Test to make sure the get_attr node for const propagated weight Tensor gets the correct
1420        metadata (from original get_attr node from weight)
1421        """
1422
1423        class M(torch.nn.Module):
1424            def __init__(self) -> None:
1425                super().__init__()
1426                self.linear = torch.nn.Linear(2, 2)
1427
1428            def forward(self, x):
1429                return self.linear(x)
1430
1431        quantizer = XNNPACKQuantizer()
1432        operator_config = get_symmetric_quantization_config()
1433        quantizer.set_global(operator_config)
1434        example_inputs = (torch.randn(2, 2),)
1435        m = M().eval()
1436        m = capture_pre_autograd_graph(
1437            m,
1438            example_inputs,
1439        )
1440        weight_meta = None
1441        for n in m.graph.nodes:
1442            if (
1443                n.op == "get_attr"
1444                and next(iter(n.users)).target == torch.ops.aten.linear.default
1445            ):
1446                weight_meta = n.meta
1447                break
1448        assert weight_meta is not None, "Expect to find metadata for weight node"
1449
1450        m = prepare_pt2e(m, quantizer)
1451        m(*example_inputs)
1452        m = convert_pt2e(m)
1453
1454        for n in m.graph.nodes:
1455            if n.op == "get_attr" and "frozen_param" in n.target:
1456                for key in n.meta:
1457                    self.assertEqual(n.meta[key], weight_meta[key])
1458
1459    def test_save_load(self):
1460        """Test save/load a quantized model"""
1461        m = self._get_pt2e_quantized_linear()
1462        example_inputs = (torch.randn(2, 2),)
1463        ref_res = m(*example_inputs)
1464
1465        with TemporaryFileName() as fname:
1466            # serialization
1467            quantized_ep = torch.export.export(m, example_inputs)
1468            torch.export.save(quantized_ep, fname)
1469            # deserialization
1470            loaded_ep = torch.export.load(fname)
1471            loaded_quantized_model = loaded_ep.module()
1472            res = loaded_quantized_model(*example_inputs)
1473            self.assertEqual(ref_res, res)
1474
1475    def test_composable_quantizer_throw(self):
1476        class BadQuantizer(Quantizer):
1477            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
1478                for n in gm.graph.nodes:
1479                    n.meta["quantization_annotation"] = None
1480
1481            def validate(self, model: torch.fx.GraphModule) -> None:
1482                pass
1483
1484        quantizer = XNNPACKQuantizer()
1485        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
1486        quantizer.set_global(quantization_config)
1487        bad_quantizer = BadQuantizer()
1488        composable_quantizer = ComposableQuantizer([quantizer, bad_quantizer])
1489        m_eager = TestHelperModules.ConvLinearWPermute().eval()
1490        example_inputs = (torch.randn(2, 3, 4, 4),)
1491        self.assertRaises(
1492            RuntimeError,
1493            lambda: self._test_quantizer(
1494                m_eager, example_inputs, composable_quantizer, {}
1495            ),
1496        )
1497
1498    def test_transform_for_annotation(self):
1499        class TestQuantizer(Quantizer):
1500            def transform_for_annotation(
1501                self, model: torch.fx.GraphModule
1502            ) -> torch.fx.GraphModule:
1503                # Make a copy of the graph to ensure that we are using the
1504                # return value of this function.
1505                graph = torch.fx.Graph()
1506                graph.graph_copy(model.graph, {})
1507                for n in graph.nodes:
1508                    if n.target == torch.ops.aten.add.Tensor:
1509                        n.target = torch.ops.aten.mul.Tensor
1510                model = torch.fx.GraphModule(model, graph)
1511                return model
1512
1513            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1514                return model
1515
1516            def validate(self, model: torch.fx.GraphModule) -> None:
1517                pass
1518
1519        class M(torch.nn.Module):
1520            def forward(self, x):
1521                return x + 3
1522
1523        m = M().eval()
1524        quantizer = TestQuantizer()
1525        example_inputs = (torch.randn(1, 2, 3, 3),)
1526        m = capture_pre_autograd_graph(m, example_inputs)
1527        m = prepare_pt2e(m, quantizer)
1528        m(*example_inputs)
1529        node_occurrence = {
1530            ns.call_function(torch.ops.aten.add.Tensor): 0,
1531            ns.call_function(torch.ops.aten.mul.Tensor): 1,
1532        }
1533        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1534
1535    def test_composable_quantizer_transform_for_annotation(self):
1536        class TestQuantizer1(Quantizer):
1537            def transform_for_annotation(
1538                self, model: torch.fx.GraphModule
1539            ) -> torch.fx.GraphModule:
1540                for n in model.graph.nodes:
1541                    if n.target == torch.ops.aten.add.Tensor:
1542                        n.target = torch.ops.aten.mul.Tensor
1543                return model
1544
1545            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1546                return model
1547
1548            def validate(self, model: torch.fx.GraphModule) -> None:
1549                pass
1550
1551        class TestQuantizer2(Quantizer):
1552            def transform_for_annotation(
1553                self, model: torch.fx.GraphModule
1554            ) -> torch.fx.GraphModule:
1555                for n in model.graph.nodes:
1556                    if n.target == torch.ops.aten.sub.Tensor:
1557                        n.target = torch.ops.aten.div.Tensor
1558                return model
1559
1560            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1561                return model
1562
1563            def validate(self, model: torch.fx.GraphModule) -> None:
1564                pass
1565
1566        class M(torch.nn.Module):
1567            def forward(self, x, y, z):
1568                return x + y - z
1569
1570        m = M().eval()
1571        quantizer = ComposableQuantizer([TestQuantizer1(), TestQuantizer2()])
1572        example_inputs = (
1573            torch.randn(1, 2, 3, 3),
1574            torch.randn(1, 2, 3, 3),
1575            torch.randn(1, 2, 3, 3),
1576        )
1577        m = capture_pre_autograd_graph(m, example_inputs)
1578        m = prepare_pt2e(m, quantizer)
1579        m(*example_inputs)
1580        node_occurrence = {
1581            ns.call_function(torch.ops.aten.add.Tensor): 0,
1582            ns.call_function(torch.ops.aten.sub.Tensor): 0,
1583            ns.call_function(torch.ops.aten.mul.Tensor): 1,
1584            ns.call_function(torch.ops.aten.div.Tensor): 1,
1585        }
1586        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1587
1588    def test_embedding_quantizer(self):
1589        m_eager = TestHelperModules.EmbeddingModule().eval()
1590        indices = torch.tensor(
1591            [
1592                9,
1593                6,
1594                5,
1595                7,
1596                8,
1597                8,
1598                9,
1599                2,
1600                8,
1601                6,
1602                6,
1603                9,
1604                1,
1605                6,
1606                8,
1607                8,
1608                3,
1609                2,
1610                3,
1611                6,
1612                3,
1613                6,
1614                5,
1615                7,
1616                0,
1617                8,
1618                4,
1619                6,
1620                5,
1621                8,
1622                2,
1623                3,
1624            ]
1625        )
1626        example_inputs = (indices,)
1627
1628        quantizer = EmbeddingQuantizer()
1629        node_occurrence = {
1630            # note: quantize op for weights are const propagated
1631            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1632            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1633        }
1634        node_list = [
1635            torch.ops.quantized_decomposed.dequantize_per_channel.default,
1636            torch.ops.aten.embedding.default,
1637        ]
1638        # Compare against short term workflow
1639        # cannot compare against fx quant because of the numerical differences coming
1640        # from quantize and dequantize ops
1641        qconfig = default_per_channel_symmetric_qnnpack_qconfig
1642        qconfig_mapping = QConfigMapping().set_global(qconfig)
1643        qconfig_mapping = qconfig_mapping.set_object_type(
1644            torch.nn.Embedding, float_qparams_weight_only_qconfig
1645        )
1646        self._test_quantizer(
1647            m_eager,
1648            example_inputs,
1649            quantizer,
1650            node_occurrence,
1651            node_list,
1652            True,
1653            qconfig_mapping,
1654        )
1655
1656    def test_composable_quantizer_linear_conv(self):
1657        dynamic_quantizer = XNNPACKQuantizer()
1658        quantization_config_dynamic = get_symmetric_quantization_config(
1659            is_per_channel=False, is_dynamic=True
1660        )
1661        dynamic_quantizer.set_global(quantization_config_dynamic)
1662        static_quantizer = XNNPACKQuantizer()
1663        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
1664        static_quantizer.set_global(quantization_config)
1665        # Note that dynamic quantization must be applied first here.
1666        # this is because static quantizer also quantizes linear with static qspec
1667        # and if we apply static_quantizer first then dynamic_quantizer cannot be applied
1668        composable_quantizer = ComposableQuantizer(
1669            [dynamic_quantizer, static_quantizer]
1670        )
1671        m_eager = TestHelperModules.ConvLinearWPermute().eval()
1672
1673        node_occurrence = {
1674            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
1675            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
1676            # note: quantize op for weights are const propagated
1677            torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
1678            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
1679            # note: quantize op for weights are const propagated
1680            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1681            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1682        }
1683        act_affine_quant_obs = observer.PlaceholderObserver.with_args(
1684            dtype=torch.qint8,
1685            qscheme=torch.per_tensor_affine,
1686            quant_min=-128,
1687            quant_max=127,
1688            eps=2**-12,
1689            is_dynamic=True,
1690        )
1691        dynamic_qconfig = QConfig(
1692            activation=act_affine_quant_obs,
1693            weight=weight_observer_range_neg_127_to_127,
1694        )
1695        # Test with 2d inputs
1696        example_inputs = (torch.randn(2, 3, 4, 4),)
1697        qconfig = default_per_channel_symmetric_qnnpack_qconfig
1698        qconfig_mapping = QConfigMapping().set_global(qconfig)
1699        qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig)
1700        # Had to turn off check against fx because fx quant workflow does not seem
1701        # to propagate observers for permute node for this model.
1702        # Suprisingly it does propagate it for EmbeddingConvLinearModule
1703        # TODO: Figure out the right behavior for propagation
1704        self._test_quantizer(
1705            m_eager,
1706            example_inputs,
1707            composable_quantizer,
1708            node_occurrence,
1709            [],
1710            False,
1711            qconfig_mapping,
1712        )
1713
1714    def test_embedding_conv_linear_quantization(self):
1715        m_eager = TestHelperModules.EmbeddingConvLinearModule().eval()
1716        indices = torch.tensor(
1717            [
1718                9,
1719                6,
1720                5,
1721                7,
1722                8,
1723                8,
1724                9,
1725                2,
1726                8,
1727                6,
1728                6,
1729                9,
1730                1,
1731                6,
1732                8,
1733                8,
1734                3,
1735                2,
1736                3,
1737                6,
1738                3,
1739                6,
1740                5,
1741                7,
1742                0,
1743                8,
1744                4,
1745                6,
1746                5,
1747                8,
1748                2,
1749                3,
1750            ]
1751        )
1752        indices = torch.unsqueeze(indices, 0)
1753        example_inputs = (indices,)
1754
1755        embedding_quantizer = EmbeddingQuantizer()
1756        dynamic_quantizer = XNNPACKQuantizer()
1757        quantization_config_dynamic = get_symmetric_quantization_config(
1758            is_per_channel=True, is_dynamic=True
1759        )
1760        dynamic_quantizer.set_global(quantization_config_dynamic)
1761        static_quantizer = XNNPACKQuantizer()
1762        quantization_config = get_symmetric_quantization_config(is_per_channel=True)
1763        static_quantizer.set_global(quantization_config)
1764        composed_quantizer = ComposableQuantizer(
1765            [embedding_quantizer, dynamic_quantizer, static_quantizer]
1766        )
1767
1768        act_affine_quant_obs = observer.PlaceholderObserver.with_args(
1769            dtype=torch.qint8,
1770            qscheme=torch.per_tensor_affine,
1771            quant_min=-128,
1772            quant_max=127,
1773            eps=2**-12,
1774            is_dynamic=True,
1775        )
1776        dynamic_qconfig = QConfig(
1777            activation=act_affine_quant_obs,
1778            weight=per_channel_weight_observer_range_neg_127_to_127,
1779        )
1780        qconfig = default_per_channel_symmetric_qnnpack_qconfig
1781        qconfig_mapping = QConfigMapping().set_global(qconfig)
1782        qconfig_mapping.set_object_type(torch.nn.Linear, dynamic_qconfig)
1783        qconfig_mapping = qconfig_mapping.set_object_type(
1784            torch.nn.Embedding, float_qparams_weight_only_qconfig
1785        )
1786
1787        node_occurrence = {
1788            torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
1789            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
1790            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
1791            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
1792            # note: quantize op for weights are const propagated
1793            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1794            torch.ops.quantized_decomposed.dequantize_per_channel.default: 3,
1795        }
1796        self._test_quantizer(
1797            m_eager,
1798            example_inputs,
1799            composed_quantizer,
1800            node_occurrence,
1801            [],
1802            True,
1803            qconfig_mapping,
1804        )
1805
1806    def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload):
1807        """
1808        Return the first node matching the specified target, throwing an exception
1809        if no such batch norm node is found.
1810        """
1811        for n in m.graph.nodes:
1812            if n.target == target:
1813                return n
1814        raise ValueError("Did not find node with target ", target)
1815
1816    def _test_move_exported_model_dropout(self, inplace: bool):
1817        """
1818        Test switching dropout behavior between train and eval modes using
1819        `move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
1820        """
1821
1822        class M(torch.nn.Module):
1823            def __init__(self) -> None:
1824                super().__init__()
1825                self.dropout = torch.nn.Dropout(0.5, inplace=inplace)
1826
1827            def forward(self, x):
1828                return self.dropout(x)
1829
1830        example_inputs = (torch.randn(1),)
1831        m = M().train()
1832        m = capture_pre_autograd_graph(m, example_inputs)
1833        if inplace:
1834            target = torch.ops.aten.dropout_.default
1835        else:
1836            target = torch.ops.aten.dropout.default
1837
1838        # Assert that dropout op exists and is in train mode
1839        dropout_node = self._get_node(m, target)
1840        self.assertTrue(dropout_node is not None)
1841        self.assertTrue(dropout_node.args[2])
1842
1843        # Move to eval
1844        torch.ao.quantization.move_exported_model_to_eval(m)
1845
1846        # Assert that dropout op is now in eval mode
1847        dropout_node = self._get_node(m, target)
1848        self.assertTrue(dropout_node is not None)
1849        self.assertTrue(not dropout_node.args[2])
1850
1851        # Move back to train
1852        torch.ao.quantization.move_exported_model_to_train(m)
1853
1854        # Assert that dropout op is now in train mode again
1855        dropout_node = self._get_node(m, target)
1856        self.assertTrue(dropout_node is not None)
1857        self.assertTrue(dropout_node.args[2])
1858
1859    def test_move_exported_model_dropout(self):
1860        self._test_move_exported_model_dropout(inplace=False)
1861
1862    def test_move_exported_model_dropout_inplace(self):
1863        self._test_move_exported_model_dropout(inplace=True)
1864
1865    def _get_bn_train_eval_ops(self):
1866        if capture_pre_autograd_graph_using_training_ir():
1867            return (
1868                torch.ops.aten.batch_norm.default,
1869                torch.ops.aten.batch_norm.default,
1870            )
1871        # TODO: This branch is going through a deprecated branch and should be deleted soon,
1872        # after capture_pre_autograd_graph fully migrate to training IR
1873        # T199018392
1874        if TEST_WITH_ROCM:
1875            return (
1876                torch.ops.aten.miopen_batch_norm.default,
1877                torch.ops.aten.miopen_batch_norm.default,
1878            )
1879        elif TEST_CUDA:
1880            return (
1881                torch.ops.aten.cudnn_batch_norm.default,
1882                torch.ops.aten.cudnn_batch_norm.default,
1883            )
1884        else:
1885            return (
1886                torch.ops.aten._native_batch_norm_legit.default,
1887                torch.ops.aten._native_batch_norm_legit_no_training.default,
1888            )
1889
1890    def test_move_exported_model_bn(self):
1891        """
1892        Test switching batch_norm behavior between train and eval modes using
1893        `move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
1894        """
1895
1896        class M(torch.nn.Module):
1897            def __init__(self) -> None:
1898                super().__init__()
1899                self.bn = torch.nn.BatchNorm2d(3)
1900
1901            def forward(self, x):
1902                return self.bn(x)
1903
1904        if TEST_CUDA:
1905            m = M().train().cuda()
1906            example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
1907        else:
1908            m = M().train()
1909            example_inputs = (torch.randn(1, 3, 3, 3),)
1910        bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
1911        m = capture_pre_autograd_graph(m, example_inputs)
1912
1913        # Assert that batch norm op exists and is in train mode
1914        bn_node = self._get_node(m, bn_train_op)
1915        self.assertTrue(bn_node is not None)
1916        self.assertTrue(bn_node.args[5])
1917
1918        # Move to eval
1919        torch.ao.quantization.move_exported_model_to_eval(m)
1920
1921        # Assert that batch norm op is now in eval mode
1922        bn_node = self._get_node(m, bn_eval_op)
1923        self.assertTrue(bn_node is not None)
1924
1925        # Move to train
1926        torch.ao.quantization.move_exported_model_to_train(m)
1927
1928        # Assert that batch norm op is now in train mode again
1929        bn_node = self._get_node(m, bn_train_op)
1930        self.assertTrue(bn_node is not None)
1931        self.assertTrue(bn_node.args[5])
1932
1933    def test_disallow_eval_train(self):
1934        m = TestHelperModules.ConvWithBNRelu(relu=True)
1935        example_inputs = (torch.rand(3, 3, 5, 5),)
1936
1937        # Before export: this is OK
1938        m.eval()
1939        m.train()
1940
1941        # After export: this is not OK
1942        m = capture_pre_autograd_graph(m, example_inputs)
1943        with self.assertRaises(NotImplementedError):
1944            m.eval()
1945        with self.assertRaises(NotImplementedError):
1946            m.train()
1947
1948        # After prepare: still not OK
1949        quantizer = XNNPACKQuantizer()
1950        m = prepare_qat_pt2e(m, quantizer)
1951        with self.assertRaises(NotImplementedError):
1952            m.eval()
1953        with self.assertRaises(NotImplementedError):
1954            m.train()
1955
1956        # After convert: still not OK
1957        m = convert_pt2e(m)
1958        with self.assertRaises(NotImplementedError):
1959            m.eval()
1960        with self.assertRaises(NotImplementedError):
1961            m.train()
1962
1963    def test_allow_exported_model_train_eval(self):
1964        class M(torch.nn.Module):
1965            def __init__(self) -> None:
1966                super().__init__()
1967                self.bn = torch.nn.BatchNorm2d(3)
1968                self.dropout = torch.nn.Dropout(0.5)
1969
1970            def forward(self, x):
1971                x = self.bn(x)
1972                x = self.dropout(x)
1973                return x
1974
1975        if TEST_CUDA:
1976            m = M().train().cuda()
1977            example_inputs = (torch.randn(1, 3, 3, 3).cuda(),)
1978        else:
1979            m = M().train()
1980            example_inputs = (torch.randn(1, 3, 3, 3),)
1981        bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
1982        m = capture_pre_autograd_graph(m, example_inputs)
1983
1984        def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
1985            targets = [n.target for n in m.graph.nodes]
1986            bn_op = bn_train_op if train else bn_eval_op
1987            bn_node = self._get_node(m, bn_op)
1988            self.assertTrue(bn_node is not None)
1989            if TEST_CUDA:
1990                self.assertEqual(bn_node.args[5], train)
1991            dropout_node = self._get_node(m, torch.ops.aten.dropout.default)
1992            self.assertEqual(dropout_node.args[2], train)
1993
1994        # Before wrapping: this is not OK
1995        with self.assertRaises(NotImplementedError):
1996            m.eval()
1997        with self.assertRaises(NotImplementedError):
1998            m.train()
1999
2000        # After wrapping: does not error and swaps the ops accordingly
2001        torch.ao.quantization.allow_exported_model_train_eval(m)
2002        m.eval()
2003        _assert_ops_are_correct(m, train=False)
2004        m.train()
2005        _assert_ops_are_correct(m, train=True)
2006
2007        # After prepare but before wrapping: this is not OK
2008        quantizer = XNNPACKQuantizer()
2009        m = prepare_qat_pt2e(m, quantizer)
2010        with self.assertRaises(NotImplementedError):
2011            m.eval()
2012        with self.assertRaises(NotImplementedError):
2013            m.train()
2014
2015        # After prepare and after wrapping: does not error and swaps the ops accordingly
2016        torch.ao.quantization.allow_exported_model_train_eval(m)
2017        m.eval()
2018        _assert_ops_are_correct(m, train=False)
2019        m.train()
2020        _assert_ops_are_correct(m, train=True)
2021
2022        # After convert but before wrapping: this is not OK
2023        m = convert_pt2e(m, fold_quantize=True)
2024        with self.assertRaises(NotImplementedError):
2025            m.eval()
2026        with self.assertRaises(NotImplementedError):
2027            m.train()
2028
2029        # After convert and after wrapping: does not error and swaps the ops accordingly
2030        torch.ao.quantization.allow_exported_model_train_eval(m)
2031        m.eval()
2032        _assert_ops_are_correct(m, train=False)
2033        m.train()
2034        _assert_ops_are_correct(m, train=True)
2035
2036    def test_model_is_exported(self):
2037        m = TestHelperModules.ConvWithBNRelu(relu=True)
2038        example_inputs = (torch.rand(3, 3, 5, 5),)
2039        exported_gm = capture_pre_autograd_graph(m, example_inputs)
2040        fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
2041        self.assertTrue(
2042            torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
2043        )
2044        self.assertFalse(
2045            torch.ao.quantization.pt2e.export_utils.model_is_exported(fx_traced_gm)
2046        )
2047        self.assertFalse(torch.ao.quantization.pt2e.export_utils.model_is_exported(m))
2048
2049    def test_reentrant(self):
2050        """Test we can safely call quantization apis multiple times"""
2051        m = TestHelperModules.ConvBnReLU2dAndLinearReLU()
2052        example_inputs = (torch.randn(3, 3, 10, 10),)
2053
2054        quantizer = XNNPACKQuantizer().set_global(
2055            get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
2056        )
2057        m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs)
2058        m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
2059        m(*example_inputs)
2060        m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
2061
2062        quantizer = XNNPACKQuantizer().set_module_type(
2063            torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
2064        )
2065        m = capture_pre_autograd_graph(m, example_inputs)
2066        m = prepare_pt2e(m, quantizer)
2067        m = convert_pt2e(m)
2068
2069        node_occurrence = {
2070            ns.call_function(
2071                torch.ops.quantized_decomposed.quantize_per_tensor.default
2072            ): 4,
2073            # one for weight
2074            ns.call_function(
2075                torch.ops.quantized_decomposed.dequantize_per_tensor.default
2076            ): 5,
2077            ns.call_function(
2078                torch.ops.quantized_decomposed.dequantize_per_channel.default
2079            ): 1,
2080        }
2081        node_list = [
2082            ns.call_function(
2083                torch.ops.quantized_decomposed.dequantize_per_tensor.default
2084            ),
2085            ns.call_function(torch.ops.aten.conv2d.default),
2086            ns.call_function(torch.ops.aten.relu.default),
2087            ns.call_function(
2088                torch.ops.quantized_decomposed.quantize_per_tensor.default
2089            ),
2090            ns.call_function(
2091                torch.ops.quantized_decomposed.dequantize_per_tensor.default
2092            ),
2093            ns.call_function(torch.ops.aten.linear.default),
2094            ns.call_function(
2095                torch.ops.quantized_decomposed.quantize_per_tensor.default
2096            ),
2097        ]
2098        self.checkGraphModuleNodes(
2099            m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
2100        )
2101
2102    def test_groupwise_per_channel_quant(self):
2103        m = TestHelperModules.GroupwiseConv2d()
2104        quantizer = XNNPACKQuantizer()
2105        operator_config = get_symmetric_quantization_config(is_per_channel=True)
2106        quantizer.set_global(operator_config)
2107        example_inputs = m.example_inputs()
2108        m = self._quantize(m, quantizer, example_inputs)
2109        # make sure it runs
2110        m(*example_inputs)
2111
2112    def test_observer_callback(self):
2113        from torch.library import impl, Library
2114
2115        test_lib = Library("test_int4", "DEF")  # noqa: TOR901
2116        test_lib.define(
2117            "quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor"
2118        )
2119
2120        @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
2121        def quantize_per_tensor_int4(
2122            input: torch.Tensor,
2123            scale: float,
2124            zero_point: int,
2125        ) -> torch.Tensor:
2126            inv_scale = 1.0 / scale
2127            return (
2128                torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15)
2129                .to(torch.uint8)
2130                .view(torch.bits8)
2131            )
2132
2133        test_lib.define(
2134            "dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor"
2135        )
2136
2137        @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
2138        def dequantize_per_tensor_int4(
2139            input: torch.Tensor,
2140            scale: float,
2141            zero_point: int,
2142        ) -> torch.Tensor:
2143            return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale
2144
2145        from torch.ao.quantization.observer import ObserverBase
2146
2147        class Int4Observer(ObserverBase):
2148            def __init__(self, *args, **kwargs):
2149                # just faking a dtype here
2150                super().__init__(dtype=torch.int8)
2151
2152            def forward(self, x):
2153                return x
2154
2155            def calculate_qparams(self, **kwargs):
2156                pass
2157
2158            def convert(self, model: torch.fx.GraphModule, observer_node: Node):
2159                with model.graph.inserting_before(observer_node):
2160                    q_node = model.graph.call_function(
2161                        torch.ops.test_int4.quantize_per_tensor_int4,
2162                        (observer_node.args[0], 1.0, 0),
2163                        {},
2164                    )
2165                    dq_node = model.graph.call_function(
2166                        torch.ops.test_int4.dequantize_per_tensor_int4,
2167                        (q_node, 1.0, 0),
2168                        {},
2169                    )
2170                    observer_node.replace_all_uses_with(dq_node)
2171                    model.graph.erase_node(observer_node)
2172
2173        class BackendAQuantizer(Quantizer):
2174            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2175                for node in model.graph.nodes:
2176                    if (
2177                        node.op == "call_function"
2178                        and node.target == torch.ops.aten.add.Tensor
2179                    ):
2180                        input_act0 = node.args[0]
2181                        assert isinstance(input_act0, Node)
2182                        input_act1 = node.args[1]
2183                        assert isinstance(input_act1, Node)
2184
2185                        act_qspec = QuantizationSpec(
2186                            dtype=torch.uint8,
2187                            quant_min=0,
2188                            quant_max=255,
2189                            qscheme=torch.per_tensor_affine,
2190                            is_dynamic=False,
2191                            observer_or_fake_quant_ctr=Int4Observer,
2192                        )
2193                        node.meta["quantization_annotation"] = QuantizationAnnotation(
2194                            input_qspec_map={
2195                                input_act0: act_qspec,
2196                                input_act1: act_qspec,
2197                            },
2198                            output_qspec=act_qspec,
2199                            _annotated=True,
2200                        )
2201
2202            def validate(self, model: torch.fx.GraphModule) -> None:
2203                pass
2204
2205        class M(torch.nn.Module):
2206            def forward(self, x1, x2):
2207                return x1 + x2
2208
2209        example_inputs = (
2210            torch.randn(1, 3, 5, 5),
2211            torch.randn(1, 3, 5, 5),
2212        )
2213        node_occurrence = {
2214            # two for input of the first conv, one for output for the first conv
2215            torch.ops.test_int4.quantize_per_tensor_int4: 3,
2216            torch.ops.test_int4.dequantize_per_tensor_int4: 3,
2217        }
2218        node_list = [
2219            torch.ops.test_int4.dequantize_per_tensor_int4,
2220            torch.ops.test_int4.dequantize_per_tensor_int4,
2221            torch.ops.aten.add.Tensor,
2222            torch.ops.test_int4.quantize_per_tensor_int4,
2223        ]
2224        self._test_quantizer(
2225            M().eval(),
2226            example_inputs,
2227            BackendAQuantizer(),
2228            node_occurrence,
2229            node_list,
2230        )
2231
2232    def test_speed(self):
2233        import time
2234
2235        def dynamic_quantize_pt2e(model, example_inputs):
2236            torch._dynamo.reset()
2237            model = capture_pre_autograd_graph(model, example_inputs)
2238            # Per channel quantization for weight
2239            # Dynamic quantization for activation
2240            # Please read a detail: https://fburl.com/code/30zds51q
2241            embedding_quantizer = EmbeddingQuantizer()
2242            dynamic_quantizer = XNNPACKQuantizer()
2243            operator_config_dynamic = get_symmetric_quantization_config(
2244                is_per_channel=True, is_dynamic=True
2245            )
2246            dynamic_quantizer.set_global(operator_config_dynamic)
2247            composed_quantizer = ComposableQuantizer(
2248                [embedding_quantizer, dynamic_quantizer]
2249            )
2250            prev = time.time()
2251            model = prepare_qat_pt2e(model, composed_quantizer)
2252            cur = time.time()
2253            # print("prepare time:", cur - prev)
2254            # Without Calibraiton, scale/zero value will have an initialized value of 1.0
2255            # Per channel quantization needs a proper scale/zero shape/value to work properly.
2256            # So we need to run calibration before converting to quantized model.
2257            model(*example_inputs)
2258            prev = time.time()
2259            model = convert_pt2e(model)
2260            cur = time.time()
2261            # uncomment to see the time
2262            # print("convert time:", cur - prev)
2263            return model
2264
2265        class M(torch.nn.Module):
2266            def __init__(self) -> None:
2267                super().__init__()
2268                self.linear = torch.nn.Linear(5, 5)
2269
2270            def forward(self, x):
2271                return self.linear(x)
2272
2273        m = M().eval()
2274        example_inputs = (torch.randn(5, 5),)
2275        _ = dynamic_quantize_pt2e(m, example_inputs)
2276
2277    def test_conv_transpose_bn_relu(self):
2278        class BackendAQuantizer(Quantizer):
2279            def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2280                int8_qspec = QuantizationSpec(
2281                    dtype=torch.int8,
2282                    quant_min=-128,
2283                    quant_max=127,
2284                    qscheme=torch.per_tensor_symmetric,
2285                    is_dynamic=False,
2286                    observer_or_fake_quant_ctr=observer.default_weight_observer,
2287                )
2288                quantization_config = QuantizationConfig(
2289                    input_activation=int8_qspec,
2290                    weight=int8_qspec,
2291                    bias=None,
2292                    output_activation=int8_qspec,
2293                )
2294                # conv_transpose + bn is fused automatically in PTQ (not configurable)
2295                # so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
2296                # pattern
2297                OP_TO_ANNOTATOR["conv_transpose_relu"](model, quantization_config)
2298
2299            def validate(self, model: torch.fx.GraphModule) -> None:
2300                pass
2301
2302        example_inputs = (torch.randn(1, 3, 5, 5),)
2303        node_occurrence = {
2304            # two for input of the first conv, one for output for the first conv
2305            torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
2306            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
2307        }
2308        node_list = [
2309            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2310            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2311            torch.ops.aten.conv_transpose2d.input,
2312            torch.ops.aten.relu.default,
2313            torch.ops.quantized_decomposed.quantize_per_tensor.default,
2314        ]
2315        self._test_quantizer(
2316            TestHelperModules.ConvTWithBNRelu(relu=True, bn=True),
2317            example_inputs,
2318            BackendAQuantizer(),
2319            node_occurrence,
2320            node_list,
2321        )
2322
2323    def test_multi_users_without_output_observer(self):
2324        """
2325        Test the case in which a node is used by multiple users,
2326        and had its output observer removed.
2327        """
2328
2329        class M(torch.nn.Module):
2330            def __init__(self) -> None:
2331                super().__init__()
2332                self.conv = torch.nn.Conv2d(3, 3, 3)
2333
2334            def forward(self, x):
2335                x = self.conv(x)
2336                return x, x + 1
2337
2338        example_inputs = (torch.randn(1, 3, 5, 5),)
2339        m = M()
2340        m = capture_pre_autograd_graph(m, example_inputs)
2341        quantizer = XNNPACKQuantizer().set_global(
2342            get_symmetric_quantization_config(),
2343        )
2344        m = prepare_pt2e(m, quantizer)
2345        m(*example_inputs)
2346
2347        # Remove output observer
2348        observer_to_remove = None
2349        for n in m.graph.nodes:
2350            if n.op == "output":
2351                observer_to_remove = n.args[0][0]
2352                assert observer_to_remove.op == "call_module"
2353                assert observer_to_remove.target.startswith("activation_post_process_")
2354                break
2355        assert observer_to_remove is not None
2356        observer_to_remove.replace_all_uses_with(observer_to_remove.args[0])
2357        m.graph.erase_node(observer_to_remove)
2358        m.recompile()
2359
2360        # Convert should succeed
2361        m = convert_pt2e(m)
2362        m(*example_inputs)
2363
2364
2365instantiate_parametrized_tests(TestQuantizePT2E)
2366