xref: /aosp_15_r20/external/pytorch/test/fx/test_cse_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: fx"]
2
3import random
4
5import torch
6from torch.fx import symbolic_trace
7from torch.fx.experimental.proxy_tensor import make_fx
8from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12banned_ops = get_CSE_banned_ops()
13P_default = CSEPass(banned_ops=banned_ops)
14
15
16def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
17    """
18    check if the CSE modified graph of ``f``
19    1) has delta less nodes, and
20    2) do not reduce the number of nodes further on a second pass, and
21    3) modified returned is true only if the number of nodes decreases.
22
23    Args:
24        f: function to be checked
25        t: tensor to be passed to f
26        delta: an integer >= -1.
27               If delta = -1, it only checks if the new graph has less or equal number of nodes
28        check_val: if True, check if the output of f is correct
29        graph_input: True is f is type GraphModule
30        P: the pass to use. If None, use P_default
31    """
32    if graph_input:
33        fx_g = f
34    else:
35        fx_g = make_fx(f)(t)
36
37    if P is None:
38        P = P_default
39
40    res = P(fx_g)
41    new_g = res.graph_module
42    new_graph = new_g.graph
43    modified = res.modified
44
45    # the number of nodes decrease/ or stay the same
46    old_num_nodes = len(fx_g.graph.nodes)
47    new_num_nodes = len(new_graph.nodes)
48
49    assert (
50        new_num_nodes < old_num_nodes
51    ) == modified, "modified should be True if the number of nodes decrease"
52
53    if delta == -1:
54        self.assertTrue(
55            old_num_nodes >= new_num_nodes,
56            (f"number of nodes increased {old_num_nodes}, {new_num_nodes}"),
57        )
58    else:
59        self.assertTrue(
60            old_num_nodes == new_num_nodes + delta,
61            (
62                f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
63            ),
64        )
65
66    # a second pass should not reduce more nodes
67    res = P(new_g)
68    pass_2_graph = res.graph_module.graph
69    pass_2_num_nodes = len(pass_2_graph.nodes)
70    self.assertTrue(
71        pass_2_num_nodes == new_num_nodes,
72        (
73            f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
74        ),
75    )
76
77    # check correctness
78    if check_val:
79        true_result = fx_g(t)
80        our_result = new_g(t)
81        if true_result is None:  # both return None
82            self.assertTrue(
83                our_result is None, f"true result is None, CSE result is {our_result}"
84            )
85        else:  # results returned are the same
86            self.assertTrue(
87                torch.all(true_result == our_result),
88                (f"results are different {true_result}, {our_result}"),
89            )  # check results are the same
90
91
92class TestCSEPass(TestCase):
93    def test_nochange(self):
94        def f(x):
95            a = x + 1
96            b = x + a
97            a = x
98            d = x + a
99            return b + d
100
101        t = torch.randn(2, 2)
102        check(self, f, t, 0)
103
104    def test_empty(self):
105        def f(x):
106            pass
107
108        t = torch.randn(2, 2)
109        check(self, f, t, 0)
110
111    def test_immutable_list_type(self):
112        def f(x):
113            a = x.sum(dim=1)
114            b = x.sum(dim=1)
115            c = x.sum()
116            d = x.sum()
117            return a + b + c + d
118
119        t = torch.randn(2, 2)
120        check(self, f, t, 2)
121
122    def test_immutable_list_multiple_entries(self):
123        def f(x):
124            a = x.sum(dim=[0, 1])
125            b = x.sum(dim=[0, 1])
126            c = x.sum(dim=1)
127            d = x.sum(dim=1)
128            return a + b + c + d
129
130        t = torch.randn(2, 2)
131        check(self, f, t, 2)
132
133    def test_simple(self):
134        def f(x):
135            a = x.cos()
136            b = x.cos()
137            c = a + a
138            d = b + b
139            return c + d
140
141        t = torch.randn(2, 2)
142        check(self, f, t, 2)
143
144    def test_simple_2(self):
145        def f(x):
146            a = x.cos().sin()
147            b = x.cos().sin()
148            c = a + a
149            d = b + b
150            return c + d
151
152        t = torch.randn(1)
153        check(self, f, t, 3)
154
155    def test_two_args_default(self):
156        def f(x):
157            a = x.sum(dim=1)
158            b = x.sum(dim=1, keepdim=False)
159            c = x.sum(dim=1, keepdim=False)
160            d = x.sum(dim=1)
161            return a + b + c + d
162
163        t = torch.randn(2, 2)
164        check(self, f, t, 3)
165
166    def test_two_args(self):
167        def f(x):
168            a = x.sum(dim=1)
169            b = x.sum(dim=1, keepdim=True)
170            c = x.sum(dim=1, keepdim=True)
171            d = x.sum(dim=1)
172            return a + b + c + d
173
174        t = torch.randn(2, 2)
175        check(self, f, t, 2)
176
177    def test_simple_multiple_same_ops(self):
178        def f(x):
179            a = x.sum()
180            b = x.sum()
181            c = x.sum()
182            d = x.sum()
183            return a + b + c + d
184
185        t = torch.randn(2, 2)
186        check(self, f, t, 3)
187
188    def test_nested_immutable_list_type(self):
189        def f(x):
190            a = torch.cat((x, x))
191            b = torch.cat((x, x))
192            return a + b
193
194        t = torch.randn(2, 2)
195        check(self, f, t, 1)
196
197    def test_kwarg(self):
198        def f(x):
199            a = torch.ones_like(x)
200            b = torch.ones_like(x)
201            return a + b
202
203        t = torch.randn(2, 2)
204        check(self, f, t, 1)
205
206    """
207    Generate function with random ops and check if the result is the same
208    """
209
210    def test_random(self):
211        def f(x):
212            vals = [x]
213            ops = [torch.clone, torch.cos, torch.tanh, torch.nn.functional.gelu]
214            for _ in range(100):
215                new_val = random.choice(ops)(random.choice(vals))
216                vals.append(new_val)
217            return vals[-1]
218
219        fx_g = symbolic_trace(f)
220        fx_g.graph.eliminate_dead_code()
221        fx_g.recompile()
222        t = torch.randn(2, 2)
223
224        for _ in range(30):
225            check(self, fx_g, t, -1, graph_input=True)
226
227    """
228    Test that banned list ban ops as expected.
229    """
230
231    def test_banned_list(self):
232        def f(x):
233            a = x + 1
234            b = x + 1
235            return a + b
236
237        t = torch.randn(2, 2)
238        P_ban_add = P = CSEPass(banned_ops=[torch.ops.aten.add])
239        check(self, f, t, 0, P=P_ban_add)  # check that add is banned
240        check(self, f, t, 1)  # check that add is not banned by default
241
242    def test_rand_like(self):
243        def f(x):
244            a = torch.rand_like(x)
245            b = torch.rand_like(x)
246            return a + b
247
248        t = torch.randn(2, 2)
249        check(self, f, t, 0, check_val=False)
250
251    def test_rand_n(self):
252        def f(x):
253            a = torch.randn(4)
254            b = torch.randn(4)
255            return a + b
256
257        t = torch.randn(2, 2)
258        check(self, f, t, 0, check_val=False)
259
260
261if __name__ == "__main__":
262    run_tests()
263