1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8from typing import Iterator, Union 9 10import torch 11from executorch import exir 12from executorch.exir.backend.backend_api import to_backend 13from executorch.exir.backend.test.backend_with_delegate_mapping_demo import ( 14 BackendWithDelegateMappingDemo, 15) 16 17from executorch.exir.backend.utils import DelegateMappingBuilder 18 19 20class TestDelegateMapBuilder(unittest.TestCase): 21 def setUp(self) -> None: 22 class Model(torch.nn.Module): 23 def __init__(self): 24 super().__init__() 25 26 def forward(self, x): 27 y = torch.sin(x) 28 return torch.cos(y) 29 30 model = Model() 31 model_inputs = (torch.ones(1, 1),) 32 program = ( 33 exir.capture(model, model_inputs, exir.CaptureConfig(pt2_mode=True)) 34 .to_edge() 35 .to_executorch() 36 ) 37 38 # Create nodes for testing mapping 39 # nodes: [arg0_1, alloc, aten_sin_default, alloc_1, aten_cos_default, output] 40 # debug handles: [0, None, 1, None, 2, 3] 41 self.nodes = list(program.graph_module.graph.nodes) 42 43 self.handles = [node.meta.get("debug_handle") for node in self.nodes] 44 45 def test_basic_generated_identifier(self): 46 delegate_builder = DelegateMappingBuilder(generated_identifiers=True) 47 48 expected_mapping = {0: (1, 2, 3, 4)} 49 self.assertEqual( 50 delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 0 51 ) 52 self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) 53 54 expected_mapping = {0: (1, 2, 3, 4), 1: (1,)} 55 self.assertEqual( 56 delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes[0]), 1 57 ) 58 self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) 59 60 expected_mapping = {0: (1, 2, 3, 4), 1: (1,), 2: (2,)} 61 self.assertEqual( 62 delegate_builder.insert_delegate_mapping_entry(handles=self.handles[2]), 63 2, 64 ) 65 self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) 66 67 expected_mapping = { 68 0: (1, 2, 3, 4), 69 1: (1,), 70 2: (2,), 71 3: (1, 2, 3, 4), 72 } 73 self.assertEqual( 74 delegate_builder.insert_delegate_mapping_entry(handles=self.handles), 3 75 ) 76 self.assertEqual(delegate_builder.get_delegate_mapping(), expected_mapping) 77 78 def test_basic_manual_int_identifier(self): 79 self._test_basic_manual_identifier(iter([22, 55])) 80 81 def test_basic_manual_string_identifier(self): 82 self._test_basic_manual_identifier(iter(["22", "55"])) 83 84 def test_adding_manual_identifier_when_generated(self): 85 delegate_builder = DelegateMappingBuilder(generated_identifiers=True) 86 87 self.assertRaises( 88 Exception, 89 lambda: delegate_builder.insert_delegate_mapping_entry( 90 nodes=self.nodes, identifier="22" 91 ), 92 ) 93 self.assertRaises( 94 Exception, 95 lambda: delegate_builder.insert_delegate_mapping_entry( 96 handles=self.handles, identifier="22" 97 ), 98 ) 99 100 def test_omitting_identifier_when_not_generated(self): 101 delegate_builder = DelegateMappingBuilder() 102 103 self.assertRaises( 104 Exception, 105 lambda: delegate_builder.insert_delegate_mapping_entry(nodes=self.nodes), 106 ) 107 self.assertRaises( 108 Exception, 109 lambda: delegate_builder.insert_delegate_mapping_entry( 110 handles=self.handles 111 ), 112 ) 113 114 def test_reinsert_delegate_debug_identifier(self): 115 delegate_builder = DelegateMappingBuilder() 116 delegate_builder.insert_delegate_mapping_entry( 117 nodes=self.nodes[0], identifier="1" 118 ) 119 120 self.assertRaises( 121 Exception, 122 lambda: delegate_builder.insert_delegate_mapping_entry( 123 nodes=self.nodes[0], identifier="1" 124 ), 125 ) 126 self.assertRaises( 127 Exception, 128 lambda: delegate_builder.insert_delegate_mapping_entry( 129 handles=self.handles[0], identifier="1" 130 ), 131 ) 132 133 def test_backend_with_delegate_mapping(self) -> None: 134 model, inputs = BackendWithDelegateMappingDemo.get_test_model_and_inputs() 135 edgeir_m = exir.capture(model, inputs, exir.CaptureConfig()).to_edge( 136 exir.EdgeCompileConfig(_check_ir_validity=False) 137 ) 138 lowered_module = to_backend( 139 "BackendWithDelegateMappingDemo", edgeir_m.exported_program, [] 140 ) 141 debug_handle_map = lowered_module.meta.get("debug_handle_map") 142 self.assertIsNotNone(debug_handle_map) 143 # There should be 3 backend ops in this model. 144 self.assertEqual(len(debug_handle_map), 5) 145 # Check to see that all the delegate debug indexes in the range [0,2] are present. 146 self.assertTrue( 147 all(element in debug_handle_map.keys() for element in [1, 2, 3, 4]) 148 ) 149 150 class CompositeModule(torch.nn.Module): 151 def __init__(self): 152 super().__init__() 153 self.lowered_module = lowered_module 154 155 def forward(self, x): 156 return self.lowered_module(x) 157 158 composite_model = CompositeModule() 159 # TODO: Switch this to lowered_module.program() once lowered_module has support 160 # for storing debug delegate identifier maps. 161 exir.capture( 162 composite_model, inputs, exir.CaptureConfig() 163 ).to_edge().to_executorch() 164 165 def test_passing_both_nodes_and_handles(self): 166 delegate_builder = DelegateMappingBuilder() 167 168 self.assertRaises( 169 Exception, 170 lambda: delegate_builder.insert_delegate_mapping_entry( 171 nodes=self.nodes, handles=self.handles 172 ), 173 ) 174 175 def test_missing_handle_filtering(self): 176 delegate_builder = DelegateMappingBuilder() 177 self.assertRaises( 178 Exception, 179 lambda: delegate_builder.insert_delegate_mapping_entry(handles=[None]), 180 ) 181 self.assertRaises( 182 Exception, 183 lambda: delegate_builder.insert_delegate_mapping_entry(nodes=[None]), 184 ) 185 186 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 187 188 def _test_basic_manual_identifier(self, identifiers: Iterator[Union[int, str]]): 189 """ 190 Using the iteration of identifiers: 191 1) Create a Delegate Map Builder 192 2) Add an entry with a list of Nodes using the first identifier 193 3) Add an entry with a single node using the second identifier 194 195 Verify behavior results 196 """ 197 198 delegate_builder_nodes = DelegateMappingBuilder() 199 delegate_builder_handles = DelegateMappingBuilder() 200 201 # Entry with a list of nodes 202 iden_1 = next(identifiers) 203 expected_mapping = {iden_1: (1, 2, 3, 4)} 204 self.assertEqual( 205 delegate_builder_nodes.insert_delegate_mapping_entry( 206 nodes=self.nodes, identifier=iden_1 207 ), 208 iden_1, 209 ) 210 self.assertEqual( 211 delegate_builder_handles.insert_delegate_mapping_entry( 212 handles=self.handles, identifier=iden_1 213 ), 214 iden_1, 215 ) 216 self.assertEqual( 217 delegate_builder_nodes.get_delegate_mapping(), expected_mapping 218 ) 219 self.assertEqual( 220 delegate_builder_handles.get_delegate_mapping(), expected_mapping 221 ) 222 223 # Entry with a single node 224 iden_2 = next(identifiers) 225 expected_mapping = {iden_1: (1, 2, 3, 4), iden_2: (1,)} 226 self.assertEqual( 227 delegate_builder_nodes.insert_delegate_mapping_entry( 228 nodes=self.nodes[0], identifier=iden_2 229 ), 230 iden_2, 231 ) 232 self.assertEqual( 233 delegate_builder_handles.insert_delegate_mapping_entry( 234 handles=self.handles[0], identifier=iden_2 235 ), 236 iden_2, 237 ) 238 self.assertEqual( 239 delegate_builder_nodes.get_delegate_mapping(), expected_mapping 240 ) 241 self.assertEqual( 242 delegate_builder_handles.get_delegate_mapping(), expected_mapping 243 ) 244