xref: /aosp_15_r20/external/executorch/exir/backend/test/test_delegate_map_builder.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8from typing import Iterator, Union
9
10import torch
11from executorch import exir
12from executorch.exir.backend.backend_api import to_backend
13from executorch.exir.backend.test.backend_with_delegate_mapping_demo import (
14    BackendWithDelegateMappingDemo,
15)
16
17from executorch.exir.backend.utils import DelegateMappingBuilder
18
19
20class TestDelegateMapBuilder(unittest.TestCase):
21    def setUp(self) -> None:
22        class Model(torch.nn.Module):
23            def __init__(self):
24                super().__init__()
25
26            def forward(self, x):
27                y = torch.sin(x)
28                return torch.cos(y)
29
30        model = Model()
31        model_inputs = (torch.ones(1, 1),)
32        program = (
33            exir.capture(model, model_inputs, exir.CaptureConfig(pt2_mode=True))
34            .to_edge()
35            .to_executorch()
36        )
37
38        # Create nodes for testing mapping
39        # nodes: [arg0_1, alloc, aten_sin_default, alloc_1, aten_cos_default, output]
40        # debug handles: [0, None, 1, None, 2, 3]
41        self.nodes = list(program.graph_module.graph.nodes)
42
43        self.handles = [node.meta.get("debug_handle") for node in self.nodes]
44
45    def test_basic_generated_identifier(self):
46        delegate_builder = DelegateMappingBuilder(generated_identifiers=True)
47
48        expected_mapping = {0: (1, 2, 3, 4)}
49        self.assertEqual(
50            delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0
51        )
52        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
53
54        expected_mapping = {0: (1, 2, 3, 4), 1: (1,)}
55        self.assertEqual(
56            delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1
57        )
58        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
59
60        expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)}
61        self.assertEqual(
62            delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]),
63            2,
64        )
65        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
66
67        expected_mapping = {
68            0: (1, 2, 3, 4),
69            1: (1,),
70            2: (2,),
71            3: (1, 2, 3, 4),
72        }
73        self.assertEqual(
74            delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3
75        )
76        self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping)
77
78    def test_basic_manual_int_identifier(self):
79        self._test_basic_manual_identifier(iter([22, 55]))
80
81    def test_basic_manual_string_identifier(self):
82        self._test_basic_manual_identifier(iter(["22", "55"]))
83
84    def test_adding_manual_identifier_when_generated(self):
85        delegate_builder = DelegateMappingBuilder(generated_identifiers=True)
86
87        self.assertRaises(
88            Exception,
89            lambda: delegate_builder.insert_delegate_mapping_entry(
90                nodes=self.nodes, identifier="22"
91            ),
92        )
93        self.assertRaises(
94            Exception,
95            lambda: delegate_builder.insert_delegate_mapping_entry(
96                handles=self.handles, identifier="22"
97            ),
98        )
99
100    def test_omitting_identifier_when_not_generated(self):
101        delegate_builder = DelegateMappingBuilder()
102
103        self.assertRaises(
104            Exception,
105            lambda: delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes),
106        )
107        self.assertRaises(
108            Exception,
109            lambda: delegate_builder.insert_delegate_mapping_entry(
110                handles=self.handles
111            ),
112        )
113
114    def test_reinsert_delegate_debug_identifier(self):
115        delegate_builder = DelegateMappingBuilder()
116        delegate_builder.insert_delegate_mapping_entry(
117            nodes=self.nodes[0], identifier="1"
118        )
119
120        self.assertRaises(
121            Exception,
122            lambda: delegate_builder.insert_delegate_mapping_entry(
123                nodes=self.nodes[0], identifier="1"
124            ),
125        )
126        self.assertRaises(
127            Exception,
128            lambda: delegate_builder.insert_delegate_mapping_entry(
129                handles=self.handles[0], identifier="1"
130            ),
131        )
132
133    def test_backend_with_delegate_mapping(self) -> None:
134        model, inputs = BackendWithDelegateMappingDemo.get_test_model_and_inputs()
135        edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge(
136            exir.EdgeCompileConfig(_check_ir_validity=False)
137        )
138        lowered_module = to_backend(
139            "BackendWithDelegateMappingDemo", edgeir_m.exported_program, []
140        )
141        debug_handle_map = lowered_module.meta.get("debug_handle_map")
142        self.assertIsNotNone(debug_handle_map)
143        # There should be 3 backend ops in this model.
144        self.assertEqual(len(debug_handle_map), 5)
145        # Check to see that all the delegate debug indexes in the range [0,2] are present.
146        self.assertTrue(
147            all(element in debug_handle_map.keys() for element in [1, 2, 3, 4])
148        )
149
150        class CompositeModule(torch.nn.Module):
151            def __init__(self):
152                super().__init__()
153                self.lowered_module = lowered_module
154
155            def forward(self, x):
156                return self.lowered_module(x)
157
158        composite_model = CompositeModule()
159        # TODO: Switch this to lowered_module.program() once lowered_module has support
160        # for storing debug delegate identifier maps.
161        exir.capture(
162            composite_model, inputs, exir.CaptureConfig()
163        ).to_edge().to_executorch()
164
165    def test_passing_both_nodes_and_handles(self):
166        delegate_builder = DelegateMappingBuilder()
167
168        self.assertRaises(
169            Exception,
170            lambda: delegate_builder.insert_delegate_mapping_entry(
171                nodes=self.nodes, handles=self.handles
172            ),
173        )
174
175    def test_missing_handle_filtering(self):
176        delegate_builder = DelegateMappingBuilder()
177        self.assertRaises(
178            Exception,
179            lambda: delegate_builder.insert_delegate_mapping_entry(handles=[None]),
180        )
181        self.assertRaises(
182            Exception,
183            lambda: delegate_builder.insert_delegate_mapping_entry(nodes=[None]),
184        )
185
186    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
187
188    def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]):
189        """
190        Using the iteration of identifiers:
191        1) Create a Delegate Map Builder
192        2) Add an entry with a list of Nodes using the first identifier
193        3) Add an entry with a single node using the second identifier
194
195        Verify behavior results
196        """
197
198        delegate_builder_nodes = DelegateMappingBuilder()
199        delegate_builder_handles = DelegateMappingBuilder()
200
201        # Entry with a list of nodes
202        iden_1 = next(identifiers)
203        expected_mapping = {iden_1: (1, 2, 3, 4)}
204        self.assertEqual(
205            delegate_builder_nodes.insert_delegate_mapping_entry(
206                nodes=self.nodes, identifier=iden_1
207            ),
208            iden_1,
209        )
210        self.assertEqual(
211            delegate_builder_handles.insert_delegate_mapping_entry(
212                handles=self.handles, identifier=iden_1
213            ),
214            iden_1,
215        )
216        self.assertEqual(
217            delegate_builder_nodes.get_delegate_mapping(), expected_mapping
218        )
219        self.assertEqual(
220            delegate_builder_handles.get_delegate_mapping(), expected_mapping
221        )
222
223        # Entry with a single node
224        iden_2 = next(identifiers)
225        expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)}
226        self.assertEqual(
227            delegate_builder_nodes.insert_delegate_mapping_entry(
228                nodes=self.nodes[0], identifier=iden_2
229            ),
230            iden_2,
231        )
232        self.assertEqual(
233            delegate_builder_handles.insert_delegate_mapping_entry(
234                handles=self.handles[0], identifier=iden_2
235            ),
236            iden_2,
237        )
238        self.assertEqual(
239            delegate_builder_nodes.get_delegate_mapping(), expected_mapping
240        )
241        self.assertEqual(
242            delegate_builder_handles.get_delegate_mapping(), expected_mapping
243        )
244