xref: /aosp_15_r20/external/executorch/exir/backend/test/test_lowered_backend_module.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import executorch.exir.tests.models as models
10
11import torch
12from executorch import exir
13from executorch.exir import to_edge
14from executorch.exir.backend.backend_api import to_backend
15from executorch.exir.backend.compile_spec_schema import CompileSpec
16from executorch.exir.backend.test.backend_with_compiler_demo import (
17    BackendWithCompilerDemo,
18)
19from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
20from executorch.exir.schema import DelegateCall, Program
21
22from executorch.extension.pybindings.portable_lib import (  # @manual
23    _load_for_executorch_from_buffer,
24)
25from hypothesis import given, settings, strategies as st
26from torch.export import export
27
28
29class TestBackendAPI(unittest.TestCase):
30    def validate_lowered_module_program(self, program: Program) -> None:
31        """
32        For any program emitted from lowered_backend_module, we expect only one delegate call
33        """
34        # there should only be one instruction
35        self.assertEqual(
36            len(program.execution_plan[0].chains[0].instructions),
37            1,
38        )
39
40        # the only instruction should be a delegate call
41        self.assertTrue(
42            isinstance(
43                program.execution_plan[0].chains[0].instructions[0].instr_args,
44                DelegateCall,
45            )
46        )
47
48    def get_program_from_wrapped_module(
49        self, lowered_module, example_inputs, edge_compile_config
50    ):
51        class WrappedModule(torch.nn.Module):
52            def __init__(self):
53                super().__init__()
54                self.one_module = lowered_module
55
56            def forward(self, *args):
57                return self.one_module(*args)
58
59        return (
60            to_edge(
61                export(WrappedModule(), example_inputs),
62                compile_config=edge_compile_config,
63            )
64            .to_executorch()
65            .executorch_program
66        )
67
68    @settings(deadline=500000)
69    def test_emit_lowered_backend_module_end_to_end(self):
70        class SinModule(torch.nn.Module):
71            def __init__(self):
72                super().__init__()
73
74            def forward(self, x):
75                return torch.sin(x)
76
77        sin_module = SinModule()
78        model_inputs = (torch.ones(1),)
79        expected_res = sin_module(*model_inputs)
80        edgeir_m = to_edge(
81            export(
82                sin_module,
83                model_inputs,
84            ),
85            compile_config=exir.EdgeCompileConfig(
86                _check_ir_validity=False, _use_edge_ops=True
87            ),
88        )
89        max_value = model_inputs[0].shape[0]
90        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
91        lowered_sin_module = to_backend(
92            BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs
93        )
94
95        new_res = lowered_sin_module(*model_inputs)
96
97        self.assertTrue(torch.allclose(new_res[0], expected_res))
98        program = lowered_sin_module.program()
99        self.validate_lowered_module_program(program)
100        buff = lowered_sin_module.buffer()
101
102        executorch_module = _load_for_executorch_from_buffer(buff)
103        model_inputs = torch.ones(1)
104        model_outputs = executorch_module.forward([model_inputs])
105        self.assertEqual(
106            model_inputs,
107            torch.ones(1),
108        )
109        expected_res = 0.8333 * torch.ones(1)
110
111        self.assertTrue(
112            torch.allclose(model_outputs[0], expected_res, atol=1e-03, rtol=1e-03)
113        )
114
115    @given(
116        unlift=st.booleans(),  # verify both lifted and unlifted graph
117    )
118    @settings(deadline=500000)
119    def test_emit_lowered_backend_module(self, unlift):
120        module_list = [
121            models.Emformer(),
122            models.Repeat(),
123            models.ElementwiseAdd(),
124            models.MLP(),
125            models.ModelWithUnusedArg(),
126        ]
127
128        edge_compile_config = exir.EdgeCompileConfig(
129            _check_ir_validity=False, _use_edge_ops=True
130        )
131
132        for model in module_list:
133            model_inputs = model.get_random_inputs()
134
135            edgeir_m = to_edge(
136                export(model, model_inputs), compile_config=edge_compile_config
137            )
138            lowered_model = to_backend(
139                QnnBackend.__name__, edgeir_m.exported_program(), []
140            )
141            program = lowered_model.program()
142            reference_program = self.get_program_from_wrapped_module(
143                lowered_model, model_inputs, edge_compile_config
144            )
145
146            # Check program is fairly equal to the reference program
147            self.assertEqual(
148                len(program.execution_plan[0].chains[0].instructions),
149                len(reference_program.execution_plan[0].chains[0].instructions),
150            )
151
152            self.assertEqual(
153                len(program.execution_plan[0].values),
154                len(reference_program.execution_plan[0].values),
155            )
156
157            self.assertEqual(
158                len(program.execution_plan[0].inputs),
159                len(reference_program.execution_plan[0].inputs),
160            )
161
162            self.assertEqual(
163                len(program.execution_plan[0].outputs),
164                len(reference_program.execution_plan[0].outputs),
165            )
166
167            # Ensure we can get the buffer
168            _ = lowered_model.buffer()
169            self.validate_lowered_module_program(program)
170
171    @given(
172        unlift=st.booleans(),  # verify both lifted and unlifted graph
173    )
174    @settings(deadline=500000)
175    def test_emit_nested_lowered_backend_module(self, unlift):
176        module_list = [
177            models.Emformer(),
178            models.Repeat(),
179            models.ElementwiseAdd(),
180            models.MLP(),
181            models.ModelWithUnusedArg(),
182        ]
183
184        edge_compile_config = exir.EdgeCompileConfig(
185            _check_ir_validity=False, _use_edge_ops=True
186        )
187
188        for model in module_list:
189            model_inputs = model.get_random_inputs()
190
191            edgeir_m = to_edge(
192                export(model, model_inputs), compile_config=edge_compile_config
193            )
194            lowered_module = to_backend(
195                QnnBackend.__name__, edgeir_m.exported_program(), []
196            )
197
198            # This module will include one operator and two delegate call
199            class WrappedModule(torch.nn.Module):
200                def __init__(self, lowered_module):
201                    super().__init__()
202                    self.one_module = lowered_module
203
204                def forward(self, *args):
205                    return self.one_module(*args)
206
207            wrapped_module = WrappedModule(lowered_module)
208            wrapped_module_edge = to_edge(
209                export(wrapped_module, model_inputs), compile_config=edge_compile_config
210            )
211
212            nested_lowered_model = to_backend(
213                QnnBackend.__name__, wrapped_module_edge.exported_program(), []
214            )
215
216            program = nested_lowered_model.program()
217            self.validate_lowered_module_program(program)
218