xref: /aosp_15_r20/external/pytorch/test/quantization/jit/test_fusion_passes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3# torch
4import torch
5from torch.testing import FileCheck
6from torch.testing._internal.common_quantization import QuantizationTestCase
7
8
9class TestFusionPasses(QuantizationTestCase):
10    def test_quantized_add_relu_fusion(self):
11        class MAdd(torch.nn.Module):
12            def forward(self, x, y):
13                a = torch.ops.quantized.add(x, y, 1.0, 0)
14                relu_out = torch.relu(a)
15                return relu_out
16
17        A = torch.arange(-128, 130, dtype=torch.float)
18        B = torch.arange(-128, 130, dtype=torch.float)
19        scale = 2.0
20        zero_point = 127
21        qA = torch.quantize_per_tensor(
22            A, scale=scale, zero_point=zero_point, dtype=torch.quint8
23        )
24        qB = torch.quantize_per_tensor(
25            B, scale=scale, zero_point=zero_point, dtype=torch.quint8
26        )
27
28        # Check quantized add + relu fusion
29        m = MAdd()
30        scripted_m = torch.jit.script(m)
31        ref_output = scripted_m(qA, qB)
32
33        # Must inline the graph.
34        # In this test case since we are directly calling ops
35        # it does not matter, however if we are calling nn
36        # modules we have to inline graph.
37        torch._C._jit_pass_inline(scripted_m.graph)
38        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
39        FileCheck().check_not("aten::relu").check("quantized::add_relu").run(
40            scripted_m.graph
41        )
42        output = scripted_m(qA, qB)
43        self.assertEqual(ref_output, output)
44
45        class MAddOut(torch.nn.Module):
46            def forward(self, x, y, z):
47                a = torch.ops.quantized.add_out(x, y, z)
48                relu_out = torch.relu(a)
49                return relu_out
50
51        qC = torch._empty_affine_quantized(
52            qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8
53        )
54        # Check quantized add + relu fusion
55        m = MAddOut()
56        scripted_m = torch.jit.script(m)
57        ref_output = scripted_m(qA, qB, qC)
58        # Must inline the graph.
59        # In this test case since we are directly calling ops
60        # it does not matter, however if we are calling nn
61        # modules we have to inline graph.
62        torch._C._jit_pass_inline(scripted_m.graph)
63        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
64        FileCheck().check_not("aten::relu").check_not("quantized::add_out").check(
65            "quantized::add_relu_out"
66        ).run(scripted_m.graph)
67        output = scripted_m(qA, qB, qC)
68        self.assertEqual(ref_output, output)
69
70        class MAddScalar(torch.nn.Module):
71            def forward(self, x, y: float):
72                a = torch.ops.quantized.add_scalar(x, y)
73                relu_out = torch.relu(a)
74                return relu_out
75
76        # Check quantized add + relu fusion
77        m = MAddScalar()
78        scripted_m = torch.jit.script(m)
79        ref_output = scripted_m(qA, 3.0)
80        torch._C._jit_pass_inline(scripted_m.graph)
81        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
82        FileCheck().check_not("aten::relu").check_not("quantized::add_scalar(").check(
83            "quantized::add_scalar_relu"
84        ).run(scripted_m.graph)
85        output = scripted_m(qA, 3.0)
86        self.assertEqual(ref_output, output)
87
88        class MAddScalarOut(torch.nn.Module):
89            def forward(self, x, y: float, z):
90                a = torch.ops.quantized.add_scalar_out(x, y, z)
91                relu_out = torch.relu(a)
92                return relu_out
93
94        qC = torch._empty_affine_quantized(
95            qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8
96        )
97        m = MAddScalarOut()
98        scripted_m = torch.jit.script(m)
99        ref_output = scripted_m(qA, 3.0, qC)
100        torch._C._jit_pass_inline(scripted_m.graph)
101        torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
102        FileCheck().check_not("aten::relu").check_not(
103            "quantized::add_scalar_out"
104        ).check("quantized::add_scalar_relu_out").run(scripted_m.graph)
105        output = scripted_m(qA, 3.0, qC)
106        self.assertEqual(ref_output, output)
107