1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import typing 10import unittest 11from contextlib import contextmanager 12from typing import List, Optional, Tuple 13 14import executorch.exir as exir 15 16import executorch.exir.schema as schema 17import executorch.exir.tests.models as models 18import pytest 19import torch 20from executorch.exir import ( 21 EdgeCompileConfig, 22 ExecutorchBackendConfig, 23 ExecutorchProgramManager, 24 to_edge, 25) 26from executorch.exir._serialize._program import deserialize_pte_binary 27from executorch.exir.backend.backend_api import to_backend 28from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 29from executorch.exir.dialects._ops import ops as exir_ops 30from executorch.exir.emit import emit_program # noqa 31from executorch.exir.error import InternalError 32from executorch.exir.passes import MemoryPlanningPass 33from executorch.exir.passes.constant_prop_pass import constant_prop_pass 34from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass 35from executorch.exir.print_program import pretty_print, print_program # noqa 36from executorch.exir.schema import ( 37 Bool, 38 DelegateCall, 39 Double, 40 EValue, 41 ExecutionPlan, 42 Int, 43 IntList, 44 JumpFalseCall, 45 KernelCall, 46 KernelTypes, 47 MoveCall, 48 Null, 49 OptionalTensorList, 50 Program, 51 String, 52 Tensor, 53) 54from executorch.exir.tests.common import register_additional_test_aten_ops 55from executorch.exir.tests.models import Mul 56from executorch.extension.pybindings.portable_lib import ( 57 _load_for_executorch_from_buffer, 58) 59 60from functorch.experimental import control_flow 61from torch import nn 62 63from torch.export import Dim, export 64 65 66class WrapperModule(torch.nn.Module): 67 def __init__(self, fn): 68 super().__init__() 69 self.fn = fn 70 71 def forward(self, *args, **kwargs): 72 return self.fn(*args, **kwargs) 73 74 75@contextmanager 76def patch_forward(obj: torch.nn.Module, new_method): 77 """Helper method to make it easier to cleanly torch.export() a method on a 78 module that is not `forward`. 79 80 TODO(suo): upstream this to torch.export.wrapper. 81 """ 82 # Save the original method 83 original_method = obj.forward 84 85 # Patch the method 86 obj.forward = new_method.__get__(obj, obj.__class__) 87 88 try: 89 yield 90 finally: 91 # Restore the original method 92 obj.forward = original_method 93 94 95class TestEmit(unittest.TestCase): 96 @classmethod 97 def setUpClass(cls) -> None: 98 register_additional_test_aten_ops() 99 100 def setUp(self) -> None: 101 self.compile_config = EdgeCompileConfig(_check_ir_validity=False) 102 103 def check_tensor_buffer_loc( 104 self, 105 value_index: int, 106 values: List[EValue], 107 exp_buffer_idx: int, 108 exp_mem_id: Optional[int], 109 exp_mem_offset: Optional[int], 110 ) -> None: 111 value = typing.cast(schema.Tensor, values[value_index].val) 112 self.assertIsInstance(value, schema.Tensor) 113 114 self.assertEqual(value.data_buffer_idx, exp_buffer_idx) 115 116 if not value.allocation_info: 117 self.assertIsNone(exp_mem_id) 118 self.assertIsNone(exp_mem_offset) 119 else: 120 self.assertEqual(value.allocation_info.memory_id, exp_mem_id) 121 assert value.allocation_info 122 self.assertEqual(value.allocation_info.memory_offset, exp_mem_offset) 123 124 def count_node(self, graph_module: torch.fx.GraphModule, opname: str) -> int: 125 return [ 126 node.target._overloadpacket._qualified_op_name 127 for node in graph_module.graph.nodes 128 if node.op == "call_function" 129 ].count(opname) 130 131 def run_dce(self, graph_module: torch.fx.GraphModule) -> None: 132 for submodule in graph_module.modules(): 133 self.assertIsInstance(submodule, torch.fx.GraphModule) 134 typing.cast(torch.fx.GraphModule, submodule).graph.eliminate_dead_code() 135 136 def check_value_types(self, values: List[EValue]) -> None: 137 for value in values: 138 self.assertTrue(type(value.val) in KernelTypes.__args__) 139 140 def count_move_instructions(self, program: Program) -> int: 141 instructions = program.execution_plan[0].chains[0].instructions 142 assert instructions is not None 143 res = 0 144 for instr in instructions: 145 if isinstance(instr.instr_args, MoveCall): 146 res += 1 147 return res 148 149 def test_basic_api(self) -> None: 150 class Foo(torch.nn.Module): 151 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 152 return x * y + x 153 154 f = Foo() 155 156 program = ( 157 to_edge( 158 export( 159 f, 160 (torch.ones(3, 2), torch.zeros(3, 2)), 161 ) 162 ) 163 .to_executorch() 164 .executorch_program 165 ) 166 exec_plan = program.execution_plan[0] 167 ops = exec_plan.operators 168 for op in ops: 169 self.assertEqual(op.overload, "out") 170 171 self.assertEqual(ops[0].name, "aten::mul") 172 self.assertEqual(ops[1].name, "aten::add") 173 174 self.assertEqual(len(exec_plan.inputs), 2) 175 self.assertEqual(len(exec_plan.outputs), 1) 176 177 self.assertEqual(exec_plan.inputs[0], 0) 178 self.assertEqual(exec_plan.outputs[0], 3) 179 180 def test_basic_end_to_end(self) -> None: 181 f = models.BasicSinMax() 182 program = ( 183 to_edge(export(f, f.get_random_inputs())).to_executorch().executorch_program 184 ) 185 exec_plan = program.execution_plan[0] 186 ops = exec_plan.operators 187 for op in ops: 188 self.assertIn(op.overload, {"out", "unary_out"}) 189 190 self.assertEqual(ops[0].name, "aten::sin") 191 192 self.assertEqual(len(exec_plan.inputs), 1) 193 self.assertEqual(len(exec_plan.outputs), 1) 194 195 self.assertEqual(exec_plan.inputs[0], 0) 196 self.assertEqual(exec_plan.outputs[0], 1) 197 198 @pytest.mark.skip(reason="Test not working on OSS") 199 def test_nested_return(self) -> None: 200 class Foo(torch.nn.Module): 201 def forward( 202 self, x: torch.Tensor 203 ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: 204 return ( 205 torch.tensor(1), 206 torch.tensor(2), 207 [torch.sin(x).max(), torch.cos(x).max()], 208 ) 209 210 f = Foo() 211 212 x = (torch.randn(100),) 213 program = to_edge(export(f, x)).to_executorch().executorch_program 214 exec_plan = program.execution_plan[0] 215 self.assertEqual(len(exec_plan.outputs), 4) 216 self.assertEqual(len(exec_plan.inputs), 1) 217 218 self.assertEqual( 219 program.execution_plan[0].container_meta_type.encoded_out_str, 220 "T3#1#1#2($,$,L2#1#1($,$))", 221 ) 222 223 self.assertEqual( 224 program.execution_plan[0].container_meta_type.encoded_inp_str, 225 "T2#1#0(T1#1($),D0())", 226 ) 227 228 def test_constant_output(self): 229 class M(torch.nn.Module): 230 def forward(self, x): 231 return [((1, 3, 1.2), True, [x + x, x * x], None)] 232 233 ep = torch.export.export(M(), (torch.ones(2, 3),)) 234 res = ep.module()(torch.ones(2, 3)) 235 self.assertEqual(res[0][0], (1, 3, 1.2)) 236 program = to_edge(ep).to_executorch().executorch_program 237 outputs = program.execution_plan[0].outputs 238 self.assertEqual(len(outputs), 7) 239 self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1) 240 self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3) 241 self.assertEqual( 242 program.execution_plan[0].values[outputs[2]].val.double_val, 1.2 243 ) 244 self.assertEqual( 245 program.execution_plan[0].values[outputs[3]].val.bool_val, True 246 ) 247 self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null) 248 249 def test_int_list_input(self): 250 class M(torch.nn.Module): 251 def forward(self, x, y, z): 252 return x + y, x + x, x + y + z 253 254 ep = torch.export.export(M(), (torch.ones(2, 3), 2, True)) 255 ep.module()(torch.ones(2, 3), 2, True) 256 program = to_edge(ep).to_executorch().executorch_program 257 inputs = program.execution_plan[0].inputs 258 self.assertEqual(len(inputs), 3) 259 self.assertEqual(program.execution_plan[0].values[inputs[1]].val.int_val, 2) 260 self.assertEqual(program.execution_plan[0].values[inputs[2]].val.bool_val, True) 261 262 def test_inplace_ops(self) -> None: 263 class Foo(torch.nn.Module): 264 def forward(self, x: torch.Tensor) -> torch.Tensor: 265 y = torch.sin(x) 266 z = y.view(100) 267 torch.relu_(z) 268 return z.max() 269 270 f = Foo() 271 272 inputs = (torch.ones((10, 10)),) 273 edge = to_edge(export(f, inputs)) 274 275 removed_ops = ["aten::relu_", "aten::view"] 276 expected_ops = [ 277 "aten::sin", 278 "aten::relu", 279 "aten::max", 280 "executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False 281 ] 282 283 for opname in removed_ops: 284 self.assertEqual( 285 self.count_node(edge.exported_program().graph_module, opname), 0 286 ) 287 for opname in expected_ops: 288 if ( 289 opname != "executorch_prim::et_view" 290 ): # et_view appears as call_function with target = memory.view in graph 291 self.assertTrue( 292 self.count_node(edge.exported_program().graph_module, opname) >= 1 293 ) 294 295 program = edge.to_executorch().executorch_program 296 for opname in removed_ops: 297 self.assertTrue( 298 all(op.name != opname for op in program.execution_plan[0].operators) 299 ) 300 for opname in expected_ops: 301 self.assertTrue( 302 any(op.name == opname for op in program.execution_plan[0].operators) 303 ) 304 305 def test_operators_unique(self) -> None: 306 class OpRepeatedModule(torch.nn.Module): 307 def __init__(self) -> None: 308 super().__init__() 309 self.a = torch.ones(2, 2) 310 self.b = 2 * torch.ones(2, 2) 311 312 def forward(self, x: torch.Tensor) -> torch.Tensor: 313 for _ in range(10): 314 z = self.a * x 315 y = z + self.b 316 return y 317 318 model = OpRepeatedModule() 319 320 inputs = (torch.ones(2, 2),) 321 322 program = to_edge(export(model, inputs)).to_executorch().executorch_program 323 324 self.assertEqual(len(program.execution_plan[0].operators), 2) 325 326 def test_list_type(self) -> None: 327 """Tests that the types of lists are correctly found""" 328 329 class Foo(torch.nn.Module): 330 def forward(self, x: torch.Tensor) -> torch.Tensor: 331 return torch.permute(x, (2, 0, 1)) 332 333 f = Foo() 334 335 program = ( 336 to_edge(export(f, (torch.randn(2, 3, 5),))) 337 .to_executorch() 338 .executorch_program 339 ) 340 exir.print_program.pretty_print(program) 341 342 deboxed_int_list = [] 343 for item in program.execution_plan[0].values[5].val.items: # pyre-ignore[16] 344 deboxed_int_list.append( 345 program.execution_plan[0].values[item].val.int_val # pyre-ignore[16] 346 ) 347 348 self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1])) 349 350 def test_kwargs1(self) -> None: 351 """Tests that the kwargs are placed in the order specified by 352 native_functions.yaml 353 """ 354 355 class Foo(torch.nn.Module): 356 def forward(self, x: torch.Tensor) -> torch.Tensor: 357 batch1 = torch.randn(10, 3, 4) 358 batch2 = torch.randn(10, 4, 5) 359 return torch.addbmm(x, batch1, batch2, alpha=2, beta=3) 360 361 f = Foo() 362 363 program = ( 364 to_edge(export(f, (torch.randn(3, 5),))).to_executorch().executorch_program 365 ) 366 # The value for beta should appear before alpha 367 self.assertEqual(program.execution_plan[0].values[12].val, Int(3)) 368 self.assertEqual(program.execution_plan[0].values[13].val, Int(2)) 369 370 def test_kwargs2(self) -> None: 371 """Tests that the kwargs are placed in the order specified by 372 native_functions.yaml 373 """ 374 375 class Foo(torch.nn.Module): 376 def forward(self, x: torch.Tensor) -> torch.Tensor: 377 values = torch.randn(3, 2) 378 return torch.searchsorted(x, values, side="right", right=True) 379 380 f = Foo() 381 382 x, _ = torch.sort(torch.randn(3, 4)) 383 program = to_edge(export(f, (x,))).to_executorch().executorch_program 384 # The value for right should appear before side 385 self.assertEqual(program.execution_plan[0].values[6].val, Bool(False)) 386 self.assertEqual(program.execution_plan[0].values[7].val, Bool(True)) 387 self.assertEqual(program.execution_plan[0].values[8].val, String("right")) 388 self.assertEqual(program.execution_plan[0].values[9].val, Null()) 389 390 def _assertCallLength(self, program: Program, idx: int, expected_len: int) -> None: 391 instr_args = program.execution_plan[0].chains[0].instructions[idx].instr_args 392 393 if isinstance(instr_args, KernelCall) or isinstance(instr_args, DelegateCall): 394 self.assertEqual(len(instr_args.args), expected_len) 395 else: 396 self.assertTrue(False) 397 398 def test_out(self) -> None: 399 class Foo(torch.nn.Module): 400 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 401 z = y.clone() 402 return torch.mul(x, y, out=z) 403 404 f = Foo() 405 406 program = ( 407 to_edge(export(f, (torch.ones(3), torch.ones(3)))) 408 .to_executorch() 409 .executorch_program 410 ) 411 412 self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1) 413 self._assertCallLength(program, 0, 4) 414 415 def test_model_out(self) -> None: 416 class Module_out(torch.nn.Module): 417 def __init__(self) -> None: 418 super().__init__() 419 self.a = 3 * torch.ones(2, 2, dtype=torch.int32) 420 self.b = 2 * torch.ones(2, 2, dtype=torch.int32) 421 422 def forward(self, x: torch.Tensor) -> torch.Tensor: 423 z = x.clone() 424 torch.mul(self.a, x, out=z) 425 y = x.clone() 426 torch.add(z, self.b, alpha=2, out=y) 427 return y 428 429 model_out = Module_out() 430 431 inputs = (torch.ones(2, 2, dtype=torch.int32),) 432 433 # Trace to FX Graph. 434 program = to_edge(export(model_out, inputs)).to_executorch().executorch_program 435 436 self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 2) 437 self._assertCallLength(program, 0, 4) 438 self._assertCallLength(program, 1, 5) 439 440 def test_stacktrace(self) -> None: 441 def f(x: torch.Tensor) -> torch.Tensor: 442 return torch.mul(x, torch.randn(3, 2)) 443 444 def g(x: torch.Tensor) -> torch.Tensor: 445 return torch.sin(f(x)) 446 447 class Foo(torch.nn.Module): 448 def forward(self, x: torch.Tensor) -> torch.Tensor: 449 return torch.add(g(x), torch.randn(3, 2)) 450 451 h = Foo() 452 453 x = (torch.randn(3, 2),) 454 exec_prog = to_edge(export(h, x)).to_executorch( 455 exir.ExecutorchBackendConfig(emit_stacktrace=True) 456 ) 457 program = exec_prog.executorch_program 458 459 # Check the mul operator's stack trace contains f -> g -> h 460 self.assertTrue( 461 "return torch.mul(x, torch.randn(3, 2))" 462 in program.execution_plan[0] # pyre-ignore[16] 463 .chains[0] 464 .stacktrace[1] 465 .items[-1] 466 .context 467 ) 468 self.assertEqual( 469 program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f" 470 ) 471 self.assertEqual( 472 program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "g" 473 ) 474 self.assertEqual( 475 program.execution_plan[0].chains[0].stacktrace[1].items[-3].name, "forward" 476 ) 477 478 # Check the sin operator's stack trace contains g -> h 479 self.assertEqual( 480 program.execution_plan[0].chains[0].stacktrace[2].items[-1].name, "g" 481 ) 482 self.assertEqual( 483 program.execution_plan[0].chains[0].stacktrace[2].items[-2].name, "forward" 484 ) 485 486 def test_stacktrace_off(self) -> None: 487 class Foo(torch.nn.Module): 488 def forward(self, x: torch.Tensor) -> torch.Tensor: 489 return torch.mul(x, torch.randn(3, 2)) 490 491 f = Foo() 492 493 class Goo(torch.nn.Module): 494 def forward(self, x: torch.Tensor) -> torch.Tensor: 495 return torch.sin(f(x)) 496 497 g = Goo() 498 499 class Hoo(torch.nn.Module): 500 def forward(self, x: torch.Tensor) -> torch.Tensor: 501 return torch.add(g(x), torch.randn(3, 2)) 502 503 h = Hoo() 504 505 x = (torch.randn(3, 2),) 506 program = to_edge(export(h, x)).to_executorch().executorch_program 507 508 # Check the stacktrace is None since we did not specify to get the stacktrace 509 self.assertTrue(program.execution_plan[0].chains[0].stacktrace is None) 510 511 def test_positional_argument_default_value(self) -> None: 512 class Foo(torch.nn.Module): 513 def forward(self, x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: 514 z = torch.ones(6, 2) 515 return torch.ops.aten.cat.out((x, n), out=z) 516 517 f = Foo() 518 519 x = torch.randn(3, 2) 520 program = ( 521 to_edge(export(f, (x, x))) 522 # .to_edge(self.compile_config) # TODO(larryliu): fix cat 523 .to_executorch().executorch_program 524 ) 525 526 self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1) 527 self._assertCallLength(program, 0, 4) 528 529 @pytest.mark.skip(reason="Test not working on OSS") 530 def test_emit_multiple_out(self) -> None: 531 class Foo(torch.nn.Module): 532 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 533 return torch.topk(x, 5) 534 535 f = Foo() 536 537 x = (torch.randn(10),) 538 program = to_edge(export(f, x)).to_executorch().executorch_program 539 self._assertCallLength(program, 0, 8) 540 541 def test_emit_layout(self) -> None: 542 class Foo(torch.nn.Module): 543 def forward(self, x: torch.Tensor) -> torch.Tensor: 544 return torch.ones_like(x) 545 546 f = Foo() 547 548 x = (torch.randn(3, 2),) 549 program = to_edge(export(f, x)).to_executorch().executorch_program 550 551 vals = program.execution_plan[0].values 552 for val in vals: 553 v = val.val 554 if isinstance(v, Tensor): 555 self.assertEqual(v.layout, 0) 556 557 def test_optional_tensor_list(self) -> None: 558 class Foo(torch.nn.Module): 559 def forward(self, x: torch.Tensor) -> torch.Tensor: 560 a = torch.nonzero(x) 561 torch._constrain_as_size(a.shape[0], min=1) 562 b = torch.ops.aten.index.Tensor(x, [a]) 563 return b 564 565 f = Foo() 566 x = (torch.triu(torch.ones(2, 2)),) 567 program = ( 568 to_edge( 569 export(f, x), 570 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 571 ) 572 .to_executorch() 573 .executorch_program 574 ) 575 self.assertTrue( 576 isinstance(program.execution_plan[0].values[3].val, OptionalTensorList) 577 ) 578 self._assertCallLength(program, 0, 3) 579 self._assertCallLength(program, 1, 4) 580 581 def test_optional_float_list(self) -> None: 582 class M(torch.nn.Module): 583 def forward(self, x): 584 return torch.nn.functional.interpolate(x, scale_factor=2) 585 586 x = (torch.randn(1, 1, 2, 2),) 587 program = to_edge(export(M(), x)).to_executorch().executorch_program 588 self.assertIsInstance( 589 program.execution_plan[0].values[-1].val, schema.OptionalTensorList 590 ) 591 592 def test_emit_cond(self) -> None: 593 class M(torch.nn.Module): 594 def __init__(self): 595 super().__init__() 596 597 def forward(self, pred, x): 598 def true_fn(y: torch.Tensor) -> torch.Tensor: 599 y = y + y 600 y = torch.mm(y, y) 601 return y 602 603 def false_fn(y: torch.Tensor) -> torch.Tensor: 604 return torch.mm(y, y) 605 606 ret = control_flow.cond(pred, true_fn, false_fn, [x]) 607 return ret 608 609 module = to_edge(export(M(), (torch.tensor(True), torch.ones(2, 2)))) 610 program = module.to_executorch().executorch_program 611 612 num_mm = 0 613 num_add = 0 614 num_other = 0 615 for inst in program.execution_plan[0].chains[0].instructions: 616 if not isinstance(inst.instr_args, KernelCall): 617 continue 618 619 op = ( 620 program.execution_plan[0] 621 .operators[inst.instr_args.op_index] # pyre-ignore[16] 622 .name 623 ) 624 625 if "mm" in op: 626 num_mm += 1 627 elif "add" in op: 628 num_add += 1 629 else: 630 num_other += 1 631 632 self.assertEqual(num_mm, 2) 633 self.assertEqual(num_add, 1) 634 self.assertEqual(num_other, 0) 635 636 def test_emit_map(self) -> None: 637 class Foo(torch.nn.Module): 638 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 639 def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 640 return x + y 641 642 return control_flow.map(map_fn, x, y) 643 644 f = Foo() 645 646 inputs = (torch.ones(4, 4), torch.ones(4)) 647 module = to_edge( 648 export(f, inputs), 649 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 650 ) 651 program = module.to_executorch().executorch_program 652 653 op_table = program.execution_plan[0].operators 654 # The first two operators at the beginning of a map program should be sym_size 655 # and select_copy, which is what we verify here. The first operator is to generate 656 # the number of iterations and the second operator is to slice the input tensor to 657 # generate the tensor on which this iteration will operate on. 658 self.assertEqual( 659 op_table[ 660 program.execution_plan[0] # pyre-ignore[16] 661 .chains[0] 662 .instructions[0] 663 .instr_args.op_index 664 ].name, 665 "aten::sym_size", 666 ) 667 self.assertEqual( 668 op_table[ 669 program.execution_plan[0] # pyre-ignore[16] 670 .chains[0] 671 .instructions[1] 672 .instr_args.op_index 673 ].name, 674 "aten::select_copy", 675 ) 676 677 # The last three instructions in the map sub-program are: 678 # - Calling the custom op to append the output of this iteration to the accumulator tensor 679 # - Increment the iteration count. 680 # - Then checking if we've completed all the iterations. 681 # We check here that both of these have been generated. 682 self.assertEqual( 683 op_table[ 684 program.execution_plan[0] # pyre-ignore[16] 685 .chains[0] 686 .instructions[-5] 687 .instr_args.op_index 688 ].name, 689 "executorch_prim::et_copy_index", 690 ) 691 self.assertEqual( 692 op_table[ 693 program.execution_plan[0] # pyre-ignore[16] 694 .chains[0] 695 .instructions[-4] 696 .instr_args.op_index 697 ].name, 698 "executorch_prim::add", 699 ) 700 self.assertEqual( 701 op_table[ 702 program.execution_plan[0] # pyre-ignore[16] 703 .chains[0] 704 .instructions[-3] 705 .instr_args.op_index 706 ].name, 707 "executorch_prim::eq", 708 ) 709 # The last two instructions in the overall program check if we should jump back to the 710 # beginning of the loop and then resets the iteration counter if we fall through. 711 self.assertTrue( 712 isinstance( 713 program.execution_plan[0].chains[0].instructions[-2].instr_args, 714 JumpFalseCall, 715 ) 716 ) 717 self.assertEqual( 718 op_table[ 719 program.execution_plan[0] # pyre-ignore[16] 720 .chains[0] 721 .instructions[-1] 722 .instr_args.op_index 723 ].name, 724 "executorch_prim::sub", 725 ) 726 727 def test_load_emit_map(self) -> None: 728 class Foo(torch.nn.Module): 729 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 730 def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 731 return x + y 732 733 return control_flow.map(map_fn, x, y) 734 735 f = Foo() 736 737 inputs = (torch.ones(4, 4), torch.ones(4)) 738 module = to_edge( 739 export(f, inputs), 740 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 741 ) 742 _load_for_executorch_from_buffer(module.to_executorch().buffer) 743 744 def test_run_emit_map(self) -> None: 745 class Foo(torch.nn.Module): 746 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 747 def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 748 return x + y 749 750 return control_flow.map(map_fn, x, y) 751 752 f = Foo() 753 754 inputs = (torch.ones(4, 4), torch.ones(4)) 755 module = to_edge( 756 export(f, inputs), 757 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 758 ) 759 buffer = module.to_executorch().buffer 760 loaded_model = _load_for_executorch_from_buffer(buffer) 761 outputs = loaded_model(inputs)[0] 762 torch.allclose(outputs, f(*inputs)) 763 764 def test_dim_order(self) -> None: 765 class SimpleLinear(torch.nn.Module): 766 def __init__(self) -> None: 767 super().__init__() 768 self.linear = torch.nn.Linear(5, 5) 769 770 def forward(self, x: torch.Tensor) -> torch.Tensor: 771 return torch.nn.functional.relu(self.linear(x)) 772 773 model = SimpleLinear() 774 inputs = (torch.ones(10, 5),) 775 program = ( 776 to_edge( 777 export(model, inputs), 778 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 779 ) 780 .to_executorch() 781 .executorch_program 782 ) 783 784 addmm_found = False 785 for inst in program.execution_plan[0].chains[0].instructions: 786 kernel = inst.instr_args 787 if isinstance(kernel, KernelCall): 788 op_id = kernel.op_index 789 op = program.execution_plan[0].operators[op_id] 790 if op.name == "aten::addmm": 791 addmm_found = True 792 args = kernel.args 793 bias_id = args[0] 794 act_id = args[1] 795 weight_id = args[2] 796 bias_dim_order = [0] 797 act_dim_order = [0, 1] 798 weight_dim_order = [0, 1] 799 bias_tensor = typing.cast( 800 schema.Tensor, program.execution_plan[0].values[bias_id].val 801 ) 802 act_tensor = typing.cast( 803 schema.Tensor, program.execution_plan[0].values[act_id].val 804 ) 805 weight_tensor = typing.cast( 806 schema.Tensor, program.execution_plan[0].values[weight_id].val 807 ) 808 self.assertTrue(bias_tensor.dim_order == bias_dim_order) 809 self.assertTrue(act_tensor.dim_order == act_dim_order) 810 self.assertTrue(weight_tensor.dim_order == weight_dim_order) 811 self.assertTrue(addmm_found) 812 813 def test_non_const_buffer_sizes(self) -> None: 814 class Add(torch.nn.Module): 815 def forward(self, x: torch.Tensor) -> torch.Tensor: 816 b = 3 + 1 817 return x + b 818 819 f = Add() 820 821 edge_program_manager = to_edge( 822 export( 823 f, 824 (torch.ones(3, 2),), 825 ) 826 ) 827 edge_program_manager._edge_programs["forward"] = constant_prop_pass( 828 edge_program_manager.exported_program() 829 ) 830 non_const_buffer_size_with_const_prop_pass = ( 831 edge_program_manager.to_executorch() 832 .executorch_program.execution_plan[0] 833 .non_const_buffer_sizes 834 ) 835 836 edge_program_manager = to_edge( 837 export( 838 f, 839 (torch.ones(3, 2),), 840 ) 841 ) 842 non_const_buffer_size_without_const_prop_pass = ( 843 edge_program_manager.to_executorch() 844 .executorch_program.execution_plan[0] 845 .non_const_buffer_sizes 846 ) 847 self.assertTrue( 848 non_const_buffer_size_with_const_prop_pass[1] 849 < non_const_buffer_size_without_const_prop_pass[1] 850 ) 851 852 # cant compare plans directly with __eq__ because of the plan names, and data_buffer_idx in tensor values 853 def _compare_execution_plans( 854 self, plan_single: ExecutionPlan, plan_merged: ExecutionPlan 855 ) -> None: 856 self.assertEqual( 857 plan_single.container_meta_type, 858 plan_merged.container_meta_type, 859 ) 860 self.assertEqual( 861 plan_single.inputs, 862 plan_merged.inputs, 863 ) 864 self.assertEqual( 865 plan_single.outputs, 866 plan_merged.outputs, 867 ) 868 self.assertEqual( 869 plan_single.chains, 870 plan_merged.chains, 871 ) 872 self.assertEqual( 873 plan_single.operators, 874 plan_merged.operators, 875 ) 876 self.assertEqual( 877 plan_single.non_const_buffer_sizes, 878 plan_merged.non_const_buffer_sizes, 879 ) 880 self.assertEqual( 881 len(plan_single.values), 882 len(plan_merged.values), 883 ) 884 for i in range(0, len(plan_single.values)): 885 single_val = plan_single.values[i].val 886 merged_val = plan_merged.values[i].val 887 if isinstance(single_val, Tensor): 888 # constant buffer index might be different as the constant buffer is shared between plans 889 self.assertTrue(isinstance(merged_val, Tensor)) 890 self.assertEqual(single_val.storage_offset, merged_val.storage_offset) 891 self.assertEqual(single_val.scalar_type, merged_val.scalar_type) 892 self.assertEqual(single_val.sizes, merged_val.sizes) 893 self.assertEqual(single_val.dim_order, merged_val.dim_order) 894 self.assertEqual(single_val.requires_grad, merged_val.requires_grad) 895 self.assertEqual(single_val.layout, merged_val.layout) 896 self.assertEqual(single_val.allocation_info, merged_val.allocation_info) 897 self.assertEqual(single_val.shape_dynamism, merged_val.shape_dynamism) 898 else: 899 self.assertEqual(single_val, merged_val) 900 901 def test_emit_memory_format_valid(self) -> None: 902 class SimpleLinear(torch.nn.Module): 903 def __init__(self) -> None: 904 super().__init__() 905 906 def forward(self, x: torch.Tensor) -> torch.Tensor: 907 contiguous = x.to( 908 dtype=torch.float32, memory_format=torch.contiguous_format 909 ) 910 preserve = x.to( 911 dtype=torch.float32, memory_format=torch.preserve_format 912 ) 913 return contiguous + preserve 914 915 # Should succeed at exporting model with legal memory format (contiguous, preserve) 916 model = SimpleLinear() 917 inputs = (torch.ones(10, 5),) 918 try: 919 to_edge( 920 export(model, inputs), 921 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 922 ).to_executorch() 923 except: 924 self.fail("Failed to export model with legal memory format") 925 926 def test_emit_memory_format_invalid(self) -> None: 927 class SimpleLinear(torch.nn.Module): 928 def __init__(self) -> None: 929 super().__init__() 930 931 def forward(self, x: torch.Tensor) -> torch.Tensor: 932 return x.to(dtype=torch.float32, memory_format=torch.channels_last) 933 934 # Failure expected when exporting model with illegal memory format (channels_last) when not using dim_order 935 model = SimpleLinear() 936 inputs = (torch.ones(10, 5, 2, 1),) 937 with self.assertRaises(InternalError): 938 to_edge( 939 export(model, inputs), 940 compile_config=exir.EdgeCompileConfig( 941 _check_ir_validity=False, _skip_dim_order=True 942 ), 943 ).to_executorch() 944 945 # Success if you use dim_order 946 to_edge( 947 export(model, inputs), 948 compile_config=exir.EdgeCompileConfig( 949 _check_ir_validity=False, _skip_dim_order=False 950 ), 951 ).to_executorch() 952 953 def test_emit_multiple_entry_points(self) -> None: 954 class SimpleLinear(torch.nn.Module): 955 def __init__(self) -> None: 956 super().__init__() 957 self.linear = torch.nn.Linear(5, 5) 958 self.linear2 = torch.nn.Linear(5, 5) 959 960 def forward_relu(self, x: torch.Tensor) -> torch.Tensor: 961 return torch.nn.functional.relu(self.linear(x)) 962 963 def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: 964 return torch.nn.functional.sigmoid(self.linear2(x)) 965 966 model = SimpleLinear() 967 inputs = (torch.ones(10, 5),) 968 with patch_forward(model, model.forward_relu): 969 program_relu = to_edge( 970 export(model, inputs), 971 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 972 ).to_executorch() 973 with patch_forward(model, model.forward_sigmoid): 974 program_sigmoid = to_edge( 975 export(model, inputs), 976 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 977 ).to_executorch() 978 exir_input = { 979 "forward_relu": program_relu.exported_program(), 980 "forward_sigmoid": program_sigmoid.exported_program(), 981 } 982 merged_program = emit_program(exir_input, False).program 983 self.assertEqual(len(merged_program.execution_plan), 2) 984 985 self.assertEqual( 986 merged_program.execution_plan[0].name, 987 "forward_relu", 988 ) 989 self.assertEqual( 990 merged_program.execution_plan[1].name, 991 "forward_sigmoid", 992 ) 993 # reserved spot, weight, bias 994 self.assertEqual( 995 len(program_sigmoid._emitter_output.program.constant_buffer), 996 3, 997 ) 998 self.assertEqual( 999 len(program_relu._emitter_output.program.constant_buffer), 1000 3, 1001 ) 1002 # sum of the entry points minus 1 because we only have one reserved spot still 1003 self.assertEqual( 1004 len(merged_program.constant_buffer), 1005 len(program_sigmoid._emitter_output.program.constant_buffer) 1006 + len(program_relu._emitter_output.program.constant_buffer) 1007 - 1, 1008 ) 1009 1010 self._compare_execution_plans( 1011 merged_program.execution_plan[0], 1012 program_relu._emitter_output.program.execution_plan[0], 1013 ) 1014 self._compare_execution_plans( 1015 merged_program.execution_plan[1], 1016 program_sigmoid._emitter_output.program.execution_plan[0], 1017 ) 1018 1019 def test_emit_weight_deduplication(self) -> None: 1020 class SimpleLinear(torch.nn.Module): 1021 def __init__(self) -> None: 1022 super().__init__() 1023 self.linear = torch.nn.Linear(5, 5) 1024 1025 def forward_relu(self, x: torch.Tensor) -> torch.Tensor: 1026 return torch.nn.functional.relu(self.linear(x)) 1027 1028 def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor: 1029 return torch.nn.functional.sigmoid(self.linear(x)) 1030 1031 model = SimpleLinear() 1032 inputs = (torch.ones(10, 5),) 1033 with patch_forward(model, model.forward_relu): 1034 program_relu = to_edge(export(model, inputs)).to_executorch() 1035 with patch_forward(model, model.forward_sigmoid): 1036 program_sigmoid = to_edge(export(model, inputs)).to_executorch() 1037 exir_input = { 1038 "forward_relu": program_relu.exported_program(), 1039 "forward_sigmoid": program_sigmoid.exported_program(), 1040 } 1041 merged_program = emit_program(exir_input, False).program 1042 self.assertEqual(len(merged_program.execution_plan), 2) 1043 1044 # reserved spot, weight, bias 1045 self.assertEqual( 1046 len(program_sigmoid._emitter_output.program.constant_buffer), 1047 3, 1048 ) 1049 self.assertEqual( 1050 len(program_relu._emitter_output.program.constant_buffer), 1051 3, 1052 ) 1053 # weights are shared between entry points so the merged one should deduplicate everything 1054 self.assertEqual(len(merged_program.constant_buffer), 3) 1055 1056 self._compare_execution_plans( 1057 merged_program.execution_plan[0], 1058 program_relu._emitter_output.program.execution_plan[0], 1059 ) 1060 self._compare_execution_plans( 1061 merged_program.execution_plan[1], 1062 program_sigmoid._emitter_output.program.execution_plan[0], 1063 ) 1064 1065 def test_emit_execution_plans_sorted(self) -> None: 1066 class Simple(torch.nn.Module): 1067 def __init__(self) -> None: 1068 super().__init__() 1069 1070 def a(self, x: torch.Tensor) -> torch.Tensor: 1071 return x 1072 1073 def b(self, x: torch.Tensor) -> torch.Tensor: 1074 return x 1075 1076 def c(self, x: torch.Tensor) -> torch.Tensor: 1077 return x 1078 1079 model = Simple() 1080 inputs = (torch.ones(10, 5),) 1081 1082 def make_program( 1083 fn, 1084 inputs, 1085 ) -> "ExecutorchProgramManager": 1086 return to_edge( 1087 export( 1088 WrapperModule(fn), 1089 inputs, 1090 ) 1091 ).to_executorch() 1092 1093 program_a = make_program(model.a, inputs) 1094 program_b = make_program(model.b, inputs) 1095 program_c = make_program(model.c, inputs) 1096 1097 exir_input = { 1098 "b": program_b.exported_program(), 1099 "c": program_c.exported_program(), 1100 "a": program_a.exported_program(), 1101 } 1102 merged_program = emit_program(exir_input, False).program 1103 self.assertEqual(len(merged_program.execution_plan), 3) 1104 self.assertEqual(merged_program.execution_plan[0].name, "a") 1105 self.assertEqual(merged_program.execution_plan[1].name, "b") 1106 self.assertEqual(merged_program.execution_plan[2].name, "c") 1107 1108 # Create a second program equivalent to the first, but the input is in a different order. 1109 # python dicts are instertion ordered 1110 exir_input2 = { 1111 "a": program_b.exported_program(), 1112 "b": program_c.exported_program(), 1113 "c": program_a.exported_program(), 1114 } 1115 merged_program2 = emit_program(exir_input2, False).program 1116 self.assertEqual( 1117 merged_program2.execution_plan[0], merged_program.execution_plan[0] 1118 ) 1119 self.assertEqual( 1120 merged_program2.execution_plan[1], merged_program.execution_plan[1] 1121 ) 1122 self.assertEqual( 1123 merged_program2.execution_plan[2], merged_program.execution_plan[2] 1124 ) 1125 1126 def test_upper_bound_memory_planning_respect_input_constraints(self) -> None: 1127 class Foo(torch.nn.Module): 1128 def forward(self, k: torch.Tensor) -> torch.Tensor: 1129 k = torch.cat((k, torch.ones(1, 4))) 1130 return k 1131 1132 func = Foo() 1133 1134 k = torch.rand(2, 4) 1135 dim0_k = Dim("dim0_k", max=3) 1136 dynamic_shapes = {"k": {0: dim0_k}} 1137 captured = export( 1138 func, 1139 (k,), 1140 dynamic_shapes=dynamic_shapes, 1141 ) 1142 edge = to_edge(captured) 1143 from executorch.exir.passes import MemoryPlanningPass 1144 1145 config = exir.ExecutorchBackendConfig( 1146 sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), 1147 memory_planning_pass=MemoryPlanningPass( 1148 # allow_lifetime_and_storage_overlap: bool = False, 1149 alloc_graph_input=True, 1150 alloc_graph_output=False, 1151 ), 1152 ) 1153 1154 exe_prog = edge.to_executorch(config) 1155 program = exe_prog._emitter_output.program 1156 exir.print_program.pretty_print(exe_prog._emitter_output.program.execution_plan) 1157 execution_plan = program.execution_plan[0] 1158 self.check_tensor_buffer_loc(0, execution_plan.values, 0, 1, 0) 1159 self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48) 1160 1161 def test_emit_prims(self) -> None: 1162 tensor_output = torch.rand(1, 4) 1163 tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)] 1164 1165 class Simple(torch.nn.Module): 1166 def __init__(self) -> None: 1167 super().__init__() 1168 self.linear = torch.nn.Linear(5, 5) 1169 self.x: int = 3 1170 self.y = 2 1171 1172 def get_ints(self) -> Tuple[int]: 1173 return (self.x, self.y) 1174 1175 def get_str(self) -> str: 1176 return "foo" 1177 1178 def get_tensor(self) -> torch.Tensor: 1179 return tensor_output 1180 1181 def get_tensor_list(self) -> List[torch.Tensor]: 1182 return tensor_list_output 1183 1184 def forward(self, x: torch.Tensor) -> torch.Tensor: 1185 return torch.nn.functional.sigmoid(self.linear(x)) 1186 1187 model = Simple() 1188 inputs = (torch.ones(10, 5),) 1189 program = to_edge(export(model, inputs)).to_executorch() 1190 exir_input = { 1191 "forward": program.exported_program(), 1192 } 1193 getters = {} 1194 getters["get_ints"] = model.get_ints() 1195 getters["get_str"] = model.get_str() 1196 getters["get_tensor"] = model.get_tensor() 1197 getters["get_tensor_list"] = model.get_tensor_list() 1198 1199 merged_program = emit_program(exir_input, False, getters).program 1200 1201 self.assertEqual(len(merged_program.execution_plan), 5) 1202 1203 self.assertEqual( 1204 merged_program.execution_plan[0].name, 1205 "forward", 1206 ) 1207 self.assertEqual( 1208 merged_program.execution_plan[1].name, 1209 "get_ints", 1210 ) 1211 self.assertEqual( 1212 merged_program.execution_plan[2].name, 1213 "get_str", 1214 ) 1215 self.assertEqual( 1216 merged_program.execution_plan[3].name, 1217 "get_tensor", 1218 ) 1219 self.assertEqual( 1220 merged_program.execution_plan[4].name, 1221 "get_tensor_list", 1222 ) 1223 1224 # no instructions in a getter 1225 self.assertEqual( 1226 len(merged_program.execution_plan[1].chains[0].instructions), 1227 0, 1228 ) 1229 # 2 outputs for the flattened tuple 1230 self.assertEqual( 1231 len(merged_program.execution_plan[1].outputs), 1232 2, 1233 ) 1234 # outputs are 0 and 1 in the values table 1235 self.assertEqual( 1236 merged_program.execution_plan[1].outputs, 1237 [0, 1], 1238 ) 1239 # value 0 is 3 1240 self.assertEqual( 1241 # pyre-ignore 1242 merged_program.execution_plan[1].values[0].val.int_val, 1243 3, 1244 ) 1245 self.assertEqual( 1246 # pyre-ignore 1247 merged_program.execution_plan[1].values[1].val.int_val, 1248 2, 1249 ) 1250 self.assertEqual( 1251 len(merged_program.execution_plan[2].outputs), 1252 1, 1253 ) 1254 self.assertEqual( 1255 # pyre-ignore 1256 merged_program.execution_plan[2].values[0].val.string_val, 1257 "foo", 1258 ) 1259 self.assertEqual(len(merged_program.execution_plan[3].outputs), 1) 1260 self.assertEqual(len(merged_program.execution_plan[4].outputs), 2) 1261 1262 merged_program = to_edge( 1263 export(model, inputs), constant_methods=getters 1264 ).to_executorch() 1265 executorch_module = _load_for_executorch_from_buffer(merged_program.buffer) 1266 torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output) 1267 model_output = executorch_module.run_method("get_tensor_list", []) 1268 for i in range(len(tensor_list_output)): 1269 torch.allclose(model_output[i], tensor_list_output[i]) 1270 1271 def test_emit_debug_handle_map(self) -> None: 1272 mul_model = Mul() 1273 program_mul = to_edge( 1274 export( 1275 mul_model, 1276 mul_model.get_random_inputs(), 1277 ) 1278 ).to_executorch() 1279 # this triggers the actual emission of the graph 1280 program_mul._emitter_output.program 1281 self.assertIsNotNone(program_mul.debug_handle_map) 1282 1283 def test_final_graph_module_update_debug_handle(self) -> None: 1284 class SimpleAddMul(torch.nn.Module): 1285 def __init__(self) -> None: 1286 super().__init__() 1287 1288 def forward(self, x: torch.Tensor) -> torch.Tensor: 1289 a = x + 1 1290 return a * 2 1291 1292 mul_model = SimpleAddMul() 1293 program_mul = to_edge( 1294 export( 1295 mul_model, 1296 (torch.ones(2, 2),), 1297 ) 1298 ).to_executorch() 1299 1300 # this triggers the actual emission of the graph 1301 program = program_mul._emitter_output.program 1302 node = None 1303 program.execution_plan[0].chains[0].instructions[ # pyre-ignore[16] 1304 0 1305 ].instr_args.op_index 1306 1307 # Find the multiplication node in the graph that was emitted. 1308 for node in program_mul.exported_program().graph.nodes: 1309 if node.target == torch.ops.aten.mul.out: 1310 break 1311 self.assertIsNotNone(node) 1312 1313 idx = 0 1314 # Find the multiplication instruction in the program that was emitted. 1315 for idx in range(len(program.execution_plan[0].chains[0].instructions)): 1316 instruction = program.execution_plan[0].chains[0].instructions[idx] 1317 op_index = instruction.instr_args.op_index # pyre-ignore[16] 1318 if "mul" in program.execution_plan[0].operators[op_index].name: 1319 break 1320 1321 # The instruction id of the multiplication instruction and the debug handle of the 1322 # multiplication node in the graph module (which was updated in the emitter to be 1323 # the same as the instruction id) must be the same. 1324 self.assertEqual( 1325 idx, 1326 node.meta.get("debug_handle"), 1327 ) 1328 1329 def test_delegate_with_input_list(self) -> None: 1330 class BackendWithCompilerExample(BackendDetails): 1331 @staticmethod 1332 def preprocess( 1333 edge_program, 1334 compile_specs, 1335 ) -> bytes: 1336 return PreprocessResult( 1337 processed_bytes=bytes(str("test"), encoding="utf8"), 1338 debug_handle_map=None, 1339 ) 1340 1341 class TestModel(nn.Module): 1342 def __init__(self): 1343 super(TestModel, self).__init__() 1344 1345 def forward(self, x): 1346 return torch.cat(x) 1347 1348 inputs = ([torch.ones(2, 2), torch.ones(2, 2)],) 1349 model = TestModel() 1350 edgeir_m = to_edge(export(model, inputs)) 1351 lowered_module = to_backend( 1352 "BackendWithCompilerExample", edgeir_m.exported_program(), [] 1353 ) 1354 1355 class CompositeModule(torch.nn.Module): 1356 def __init__(self): 1357 super().__init__() 1358 self.lowered_module = lowered_module 1359 1360 def forward(self, list_a): 1361 return self.lowered_module(list_a) 1362 1363 composite_model = CompositeModule() 1364 exec_prog = to_edge( 1365 export(composite_model, inputs), 1366 ).to_executorch() 1367 exec_prog.buffer 1368 1369 def test_delegate_with_input_tuple(self) -> None: 1370 class BackendWithCompilerExample(BackendDetails): 1371 @staticmethod 1372 def preprocess( 1373 edge_program, 1374 compile_specs, 1375 ) -> bytes: 1376 return PreprocessResult( 1377 processed_bytes=bytes(str("test"), encoding="utf8"), 1378 debug_handle_map=None, 1379 ) 1380 1381 class AddMulModule(torch.nn.Module): 1382 def __init__(self): 1383 super().__init__() 1384 1385 def forward(self, input): # a, x, b): 1386 y = torch.mm(input[0], input[1]) 1387 z = torch.add(y, input[2]) 1388 return z 1389 1390 model_inputs = ((torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)),) 1391 model = AddMulModule() 1392 edgeir_m = to_edge(export(model, model_inputs)) 1393 lowered_module = to_backend( 1394 "BackendWithCompilerExample", edgeir_m.exported_program(), [] 1395 ) 1396 1397 class CompositeModule(torch.nn.Module): 1398 def __init__(self): 1399 super().__init__() 1400 self.lowered_module = lowered_module 1401 1402 def forward(self, list_a): 1403 return self.lowered_module(list_a) 1404 1405 composite_model = CompositeModule() 1406 exec_prog = to_edge( 1407 export(composite_model, model_inputs), 1408 ).to_executorch() 1409 exec_prog.buffer 1410 1411 def test_delegate_mapping(self) -> None: 1412 debug_handle_map = {1: [1, 2]} 1413 1414 class BackendWithCompilerExample(BackendDetails): 1415 @staticmethod 1416 def preprocess( 1417 edge_program, 1418 compile_specs, 1419 ) -> bytes: 1420 return PreprocessResult( 1421 processed_bytes=bytes(str("test"), encoding="utf8"), 1422 debug_handle_map=debug_handle_map, 1423 ) 1424 1425 class TestModel(nn.Module): 1426 def __init__(self): 1427 super(TestModel, self).__init__() 1428 1429 def forward(self, x, y): 1430 return torch.add(x, y) 1431 1432 inputs = (torch.ones(2, 2), torch.ones(2, 2)) 1433 model = TestModel() 1434 edgeir_m = to_edge(export(model, inputs)) 1435 lowered_module = to_backend( 1436 "BackendWithCompilerExample", edgeir_m.exported_program(), [] 1437 ) 1438 1439 class CompositeModule(torch.nn.Module): 1440 def __init__(self): 1441 super().__init__() 1442 self.lowered_module = lowered_module 1443 1444 def forward(self, x, y): 1445 return self.lowered_module(x, y) 1446 1447 composite_model = CompositeModule() 1448 exec_prog = to_edge( 1449 export(composite_model, inputs), 1450 ).to_executorch() 1451 # Reading the program triggers the call to emit_program underneath which 1452 # we need to be done for our test to succeed. 1453 exec_prog._emitter_output.program 1454 self.assertIsNotNone(exec_prog.delegate_map) 1455 self.assertIsNotNone(exec_prog.delegate_map.get("forward")) 1456 self.assertIsNotNone( 1457 exec_prog.delegate_map.get("forward").get(0) # pyre-ignore[16] 1458 ) 1459 self.assertEqual( 1460 exec_prog.delegate_map.get("forward").get(0).get("name"), 1461 "BackendWithCompilerExample", 1462 ) 1463 self.assertTrue( 1464 len(exec_prog.delegate_map.get("forward").get(0).get("delegate_map")) != 0 1465 ) 1466 1467 def test_emit_weight_view(self) -> None: 1468 class ModWithWeightViews(nn.Module): 1469 def __init__(self): 1470 super(ModWithWeightViews, self).__init__() 1471 self.W = torch.nn.Parameter(torch.randn(2)) 1472 self.W1 = self.W[:1] 1473 self.W2 = self.W[1:] 1474 1475 def forward(self, x): 1476 return self.W1 + self.W2 + x 1477 1478 model = ModWithWeightViews() 1479 # each weight is a view of the same storage 1480 self.assertEqual(model.W1.nbytes, 4) 1481 self.assertEqual(model.W1.untyped_storage().nbytes(), 8) 1482 self.assertEqual(model.W2.nbytes, 4) 1483 self.assertEqual(model.W2.untyped_storage().nbytes(), 8) 1484 program = to_edge( 1485 export( 1486 model, 1487 (torch.ones(1),), 1488 ) 1489 ).to_executorch() 1490 1491 program = program._emitter_output.program 1492 # each emitted weight is not a view 1493 self.assertEqual(len(program.constant_buffer[1].storage), 4) 1494 self.assertEqual(len(program.constant_buffer[2].storage), 4) 1495 1496 def test_non_persistent_buffer(self) -> None: 1497 class NonPersistentBuffer(nn.Module): 1498 def __init__(self): 1499 super(NonPersistentBuffer, self).__init__() 1500 self.register_buffer("buf", torch.tensor([1]), persistent=False) 1501 1502 def forward(self, x): 1503 return x + self.buf 1504 1505 model = NonPersistentBuffer() 1506 program = to_edge( 1507 export( 1508 model, 1509 (torch.ones(1),), 1510 ) 1511 ).to_executorch() 1512 program = program._emitter_output.program 1513 # confirm that the buffer was emitted 1514 self.assertEqual(len(program.constant_buffer), 2) 1515 self.assertEqual(len(program.constant_buffer[1].storage), 8) 1516 1517 def test_emit_lifted_tensor_constant(self) -> None: 1518 class LiftedConstants(nn.Module): 1519 def __init__(self): 1520 super().__init__() 1521 1522 def forward(self, x): 1523 x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float) 1524 return x 1525 1526 model = LiftedConstants() 1527 1528 program = to_edge( 1529 export( 1530 model, 1531 (torch.ones(3, 2),), 1532 ) 1533 ).to_executorch() 1534 1535 program = program._emitter_output.program 1536 exec_plan = program.execution_plan[0] 1537 # There should only be 1 input to this model. 1538 self.assertEqual(len(exec_plan.inputs), 1) 1539 self.assertEqual(len(program.constant_buffer), 2) 1540 self.assertEqual(len(program.constant_buffer[1].storage), 24) 1541 1542 def test_mutable_buffers(self) -> None: 1543 def count_copies(gm: torch.fx.GraphModule) -> int: 1544 return sum( 1545 ( 1546 node.target == torch.ops.aten.copy_ 1547 or node.target == exir_ops.edge.aten.copy_.default 1548 ) 1549 for node in gm.graph.nodes 1550 ) 1551 1552 class MutableStateModule(torch.nn.Module): 1553 def __init__(self): 1554 super().__init__() 1555 self.register_buffer("state", torch.zeros(1)) 1556 1557 def forward(self, x): 1558 y = x + self.state 1559 self.state.add_(1) 1560 return y 1561 1562 model = to_edge( 1563 export( 1564 MutableStateModule(), 1565 (torch.zeros(1),), 1566 ) 1567 ) 1568 model = model.to_executorch() 1569 model.dump_executorch_program(True) 1570 self.assertTrue( 1571 model.executorch_program.execution_plan[0] # pyre-ignore[16] 1572 .values[0] 1573 .val.allocation_info 1574 is not None 1575 ) 1576 executorch_module = _load_for_executorch_from_buffer(model.buffer) 1577 self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1)) 1578 self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1) 1579 1580 def test_mutable_buffers_without_memplanned_inputs(self) -> None: 1581 def count_copies(gm: torch.fx.GraphModule) -> int: 1582 return sum( 1583 ( 1584 node.target == torch.ops.aten.copy_ 1585 or node.target == exir_ops.edge.aten.copy_.default 1586 ) 1587 for node in gm.graph.nodes 1588 ) 1589 1590 class MutableStateModule(torch.nn.Module): 1591 def __init__(self): 1592 super().__init__() 1593 self.register_buffer("state", torch.zeros(1)) 1594 1595 def forward(self, x): 1596 y = x + self.state 1597 self.state.add_(1) 1598 return y 1599 1600 model = to_edge( 1601 export( 1602 MutableStateModule(), 1603 (torch.zeros(1),), 1604 ) 1605 ) 1606 model = model.to_executorch( 1607 config=ExecutorchBackendConfig( 1608 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 1609 sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), 1610 ) 1611 ) 1612 model.dump_executorch_program(True) 1613 self.assertTrue( 1614 model.executorch_program.execution_plan[0] # pyre-ignore[16] 1615 .values[0] 1616 .val.allocation_info 1617 is not None 1618 ) 1619 executorch_module = _load_for_executorch_from_buffer(model.buffer) 1620 self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1)) 1621 self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1) 1622 1623 def test_infinity_in_model(self) -> None: 1624 class InfinityMaskModel(nn.Module): 1625 def __init__(self): 1626 super().__init__() 1627 self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32) 1628 1629 def forward(self, x): 1630 masked_weights = x.masked_fill(self.mask == 0, float("-inf")) 1631 return masked_weights 1632 1633 model = to_edge( 1634 export( 1635 InfinityMaskModel(), 1636 (torch.randn(2, 2),), 1637 ) 1638 ) 1639 1640 # Confirm that we can serialize the model with infinity in it. 1641 model = model.to_executorch() 1642 1643 # Assert that the infinity is stored as a string "-inf". 1644 values = model.executorch_program.execution_plan[0].values 1645 self.assertEqual(values[5].val, Double(double_val=float("-inf"))) 1646 1647 # Confirm that we can also deserialize the model with infinity in it. 1648 pte_data = deserialize_pte_binary(model.buffer) 1649 self.assertEqual( 1650 pte_data.execution_plan, model.executorch_program.execution_plan 1651 ) 1652 1653 def test_mutate_input_tensor(self) -> None: 1654 class MutateInputTensorModule(torch.nn.Module): 1655 def __init__(self): 1656 super().__init__() 1657 1658 def forward(self, x): 1659 x.add_(1) 1660 1661 model = to_edge( 1662 export(MutateInputTensorModule(), (torch.zeros(1),)) 1663 ).to_executorch( 1664 config=ExecutorchBackendConfig( 1665 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False) 1666 ) 1667 ) 1668 executorch_model = _load_for_executorch_from_buffer(model.buffer) 1669 input = torch.zeros(1) 1670 executorch_model(input) 1671 self.assertEqual(input, torch.ones(1)) 1672