xref: /aosp_15_r20/external/pytorch/test/functorch/test_minifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3import torch
4from functorch import make_fx
5from functorch.compile import minifier
6from torch._functorch.compile_utils import get_outputs, get_placeholders
7from torch.testing._internal.common_utils import run_tests, TestCase
8
9
10class TestMinifier(TestCase):
11    def test_has_mul_minifier(self):
12        def failing_f(x, y):
13            y = y / 3
14            x = x + 3
15            x = x * y
16            return x + y
17
18        inps = [torch.randn(3), torch.randn(3)]
19        failing_f = make_fx(failing_f)(*inps)
20
21        def has_mul(fx_g, inps):
22            return torch.ops.aten.mul.Tensor in (i.target for i in fx_g.graph.nodes)
23
24        min_f, inps = minifier(failing_f, inps, has_mul)
25        self.assertEqual(len(min_f.graph.nodes), 4)
26        self.assertEqual(len(inps), 2)
27
28    def test_has_add_mul(self):
29        def failing_f(x):
30            x = x * 3
31            x = x + 5
32            x = x.cos()
33            zero = x - x
34            result = zero / zero
35            result = result + 3
36            return (result * 2,)
37
38        inps = [torch.randn(3)]
39        failing_f = make_fx(failing_f)(*inps)
40
41        def has_nans(fx_g, inps):
42            # Basically, make sure none of the nodes are computing nans
43            for i in inps:
44                if torch.isnan(i).any():
45                    return False
46            return torch.isnan(fx_g(*inps)[0]).any()
47
48        min_f, inps = minifier(failing_f, inps, has_nans)
49        self.assertEqual(len(min_f.graph.nodes), 3)
50        self.assertEqual(len(inps), 1)
51
52    def test_input_returned(self):
53        def f(a, b, c):
54            a = a.sin()
55            c = c.cos()
56            d = a * c
57            return (a, b, c, d)
58
59        inps = [torch.randn(3) for _ in range(3)]
60
61        def inputs_returned(fx_g, inps):
62            inps = set(get_placeholders(fx_g.graph))
63            outs = set(get_outputs(fx_g.graph))
64            return len(inps & outs) > 0
65
66        failing_f = make_fx(f)(*inps)
67        min_f, inps = minifier(failing_f, inps, inputs_returned)
68        self.assertEqual(len(min_f.graph.nodes), 2)
69        self.assertEqual(len(inps), 1)
70
71    def test_tup_use(self):
72        def f(a, b):
73            tup = torch.std_mean(a)
74            return (tup[0] + b * tup[1],)
75
76        inps = [torch.randn(3), torch.randn(3)]
77
78        def has_add(fx_g, inps):
79            return torch.ops.aten.add.Tensor in (i.target for i in fx_g.graph.nodes)
80
81        failing_f = make_fx(f)(*inps)
82        min_f, inps = minifier(failing_f, inps, has_add)
83
84        self.assertEqual(len(min_f.graph.nodes), 4)
85        self.assertEqual(len(inps), 2)
86
87    def test_module(self):
88        class MockModule(torch.nn.Module):
89            def __init__(self) -> None:
90                super().__init__()
91                self.relu = torch.nn.ReLU()
92
93            def forward(self, x):
94                y = self.relu(x)
95                zero = y - y
96                result = zero / zero
97                result = result + 3
98                return result
99
100        mod = MockModule()
101        failing_f = torch.fx.symbolic_trace(mod)
102
103        inps = [torch.randn(3)]
104
105        def pass_checker(fx_g, inps):
106            # Basically, make sure none of the inputs are nans
107            for i in inps:
108                if torch.isnan(i).any():
109                    return False
110            return torch.isnan(fx_g(*inps)[0]).any()
111
112        min_f, inps = minifier(failing_f, inps, pass_checker)
113        assert len(min_f.graph.nodes) == 3
114        assert len(inps) == 1
115
116
117if __name__ == "__main__":
118    run_tests()
119