1# Owner(s): ["oncall: quantization"] 2import copy 3import unittest 4from typing import Any, Dict 5 6import torch 7from torch._export import capture_pre_autograd_graph 8from torch.ao.quantization.observer import ( 9 HistogramObserver, 10 MinMaxObserver, 11 PlaceholderObserver, 12) 13from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 14from torch.ao.quantization.quantizer import ( 15 QuantizationAnnotation, 16 QuantizationSpec, 17 Quantizer, 18 SharedQuantizationSpec, 19) 20from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 21 get_symmetric_quantization_config, 22) 23from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 24 OP_TO_ANNOTATOR, 25 QuantizationConfig, 26) 27from torch.testing._internal.common_quantization import QuantizationTestCase 28from torch.testing._internal.common_utils import IS_WINDOWS 29 30 31class TestHelperModules: 32 class Conv2dWithObsSharingOps(torch.nn.Module): 33 def __init__(self) -> None: 34 super().__init__() 35 self.conv = torch.nn.Conv2d(3, 3, 3) 36 self.hardtanh = torch.nn.Hardtanh() 37 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 38 self.linear = torch.nn.Linear(3, 3) 39 40 def forward(self, x): 41 x = self.conv(x) 42 x = self.adaptive_avg_pool2d(x) 43 x = self.hardtanh(x) 44 x = x.view(-1, 3) 45 x = self.linear(x) 46 return x 47 48 class Conv2dWithSharedDQ(torch.nn.Module): 49 def __init__(self) -> None: 50 super().__init__() 51 self.conv1 = torch.nn.Conv2d(3, 3, 3) 52 self.conv2 = torch.nn.Conv2d(3, 3, 1) 53 self.linear = torch.nn.Linear(3, 3) 54 55 def forward(self, x): 56 x = self.conv1(x) 57 z = x.view(-1, 3) 58 w = self.linear(z) 59 60 y = self.conv2(x) 61 add_output = x + y 62 63 extra_output = x * 2 64 return w, add_output, extra_output 65 66 class ModuleForDifferentQconfig(torch.nn.Module): 67 def __init__(self) -> None: 68 super().__init__() 69 self.conv1 = torch.nn.Conv2d(3, 3, 3) 70 self.conv2 = torch.nn.Conv2d(3, 3, 1) 71 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 72 73 def forward(self, x): 74 x = self.conv1(x) 75 w = self.adaptive_avg_pool2d(x) 76 77 y = self.conv2(x) 78 add_output = x + y 79 80 extra_output = x + 2 81 return w, add_output, extra_output 82 83 84_DEQUANTIZE_OPS = [ 85 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 86 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 87 torch.ops.quantized_decomposed.dequantize_per_channel.default, 88] 89 90 91@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") 92class TestDuplicateDQPass(QuantizationTestCase): 93 def _test_duplicate_dq( 94 self, 95 model, 96 example_inputs, 97 quantizer, 98 ): 99 m_eager = model.eval() 100 101 # program capture 102 m = copy.deepcopy(m_eager) 103 m = capture_pre_autograd_graph( 104 m, 105 example_inputs, 106 ) 107 108 m = prepare_pt2e(m, quantizer) 109 # Calibrate 110 m(*example_inputs) 111 m = convert_pt2e(m) 112 113 pt2_quant_output = m(*example_inputs) 114 for n in m.graph.nodes: 115 annotation = n.meta.get("quantization_annotation", None) 116 if annotation is not None: 117 for arg in n.args: 118 if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS: 119 self.assertEqual(len(arg.users.keys()), 1) 120 121 def test_no_need_for_duplicate_dq(self): 122 """ 123 Model under test 124 conv2d -> avgpool -> hardtanh -> linear 125 Check quantization tags on conv2d, avgpool and linear are correctly set 126 """ 127 128 class BackendAQuantizer(Quantizer): 129 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 130 backend_string = "BackendA" 131 quantization_config = get_symmetric_quantization_config( 132 is_per_channel=True 133 ) 134 OP_TO_ANNOTATOR["linear"](gm, quantization_config) 135 OP_TO_ANNOTATOR["conv"](gm, quantization_config) 136 OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) 137 138 def validate(self, model: torch.fx.GraphModule) -> None: 139 pass 140 141 example_inputs = (torch.randn(1, 3, 5, 7),) 142 self._test_duplicate_dq( 143 TestHelperModules.Conv2dWithObsSharingOps(), 144 example_inputs, 145 BackendAQuantizer(), 146 ) 147 148 def test_simple_duplicate_dq(self): 149 """ 150 Model under test 151 conv2d -> conv2d -> add 152 | | 153 ---------> 154 | 155 -----> view_copy --> linear 156 | 157 -----> mul 158 There should be three dq nodes because output for the 159 first conv2d is fed to next conv2d, add, and view_copy + linear. 160 All three are quantized. 161 Thus DQ node is not duplicated for those three uses 162 """ 163 164 class BackendAQuantizer(Quantizer): 165 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 166 backend_string = "BackendA" 167 quantization_config = get_symmetric_quantization_config( 168 is_per_channel=True 169 ) 170 OP_TO_ANNOTATOR["linear"](gm, quantization_config) 171 OP_TO_ANNOTATOR["conv"](gm, quantization_config) 172 OP_TO_ANNOTATOR["add"](gm, quantization_config) 173 174 def validate(self, model: torch.fx.GraphModule) -> None: 175 pass 176 177 example_inputs = (torch.randn(1, 3, 5, 7),) 178 self._test_duplicate_dq( 179 TestHelperModules.Conv2dWithSharedDQ(), 180 example_inputs, 181 BackendAQuantizer(), 182 ) 183 184 def test_no_add_quant_duplicate_dq(self): 185 """ 186 Model under test 187 conv2d -> conv2d -> add 188 | | 189 ---------> 190 | 191 -----> view_copy --> linear 192 | 193 -----> mul 194 There should be three dq nodes because output for the 195 first conv2d is fed to next conv2d, and view_copy + linear. 196 Both are quantized. 197 However the skip connection to add and mul are not quantized. 198 Thus DQ node is not duplicated for those two uses 199 """ 200 201 class BackendAQuantizer(Quantizer): 202 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 203 backend_string = "BackendA" 204 quantization_config = get_symmetric_quantization_config( 205 is_per_channel=True 206 ) 207 OP_TO_ANNOTATOR["linear"](gm, quantization_config) 208 OP_TO_ANNOTATOR["conv"](gm, quantization_config) 209 210 def validate(self, model: torch.fx.GraphModule) -> None: 211 pass 212 213 example_inputs = (torch.randn(1, 3, 5, 7),) 214 self._test_duplicate_dq( 215 TestHelperModules.Conv2dWithSharedDQ(), 216 example_inputs, 217 BackendAQuantizer(), 218 ) 219 220 def test_avgpool_use_different_qconfig(self): 221 """ 222 Model under test 223 conv2d -> conv2d -> add 224 | | 225 ---------> 226 | 227 -----> adaptive_avgpool2d (different qconfig) 228 | 229 -----> add 230 output 231 conv2d -> dq -> conv2d -> add 232 | | 233 -------> dq -----> 234 | 235 -> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig) 236 | 237 -> dq -----> add 238 """ 239 240 def _get_uint8_quantization_config(): 241 act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] 242 act_quantization_spec = QuantizationSpec( 243 dtype=torch.uint8, 244 quant_min=0, 245 quant_max=255, 246 qscheme=torch.per_tensor_affine, 247 observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( 248 eps=2**-12 249 ), 250 ) 251 weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 252 MinMaxObserver 253 ) 254 255 extra_args: Dict[str, Any] = {"eps": 2**-12} 256 weight_quantization_spec = QuantizationSpec( 257 dtype=torch.uint8, 258 quant_min=0, 259 quant_max=255, 260 qscheme=torch.per_tensor_affine, 261 ch_axis=0, 262 is_dynamic=False, 263 observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( 264 **extra_args 265 ), 266 ) 267 268 bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 269 PlaceholderObserver 270 ) 271 bias_quantization_spec = QuantizationSpec( 272 dtype=torch.float, 273 observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr, 274 ) 275 quantization_config = QuantizationConfig( 276 act_quantization_spec, 277 act_quantization_spec, 278 weight_quantization_spec, 279 bias_quantization_spec, 280 ) 281 return quantization_config 282 283 class BackendAQuantizer(Quantizer): 284 def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 285 backend_string = "BackendA" 286 quantization_config = get_symmetric_quantization_config( 287 is_per_channel=True 288 ) 289 avgpool_qconfig = _get_uint8_quantization_config() 290 OP_TO_ANNOTATOR["conv"](gm, quantization_config) 291 OP_TO_ANNOTATOR["add"](gm, quantization_config) 292 for n in gm.graph.nodes: 293 if n.op == "call_function" and n.target == torch.ops.aten.mean.dim: 294 qspec = avgpool_qconfig.input_activation 295 input_act = n.args[0] 296 output_qspec = SharedQuantizationSpec((input_act, n)) 297 n.meta["quantization_annotation"] = QuantizationAnnotation( 298 input_qspec_map={input_act: qspec}, 299 output_qspec=output_qspec, 300 _annotated=True, 301 ) 302 303 def validate(self, model: torch.fx.GraphModule) -> None: 304 pass 305 306 example_inputs = (torch.randn(1, 3, 5, 7),) 307 self._test_duplicate_dq( 308 TestHelperModules.ModuleForDifferentQconfig(), 309 example_inputs, 310 BackendAQuantizer(), 311 ) 312