1# Owner(s): ["oncall: quantization"] 2 3# torch 4import torch 5from torch.testing import FileCheck 6from torch.testing._internal.common_quantization import QuantizationTestCase 7 8 9class TestFusionPasses(QuantizationTestCase): 10 def test_quantized_add_relu_fusion(self): 11 class MAdd(torch.nn.Module): 12 def forward(self, x, y): 13 a = torch.ops.quantized.add(x, y, 1.0, 0) 14 relu_out = torch.relu(a) 15 return relu_out 16 17 A = torch.arange(-128, 130, dtype=torch.float) 18 B = torch.arange(-128, 130, dtype=torch.float) 19 scale = 2.0 20 zero_point = 127 21 qA = torch.quantize_per_tensor( 22 A, scale=scale, zero_point=zero_point, dtype=torch.quint8 23 ) 24 qB = torch.quantize_per_tensor( 25 B, scale=scale, zero_point=zero_point, dtype=torch.quint8 26 ) 27 28 # Check quantized add + relu fusion 29 m = MAdd() 30 scripted_m = torch.jit.script(m) 31 ref_output = scripted_m(qA, qB) 32 33 # Must inline the graph. 34 # In this test case since we are directly calling ops 35 # it does not matter, however if we are calling nn 36 # modules we have to inline graph. 37 torch._C._jit_pass_inline(scripted_m.graph) 38 torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) 39 FileCheck().check_not("aten::relu").check("quantized::add_relu").run( 40 scripted_m.graph 41 ) 42 output = scripted_m(qA, qB) 43 self.assertEqual(ref_output, output) 44 45 class MAddOut(torch.nn.Module): 46 def forward(self, x, y, z): 47 a = torch.ops.quantized.add_out(x, y, z) 48 relu_out = torch.relu(a) 49 return relu_out 50 51 qC = torch._empty_affine_quantized( 52 qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8 53 ) 54 # Check quantized add + relu fusion 55 m = MAddOut() 56 scripted_m = torch.jit.script(m) 57 ref_output = scripted_m(qA, qB, qC) 58 # Must inline the graph. 59 # In this test case since we are directly calling ops 60 # it does not matter, however if we are calling nn 61 # modules we have to inline graph. 62 torch._C._jit_pass_inline(scripted_m.graph) 63 torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) 64 FileCheck().check_not("aten::relu").check_not("quantized::add_out").check( 65 "quantized::add_relu_out" 66 ).run(scripted_m.graph) 67 output = scripted_m(qA, qB, qC) 68 self.assertEqual(ref_output, output) 69 70 class MAddScalar(torch.nn.Module): 71 def forward(self, x, y: float): 72 a = torch.ops.quantized.add_scalar(x, y) 73 relu_out = torch.relu(a) 74 return relu_out 75 76 # Check quantized add + relu fusion 77 m = MAddScalar() 78 scripted_m = torch.jit.script(m) 79 ref_output = scripted_m(qA, 3.0) 80 torch._C._jit_pass_inline(scripted_m.graph) 81 torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) 82 FileCheck().check_not("aten::relu").check_not("quantized::add_scalar(").check( 83 "quantized::add_scalar_relu" 84 ).run(scripted_m.graph) 85 output = scripted_m(qA, 3.0) 86 self.assertEqual(ref_output, output) 87 88 class MAddScalarOut(torch.nn.Module): 89 def forward(self, x, y: float, z): 90 a = torch.ops.quantized.add_scalar_out(x, y, z) 91 relu_out = torch.relu(a) 92 return relu_out 93 94 qC = torch._empty_affine_quantized( 95 qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8 96 ) 97 m = MAddScalarOut() 98 scripted_m = torch.jit.script(m) 99 ref_output = scripted_m(qA, 3.0, qC) 100 torch._C._jit_pass_inline(scripted_m.graph) 101 torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph) 102 FileCheck().check_not("aten::relu").check_not( 103 "quantized::add_scalar_out" 104 ).check("quantized::add_scalar_relu_out").run(scripted_m.graph) 105 output = scripted_m(qA, 3.0, qC) 106 self.assertEqual(ref_output, output) 107