xref: /aosp_15_r20/external/executorch/exir/tests/test_delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import executorch.exir.tests.models as models
10
11import torch
12from executorch.exir import EdgeCompileConfig, to_edge
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.lowered_backend_module import (
15    create_submodule_from_nodes,
16    LoweredBackendModule,
17)
18from executorch.exir.schema import (
19    BackendDelegate,
20    BackendDelegateDataReference,
21    DataLocation,
22    DelegateCall,
23)
24from executorch.exir.tests.common import register_additional_test_aten_ops
25from torch.export import export
26from torch.testing import FileCheck
27
28
29class WrapperModule(torch.nn.Module):
30    def __init__(self, fn):
31        super().__init__()
32        self.fn = fn
33
34    def forward(self, *args, **kwargs):
35        return self.fn(*args, **kwargs)
36
37
38class TestDelegate(unittest.TestCase):
39    @classmethod
40    def setUpClass(cls) -> None:
41        register_additional_test_aten_ops()
42
43    def test_call_delegate(self) -> None:
44        def g(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
45            return x + y
46
47        inputs = (torch.ones(1, 3), torch.ones(1, 3))
48        edge_ir_m = to_edge(export(WrapperModule(g), inputs))
49        lowered_module: LoweredBackendModule = LoweredBackendModule(
50            edge_ir_m.exported_program(), "BackendWithCompilerDemo", b"moo", []
51        )
52
53        def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
54            return torch.ops.higher_order.executorch_call_delegate(lowered_module, x, y)
55
56        orig_res = f(*inputs)
57        gm = export(
58            WrapperModule(f),
59            inputs,
60        )
61        FileCheck().check("lowered_module_0").check(
62            "torch.ops.higher_order.executorch_call_delegate"
63        ).run(gm.graph_module.code)
64        self.assertTrue(torch.allclose(orig_res, gm.module()(*inputs)))
65
66    def test_to_backend(self) -> None:
67        """Check if we have patched a lowered module correctly (for delegation)"""
68
69        m = models.CompositeDelegateModule()
70
71        exec_prog = to_edge(
72            export(m, m.get_random_inputs()),
73            compile_config=EdgeCompileConfig(_check_ir_validity=False),
74        ).to_executorch()  # TODO(larryliu): fix split_copy.Tensor
75        graph_module = exec_prog.exported_program().graph_module
76        program = exec_prog._emitter_output.program
77
78        # Check that there exists a call_delegate, representing the call to the
79        # delegated function
80        FileCheck().check("lowered_module_0").check(
81            "torch.ops.higher_order.executorch_call_delegate"
82        ).run(graph_module.code)
83
84        # Check that there does not exist an add node (from the non-delegated
85        # BasicModuleAdd.forward function)
86        self.assertTrue(
87            exir_ops.edge.aten.add.default
88            not in {node.target for node in graph_module.graph.nodes}
89        )
90
91        for node in graph_module.graph.nodes:
92            if (
93                node.op == "call_function"
94                and node.target == torch.ops.higher_order.executorch_call_delegate
95            ):
96                # Check that the first argument is the lowered backend module
97                # (which we got from a getattr)
98                self.assertEqual(node.args[0].op, "get_attr")
99                get_attr_backend = getattr(graph_module, node.args[0].target)
100                self.assertEqual(
101                    get_attr_backend._backend_id, m.lowered_module._backend_id
102                )
103                self.assertEqual(
104                    get_attr_backend._processed_bytes, m.lowered_module._processed_bytes
105                )
106                self.assertEqual(
107                    get_attr_backend._compile_specs, m.lowered_module._compile_specs
108                )
109
110        # Check the BackendDelegate object itself
111        delegate: BackendDelegate = program.execution_plan[0].delegates[0]
112        self.assertEqual(delegate.id, "backend_demo")
113        processed: BackendDelegateDataReference = delegate.processed
114        self.assertEqual(processed.location, DataLocation.INLINE)
115        self.assertLess(processed.index, len(program.backend_delegate_data))
116        self.assertEqual(
117            program.backend_delegate_data[processed.index].data, b"basic_module_add"
118        )
119
120        # Check the delegate instruction
121        self.assertTrue(
122            isinstance(
123                program.execution_plan[0].chains[0].instructions[0].instr_args,
124                DelegateCall,
125            )
126        )
127
128    def test_cannot_assign_attr(self) -> None:
129        deleg = LoweredBackendModule(None, "", b"", [])  # pyre-ignore
130        with self.assertRaises(AttributeError):
131            deleg.backend_id = "123"  # pyre-ignore
132
133    def test_create_submodule_single_return(self) -> None:
134        """
135        Original graph:
136            add_tensor = add(x, y)
137            mul_tensor = mul(add_tensor, y)
138            sub_tensor = sub(mul_tensor, y)
139            div_tensor = div(sub_tensor, y)
140            return [div_tensor]
141
142        Partitioned graph:
143            add_tensor = add(x, y)
144            mul_tensor = mul(add_tensor, y)
145            return [mul_tensor]  # Output is pytree.flatten-ed
146
147        Final graph:
148            partitioned_res = partitioned_graph(x, y)
149            getitem_0 = partitioned_res[0]
150            sub_tensor = sub(getitem_0, y)
151            div_tensor = div(sub_tensor, y)
152            return [div_tensor]
153        """
154        inputs = (torch.randn(1, 3), torch.randn(1, 3))
155
156        class Model(torch.nn.Module):
157            def __init__(self):
158                super().__init__()
159
160            def forward(self, x, y):
161                x = x + y
162                x = x * y
163                x = x - y
164                x = x / y
165                return x
166
167        orig_res = Model()(*inputs)
168        prog = to_edge(export(Model(), inputs))
169        gm = prog.exported_program().graph_module
170
171        node_list = []
172        for node in gm.graph.nodes:
173            if node.op == "call_function" and node.target in {
174                exir_ops.edge.aten.add.Tensor,
175                exir_ops.edge.aten.mul.Tensor,
176            }:
177                node_list.append(node)
178
179        sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag")
180        sub_gm.recompile()
181        gm.recompile()
182
183        for node in sub_gm.graph.nodes:
184            if node.op == "output":
185                self.assertEqual(len(node.args), 1)
186                self.assertTrue(isinstance(node.args[0], list))
187                self.assertEqual(len(node.args[0]), 1)
188
189        new_res = prog.exported_program().module()(*inputs)
190        self.assertTrue(torch.allclose(new_res, orig_res))
191
192    def test_create_submodule_multiple_return(self) -> None:
193        """
194        Original graph:
195            add_tensor = add(x, y)
196            mul_tensor = mul(add_tensor, y)
197            sub_tensor = sub(add_tensor, mul_tensor)
198            div_tensor = div(sub_tensor, mul_tensor)
199            return [div_tensor]
200
201        Partitioned graph:
202            add_tensor = add(x, y)
203            mul_tensor = mul(add_tensor, y)
204            return [add_tensor, mul_tensor]
205
206        Final graph:
207            partitioned_res = partitioned_graph(x, y)
208            getitem_0 = partitioned_res[0]
209            getitem_1 = partitioned_res[1]
210            sub_tensor = sub(getitem_0, getitem_1)
211            div_tensor = div(sub_tensor, getitem_1)
212            return [div_tensor]
213        """
214        inputs = (torch.randn(1, 3), torch.randn(1, 3))
215
216        class Model(torch.nn.Module):
217            def __init__(self):
218                super().__init__()
219
220            def forward(self, x, y):
221                x = x + y
222                y = x * y
223                x = x - y
224                x = x / y
225                return x
226
227        orig_res = Model()(*inputs)
228        prog = to_edge(export(Model(), inputs))
229        gm = prog.exported_program().graph_module
230
231        node_list = []
232        for node in gm.graph.nodes:
233            if node.op == "call_function" and node.target in {
234                exir_ops.edge.aten.add.Tensor,
235                exir_ops.edge.aten.mul.Tensor,
236            }:
237                node_list.append(node)
238
239        sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag")
240        sub_gm.recompile()
241        gm.recompile()
242
243        for node in sub_gm.graph.nodes:
244            if node.op == "output":
245                self.assertEqual(len(node.args), 1)
246                self.assertTrue(isinstance(node.args[0], list))
247                self.assertEqual(len(node.args[0]), 2)
248
249        new_res = prog.exported_program().module()(*inputs)
250        self.assertTrue(torch.allclose(new_res, orig_res))
251
252    def test_create_submodule_list_return(self) -> None:
253        """
254        Original graph:
255            split_tensor = split(x, 5)
256            getitem_0 = split_tensor[0]
257            sub_tensor = sub(getitem_0, y)
258            div_tensor = div(sub_tensor, y)
259            return [div_tensor]
260
261        Partitioned graph:
262            split_tensor = split(x, 5)
263            getitem_0 = split_tensor[0]
264            getitem_1 = split_tensor[1]
265            return [getitem_0, getitem_1]  # List output is "opened"
266
267        Final graph:
268            partitioned_res = partitioned_graph(x, y)
269            getitem_0 = partitioned_res[0]
270            sub_tensor = sub(getitem_0, y)
271            div_tensor = div(sub_tensor, y)
272            return [div_tensor]
273        """
274        inputs = (torch.randn(10), torch.randn(5))
275
276        class Model(torch.nn.Module):
277            def __init__(self):
278                super().__init__()
279
280            def forward(self, x, y):
281                x = torch.split(x, 5)
282                x = x[0] - y
283                x = x / y
284                return x
285
286        orig_res = Model()(*inputs)
287        prog = to_edge(export(Model(), inputs))
288        gm = prog.exported_program().graph_module
289
290        node_list = []
291        for node in gm.graph.nodes:
292            # TODO(ssjia): split.Tensor now gets decomposed to split_with_sizes. Due to how executorch uses a pinned Pytorch
293            # nightly, the CI may not catch the changes to Pytorch's core decomposition table. As a temporary workaround,
294            # make the test backwards compatible with the old decomposition table. Remove the or statement once Pytorch nightly
295            # has been updated.
296            if node.op == "call_function" and (
297                node.target == exir_ops.edge.aten.split_with_sizes_copy.default
298                or node.target == exir_ops.edge.aten.split_copy.Tensor
299            ):
300                node_list.append(node)
301
302        sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag")
303
304        for node in sub_gm.graph.nodes:
305            if node.op == "output":
306                self.assertEqual(len(node.args), 1)
307                self.assertTrue(isinstance(node.args[0], list))
308                self.assertEqual(len(node.args[0]), 2)
309
310        new_res = prog.exported_program().module()(*inputs)
311        self.assertTrue(torch.allclose(new_res, orig_res))
312