xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_duplicate_dq.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2import copy
3import unittest
4from typing import Any, Dict
5
6import torch
7from torch._export import capture_pre_autograd_graph
8from torch.ao.quantization.observer import (
9    HistogramObserver,
10    MinMaxObserver,
11    PlaceholderObserver,
12)
13from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
14from torch.ao.quantization.quantizer import (
15    QuantizationAnnotation,
16    QuantizationSpec,
17    Quantizer,
18    SharedQuantizationSpec,
19)
20from torch.ao.quantization.quantizer.xnnpack_quantizer import (
21    get_symmetric_quantization_config,
22)
23from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
24    OP_TO_ANNOTATOR,
25    QuantizationConfig,
26)
27from torch.testing._internal.common_quantization import QuantizationTestCase
28from torch.testing._internal.common_utils import IS_WINDOWS
29
30
31class TestHelperModules:
32    class Conv2dWithObsSharingOps(torch.nn.Module):
33        def __init__(self) -> None:
34            super().__init__()
35            self.conv = torch.nn.Conv2d(3, 3, 3)
36            self.hardtanh = torch.nn.Hardtanh()
37            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
38            self.linear = torch.nn.Linear(3, 3)
39
40        def forward(self, x):
41            x = self.conv(x)
42            x = self.adaptive_avg_pool2d(x)
43            x = self.hardtanh(x)
44            x = x.view(-1, 3)
45            x = self.linear(x)
46            return x
47
48    class Conv2dWithSharedDQ(torch.nn.Module):
49        def __init__(self) -> None:
50            super().__init__()
51            self.conv1 = torch.nn.Conv2d(3, 3, 3)
52            self.conv2 = torch.nn.Conv2d(3, 3, 1)
53            self.linear = torch.nn.Linear(3, 3)
54
55        def forward(self, x):
56            x = self.conv1(x)
57            z = x.view(-1, 3)
58            w = self.linear(z)
59
60            y = self.conv2(x)
61            add_output = x + y
62
63            extra_output = x * 2
64            return w, add_output, extra_output
65
66    class ModuleForDifferentQconfig(torch.nn.Module):
67        def __init__(self) -> None:
68            super().__init__()
69            self.conv1 = torch.nn.Conv2d(3, 3, 3)
70            self.conv2 = torch.nn.Conv2d(3, 3, 1)
71            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
72
73        def forward(self, x):
74            x = self.conv1(x)
75            w = self.adaptive_avg_pool2d(x)
76
77            y = self.conv2(x)
78            add_output = x + y
79
80            extra_output = x + 2
81            return w, add_output, extra_output
82
83
84_DEQUANTIZE_OPS = [
85    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
86    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
87    torch.ops.quantized_decomposed.dequantize_per_channel.default,
88]
89
90
91@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
92class TestDuplicateDQPass(QuantizationTestCase):
93    def _test_duplicate_dq(
94        self,
95        model,
96        example_inputs,
97        quantizer,
98    ):
99        m_eager = model.eval()
100
101        # program capture
102        m = copy.deepcopy(m_eager)
103        m = capture_pre_autograd_graph(
104            m,
105            example_inputs,
106        )
107
108        m = prepare_pt2e(m, quantizer)
109        # Calibrate
110        m(*example_inputs)
111        m = convert_pt2e(m)
112
113        pt2_quant_output = m(*example_inputs)
114        for n in m.graph.nodes:
115            annotation = n.meta.get("quantization_annotation", None)
116            if annotation is not None:
117                for arg in n.args:
118                    if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
119                        self.assertEqual(len(arg.users.keys()), 1)
120
121    def test_no_need_for_duplicate_dq(self):
122        """
123        Model under test
124        conv2d -> avgpool -> hardtanh -> linear
125        Check quantization tags on conv2d, avgpool and linear are correctly set
126        """
127
128        class BackendAQuantizer(Quantizer):
129            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
130                backend_string = "BackendA"
131                quantization_config = get_symmetric_quantization_config(
132                    is_per_channel=True
133                )
134                OP_TO_ANNOTATOR["linear"](gm, quantization_config)
135                OP_TO_ANNOTATOR["conv"](gm, quantization_config)
136                OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
137
138            def validate(self, model: torch.fx.GraphModule) -> None:
139                pass
140
141        example_inputs = (torch.randn(1, 3, 5, 7),)
142        self._test_duplicate_dq(
143            TestHelperModules.Conv2dWithObsSharingOps(),
144            example_inputs,
145            BackendAQuantizer(),
146        )
147
148    def test_simple_duplicate_dq(self):
149        """
150        Model under test
151        conv2d -> conv2d -> add
152             |          |
153              --------->
154             |
155              -----> view_copy --> linear
156             |
157              -----> mul
158        There should be three dq nodes because output for the
159        first conv2d is fed to next conv2d, add, and view_copy + linear.
160        All three are quantized.
161        Thus DQ node is not duplicated for those three uses
162        """
163
164        class BackendAQuantizer(Quantizer):
165            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
166                backend_string = "BackendA"
167                quantization_config = get_symmetric_quantization_config(
168                    is_per_channel=True
169                )
170                OP_TO_ANNOTATOR["linear"](gm, quantization_config)
171                OP_TO_ANNOTATOR["conv"](gm, quantization_config)
172                OP_TO_ANNOTATOR["add"](gm, quantization_config)
173
174            def validate(self, model: torch.fx.GraphModule) -> None:
175                pass
176
177        example_inputs = (torch.randn(1, 3, 5, 7),)
178        self._test_duplicate_dq(
179            TestHelperModules.Conv2dWithSharedDQ(),
180            example_inputs,
181            BackendAQuantizer(),
182        )
183
184    def test_no_add_quant_duplicate_dq(self):
185        """
186        Model under test
187        conv2d -> conv2d -> add
188             |          |
189              --------->
190             |
191              -----> view_copy --> linear
192             |
193              -----> mul
194        There should be three dq nodes because output for the
195        first conv2d is fed to next conv2d, and view_copy + linear.
196        Both are quantized.
197        However the skip connection to add and mul are not quantized.
198        Thus DQ node is not duplicated for those two uses
199        """
200
201        class BackendAQuantizer(Quantizer):
202            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
203                backend_string = "BackendA"
204                quantization_config = get_symmetric_quantization_config(
205                    is_per_channel=True
206                )
207                OP_TO_ANNOTATOR["linear"](gm, quantization_config)
208                OP_TO_ANNOTATOR["conv"](gm, quantization_config)
209
210            def validate(self, model: torch.fx.GraphModule) -> None:
211                pass
212
213        example_inputs = (torch.randn(1, 3, 5, 7),)
214        self._test_duplicate_dq(
215            TestHelperModules.Conv2dWithSharedDQ(),
216            example_inputs,
217            BackendAQuantizer(),
218        )
219
220    def test_avgpool_use_different_qconfig(self):
221        """
222        Model under test
223        conv2d -> conv2d -> add
224             |          |
225              --------->
226             |
227              -----> adaptive_avgpool2d (different qconfig)
228             |
229              -----> add
230        output
231        conv2d -> dq -> conv2d -> add
232             |                  |
233              -------> dq ----->
234             |
235              -> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
236             |
237              -> dq -----> add
238        """
239
240        def _get_uint8_quantization_config():
241            act_observer_or_fake_quant_ctr = HistogramObserver  # type: ignore[assignment]
242            act_quantization_spec = QuantizationSpec(
243                dtype=torch.uint8,
244                quant_min=0,
245                quant_max=255,
246                qscheme=torch.per_tensor_affine,
247                observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
248                    eps=2**-12
249                ),
250            )
251            weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (  # noqa: F821
252                MinMaxObserver
253            )
254
255            extra_args: Dict[str, Any] = {"eps": 2**-12}
256            weight_quantization_spec = QuantizationSpec(
257                dtype=torch.uint8,
258                quant_min=0,
259                quant_max=255,
260                qscheme=torch.per_tensor_affine,
261                ch_axis=0,
262                is_dynamic=False,
263                observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
264                    **extra_args
265                ),
266            )
267
268            bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (  # noqa: F821
269                PlaceholderObserver
270            )
271            bias_quantization_spec = QuantizationSpec(
272                dtype=torch.float,
273                observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
274            )
275            quantization_config = QuantizationConfig(
276                act_quantization_spec,
277                act_quantization_spec,
278                weight_quantization_spec,
279                bias_quantization_spec,
280            )
281            return quantization_config
282
283        class BackendAQuantizer(Quantizer):
284            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
285                backend_string = "BackendA"
286                quantization_config = get_symmetric_quantization_config(
287                    is_per_channel=True
288                )
289                avgpool_qconfig = _get_uint8_quantization_config()
290                OP_TO_ANNOTATOR["conv"](gm, quantization_config)
291                OP_TO_ANNOTATOR["add"](gm, quantization_config)
292                for n in gm.graph.nodes:
293                    if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
294                        qspec = avgpool_qconfig.input_activation
295                        input_act = n.args[0]
296                        output_qspec = SharedQuantizationSpec((input_act, n))
297                        n.meta["quantization_annotation"] = QuantizationAnnotation(
298                            input_qspec_map={input_act: qspec},
299                            output_qspec=output_qspec,
300                            _annotated=True,
301                        )
302
303            def validate(self, model: torch.fx.GraphModule) -> None:
304                pass
305
306        example_inputs = (torch.randn(1, 3, 5, 7),)
307        self._test_duplicate_dq(
308            TestHelperModules.ModuleForDifferentQconfig(),
309            example_inputs,
310            BackendAQuantizer(),
311        )
312