xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_metadata_porting.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2import copy
3import unittest
4from typing import List
5
6import torch
7import torch._export
8from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
9from torch.ao.quantization.quantizer import QuantizationAnnotation, Quantizer
10from torch.ao.quantization.quantizer.xnnpack_quantizer import (
11    get_symmetric_quantization_config,
12)
13from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
14from torch.fx import Node
15from torch.testing._internal.common_quantization import QuantizationTestCase
16from torch.testing._internal.common_utils import IS_WINDOWS
17
18
19class TestHelperModules:
20    class Conv2dWithObsSharingOps(torch.nn.Module):
21        def __init__(self) -> None:
22            super().__init__()
23            self.conv = torch.nn.Conv2d(3, 3, 3)
24            self.hardtanh = torch.nn.Hardtanh()
25            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
26            self.linear = torch.nn.Linear(3, 3)
27
28        def forward(self, x):
29            x = self.conv(x)
30            x = self.adaptive_avg_pool2d(x)
31            x = self.hardtanh(x)
32            x = x.view(-1, 3)
33            x = self.linear(x)
34            return x
35
36
37def _tag_partitions(
38    backend_name: str, op_name: str, annotated_partitions: List[List[Node]]
39):
40    for index, partition_nodes in enumerate(annotated_partitions):
41        tag_name = backend_name + "_" + op_name + "_" + str(index)
42        for node in partition_nodes:
43            assert "quantization_tag" not in node.meta, f"{node} is already tagged"
44            node.meta["quantization_tag"] = tag_name
45
46
47_QUANT_OPS = {
48    torch.ops.quantized_decomposed.quantize_per_tensor.default,
49    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
50    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
51    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
52    torch.ops.quantized_decomposed.quantize_per_channel.default,
53    torch.ops.quantized_decomposed.dequantize_per_channel.default,
54    torch.ops.quantized_decomposed.choose_qparams.tensor,
55}
56
57
58# TODO: rename to TestPortMetadataPass to align with the util name?
59@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
60class TestMetaDataPorting(QuantizationTestCase):
61    def _test_quant_tag_preservation_through_decomp(
62        self, model, example_inputs, from_node_to_tags
63    ):
64        ep = torch.export.export(model, example_inputs)
65        found_tags = True
66        not_found_nodes = ""
67        for from_node, tag in from_node_to_tags.items():
68            for n in ep.graph_module.graph.nodes:
69                from_node_meta = n.meta.get("from_node", None)
70                if from_node_meta is None:
71                    continue
72                if not isinstance(from_node_meta, list):
73                    raise ValueError(
74                        f"from_node metadata is of type {type(from_node_meta)}, but expected list"
75                    )
76                for meta in from_node_meta:
77                    node_target = meta[1]
78                    if node_target == from_node:
79                        node_tag = n.meta.get("quantization_tag", None)
80                        if node_tag is None or tag != node_tag:
81                            not_found_nodes += str(n.target) + ", "
82                            found_tags = False
83                            break
84                if not found_tags:
85                    break
86        self.assertTrue(
87            found_tags,
88            f"Decomposition did not preserve quantization tag for {not_found_nodes}",
89        )
90
91    def _test_metadata_porting(
92        self,
93        model,
94        example_inputs,
95        quantizer,
96        node_tags=None,
97    ) -> torch.fx.GraphModule:
98        m_eager = model.eval()
99
100        # program capture
101        m = copy.deepcopy(m_eager)
102        m = torch._export.capture_pre_autograd_graph(
103            m,
104            example_inputs,
105        )
106
107        m = prepare_pt2e(m, quantizer)
108        # Calibrate
109        m(*example_inputs)
110        m = convert_pt2e(m)
111
112        pt2_quant_output = m(*example_inputs)
113        recorded_node_tags = {}
114        for n in m.graph.nodes:
115            if "quantization_tag" not in n.meta:
116                continue
117            if n.op == "call_function" and n.target in _QUANT_OPS:
118                key = n.target
119            elif n.op == "get_attr":
120                key = "get_attr"
121            else:
122                continue
123
124            if key not in recorded_node_tags:
125                recorded_node_tags[key] = set()
126
127            if (
128                n.op == "call_function"
129                and n.meta["quantization_tag"] in recorded_node_tags[key]
130            ):
131                raise ValueError(
132                    f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that "
133                    "is associated with another node of the same type"
134                )
135            recorded_node_tags[key].add(n.meta["quantization_tag"])
136
137        self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys()))
138        for k, v in recorded_node_tags.items():
139            self.assertEqual(v, node_tags[k])
140        return m
141
142    def test_simple_metadata_porting(self):
143        """
144        Model under test
145        conv2d -> avgpool -> hardtanh -> linear
146        Check quantization tags on conv2d, avgpool and linear are correctly set
147        """
148
149        class BackendAQuantizer(Quantizer):
150            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
151                backend_string = "BackendA"
152                quantization_config = get_symmetric_quantization_config(
153                    is_per_channel=True
154                )
155                annotated_partitions = OP_TO_ANNOTATOR["linear"](
156                    gm, quantization_config
157                )
158                _tag_partitions(backend_string, "linear", annotated_partitions)
159                annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
160                _tag_partitions(backend_string, "conv2d", annotated_partitions)
161                annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
162                    gm, quantization_config
163                )
164                _tag_partitions(
165                    backend_string, "adaptive_avg_pool2d", annotated_partitions
166                )
167
168            def validate(self, model: torch.fx.GraphModule) -> None:
169                pass
170
171        example_inputs = (torch.randn(1, 3, 5, 5),)
172        get_attr_tags = {
173            "BackendA_conv2d_0",
174            "BackendA_linear_0",
175        }
176        quantize_per_tensor_tags = {
177            "BackendA_conv2d_0",
178            "BackendA_adaptive_avg_pool2d_0",
179            "BackendA_linear_0",
180        }
181        dequantize_per_tensor_tags = {
182            "BackendA_adaptive_avg_pool2d_0",
183            "BackendA_conv2d_0",
184            "BackendA_linear_0",
185        }
186        dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
187        node_tags = {
188            "get_attr": get_attr_tags,
189            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
190            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
191            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
192        }
193        m = self._test_metadata_porting(
194            TestHelperModules.Conv2dWithObsSharingOps(),
195            example_inputs,
196            BackendAQuantizer(),
197            node_tags,
198        )
199
200        from_node_to_tags = {
201            torch.ops.aten.adaptive_avg_pool2d.default: "BackendA_adaptive_avg_pool2d_0",
202            torch.ops.aten.linear.default: "BackendA_linear_0",
203        }
204        self._test_quant_tag_preservation_through_decomp(
205            m, example_inputs, from_node_to_tags
206        )
207
208    def test_metadata_porting_with_no_quant_inbetween(self):
209        """
210        Model under test
211        conv2d -> avgpool -> hardtanh -> linear
212        Dont quantize avgpool
213        Check quantization tags on conv2d and linear are correctly set
214        """
215
216        class BackendAQuantizer(Quantizer):
217            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
218                backend_string = "BackendA"
219                quantization_config = get_symmetric_quantization_config(
220                    is_per_channel=True
221                )
222                annotated_partitions = OP_TO_ANNOTATOR["linear"](
223                    gm, quantization_config
224                )
225                _tag_partitions(backend_string, "linear", annotated_partitions)
226                annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
227                _tag_partitions(backend_string, "conv2d", annotated_partitions)
228
229            def validate(self, model: torch.fx.GraphModule) -> None:
230                pass
231
232        example_inputs = (torch.randn(1, 3, 5, 5),)
233        get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
234        quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
235        dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
236        dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"}
237        node_tags = {
238            "get_attr": get_attr_tags,
239            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
240            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
241            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
242        }
243        self._test_metadata_porting(
244            TestHelperModules.Conv2dWithObsSharingOps(),
245            example_inputs,
246            BackendAQuantizer(),
247            node_tags,
248        )
249
250    @unittest.skip("Temporarily disabled")
251    def test_metadata_porting_for_dq(self):
252        """
253        Model under test
254        conv2d -> avgpool -> hardtanh -> linear
255        Quantize all except linear.
256        Quantize linear with dynamic quantization
257        Check quantization tags on conv2d, avgpool and linear are correctly set
258        """
259
260        class BackendAQuantizer(Quantizer):
261            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
262                backend_string = "BackendA"
263                # static quantiazation
264                quantization_config = get_symmetric_quantization_config(
265                    is_per_channel=True
266                )
267                annotated_partitions = OP_TO_ANNOTATOR["conv"](gm, quantization_config)
268                _tag_partitions(backend_string, "conv2d", annotated_partitions)
269                annotated_partitions = OP_TO_ANNOTATOR["adaptive_avg_pool2d"](
270                    gm, quantization_config
271                )
272                _tag_partitions(
273                    backend_string, "adaptive_avg_pool2d", annotated_partitions
274                )
275
276                # dynamic quantization
277                quantization_config_dynamic = get_symmetric_quantization_config(
278                    is_per_channel=True, is_dynamic=True
279                )
280                annotated_partitions = OP_TO_ANNOTATOR["linear"](
281                    gm, quantization_config_dynamic
282                )
283                _tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
284
285            def validate(self, model: torch.fx.GraphModule) -> None:
286                pass
287
288        example_inputs = (torch.randn(1, 3, 5, 5),)
289        # TODO: add get_attr_tags when the test is re-enabled
290        get_attr_tags = {}
291        quantize_per_tensor_tags = {
292            "BackendA_conv2d_0",
293            "BackendA_adaptive_avg_pool2d_0",
294        }
295        quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
296        choose_qparams_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
297        dequantize_per_tensor_tags = {
298            "BackendA_adaptive_avg_pool2d_0",
299            "BackendA_conv2d_0",
300        }
301        dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
302        dequantize_per_channel_tags = {
303            "BackendA_conv2d_0",
304            "BackendA_linear_dynamic_0",
305        }
306        node_tags = {
307            "get_attr": get_attr_tags,
308            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags,
309            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
310            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags,
311            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
312            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
313            torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tensor_tags,
314        }
315        self._test_metadata_porting(
316            TestHelperModules.Conv2dWithObsSharingOps(),
317            example_inputs,
318            BackendAQuantizer(),
319            node_tags,
320        )
321
322    def test_metadata_porting_for_two_dq(self):
323        """
324        Model under test
325        conv2d -> avgpool -> hardtanh -> linear
326        Quantize linear and conv with dynamic quantization
327        Check quantization tags on conv2d, avgpool and linear are correctly set
328        """
329
330        class BackendAQuantizer(Quantizer):
331            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
332                backend_string = "BackendA"
333
334                # dynamic quantization
335                quantization_config_dynamic = get_symmetric_quantization_config(
336                    is_per_channel=True, is_dynamic=True
337                )
338                annotated_partitions = OP_TO_ANNOTATOR["conv"](
339                    gm, quantization_config_dynamic
340                )
341                _tag_partitions(backend_string, "conv2d_dynamic", annotated_partitions)
342                annotated_partitions = OP_TO_ANNOTATOR["linear"](
343                    gm, quantization_config_dynamic
344                )
345                _tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
346
347            def validate(self, model: torch.fx.GraphModule) -> None:
348                pass
349
350        example_inputs = (torch.randn(1, 3, 5, 5),)
351        get_attr_tags = {
352            "BackendA_conv2d_dynamic_0",
353            "BackendA_linear_dynamic_0",
354        }
355        choose_qparams_tensor_tags = {
356            "BackendA_conv2d_dynamic_0",
357            "BackendA_linear_dynamic_0",
358        }
359        quantize_per_tensor_tensor_tags = {
360            "BackendA_conv2d_dynamic_0",
361            "BackendA_linear_dynamic_0",
362        }
363        dequantize_per_tensor_tensor_tags = {
364            "BackendA_conv2d_dynamic_0",
365            "BackendA_linear_dynamic_0",
366        }
367        dequantize_per_channel_tags = {
368            "BackendA_conv2d_dynamic_0",
369            "BackendA_linear_dynamic_0",
370        }
371        node_tags = {
372            "get_attr": get_attr_tags,
373            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
374            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
375            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
376            torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
377        }
378        self._test_metadata_porting(
379            TestHelperModules.Conv2dWithObsSharingOps(),
380            example_inputs,
381            BackendAQuantizer(),
382            node_tags,
383        )
384
385    def test_metadata_porting_for_dq_no_static_q(self):
386        """
387        Model under test
388        conv2d -> avgpool -> hardtanh -> linear
389        Dont quantize anything except linear.
390        Quantize linear with dynamic quantization
391        Check quantization tags on conv2d, avgpool and linear are correctly set
392        """
393
394        class BackendAQuantizer(Quantizer):
395            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
396                backend_string = "BackendA"
397                # dynamic quantization
398                quantization_config_dynamic = get_symmetric_quantization_config(
399                    is_per_channel=True, is_dynamic=True
400                )
401                annotated_partitions = OP_TO_ANNOTATOR["linear"](
402                    gm, quantization_config_dynamic
403                )
404                _tag_partitions(backend_string, "linear_dynamic", annotated_partitions)
405
406            def validate(self, model: torch.fx.GraphModule) -> None:
407                pass
408
409        example_inputs = (torch.randn(1, 3, 5, 5),)
410        get_attr_tags = {"BackendA_linear_dynamic_0"}
411        choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"}
412        quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
413        dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"}
414        dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"}
415        node_tags = {
416            "get_attr": get_attr_tags,
417            torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags,
418            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags,
419            torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags,
420            torch.ops.quantized_decomposed.choose_qparams.tensor: choose_qparams_tensor_tags,
421        }
422        self._test_metadata_porting(
423            TestHelperModules.Conv2dWithObsSharingOps(),
424            example_inputs,
425            BackendAQuantizer(),
426            node_tags,
427        )
428
429    def test_no_metadata_porting(self):
430        class BackendAQuantizer(Quantizer):
431            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
432                backend_string = "BackendA"
433                quantization_config = get_symmetric_quantization_config(
434                    is_per_channel=True
435                )
436                OP_TO_ANNOTATOR["linear"](gm, quantization_config)
437                OP_TO_ANNOTATOR["conv"](gm, quantization_config)
438                OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
439
440            def validate(self, model: torch.fx.GraphModule) -> None:
441                pass
442
443        example_inputs = (torch.randn(1, 3, 5, 5),)
444        node_tags = {}
445        m = self._test_metadata_porting(
446            TestHelperModules.Conv2dWithObsSharingOps(),
447            example_inputs,
448            BackendAQuantizer(),
449            node_tags,
450        )
451
452        from_node_to_tags = {}
453        self._test_quant_tag_preservation_through_decomp(
454            m, example_inputs, from_node_to_tags
455        )
456
457    def test_no_metadata_porting_through_unknown_ops(self):
458        """
459        Model under test
460        matmul -> add -> relu
461        matmul has get_attr as first input, but the quantization_tag should not be
462        propagated to add even if it's part of a chain that ends at get_attr
463        """
464
465        class MatmulWithConstInput(torch.nn.Module):
466            def __init__(self) -> None:
467                super().__init__()
468                self.register_parameter("w", torch.nn.Parameter(torch.rand(8, 16)))
469
470            def forward(self, x, y):
471                x = torch.matmul(self.w, x)
472                z = x + y
473                return torch.nn.functional.relu(z)
474
475        class BackendAQuantizer(Quantizer):
476            def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
477                backend_string = "BackendA"
478                qconfig = get_symmetric_quantization_config()
479                for n in gm.graph.nodes:
480                    if n.op != "call_function":
481                        continue
482
483                    n.meta["quantization_annotation"] = QuantizationAnnotation(
484                        input_qspec_map={n.args[0]: qconfig.input_activation},
485                        output_qspec=qconfig.output_activation,
486                    )
487
488                    tag = str(n.target)
489                    n.meta["quantization_tag"] = tag
490                    for arg in n.args:
491                        if arg.op == "get_attr":
492                            arg.meta["quantization_tag"] = tag
493
494            def validate(self, model: torch.fx.GraphModule) -> None:
495                pass
496
497        example_inputs = (torch.randn(16, 24), torch.randn(8, 24))
498        get_attr_tags = {"aten.matmul.default"}
499        quantize_per_tensor_tensor_tags = {
500            "aten.matmul.default",
501            "aten.add.Tensor",
502            "aten.relu.default",
503        }
504        dequantize_per_tensor_tensor_tags = {
505            "aten.matmul.default",
506            "aten.add.Tensor",
507            "aten.relu.default",
508        }
509        node_tags = {
510            "get_attr": get_attr_tags,
511            torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tensor_tags,
512            torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tensor_tags,
513        }
514        m = self._test_metadata_porting(
515            MatmulWithConstInput(),
516            example_inputs,
517            BackendAQuantizer(),
518            node_tags,
519        )
520