xref: /aosp_15_r20/external/executorch/exir/tests/test_pass_infra.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10
11import executorch.exir as exir
12
13import torch
14from executorch.exir.pass_manager import PassManager
15from executorch.exir.passes import ScalarToTensorPass
16from executorch.exir.passes.pass_registry import PassRegistry
17from torch.fx.passes.infra.pass_base import PassBase
18
19
20class TestPassInfra(unittest.TestCase):
21    def test_fail_passbase(self) -> None:
22        """
23        Tests if we catch errors when we do not inherit PassBase correctly
24        """
25
26        # Catches error if we do not implement call()
27        class TestPass3(PassBase):
28            def __init__(self):
29                pass
30
31        with self.assertRaises(TypeError):
32            # pyre-ignore
33            TestPass3()
34
35    def test_pass_registry_func(self) -> None:
36        """
37        Test if we register a callable correctly
38        """
39
40        # Registering w/o specifying pass_name
41        @PassRegistry.register()
42        def test_pass1(graph_module: torch.fx.GraphModule) -> None:
43            pass
44
45        self.assertEqual(len(PassRegistry.get("test_pass1")), 1)
46
47        # Registering with a specified pass_name
48        @PassRegistry.register(pass_name="test_pass1_1")
49        def test_pass11(graph_module: torch.fx.GraphModule) -> None:
50            pass
51
52        self.assertEqual(len(PassRegistry.get("test_pass1_1")), 1)
53
54    def test_pass_registry_passbase(self) -> None:
55        """
56        Test if we register a PassBase subclass correctly
57        """
58
59        class TestPass2(PassBase):
60            def __init__(self) -> None:
61                pass
62
63            def call(self, graph_module: torch.fx.GraphModule) -> None:
64                pass
65
66        PassRegistry.register("test_pass2")(TestPass2())
67
68        self.assertEqual(len(PassRegistry.get("test_pass2")), 1)
69
70    def test_pass_registry_list(self) -> None:
71        def test_pass1(graph_module: torch.fx.GraphModule) -> None:
72            pass
73
74        class TestPass2(PassBase):
75            def __init__(self) -> None:
76                pass
77
78            def call(self, graph_module: torch.fx.GraphModule) -> None:
79                pass
80
81        # Register a list of passes
82        PassRegistry.register_list(
83            pass_name="test_pass3", pass_list=[test_pass1, TestPass2()]
84        )
85        self.assertEqual(len(PassRegistry.get("test_pass3")), 2)
86
87    def test_pass_manager(self) -> None:
88        """
89        Tests that the pass manager runs the passes correctly.
90        """
91
92        def replace_add_with_mul(gm: torch.fx.GraphModule) -> None:
93            for node in gm.graph.nodes:
94                if node.op == "call_function" and "aten.add.Tensor" in str(node.target):
95                    node.target = torch.mul
96
97        def replace_mul_with_div(gm: torch.fx.GraphModule) -> None:
98            for node in gm.graph.nodes:
99                if node.op == "call_function" and node.target == torch.mul:
100                    node.target = torch.div
101
102        def f(x: torch.Tensor) -> torch.Tensor:
103            y = torch.add(x, x)
104            z = torch.add(y, x)
105            return z
106
107        f = (
108            exir.capture(f, (torch.randn(10),), exir.CaptureConfig())
109            .to_edge()
110            .exported_program.graph_module
111        )
112        pm = PassManager(passes=[replace_add_with_mul, replace_mul_with_div])
113        self.assertEqual(len(pm.passes), 2)
114        pm(f)
115
116        # Check that all call_function nodes are divs
117        for node in f.graph.nodes:
118            if node.op == "call_function":
119                self.assertEqual(node.target, torch.div)
120
121    def test_pass_manager_invalid_passes(self) -> None:
122        """
123        Tests that the pass manager detects invalid passes
124        """
125
126        class Foo(torch.nn.Module):
127            def __init__(self) -> None:
128                super().__init__()
129
130            def forward(self, x: torch.Tensor) -> torch.Tensor:
131                return x
132
133        def introduce_call_method(gm: torch.fx.GraphModule) -> None:
134            node = list(gm.graph.nodes)[-2]
135            with gm.graph.inserting_after(node):
136                new_node = gm.graph.call_method("torch.ops.relu", (torch.randn(2),))
137                node.replace_all_uses_with(new_node)
138
139        def introduce_call_module(gm: torch.fx.GraphModule) -> None:
140            node = list(gm.graph.nodes)[-2]
141            gm.add_submodule("foo", Foo())
142
143            with gm.graph.inserting_after(node):
144                new_node = gm.graph.call_module("foo", (torch.randn(2),))
145                node.replace_all_uses_with(new_node)
146
147        def f(x: torch.Tensor) -> torch.Tensor:
148            y = torch.add(x, x)
149            z = torch.add(y, x)
150            return z
151
152        traced_f1 = (
153            exir.capture(f, (torch.randn(10),), exir.CaptureConfig())
154            .to_edge()
155            .exported_program.graph_module
156        )
157        pm1 = PassManager(
158            passes=[introduce_call_method], run_checks_after_each_pass=True
159        )
160
161        with self.assertRaisesRegex(Exception, "call_method"):
162            pm1(traced_f1)
163
164    def test_pass_metadata(self) -> None:
165        def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166            return x + y
167
168        sample_inputs = (torch.randn(1, 3), torch.randn(1, 3))
169        gm = exir.capture(
170            f, sample_inputs, exir.CaptureConfig()
171        ).exported_program.graph_module
172
173        pass_result = ScalarToTensorPass()(gm)
174        self.assertIsNotNone(pass_result)
175        new_gm = pass_result.graph_module
176
177        for node in new_gm.graph.nodes:
178            if node.target != "output":
179                self.assertIn("val", node.meta)
180