1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import unittest 9 10import torch 11from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( 12 DuplicateDynamicQuantChainPass, 13) 14from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 15from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 16 get_symmetric_quantization_config, 17 XNNPACKQuantizer, 18) 19 20# TODO: Move away from using torch's internal testing utils 21from torch.testing._internal.common_quantization import ( 22 NodeSpec as ns, 23 QuantizationTestCase, 24 TestHelperModules, 25) 26 27 28class MyTestHelperModules: 29 class TwoFanOutLinears(torch.nn.Module): 30 def __init__(self): 31 super().__init__() 32 self.linear1 = torch.nn.Linear(8, 16, bias=False) 33 self.linear2 = torch.nn.Linear(8, 16) 34 35 def forward(self, x): 36 x1 = self.linear1(x) 37 x2 = self.linear2(x) 38 return x1 + x2 39 40 41_DEQUANTIZE_OPS = [ 42 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 43 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 44 torch.ops.quantized_decomposed.dequantize_per_channel.default, 45] 46 47 48class TestDuplicateDynamicQuantChainPass(QuantizationTestCase): 49 def _test_duplicate_chain( 50 self, 51 model, 52 example_inputs, 53 quantizer, 54 before_node_occurrences, 55 after_node_occurrences, 56 ): 57 m_eager = model.eval() 58 59 # program capture 60 m = copy.deepcopy(m_eager) 61 m = torch.export.export_for_training( 62 m, 63 example_inputs, 64 ).module() 65 66 m = prepare_pt2e(m, quantizer) 67 # Calibrate 68 m(*example_inputs) 69 m = convert_pt2e(m, fold_quantize=True) 70 print(m) 71 node_occurrence = { 72 ns.call_function(k): v for k, v in before_node_occurrences.items() 73 } 74 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 75 DuplicateDynamicQuantChainPass()(m) 76 node_occurrence = { 77 ns.call_function(k): v for k, v in after_node_occurrences.items() 78 } 79 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 80 return m 81 82 def test_no_need_for_duplicate(self): 83 """ 84 Model under test 85 linear -> linear 86 Check two chose qparams, q, dq before and after the pass 87 """ 88 89 quantizer = XNNPACKQuantizer() 90 quantization_config = get_symmetric_quantization_config( 91 is_per_channel=True, is_dynamic=True 92 ) 93 quantizer.set_global(quantization_config) 94 example_inputs = (torch.randn(9, 8),) 95 before_node_occurrence = { 96 torch.ops.quantized_decomposed.choose_qparams.tensor: 2, 97 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, 98 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, 99 # note: quantize op for weights are const propagated 100 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 101 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 102 } 103 self._test_duplicate_chain( 104 TestHelperModules.TwoLinearModule().eval(), 105 example_inputs, 106 quantizer, 107 before_node_occurrences=before_node_occurrence, 108 after_node_occurrences=before_node_occurrence, 109 ) 110 111 def test_simple_duplicate_chain(self): 112 """ 113 Model under test 114 x -> linear -> add 115 | | 116 -> linear - 117 Before duplication there should be only 1 dynamic q chain 118 After duplication there should be 2 dynamic q chains 119 """ 120 121 quantizer = XNNPACKQuantizer() 122 quantization_config = get_symmetric_quantization_config( 123 is_per_channel=True, is_dynamic=True 124 ) 125 quantizer.set_global(quantization_config) 126 example_inputs = (torch.randn(9, 8),) 127 before_node_occurrence = { 128 torch.ops.quantized_decomposed.choose_qparams.tensor: 1, 129 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 130 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 131 # note: quantize op for weights are const propagated 132 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 133 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 134 } 135 after_node_occurrence = { 136 torch.ops.quantized_decomposed.choose_qparams.tensor: 2, 137 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, 138 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, 139 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 140 torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, 141 } 142 self._test_duplicate_chain( 143 MyTestHelperModules.TwoFanOutLinears().eval(), 144 example_inputs, 145 quantizer, 146 before_node_occurrences=before_node_occurrence, 147 after_node_occurrences=after_node_occurrence, 148 ) 149 150 @unittest.skip("Set module name API does not work correctly when used as here.") 151 def test_no_duplicate_chain_different_qscheme(self): 152 """ 153 Model under test 154 x -> linear1 -> linear 2 155 """ 156 157 quantizer = XNNPACKQuantizer() 158 dynamic_qconfig = get_symmetric_quantization_config( 159 is_per_channel=True, is_dynamic=True 160 ) 161 static_qconfig = get_symmetric_quantization_config(is_per_channel=False) 162 quantizer.set_module_name("linear1", dynamic_qconfig) 163 quantizer.set_module_name("linear2", static_qconfig) 164 example_inputs = (torch.randn(9, 8),) 165 before_node_occurrence = { 166 torch.ops.quantized_decomposed.choose_qparams.tensor: 1, 167 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 168 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 169 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 170 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 171 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 172 } 173 after_node_occurrence = { 174 torch.ops.quantized_decomposed.choose_qparams.tensor: 1, 175 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, 176 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, 177 torch.ops.quantized_decomposed.quantize_per_channel.default: 0, 178 torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, 179 torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, 180 } 181 self._test_duplicate_chain( 182 TestHelperModules.TwoLinearModule().eval(), 183 example_inputs, 184 quantizer, 185 before_node_occurrences=before_node_occurrence, 186 after_node_occurrences=after_node_occurrence, 187 ) 188