1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import executorch.exir.tests.models as models 10 11import torch 12from executorch.exir import EdgeCompileConfig, to_edge 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.lowered_backend_module import ( 15 create_submodule_from_nodes, 16 LoweredBackendModule, 17) 18from executorch.exir.schema import ( 19 BackendDelegate, 20 BackendDelegateDataReference, 21 DataLocation, 22 DelegateCall, 23) 24from executorch.exir.tests.common import register_additional_test_aten_ops 25from torch.export import export 26from torch.testing import FileCheck 27 28 29class WrapperModule(torch.nn.Module): 30 def __init__(self, fn): 31 super().__init__() 32 self.fn = fn 33 34 def forward(self, *args, **kwargs): 35 return self.fn(*args, **kwargs) 36 37 38class TestDelegate(unittest.TestCase): 39 @classmethod 40 def setUpClass(cls) -> None: 41 register_additional_test_aten_ops() 42 43 def test_call_delegate(self) -> None: 44 def g(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 45 return x + y 46 47 inputs = (torch.ones(1, 3), torch.ones(1, 3)) 48 edge_ir_m = to_edge(export(WrapperModule(g), inputs)) 49 lowered_module: LoweredBackendModule = LoweredBackendModule( 50 edge_ir_m.exported_program(), "BackendWithCompilerDemo", b"moo", [] 51 ) 52 53 def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 54 return torch.ops.higher_order.executorch_call_delegate(lowered_module, x, y) 55 56 orig_res = f(*inputs) 57 gm = export( 58 WrapperModule(f), 59 inputs, 60 ) 61 FileCheck().check("lowered_module_0").check( 62 "torch.ops.higher_order.executorch_call_delegate" 63 ).run(gm.graph_module.code) 64 self.assertTrue(torch.allclose(orig_res, gm.module()(*inputs))) 65 66 def test_to_backend(self) -> None: 67 """Check if we have patched a lowered module correctly (for delegation)""" 68 69 m = models.CompositeDelegateModule() 70 71 exec_prog = to_edge( 72 export(m, m.get_random_inputs()), 73 compile_config=EdgeCompileConfig(_check_ir_validity=False), 74 ).to_executorch() # TODO(larryliu): fix split_copy.Tensor 75 graph_module = exec_prog.exported_program().graph_module 76 program = exec_prog._emitter_output.program 77 78 # Check that there exists a call_delegate, representing the call to the 79 # delegated function 80 FileCheck().check("lowered_module_0").check( 81 "torch.ops.higher_order.executorch_call_delegate" 82 ).run(graph_module.code) 83 84 # Check that there does not exist an add node (from the non-delegated 85 # BasicModuleAdd.forward function) 86 self.assertTrue( 87 exir_ops.edge.aten.add.default 88 not in {node.target for node in graph_module.graph.nodes} 89 ) 90 91 for node in graph_module.graph.nodes: 92 if ( 93 node.op == "call_function" 94 and node.target == torch.ops.higher_order.executorch_call_delegate 95 ): 96 # Check that the first argument is the lowered backend module 97 # (which we got from a getattr) 98 self.assertEqual(node.args[0].op, "get_attr") 99 get_attr_backend = getattr(graph_module, node.args[0].target) 100 self.assertEqual( 101 get_attr_backend._backend_id, m.lowered_module._backend_id 102 ) 103 self.assertEqual( 104 get_attr_backend._processed_bytes, m.lowered_module._processed_bytes 105 ) 106 self.assertEqual( 107 get_attr_backend._compile_specs, m.lowered_module._compile_specs 108 ) 109 110 # Check the BackendDelegate object itself 111 delegate: BackendDelegate = program.execution_plan[0].delegates[0] 112 self.assertEqual(delegate.id, "backend_demo") 113 processed: BackendDelegateDataReference = delegate.processed 114 self.assertEqual(processed.location, DataLocation.INLINE) 115 self.assertLess(processed.index, len(program.backend_delegate_data)) 116 self.assertEqual( 117 program.backend_delegate_data[processed.index].data, b"basic_module_add" 118 ) 119 120 # Check the delegate instruction 121 self.assertTrue( 122 isinstance( 123 program.execution_plan[0].chains[0].instructions[0].instr_args, 124 DelegateCall, 125 ) 126 ) 127 128 def test_cannot_assign_attr(self) -> None: 129 deleg = LoweredBackendModule(None, "", b"", []) # pyre-ignore 130 with self.assertRaises(AttributeError): 131 deleg.backend_id = "123" # pyre-ignore 132 133 def test_create_submodule_single_return(self) -> None: 134 """ 135 Original graph: 136 add_tensor = add(x, y) 137 mul_tensor = mul(add_tensor, y) 138 sub_tensor = sub(mul_tensor, y) 139 div_tensor = div(sub_tensor, y) 140 return [div_tensor] 141 142 Partitioned graph: 143 add_tensor = add(x, y) 144 mul_tensor = mul(add_tensor, y) 145 return [mul_tensor] # Output is pytree.flatten-ed 146 147 Final graph: 148 partitioned_res = partitioned_graph(x, y) 149 getitem_0 = partitioned_res[0] 150 sub_tensor = sub(getitem_0, y) 151 div_tensor = div(sub_tensor, y) 152 return [div_tensor] 153 """ 154 inputs = (torch.randn(1, 3), torch.randn(1, 3)) 155 156 class Model(torch.nn.Module): 157 def __init__(self): 158 super().__init__() 159 160 def forward(self, x, y): 161 x = x + y 162 x = x * y 163 x = x - y 164 x = x / y 165 return x 166 167 orig_res = Model()(*inputs) 168 prog = to_edge(export(Model(), inputs)) 169 gm = prog.exported_program().graph_module 170 171 node_list = [] 172 for node in gm.graph.nodes: 173 if node.op == "call_function" and node.target in { 174 exir_ops.edge.aten.add.Tensor, 175 exir_ops.edge.aten.mul.Tensor, 176 }: 177 node_list.append(node) 178 179 sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag") 180 sub_gm.recompile() 181 gm.recompile() 182 183 for node in sub_gm.graph.nodes: 184 if node.op == "output": 185 self.assertEqual(len(node.args), 1) 186 self.assertTrue(isinstance(node.args[0], list)) 187 self.assertEqual(len(node.args[0]), 1) 188 189 new_res = prog.exported_program().module()(*inputs) 190 self.assertTrue(torch.allclose(new_res, orig_res)) 191 192 def test_create_submodule_multiple_return(self) -> None: 193 """ 194 Original graph: 195 add_tensor = add(x, y) 196 mul_tensor = mul(add_tensor, y) 197 sub_tensor = sub(add_tensor, mul_tensor) 198 div_tensor = div(sub_tensor, mul_tensor) 199 return [div_tensor] 200 201 Partitioned graph: 202 add_tensor = add(x, y) 203 mul_tensor = mul(add_tensor, y) 204 return [add_tensor, mul_tensor] 205 206 Final graph: 207 partitioned_res = partitioned_graph(x, y) 208 getitem_0 = partitioned_res[0] 209 getitem_1 = partitioned_res[1] 210 sub_tensor = sub(getitem_0, getitem_1) 211 div_tensor = div(sub_tensor, getitem_1) 212 return [div_tensor] 213 """ 214 inputs = (torch.randn(1, 3), torch.randn(1, 3)) 215 216 class Model(torch.nn.Module): 217 def __init__(self): 218 super().__init__() 219 220 def forward(self, x, y): 221 x = x + y 222 y = x * y 223 x = x - y 224 x = x / y 225 return x 226 227 orig_res = Model()(*inputs) 228 prog = to_edge(export(Model(), inputs)) 229 gm = prog.exported_program().graph_module 230 231 node_list = [] 232 for node in gm.graph.nodes: 233 if node.op == "call_function" and node.target in { 234 exir_ops.edge.aten.add.Tensor, 235 exir_ops.edge.aten.mul.Tensor, 236 }: 237 node_list.append(node) 238 239 sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag") 240 sub_gm.recompile() 241 gm.recompile() 242 243 for node in sub_gm.graph.nodes: 244 if node.op == "output": 245 self.assertEqual(len(node.args), 1) 246 self.assertTrue(isinstance(node.args[0], list)) 247 self.assertEqual(len(node.args[0]), 2) 248 249 new_res = prog.exported_program().module()(*inputs) 250 self.assertTrue(torch.allclose(new_res, orig_res)) 251 252 def test_create_submodule_list_return(self) -> None: 253 """ 254 Original graph: 255 split_tensor = split(x, 5) 256 getitem_0 = split_tensor[0] 257 sub_tensor = sub(getitem_0, y) 258 div_tensor = div(sub_tensor, y) 259 return [div_tensor] 260 261 Partitioned graph: 262 split_tensor = split(x, 5) 263 getitem_0 = split_tensor[0] 264 getitem_1 = split_tensor[1] 265 return [getitem_0, getitem_1] # List output is "opened" 266 267 Final graph: 268 partitioned_res = partitioned_graph(x, y) 269 getitem_0 = partitioned_res[0] 270 sub_tensor = sub(getitem_0, y) 271 div_tensor = div(sub_tensor, y) 272 return [div_tensor] 273 """ 274 inputs = (torch.randn(10), torch.randn(5)) 275 276 class Model(torch.nn.Module): 277 def __init__(self): 278 super().__init__() 279 280 def forward(self, x, y): 281 x = torch.split(x, 5) 282 x = x[0] - y 283 x = x / y 284 return x 285 286 orig_res = Model()(*inputs) 287 prog = to_edge(export(Model(), inputs)) 288 gm = prog.exported_program().graph_module 289 290 node_list = [] 291 for node in gm.graph.nodes: 292 # TODO(ssjia): split.Tensor now gets decomposed to split_with_sizes. Due to how executorch uses a pinned Pytorch 293 # nightly, the CI may not catch the changes to Pytorch's core decomposition table. As a temporary workaround, 294 # make the test backwards compatible with the old decomposition table. Remove the or statement once Pytorch nightly 295 # has been updated. 296 if node.op == "call_function" and ( 297 node.target == exir_ops.edge.aten.split_with_sizes_copy.default 298 or node.target == exir_ops.edge.aten.split_copy.Tensor 299 ): 300 node_list.append(node) 301 302 sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag") 303 304 for node in sub_gm.graph.nodes: 305 if node.op == "output": 306 self.assertEqual(len(node.args), 1) 307 self.assertTrue(isinstance(node.args[0], list)) 308 self.assertEqual(len(node.args[0]), 2) 309 310 new_res = prog.exported_program().module()(*inputs) 311 self.assertTrue(torch.allclose(new_res, orig_res)) 312