xref: /aosp_15_r20/external/executorch/exir/backend/test/test_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8from types import MappingProxyType
9
10import torch
11
12from executorch import exir
13from executorch.exir.backend.backend_details import CompileSpec, ExportedProgram
14from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
15    generate_pattern_op_partitions,
16)
17
18from executorch.exir.backend.partitioner import (
19    DelegationSpec,
20    Partitioner,
21    PartitionResult,
22)
23from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
24    AnyOperatorSupport,
25)
26from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
27    ExecutorBackend,
28)
29from executorch.exir.backend.test.op_partitioner_demo import (
30    AddAttributePartitionerDemo,
31    AllNodesPartitionerDemo,
32)
33from executorch.exir.backend.utils import get_delegates, tag_constant_data
34
35from executorch.exir.dialects._ops import ops as exir_ops
36
37from executorch.exir.tests.models import MLP
38from executorch.extension.pybindings.portable_lib import (  # @manual=//executorch/extension/pybindings:portable_lib
39    _load_for_executorch_from_buffer,
40)
41from executorch.extension.pytree import tree_flatten
42from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
43from torch.export import export, export_for_training
44from torch.fx.passes.operator_support import any_chain
45
46
47class TestPartitioner(unittest.TestCase):
48    def test_partitioner_with_spec(self):
49        # Create a custom partitioner with spec and check the spec can be accessed by not mutable.
50        class PartitionerWithSpec(Partitioner):
51            def __init__(self, spec) -> None:
52                super().__init__(spec)
53                self.op_support = any_chain(AnyOperatorSupport())
54                self.delegation_spec = DelegationSpec(
55                    ExecutorBackend.__name__,
56                    [CompileSpec(key, value) for key, value in self.spec.items()],
57                )
58
59            def partition(
60                self, edge_exported_program: ExportedProgram
61            ) -> PartitionResult:
62                partition_tags = {}
63                partition_list = generate_pattern_op_partitions(
64                    edge_exported_program.graph_module, op_support=self.op_support
65                )
66                for partition in partition_list:
67                    for node in partition.nodes:
68                        delegation_tag = f"tag{partition.id}"
69                        node.meta["delegation_tag"] = delegation_tag
70                        partition_tags[delegation_tag] = self.delegation_spec
71
72                return PartitionResult(
73                    tagged_exported_program=edge_exported_program,
74                    partition_tags=partition_tags,
75                )
76
77        mlp = MLP()
78        example_inputs = mlp.get_random_inputs()
79        model = export_for_training(mlp, example_inputs).module()
80        aten = export(model, example_inputs)
81        spec_key = "path"
82        spec_value = "/a/b/c/d"
83        spec = MappingProxyType({spec_key: spec_value})
84        my_partitioner = PartitionerWithSpec(spec)
85        edge = exir.to_edge(aten).to_backend(my_partitioner)
86
87        lowered_module_nodes = get_delegates(edge.exported_program().graph)
88
89        self.assertEqual(len(lowered_module_nodes), 1)
90        # Check the lowered module has correct compile spec
91        for lower_module_node in lowered_module_nodes:
92            lower_module = getattr(
93                edge.exported_program().graph_module, lower_module_node.name
94            )
95            self.assertEqual(lower_module.compile_specs[0].key, spec_key)
96            self.assertEqual(lower_module.compile_specs[0].value, spec_value)
97
98        # Check the custom partitioner has the correct spec
99        self.assertEqual(my_partitioner.spec[spec_key], spec_value)
100
101        with self.assertRaisesRegex(
102            TypeError,
103            "'mappingproxy' object does not support item assignment",
104        ):
105            my_partitioner.spec[spec_key] = "new_value"
106
107        with self.assertRaisesRegex(
108            AttributeError,
109            "can't set attribute 'spec'",
110        ):
111            my_partitioner.spec = {"new_key": "new_value"}
112
113    def test_bad_partitioner_tagged_output(self):
114        # Create a bad partitioner to tag output, which is not allowed.
115        class PartitionerTagOutput(Partitioner):
116            def __init__(self) -> None:
117                super().__init__()
118                self.delegation_spec = DelegationSpec(
119                    ExecutorBackend.__name__,
120                    [CompileSpec(key, value) for key, value in self.spec.items()],
121                )
122
123            def partition(
124                self, edge_exported_program: ExportedProgram
125            ) -> PartitionResult:
126                partition_tags = {}
127                for node in edge_exported_program.graph.nodes:
128                    if node.op == "output":
129                        delegation_tag = "tag0"
130                        node.meta["delegation_tag"] = delegation_tag
131                        partition_tags[delegation_tag] = self.delegation_spec
132
133                return PartitionResult(
134                    tagged_exported_program=edge_exported_program,
135                    partition_tags=partition_tags,
136                )
137
138        mlp = MLP()
139        example_inputs = mlp.get_random_inputs()
140        model = export_for_training(mlp, example_inputs).module()
141        aten = export(model, example_inputs)
142        edge = exir.to_edge(aten)
143
144        with self.assertRaisesRegex(
145            RuntimeError,
146            "output node output should not be tagged",
147        ):
148            _ = edge.to_backend(PartitionerTagOutput())
149
150    def test_bad_partitioner_tagged_model_input(self):
151        # Create a bad partitioner to tag an input that is neither params nor buffer, which is not allowed.
152        class PartitionerTagInput(Partitioner):
153            def __init__(self) -> None:
154                super().__init__()
155                self.delegation_spec = DelegationSpec(
156                    ExecutorBackend.__name__,
157                    [CompileSpec(key, value) for key, value in self.spec.items()],
158                )
159
160            def partition(
161                self, edge_exported_program: ExportedProgram
162            ) -> PartitionResult:
163                partition_tags = {}
164                for node in edge_exported_program.graph.nodes:
165                    if node.op == "placeholder":
166                        if not is_param(edge_exported_program, node) and not is_buffer(
167                            edge_exported_program, node
168                        ):
169                            delegation_tag = "tag_" + str(node.meta["debug_handle"])
170                            node.meta["delegation_tag"] = delegation_tag
171                            partition_tags[delegation_tag] = self.delegation_spec
172
173                return PartitionResult(
174                    tagged_exported_program=edge_exported_program,
175                    partition_tags=partition_tags,
176                )
177
178        mlp = MLP()
179        example_inputs = mlp.get_random_inputs()
180        model = export_for_training(mlp, example_inputs).module()
181        edge = exir.to_edge(export(model, example_inputs))
182
183        with self.assertRaisesRegex(
184            RuntimeError,
185            "placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged",
186        ):
187            _ = edge.to_backend(PartitionerTagInput())
188
189    class AddConst(torch.nn.Module):
190        def __init__(self):
191            super().__init__()
192            self.const1 = torch.ones(2, 2)
193            self.register_buffer("const2", torch.ones(2, 2), persistent=False)
194            self.register_parameter("const3", torch.nn.Parameter(torch.ones(2, 2)))
195
196        def forward(self, x):
197            return x + self.const1 + self.const2 + self.const3
198
199    def test_partitioner_not_tag_data(self):
200        """
201        We test here that when partitioners do not explicitly tag constant data nodes,
202        then the partitioned ExportedProgram will not own the data. Instead the owning program
203        will still own the constant data and instead feed it as inputs to the partitioned
204        program
205        """
206
207        class PartitionerNoTagData(Partitioner):
208            def __init__(self):
209                super().__init__()
210                self.delegation_spec = DelegationSpec(
211                    ExecutorBackend.__name__,
212                    [CompileSpec(key, value) for key, value in self.spec.items()],
213                )
214
215            def partition(
216                self, edge_exported_program: ExportedProgram
217            ) -> PartitionResult:
218                partition_tags = {}
219                for node in edge_exported_program.graph.nodes:
220                    if node.op == "call_function" and node.target in [
221                        exir_ops.edge.aten.add.Tensor
222                    ]:
223                        delegation_tag = "tag0"
224                        node.meta["delegation_tag"] = delegation_tag
225                        partition_tags[delegation_tag] = self.delegation_spec
226
227                return PartitionResult(
228                    tagged_exported_program=edge_exported_program,
229                    partition_tags=partition_tags,
230                )
231
232        model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
233        edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
234        delegated = edge.to_backend(PartitionerNoTagData())
235
236        # Check Owning Program still owns all constant data
237        owning_program = delegated.exported_program()
238        self.assertEqual(
239            len(owning_program.state_dict) + len(owning_program.constants), 3
240        )
241        self.assertEqual(
242            len(owning_program.graph_signature.buffers)
243            + len(owning_program.graph_signature.lifted_tensor_constants),
244            2,
245        )
246        self.assertEqual(len(owning_program.graph_signature.parameters), 1)
247
248        # Check Lowered Module Exported Program does not have any constant data
249        lowered_module_nodes = get_delegates(delegated.exported_program().graph)
250        self.assertEqual(len(lowered_module_nodes), 1)
251        lowered_module_node = lowered_module_nodes[0]
252
253        # get call delegate node
254        call_delegate_node = list(lowered_module_node.users.keys())[0]
255        # 5 args to lowered module are: delegated_payload, x, const1, const2, const3
256        self.assertEqual(len(call_delegate_node.args), 5)
257        lower_module = getattr(
258            delegated.exported_program().graph_module, lowered_module_node.name
259        )
260        delegated_ep = lower_module.original_module
261        self.assertEqual(len(delegated_ep.state_dict), 0)
262        self.assertEqual(len(delegated_ep.graph_signature.buffers), 0)
263        self.assertEqual(len(delegated_ep.graph_signature.parameters), 0)
264
265        # check exported program is still runnable
266        output = delegated.exported_program().module()(torch.ones(2, 2))
267        reference_output = model(torch.ones(2, 2))
268        self.assertTrue(torch.allclose(reference_output, output))
269
270    def test_partitioner_tag_data(self):
271        """
272        We test here that when partitioners explicitly tag constant data nodes,
273        then the partitioned ExportedProgram will own the data, and the data will
274        be removed from the owning program.
275        """
276
277        class PartitionerTagData(Partitioner):
278            def __init__(self):
279                super().__init__()
280                self.delegation_spec = DelegationSpec(
281                    ExecutorBackend.__name__,
282                    [CompileSpec(key, value) for key, value in self.spec.items()],
283                )
284
285            def partition(
286                self, edge_exported_program: ExportedProgram
287            ) -> PartitionResult:
288                partition_tags = {}
289                for node in edge_exported_program.graph.nodes:
290                    if node.op == "call_function" and node.target in [
291                        exir_ops.edge.aten.add.Tensor
292                    ]:
293                        delegation_tag = "tag0"
294                        node.meta["delegation_tag"] = delegation_tag
295                        partition_tags[delegation_tag] = self.delegation_spec
296
297                    if node.op == "placeholder" and (
298                        is_param(edge_exported_program, node)
299                        or is_buffer(edge_exported_program, node)
300                        or is_lifted_tensor_constant(edge_exported_program, node)
301                    ):
302                        delegation_tag = "tag0"
303                        node.meta["delegation_tag"] = delegation_tag
304                        partition_tags[delegation_tag] = self.delegation_spec
305
306                return PartitionResult(
307                    tagged_exported_program=edge_exported_program,
308                    partition_tags=partition_tags,
309                )
310
311        model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
312        edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
313        delegated = edge.to_backend(PartitionerTagData())
314
315        # Check Owning Program still owns all constant data
316        owning_program = delegated.exported_program()
317        self.assertEqual(len(owning_program.state_dict), 0)
318        self.assertEqual(len(owning_program.graph_signature.buffers), 0)
319        self.assertEqual(len(owning_program.graph_signature.parameters), 0)
320
321        # Check Lowered Module Exported Program does not have any constant data
322        lowered_module_nodes = get_delegates(delegated.exported_program().graph)
323        self.assertEqual(len(lowered_module_nodes), 1)
324        lowered_module_node = lowered_module_nodes[0]
325
326        # get call delegate node
327        call_delegate_node = list(lowered_module_node.users.keys())[0]
328        # 5 args to lowered module are: delegated_payload, x
329        self.assertEqual(len(call_delegate_node.args), 2)
330        lower_module = getattr(
331            delegated.exported_program().graph_module, lowered_module_node.name
332        )
333        delegated_ep = lower_module.original_module
334        self.assertEqual(len(delegated_ep.state_dict) + len(delegated_ep.constants), 3)
335        self.assertEqual(
336            len(delegated_ep.graph_signature.buffers)
337            + len(delegated_ep.graph_signature.lifted_tensor_constants),
338            2,
339        )
340        self.assertEqual(len(delegated_ep.graph_signature.parameters), 1)
341
342        # check exported program is still runnable
343        output = delegated.exported_program().module()(torch.ones(2, 2))
344        reference_output = model(torch.ones(2, 2))
345        self.assertTrue(torch.allclose(reference_output, output))
346
347    def test_partitioner_tag_only_params(self):
348        """
349        We test here that when partitioners explicitly tag constant data nodes,
350        then the partitioned ExportedProgram will own the data, and the data will
351        be removed from the owning program.
352        """
353
354        class PartitionerTagData(Partitioner):
355            def __init__(self):
356                super().__init__()
357                self.delegation_spec = DelegationSpec(
358                    ExecutorBackend.__name__,
359                    [CompileSpec(key, value) for key, value in self.spec.items()],
360                )
361
362            def partition(
363                self, edge_exported_program: ExportedProgram
364            ) -> PartitionResult:
365                partition_tags = {}
366                for node in edge_exported_program.graph.nodes:
367                    if node.op == "call_function" and node.target in [
368                        exir_ops.edge.aten.add.Tensor
369                    ]:
370                        delegation_tag = "tag0"
371                        node.meta["delegation_tag"] = delegation_tag
372                        partition_tags[delegation_tag] = self.delegation_spec
373
374                    if node.op == "placeholder" and (
375                        is_param(edge_exported_program, node)
376                    ):
377                        delegation_tag = "tag0"
378                        node.meta["delegation_tag"] = delegation_tag
379                        partition_tags[delegation_tag] = self.delegation_spec
380
381                return PartitionResult(
382                    tagged_exported_program=edge_exported_program,
383                    partition_tags=partition_tags,
384                )
385
386        model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module()
387        edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
388        delegated = edge.to_backend(PartitionerTagData())
389
390        # Check Owning Program still owns only buffers
391        owning_program = delegated.exported_program()
392        self.assertEqual(
393            len(owning_program.state_dict) + len(owning_program.constants), 2
394        )
395        self.assertEqual(
396            len(owning_program.graph_signature.buffers)
397            + len(owning_program.graph_signature.lifted_tensor_constants),
398            2,
399        )
400        self.assertEqual(len(owning_program.graph_signature.parameters), 0)
401
402        # Check Lowered Module Exported Program does not own any buffers
403        lowered_module_nodes = get_delegates(delegated.exported_program().graph)
404        self.assertEqual(len(lowered_module_nodes), 1)
405        lowered_module_node = lowered_module_nodes[0]
406
407        # get call delegate node
408        call_delegate_node = list(lowered_module_node.users.keys())[0]
409        # 5 args to lowered module are: delegated_payload, x, buffer1, buffer2
410        self.assertEqual(len(call_delegate_node.args), 4)
411        lower_module = getattr(
412            delegated.exported_program().graph_module, lowered_module_node.name
413        )
414        delegated_ep = lower_module.original_module
415        self.assertEqual(len(delegated_ep.state_dict), 1)
416        self.assertEqual(len(delegated_ep.graph_signature.buffers), 0)
417        self.assertEqual(len(delegated_ep.graph_signature.parameters), 1)
418
419        # check exported program is still runnable
420        output = delegated.exported_program().module()(torch.ones(2, 2))
421        reference_output = model(torch.ones(2, 2))
422        self.assertTrue(torch.allclose(reference_output, output))
423
424    def test_partitioner_splits_constant_data(self):
425        """
426        We test that we throw an error when constant data users are split
427        between different delegated payloads or owning program.
428        """
429
430        class ReuseConstData(torch.nn.Module):
431            def __init__(self):
432                super().__init__()
433                self.const = torch.ones(2, 2)
434
435            def forward(self, x):
436                y = x + self.const
437                z = x - self.const
438                return y, z
439
440        class PartitionerTagData(Partitioner):
441            def __init__(self):
442                super().__init__()
443                self.delegation_spec = DelegationSpec(
444                    ExecutorBackend.__name__,
445                    [CompileSpec(key, value) for key, value in self.spec.items()],
446                )
447
448            def partition(
449                self, edge_exported_program: ExportedProgram
450            ) -> PartitionResult:
451                partition_tags = {}
452                for node in edge_exported_program.graph.nodes:
453                    if node.op == "call_function" and node.target in [
454                        exir_ops.edge.aten.add.Tensor
455                    ]:
456                        delegation_tag = "tag0"
457                        node.meta["delegation_tag"] = delegation_tag
458                        partition_tags[delegation_tag] = self.delegation_spec
459
460                    if node.op == "placeholder" and (
461                        is_param(edge_exported_program, node)
462                        or is_buffer(edge_exported_program, node)
463                    ):
464                        delegation_tag = "tag0"
465                        node.meta["delegation_tag"] = delegation_tag
466                        partition_tags[delegation_tag] = self.delegation_spec
467
468                return PartitionResult(
469                    tagged_exported_program=edge_exported_program,
470                    partition_tags=partition_tags,
471                )
472
473        inputs = (torch.ones(2, 2),)
474        model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
475        edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
476        exec_prog = edge.to_backend(PartitionerTagData()).to_executorch()
477        executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer)
478        inputs_flattened, _ = tree_flatten(inputs)
479
480        # Send the input from server executor to client executor, and receive the result from client executor
481        _ = executorch_module.run_method("forward", inputs)
482
483    def test_partitioner_alert_split_constant_data(self):
484        """
485        We test that we throw an error when constant data users are split
486        between different delegated payloads or owning program.
487        """
488
489        class ReuseConstData(torch.nn.Module):
490            def __init__(self):
491                super().__init__()
492                self.const = torch.ones(2, 2)
493
494            def forward(self, x):
495                y = x + self.const
496                z = x - self.const
497                return y, z
498
499        class PartitionerTagData(Partitioner):
500            def __init__(self):
501                super().__init__()
502                self.delegation_spec = DelegationSpec(
503                    ExecutorBackend.__name__,
504                    [CompileSpec(key, value) for key, value in self.spec.items()],
505                )
506
507            def partition(
508                self, edge_exported_program: ExportedProgram
509            ) -> PartitionResult:
510                partition_tags = {}
511                for node in edge_exported_program.graph.nodes:
512                    if node.op == "call_function" and node.target in [
513                        exir_ops.edge.aten.add.Tensor
514                    ]:
515                        delegation_tag = "tag0"
516                        node.meta["delegation_tag"] = delegation_tag
517                        partition_tags[delegation_tag] = self.delegation_spec
518
519                    if node.op == "placeholder" and (
520                        is_param(edge_exported_program, node)
521                        or is_buffer(edge_exported_program, node)
522                        or is_lifted_tensor_constant(edge_exported_program, node)
523                    ):
524                        delegation_tag = "tag0"
525                        node.meta["delegation_tag"] = delegation_tag
526                        node.meta["no_copy"] = True
527                        partition_tags[delegation_tag] = self.delegation_spec
528
529                return PartitionResult(
530                    tagged_exported_program=edge_exported_program,
531                    partition_tags=partition_tags,
532                )
533
534        model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
535        edge = exir.to_edge(export(model, (torch.ones(2, 2),)))
536        with self.assertRaises(RuntimeError) as error:
537            _ = edge.to_backend(PartitionerTagData())
538
539        self.assertTrue(
540            "is tagged with (tag0) but has user (aten_sub_tensor) which has tag (None)"
541            in str(error.exception),
542        )
543
544    def test_not_delegate_mutable_buffers(self) -> None:
545        """
546        A test case to check the mutated buffer is not delegated. We'll need to add a test case
547        to consider when the delegate can consume the mutable buffer.
548        """
549
550        class MutableStateModule(torch.nn.Module):
551            def __init__(self):
552                super().__init__()
553                self.register_buffer("my_state", torch.zeros(1))
554
555            def forward(self, x):
556                y = x + self.my_state
557                self.my_state.add_(1)
558                return y
559
560        edge = exir.to_edge(
561            torch.export.export(
562                MutableStateModule(),
563                (torch.zeros(1),),
564            )
565        )
566        self.assertGreater(
567            len(edge.exported_program().graph_signature.buffers_to_mutate),
568            0,
569            "The test case should at leaset one mutable buffer",
570        )
571
572        class PartitionerTagData(Partitioner):
573            def __init__(self):
574                super().__init__()
575                self.delegation_spec = DelegationSpec(
576                    ExecutorBackend.__name__,
577                    [CompileSpec(key, value) for key, value in self.spec.items()],
578                )
579
580            def partition(
581                self, edge_exported_program: ExportedProgram
582            ) -> PartitionResult:
583                partition_tags = {}
584                for node in edge_exported_program.graph.nodes:
585                    if node.op == "call_function" and node.target in [
586                        exir_ops.edge.aten.add.Tensor
587                    ]:
588                        delegation_tag = "tag0"
589                        node.meta["delegation_tag"] = delegation_tag
590                        partition_tags[delegation_tag] = self.delegation_spec
591                tag_constant_data(edge_exported_program)
592                return PartitionResult(
593                    tagged_exported_program=edge_exported_program,
594                    partition_tags=partition_tags,
595                )
596
597        # Check the edge program inital buffers_to_mutate
598        mutate_op = "aten_add_tensor_1"
599        self.assertEqual(
600            edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
601            "my_state",
602        )
603        edge = edge.to_backend(PartitionerTagData())
604        # After to_backend, add is delegated and is no longer in buffers_to_mutate.
605        self.assertNotIn(
606            mutate_op,
607            edge.exported_program().graph_signature.buffers_to_mutate,
608        )
609
610        mutate_op = "getitem_1"
611        # Ensure the mutated buffer is not delegated, and the new mutate node is getitem (from call_delegate)
612        self.assertEqual(
613            edge.exported_program().graph_signature.buffers_to_mutate[mutate_op],
614            "my_state",
615        )
616        # Check the copy_ node is inserted
617        edge = edge.to_executorch()
618        copy_node = [
619            node
620            for node in edge.exported_program().graph.nodes
621            if node.op == "call_function"
622            and node.target == torch.ops.aten.copy_.default
623        ]
624        self.assertEqual(len(copy_node), 1)
625
626    def test_buffer_mutation1(self):
627        class TestModule(torch.nn.Module):
628            def __init__(self):
629                super().__init__()
630                self.register_buffer("b", torch.ones(3, 3))
631
632            def forward(self, x):
633                self.b.add_(x)
634                return x + self.b
635
636        model_inputs = (torch.ones(3, 3),)
637        orig_res = TestModule()(*model_inputs)
638        edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs))
639        lowered = edge_program.to_backend(AddAttributePartitionerDemo())
640
641        self.assertTrue(
642            torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res)
643        )
644
645        self.assertEqual(
646            len(lowered.exported_program().graph_signature.buffers_to_mutate),
647            0,
648        )
649        lowered_module_nodes = get_delegates(lowered.exported_program().graph)
650        self.assertEqual(len(lowered_module_nodes), 1)
651        lowered_module_node = lowered_module_nodes[0]
652
653        # get call delegate node
654        call_delegate_node = list(lowered_module_node.users.keys())[0]
655        self.assertEqual(len(call_delegate_node.args), 2)
656
657        lower_module = getattr(
658            lowered.exported_program().graph_module, lowered_module_node.name
659        )
660        delegated_ep = lower_module.original_module
661
662        self.assertEqual(len(delegated_ep.state_dict), 1)
663        self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
664        self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
665
666    def test_buffer_mutation_llama_repro(self):
667        SHAPE = (2, 3)
668
669        class Model(torch.nn.Module):
670            def __init__(self):
671                super().__init__()
672                self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32))
673
674            def forward(self, q, k_val, input_pos):
675                q_T = q.transpose(0, 1)
676                k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
677                attn = k.mm(q_T)
678                return attn
679
680        q = torch.rand(1, 3)
681        k = torch.rand(1, 3)
682        example_inputs = (q, k, torch.tensor([1, 1]))
683
684        model = Model()
685        model.eval()
686
687        exir_program_aten = torch.export.export(model, example_inputs)
688        exir_program_aten.module()(*example_inputs)
689        edge_program_manager = exir.to_edge(exir_program_aten)
690        lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo())
691
692        self.assertEqual(
693            len(lowered.exported_program().graph_signature.buffers_to_mutate),
694            0,
695        )
696        lowered_module_nodes = get_delegates(lowered.exported_program().graph)
697        self.assertEqual(len(lowered_module_nodes), 1)
698        lowered_module_node = lowered_module_nodes[0]
699
700        # get call delegate node
701        call_delegate_node = list(lowered_module_node.users.keys())[0]
702        self.assertEqual(len(call_delegate_node.args), 4)
703
704        lower_module = getattr(
705            lowered.exported_program().graph_module, lowered_module_node.name
706        )
707        delegated_ep = lower_module.original_module
708
709        self.assertEqual(len(delegated_ep.state_dict), 1)
710        self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
711        self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)
712
713    def test_buffer_mutation_unsupported(self):
714        SHAPE = (2, 3)
715
716        class Model(torch.nn.Module):
717            def __init__(self):
718                super().__init__()
719                self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32))
720
721            def forward(self, x):
722                add = self.state_1.add_(x)
723                return add
724
725        model = Model()
726        model.eval()
727
728        example_inputs = (torch.randn(SHAPE),)
729        exir_program_aten = torch.export.export(model, example_inputs)
730        edge_program_manager = exir.to_edge(exir_program_aten)
731        with self.assertRaises(AssertionError):
732            edge_program_manager.to_backend(AddAttributePartitionerDemo())
733