1# Owner(s): ["oncall: quantization"] 2 3import torch 4 5from torch.testing._internal.common_quantization import ( 6 QuantizationTestCase, 7 ModelMultipleOps, 8 ModelMultipleOpsNoAvgPool, 9) 10from torch.testing._internal.common_quantized import ( 11 override_quantized_engine, 12 supported_qengines, 13) 14 15class TestModelNumericsEager(QuantizationTestCase): 16 def test_float_quant_compare_per_tensor(self): 17 for qengine in supported_qengines: 18 with override_quantized_engine(qengine): 19 torch.manual_seed(42) 20 my_model = ModelMultipleOps().to(torch.float32) 21 my_model.eval() 22 calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32) 23 eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32) 24 out_ref = my_model(eval_data) 25 qModel = torch.ao.quantization.QuantWrapper(my_model) 26 qModel.eval() 27 qModel.qconfig = torch.ao.quantization.default_qconfig 28 torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) 29 torch.ao.quantization.prepare(qModel, inplace=True) 30 qModel(calib_data) 31 torch.ao.quantization.convert(qModel, inplace=True) 32 out_q = qModel(eval_data) 33 SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) 34 # Quantized model output should be close to floating point model output numerically 35 # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired 36 # output 37 self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB') 38 39 def test_float_quant_compare_per_channel(self): 40 # Test for per-channel Quant 41 torch.manual_seed(67) 42 my_model = ModelMultipleOps().to(torch.float32) 43 my_model.eval() 44 calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) 45 eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) 46 out_ref = my_model(eval_data) 47 q_model = torch.ao.quantization.QuantWrapper(my_model) 48 q_model.eval() 49 q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig 50 torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) 51 torch.ao.quantization.prepare(q_model) 52 q_model(calib_data) 53 torch.ao.quantization.convert(q_model) 54 out_q = q_model(eval_data) 55 SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) 56 # Quantized model output should be close to floating point model output numerically 57 # Setting target SQNR to be 35 dB 58 self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') 59 60 def test_fake_quant_true_quant_compare(self): 61 for qengine in supported_qengines: 62 with override_quantized_engine(qengine): 63 torch.manual_seed(67) 64 my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) 65 calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) 66 eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) 67 my_model.eval() 68 out_ref = my_model(eval_data) 69 fq_model = torch.ao.quantization.QuantWrapper(my_model) 70 fq_model.train() 71 fq_model.qconfig = torch.ao.quantization.default_qat_qconfig 72 torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) 73 torch.ao.quantization.prepare_qat(fq_model) 74 fq_model.eval() 75 fq_model.apply(torch.ao.quantization.disable_fake_quant) 76 fq_model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) 77 fq_model(calib_data) 78 fq_model.apply(torch.ao.quantization.enable_fake_quant) 79 fq_model.apply(torch.ao.quantization.disable_observer) 80 out_fq = fq_model(eval_data) 81 SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) 82 # Quantized model output should be close to floating point model output numerically 83 # Setting target SQNR to be 35 dB 84 self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') 85 torch.ao.quantization.convert(fq_model) 86 out_q = fq_model(eval_data) 87 SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) 88 self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB') 89 90 # Test to compare weight only quantized model numerics and 91 # activation only quantized model numerics with float 92 def test_weight_only_activation_only_fakequant(self): 93 for qengine in supported_qengines: 94 with override_quantized_engine(qengine): 95 torch.manual_seed(67) 96 calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) 97 eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) 98 qconfigset = {torch.ao.quantization.default_weight_only_qconfig, 99 torch.ao.quantization.default_activation_only_qconfig} 100 SQNRTarget = [35, 45] 101 for idx, qconfig in enumerate(qconfigset): 102 my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) 103 my_model.eval() 104 out_ref = my_model(eval_data) 105 fq_model = torch.ao.quantization.QuantWrapper(my_model) 106 fq_model.train() 107 fq_model.qconfig = qconfig 108 torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) 109 torch.ao.quantization.prepare_qat(fq_model) 110 fq_model.eval() 111 fq_model.apply(torch.ao.quantization.disable_fake_quant) 112 fq_model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats) 113 fq_model(calib_data) 114 fq_model.apply(torch.ao.quantization.enable_fake_quant) 115 fq_model.apply(torch.ao.quantization.disable_observer) 116 out_fq = fq_model(eval_data) 117 SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) 118 self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float') 119 120 121if __name__ == '__main__': 122 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 123 "\tpython test/test_quantization.py TESTNAME\n\n" 124 "instead.") 125