xref: /aosp_15_r20/external/pytorch/test/quantization/pt2e/test_numeric_debugger.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import copy
4import unittest
5from collections import Counter
6from typing import Dict
7
8import torch
9from torch._export import capture_pre_autograd_graph
10from torch.ao.quantization import (
11    compare_results,
12    CUSTOM_KEY,
13    extract_results_from_loggers,
14    generate_numeric_debug_handle,
15    NUMERIC_DEBUG_HANDLE_KEY,
16    prepare_for_propagation_comparison,
17)
18from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
19from torch.ao.quantization.quantizer.xnnpack_quantizer import (
20    get_symmetric_quantization_config,
21    XNNPACKQuantizer,
22)
23from torch.testing._internal.common_quantization import TestHelperModules
24from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
25
26
27def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
28    debug_handle_map: Dict[torch.fx.Node, int] = {}
29
30    for node in model.graph.nodes:
31        if (
32            CUSTOM_KEY in node.meta
33            and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
34        ):
35            debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
36                NUMERIC_DEBUG_HANDLE_KEY
37            ]
38
39    return debug_handle_map
40
41
42def is_fbcode():
43    return not hasattr(torch.version, "git_version")
44
45
46@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
47class TestNumericDebugger(TestCase):
48    def test_simple(self):
49        m = TestHelperModules.Conv2dThenConv1d()
50        example_inputs = m.example_inputs()
51        m = torch.export.export(m, example_inputs)
52        generate_numeric_debug_handle(m)
53        unique_ids = set()
54        count = 0
55        for n in m.graph.nodes:
56            if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
57                unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY])
58                count += 1
59        self.assertEqual(len(unique_ids), count)
60
61    @unittest.skipIf(
62        is_fbcode(),
63        "fbcode changes the code path for `capture_pre_autograd_graph` "
64        "we can enable the test in fbcode after we remove `capture_pre_autograd_graph`",
65    )
66    def test_quantize_pt2e_preserve_handle(self):
67        m = TestHelperModules.Conv2dThenConv1d()
68        example_inputs = m.example_inputs()
69        m = capture_pre_autograd_graph(m, example_inputs)
70        generate_numeric_debug_handle(m)
71
72        quantizer = XNNPACKQuantizer().set_global(
73            get_symmetric_quantization_config(is_per_channel=False)
74        )
75        m = prepare_pt2e(m, quantizer)
76        debug_handle_map = _extract_debug_handles(m)
77        res_counter = Counter(debug_handle_map.values())
78        repeated_debug_handle_ids = [2, 3, 6]
79        # 3 ids were repeated because we copy over the id from node to its output observer
80        # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
81        for dh_id in repeated_debug_handle_ids:
82            self.assertEqual(res_counter[dh_id], 2)
83
84        m(*example_inputs)
85        m = convert_pt2e(m)
86        debug_handle_map = _extract_debug_handles(m)
87        res_counter = Counter(debug_handle_map.values())
88        # same set of ids where repeated, because we copy over the id from observer/fake_quant to
89        # dequantize node
90        repeated_debug_handle_ids = [2, 3, 6]
91        for dh_id in repeated_debug_handle_ids:
92            self.assertEqual(res_counter[dh_id], 2)
93
94    def test_copy_preserve_handle(self):
95        m = TestHelperModules.Conv2dThenConv1d()
96        example_inputs = m.example_inputs()
97        m = torch.export.export(m, example_inputs)
98        generate_numeric_debug_handle(m)
99
100        debug_handle_map_ref = _extract_debug_handles(m)
101
102        m_copy = copy.copy(m)
103        debug_handle_map = _extract_debug_handles(m_copy)
104
105        self.assertEqual(debug_handle_map, debug_handle_map_ref)
106
107    def test_deepcopy_preserve_handle(self):
108        m = TestHelperModules.Conv2dThenConv1d()
109        example_inputs = m.example_inputs()
110        m = torch.export.export(m, example_inputs)
111        generate_numeric_debug_handle(m)
112
113        debug_handle_map_ref = _extract_debug_handles(m)
114        m_copy = copy.deepcopy(m)
115        debug_handle_map = _extract_debug_handles(m_copy)
116
117        self.assertEqual(debug_handle_map, debug_handle_map_ref)
118
119    @unittest.skip("All nodes' meta are preserved but get_attr nodes' meta are wrong.")
120    def test_re_export_preserve_handle(self):
121        m = TestHelperModules.Conv2dThenConv1d()
122        example_inputs = m.example_inputs()
123        m = capture_pre_autograd_graph(m, example_inputs)
124        generate_numeric_debug_handle(m)
125
126        debug_handle_map_ref = _extract_debug_handles(m)
127        m_export = capture_pre_autograd_graph(m, example_inputs)
128        debug_handle_map = _extract_debug_handles(m_export)
129
130        self.assertEqual(debug_handle_map, debug_handle_map_ref)
131
132    @unittest.skip(
133        "All nodes' meta are preserved but the first arg for the first node seems to be dropped"
134    )
135    def test_run_decompositions_preserve_handle(self):
136        m = TestHelperModules.Conv2dThenConv1d()
137        example_inputs = m.example_inputs()
138        m = torch.export.export(m, example_inputs)
139        generate_numeric_debug_handle(m)
140
141        debug_handle_map_ref = _extract_debug_handles(m)
142
143        m_copy = copy.copy(m)
144        m_copy = m_copy.run_decompositions()
145        debug_handle_map = _extract_debug_handles(m_copy)
146
147        # checking the map still has the same ids, the node may change
148        self.assertEqual(
149            set(debug_handle_map.values()), set(debug_handle_map_ref.values())
150        )
151
152    def test_prepare_for_propagation_comparison(self):
153        m = TestHelperModules.Conv2dThenConv1d()
154        example_inputs = m.example_inputs()
155        m = capture_pre_autograd_graph(m, example_inputs)
156        generate_numeric_debug_handle(m)
157        m_logger = prepare_for_propagation_comparison(m)
158        ref = m(*example_inputs)
159        res = m_logger(*example_inputs)
160
161        from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger
162
163        loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)]
164        self.assertEqual(len(loggers), 7)
165        self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
166        self.assertEqual(res, ref)
167
168    def test_extract_results_from_loggers(self):
169        m = TestHelperModules.Conv2dThenConv1d()
170        example_inputs = m.example_inputs()
171        m = capture_pre_autograd_graph(m, example_inputs)
172        generate_numeric_debug_handle(m)
173        m_ref_logger = prepare_for_propagation_comparison(m)
174
175        quantizer = XNNPACKQuantizer().set_global(
176            get_symmetric_quantization_config(is_per_channel=False)
177        )
178        m = prepare_pt2e(m, quantizer)
179        m(*example_inputs)
180        m = convert_pt2e(m)
181        m_quant_logger = prepare_for_propagation_comparison(m)
182
183        m_ref_logger(*example_inputs)
184        m_quant_logger(*example_inputs)
185        ref_results = extract_results_from_loggers(m_ref_logger)
186        quant_results = extract_results_from_loggers(m_quant_logger)
187        comparison_results = compare_results(ref_results, quant_results)
188        for node_summary in comparison_results.values():
189            if len(node_summary.results) > 0:
190                self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
191