1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8from types import MappingProxyType 9 10import torch 11 12from executorch import exir 13from executorch.exir.backend.backend_details import CompileSpec, ExportedProgram 14from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 15 generate_pattern_op_partitions, 16) 17 18from executorch.exir.backend.partitioner import ( 19 DelegationSpec, 20 Partitioner, 21 PartitionResult, 22) 23from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import ( 24 AnyOperatorSupport, 25) 26from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( 27 ExecutorBackend, 28) 29from executorch.exir.backend.test.op_partitioner_demo import ( 30 AddAttributePartitionerDemo, 31 AllNodesPartitionerDemo, 32) 33from executorch.exir.backend.utils import get_delegates, tag_constant_data 34 35from executorch.exir.dialects._ops import ops as exir_ops 36 37from executorch.exir.tests.models import MLP 38from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib 39 _load_for_executorch_from_buffer, 40) 41from executorch.extension.pytree import tree_flatten 42from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param 43from torch.export import export, export_for_training 44from torch.fx.passes.operator_support import any_chain 45 46 47class TestPartitioner(unittest.TestCase): 48 def test_partitioner_with_spec(self): 49 # Create a custom partitioner with spec and check the spec can be accessed by not mutable. 50 class PartitionerWithSpec(Partitioner): 51 def __init__(self, spec) -> None: 52 super().__init__(spec) 53 self.op_support = any_chain(AnyOperatorSupport()) 54 self.delegation_spec = DelegationSpec( 55 ExecutorBackend.__name__, 56 [CompileSpec(key, value) for key, value in self.spec.items()], 57 ) 58 59 def partition( 60 self, edge_exported_program: ExportedProgram 61 ) -> PartitionResult: 62 partition_tags = {} 63 partition_list = generate_pattern_op_partitions( 64 edge_exported_program.graph_module, op_support=self.op_support 65 ) 66 for partition in partition_list: 67 for node in partition.nodes: 68 delegation_tag = f"tag{partition.id}" 69 node.meta["delegation_tag"] = delegation_tag 70 partition_tags[delegation_tag] = self.delegation_spec 71 72 return PartitionResult( 73 tagged_exported_program=edge_exported_program, 74 partition_tags=partition_tags, 75 ) 76 77 mlp = MLP() 78 example_inputs = mlp.get_random_inputs() 79 model = export_for_training(mlp, example_inputs).module() 80 aten = export(model, example_inputs) 81 spec_key = "path" 82 spec_value = "/a/b/c/d" 83 spec = MappingProxyType({spec_key: spec_value}) 84 my_partitioner = PartitionerWithSpec(spec) 85 edge = exir.to_edge(aten).to_backend(my_partitioner) 86 87 lowered_module_nodes = get_delegates(edge.exported_program().graph) 88 89 self.assertEqual(len(lowered_module_nodes), 1) 90 # Check the lowered module has correct compile spec 91 for lower_module_node in lowered_module_nodes: 92 lower_module = getattr( 93 edge.exported_program().graph_module, lower_module_node.name 94 ) 95 self.assertEqual(lower_module.compile_specs[0].key, spec_key) 96 self.assertEqual(lower_module.compile_specs[0].value, spec_value) 97 98 # Check the custom partitioner has the correct spec 99 self.assertEqual(my_partitioner.spec[spec_key], spec_value) 100 101 with self.assertRaisesRegex( 102 TypeError, 103 "'mappingproxy' object does not support item assignment", 104 ): 105 my_partitioner.spec[spec_key] = "new_value" 106 107 with self.assertRaisesRegex( 108 AttributeError, 109 "can't set attribute 'spec'", 110 ): 111 my_partitioner.spec = {"new_key": "new_value"} 112 113 def test_bad_partitioner_tagged_output(self): 114 # Create a bad partitioner to tag output, which is not allowed. 115 class PartitionerTagOutput(Partitioner): 116 def __init__(self) -> None: 117 super().__init__() 118 self.delegation_spec = DelegationSpec( 119 ExecutorBackend.__name__, 120 [CompileSpec(key, value) for key, value in self.spec.items()], 121 ) 122 123 def partition( 124 self, edge_exported_program: ExportedProgram 125 ) -> PartitionResult: 126 partition_tags = {} 127 for node in edge_exported_program.graph.nodes: 128 if node.op == "output": 129 delegation_tag = "tag0" 130 node.meta["delegation_tag"] = delegation_tag 131 partition_tags[delegation_tag] = self.delegation_spec 132 133 return PartitionResult( 134 tagged_exported_program=edge_exported_program, 135 partition_tags=partition_tags, 136 ) 137 138 mlp = MLP() 139 example_inputs = mlp.get_random_inputs() 140 model = export_for_training(mlp, example_inputs).module() 141 aten = export(model, example_inputs) 142 edge = exir.to_edge(aten) 143 144 with self.assertRaisesRegex( 145 RuntimeError, 146 "output node output should not be tagged", 147 ): 148 _ = edge.to_backend(PartitionerTagOutput()) 149 150 def test_bad_partitioner_tagged_model_input(self): 151 # Create a bad partitioner to tag an input that is neither params nor buffer, which is not allowed. 152 class PartitionerTagInput(Partitioner): 153 def __init__(self) -> None: 154 super().__init__() 155 self.delegation_spec = DelegationSpec( 156 ExecutorBackend.__name__, 157 [CompileSpec(key, value) for key, value in self.spec.items()], 158 ) 159 160 def partition( 161 self, edge_exported_program: ExportedProgram 162 ) -> PartitionResult: 163 partition_tags = {} 164 for node in edge_exported_program.graph.nodes: 165 if node.op == "placeholder": 166 if not is_param(edge_exported_program, node) and not is_buffer( 167 edge_exported_program, node 168 ): 169 delegation_tag = "tag_" + str(node.meta["debug_handle"]) 170 node.meta["delegation_tag"] = delegation_tag 171 partition_tags[delegation_tag] = self.delegation_spec 172 173 return PartitionResult( 174 tagged_exported_program=edge_exported_program, 175 partition_tags=partition_tags, 176 ) 177 178 mlp = MLP() 179 example_inputs = mlp.get_random_inputs() 180 model = export_for_training(mlp, example_inputs).module() 181 edge = exir.to_edge(export(model, example_inputs)) 182 183 with self.assertRaisesRegex( 184 RuntimeError, 185 "placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged", 186 ): 187 _ = edge.to_backend(PartitionerTagInput()) 188 189 class AddConst(torch.nn.Module): 190 def __init__(self): 191 super().__init__() 192 self.const1 = torch.ones(2, 2) 193 self.register_buffer("const2", torch.ones(2, 2), persistent=False) 194 self.register_parameter("const3", torch.nn.Parameter(torch.ones(2, 2))) 195 196 def forward(self, x): 197 return x + self.const1 + self.const2 + self.const3 198 199 def test_partitioner_not_tag_data(self): 200 """ 201 We test here that when partitioners do not explicitly tag constant data nodes, 202 then the partitioned ExportedProgram will not own the data. Instead the owning program 203 will still own the constant data and instead feed it as inputs to the partitioned 204 program 205 """ 206 207 class PartitionerNoTagData(Partitioner): 208 def __init__(self): 209 super().__init__() 210 self.delegation_spec = DelegationSpec( 211 ExecutorBackend.__name__, 212 [CompileSpec(key, value) for key, value in self.spec.items()], 213 ) 214 215 def partition( 216 self, edge_exported_program: ExportedProgram 217 ) -> PartitionResult: 218 partition_tags = {} 219 for node in edge_exported_program.graph.nodes: 220 if node.op == "call_function" and node.target in [ 221 exir_ops.edge.aten.add.Tensor 222 ]: 223 delegation_tag = "tag0" 224 node.meta["delegation_tag"] = delegation_tag 225 partition_tags[delegation_tag] = self.delegation_spec 226 227 return PartitionResult( 228 tagged_exported_program=edge_exported_program, 229 partition_tags=partition_tags, 230 ) 231 232 model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() 233 edge = exir.to_edge(export(model, (torch.ones(2, 2),))) 234 delegated = edge.to_backend(PartitionerNoTagData()) 235 236 # Check Owning Program still owns all constant data 237 owning_program = delegated.exported_program() 238 self.assertEqual( 239 len(owning_program.state_dict) + len(owning_program.constants), 3 240 ) 241 self.assertEqual( 242 len(owning_program.graph_signature.buffers) 243 + len(owning_program.graph_signature.lifted_tensor_constants), 244 2, 245 ) 246 self.assertEqual(len(owning_program.graph_signature.parameters), 1) 247 248 # Check Lowered Module Exported Program does not have any constant data 249 lowered_module_nodes = get_delegates(delegated.exported_program().graph) 250 self.assertEqual(len(lowered_module_nodes), 1) 251 lowered_module_node = lowered_module_nodes[0] 252 253 # get call delegate node 254 call_delegate_node = list(lowered_module_node.users.keys())[0] 255 # 5 args to lowered module are: delegated_payload, x, const1, const2, const3 256 self.assertEqual(len(call_delegate_node.args), 5) 257 lower_module = getattr( 258 delegated.exported_program().graph_module, lowered_module_node.name 259 ) 260 delegated_ep = lower_module.original_module 261 self.assertEqual(len(delegated_ep.state_dict), 0) 262 self.assertEqual(len(delegated_ep.graph_signature.buffers), 0) 263 self.assertEqual(len(delegated_ep.graph_signature.parameters), 0) 264 265 # check exported program is still runnable 266 output = delegated.exported_program().module()(torch.ones(2, 2)) 267 reference_output = model(torch.ones(2, 2)) 268 self.assertTrue(torch.allclose(reference_output, output)) 269 270 def test_partitioner_tag_data(self): 271 """ 272 We test here that when partitioners explicitly tag constant data nodes, 273 then the partitioned ExportedProgram will own the data, and the data will 274 be removed from the owning program. 275 """ 276 277 class PartitionerTagData(Partitioner): 278 def __init__(self): 279 super().__init__() 280 self.delegation_spec = DelegationSpec( 281 ExecutorBackend.__name__, 282 [CompileSpec(key, value) for key, value in self.spec.items()], 283 ) 284 285 def partition( 286 self, edge_exported_program: ExportedProgram 287 ) -> PartitionResult: 288 partition_tags = {} 289 for node in edge_exported_program.graph.nodes: 290 if node.op == "call_function" and node.target in [ 291 exir_ops.edge.aten.add.Tensor 292 ]: 293 delegation_tag = "tag0" 294 node.meta["delegation_tag"] = delegation_tag 295 partition_tags[delegation_tag] = self.delegation_spec 296 297 if node.op == "placeholder" and ( 298 is_param(edge_exported_program, node) 299 or is_buffer(edge_exported_program, node) 300 or is_lifted_tensor_constant(edge_exported_program, node) 301 ): 302 delegation_tag = "tag0" 303 node.meta["delegation_tag"] = delegation_tag 304 partition_tags[delegation_tag] = self.delegation_spec 305 306 return PartitionResult( 307 tagged_exported_program=edge_exported_program, 308 partition_tags=partition_tags, 309 ) 310 311 model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() 312 edge = exir.to_edge(export(model, (torch.ones(2, 2),))) 313 delegated = edge.to_backend(PartitionerTagData()) 314 315 # Check Owning Program still owns all constant data 316 owning_program = delegated.exported_program() 317 self.assertEqual(len(owning_program.state_dict), 0) 318 self.assertEqual(len(owning_program.graph_signature.buffers), 0) 319 self.assertEqual(len(owning_program.graph_signature.parameters), 0) 320 321 # Check Lowered Module Exported Program does not have any constant data 322 lowered_module_nodes = get_delegates(delegated.exported_program().graph) 323 self.assertEqual(len(lowered_module_nodes), 1) 324 lowered_module_node = lowered_module_nodes[0] 325 326 # get call delegate node 327 call_delegate_node = list(lowered_module_node.users.keys())[0] 328 # 5 args to lowered module are: delegated_payload, x 329 self.assertEqual(len(call_delegate_node.args), 2) 330 lower_module = getattr( 331 delegated.exported_program().graph_module, lowered_module_node.name 332 ) 333 delegated_ep = lower_module.original_module 334 self.assertEqual(len(delegated_ep.state_dict) + len(delegated_ep.constants), 3) 335 self.assertEqual( 336 len(delegated_ep.graph_signature.buffers) 337 + len(delegated_ep.graph_signature.lifted_tensor_constants), 338 2, 339 ) 340 self.assertEqual(len(delegated_ep.graph_signature.parameters), 1) 341 342 # check exported program is still runnable 343 output = delegated.exported_program().module()(torch.ones(2, 2)) 344 reference_output = model(torch.ones(2, 2)) 345 self.assertTrue(torch.allclose(reference_output, output)) 346 347 def test_partitioner_tag_only_params(self): 348 """ 349 We test here that when partitioners explicitly tag constant data nodes, 350 then the partitioned ExportedProgram will own the data, and the data will 351 be removed from the owning program. 352 """ 353 354 class PartitionerTagData(Partitioner): 355 def __init__(self): 356 super().__init__() 357 self.delegation_spec = DelegationSpec( 358 ExecutorBackend.__name__, 359 [CompileSpec(key, value) for key, value in self.spec.items()], 360 ) 361 362 def partition( 363 self, edge_exported_program: ExportedProgram 364 ) -> PartitionResult: 365 partition_tags = {} 366 for node in edge_exported_program.graph.nodes: 367 if node.op == "call_function" and node.target in [ 368 exir_ops.edge.aten.add.Tensor 369 ]: 370 delegation_tag = "tag0" 371 node.meta["delegation_tag"] = delegation_tag 372 partition_tags[delegation_tag] = self.delegation_spec 373 374 if node.op == "placeholder" and ( 375 is_param(edge_exported_program, node) 376 ): 377 delegation_tag = "tag0" 378 node.meta["delegation_tag"] = delegation_tag 379 partition_tags[delegation_tag] = self.delegation_spec 380 381 return PartitionResult( 382 tagged_exported_program=edge_exported_program, 383 partition_tags=partition_tags, 384 ) 385 386 model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() 387 edge = exir.to_edge(export(model, (torch.ones(2, 2),))) 388 delegated = edge.to_backend(PartitionerTagData()) 389 390 # Check Owning Program still owns only buffers 391 owning_program = delegated.exported_program() 392 self.assertEqual( 393 len(owning_program.state_dict) + len(owning_program.constants), 2 394 ) 395 self.assertEqual( 396 len(owning_program.graph_signature.buffers) 397 + len(owning_program.graph_signature.lifted_tensor_constants), 398 2, 399 ) 400 self.assertEqual(len(owning_program.graph_signature.parameters), 0) 401 402 # Check Lowered Module Exported Program does not own any buffers 403 lowered_module_nodes = get_delegates(delegated.exported_program().graph) 404 self.assertEqual(len(lowered_module_nodes), 1) 405 lowered_module_node = lowered_module_nodes[0] 406 407 # get call delegate node 408 call_delegate_node = list(lowered_module_node.users.keys())[0] 409 # 5 args to lowered module are: delegated_payload, x, buffer1, buffer2 410 self.assertEqual(len(call_delegate_node.args), 4) 411 lower_module = getattr( 412 delegated.exported_program().graph_module, lowered_module_node.name 413 ) 414 delegated_ep = lower_module.original_module 415 self.assertEqual(len(delegated_ep.state_dict), 1) 416 self.assertEqual(len(delegated_ep.graph_signature.buffers), 0) 417 self.assertEqual(len(delegated_ep.graph_signature.parameters), 1) 418 419 # check exported program is still runnable 420 output = delegated.exported_program().module()(torch.ones(2, 2)) 421 reference_output = model(torch.ones(2, 2)) 422 self.assertTrue(torch.allclose(reference_output, output)) 423 424 def test_partitioner_splits_constant_data(self): 425 """ 426 We test that we throw an error when constant data users are split 427 between different delegated payloads or owning program. 428 """ 429 430 class ReuseConstData(torch.nn.Module): 431 def __init__(self): 432 super().__init__() 433 self.const = torch.ones(2, 2) 434 435 def forward(self, x): 436 y = x + self.const 437 z = x - self.const 438 return y, z 439 440 class PartitionerTagData(Partitioner): 441 def __init__(self): 442 super().__init__() 443 self.delegation_spec = DelegationSpec( 444 ExecutorBackend.__name__, 445 [CompileSpec(key, value) for key, value in self.spec.items()], 446 ) 447 448 def partition( 449 self, edge_exported_program: ExportedProgram 450 ) -> PartitionResult: 451 partition_tags = {} 452 for node in edge_exported_program.graph.nodes: 453 if node.op == "call_function" and node.target in [ 454 exir_ops.edge.aten.add.Tensor 455 ]: 456 delegation_tag = "tag0" 457 node.meta["delegation_tag"] = delegation_tag 458 partition_tags[delegation_tag] = self.delegation_spec 459 460 if node.op == "placeholder" and ( 461 is_param(edge_exported_program, node) 462 or is_buffer(edge_exported_program, node) 463 ): 464 delegation_tag = "tag0" 465 node.meta["delegation_tag"] = delegation_tag 466 partition_tags[delegation_tag] = self.delegation_spec 467 468 return PartitionResult( 469 tagged_exported_program=edge_exported_program, 470 partition_tags=partition_tags, 471 ) 472 473 inputs = (torch.ones(2, 2),) 474 model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() 475 edge = exir.to_edge(export(model, (torch.ones(2, 2),))) 476 exec_prog = edge.to_backend(PartitionerTagData()).to_executorch() 477 executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer) 478 inputs_flattened, _ = tree_flatten(inputs) 479 480 # Send the input from server executor to client executor, and receive the result from client executor 481 _ = executorch_module.run_method("forward", inputs) 482 483 def test_partitioner_alert_split_constant_data(self): 484 """ 485 We test that we throw an error when constant data users are split 486 between different delegated payloads or owning program. 487 """ 488 489 class ReuseConstData(torch.nn.Module): 490 def __init__(self): 491 super().__init__() 492 self.const = torch.ones(2, 2) 493 494 def forward(self, x): 495 y = x + self.const 496 z = x - self.const 497 return y, z 498 499 class PartitionerTagData(Partitioner): 500 def __init__(self): 501 super().__init__() 502 self.delegation_spec = DelegationSpec( 503 ExecutorBackend.__name__, 504 [CompileSpec(key, value) for key, value in self.spec.items()], 505 ) 506 507 def partition( 508 self, edge_exported_program: ExportedProgram 509 ) -> PartitionResult: 510 partition_tags = {} 511 for node in edge_exported_program.graph.nodes: 512 if node.op == "call_function" and node.target in [ 513 exir_ops.edge.aten.add.Tensor 514 ]: 515 delegation_tag = "tag0" 516 node.meta["delegation_tag"] = delegation_tag 517 partition_tags[delegation_tag] = self.delegation_spec 518 519 if node.op == "placeholder" and ( 520 is_param(edge_exported_program, node) 521 or is_buffer(edge_exported_program, node) 522 or is_lifted_tensor_constant(edge_exported_program, node) 523 ): 524 delegation_tag = "tag0" 525 node.meta["delegation_tag"] = delegation_tag 526 node.meta["no_copy"] = True 527 partition_tags[delegation_tag] = self.delegation_spec 528 529 return PartitionResult( 530 tagged_exported_program=edge_exported_program, 531 partition_tags=partition_tags, 532 ) 533 534 model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() 535 edge = exir.to_edge(export(model, (torch.ones(2, 2),))) 536 with self.assertRaises(RuntimeError) as error: 537 _ = edge.to_backend(PartitionerTagData()) 538 539 self.assertTrue( 540 "is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)" 541 in str(error.exception), 542 ) 543 544 def test_not_delegate_mutable_buffers(self) -> None: 545 """ 546 A test case to check the mutated buffer is not delegated. We'll need to add a test case 547 to consider when the delegate can consume the mutable buffer. 548 """ 549 550 class MutableStateModule(torch.nn.Module): 551 def __init__(self): 552 super().__init__() 553 self.register_buffer("my_state", torch.zeros(1)) 554 555 def forward(self, x): 556 y = x + self.my_state 557 self.my_state.add_(1) 558 return y 559 560 edge = exir.to_edge( 561 torch.export.export( 562 MutableStateModule(), 563 (torch.zeros(1),), 564 ) 565 ) 566 self.assertGreater( 567 len(edge.exported_program().graph_signature.buffers_to_mutate), 568 0, 569 "The test case should at leaset one mutable buffer", 570 ) 571 572 class PartitionerTagData(Partitioner): 573 def __init__(self): 574 super().__init__() 575 self.delegation_spec = DelegationSpec( 576 ExecutorBackend.__name__, 577 [CompileSpec(key, value) for key, value in self.spec.items()], 578 ) 579 580 def partition( 581 self, edge_exported_program: ExportedProgram 582 ) -> PartitionResult: 583 partition_tags = {} 584 for node in edge_exported_program.graph.nodes: 585 if node.op == "call_function" and node.target in [ 586 exir_ops.edge.aten.add.Tensor 587 ]: 588 delegation_tag = "tag0" 589 node.meta["delegation_tag"] = delegation_tag 590 partition_tags[delegation_tag] = self.delegation_spec 591 tag_constant_data(edge_exported_program) 592 return PartitionResult( 593 tagged_exported_program=edge_exported_program, 594 partition_tags=partition_tags, 595 ) 596 597 # Check the edge program inital buffers_to_mutate 598 mutate_op = "aten_add_tensor_1" 599 self.assertEqual( 600 edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], 601 "my_state", 602 ) 603 edge = edge.to_backend(PartitionerTagData()) 604 # After to_backend, add is delegated and is no longer in buffers_to_mutate. 605 self.assertNotIn( 606 mutate_op, 607 edge.exported_program().graph_signature.buffers_to_mutate, 608 ) 609 610 mutate_op = "getitem_1" 611 # Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate) 612 self.assertEqual( 613 edge.exported_program().graph_signature.buffers_to_mutate[mutate_op], 614 "my_state", 615 ) 616 # Check the copy_ node is inserted 617 edge = edge.to_executorch() 618 copy_node = [ 619 node 620 for node in edge.exported_program().graph.nodes 621 if node.op == "call_function" 622 and node.target == torch.ops.aten.copy_.default 623 ] 624 self.assertEqual(len(copy_node), 1) 625 626 def test_buffer_mutation1(self): 627 class TestModule(torch.nn.Module): 628 def __init__(self): 629 super().__init__() 630 self.register_buffer("b", torch.ones(3, 3)) 631 632 def forward(self, x): 633 self.b.add_(x) 634 return x + self.b 635 636 model_inputs = (torch.ones(3, 3),) 637 orig_res = TestModule()(*model_inputs) 638 edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs)) 639 lowered = edge_program.to_backend(AddAttributePartitionerDemo()) 640 641 self.assertTrue( 642 torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res) 643 ) 644 645 self.assertEqual( 646 len(lowered.exported_program().graph_signature.buffers_to_mutate), 647 0, 648 ) 649 lowered_module_nodes = get_delegates(lowered.exported_program().graph) 650 self.assertEqual(len(lowered_module_nodes), 1) 651 lowered_module_node = lowered_module_nodes[0] 652 653 # get call delegate node 654 call_delegate_node = list(lowered_module_node.users.keys())[0] 655 self.assertEqual(len(call_delegate_node.args), 2) 656 657 lower_module = getattr( 658 lowered.exported_program().graph_module, lowered_module_node.name 659 ) 660 delegated_ep = lower_module.original_module 661 662 self.assertEqual(len(delegated_ep.state_dict), 1) 663 self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) 664 self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) 665 666 def test_buffer_mutation_llama_repro(self): 667 SHAPE = (2, 3) 668 669 class Model(torch.nn.Module): 670 def __init__(self): 671 super().__init__() 672 self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32)) 673 674 def forward(self, q, k_val, input_pos): 675 q_T = q.transpose(0, 1) 676 k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val) 677 attn = k.mm(q_T) 678 return attn 679 680 q = torch.rand(1, 3) 681 k = torch.rand(1, 3) 682 example_inputs = (q, k, torch.tensor([1, 1])) 683 684 model = Model() 685 model.eval() 686 687 exir_program_aten = torch.export.export(model, example_inputs) 688 exir_program_aten.module()(*example_inputs) 689 edge_program_manager = exir.to_edge(exir_program_aten) 690 lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo()) 691 692 self.assertEqual( 693 len(lowered.exported_program().graph_signature.buffers_to_mutate), 694 0, 695 ) 696 lowered_module_nodes = get_delegates(lowered.exported_program().graph) 697 self.assertEqual(len(lowered_module_nodes), 1) 698 lowered_module_node = lowered_module_nodes[0] 699 700 # get call delegate node 701 call_delegate_node = list(lowered_module_node.users.keys())[0] 702 self.assertEqual(len(call_delegate_node.args), 4) 703 704 lower_module = getattr( 705 lowered.exported_program().graph_module, lowered_module_node.name 706 ) 707 delegated_ep = lower_module.original_module 708 709 self.assertEqual(len(delegated_ep.state_dict), 1) 710 self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) 711 self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) 712 713 def test_buffer_mutation_unsupported(self): 714 SHAPE = (2, 3) 715 716 class Model(torch.nn.Module): 717 def __init__(self): 718 super().__init__() 719 self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32)) 720 721 def forward(self, x): 722 add = self.state_1.add_(x) 723 return add 724 725 model = Model() 726 model.eval() 727 728 example_inputs = (torch.randn(SHAPE),) 729 exir_program_aten = torch.export.export(model, example_inputs) 730 edge_program_manager = exir.to_edge(exir_program_aten) 731 with self.assertRaises(AssertionError): 732 edge_program_manager.to_backend(AddAttributePartitionerDemo()) 733