xref: /aosp_15_r20/external/executorch/exir/tests/test_quantization.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10
11import torch
12import torchvision
13from executorch.exir import EdgeCompileConfig, to_edge
14from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
15from executorch.exir.passes.spec_prop_pass import SpecPropPass
16from torch.ao.ns.fx.utils import compute_sqnr
17from torch.ao.quantization import QConfigMapping  # @manual
18from torch.ao.quantization.backend_config import get_executorch_backend_config
19from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig
20from torch.ao.quantization.quantize_fx import prepare_fx
21from torch.ao.quantization.quantize_pt2e import (
22    _convert_to_reference_decomposed_fx,
23    convert_pt2e,
24    prepare_pt2e,
25)
26
27from torch.ao.quantization.quantizer.xnnpack_quantizer import (
28    get_symmetric_quantization_config,
29    XNNPACKQuantizer,
30)
31from torch.export import export
32from torch.testing import FileCheck
33from torch.testing._internal.common_quantized import override_quantized_engine
34
35# load executorch out variant ops
36torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")
37
38
39class TestQuantization(unittest.TestCase):
40    """prepare_pt2e and convert_pt2e are OSS APIs, the rest are all meta-only
41
42    APIs for now, but we plan to open source them in the future
43    """
44
45    def test_resnet(self) -> None:
46        import copy
47
48        with override_quantized_engine("qnnpack"):
49            torch.backends.quantized.engine = "qnnpack"
50            example_inputs = (torch.randn(1, 3, 224, 224),)
51            m = torchvision.models.resnet18().eval()
52            m_copy = copy.deepcopy(m)
53            # program capture
54            m = torch.export.export_for_training(
55                m, copy.deepcopy(example_inputs)
56            ).module()
57
58            quantizer = XNNPACKQuantizer()
59            operator_config = get_symmetric_quantization_config(is_per_channel=True)
60            quantizer.set_global(operator_config)
61            m = prepare_pt2e(m, quantizer)  # pyre-fixme[6]
62            self.assertEqual(
63                id(m.activation_post_process_3), id(m.activation_post_process_2)
64            )
65            after_prepare_result = m(*example_inputs)[0]
66            m = convert_pt2e(m)
67
68            # TODO: conv, conv_relu, linear delegation
69            # quantized ops to implement: add_relu
70            compile_config = EdgeCompileConfig(
71                _check_ir_validity=False,
72            )
73            m = to_edge(
74                export(m, example_inputs), compile_config=compile_config
75            ).transform([QuantFusionPass(), SpecPropPass()])
76
77            after_quant_result = m.exported_program().module()(*example_inputs)[0]
78            FileCheck().check(
79                "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor"
80            ).check(
81                "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor"
82            ).run(
83                m.exported_program().graph_module.code
84            )
85            # after_quant_fusion_result = m(*example_inputs)[0]
86
87            # TODO: implement torch.ops.quantized_decomposed.add_relu.out
88            # m = m.to_executorch().dump_graph_module()
89            # after_to_executorch = m(*example_inputs)[0]
90            # test the result before and after to_executorch matches
91            # TODO: debug why this is a mismatch
92            # self.assertTrue(torch.equal(after_quant_fusion_result, after_to_executorch))
93            # self.assertEqual(compute_sqnr(after_quant_fusion_result, after_to_executorch), torch.tensor(float("inf")))
94
95            # comparing with existing fx graph mode quantization reference flow
96            qconfig = default_per_channel_symmetric_qnnpack_qconfig
97            qconfig_mapping = QConfigMapping().set_global(qconfig)
98            backend_config = get_executorch_backend_config()
99            m_fx = prepare_fx(
100                m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
101            )
102            after_prepare_result_fx = m_fx(*example_inputs)
103            m_fx = _convert_to_reference_decomposed_fx(
104                m_fx, backend_config=backend_config
105            )
106            after_quant_result_fx = m_fx(*example_inputs)
107
108            # the result matches exactly after prepare
109            self.assertTrue(
110                torch.allclose(after_prepare_result, after_prepare_result_fx, atol=1e-6)
111            )
112
113            # there are slight differences after convert due to different implementations
114            # of quant/dequant
115            self.assertTrue(
116                torch.max(after_quant_result - after_quant_result_fx) < 1e-1
117            )
118            self.assertTrue(
119                compute_sqnr(after_quant_result, after_quant_result_fx) > 35
120            )
121