1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8from typing import Dict, List 9 10import executorch.exir as exir 11import torch 12from executorch.exir import to_edge 13from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend 14from executorch.exir.backend.compile_spec_schema import CompileSpec 15from executorch.exir.backend.partitioner import ( 16 DelegationSpec, 17 Partitioner, 18 PartitionResult, 19) 20 21# import the backend implementation 22from executorch.exir.backend.test.backend_with_compiler_demo import ( 23 BackendWithCompilerDemo, 24) 25from executorch.exir.backend.test.hta_partitioner_demo import ( 26 HTAPartitionerMultiplePatternsDemo, 27 HTAPartitionerOnePatternDemo, 28) 29from executorch.exir.backend.test.op_partitioner_demo import ( 30 AddAttributePartitionerDemo, 31 AddMulPartitionerDemo, 32) 33from executorch.exir.backend.test.qnn_backend_demo import QnnBackend 34 35from executorch.exir.delegate import executorch_call_delegate 36from executorch.exir.dialects._ops import ops as exir_ops 37from executorch.exir.graph_module import get_control_flow_submodules 38from executorch.exir.lowered_backend_module import ( 39 get_lowered_backend_modules, 40 get_lowered_submodules, 41) 42from executorch.exir.print_program import print_program 43from executorch.exir.schema import ( 44 BackendDelegate, 45 BackendDelegateDataReference, 46 DataLocation, 47 DelegateCall, 48 Program, 49) 50 51from executorch.extension.pybindings.portable_lib import ( # @manual 52 _load_for_executorch_from_buffer, 53) 54from executorch.extension.pytree import tree_flatten 55 56from functorch.experimental import control_flow 57from torch.ao.quantization import get_default_qconfig_mapping # @manual 58from torch.ao.quantization.backend_config.executorch import ( 59 get_executorch_backend_config, 60) 61from torch.ao.quantization.quantize_fx import ( 62 _convert_to_reference_decomposed_fx, 63 prepare_fx, 64) 65from torch.export import export, ExportedProgram 66from torch.testing import FileCheck 67 68 69def vary_segments(test_method): 70 """A decorator that calls the test method with `extract_delegate_segments` set to 71 True and False. 72 73 Decorated test methods must expect a boolean parameter named 74 `extract_delegate_segments`, and they should pass that value to to_executorch() like: 75 76 m.to_executorch( 77 config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments) 78 ) 79 80 This will cause the delegate data blobs to be extracted from the program and 81 serialized as separate, freeable program segments. Backends should detect no 82 difference at runtime. 83 """ 84 85 def wrapper(self): 86 for extract_delegate_segments in [False, True]: 87 # subTest will create a different top-level test entry for each 88 # value, whose full names have a suffix like 89 # "(extract_delegate_segments=True)". 90 with self.subTest(extract_delegate_segments=extract_delegate_segments): 91 test_method(self, extract_delegate_segments=extract_delegate_segments) 92 93 return wrapper 94 95 96class TestBackends(unittest.TestCase): 97 def check_delegate_input( 98 self, delegate: LoweredBackendModule, input_len: int 99 ) -> None: 100 counter = 0 101 for node in delegate.original_module.graph.nodes: 102 if node.op == "placeholder": 103 counter += 1 104 self.assertEqual(counter, input_len) 105 106 def check_backend_delegate( 107 self, 108 program: Program, 109 delegate: BackendDelegate, 110 expected_id: str, 111 expected_processed: bytes, 112 ) -> None: 113 self.assertEqual(delegate.id, expected_id) 114 processed: BackendDelegateDataReference = delegate.processed 115 self.assertEqual(processed.location, DataLocation.INLINE) 116 self.assertLess(processed.index, len(program.backend_delegate_data)) 117 self.assertEqual( 118 program.backend_delegate_data[processed.index].data, expected_processed 119 ) 120 121 def test_simple(self): 122 class SinModule(torch.nn.Module): 123 def __init__(self): 124 super().__init__() 125 126 def forward(self, x): 127 return torch.sin(x) 128 129 sin_module = SinModule() 130 model_inputs = (torch.ones(1),) 131 expected_res = sin_module(*model_inputs) 132 edgeir_m = to_edge(export(sin_module, model_inputs)) 133 134 lowered_sin_module = to_backend( 135 "BackendWithCompilerDemo", edgeir_m.exported_program(), [] 136 ) 137 new_res = lowered_sin_module(*model_inputs) 138 139 self.assertTrue(torch.allclose(new_res, expected_res)) 140 141 # TODO(tkaruturi): emitting single LoweredBackendModule 142 # program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program 143 144 @vary_segments 145 def test_backend_with_compiler(self, extract_delegate_segments: bool): 146 class SinModule(torch.nn.Module): 147 def __init__(self): 148 super().__init__() 149 150 # TODO(chenlai): add a test with a diffrent method name when 151 # it's resolved in compiler side. 152 def forward(self, x): 153 return torch.sin(x) 154 155 sin_module = SinModule() 156 model_inputs = (torch.ones(1),) 157 edgeir_m = to_edge(export(sin_module, model_inputs)) 158 max_value = model_inputs[0].shape[0] 159 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 160 lowered_sin_module = to_backend( 161 "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs 162 ) 163 164 class CompositeModule(torch.nn.Module): 165 def __init__(self): 166 super().__init__() 167 self.lowered_linear_sin = lowered_sin_module 168 169 def forward(self, x): 170 return self.lowered_linear_sin(x) 171 172 composite_model = CompositeModule() 173 model_inputs = (torch.ones(1),) 174 175 composite_model(*model_inputs) 176 177 exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch( 178 config=exir.ExecutorchBackendConfig( 179 extract_delegate_segments=extract_delegate_segments 180 ) 181 ) 182 graph_module = exec_prog.exported_program().graph_module 183 184 # Check that there is not an aten.sin node. 185 self.assertTrue( 186 exir_ops.edge.aten.sin 187 not in {node.target for node in graph_module.graph.nodes} 188 ) 189 190 # Check that there exists a call_delegate, representing the call to the 191 # delegated function 192 FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( 193 graph_module.code 194 ) 195 lowered_submodules = get_lowered_submodules(graph_module) 196 self.assertEqual(len(lowered_submodules), 1) 197 198 for node in graph_module.graph.nodes: 199 if node.op == "call_function" and node.target == executorch_call_delegate: 200 # Check that first arg is lowered_module_{unique_id} 201 self.assertEqual(node.args[0].target, "lowered_module_0") 202 203 program = exec_prog._emitter_output.program 204 205 # Check the program can be printed 206 print_program(program) 207 208 # Check the backend delegate 209 self.check_backend_delegate( 210 program=program, 211 delegate=program.execution_plan[0].delegates[0], 212 expected_id=BackendWithCompilerDemo.__name__, 213 expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#", 214 ) 215 216 # Check the delegate instruction 217 self.assertTrue( 218 isinstance( 219 program.execution_plan[0].chains[0].instructions[0].instr_args, 220 DelegateCall, 221 ) 222 ) 223 buff = exec_prog.buffer 224 225 executorch_module = _load_for_executorch_from_buffer(buff) 226 model_inputs = torch.ones(1) 227 model_outputs = executorch_module.forward([model_inputs]) 228 self.assertEqual( 229 model_inputs, 230 torch.ones(1), 231 ) 232 expected_output = 0.8333 * torch.ones(1) 233 234 self.assertTrue( 235 torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) 236 ) 237 238 @vary_segments 239 def test_lowered_add_mul(self, extract_delegate_segments: bool): 240 class AddMulModule(torch.nn.Module): 241 def __init__(self): 242 super().__init__() 243 244 def forward(self, a, x, b): 245 y = torch.mm(a, x) 246 z = torch.add(y, b) 247 return z 248 249 add_mul_module = AddMulModule() 250 model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) 251 edge_graph_module = to_edge(export(add_mul_module, model_inputs)) 252 max_value = model_inputs[0].shape[0] 253 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 254 lowered_add_mul = to_backend( 255 "BackendWithCompilerDemo", 256 edge_graph_module.exported_program(), 257 compile_specs, 258 ) 259 260 class CompositeModule(torch.nn.Module): 261 def __init__(self): 262 super().__init__() 263 self.lowered_add_mul = lowered_add_mul 264 265 def forward(self, a, x, b): 266 return self.lowered_add_mul(a, x, b) 267 268 composite_model = CompositeModule() 269 270 composite_model(*model_inputs) 271 272 exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch( 273 config=exir.ExecutorchBackendConfig( 274 extract_delegate_segments=extract_delegate_segments 275 ) 276 ) 277 buff = exec_prog.buffer 278 279 executorch_module = _load_for_executorch_from_buffer(buff) 280 281 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 282 inputs_flattened, _ = tree_flatten(model_inputs) 283 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 284 ref_output = add_mul_module(*model_inputs) 285 286 self.assertTrue( 287 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) 288 ) 289 290 def run_model_in_unsupported_backend(self, extract_delegate_segments: bool): 291 class SinModule(torch.nn.Module): 292 def __init__(self): 293 super().__init__() 294 295 def forward(self, x): 296 return torch.sin(x) 297 298 sin_module = SinModule() 299 # the backend only accepts shape <= 4 300 model_inputs = (torch.ones(6),) 301 edgeir_m = to_edge(export(sin_module, model_inputs)) 302 max_value = model_inputs[0].shape[0] 303 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 304 lowered_sin_module = to_backend( 305 "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs 306 ) 307 308 class CompositeModule(torch.nn.Module): 309 def __init__(self): 310 super().__init__() 311 self.lowered_linear_sin = lowered_sin_module 312 313 def forward(self, x): 314 return self.lowered_linear_sin(x) 315 316 composite_model = CompositeModule() 317 model_inputs = (torch.zeros(6),) 318 319 composite_model(*model_inputs) 320 321 exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch( 322 config=exir.ExecutorchBackendConfig( 323 extract_delegate_segments=extract_delegate_segments 324 ), 325 ) 326 327 buff = exec_prog.buffer 328 329 # This line should raise an exception like 330 # RuntimeError: failed with error 0x12 331 _load_for_executorch_from_buffer(buff) 332 333 @vary_segments 334 def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool): 335 with self.assertRaisesRegex( 336 RuntimeError, 337 "loading method forward failed with error 0x12", 338 ): 339 self.run_model_in_unsupported_backend( 340 extract_delegate_segments=extract_delegate_segments 341 ) 342 343 @vary_segments 344 def test_backend_with_compiler_delegate_and_operator( 345 self, extract_delegate_segments: bool 346 ): 347 # Test includes both delegates and operator 348 # import the backend implementation 349 from executorch.exir.backend.test.backend_with_compiler_demo import ( 350 BackendWithCompilerDemo, 351 ) 352 353 class SinModule(torch.nn.Module): 354 def __init__(self): 355 super().__init__() 356 357 # TODO(chenlai): add a test with a diffrent method name when 358 # it's resolved in compiler side. 359 def forward(self, x): 360 return [torch.sin(x)] 361 362 sin_module = SinModule() 363 model_inputs = (torch.ones(1),) 364 edgeir_m = to_edge(export(sin_module, model_inputs)) 365 max_value = model_inputs[0].shape[0] 366 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 367 lowered_sin_module = to_backend( 368 "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs 369 ) 370 371 class CompositeModule(torch.nn.Module): 372 def __init__(self): 373 super().__init__() 374 self.lowered_linear_sin = lowered_sin_module 375 376 def forward(self, x): 377 a = self.lowered_linear_sin(x)[0] 378 b = self.lowered_linear_sin(x)[0] 379 return torch.add(a, b) 380 381 composite_model = CompositeModule() 382 model_inputs = (torch.ones(1),) 383 384 composite_model(*model_inputs) 385 386 exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch( 387 config=exir.ExecutorchBackendConfig( 388 extract_delegate_segments=extract_delegate_segments 389 ), 390 ) 391 graph_module = exec_prog.exported_program().graph_module 392 program = exec_prog._emitter_output.program 393 buff = exec_prog.buffer 394 395 # Check that there is not an aten.sin node. 396 self.assertTrue( 397 exir_ops.edge.aten.sin.default 398 not in {node.target for node in graph_module.graph.nodes} 399 ) 400 401 # Check that there exists a call_delegate op, representing the call to the 402 # delegated function 403 FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( 404 graph_module.code 405 ) 406 407 for node in graph_module.graph.nodes: 408 if node.op == "call_function" and node.target == executorch_call_delegate: 409 # Check that first arg is lowered_module_{unique_id} 410 self.assertEqual(node.args[0].target, "lowered_module_0") 411 412 # Check the backend delegate 413 self.check_backend_delegate( 414 program=program, 415 delegate=program.execution_plan[0].delegates[0], 416 expected_id=BackendWithCompilerDemo.__name__, 417 expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#", 418 ) 419 420 # Check the delegate instruction 421 self.assertTrue( 422 isinstance( 423 program.execution_plan[0].chains[0].instructions[0].instr_args, 424 DelegateCall, 425 ) 426 ) 427 428 executorch_module = _load_for_executorch_from_buffer(buff) 429 model_inputs = torch.ones(1) 430 431 model_outputs = executorch_module.forward([model_inputs]) 432 433 self.assertEqual( 434 model_inputs, 435 torch.ones(1), 436 ) 437 expected_output = 1.666667 * torch.ones(1) 438 439 self.assertTrue( 440 torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) 441 ) 442 443 def test_backend_with_compiler_backend_runtime_exception(self): 444 class SinModule(torch.nn.Module): 445 def __init__(self): 446 super().__init__() 447 448 # TODO(chenlai): add a test with a diffrent method name when 449 # it's resolved in compiler side. 450 def forward(self, x): 451 return torch.sin(x) + torch.cos(x) 452 453 sin_module = SinModule() 454 model_inputs = (torch.ones(1),) 455 edgeir_m = to_edge(export(sin_module, model_inputs)) 456 error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo" 457 458 with self.assertRaisesRegex( 459 RuntimeError, 460 error_msg, 461 ): 462 _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program(), []) 463 464 def test_backend_with_compiler_backend_not_found_exception(self): 465 class SinModule(torch.nn.Module): 466 def __init__(self): 467 super().__init__() 468 469 # TODO(chenlai): add a test with a diffrent method name when 470 # it's resolved in compiler side. 471 def forward(self, x): 472 return torch.sin(x) + torch.cos(x) 473 474 sin_module = SinModule() 475 model_inputs = (torch.ones(1),) 476 edgeir_m = to_edge(export(sin_module, model_inputs)) 477 error_msg = r"Backend FakeBackendWithCompilerDemo was not found." 478 479 with self.assertRaisesRegex( 480 NotImplementedError, 481 error_msg, 482 ): 483 _ = to_backend( 484 "FakeBackendWithCompilerDemo", edgeir_m.exported_program(), [] 485 ) 486 487 @vary_segments 488 def test_backend_with_compiler_delegate_and_operator_with_two_modules( 489 self, extract_delegate_segments: bool 490 ): 491 # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend 492 class LowerableSubModel(torch.nn.Module): 493 def __init__(self): 494 super().__init__() 495 496 def forward(self, x): 497 return torch.sin(x) 498 499 # sin_module is an nn.Module 500 to_be_lowered = LowerableSubModel() 501 example_input = (torch.ones(1),) 502 to_be_lowered_exir_submodule = to_edge(export(to_be_lowered, example_input)) 503 504 max_value = example_input[0].shape[0] 505 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 506 lowered_module = to_backend( 507 "BackendWithCompilerDemo", 508 to_be_lowered_exir_submodule.exported_program(), 509 compile_specs, 510 ) 511 512 class NonLowerableSubModel(torch.nn.Module): 513 def __init__(self, bias): 514 super().__init__() 515 self.register_buffer("bias", bias) 516 517 def forward(self, a, b): 518 return torch.add(torch.add(a, b), self.bias) 519 520 # the composite modules, including lower part and non-lowerpart 521 class CompositeModel(torch.nn.Module): 522 def __init__(self): 523 super().__init__() 524 self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3) 525 self.lowerable = lowered_module 526 527 def forward(self, x): 528 a = self.lowerable(x) 529 b = self.lowerable(a) 530 ret = self.non_lowerable(a, b) 531 return a, b, ret 532 533 composite_model = CompositeModel() 534 535 # Prepare the model input 536 model_inputs = (torch.ones(1),) 537 538 # Verify the input works with eager module 539 composite_model(*model_inputs) 540 541 exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch( 542 config=exir.ExecutorchBackendConfig( 543 extract_delegate_segments=extract_delegate_segments 544 ), 545 ) 546 flatbuffer = exec_prog.buffer 547 548 executorch_module = _load_for_executorch_from_buffer(flatbuffer) 549 model_outputs = executorch_module.forward([*model_inputs]) 550 551 expected_outputs = [ 552 0.8333 * torch.ones(1), 553 0.7369 * torch.ones(1), 554 1.8702 * torch.ones(1), 555 ] 556 557 for index, expected_output in enumerate(expected_outputs): 558 self.assertTrue( 559 torch.allclose( 560 model_outputs[index], expected_output, atol=1e-03, rtol=1e-03 561 ) 562 ) 563 564 @vary_segments 565 def test_partition_delegate_graph_with_multiple_patterns( 566 self, extract_delegate_segments: bool 567 ): 568 class CompositeModel(torch.nn.Module): 569 def __init__(self, _weight): 570 super().__init__() 571 self.weight = _weight 572 self.lstm = torch.nn.LSTM( 573 input_size=32, 574 hidden_size=32, 575 num_layers=1, 576 ) 577 self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) 578 579 def forward(self, x_raw, h, c): 580 output, (hn, cn) = self.lstm(x_raw, (h, c)) 581 k = self.conv(output) 582 x = output 583 y = cn 584 a = torch.sub(x, y) 585 b = torch.sub(x, a) 586 c = torch.sub(x, b) 587 d = torch.add(x, self.weight) 588 e = torch.mul(c, d) 589 return e, hn, k 590 591 # Prepare input and trace it 592 input_x = torch.ones([1, 32]) 593 input_h = torch.ones([1, 32]) 594 input_c = torch.ones([1, 32]) 595 inputs = (input_x, input_h, input_c) 596 597 composite_m = CompositeModel(3) 598 orig_res = composite_m(*inputs) 599 600 traced = to_edge( 601 export(composite_m, inputs), 602 compile_config=exir.EdgeCompileConfig( 603 _check_ir_validity=False, _use_edge_ops=True 604 ), 605 ) 606 607 program_without_delegates = to_edge( 608 export(CompositeModel(3), inputs), 609 compile_config=exir.EdgeCompileConfig( 610 _check_ir_validity=False, 611 ), 612 ).to_executorch( 613 config=exir.ExecutorchBackendConfig( 614 extract_delegate_segments=extract_delegate_segments 615 ), 616 ) 617 # after this step, part of the graph will be lowered to backend, depending on 618 # HTAPartitionerDemo's rule. 619 program_with_delegates = traced 620 program_with_delegates = program_with_delegates.to_backend( 621 HTAPartitionerMultiplePatternsDemo() 622 ) 623 program_with_delegates = program_with_delegates.to_executorch( 624 config=exir.ExecutorchBackendConfig( 625 extract_delegate_segments=extract_delegate_segments 626 ), 627 ) 628 629 new_res = program_with_delegates.exported_program().module()(*inputs) 630 for t1, t2 in zip(new_res, orig_res, strict=True): 631 self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) 632 633 # Check the backend delegate 634 self.check_backend_delegate( 635 program=program_with_delegates._emitter_output.program, 636 delegate=program_with_delegates._emitter_output.program.execution_plan[ 637 0 638 ].delegates[0], 639 expected_id=QnnBackend.__name__, 640 expected_processed=b"imqnncompiled", 641 ) 642 643 # Check add not in the program with delegates 644 self.assertEqual( 645 0, 646 len( 647 [ 648 op 649 for op in program_with_delegates._emitter_output.program.execution_plan[ 650 0 651 ].operators 652 if op.name == "aten::sub" 653 ] 654 ), 655 ) 656 657 # Check convolution not in the program with delegates 658 self.assertEqual( 659 0, 660 len( 661 [ 662 op 663 for op in program_with_delegates._emitter_output.program.execution_plan[ 664 0 665 ].operators 666 if op.name == "aten::convolution" 667 ] 668 ), 669 ) 670 671 # Check convolution in the program without delegates 672 self.assertEqual( 673 1, 674 len( 675 [ 676 op 677 for op in program_without_delegates._emitter_output.program.execution_plan[ 678 0 679 ].operators 680 if op.name == "aten::convolution" 681 ] 682 ), 683 ) 684 685 @vary_segments 686 def test_partition_delegate_graph_with_one_patterns( 687 self, extract_delegate_segments: bool 688 ): 689 class CompositeModel(torch.nn.Module): 690 def __init__(self, _weight): 691 super().__init__() 692 self.weight = _weight 693 self.lstm = torch.nn.LSTM( 694 input_size=32, 695 hidden_size=32, 696 num_layers=1, 697 ) 698 self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) 699 700 def forward(self, x_raw, h, c): 701 output, (hn, cn) = self.lstm(x_raw, (h, c)) 702 k = self.conv(output) 703 x = output 704 y = cn 705 a = torch.sub(x, y) 706 b = torch.sub(x, a) 707 c = torch.sub(x, b) 708 d = torch.add(x, self.weight) 709 e = torch.mul(c, d) 710 return e, hn, k 711 712 # Prepare input and trace it 713 input_x = torch.ones([1, 32]) 714 input_h = torch.ones([1, 32]) 715 input_c = torch.ones([1, 32]) 716 inputs = (input_x, input_h, input_c) 717 718 composite_m = CompositeModel(3) 719 orig_res = composite_m(*inputs) 720 721 traced = to_edge( 722 export(composite_m, inputs), 723 compile_config=exir.EdgeCompileConfig( 724 _check_ir_validity=False, _use_edge_ops=True 725 ), 726 ) 727 728 program_without_delegates = to_edge( 729 export( 730 CompositeModel(3), 731 (input_x, input_h, input_c), 732 ), 733 compile_config=exir.EdgeCompileConfig( 734 _check_ir_validity=False, 735 ), 736 ).to_executorch( 737 config=exir.ExecutorchBackendConfig( 738 extract_delegate_segments=extract_delegate_segments 739 ), 740 ) 741 # after this step, part of the graph will be lowered to backend, depending on 742 # HTAPartitionerDemo's rule. 743 traced_with_delegate = traced 744 traced_with_delegate = traced_with_delegate.to_backend( 745 HTAPartitionerOnePatternDemo() 746 ) 747 748 new_res = traced_with_delegate.exported_program().module()(*inputs) 749 for t1, t2 in zip(new_res, orig_res, strict=True): 750 self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) 751 752 program_with_delegates = traced_with_delegate.to_executorch( 753 config=exir.ExecutorchBackendConfig( 754 extract_delegate_segments=extract_delegate_segments 755 ), 756 ) 757 758 # TODO(T143084047): Currently not retraceable 759 # Retracing is not needed, but keeping this here to make sure the result 760 # of to_backend is retraceable 761 # graph_module_with_delegate = to_edge(export( 762 # traced_with_delegate, 763 # (input_x, input_h, input_c), 764 # 765 # )) 766 767 # program_with_delegates = graph_module_with_delegate.to_executorch( 768 # config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments), 769 # ) 770 771 new_res = program_with_delegates.exported_program().module()(*inputs) 772 for t1, t2 in zip(new_res, orig_res, strict=True): 773 self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) 774 775 # Check the backend delegate 776 self.check_backend_delegate( 777 program=program_with_delegates._emitter_output.program, 778 delegate=program_with_delegates._emitter_output.program.execution_plan[ 779 0 780 ].delegates[0], 781 expected_id=QnnBackend.__name__, 782 expected_processed=b"imqnncompiled", 783 ) 784 785 # Check add is in the program with delegates 786 self.assertEqual( 787 1, 788 len( 789 [ 790 op 791 for op in program_with_delegates._emitter_output.program.execution_plan[ 792 0 793 ].operators 794 if op.name == "aten::sub" 795 ] 796 ), 797 ) 798 799 # Check convolution not in the program with delegates 800 self.assertEqual( 801 0, 802 len( 803 [ 804 op 805 for op in program_with_delegates._emitter_output.program.execution_plan[ 806 0 807 ].operators 808 if op.name == "aten::convolution" 809 ] 810 ), 811 ) 812 813 # Check convolution in the program without delegates 814 self.assertEqual( 815 1, 816 len( 817 [ 818 op 819 for op in program_without_delegates._emitter_output.program.execution_plan[ 820 0 821 ].operators 822 if op.name == "aten::convolution" 823 ] 824 ), 825 ) 826 827 @vary_segments 828 def test_add_mul_partitioner(self, extract_delegate_segments: bool): 829 class Model(torch.nn.Module): 830 def __init__(self): 831 super().__init__() 832 833 def forward(self, a, x, b): 834 y = torch.mm(a, x) 835 z = y + b 836 a = z - a 837 y = torch.mm(a, x) 838 z = y + b 839 return z 840 841 m = Model() 842 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 843 orig_res = m(*inputs) 844 845 ep = to_edge(export(m, inputs)) 846 executorch_prog = ep 847 executorch_prog = executorch_prog.to_backend(AddMulPartitionerDemo()) 848 executorch_prog = executorch_prog.to_executorch( 849 config=exir.ExecutorchBackendConfig( 850 extract_delegate_segments=extract_delegate_segments 851 ), 852 ) 853 854 new_res = executorch_prog.exported_program().graph_module(*inputs) 855 self.assertTrue(torch.allclose(new_res[0], orig_res)) 856 857 counter = 0 858 for node in executorch_prog.exported_program().graph_module.graph.nodes: 859 if node.op == "get_attr": 860 self.assertEqual(node.target, f"lowered_module_{counter}") 861 counter += 1 862 # There should be 2 delegated modules 863 self.assertEqual(counter, 2) 864 865 executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer) 866 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 867 inputs_flattened, _ = tree_flatten(inputs) 868 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 869 ref_output = m(*inputs) 870 871 self.assertTrue( 872 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03), 873 ) 874 875 @vary_segments 876 def test_partitioner_with_attributes(self, extract_delegate_segments: bool): 877 """ 878 check that parameters that are lowered are correctly moved into the sub 879 program, rather than being retained and passed as inputs. 880 """ 881 882 class AddOne(torch.nn.Module): 883 def __init__(self): 884 super().__init__() 885 self.register_buffer("one", torch.ones(1, 3)) 886 887 def forward(self, x): 888 return x + self.one 889 890 class Model(torch.nn.Module): 891 def __init__(self): 892 super().__init__() 893 self.add_one = AddOne() 894 self.add_one_2 = AddOne() 895 896 def forward(self, x, y): 897 x = self.add_one(x) * y 898 return self.add_one_2(x) 899 900 inputs = (torch.randn(1, 3), torch.randn(1, 3)) 901 orig_res = Model()(*inputs) 902 ep = to_edge(export(Model(), inputs)) 903 executorch_prog = ep 904 executorch_prog = executorch_prog.to_backend(AddAttributePartitionerDemo()) 905 executorch_prog = executorch_prog.to_executorch( 906 config=exir.ExecutorchBackendConfig( 907 extract_delegate_segments=extract_delegate_segments 908 ), 909 ) 910 911 # Check the delegated submodules 912 lowered_backends = get_lowered_backend_modules( 913 executorch_prog.exported_program().graph_module 914 ) 915 self.assertEqual(len(lowered_backends), 2) 916 for backend in lowered_backends: 917 original_program = backend.original_module 918 # check that program has the lowered attributes 919 self.assertEqual(len(original_program.state_dict), 1) 920 # check backend has one placeholder input one placeholder parameter 921 self.check_delegate_input(backend, 2) 922 923 executorch_prog.buffer 924 925 new_res = executorch_prog.exported_program().graph_module(*inputs) 926 self.assertTrue(torch.allclose(orig_res, new_res[0])) 927 928 def test_bad_partitioner(self): 929 """ 930 Checks that we throw an error if user provided partitioner modifies the 931 graph module 932 """ 933 inputs = (torch.randn(1, 3), torch.randn(1, 3)) 934 935 class Model(torch.nn.Module): 936 def __init__(self): 937 super().__init__() 938 939 def forward(self, x, y): 940 x = x + y 941 x = x * y 942 x = x - y 943 x = x / y 944 x = x * y 945 x = x + y 946 return x 947 948 class BadPartitioner(Partitioner): 949 partition_tags = {"tag1": DelegationSpec("BackendWithCompilerDemo", [])} 950 951 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 952 # Partitioner should not modify the given graph module 953 partition_tags: Dict[str, DelegationSpec] = {} 954 for node in exported_program.graph.nodes: 955 if ( 956 node.op == "call_function" 957 and node.target == exir_ops.edge.aten.add.Tensor 958 ): 959 node.target = exir_ops.edge.aten.mul.Tensor 960 return PartitionResult( 961 tagged_exported_program=exported_program, 962 partition_tags=partition_tags, 963 ) 964 965 ep = to_edge(export(Model(), inputs)) 966 with self.assertRaises(AssertionError): 967 _ = ep.to_backend(BadPartitioner()) 968 969 def test_quantized_with_delegate(self) -> None: 970 torch.ops.load_library( 971 "//executorch/kernels/quantized:custom_ops_generated_lib" 972 ) 973 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 974 in_size = 2 975 input_size = 3 976 output_size = 4 977 linear = torch.nn.Linear(input_size, output_size).eval() 978 example_inputs = (torch.ones(in_size, input_size),) 979 prepared_linear = prepare_fx( 980 linear, 981 qconfig_mapping, 982 example_inputs, 983 backend_config=get_executorch_backend_config(), 984 ) 985 converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx( 986 prepared_linear, 987 ) 988 989 # fails to trace here 990 converted_linear_gm = to_edge( 991 export( 992 converted_linear, 993 example_inputs, 994 ), 995 compile_config=exir.EdgeCompileConfig( 996 _check_ir_validity=False, 997 ), 998 ) 999 FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run( 1000 converted_linear_gm.exported_program().graph_module.code 1001 ) 1002 1003 def test_partition_with_control_flow(self) -> None: 1004 def true_fn(x, y): 1005 x = x - y 1006 x = x + y 1007 x = x - y 1008 return x 1009 1010 def false_fn(x, y): 1011 x = x - y 1012 x = torch.mm(x, y) 1013 x = x - y 1014 return x 1015 1016 class Module(torch.nn.Module): 1017 def forward(self, x, y): 1018 x = x + y 1019 x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) 1020 x = x - y 1021 return x 1022 1023 f = Module() 1024 inputs = (torch.ones(2, 2), torch.ones(2, 2)) 1025 orig_res = f(*inputs) 1026 orig = to_edge( 1027 export( 1028 f, 1029 inputs, 1030 ) 1031 ) 1032 partitioned = orig 1033 partitioned = partitioned.to_backend(AddMulPartitionerDemo()) 1034 1035 new_res = partitioned.exported_program().module()(*inputs) 1036 self.assertTrue(torch.allclose(orig_res, new_res[0])) 1037 1038 toplevel_lowered = get_lowered_submodules( 1039 partitioned.exported_program().graph_module 1040 ) 1041 self.assertEqual(len(toplevel_lowered), 1) 1042 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1043 toplevel_lowered[0][1].original_module.graph_module.code 1044 ) 1045 1046 # Toplevel module only has the cond submodules 1047 partitioned_submodules = get_control_flow_submodules( 1048 partitioned.exported_program().graph_module 1049 ) 1050 self.assertEqual(len(partitioned_submodules), 2) 1051 1052 true_gm = partitioned_submodules[0][1] 1053 true_lowered = get_lowered_submodules(true_gm) 1054 self.assertEqual(len(true_lowered), 1) 1055 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1056 true_lowered[0][1].original_module.graph_module.code 1057 ) 1058 1059 false_gm = partitioned_submodules[1][1] 1060 false_lowered = get_lowered_submodules(false_gm) 1061 self.assertEqual(len(true_lowered), 1) 1062 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1063 false_lowered[0][1].original_module.graph_module.code 1064 ) 1065 1066 def test_partition_with_map(self) -> None: 1067 def map_fn(x, y): 1068 x = x - y 1069 x = x + y 1070 return x 1071 1072 class Module(torch.nn.Module): 1073 def forward(self, xs, y): 1074 y = torch.mm(y, y) 1075 return control_flow.map(map_fn, xs, y) 1076 1077 f = Module() 1078 inputs = (torch.ones(2, 2), torch.ones(2, 2)) 1079 orig_res = f(*inputs) 1080 orig = to_edge( 1081 export( 1082 f, 1083 inputs, 1084 ) 1085 ) 1086 partitioned = orig 1087 partitioned = partitioned.to_backend(AddMulPartitionerDemo()) 1088 1089 toplevel_lowered = get_lowered_submodules( 1090 partitioned.exported_program().graph_module 1091 ) 1092 self.assertEqual(len(toplevel_lowered), 1) 1093 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1094 toplevel_lowered[0][1].original_module.graph_module.code 1095 ) 1096 1097 # Toplevel module only has the map submodule 1098 partitioned_submodules = get_control_flow_submodules( 1099 partitioned.exported_program().graph_module 1100 ) 1101 self.assertEqual(len(partitioned_submodules), 1) 1102 1103 map_fn_gm = partitioned_submodules[0][1] 1104 map_fn_lowered = get_lowered_submodules(map_fn_gm) 1105 self.assertEqual(len(map_fn_lowered), 1) 1106 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1107 map_fn_lowered[0][1].original_module.graph_module.code 1108 ) 1109 1110 new_res = partitioned.exported_program().module()(*inputs) 1111 1112 self.assertTrue(torch.allclose(orig_res, new_res[0])) 1113 1114 def test_partition_with_nested_control_flow(self) -> None: 1115 """ 1116 Partitions the add and mul ops, including the ones inside the submodules 1117 """ 1118 1119 def true_nested(y): 1120 y = y + y 1121 y = torch.mm(y, y) 1122 return y 1123 1124 def false_nested(y): 1125 return torch.mm(y, y) 1126 1127 def true_fn(x, pred2): 1128 z = control_flow.cond(pred2, true_nested, false_nested, [x]) 1129 return x + z 1130 1131 def false_fn(x, _): 1132 return x.cos() 1133 1134 def map_fn(x, pred1, pred2, y): 1135 x = x.cos() 1136 y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) 1137 x = x + y 1138 return x.sin() 1139 1140 class Module(torch.nn.Module): 1141 def forward(self, xs, pred1, pred2, y): 1142 y = torch.mm(y, y) 1143 return control_flow.map(map_fn, xs, pred1, pred2, y) 1144 1145 inputs = ( 1146 torch.ones(2, 2), 1147 torch.tensor([False]), 1148 torch.Tensor([False]), 1149 torch.ones(2, 2), 1150 ) 1151 1152 f = Module() 1153 orig_res = f(*inputs) 1154 orig = to_edge( 1155 export( 1156 f, 1157 inputs, 1158 ) 1159 ) 1160 partitioned = orig 1161 partitioned = partitioned.to_backend(AddMulPartitionerDemo()) 1162 1163 new_res = partitioned.exported_program().module()(*inputs) 1164 self.assertTrue(torch.allclose(orig_res, new_res[0])) 1165 1166 toplevel_lowered = get_lowered_submodules( 1167 partitioned.exported_program().graph_module 1168 ) 1169 self.assertEqual(len(toplevel_lowered), 1) 1170 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1171 toplevel_lowered[0][1].original_module.graph_module.code 1172 ) 1173 1174 # Toplevel module only has the map submodule 1175 partitioned_submodules = get_control_flow_submodules( 1176 partitioned.exported_program().graph_module 1177 ) 1178 self.assertEqual(len(partitioned_submodules), 1) 1179 1180 # Map module has the cond submodules 1181 map_submodules = get_control_flow_submodules(partitioned_submodules[0][1]) 1182 self.assertEqual(len(map_submodules), 2) 1183 1184 # True module 1185 true_module = map_submodules[0][1] 1186 true_lowered = get_lowered_submodules(true_module) 1187 self.assertEqual(len(true_lowered), 1) 1188 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1189 true_lowered[0][1].original_module.graph_module.code 1190 ) 1191 1192 # False module 1193 false_lowered = get_lowered_submodules(map_submodules[1][1]) 1194 self.assertEqual(len(false_lowered), 0) 1195 1196 # True module has the nested cond submodules 1197 true_submodules = get_control_flow_submodules(true_module) 1198 self.assertEqual(len(true_submodules), 2) 1199 1200 # Nested True module 1201 true_true_lowered = get_lowered_submodules(true_submodules[0][1]) 1202 self.assertEqual(len(true_true_lowered), 1) 1203 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check( 1204 "executorch_exir_dialects_edge__ops_aten_mm_default" 1205 ).run(true_true_lowered[0][1].original_module.graph_module.code) 1206 1207 # Nested False module 1208 true_false_lowered = get_lowered_submodules(true_submodules[1][1]) 1209 self.assertEqual(len(true_false_lowered), 1) 1210 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1211 true_false_lowered[0][1].original_module.graph_module.code 1212 ) 1213 1214 def test_list_input(self): 1215 class Module(torch.nn.Module): 1216 def forward(self, x: List[torch.Tensor]): 1217 y = x[0] + x[1] 1218 return y 1219 1220 f = Module() 1221 inputs = ([torch.randn(2, 2), torch.randn(2, 2)],) 1222 edge_prog = to_edge(export(f, inputs)) 1223 lowered_gm = to_backend( 1224 BackendWithCompilerDemo.__name__, edge_prog.exported_program(), [] 1225 ) 1226 1227 class ComposedM(torch.nn.Module): 1228 def __init__(self): 1229 super().__init__() 1230 self.lowered = lowered_gm 1231 1232 def forward(self, x: List[torch.Tensor]): 1233 return self.lowered(x) 1234 1235 gm = to_edge(export(ComposedM(), inputs)) 1236 gm.exported_program().module()(*inputs) 1237 1238 def test_dict_input(self): 1239 class Module(torch.nn.Module): 1240 def forward(self, x: Dict[str, torch.Tensor]): 1241 y = x["a"] + x["b"] 1242 return y 1243 1244 f = Module() 1245 inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) 1246 edge_prog = to_edge(export(f, inputs)) 1247 lowered_gm = to_backend( 1248 BackendWithCompilerDemo.__name__, edge_prog.exported_program(), [] 1249 ) 1250 1251 class ComposedM(torch.nn.Module): 1252 def __init__(self): 1253 super().__init__() 1254 self.lowered = lowered_gm 1255 1256 def forward(self, x: List[torch.Tensor]): 1257 return self.lowered(x) 1258 1259 gm = to_edge(export(ComposedM(), inputs)) 1260 gm.exported_program().module()(*inputs) 1261