xref: /aosp_15_r20/external/executorch/exir/emit/test/test_emit.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import typing
10import unittest
11from contextlib import contextmanager
12from typing import List, Optional, Tuple
13
14import executorch.exir as exir
15
16import executorch.exir.schema as schema
17import executorch.exir.tests.models as models
18import pytest
19import torch
20from executorch.exir import (
21    EdgeCompileConfig,
22    ExecutorchBackendConfig,
23    ExecutorchProgramManager,
24    to_edge,
25)
26from executorch.exir._serialize._program import deserialize_pte_binary
27from executorch.exir.backend.backend_api import to_backend
28from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
29from executorch.exir.dialects._ops import ops as exir_ops
30from executorch.exir.emit import emit_program  # noqa
31from executorch.exir.error import InternalError
32from executorch.exir.passes import MemoryPlanningPass
33from executorch.exir.passes.constant_prop_pass import constant_prop_pass
34from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
35from executorch.exir.print_program import pretty_print, print_program  # noqa
36from executorch.exir.schema import (
37    Bool,
38    DelegateCall,
39    Double,
40    EValue,
41    ExecutionPlan,
42    Int,
43    IntList,
44    JumpFalseCall,
45    KernelCall,
46    KernelTypes,
47    MoveCall,
48    Null,
49    OptionalTensorList,
50    Program,
51    String,
52    Tensor,
53)
54from executorch.exir.tests.common import register_additional_test_aten_ops
55from executorch.exir.tests.models import Mul
56from executorch.extension.pybindings.portable_lib import (
57    _load_for_executorch_from_buffer,
58)
59
60from functorch.experimental import control_flow
61from torch import nn
62
63from torch.export import Dim, export
64
65
66class WrapperModule(torch.nn.Module):
67    def __init__(self, fn):
68        super().__init__()
69        self.fn = fn
70
71    def forward(self, *args, **kwargs):
72        return self.fn(*args, **kwargs)
73
74
75@contextmanager
76def patch_forward(obj: torch.nn.Module, new_method):
77    """Helper method to make it easier to cleanly torch.export() a method on a
78    module that is not `forward`.
79
80    TODO(suo): upstream this to torch.export.wrapper.
81    """
82    # Save the original method
83    original_method = obj.forward
84
85    # Patch the method
86    obj.forward = new_method.__get__(obj, obj.__class__)
87
88    try:
89        yield
90    finally:
91        # Restore the original method
92        obj.forward = original_method
93
94
95class TestEmit(unittest.TestCase):
96    @classmethod
97    def setUpClass(cls) -> None:
98        register_additional_test_aten_ops()
99
100    def setUp(self) -> None:
101        self.compile_config = EdgeCompileConfig(_check_ir_validity=False)
102
103    def check_tensor_buffer_loc(
104        self,
105        value_index: int,
106        values: List[EValue],
107        exp_buffer_idx: int,
108        exp_mem_id: Optional[int],
109        exp_mem_offset: Optional[int],
110    ) -> None:
111        value = typing.cast(schema.Tensor, values[value_index].val)
112        self.assertIsInstance(value, schema.Tensor)
113
114        self.assertEqual(value.data_buffer_idx, exp_buffer_idx)
115
116        if not value.allocation_info:
117            self.assertIsNone(exp_mem_id)
118            self.assertIsNone(exp_mem_offset)
119        else:
120            self.assertEqual(value.allocation_info.memory_id, exp_mem_id)
121            assert value.allocation_info
122            self.assertEqual(value.allocation_info.memory_offset, exp_mem_offset)
123
124    def count_node(self, graph_module: torch.fx.GraphModule, opname: str) -> int:
125        return [
126            node.target._overloadpacket._qualified_op_name
127            for node in graph_module.graph.nodes
128            if node.op == "call_function"
129        ].count(opname)
130
131    def run_dce(self, graph_module: torch.fx.GraphModule) -> None:
132        for submodule in graph_module.modules():
133            self.assertIsInstance(submodule, torch.fx.GraphModule)
134            typing.cast(torch.fx.GraphModule, submodule).graph.eliminate_dead_code()
135
136    def check_value_types(self, values: List[EValue]) -> None:
137        for value in values:
138            self.assertTrue(type(value.val) in KernelTypes.__args__)
139
140    def count_move_instructions(self, program: Program) -> int:
141        instructions = program.execution_plan[0].chains[0].instructions
142        assert instructions is not None
143        res = 0
144        for instr in instructions:
145            if isinstance(instr.instr_args, MoveCall):
146                res += 1
147        return res
148
149    def test_basic_api(self) -> None:
150        class Foo(torch.nn.Module):
151            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
152                return x * y + x
153
154        f = Foo()
155
156        program = (
157            to_edge(
158                export(
159                    f,
160                    (torch.ones(3, 2), torch.zeros(3, 2)),
161                )
162            )
163            .to_executorch()
164            .executorch_program
165        )
166        exec_plan = program.execution_plan[0]
167        ops = exec_plan.operators
168        for op in ops:
169            self.assertEqual(op.overload, "out")
170
171        self.assertEqual(ops[0].name, "aten::mul")
172        self.assertEqual(ops[1].name, "aten::add")
173
174        self.assertEqual(len(exec_plan.inputs), 2)
175        self.assertEqual(len(exec_plan.outputs), 1)
176
177        self.assertEqual(exec_plan.inputs[0], 0)
178        self.assertEqual(exec_plan.outputs[0], 3)
179
180    def test_basic_end_to_end(self) -> None:
181        f = models.BasicSinMax()
182        program = (
183            to_edge(export(f, f.get_random_inputs())).to_executorch().executorch_program
184        )
185        exec_plan = program.execution_plan[0]
186        ops = exec_plan.operators
187        for op in ops:
188            self.assertIn(op.overload, {"out", "unary_out"})
189
190        self.assertEqual(ops[0].name, "aten::sin")
191
192        self.assertEqual(len(exec_plan.inputs), 1)
193        self.assertEqual(len(exec_plan.outputs), 1)
194
195        self.assertEqual(exec_plan.inputs[0], 0)
196        self.assertEqual(exec_plan.outputs[0], 1)
197
198    @pytest.mark.skip(reason="Test not working on OSS")
199    def test_nested_return(self) -> None:
200        class Foo(torch.nn.Module):
201            def forward(
202                self, x: torch.Tensor
203            ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
204                return (
205                    torch.tensor(1),
206                    torch.tensor(2),
207                    [torch.sin(x).max(), torch.cos(x).max()],
208                )
209
210        f = Foo()
211
212        x = (torch.randn(100),)
213        program = to_edge(export(f, x)).to_executorch().executorch_program
214        exec_plan = program.execution_plan[0]
215        self.assertEqual(len(exec_plan.outputs), 4)
216        self.assertEqual(len(exec_plan.inputs), 1)
217
218        self.assertEqual(
219            program.execution_plan[0].container_meta_type.encoded_out_str,
220            "T3#1#1#2($,$,L2#1#1($,$))",
221        )
222
223        self.assertEqual(
224            program.execution_plan[0].container_meta_type.encoded_inp_str,
225            "T2#1#0(T1#1($),D0())",
226        )
227
228    def test_constant_output(self):
229        class M(torch.nn.Module):
230            def forward(self, x):
231                return [((1, 3, 1.2), True, [x + x, x * x], None)]
232
233        ep = torch.export.export(M(), (torch.ones(2, 3),))
234        res = ep.module()(torch.ones(2, 3))
235        self.assertEqual(res[0][0], (1, 3, 1.2))
236        program = to_edge(ep).to_executorch().executorch_program
237        outputs = program.execution_plan[0].outputs
238        self.assertEqual(len(outputs), 7)
239        self.assertEqual(program.execution_plan[0].values[outputs[0]].val.int_val, 1)
240        self.assertEqual(program.execution_plan[0].values[outputs[1]].val.int_val, 3)
241        self.assertEqual(
242            program.execution_plan[0].values[outputs[2]].val.double_val, 1.2
243        )
244        self.assertEqual(
245            program.execution_plan[0].values[outputs[3]].val.bool_val, True
246        )
247        self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
248
249    def test_int_list_input(self):
250        class M(torch.nn.Module):
251            def forward(self, x, y, z):
252                return x + y, x + x, x + y + z
253
254        ep = torch.export.export(M(), (torch.ones(2, 3), 2, True))
255        ep.module()(torch.ones(2, 3), 2, True)
256        program = to_edge(ep).to_executorch().executorch_program
257        inputs = program.execution_plan[0].inputs
258        self.assertEqual(len(inputs), 3)
259        self.assertEqual(program.execution_plan[0].values[inputs[1]].val.int_val, 2)
260        self.assertEqual(program.execution_plan[0].values[inputs[2]].val.bool_val, True)
261
262    def test_inplace_ops(self) -> None:
263        class Foo(torch.nn.Module):
264            def forward(self, x: torch.Tensor) -> torch.Tensor:
265                y = torch.sin(x)
266                z = y.view(100)
267                torch.relu_(z)
268                return z.max()
269
270        f = Foo()
271
272        inputs = (torch.ones((10, 10)),)
273        edge = to_edge(export(f, inputs))
274
275        removed_ops = ["aten::relu_", "aten::view"]
276        expected_ops = [
277            "aten::sin",
278            "aten::relu",
279            "aten::max",
280            "executorch_prim::et_view",  # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False
281        ]
282
283        for opname in removed_ops:
284            self.assertEqual(
285                self.count_node(edge.exported_program().graph_module, opname), 0
286            )
287        for opname in expected_ops:
288            if (
289                opname != "executorch_prim::et_view"
290            ):  # et_view appears as call_function with target = memory.view in graph
291                self.assertTrue(
292                    self.count_node(edge.exported_program().graph_module, opname) >= 1
293                )
294
295        program = edge.to_executorch().executorch_program
296        for opname in removed_ops:
297            self.assertTrue(
298                all(op.name != opname for op in program.execution_plan[0].operators)
299            )
300        for opname in expected_ops:
301            self.assertTrue(
302                any(op.name == opname for op in program.execution_plan[0].operators)
303            )
304
305    def test_operators_unique(self) -> None:
306        class OpRepeatedModule(torch.nn.Module):
307            def __init__(self) -> None:
308                super().__init__()
309                self.a = torch.ones(2, 2)
310                self.b = 2 * torch.ones(2, 2)
311
312            def forward(self, x: torch.Tensor) -> torch.Tensor:
313                for _ in range(10):
314                    z = self.a * x
315                    y = z + self.b
316                return y
317
318        model = OpRepeatedModule()
319
320        inputs = (torch.ones(2, 2),)
321
322        program = to_edge(export(model, inputs)).to_executorch().executorch_program
323
324        self.assertEqual(len(program.execution_plan[0].operators), 2)
325
326    def test_list_type(self) -> None:
327        """Tests that the types of lists are correctly found"""
328
329        class Foo(torch.nn.Module):
330            def forward(self, x: torch.Tensor) -> torch.Tensor:
331                return torch.permute(x, (2, 0, 1))
332
333        f = Foo()
334
335        program = (
336            to_edge(export(f, (torch.randn(2, 3, 5),)))
337            .to_executorch()
338            .executorch_program
339        )
340        exir.print_program.pretty_print(program)
341
342        deboxed_int_list = []
343        for item in program.execution_plan[0].values[5].val.items:  # pyre-ignore[16]
344            deboxed_int_list.append(
345                program.execution_plan[0].values[item].val.int_val  # pyre-ignore[16]
346            )
347
348        self.assertEqual(IntList(deboxed_int_list), IntList([2, 0, 1]))
349
350    def test_kwargs1(self) -> None:
351        """Tests that the kwargs are placed in the order specified by
352        native_functions.yaml
353        """
354
355        class Foo(torch.nn.Module):
356            def forward(self, x: torch.Tensor) -> torch.Tensor:
357                batch1 = torch.randn(10, 3, 4)
358                batch2 = torch.randn(10, 4, 5)
359                return torch.addbmm(x, batch1, batch2, alpha=2, beta=3)
360
361        f = Foo()
362
363        program = (
364            to_edge(export(f, (torch.randn(3, 5),))).to_executorch().executorch_program
365        )
366        # The value for beta should appear before alpha
367        self.assertEqual(program.execution_plan[0].values[12].val, Int(3))
368        self.assertEqual(program.execution_plan[0].values[13].val, Int(2))
369
370    def test_kwargs2(self) -> None:
371        """Tests that the kwargs are placed in the order specified by
372        native_functions.yaml
373        """
374
375        class Foo(torch.nn.Module):
376            def forward(self, x: torch.Tensor) -> torch.Tensor:
377                values = torch.randn(3, 2)
378                return torch.searchsorted(x, values, side="right", right=True)
379
380        f = Foo()
381
382        x, _ = torch.sort(torch.randn(3, 4))
383        program = to_edge(export(f, (x,))).to_executorch().executorch_program
384        # The value for right should appear before side
385        self.assertEqual(program.execution_plan[0].values[6].val, Bool(False))
386        self.assertEqual(program.execution_plan[0].values[7].val, Bool(True))
387        self.assertEqual(program.execution_plan[0].values[8].val, String("right"))
388        self.assertEqual(program.execution_plan[0].values[9].val, Null())
389
390    def _assertCallLength(self, program: Program, idx: int, expected_len: int) -> None:
391        instr_args = program.execution_plan[0].chains[0].instructions[idx].instr_args
392
393        if isinstance(instr_args, KernelCall) or isinstance(instr_args, DelegateCall):
394            self.assertEqual(len(instr_args.args), expected_len)
395        else:
396            self.assertTrue(False)
397
398    def test_out(self) -> None:
399        class Foo(torch.nn.Module):
400            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
401                z = y.clone()
402                return torch.mul(x, y, out=z)
403
404        f = Foo()
405
406        program = (
407            to_edge(export(f, (torch.ones(3), torch.ones(3))))
408            .to_executorch()
409            .executorch_program
410        )
411
412        self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1)
413        self._assertCallLength(program, 0, 4)
414
415    def test_model_out(self) -> None:
416        class Module_out(torch.nn.Module):
417            def __init__(self) -> None:
418                super().__init__()
419                self.a = 3 * torch.ones(2, 2, dtype=torch.int32)
420                self.b = 2 * torch.ones(2, 2, dtype=torch.int32)
421
422            def forward(self, x: torch.Tensor) -> torch.Tensor:
423                z = x.clone()
424                torch.mul(self.a, x, out=z)
425                y = x.clone()
426                torch.add(z, self.b, alpha=2, out=y)
427                return y
428
429        model_out = Module_out()
430
431        inputs = (torch.ones(2, 2, dtype=torch.int32),)
432
433        # Trace to FX Graph.
434        program = to_edge(export(model_out, inputs)).to_executorch().executorch_program
435
436        self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 2)
437        self._assertCallLength(program, 0, 4)
438        self._assertCallLength(program, 1, 5)
439
440    def test_stacktrace(self) -> None:
441        def f(x: torch.Tensor) -> torch.Tensor:
442            return torch.mul(x, torch.randn(3, 2))
443
444        def g(x: torch.Tensor) -> torch.Tensor:
445            return torch.sin(f(x))
446
447        class Foo(torch.nn.Module):
448            def forward(self, x: torch.Tensor) -> torch.Tensor:
449                return torch.add(g(x), torch.randn(3, 2))
450
451        h = Foo()
452
453        x = (torch.randn(3, 2),)
454        exec_prog = to_edge(export(h, x)).to_executorch(
455            exir.ExecutorchBackendConfig(emit_stacktrace=True)
456        )
457        program = exec_prog.executorch_program
458
459        # Check the mul operator's stack trace contains f -> g -> h
460        self.assertTrue(
461            "return torch.mul(x, torch.randn(3, 2))"
462            in program.execution_plan[0]  # pyre-ignore[16]
463            .chains[0]
464            .stacktrace[1]
465            .items[-1]
466            .context
467        )
468        self.assertEqual(
469            program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
470        )
471        self.assertEqual(
472            program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "g"
473        )
474        self.assertEqual(
475            program.execution_plan[0].chains[0].stacktrace[1].items[-3].name, "forward"
476        )
477
478        # Check the sin operator's stack trace contains g -> h
479        self.assertEqual(
480            program.execution_plan[0].chains[0].stacktrace[2].items[-1].name, "g"
481        )
482        self.assertEqual(
483            program.execution_plan[0].chains[0].stacktrace[2].items[-2].name, "forward"
484        )
485
486    def test_stacktrace_off(self) -> None:
487        class Foo(torch.nn.Module):
488            def forward(self, x: torch.Tensor) -> torch.Tensor:
489                return torch.mul(x, torch.randn(3, 2))
490
491        f = Foo()
492
493        class Goo(torch.nn.Module):
494            def forward(self, x: torch.Tensor) -> torch.Tensor:
495                return torch.sin(f(x))
496
497        g = Goo()
498
499        class Hoo(torch.nn.Module):
500            def forward(self, x: torch.Tensor) -> torch.Tensor:
501                return torch.add(g(x), torch.randn(3, 2))
502
503        h = Hoo()
504
505        x = (torch.randn(3, 2),)
506        program = to_edge(export(h, x)).to_executorch().executorch_program
507
508        # Check the stacktrace is None since we did not specify to get the stacktrace
509        self.assertTrue(program.execution_plan[0].chains[0].stacktrace is None)
510
511    def test_positional_argument_default_value(self) -> None:
512        class Foo(torch.nn.Module):
513            def forward(self, x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
514                z = torch.ones(6, 2)
515                return torch.ops.aten.cat.out((x, n), out=z)
516
517        f = Foo()
518
519        x = torch.randn(3, 2)
520        program = (
521            to_edge(export(f, (x, x)))
522            # .to_edge(self.compile_config)  # TODO(larryliu): fix cat
523            .to_executorch().executorch_program
524        )
525
526        self.assertEqual(len(program.execution_plan[0].chains[0].instructions), 1)
527        self._assertCallLength(program, 0, 4)
528
529    @pytest.mark.skip(reason="Test not working on OSS")
530    def test_emit_multiple_out(self) -> None:
531        class Foo(torch.nn.Module):
532            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
533                return torch.topk(x, 5)
534
535        f = Foo()
536
537        x = (torch.randn(10),)
538        program = to_edge(export(f, x)).to_executorch().executorch_program
539        self._assertCallLength(program, 0, 8)
540
541    def test_emit_layout(self) -> None:
542        class Foo(torch.nn.Module):
543            def forward(self, x: torch.Tensor) -> torch.Tensor:
544                return torch.ones_like(x)
545
546        f = Foo()
547
548        x = (torch.randn(3, 2),)
549        program = to_edge(export(f, x)).to_executorch().executorch_program
550
551        vals = program.execution_plan[0].values
552        for val in vals:
553            v = val.val
554            if isinstance(v, Tensor):
555                self.assertEqual(v.layout, 0)
556
557    def test_optional_tensor_list(self) -> None:
558        class Foo(torch.nn.Module):
559            def forward(self, x: torch.Tensor) -> torch.Tensor:
560                a = torch.nonzero(x)
561                torch._constrain_as_size(a.shape[0], min=1)
562                b = torch.ops.aten.index.Tensor(x, [a])
563                return b
564
565        f = Foo()
566        x = (torch.triu(torch.ones(2, 2)),)
567        program = (
568            to_edge(
569                export(f, x),
570                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
571            )
572            .to_executorch()
573            .executorch_program
574        )
575        self.assertTrue(
576            isinstance(program.execution_plan[0].values[3].val, OptionalTensorList)
577        )
578        self._assertCallLength(program, 0, 3)
579        self._assertCallLength(program, 1, 4)
580
581    def test_optional_float_list(self) -> None:
582        class M(torch.nn.Module):
583            def forward(self, x):
584                return torch.nn.functional.interpolate(x, scale_factor=2)
585
586        x = (torch.randn(1, 1, 2, 2),)
587        program = to_edge(export(M(), x)).to_executorch().executorch_program
588        self.assertIsInstance(
589            program.execution_plan[0].values[-1].val, schema.OptionalTensorList
590        )
591
592    def test_emit_cond(self) -> None:
593        class M(torch.nn.Module):
594            def __init__(self):
595                super().__init__()
596
597            def forward(self, pred, x):
598                def true_fn(y: torch.Tensor) -> torch.Tensor:
599                    y = y + y
600                    y = torch.mm(y, y)
601                    return y
602
603                def false_fn(y: torch.Tensor) -> torch.Tensor:
604                    return torch.mm(y, y)
605
606                ret = control_flow.cond(pred, true_fn, false_fn, [x])
607                return ret
608
609        module = to_edge(export(M(), (torch.tensor(True), torch.ones(2, 2))))
610        program = module.to_executorch().executorch_program
611
612        num_mm = 0
613        num_add = 0
614        num_other = 0
615        for inst in program.execution_plan[0].chains[0].instructions:
616            if not isinstance(inst.instr_args, KernelCall):
617                continue
618
619            op = (
620                program.execution_plan[0]
621                .operators[inst.instr_args.op_index]  # pyre-ignore[16]
622                .name
623            )
624
625            if "mm" in op:
626                num_mm += 1
627            elif "add" in op:
628                num_add += 1
629            else:
630                num_other += 1
631
632        self.assertEqual(num_mm, 2)
633        self.assertEqual(num_add, 1)
634        self.assertEqual(num_other, 0)
635
636    def test_emit_map(self) -> None:
637        class Foo(torch.nn.Module):
638            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
639                def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
640                    return x + y
641
642                return control_flow.map(map_fn, x, y)
643
644        f = Foo()
645
646        inputs = (torch.ones(4, 4), torch.ones(4))
647        module = to_edge(
648            export(f, inputs),
649            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
650        )
651        program = module.to_executorch().executorch_program
652
653        op_table = program.execution_plan[0].operators
654        # The first two operators at the beginning of a map program should be sym_size
655        # and select_copy, which is what we verify here. The first operator is to generate
656        # the number of iterations and the second operator is to slice the input tensor to
657        # generate the tensor on which this iteration will operate on.
658        self.assertEqual(
659            op_table[
660                program.execution_plan[0]  # pyre-ignore[16]
661                .chains[0]
662                .instructions[0]
663                .instr_args.op_index
664            ].name,
665            "aten::sym_size",
666        )
667        self.assertEqual(
668            op_table[
669                program.execution_plan[0]  # pyre-ignore[16]
670                .chains[0]
671                .instructions[1]
672                .instr_args.op_index
673            ].name,
674            "aten::select_copy",
675        )
676
677        # The last three instructions in the map sub-program are:
678        # - Calling the custom op to append the output of this iteration to the accumulator tensor
679        # - Increment the iteration count.
680        # - Then checking if we've completed all the iterations.
681        # We check here that both of these have been generated.
682        self.assertEqual(
683            op_table[
684                program.execution_plan[0]  # pyre-ignore[16]
685                .chains[0]
686                .instructions[-5]
687                .instr_args.op_index
688            ].name,
689            "executorch_prim::et_copy_index",
690        )
691        self.assertEqual(
692            op_table[
693                program.execution_plan[0]  # pyre-ignore[16]
694                .chains[0]
695                .instructions[-4]
696                .instr_args.op_index
697            ].name,
698            "executorch_prim::add",
699        )
700        self.assertEqual(
701            op_table[
702                program.execution_plan[0]  # pyre-ignore[16]
703                .chains[0]
704                .instructions[-3]
705                .instr_args.op_index
706            ].name,
707            "executorch_prim::eq",
708        )
709        # The last two instructions in the overall program check if we should jump back to the
710        # beginning of the loop and then resets the iteration counter if we fall through.
711        self.assertTrue(
712            isinstance(
713                program.execution_plan[0].chains[0].instructions[-2].instr_args,
714                JumpFalseCall,
715            )
716        )
717        self.assertEqual(
718            op_table[
719                program.execution_plan[0]  # pyre-ignore[16]
720                .chains[0]
721                .instructions[-1]
722                .instr_args.op_index
723            ].name,
724            "executorch_prim::sub",
725        )
726
727    def test_load_emit_map(self) -> None:
728        class Foo(torch.nn.Module):
729            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
730                def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
731                    return x + y
732
733                return control_flow.map(map_fn, x, y)
734
735        f = Foo()
736
737        inputs = (torch.ones(4, 4), torch.ones(4))
738        module = to_edge(
739            export(f, inputs),
740            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
741        )
742        _load_for_executorch_from_buffer(module.to_executorch().buffer)
743
744    def test_run_emit_map(self) -> None:
745        class Foo(torch.nn.Module):
746            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
747                def map_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
748                    return x + y
749
750                return control_flow.map(map_fn, x, y)
751
752        f = Foo()
753
754        inputs = (torch.ones(4, 4), torch.ones(4))
755        module = to_edge(
756            export(f, inputs),
757            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
758        )
759        buffer = module.to_executorch().buffer
760        loaded_model = _load_for_executorch_from_buffer(buffer)
761        outputs = loaded_model(inputs)[0]
762        torch.allclose(outputs, f(*inputs))
763
764    def test_dim_order(self) -> None:
765        class SimpleLinear(torch.nn.Module):
766            def __init__(self) -> None:
767                super().__init__()
768                self.linear = torch.nn.Linear(5, 5)
769
770            def forward(self, x: torch.Tensor) -> torch.Tensor:
771                return torch.nn.functional.relu(self.linear(x))
772
773        model = SimpleLinear()
774        inputs = (torch.ones(10, 5),)
775        program = (
776            to_edge(
777                export(model, inputs),
778                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
779            )
780            .to_executorch()
781            .executorch_program
782        )
783
784        addmm_found = False
785        for inst in program.execution_plan[0].chains[0].instructions:
786            kernel = inst.instr_args
787            if isinstance(kernel, KernelCall):
788                op_id = kernel.op_index
789                op = program.execution_plan[0].operators[op_id]
790                if op.name == "aten::addmm":
791                    addmm_found = True
792                    args = kernel.args
793                    bias_id = args[0]
794                    act_id = args[1]
795                    weight_id = args[2]
796                    bias_dim_order = [0]
797                    act_dim_order = [0, 1]
798                    weight_dim_order = [0, 1]
799                    bias_tensor = typing.cast(
800                        schema.Tensor, program.execution_plan[0].values[bias_id].val
801                    )
802                    act_tensor = typing.cast(
803                        schema.Tensor, program.execution_plan[0].values[act_id].val
804                    )
805                    weight_tensor = typing.cast(
806                        schema.Tensor, program.execution_plan[0].values[weight_id].val
807                    )
808                    self.assertTrue(bias_tensor.dim_order == bias_dim_order)
809                    self.assertTrue(act_tensor.dim_order == act_dim_order)
810                    self.assertTrue(weight_tensor.dim_order == weight_dim_order)
811        self.assertTrue(addmm_found)
812
813    def test_non_const_buffer_sizes(self) -> None:
814        class Add(torch.nn.Module):
815            def forward(self, x: torch.Tensor) -> torch.Tensor:
816                b = 3 + 1
817                return x + b
818
819        f = Add()
820
821        edge_program_manager = to_edge(
822            export(
823                f,
824                (torch.ones(3, 2),),
825            )
826        )
827        edge_program_manager._edge_programs["forward"] = constant_prop_pass(
828            edge_program_manager.exported_program()
829        )
830        non_const_buffer_size_with_const_prop_pass = (
831            edge_program_manager.to_executorch()
832            .executorch_program.execution_plan[0]
833            .non_const_buffer_sizes
834        )
835
836        edge_program_manager = to_edge(
837            export(
838                f,
839                (torch.ones(3, 2),),
840            )
841        )
842        non_const_buffer_size_without_const_prop_pass = (
843            edge_program_manager.to_executorch()
844            .executorch_program.execution_plan[0]
845            .non_const_buffer_sizes
846        )
847        self.assertTrue(
848            non_const_buffer_size_with_const_prop_pass[1]
849            < non_const_buffer_size_without_const_prop_pass[1]
850        )
851
852    # cant compare plans directly with __eq__ because of the plan names, and data_buffer_idx in tensor values
853    def _compare_execution_plans(
854        self, plan_single: ExecutionPlan, plan_merged: ExecutionPlan
855    ) -> None:
856        self.assertEqual(
857            plan_single.container_meta_type,
858            plan_merged.container_meta_type,
859        )
860        self.assertEqual(
861            plan_single.inputs,
862            plan_merged.inputs,
863        )
864        self.assertEqual(
865            plan_single.outputs,
866            plan_merged.outputs,
867        )
868        self.assertEqual(
869            plan_single.chains,
870            plan_merged.chains,
871        )
872        self.assertEqual(
873            plan_single.operators,
874            plan_merged.operators,
875        )
876        self.assertEqual(
877            plan_single.non_const_buffer_sizes,
878            plan_merged.non_const_buffer_sizes,
879        )
880        self.assertEqual(
881            len(plan_single.values),
882            len(plan_merged.values),
883        )
884        for i in range(0, len(plan_single.values)):
885            single_val = plan_single.values[i].val
886            merged_val = plan_merged.values[i].val
887            if isinstance(single_val, Tensor):
888                # constant buffer index might be different as the constant buffer is shared between plans
889                self.assertTrue(isinstance(merged_val, Tensor))
890                self.assertEqual(single_val.storage_offset, merged_val.storage_offset)
891                self.assertEqual(single_val.scalar_type, merged_val.scalar_type)
892                self.assertEqual(single_val.sizes, merged_val.sizes)
893                self.assertEqual(single_val.dim_order, merged_val.dim_order)
894                self.assertEqual(single_val.requires_grad, merged_val.requires_grad)
895                self.assertEqual(single_val.layout, merged_val.layout)
896                self.assertEqual(single_val.allocation_info, merged_val.allocation_info)
897                self.assertEqual(single_val.shape_dynamism, merged_val.shape_dynamism)
898            else:
899                self.assertEqual(single_val, merged_val)
900
901    def test_emit_memory_format_valid(self) -> None:
902        class SimpleLinear(torch.nn.Module):
903            def __init__(self) -> None:
904                super().__init__()
905
906            def forward(self, x: torch.Tensor) -> torch.Tensor:
907                contiguous = x.to(
908                    dtype=torch.float32, memory_format=torch.contiguous_format
909                )
910                preserve = x.to(
911                    dtype=torch.float32, memory_format=torch.preserve_format
912                )
913                return contiguous + preserve
914
915        # Should succeed at exporting model with legal memory format (contiguous, preserve)
916        model = SimpleLinear()
917        inputs = (torch.ones(10, 5),)
918        try:
919            to_edge(
920                export(model, inputs),
921                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
922            ).to_executorch()
923        except:
924            self.fail("Failed to export model with legal memory format")
925
926    def test_emit_memory_format_invalid(self) -> None:
927        class SimpleLinear(torch.nn.Module):
928            def __init__(self) -> None:
929                super().__init__()
930
931            def forward(self, x: torch.Tensor) -> torch.Tensor:
932                return x.to(dtype=torch.float32, memory_format=torch.channels_last)
933
934        # Failure expected when exporting model with illegal memory format (channels_last) when not using dim_order
935        model = SimpleLinear()
936        inputs = (torch.ones(10, 5, 2, 1),)
937        with self.assertRaises(InternalError):
938            to_edge(
939                export(model, inputs),
940                compile_config=exir.EdgeCompileConfig(
941                    _check_ir_validity=False, _skip_dim_order=True
942                ),
943            ).to_executorch()
944
945        # Success if you use dim_order
946        to_edge(
947            export(model, inputs),
948            compile_config=exir.EdgeCompileConfig(
949                _check_ir_validity=False, _skip_dim_order=False
950            ),
951        ).to_executorch()
952
953    def test_emit_multiple_entry_points(self) -> None:
954        class SimpleLinear(torch.nn.Module):
955            def __init__(self) -> None:
956                super().__init__()
957                self.linear = torch.nn.Linear(5, 5)
958                self.linear2 = torch.nn.Linear(5, 5)
959
960            def forward_relu(self, x: torch.Tensor) -> torch.Tensor:
961                return torch.nn.functional.relu(self.linear(x))
962
963            def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor:
964                return torch.nn.functional.sigmoid(self.linear2(x))
965
966        model = SimpleLinear()
967        inputs = (torch.ones(10, 5),)
968        with patch_forward(model, model.forward_relu):
969            program_relu = to_edge(
970                export(model, inputs),
971                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
972            ).to_executorch()
973        with patch_forward(model, model.forward_sigmoid):
974            program_sigmoid = to_edge(
975                export(model, inputs),
976                compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
977            ).to_executorch()
978        exir_input = {
979            "forward_relu": program_relu.exported_program(),
980            "forward_sigmoid": program_sigmoid.exported_program(),
981        }
982        merged_program = emit_program(exir_input, False).program
983        self.assertEqual(len(merged_program.execution_plan), 2)
984
985        self.assertEqual(
986            merged_program.execution_plan[0].name,
987            "forward_relu",
988        )
989        self.assertEqual(
990            merged_program.execution_plan[1].name,
991            "forward_sigmoid",
992        )
993        # reserved spot, weight, bias
994        self.assertEqual(
995            len(program_sigmoid._emitter_output.program.constant_buffer),
996            3,
997        )
998        self.assertEqual(
999            len(program_relu._emitter_output.program.constant_buffer),
1000            3,
1001        )
1002        # sum of the entry points minus 1 because we only have one reserved spot still
1003        self.assertEqual(
1004            len(merged_program.constant_buffer),
1005            len(program_sigmoid._emitter_output.program.constant_buffer)
1006            + len(program_relu._emitter_output.program.constant_buffer)
1007            - 1,
1008        )
1009
1010        self._compare_execution_plans(
1011            merged_program.execution_plan[0],
1012            program_relu._emitter_output.program.execution_plan[0],
1013        )
1014        self._compare_execution_plans(
1015            merged_program.execution_plan[1],
1016            program_sigmoid._emitter_output.program.execution_plan[0],
1017        )
1018
1019    def test_emit_weight_deduplication(self) -> None:
1020        class SimpleLinear(torch.nn.Module):
1021            def __init__(self) -> None:
1022                super().__init__()
1023                self.linear = torch.nn.Linear(5, 5)
1024
1025            def forward_relu(self, x: torch.Tensor) -> torch.Tensor:
1026                return torch.nn.functional.relu(self.linear(x))
1027
1028            def forward_sigmoid(self, x: torch.Tensor) -> torch.Tensor:
1029                return torch.nn.functional.sigmoid(self.linear(x))
1030
1031        model = SimpleLinear()
1032        inputs = (torch.ones(10, 5),)
1033        with patch_forward(model, model.forward_relu):
1034            program_relu = to_edge(export(model, inputs)).to_executorch()
1035        with patch_forward(model, model.forward_sigmoid):
1036            program_sigmoid = to_edge(export(model, inputs)).to_executorch()
1037        exir_input = {
1038            "forward_relu": program_relu.exported_program(),
1039            "forward_sigmoid": program_sigmoid.exported_program(),
1040        }
1041        merged_program = emit_program(exir_input, False).program
1042        self.assertEqual(len(merged_program.execution_plan), 2)
1043
1044        # reserved spot, weight, bias
1045        self.assertEqual(
1046            len(program_sigmoid._emitter_output.program.constant_buffer),
1047            3,
1048        )
1049        self.assertEqual(
1050            len(program_relu._emitter_output.program.constant_buffer),
1051            3,
1052        )
1053        # weights are shared between entry points so the merged one should deduplicate everything
1054        self.assertEqual(len(merged_program.constant_buffer), 3)
1055
1056        self._compare_execution_plans(
1057            merged_program.execution_plan[0],
1058            program_relu._emitter_output.program.execution_plan[0],
1059        )
1060        self._compare_execution_plans(
1061            merged_program.execution_plan[1],
1062            program_sigmoid._emitter_output.program.execution_plan[0],
1063        )
1064
1065    def test_emit_execution_plans_sorted(self) -> None:
1066        class Simple(torch.nn.Module):
1067            def __init__(self) -> None:
1068                super().__init__()
1069
1070            def a(self, x: torch.Tensor) -> torch.Tensor:
1071                return x
1072
1073            def b(self, x: torch.Tensor) -> torch.Tensor:
1074                return x
1075
1076            def c(self, x: torch.Tensor) -> torch.Tensor:
1077                return x
1078
1079        model = Simple()
1080        inputs = (torch.ones(10, 5),)
1081
1082        def make_program(
1083            fn,
1084            inputs,
1085        ) -> "ExecutorchProgramManager":
1086            return to_edge(
1087                export(
1088                    WrapperModule(fn),
1089                    inputs,
1090                )
1091            ).to_executorch()
1092
1093        program_a = make_program(model.a, inputs)
1094        program_b = make_program(model.b, inputs)
1095        program_c = make_program(model.c, inputs)
1096
1097        exir_input = {
1098            "b": program_b.exported_program(),
1099            "c": program_c.exported_program(),
1100            "a": program_a.exported_program(),
1101        }
1102        merged_program = emit_program(exir_input, False).program
1103        self.assertEqual(len(merged_program.execution_plan), 3)
1104        self.assertEqual(merged_program.execution_plan[0].name, "a")
1105        self.assertEqual(merged_program.execution_plan[1].name, "b")
1106        self.assertEqual(merged_program.execution_plan[2].name, "c")
1107
1108        # Create a second program equivalent to the first, but the input is in a different order.
1109        # python dicts are instertion ordered
1110        exir_input2 = {
1111            "a": program_b.exported_program(),
1112            "b": program_c.exported_program(),
1113            "c": program_a.exported_program(),
1114        }
1115        merged_program2 = emit_program(exir_input2, False).program
1116        self.assertEqual(
1117            merged_program2.execution_plan[0], merged_program.execution_plan[0]
1118        )
1119        self.assertEqual(
1120            merged_program2.execution_plan[1], merged_program.execution_plan[1]
1121        )
1122        self.assertEqual(
1123            merged_program2.execution_plan[2], merged_program.execution_plan[2]
1124        )
1125
1126    def test_upper_bound_memory_planning_respect_input_constraints(self) -> None:
1127        class Foo(torch.nn.Module):
1128            def forward(self, k: torch.Tensor) -> torch.Tensor:
1129                k = torch.cat((k, torch.ones(1, 4)))
1130                return k
1131
1132        func = Foo()
1133
1134        k = torch.rand(2, 4)
1135        dim0_k = Dim("dim0_k", max=3)
1136        dynamic_shapes = {"k": {0: dim0_k}}
1137        captured = export(
1138            func,
1139            (k,),
1140            dynamic_shapes=dynamic_shapes,
1141        )
1142        edge = to_edge(captured)
1143        from executorch.exir.passes import MemoryPlanningPass
1144
1145        config = exir.ExecutorchBackendConfig(
1146            sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
1147            memory_planning_pass=MemoryPlanningPass(
1148                # allow_lifetime_and_storage_overlap: bool = False,
1149                alloc_graph_input=True,
1150                alloc_graph_output=False,
1151            ),
1152        )
1153
1154        exe_prog = edge.to_executorch(config)
1155        program = exe_prog._emitter_output.program
1156        exir.print_program.pretty_print(exe_prog._emitter_output.program.execution_plan)
1157        execution_plan = program.execution_plan[0]
1158        self.check_tensor_buffer_loc(0, execution_plan.values, 0, 1, 0)
1159        self.check_tensor_buffer_loc(1, execution_plan.values, 0, 1, 48)
1160
1161    def test_emit_prims(self) -> None:
1162        tensor_output = torch.rand(1, 4)
1163        tensor_list_output = [torch.rand(1, 4), torch.rand(1, 4)]
1164
1165        class Simple(torch.nn.Module):
1166            def __init__(self) -> None:
1167                super().__init__()
1168                self.linear = torch.nn.Linear(5, 5)
1169                self.x: int = 3
1170                self.y = 2
1171
1172            def get_ints(self) -> Tuple[int]:
1173                return (self.x, self.y)
1174
1175            def get_str(self) -> str:
1176                return "foo"
1177
1178            def get_tensor(self) -> torch.Tensor:
1179                return tensor_output
1180
1181            def get_tensor_list(self) -> List[torch.Tensor]:
1182                return tensor_list_output
1183
1184            def forward(self, x: torch.Tensor) -> torch.Tensor:
1185                return torch.nn.functional.sigmoid(self.linear(x))
1186
1187        model = Simple()
1188        inputs = (torch.ones(10, 5),)
1189        program = to_edge(export(model, inputs)).to_executorch()
1190        exir_input = {
1191            "forward": program.exported_program(),
1192        }
1193        getters = {}
1194        getters["get_ints"] = model.get_ints()
1195        getters["get_str"] = model.get_str()
1196        getters["get_tensor"] = model.get_tensor()
1197        getters["get_tensor_list"] = model.get_tensor_list()
1198
1199        merged_program = emit_program(exir_input, False, getters).program
1200
1201        self.assertEqual(len(merged_program.execution_plan), 5)
1202
1203        self.assertEqual(
1204            merged_program.execution_plan[0].name,
1205            "forward",
1206        )
1207        self.assertEqual(
1208            merged_program.execution_plan[1].name,
1209            "get_ints",
1210        )
1211        self.assertEqual(
1212            merged_program.execution_plan[2].name,
1213            "get_str",
1214        )
1215        self.assertEqual(
1216            merged_program.execution_plan[3].name,
1217            "get_tensor",
1218        )
1219        self.assertEqual(
1220            merged_program.execution_plan[4].name,
1221            "get_tensor_list",
1222        )
1223
1224        # no instructions in a getter
1225        self.assertEqual(
1226            len(merged_program.execution_plan[1].chains[0].instructions),
1227            0,
1228        )
1229        # 2 outputs for the flattened tuple
1230        self.assertEqual(
1231            len(merged_program.execution_plan[1].outputs),
1232            2,
1233        )
1234        # outputs are 0 and 1 in the values table
1235        self.assertEqual(
1236            merged_program.execution_plan[1].outputs,
1237            [0, 1],
1238        )
1239        # value 0 is 3
1240        self.assertEqual(
1241            # pyre-ignore
1242            merged_program.execution_plan[1].values[0].val.int_val,
1243            3,
1244        )
1245        self.assertEqual(
1246            # pyre-ignore
1247            merged_program.execution_plan[1].values[1].val.int_val,
1248            2,
1249        )
1250        self.assertEqual(
1251            len(merged_program.execution_plan[2].outputs),
1252            1,
1253        )
1254        self.assertEqual(
1255            # pyre-ignore
1256            merged_program.execution_plan[2].values[0].val.string_val,
1257            "foo",
1258        )
1259        self.assertEqual(len(merged_program.execution_plan[3].outputs), 1)
1260        self.assertEqual(len(merged_program.execution_plan[4].outputs), 2)
1261
1262        merged_program = to_edge(
1263            export(model, inputs), constant_methods=getters
1264        ).to_executorch()
1265        executorch_module = _load_for_executorch_from_buffer(merged_program.buffer)
1266        torch.allclose(executorch_module.run_method("get_tensor", [])[0], tensor_output)
1267        model_output = executorch_module.run_method("get_tensor_list", [])
1268        for i in range(len(tensor_list_output)):
1269            torch.allclose(model_output[i], tensor_list_output[i])
1270
1271    def test_emit_debug_handle_map(self) -> None:
1272        mul_model = Mul()
1273        program_mul = to_edge(
1274            export(
1275                mul_model,
1276                mul_model.get_random_inputs(),
1277            )
1278        ).to_executorch()
1279        # this triggers the actual emission of the graph
1280        program_mul._emitter_output.program
1281        self.assertIsNotNone(program_mul.debug_handle_map)
1282
1283    def test_final_graph_module_update_debug_handle(self) -> None:
1284        class SimpleAddMul(torch.nn.Module):
1285            def __init__(self) -> None:
1286                super().__init__()
1287
1288            def forward(self, x: torch.Tensor) -> torch.Tensor:
1289                a = x + 1
1290                return a * 2
1291
1292        mul_model = SimpleAddMul()
1293        program_mul = to_edge(
1294            export(
1295                mul_model,
1296                (torch.ones(2, 2),),
1297            )
1298        ).to_executorch()
1299
1300        # this triggers the actual emission of the graph
1301        program = program_mul._emitter_output.program
1302        node = None
1303        program.execution_plan[0].chains[0].instructions[  # pyre-ignore[16]
1304            0
1305        ].instr_args.op_index
1306
1307        # Find the multiplication node in the graph that was emitted.
1308        for node in program_mul.exported_program().graph.nodes:
1309            if node.target == torch.ops.aten.mul.out:
1310                break
1311        self.assertIsNotNone(node)
1312
1313        idx = 0
1314        # Find the multiplication instruction in the program that was emitted.
1315        for idx in range(len(program.execution_plan[0].chains[0].instructions)):
1316            instruction = program.execution_plan[0].chains[0].instructions[idx]
1317            op_index = instruction.instr_args.op_index  # pyre-ignore[16]
1318            if "mul" in program.execution_plan[0].operators[op_index].name:
1319                break
1320
1321        # The instruction id of the multiplication instruction and the debug handle of the
1322        # multiplication node in the graph module (which was updated in the emitter to be
1323        # the same as the instruction id) must be the same.
1324        self.assertEqual(
1325            idx,
1326            node.meta.get("debug_handle"),
1327        )
1328
1329    def test_delegate_with_input_list(self) -> None:
1330        class BackendWithCompilerExample(BackendDetails):
1331            @staticmethod
1332            def preprocess(
1333                edge_program,
1334                compile_specs,
1335            ) -> bytes:
1336                return PreprocessResult(
1337                    processed_bytes=bytes(str("test"), encoding="utf8"),
1338                    debug_handle_map=None,
1339                )
1340
1341        class TestModel(nn.Module):
1342            def __init__(self):
1343                super(TestModel, self).__init__()
1344
1345            def forward(self, x):
1346                return torch.cat(x)
1347
1348        inputs = ([torch.ones(2, 2), torch.ones(2, 2)],)
1349        model = TestModel()
1350        edgeir_m = to_edge(export(model, inputs))
1351        lowered_module = to_backend(
1352            "BackendWithCompilerExample", edgeir_m.exported_program(), []
1353        )
1354
1355        class CompositeModule(torch.nn.Module):
1356            def __init__(self):
1357                super().__init__()
1358                self.lowered_module = lowered_module
1359
1360            def forward(self, list_a):
1361                return self.lowered_module(list_a)
1362
1363        composite_model = CompositeModule()
1364        exec_prog = to_edge(
1365            export(composite_model, inputs),
1366        ).to_executorch()
1367        exec_prog.buffer
1368
1369    def test_delegate_with_input_tuple(self) -> None:
1370        class BackendWithCompilerExample(BackendDetails):
1371            @staticmethod
1372            def preprocess(
1373                edge_program,
1374                compile_specs,
1375            ) -> bytes:
1376                return PreprocessResult(
1377                    processed_bytes=bytes(str("test"), encoding="utf8"),
1378                    debug_handle_map=None,
1379                )
1380
1381        class AddMulModule(torch.nn.Module):
1382            def __init__(self):
1383                super().__init__()
1384
1385            def forward(self, input):  # a, x, b):
1386                y = torch.mm(input[0], input[1])
1387                z = torch.add(y, input[2])
1388                return z
1389
1390        model_inputs = ((torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)),)
1391        model = AddMulModule()
1392        edgeir_m = to_edge(export(model, model_inputs))
1393        lowered_module = to_backend(
1394            "BackendWithCompilerExample", edgeir_m.exported_program(), []
1395        )
1396
1397        class CompositeModule(torch.nn.Module):
1398            def __init__(self):
1399                super().__init__()
1400                self.lowered_module = lowered_module
1401
1402            def forward(self, list_a):
1403                return self.lowered_module(list_a)
1404
1405        composite_model = CompositeModule()
1406        exec_prog = to_edge(
1407            export(composite_model, model_inputs),
1408        ).to_executorch()
1409        exec_prog.buffer
1410
1411    def test_delegate_mapping(self) -> None:
1412        debug_handle_map = {1: [1, 2]}
1413
1414        class BackendWithCompilerExample(BackendDetails):
1415            @staticmethod
1416            def preprocess(
1417                edge_program,
1418                compile_specs,
1419            ) -> bytes:
1420                return PreprocessResult(
1421                    processed_bytes=bytes(str("test"), encoding="utf8"),
1422                    debug_handle_map=debug_handle_map,
1423                )
1424
1425        class TestModel(nn.Module):
1426            def __init__(self):
1427                super(TestModel, self).__init__()
1428
1429            def forward(self, x, y):
1430                return torch.add(x, y)
1431
1432        inputs = (torch.ones(2, 2), torch.ones(2, 2))
1433        model = TestModel()
1434        edgeir_m = to_edge(export(model, inputs))
1435        lowered_module = to_backend(
1436            "BackendWithCompilerExample", edgeir_m.exported_program(), []
1437        )
1438
1439        class CompositeModule(torch.nn.Module):
1440            def __init__(self):
1441                super().__init__()
1442                self.lowered_module = lowered_module
1443
1444            def forward(self, x, y):
1445                return self.lowered_module(x, y)
1446
1447        composite_model = CompositeModule()
1448        exec_prog = to_edge(
1449            export(composite_model, inputs),
1450        ).to_executorch()
1451        # Reading the program triggers the call to emit_program underneath which
1452        # we need to be done for our test to succeed.
1453        exec_prog._emitter_output.program
1454        self.assertIsNotNone(exec_prog.delegate_map)
1455        self.assertIsNotNone(exec_prog.delegate_map.get("forward"))
1456        self.assertIsNotNone(
1457            exec_prog.delegate_map.get("forward").get(0)  # pyre-ignore[16]
1458        )
1459        self.assertEqual(
1460            exec_prog.delegate_map.get("forward").get(0).get("name"),
1461            "BackendWithCompilerExample",
1462        )
1463        self.assertTrue(
1464            len(exec_prog.delegate_map.get("forward").get(0).get("delegate_map")) != 0
1465        )
1466
1467    def test_emit_weight_view(self) -> None:
1468        class ModWithWeightViews(nn.Module):
1469            def __init__(self):
1470                super(ModWithWeightViews, self).__init__()
1471                self.W = torch.nn.Parameter(torch.randn(2))
1472                self.W1 = self.W[:1]
1473                self.W2 = self.W[1:]
1474
1475            def forward(self, x):
1476                return self.W1 + self.W2 + x
1477
1478        model = ModWithWeightViews()
1479        # each weight is a view of the same storage
1480        self.assertEqual(model.W1.nbytes, 4)
1481        self.assertEqual(model.W1.untyped_storage().nbytes(), 8)
1482        self.assertEqual(model.W2.nbytes, 4)
1483        self.assertEqual(model.W2.untyped_storage().nbytes(), 8)
1484        program = to_edge(
1485            export(
1486                model,
1487                (torch.ones(1),),
1488            )
1489        ).to_executorch()
1490
1491        program = program._emitter_output.program
1492        # each emitted weight is not a view
1493        self.assertEqual(len(program.constant_buffer[1].storage), 4)
1494        self.assertEqual(len(program.constant_buffer[2].storage), 4)
1495
1496    def test_non_persistent_buffer(self) -> None:
1497        class NonPersistentBuffer(nn.Module):
1498            def __init__(self):
1499                super(NonPersistentBuffer, self).__init__()
1500                self.register_buffer("buf", torch.tensor([1]), persistent=False)
1501
1502            def forward(self, x):
1503                return x + self.buf
1504
1505        model = NonPersistentBuffer()
1506        program = to_edge(
1507            export(
1508                model,
1509                (torch.ones(1),),
1510            )
1511        ).to_executorch()
1512        program = program._emitter_output.program
1513        # confirm that the buffer was emitted
1514        self.assertEqual(len(program.constant_buffer), 2)
1515        self.assertEqual(len(program.constant_buffer[1].storage), 8)
1516
1517    def test_emit_lifted_tensor_constant(self) -> None:
1518        class LiftedConstants(nn.Module):
1519            def __init__(self):
1520                super().__init__()
1521
1522            def forward(self, x):
1523                x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float)
1524                return x
1525
1526        model = LiftedConstants()
1527
1528        program = to_edge(
1529            export(
1530                model,
1531                (torch.ones(3, 2),),
1532            )
1533        ).to_executorch()
1534
1535        program = program._emitter_output.program
1536        exec_plan = program.execution_plan[0]
1537        # There should only be 1 input to this model.
1538        self.assertEqual(len(exec_plan.inputs), 1)
1539        self.assertEqual(len(program.constant_buffer), 2)
1540        self.assertEqual(len(program.constant_buffer[1].storage), 24)
1541
1542    def test_mutable_buffers(self) -> None:
1543        def count_copies(gm: torch.fx.GraphModule) -> int:
1544            return sum(
1545                (
1546                    node.target == torch.ops.aten.copy_
1547                    or node.target == exir_ops.edge.aten.copy_.default
1548                )
1549                for node in gm.graph.nodes
1550            )
1551
1552        class MutableStateModule(torch.nn.Module):
1553            def __init__(self):
1554                super().__init__()
1555                self.register_buffer("state", torch.zeros(1))
1556
1557            def forward(self, x):
1558                y = x + self.state
1559                self.state.add_(1)
1560                return y
1561
1562        model = to_edge(
1563            export(
1564                MutableStateModule(),
1565                (torch.zeros(1),),
1566            )
1567        )
1568        model = model.to_executorch()
1569        model.dump_executorch_program(True)
1570        self.assertTrue(
1571            model.executorch_program.execution_plan[0]  # pyre-ignore[16]
1572            .values[0]
1573            .val.allocation_info
1574            is not None
1575        )
1576        executorch_module = _load_for_executorch_from_buffer(model.buffer)
1577        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
1578        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)
1579
1580    def test_mutable_buffers_without_memplanned_inputs(self) -> None:
1581        def count_copies(gm: torch.fx.GraphModule) -> int:
1582            return sum(
1583                (
1584                    node.target == torch.ops.aten.copy_
1585                    or node.target == exir_ops.edge.aten.copy_.default
1586                )
1587                for node in gm.graph.nodes
1588            )
1589
1590        class MutableStateModule(torch.nn.Module):
1591            def __init__(self):
1592                super().__init__()
1593                self.register_buffer("state", torch.zeros(1))
1594
1595            def forward(self, x):
1596                y = x + self.state
1597                self.state.add_(1)
1598                return y
1599
1600        model = to_edge(
1601            export(
1602                MutableStateModule(),
1603                (torch.zeros(1),),
1604            )
1605        )
1606        model = model.to_executorch(
1607            config=ExecutorchBackendConfig(
1608                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
1609                sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
1610            )
1611        )
1612        model.dump_executorch_program(True)
1613        self.assertTrue(
1614            model.executorch_program.execution_plan[0]  # pyre-ignore[16]
1615            .values[0]
1616            .val.allocation_info
1617            is not None
1618        )
1619        executorch_module = _load_for_executorch_from_buffer(model.buffer)
1620        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1))
1621        self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1)
1622
1623    def test_infinity_in_model(self) -> None:
1624        class InfinityMaskModel(nn.Module):
1625            def __init__(self):
1626                super().__init__()
1627                self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32)
1628
1629            def forward(self, x):
1630                masked_weights = x.masked_fill(self.mask == 0, float("-inf"))
1631                return masked_weights
1632
1633        model = to_edge(
1634            export(
1635                InfinityMaskModel(),
1636                (torch.randn(2, 2),),
1637            )
1638        )
1639
1640        # Confirm that we can serialize the model with infinity in it.
1641        model = model.to_executorch()
1642
1643        # Assert that the infinity is stored as a string "-inf".
1644        values = model.executorch_program.execution_plan[0].values
1645        self.assertEqual(values[5].val, Double(double_val=float("-inf")))
1646
1647        # Confirm that we can also deserialize the model with infinity in it.
1648        pte_data = deserialize_pte_binary(model.buffer)
1649        self.assertEqual(
1650            pte_data.execution_plan, model.executorch_program.execution_plan
1651        )
1652
1653    def test_mutate_input_tensor(self) -> None:
1654        class MutateInputTensorModule(torch.nn.Module):
1655            def __init__(self):
1656                super().__init__()
1657
1658            def forward(self, x):
1659                x.add_(1)
1660
1661        model = to_edge(
1662            export(MutateInputTensorModule(), (torch.zeros(1),))
1663        ).to_executorch(
1664            config=ExecutorchBackendConfig(
1665                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False)
1666            )
1667        )
1668        executorch_model = _load_for_executorch_from_buffer(model.buffer)
1669        input = torch.zeros(1)
1670        executorch_model(input)
1671        self.assertEqual(input, torch.ones(1))
1672