1# Owner(s): ["oncall: quantization"] 2 3import copy 4 5import torch 6import torch.nn as nn 7import torch.ao.nn.quantized as nnq 8import torch.ao.nn.intrinsic as nni 9import torch.ao.nn.intrinsic.quantized as nniq 10import torch.ao.nn.intrinsic.qat as nniqat 11from torch.ao.quantization import ( 12 quantize, 13 prepare, 14 convert, 15 prepare_qat, 16 quantize_qat, 17 fuse_modules, 18 fuse_modules_qat, 19 QConfig, 20 default_qconfig, 21 default_qat_qconfig, 22) 23 24from torch.testing._internal.common_quantization import ( 25 QuantizationTestCase, 26 ModelForFusion, 27 ModelWithSequentialFusion, 28 ModelForLinearBNFusion, 29 ModelForFusionWithBias, 30 ModelForConvTransposeBNFusion, 31 SingleLayerLinearModel, 32 test_only_eval_fn, 33 test_only_train_fn, 34 skipIfNoFBGEMM, 35) 36 37from torch.testing._internal.common_quantized import ( 38 override_quantized_engine, 39 supported_qengines, 40) 41 42 43@skipIfNoFBGEMM 44class TestFuseEager(QuantizationTestCase): 45 def test_fuse_module_train(self): 46 model = ModelForFusion(default_qat_qconfig).train() 47 # Test step by step fusion 48 model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) 49 model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) 50 self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, 51 msg="Fused Conv + BN + Relu first layer") 52 self.assertEqual(type(model.bn1), torch.nn.Identity, 53 msg="Fused Conv + BN + Relu (skipped BN)") 54 self.assertEqual(type(model.relu1), torch.nn.Identity, 55 msg="Fused Conv + BN + Relu (skipped Relu)") 56 57 self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, 58 msg="Fused submodule Conv + BN") 59 self.assertEqual(type(model.sub1.bn), torch.nn.Identity, 60 msg="Fused submodule Conv + BN (skipped BN)") 61 self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, 62 msg="Non-fused submodule Conv") 63 self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, 64 msg="Non-fused submodule ReLU") 65 model = prepare_qat(model) 66 self.checkObservers(model) 67 68 def checkQAT(model): 69 self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) 70 self.assertEqual(type(model.bn1), nn.Identity) 71 self.assertEqual(type(model.relu1), nn.Identity) 72 self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) 73 self.assertEqual(type(model.sub1.bn), nn.Identity) 74 self.assertEqual(type(model.sub2.conv), nn.Conv2d) 75 self.assertEqual(type(model.sub2.relu), nn.ReLU) 76 77 checkQAT(model) 78 test_only_train_fn(model, self.img_data_1d_train) 79 model = convert(model) 80 81 def checkQuantized(model): 82 self.assertEqual(type(model.conv1), nniq.ConvReLU2d) 83 self.assertEqual(type(model.bn1), nn.Identity) 84 self.assertEqual(type(model.relu1), nn.Identity) 85 self.assertEqual(type(model.sub1.conv), nnq.Conv2d) 86 self.assertEqual(type(model.sub1.bn), nn.Identity) 87 self.assertEqual(type(model.sub2.conv), nn.Conv2d) 88 self.assertEqual(type(model.sub2.relu), nn.ReLU) 89 test_only_eval_fn(model, self.img_data_1d) 90 self.checkNoQconfig(model) 91 92 with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): 93 checkQuantized(model) 94 95 model = ModelForFusion(default_qat_qconfig).train() 96 model = fuse_modules_qat( 97 model, 98 [['conv1', 'bn1', 'relu1'], 99 ['sub1.conv', 'sub1.bn']]) 100 model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) 101 with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): 102 checkQuantized(model) 103 104 105 def test_fuse_module_eval(self): 106 model = ModelForFusion(default_qconfig) 107 model.eval() 108 model = fuse_modules( 109 model, 110 [['conv3', 'bn3', 'relu4'], 111 ['conv1', 'bn1', 'relu1'], 112 ['conv2', 'relu2'], 113 ['bn2', 'relu3'], 114 ['sub1.conv', 'sub1.bn']]) 115 self.assertEqual(type(model.conv1), nni.ConvReLU2d, 116 msg="Fused Conv + BN + Relu first layer (BN is folded)") 117 self.assertEqual(type(model.conv1[0]), nn.Conv2d, 118 msg="Fused Conv + BN + Relu (Conv + folded BN only)") 119 self.assertEqual(type(model.conv1[1]), nn.ReLU, 120 msg="Fused Conv + BN + Relu second layer (Relu only)") 121 self.assertEqual(type(model.bn1), nn.Identity, 122 msg="Fused Conv + BN + Relu second layer (Skipped BN)") 123 self.assertEqual(type(model.relu1), nn.Identity, 124 msg="Fused Conv + BN + Relu second layer (Skipped Relu)") 125 self.assertEqual(type(model.conv2), nni.ConvReLU3d, 126 msg="Fused Conv + BN + Relu first layer (BN is folded)") 127 self.assertEqual(type(model.bn2), nni.BNReLU3d, 128 msg="Fused BN + Relu first layer (Relu is folded))") 129 self.assertEqual(type(model.relu3), nn.Identity, 130 msg="Fused BN + Relu second layer (Skipped Relu)") 131 self.assertEqual(type(model.conv2[0]), nn.Conv3d, 132 msg="Fused Conv + BN + Relu (Conv + folded BN only)") 133 self.assertEqual(type(model.conv2[1]), nn.ReLU, 134 msg="Fused Conv + BN + Relu second layer (Relu only)") 135 self.assertEqual(type(model.relu2), nn.Identity, 136 msg="Fused Conv + BN + Relu second layer (Skipped Relu)") 137 138 self.assertEqual(type(model.conv3), nni.ConvReLU1d, 139 msg="Fused Conv + Relu for Conv1d (folded BN)") 140 self.assertEqual(type(model.conv3[0]), nn.Conv1d, 141 msg="Fused Conv + Relu for Conv1d ") 142 self.assertEqual(type(model.conv3[1]), nn.ReLU, 143 msg="Fused Conv + Relu for Conv1d") 144 self.assertEqual(type(model.bn3), nn.Identity, 145 msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)") 146 147 self.assertEqual(type(model.sub1.conv), nn.Conv2d, 148 msg="Fused submodule Conv + folded BN") 149 self.assertEqual(type(model.sub1.bn), nn.Identity, 150 msg="Fused submodule (skipped BN)") 151 self.assertEqual(type(model.sub2.conv), nn.Conv2d, 152 msg="Non-fused submodule Conv") 153 self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, 154 msg="Non-fused submodule ReLU") 155 156 model = prepare(model) 157 self.checkObservers(model) 158 test_only_eval_fn(model, self.img_data_1d) 159 model = convert(model) 160 161 def checkQuantized(model): 162 self.assertEqual(type(model.conv3), nniq.ConvReLU1d) 163 self.assertEqual(type(model.conv1), nniq.ConvReLU2d) 164 self.assertEqual(type(model.bn1), nn.Identity) 165 self.assertEqual(type(model.relu1), nn.Identity) 166 self.assertEqual(type(model.sub1.conv), nnq.Conv2d) 167 self.assertEqual(type(model.sub1.bn), nn.Identity) 168 self.assertEqual(type(model.sub2.conv), nn.Conv2d) 169 self.assertEqual(type(model.sub2.relu), nn.ReLU) 170 self.assertEqual(type(model.bn2), nniq.BNReLU3d) 171 test_only_eval_fn(model, self.img_data_1d) 172 self.checkNoQconfig(model) 173 174 checkQuantized(model) 175 176 model = ModelForFusion(default_qconfig).eval() 177 model = fuse_modules( 178 model, 179 [['conv1', 'bn1', 'relu1'], 180 ['conv2', 'relu2'], 181 ['bn2', 'relu3'], 182 ['sub1.conv', 'sub1.bn'], 183 ['conv3', 'bn3', 'relu4']]) 184 model = quantize(model, test_only_eval_fn, [self.img_data_1d]) 185 checkQuantized(model) 186 187 def test_fusion_sequential_model_train(self): 188 for qengine in supported_qengines: 189 with override_quantized_engine(qengine): 190 model = ModelWithSequentialFusion().train() 191 model.to(torch.float) 192 fuse_modules_qat( 193 model, [['conv1', 'relu1'] , 194 ['features.0.0', 'features.0.1', 'features.0.2'], 195 ['features.1.0', 'features.1.1', 'features.1.2'], 196 ['features.2.0', 'features.2.1', 'features.2.2'], 197 ['classifier.0', 'classifier.1']], 198 inplace=True) 199 self.assertEqual(type(model.conv1), nni.ConvReLU2d, 200 msg="Fused Conv + Relu: nni.ConvReLU2d") 201 self.assertEqual(type(model.conv1[0]), nn.Conv2d, 202 msg="Fused Conv + Relu: Conv2d") 203 self.assertEqual(type(model.conv1[1]), nn.ReLU, 204 msg="Fused Conv + Relu: Relu") 205 self.assertEqual(type(model.relu1), nn.Identity, 206 msg="Fused Conv + Relu: Identity") 207 for i in range(3): 208 self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, 209 msg="Fused submodule Conv + folded BN") 210 self.assertEqual(type(model.features[i][1]), nn.Identity, 211 msg="Fused submodule (skipped BN)") 212 self.assertEqual(type(model.features[i][2]), nn.Identity, 213 msg="Non-fused submodule Conv") 214 self.assertEqual(type(model.classifier[0]), nni.LinearReLU) 215 self.assertEqual(type(model.classifier[1]), nn.Identity) 216 model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 217 prepare_qat(model, inplace=True) 218 self.checkObservers(model) 219 model(self.img_data_2d[0][0]) 220 221 222 def checkQAT(model): 223 self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) 224 self.assertEqual(type(model.relu1), nn.Identity) 225 for i in range(3): 226 self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, 227 msg="Fused submodule Conv + folded BN") 228 self.assertEqual(type(model.features[i][1]), nn.Identity, 229 msg="Fused submodule (skipped BN)") 230 self.assertEqual(type(model.features[i][2]), nn.Identity, 231 msg="Non-fused submodule Conv") 232 self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) 233 self.assertEqual(type(model.classifier[1]), nn.Identity) 234 235 checkQAT(model) 236 model(self.img_data_2d[1][0]) 237 convert(model, inplace=True) 238 model(self.img_data_2d[1][0]) 239 self.checkModelWithSequentialQuantized(model) 240 241 def test_fusion_sequential_model_eval(self): 242 for qengine in supported_qengines: 243 with override_quantized_engine(qengine): 244 model = ModelWithSequentialFusion().eval() 245 model.to(torch.float) 246 fuse_modules( 247 model, 248 [['conv1', 'relu1'], 249 ['features.0.0', 'features.0.1', 'features.0.2'], 250 ['features.1.0', 'features.1.1', 'features.1.2'], 251 ['features.2.0', 'features.2.1', 'features.2.2'], 252 ['classifier.0', 'classifier.1']], 253 inplace=True) 254 self.assertEqual(type(model.conv1), nni.ConvReLU2d, 255 msg="Fused Conv + Relu: nni.ConvReLU2d") 256 self.assertEqual(type(model.conv1[0]), nn.Conv2d, 257 msg="Fused Conv + Relu: Conv2d") 258 self.assertEqual(type(model.conv1[1]), nn.ReLU, 259 msg="Fused Conv + Relu: Relu") 260 self.assertEqual(type(model.relu1), nn.Identity, 261 msg="Fused Conv + Relu: Identity") 262 for i in range(3): 263 self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, 264 msg="Fused submodule Conv + folded BN") 265 self.assertEqual(type(model.features[i][1]), nn.Identity, 266 msg="Fused submodule (skipped BN)") 267 self.assertEqual(type(model.features[i][2]), nn.Identity, 268 msg="Non-fused submodule Conv") 269 self.assertEqual(type(model.classifier[0]), nni.LinearReLU) 270 self.assertEqual(type(model.classifier[1]), nn.Identity) 271 model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 272 prepare(model, inplace=True) 273 self.checkObservers(model) 274 model(self.img_data_2d[0][0]) 275 convert(model, inplace=True) 276 model(self.img_data_2d[1][0]) 277 self.checkModelWithSequentialQuantized(model) 278 279 def checkModelWithSequentialQuantized(self, model): 280 self.assertEqual(type(model.conv1), nniq.ConvReLU2d) 281 self.assertEqual(type(model.relu1), nn.Identity) 282 for i in range(3): 283 self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d) 284 self.assertEqual(type(model.features[i][1]), nn.Identity) 285 self.assertEqual(type(model.features[i][2]), nn.Identity) 286 self.assertEqual(type(model.classifier[0]), nniq.LinearReLU) 287 self.assertEqual(type(model.classifier[1]), nn.Identity) 288 289 def test_fusion_conv_with_bias(self): 290 for qengine in supported_qengines: 291 with override_quantized_engine(qengine): 292 model_orig = ModelForFusionWithBias().train() 293 294 # reference model 295 model_ref = copy.deepcopy(model_orig) 296 # output with no fusion. 297 out_ref = model_ref(self.img_data_2d[0][0]) 298 299 # fused model 300 model_orig.qconfig = QConfig(activation=torch.nn.Identity, 301 weight=torch.nn.Identity) 302 model = fuse_modules_qat( 303 model_orig, 304 [["conv1", "bn1", "relu1"], 305 ["conv2", "bn2"]]) 306 prep_model = prepare_qat(model, inplace=False) 307 # output with fusion but no observers. 308 out_fused = prep_model(self.img_data_2d[0][0]) 309 310 self.assertEqual(out_ref, out_fused) 311 312 def checkBN(bn_ref, bn): 313 self.assertEqual(bn_ref.weight, bn.weight) 314 self.assertEqual(bn_ref.bias, bn.bias) 315 self.assertEqual(bn_ref.running_mean, bn.running_mean) 316 self.assertEqual(bn_ref.running_var, bn.running_var) 317 318 checkBN(model_ref.bn1, prep_model.conv1.bn) 319 checkBN(model_ref.bn2, prep_model.conv2.bn) 320 321 model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 322 prepare_qat(model, inplace=True) 323 324 model(self.img_data_2d[0][0]) 325 326 def checkQAT(model): 327 self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) 328 self.assertEqual(type(model.bn1), nn.Identity) 329 self.assertEqual(type(model.relu1), nn.Identity) 330 self.assertEqual(type(model.conv2), nniqat.ConvBn2d) 331 self.assertEqual(type(model.bn2), nn.Identity) 332 333 checkQAT(model) 334 335 336 def test_fusion_linear_bn_eval(self): 337 model = ModelForLinearBNFusion().train() 338 inp1 = torch.randn(8, 20) 339 inp2 = torch.randn(8, 20) 340 341 # Get some interesting values into the running mean and variance. 342 model(inp1) 343 model.eval() 344 golden = model(inp2) 345 346 model = fuse_modules(model, [["fc", "bn"]]) 347 self.assertEqual(type(model.bn), nn.Identity) 348 self.assertEqual(golden, model(inp2)) 349 350 def test_fusion_convtranspose_bn_eval(self): 351 model = ModelForConvTransposeBNFusion().train() 352 inp1 = torch.randn(8, 3, 16) 353 inp2 = torch.randn(8, 3, 16) 354 355 # Get some interesting values into the running mean and variance. 356 model(inp1) 357 model.eval() 358 golden = model(inp2) 359 360 model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]) 361 self.assertEqual(type(model.bn1), nn.Identity) 362 self.assertEqual(type(model.bn2), nn.Identity) 363 self.assertEqual(type(model.bn3), nn.Identity) 364 365 self.assertEqual(golden, model(inp2)) 366 367 def test_fuse_function_customization(self): 368 dummy_model = SingleLayerLinearModel().train() 369 dummy_model.eval() 370 371 # A custom fuse funct 372 def custom_fuse_func(module, is_qat, add_fuser_mapping): 373 return [torch.nn.Identity()] 374 375 dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func) 376 self.assertEqual(type(dummy_model.fc1), nn.Identity) 377 378 def test_forward_hooks_preserved(self): 379 r"""Test case that checks whether forward pre hooks of the first module and 380 post forward hooks of the last module in modules list passed to fusion function preserved. 381 (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)] 382 after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity]) 383 """ 384 model = ModelForFusion(default_qat_qconfig).train() 385 386 counter = { 387 'pre_forwards': 0, 388 'forwards': 0, 389 } 390 fused = False 391 392 def fw_pre_hook(fused_module_class, h_module, input): 393 if fused: 394 self.assertEqual(type(h_module), fused_module_class, 395 "After fusion owner of the first module's forward pre hook is not a fused module") 396 counter['pre_forwards'] += 1 397 398 def fw_hook(fused_module_class, h_module, input, output): 399 if fused: 400 self.assertEqual(type(h_module), fused_module_class, 401 "After fusion owner of the last module's forward hook is not a fused module") 402 counter['forwards'] += 1 403 404 # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference 405 model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)) 406 model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args)) 407 model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args)) 408 model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args)) 409 410 test_only_eval_fn(model, self.img_data_1d) 411 self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) 412 self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) 413 414 model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) 415 model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) 416 417 fused = True 418 before_fusion_pre_count = counter['pre_forwards'] 419 before_fusion_post_count = counter['forwards'] 420 test_only_eval_fn(model, self.img_data_1d) 421 self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d)) 422 self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d)) 423 424 def test_fuse_modules_with_nested_hooks(self): 425 r"""Test case that checks whether a nested module with sub-sub modules registered with hooks 426 can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063 427 in the future. 428 """ 429 def myhook(*x): 430 return "" 431 for qengine in supported_qengines: 432 with override_quantized_engine(qengine): 433 model = ModelWithSequentialFusion().eval() 434 435 for sub_model in model.modules(): 436 if isinstance(sub_model, nn.Sequential): 437 for layer in sub_model: 438 if hasattr(layer, 'register_forward_hook'): 439 layer.register_forward_hook(myhook) 440 441 fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True) 442 self.assertEqual( 443 type(model.features[0][0]), 444 nni.ConvReLU2d, 445 msg="Fused submodule Conv + folded BN" 446 ) 447 self.assertEqual( 448 type(model.features[0][1]), 449 nn.Identity, 450 msg="Fused submodule (skipped BN)" 451 ) 452 self.assertEqual( 453 type(model.features[0][2]), 454 nn.Identity, 455 msg="Non-fused submodule Conv" 456 ) 457 458 459if __name__ == '__main__': 460 raise RuntimeError( 461 "This test file is not meant to be run directly, use:\n\n" 462 "\tpython test/test_quantization.py TESTNAME\n\n" 463 "instead." 464 ) 465