xref: /aosp_15_r20/external/executorch/backends/transforms/test/test_duplicate_dynamic_quant_chain.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import unittest
9
10import torch
11from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
12    DuplicateDynamicQuantChainPass,
13)
14from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
15from torch.ao.quantization.quantizer.xnnpack_quantizer import (
16    get_symmetric_quantization_config,
17    XNNPACKQuantizer,
18)
19
20# TODO: Move away from using torch's internal testing utils
21from torch.testing._internal.common_quantization import (
22    NodeSpec as ns,
23    QuantizationTestCase,
24    TestHelperModules,
25)
26
27
28class MyTestHelperModules:
29    class TwoFanOutLinears(torch.nn.Module):
30        def __init__(self):
31            super().__init__()
32            self.linear1 = torch.nn.Linear(8, 16, bias=False)
33            self.linear2 = torch.nn.Linear(8, 16)
34
35        def forward(self, x):
36            x1 = self.linear1(x)
37            x2 = self.linear2(x)
38            return x1 + x2
39
40
41_DEQUANTIZE_OPS = [
42    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
43    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
44    torch.ops.quantized_decomposed.dequantize_per_channel.default,
45]
46
47
48class TestDuplicateDynamicQuantChainPass(QuantizationTestCase):
49    def _test_duplicate_chain(
50        self,
51        model,
52        example_inputs,
53        quantizer,
54        before_node_occurrences,
55        after_node_occurrences,
56    ):
57        m_eager = model.eval()
58
59        # program capture
60        m = copy.deepcopy(m_eager)
61        m = torch.export.export_for_training(
62            m,
63            example_inputs,
64        ).module()
65
66        m = prepare_pt2e(m, quantizer)
67        # Calibrate
68        m(*example_inputs)
69        m = convert_pt2e(m, fold_quantize=True)
70        print(m)
71        node_occurrence = {
72            ns.call_function(k): v for k, v in before_node_occurrences.items()
73        }
74        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
75        DuplicateDynamicQuantChainPass()(m)
76        node_occurrence = {
77            ns.call_function(k): v for k, v in after_node_occurrences.items()
78        }
79        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
80        return m
81
82    def test_no_need_for_duplicate(self):
83        """
84        Model under test
85        linear -> linear
86        Check two chose qparams, q, dq before and after the pass
87        """
88
89        quantizer = XNNPACKQuantizer()
90        quantization_config = get_symmetric_quantization_config(
91            is_per_channel=True, is_dynamic=True
92        )
93        quantizer.set_global(quantization_config)
94        example_inputs = (torch.randn(9, 8),)
95        before_node_occurrence = {
96            torch.ops.quantized_decomposed.choose_qparams.tensor: 2,
97            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
98            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
99            # note: quantize op for weights are const propagated
100            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
101            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
102        }
103        self._test_duplicate_chain(
104            TestHelperModules.TwoLinearModule().eval(),
105            example_inputs,
106            quantizer,
107            before_node_occurrences=before_node_occurrence,
108            after_node_occurrences=before_node_occurrence,
109        )
110
111    def test_simple_duplicate_chain(self):
112        """
113        Model under test
114        x -> linear  -> add
115         |           |
116          -> linear -
117        Before duplication there should be only 1 dynamic q chain
118        After duplication there should be 2 dynamic q chains
119        """
120
121        quantizer = XNNPACKQuantizer()
122        quantization_config = get_symmetric_quantization_config(
123            is_per_channel=True, is_dynamic=True
124        )
125        quantizer.set_global(quantization_config)
126        example_inputs = (torch.randn(9, 8),)
127        before_node_occurrence = {
128            torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
129            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
130            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
131            # note: quantize op for weights are const propagated
132            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
133            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
134        }
135        after_node_occurrence = {
136            torch.ops.quantized_decomposed.choose_qparams.tensor: 2,
137            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2,
138            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2,
139            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
140            torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
141        }
142        self._test_duplicate_chain(
143            MyTestHelperModules.TwoFanOutLinears().eval(),
144            example_inputs,
145            quantizer,
146            before_node_occurrences=before_node_occurrence,
147            after_node_occurrences=after_node_occurrence,
148        )
149
150    @unittest.skip("Set module name API does not work correctly when used as here.")
151    def test_no_duplicate_chain_different_qscheme(self):
152        """
153        Model under test
154        x -> linear1  -> linear 2
155        """
156
157        quantizer = XNNPACKQuantizer()
158        dynamic_qconfig = get_symmetric_quantization_config(
159            is_per_channel=True, is_dynamic=True
160        )
161        static_qconfig = get_symmetric_quantization_config(is_per_channel=False)
162        quantizer.set_module_name("linear1", dynamic_qconfig)
163        quantizer.set_module_name("linear2", static_qconfig)
164        example_inputs = (torch.randn(9, 8),)
165        before_node_occurrence = {
166            torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
167            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
168            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
169            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
170            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
171            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
172        }
173        after_node_occurrence = {
174            torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
175            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
176            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
177            torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
178            torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
179            torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
180        }
181        self._test_duplicate_chain(
182            TestHelperModules.TwoLinearModule().eval(),
183            example_inputs,
184            quantizer,
185            before_node_occurrences=before_node_occurrence,
186            after_node_occurrences=after_node_occurrence,
187        )
188