1# Owner(s): ["oncall: quantization"] 2 3import copy 4import unittest 5from collections import Counter 6from typing import Dict 7 8import torch 9from torch._export import capture_pre_autograd_graph 10from torch.ao.quantization import ( 11 compare_results, 12 CUSTOM_KEY, 13 extract_results_from_loggers, 14 generate_numeric_debug_handle, 15 NUMERIC_DEBUG_HANDLE_KEY, 16 prepare_for_propagation_comparison, 17) 18from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 19from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 20 get_symmetric_quantization_config, 21 XNNPACKQuantizer, 22) 23from torch.testing._internal.common_quantization import TestHelperModules 24from torch.testing._internal.common_utils import IS_WINDOWS, TestCase 25 26 27def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: 28 debug_handle_map: Dict[torch.fx.Node, int] = {} 29 30 for node in model.graph.nodes: 31 if ( 32 CUSTOM_KEY in node.meta 33 and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] 34 ): 35 debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ 36 NUMERIC_DEBUG_HANDLE_KEY 37 ] 38 39 return debug_handle_map 40 41 42def is_fbcode(): 43 return not hasattr(torch.version, "git_version") 44 45 46@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") 47class TestNumericDebugger(TestCase): 48 def test_simple(self): 49 m = TestHelperModules.Conv2dThenConv1d() 50 example_inputs = m.example_inputs() 51 m = torch.export.export(m, example_inputs) 52 generate_numeric_debug_handle(m) 53 unique_ids = set() 54 count = 0 55 for n in m.graph.nodes: 56 if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]: 57 unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]) 58 count += 1 59 self.assertEqual(len(unique_ids), count) 60 61 @unittest.skipIf( 62 is_fbcode(), 63 "fbcode changes the code path for `capture_pre_autograd_graph` " 64 "we can enable the test in fbcode after we remove `capture_pre_autograd_graph`", 65 ) 66 def test_quantize_pt2e_preserve_handle(self): 67 m = TestHelperModules.Conv2dThenConv1d() 68 example_inputs = m.example_inputs() 69 m = capture_pre_autograd_graph(m, example_inputs) 70 generate_numeric_debug_handle(m) 71 72 quantizer = XNNPACKQuantizer().set_global( 73 get_symmetric_quantization_config(is_per_channel=False) 74 ) 75 m = prepare_pt2e(m, quantizer) 76 debug_handle_map = _extract_debug_handles(m) 77 res_counter = Counter(debug_handle_map.values()) 78 repeated_debug_handle_ids = [2, 3, 6] 79 # 3 ids were repeated because we copy over the id from node to its output observer 80 # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default 81 for dh_id in repeated_debug_handle_ids: 82 self.assertEqual(res_counter[dh_id], 2) 83 84 m(*example_inputs) 85 m = convert_pt2e(m) 86 debug_handle_map = _extract_debug_handles(m) 87 res_counter = Counter(debug_handle_map.values()) 88 # same set of ids where repeated, because we copy over the id from observer/fake_quant to 89 # dequantize node 90 repeated_debug_handle_ids = [2, 3, 6] 91 for dh_id in repeated_debug_handle_ids: 92 self.assertEqual(res_counter[dh_id], 2) 93 94 def test_copy_preserve_handle(self): 95 m = TestHelperModules.Conv2dThenConv1d() 96 example_inputs = m.example_inputs() 97 m = torch.export.export(m, example_inputs) 98 generate_numeric_debug_handle(m) 99 100 debug_handle_map_ref = _extract_debug_handles(m) 101 102 m_copy = copy.copy(m) 103 debug_handle_map = _extract_debug_handles(m_copy) 104 105 self.assertEqual(debug_handle_map, debug_handle_map_ref) 106 107 def test_deepcopy_preserve_handle(self): 108 m = TestHelperModules.Conv2dThenConv1d() 109 example_inputs = m.example_inputs() 110 m = torch.export.export(m, example_inputs) 111 generate_numeric_debug_handle(m) 112 113 debug_handle_map_ref = _extract_debug_handles(m) 114 m_copy = copy.deepcopy(m) 115 debug_handle_map = _extract_debug_handles(m_copy) 116 117 self.assertEqual(debug_handle_map, debug_handle_map_ref) 118 119 @unittest.skip("All nodes' meta are preserved but get_attr nodes' meta are wrong.") 120 def test_re_export_preserve_handle(self): 121 m = TestHelperModules.Conv2dThenConv1d() 122 example_inputs = m.example_inputs() 123 m = capture_pre_autograd_graph(m, example_inputs) 124 generate_numeric_debug_handle(m) 125 126 debug_handle_map_ref = _extract_debug_handles(m) 127 m_export = capture_pre_autograd_graph(m, example_inputs) 128 debug_handle_map = _extract_debug_handles(m_export) 129 130 self.assertEqual(debug_handle_map, debug_handle_map_ref) 131 132 @unittest.skip( 133 "All nodes' meta are preserved but the first arg for the first node seems to be dropped" 134 ) 135 def test_run_decompositions_preserve_handle(self): 136 m = TestHelperModules.Conv2dThenConv1d() 137 example_inputs = m.example_inputs() 138 m = torch.export.export(m, example_inputs) 139 generate_numeric_debug_handle(m) 140 141 debug_handle_map_ref = _extract_debug_handles(m) 142 143 m_copy = copy.copy(m) 144 m_copy = m_copy.run_decompositions() 145 debug_handle_map = _extract_debug_handles(m_copy) 146 147 # checking the map still has the same ids, the node may change 148 self.assertEqual( 149 set(debug_handle_map.values()), set(debug_handle_map_ref.values()) 150 ) 151 152 def test_prepare_for_propagation_comparison(self): 153 m = TestHelperModules.Conv2dThenConv1d() 154 example_inputs = m.example_inputs() 155 m = capture_pre_autograd_graph(m, example_inputs) 156 generate_numeric_debug_handle(m) 157 m_logger = prepare_for_propagation_comparison(m) 158 ref = m(*example_inputs) 159 res = m_logger(*example_inputs) 160 161 from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger 162 163 loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)] 164 self.assertEqual(len(loggers), 7) 165 self.assertTrue("conv2d" in [logger.node_name for logger in loggers]) 166 self.assertEqual(res, ref) 167 168 def test_extract_results_from_loggers(self): 169 m = TestHelperModules.Conv2dThenConv1d() 170 example_inputs = m.example_inputs() 171 m = capture_pre_autograd_graph(m, example_inputs) 172 generate_numeric_debug_handle(m) 173 m_ref_logger = prepare_for_propagation_comparison(m) 174 175 quantizer = XNNPACKQuantizer().set_global( 176 get_symmetric_quantization_config(is_per_channel=False) 177 ) 178 m = prepare_pt2e(m, quantizer) 179 m(*example_inputs) 180 m = convert_pt2e(m) 181 m_quant_logger = prepare_for_propagation_comparison(m) 182 183 m_ref_logger(*example_inputs) 184 m_quant_logger(*example_inputs) 185 ref_results = extract_results_from_loggers(m_ref_logger) 186 quant_results = extract_results_from_loggers(m_quant_logger) 187 comparison_results = compare_results(ref_results, quant_results) 188 for node_summary in comparison_results.values(): 189 if len(node_summary.results) > 0: 190 self.assertGreaterEqual(node_summary.results[0].sqnr, 35) 191