1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import executorch.exir.tests.models as models 10 11import torch 12from executorch import exir 13from executorch.exir import to_edge 14from executorch.exir.backend.backend_api import to_backend 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from executorch.exir.backend.test.backend_with_compiler_demo import ( 17 BackendWithCompilerDemo, 18) 19from executorch.exir.backend.test.qnn_backend_demo import QnnBackend 20from executorch.exir.schema import DelegateCall, Program 21 22from executorch.extension.pybindings.portable_lib import ( # @manual 23 _load_for_executorch_from_buffer, 24) 25from hypothesis import given, settings, strategies as st 26from torch.export import export 27 28 29class TestBackendAPI(unittest.TestCase): 30 def validate_lowered_module_program(self, program: Program) -> None: 31 """ 32 For any program emitted from lowered_backend_module, we expect only one delegate call 33 """ 34 # there should only be one instruction 35 self.assertEqual( 36 len(program.execution_plan[0].chains[0].instructions), 37 1, 38 ) 39 40 # the only instruction should be a delegate call 41 self.assertTrue( 42 isinstance( 43 program.execution_plan[0].chains[0].instructions[0].instr_args, 44 DelegateCall, 45 ) 46 ) 47 48 def get_program_from_wrapped_module( 49 self, lowered_module, example_inputs, edge_compile_config 50 ): 51 class WrappedModule(torch.nn.Module): 52 def __init__(self): 53 super().__init__() 54 self.one_module = lowered_module 55 56 def forward(self, *args): 57 return self.one_module(*args) 58 59 return ( 60 to_edge( 61 export(WrappedModule(), example_inputs), 62 compile_config=edge_compile_config, 63 ) 64 .to_executorch() 65 .executorch_program 66 ) 67 68 @settings(deadline=500000) 69 def test_emit_lowered_backend_module_end_to_end(self): 70 class SinModule(torch.nn.Module): 71 def __init__(self): 72 super().__init__() 73 74 def forward(self, x): 75 return torch.sin(x) 76 77 sin_module = SinModule() 78 model_inputs = (torch.ones(1),) 79 expected_res = sin_module(*model_inputs) 80 edgeir_m = to_edge( 81 export( 82 sin_module, 83 model_inputs, 84 ), 85 compile_config=exir.EdgeCompileConfig( 86 _check_ir_validity=False, _use_edge_ops=True 87 ), 88 ) 89 max_value = model_inputs[0].shape[0] 90 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 91 lowered_sin_module = to_backend( 92 BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs 93 ) 94 95 new_res = lowered_sin_module(*model_inputs) 96 97 self.assertTrue(torch.allclose(new_res[0], expected_res)) 98 program = lowered_sin_module.program() 99 self.validate_lowered_module_program(program) 100 buff = lowered_sin_module.buffer() 101 102 executorch_module = _load_for_executorch_from_buffer(buff) 103 model_inputs = torch.ones(1) 104 model_outputs = executorch_module.forward([model_inputs]) 105 self.assertEqual( 106 model_inputs, 107 torch.ones(1), 108 ) 109 expected_res = 0.8333 * torch.ones(1) 110 111 self.assertTrue( 112 torch.allclose(model_outputs[0], expected_res, atol=1e-03, rtol=1e-03) 113 ) 114 115 @given( 116 unlift=st.booleans(), # verify both lifted and unlifted graph 117 ) 118 @settings(deadline=500000) 119 def test_emit_lowered_backend_module(self, unlift): 120 module_list = [ 121 models.Emformer(), 122 models.Repeat(), 123 models.ElementwiseAdd(), 124 models.MLP(), 125 models.ModelWithUnusedArg(), 126 ] 127 128 edge_compile_config = exir.EdgeCompileConfig( 129 _check_ir_validity=False, _use_edge_ops=True 130 ) 131 132 for model in module_list: 133 model_inputs = model.get_random_inputs() 134 135 edgeir_m = to_edge( 136 export(model, model_inputs), compile_config=edge_compile_config 137 ) 138 lowered_model = to_backend( 139 QnnBackend.__name__, edgeir_m.exported_program(), [] 140 ) 141 program = lowered_model.program() 142 reference_program = self.get_program_from_wrapped_module( 143 lowered_model, model_inputs, edge_compile_config 144 ) 145 146 # Check program is fairly equal to the reference program 147 self.assertEqual( 148 len(program.execution_plan[0].chains[0].instructions), 149 len(reference_program.execution_plan[0].chains[0].instructions), 150 ) 151 152 self.assertEqual( 153 len(program.execution_plan[0].values), 154 len(reference_program.execution_plan[0].values), 155 ) 156 157 self.assertEqual( 158 len(program.execution_plan[0].inputs), 159 len(reference_program.execution_plan[0].inputs), 160 ) 161 162 self.assertEqual( 163 len(program.execution_plan[0].outputs), 164 len(reference_program.execution_plan[0].outputs), 165 ) 166 167 # Ensure we can get the buffer 168 _ = lowered_model.buffer() 169 self.validate_lowered_module_program(program) 170 171 @given( 172 unlift=st.booleans(), # verify both lifted and unlifted graph 173 ) 174 @settings(deadline=500000) 175 def test_emit_nested_lowered_backend_module(self, unlift): 176 module_list = [ 177 models.Emformer(), 178 models.Repeat(), 179 models.ElementwiseAdd(), 180 models.MLP(), 181 models.ModelWithUnusedArg(), 182 ] 183 184 edge_compile_config = exir.EdgeCompileConfig( 185 _check_ir_validity=False, _use_edge_ops=True 186 ) 187 188 for model in module_list: 189 model_inputs = model.get_random_inputs() 190 191 edgeir_m = to_edge( 192 export(model, model_inputs), compile_config=edge_compile_config 193 ) 194 lowered_module = to_backend( 195 QnnBackend.__name__, edgeir_m.exported_program(), [] 196 ) 197 198 # This module will include one operator and two delegate call 199 class WrappedModule(torch.nn.Module): 200 def __init__(self, lowered_module): 201 super().__init__() 202 self.one_module = lowered_module 203 204 def forward(self, *args): 205 return self.one_module(*args) 206 207 wrapped_module = WrappedModule(lowered_module) 208 wrapped_module_edge = to_edge( 209 export(wrapped_module, model_inputs), compile_config=edge_compile_config 210 ) 211 212 nested_lowered_model = to_backend( 213 QnnBackend.__name__, wrapped_module_edge.exported_program(), [] 214 ) 215 216 program = nested_lowered_model.program() 217 self.validate_lowered_module_program(program) 218