xref: /aosp_15_r20/external/executorch/exir/backend/test/test_backends_lifted.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8from typing import Dict, List
9
10import executorch.exir as exir
11import torch
12from executorch.exir import to_edge
13from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
14from executorch.exir.backend.compile_spec_schema import CompileSpec
15from executorch.exir.backend.partitioner import (
16    DelegationSpec,
17    Partitioner,
18    PartitionResult,
19)
20
21# import the backend implementation
22from executorch.exir.backend.test.backend_with_compiler_demo import (
23    BackendWithCompilerDemo,
24)
25from executorch.exir.backend.test.hta_partitioner_demo import (
26    HTAPartitionerMultiplePatternsDemo,
27    HTAPartitionerOnePatternDemo,
28)
29from executorch.exir.backend.test.op_partitioner_demo import (
30    AddAttributePartitionerDemo,
31    AddMulPartitionerDemo,
32)
33from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
34
35from executorch.exir.delegate import executorch_call_delegate
36from executorch.exir.dialects._ops import ops as exir_ops
37from executorch.exir.graph_module import get_control_flow_submodules
38from executorch.exir.lowered_backend_module import (
39    get_lowered_backend_modules,
40    get_lowered_submodules,
41)
42from executorch.exir.print_program import print_program
43from executorch.exir.schema import (
44    BackendDelegate,
45    BackendDelegateDataReference,
46    DataLocation,
47    DelegateCall,
48    Program,
49)
50
51from executorch.extension.pybindings.portable_lib import (  # @manual
52    _load_for_executorch_from_buffer,
53)
54from executorch.extension.pytree import tree_flatten
55
56from functorch.experimental import control_flow
57from torch.ao.quantization import get_default_qconfig_mapping  # @manual
58from torch.ao.quantization.backend_config.executorch import (
59    get_executorch_backend_config,
60)
61from torch.ao.quantization.quantize_fx import (
62    _convert_to_reference_decomposed_fx,
63    prepare_fx,
64)
65from torch.export import export, ExportedProgram
66from torch.testing import FileCheck
67
68
69def vary_segments(test_method):
70    """A decorator that calls the test method with `extract_delegate_segments` set to
71    True and False.
72
73    Decorated test methods must expect a boolean parameter named
74    `extract_delegate_segments`, and they should pass that value to to_executorch() like:
75
76        m.to_executorch(
77            config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments)
78        )
79
80    This will cause the delegate data blobs to be extracted from the program and
81    serialized as separate, freeable program segments. Backends should detect no
82    difference at runtime.
83    """
84
85    def wrapper(self):
86        for extract_delegate_segments in [False, True]:
87            # subTest will create a different top-level test entry for each
88            # value, whose full names have a suffix like
89            # "(extract_delegate_segments=True)".
90            with self.subTest(extract_delegate_segments=extract_delegate_segments):
91                test_method(self, extract_delegate_segments=extract_delegate_segments)
92
93    return wrapper
94
95
96class TestBackends(unittest.TestCase):
97    def check_delegate_input(
98        self, delegate: LoweredBackendModule, input_len: int
99    ) -> None:
100        counter = 0
101        for node in delegate.original_module.graph.nodes:
102            if node.op == "placeholder":
103                counter += 1
104        self.assertEqual(counter, input_len)
105
106    def check_backend_delegate(
107        self,
108        program: Program,
109        delegate: BackendDelegate,
110        expected_id: str,
111        expected_processed: bytes,
112    ) -> None:
113        self.assertEqual(delegate.id, expected_id)
114        processed: BackendDelegateDataReference = delegate.processed
115        self.assertEqual(processed.location, DataLocation.INLINE)
116        self.assertLess(processed.index, len(program.backend_delegate_data))
117        self.assertEqual(
118            program.backend_delegate_data[processed.index].data, expected_processed
119        )
120
121    def test_simple(self):
122        class SinModule(torch.nn.Module):
123            def __init__(self):
124                super().__init__()
125
126            def forward(self, x):
127                return torch.sin(x)
128
129        sin_module = SinModule()
130        model_inputs = (torch.ones(1),)
131        expected_res = sin_module(*model_inputs)
132        edgeir_m = to_edge(export(sin_module, model_inputs))
133
134        lowered_sin_module = to_backend(
135            "BackendWithCompilerDemo", edgeir_m.exported_program(), []
136        )
137        new_res = lowered_sin_module(*model_inputs)
138
139        self.assertTrue(torch.allclose(new_res, expected_res))
140
141        # TODO(tkaruturi): emitting single LoweredBackendModule
142        # program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program
143
144    @vary_segments
145    def test_backend_with_compiler(self, extract_delegate_segments: bool):
146        class SinModule(torch.nn.Module):
147            def __init__(self):
148                super().__init__()
149
150            # TODO(chenlai): add a test with a diffrent method name when
151            # it's resolved in compiler side.
152            def forward(self, x):
153                return torch.sin(x)
154
155        sin_module = SinModule()
156        model_inputs = (torch.ones(1),)
157        edgeir_m = to_edge(export(sin_module, model_inputs))
158        max_value = model_inputs[0].shape[0]
159        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
160        lowered_sin_module = to_backend(
161            "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
162        )
163
164        class CompositeModule(torch.nn.Module):
165            def __init__(self):
166                super().__init__()
167                self.lowered_linear_sin = lowered_sin_module
168
169            def forward(self, x):
170                return self.lowered_linear_sin(x)
171
172        composite_model = CompositeModule()
173        model_inputs = (torch.ones(1),)
174
175        composite_model(*model_inputs)
176
177        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
178            config=exir.ExecutorchBackendConfig(
179                extract_delegate_segments=extract_delegate_segments
180            )
181        )
182        graph_module = exec_prog.exported_program().graph_module
183
184        # Check that there is not an aten.sin node.
185        self.assertTrue(
186            exir_ops.edge.aten.sin
187            not in {node.target for node in graph_module.graph.nodes}
188        )
189
190        # Check that there exists a call_delegate, representing the call to the
191        # delegated function
192        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
193            graph_module.code
194        )
195        lowered_submodules = get_lowered_submodules(graph_module)
196        self.assertEqual(len(lowered_submodules), 1)
197
198        for node in graph_module.graph.nodes:
199            if node.op == "call_function" and node.target == executorch_call_delegate:
200                # Check that first arg is lowered_module_{unique_id}
201                self.assertEqual(node.args[0].target, "lowered_module_0")
202
203        program = exec_prog._emitter_output.program
204
205        # Check the program can be printed
206        print_program(program)
207
208        # Check the backend delegate
209        self.check_backend_delegate(
210            program=program,
211            delegate=program.execution_plan[0].delegates[0],
212            expected_id=BackendWithCompilerDemo.__name__,
213            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
214        )
215
216        # Check the delegate instruction
217        self.assertTrue(
218            isinstance(
219                program.execution_plan[0].chains[0].instructions[0].instr_args,
220                DelegateCall,
221            )
222        )
223        buff = exec_prog.buffer
224
225        executorch_module = _load_for_executorch_from_buffer(buff)
226        model_inputs = torch.ones(1)
227        model_outputs = executorch_module.forward([model_inputs])
228        self.assertEqual(
229            model_inputs,
230            torch.ones(1),
231        )
232        expected_output = 0.8333 * torch.ones(1)
233
234        self.assertTrue(
235            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
236        )
237
238    @vary_segments
239    def test_lowered_add_mul(self, extract_delegate_segments: bool):
240        class AddMulModule(torch.nn.Module):
241            def __init__(self):
242                super().__init__()
243
244            def forward(self, a, x, b):
245                y = torch.mm(a, x)
246                z = torch.add(y, b)
247                return z
248
249        add_mul_module = AddMulModule()
250        model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
251        edge_graph_module = to_edge(export(add_mul_module, model_inputs))
252        max_value = model_inputs[0].shape[0]
253        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
254        lowered_add_mul = to_backend(
255            "BackendWithCompilerDemo",
256            edge_graph_module.exported_program(),
257            compile_specs,
258        )
259
260        class CompositeModule(torch.nn.Module):
261            def __init__(self):
262                super().__init__()
263                self.lowered_add_mul = lowered_add_mul
264
265            def forward(self, a, x, b):
266                return self.lowered_add_mul(a, x, b)
267
268        composite_model = CompositeModule()
269
270        composite_model(*model_inputs)
271
272        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
273            config=exir.ExecutorchBackendConfig(
274                extract_delegate_segments=extract_delegate_segments
275            )
276        )
277        buff = exec_prog.buffer
278
279        executorch_module = _load_for_executorch_from_buffer(buff)
280
281        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
282        inputs_flattened, _ = tree_flatten(model_inputs)
283        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
284        ref_output = add_mul_module(*model_inputs)
285
286        self.assertTrue(
287            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
288        )
289
290    def run_model_in_unsupported_backend(self, extract_delegate_segments: bool):
291        class SinModule(torch.nn.Module):
292            def __init__(self):
293                super().__init__()
294
295            def forward(self, x):
296                return torch.sin(x)
297
298        sin_module = SinModule()
299        # the backend only  accepts shape <= 4
300        model_inputs = (torch.ones(6),)
301        edgeir_m = to_edge(export(sin_module, model_inputs))
302        max_value = model_inputs[0].shape[0]
303        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
304        lowered_sin_module = to_backend(
305            "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
306        )
307
308        class CompositeModule(torch.nn.Module):
309            def __init__(self):
310                super().__init__()
311                self.lowered_linear_sin = lowered_sin_module
312
313            def forward(self, x):
314                return self.lowered_linear_sin(x)
315
316        composite_model = CompositeModule()
317        model_inputs = (torch.zeros(6),)
318
319        composite_model(*model_inputs)
320
321        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
322            config=exir.ExecutorchBackendConfig(
323                extract_delegate_segments=extract_delegate_segments
324            ),
325        )
326
327        buff = exec_prog.buffer
328
329        # This line should raise an exception like
330        # RuntimeError: failed with error 0x12
331        _load_for_executorch_from_buffer(buff)
332
333    @vary_segments
334    def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool):
335        with self.assertRaisesRegex(
336            RuntimeError,
337            "loading method forward failed with error 0x12",
338        ):
339            self.run_model_in_unsupported_backend(
340                extract_delegate_segments=extract_delegate_segments
341            )
342
343    @vary_segments
344    def test_backend_with_compiler_delegate_and_operator(
345        self, extract_delegate_segments: bool
346    ):
347        # Test includes both delegates and operator
348        # import the backend implementation
349        from executorch.exir.backend.test.backend_with_compiler_demo import (
350            BackendWithCompilerDemo,
351        )
352
353        class SinModule(torch.nn.Module):
354            def __init__(self):
355                super().__init__()
356
357            # TODO(chenlai): add a test with a diffrent method name when
358            # it's resolved in compiler side.
359            def forward(self, x):
360                return [torch.sin(x)]
361
362        sin_module = SinModule()
363        model_inputs = (torch.ones(1),)
364        edgeir_m = to_edge(export(sin_module, model_inputs))
365        max_value = model_inputs[0].shape[0]
366        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
367        lowered_sin_module = to_backend(
368            "BackendWithCompilerDemo", edgeir_m.exported_program(), compile_specs
369        )
370
371        class CompositeModule(torch.nn.Module):
372            def __init__(self):
373                super().__init__()
374                self.lowered_linear_sin = lowered_sin_module
375
376            def forward(self, x):
377                a = self.lowered_linear_sin(x)[0]
378                b = self.lowered_linear_sin(x)[0]
379                return torch.add(a, b)
380
381        composite_model = CompositeModule()
382        model_inputs = (torch.ones(1),)
383
384        composite_model(*model_inputs)
385
386        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
387            config=exir.ExecutorchBackendConfig(
388                extract_delegate_segments=extract_delegate_segments
389            ),
390        )
391        graph_module = exec_prog.exported_program().graph_module
392        program = exec_prog._emitter_output.program
393        buff = exec_prog.buffer
394
395        # Check that there is not an aten.sin node.
396        self.assertTrue(
397            exir_ops.edge.aten.sin.default
398            not in {node.target for node in graph_module.graph.nodes}
399        )
400
401        # Check that there exists a call_delegate op, representing the call to the
402        # delegated function
403        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
404            graph_module.code
405        )
406
407        for node in graph_module.graph.nodes:
408            if node.op == "call_function" and node.target == executorch_call_delegate:
409                # Check that first arg is lowered_module_{unique_id}
410                self.assertEqual(node.args[0].target, "lowered_module_0")
411
412        # Check the backend delegate
413        self.check_backend_delegate(
414            program=program,
415            delegate=program.execution_plan[0].delegates[0],
416            expected_id=BackendWithCompilerDemo.__name__,
417            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
418        )
419
420        # Check the delegate instruction
421        self.assertTrue(
422            isinstance(
423                program.execution_plan[0].chains[0].instructions[0].instr_args,
424                DelegateCall,
425            )
426        )
427
428        executorch_module = _load_for_executorch_from_buffer(buff)
429        model_inputs = torch.ones(1)
430
431        model_outputs = executorch_module.forward([model_inputs])
432
433        self.assertEqual(
434            model_inputs,
435            torch.ones(1),
436        )
437        expected_output = 1.666667 * torch.ones(1)
438
439        self.assertTrue(
440            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
441        )
442
443    def test_backend_with_compiler_backend_runtime_exception(self):
444        class SinModule(torch.nn.Module):
445            def __init__(self):
446                super().__init__()
447
448            # TODO(chenlai): add a test with a diffrent method name when
449            # it's resolved in compiler side.
450            def forward(self, x):
451                return torch.sin(x) + torch.cos(x)
452
453        sin_module = SinModule()
454        model_inputs = (torch.ones(1),)
455        edgeir_m = to_edge(export(sin_module, model_inputs))
456        error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo"
457
458        with self.assertRaisesRegex(
459            RuntimeError,
460            error_msg,
461        ):
462            _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program(), [])
463
464    def test_backend_with_compiler_backend_not_found_exception(self):
465        class SinModule(torch.nn.Module):
466            def __init__(self):
467                super().__init__()
468
469            # TODO(chenlai): add a test with a diffrent method name when
470            # it's resolved in compiler side.
471            def forward(self, x):
472                return torch.sin(x) + torch.cos(x)
473
474        sin_module = SinModule()
475        model_inputs = (torch.ones(1),)
476        edgeir_m = to_edge(export(sin_module, model_inputs))
477        error_msg = r"Backend FakeBackendWithCompilerDemo was not found."
478
479        with self.assertRaisesRegex(
480            NotImplementedError,
481            error_msg,
482        ):
483            _ = to_backend(
484                "FakeBackendWithCompilerDemo", edgeir_m.exported_program(), []
485            )
486
487    @vary_segments
488    def test_backend_with_compiler_delegate_and_operator_with_two_modules(
489        self, extract_delegate_segments: bool
490    ):
491        # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend
492        class LowerableSubModel(torch.nn.Module):
493            def __init__(self):
494                super().__init__()
495
496            def forward(self, x):
497                return torch.sin(x)
498
499        # sin_module is an nn.Module
500        to_be_lowered = LowerableSubModel()
501        example_input = (torch.ones(1),)
502        to_be_lowered_exir_submodule = to_edge(export(to_be_lowered, example_input))
503
504        max_value = example_input[0].shape[0]
505        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
506        lowered_module = to_backend(
507            "BackendWithCompilerDemo",
508            to_be_lowered_exir_submodule.exported_program(),
509            compile_specs,
510        )
511
512        class NonLowerableSubModel(torch.nn.Module):
513            def __init__(self, bias):
514                super().__init__()
515                self.register_buffer("bias", bias)
516
517            def forward(self, a, b):
518                return torch.add(torch.add(a, b), self.bias)
519
520        # the composite modules, including lower part and non-lowerpart
521        class CompositeModel(torch.nn.Module):
522            def __init__(self):
523                super().__init__()
524                self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3)
525                self.lowerable = lowered_module
526
527            def forward(self, x):
528                a = self.lowerable(x)
529                b = self.lowerable(a)
530                ret = self.non_lowerable(a, b)
531                return a, b, ret
532
533        composite_model = CompositeModel()
534
535        # Prepare the model input
536        model_inputs = (torch.ones(1),)
537
538        # Verify the input works with eager module
539        composite_model(*model_inputs)
540
541        exec_prog = to_edge(export(composite_model, model_inputs)).to_executorch(
542            config=exir.ExecutorchBackendConfig(
543                extract_delegate_segments=extract_delegate_segments
544            ),
545        )
546        flatbuffer = exec_prog.buffer
547
548        executorch_module = _load_for_executorch_from_buffer(flatbuffer)
549        model_outputs = executorch_module.forward([*model_inputs])
550
551        expected_outputs = [
552            0.8333 * torch.ones(1),
553            0.7369 * torch.ones(1),
554            1.8702 * torch.ones(1),
555        ]
556
557        for index, expected_output in enumerate(expected_outputs):
558            self.assertTrue(
559                torch.allclose(
560                    model_outputs[index], expected_output, atol=1e-03, rtol=1e-03
561                )
562            )
563
564    @vary_segments
565    def test_partition_delegate_graph_with_multiple_patterns(
566        self, extract_delegate_segments: bool
567    ):
568        class CompositeModel(torch.nn.Module):
569            def __init__(self, _weight):
570                super().__init__()
571                self.weight = _weight
572                self.lstm = torch.nn.LSTM(
573                    input_size=32,
574                    hidden_size=32,
575                    num_layers=1,
576                )
577                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)
578
579            def forward(self, x_raw, h, c):
580                output, (hn, cn) = self.lstm(x_raw, (h, c))
581                k = self.conv(output)
582                x = output
583                y = cn
584                a = torch.sub(x, y)
585                b = torch.sub(x, a)
586                c = torch.sub(x, b)
587                d = torch.add(x, self.weight)
588                e = torch.mul(c, d)
589                return e, hn, k
590
591        # Prepare input and trace it
592        input_x = torch.ones([1, 32])
593        input_h = torch.ones([1, 32])
594        input_c = torch.ones([1, 32])
595        inputs = (input_x, input_h, input_c)
596
597        composite_m = CompositeModel(3)
598        orig_res = composite_m(*inputs)
599
600        traced = to_edge(
601            export(composite_m, inputs),
602            compile_config=exir.EdgeCompileConfig(
603                _check_ir_validity=False, _use_edge_ops=True
604            ),
605        )
606
607        program_without_delegates = to_edge(
608            export(CompositeModel(3), inputs),
609            compile_config=exir.EdgeCompileConfig(
610                _check_ir_validity=False,
611            ),
612        ).to_executorch(
613            config=exir.ExecutorchBackendConfig(
614                extract_delegate_segments=extract_delegate_segments
615            ),
616        )
617        # after this step, part of the graph will be lowered to backend, depending on
618        # HTAPartitionerDemo's rule.
619        program_with_delegates = traced
620        program_with_delegates = program_with_delegates.to_backend(
621            HTAPartitionerMultiplePatternsDemo()
622        )
623        program_with_delegates = program_with_delegates.to_executorch(
624            config=exir.ExecutorchBackendConfig(
625                extract_delegate_segments=extract_delegate_segments
626            ),
627        )
628
629        new_res = program_with_delegates.exported_program().module()(*inputs)
630        for t1, t2 in zip(new_res, orig_res, strict=True):
631            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
632
633        # Check the backend delegate
634        self.check_backend_delegate(
635            program=program_with_delegates._emitter_output.program,
636            delegate=program_with_delegates._emitter_output.program.execution_plan[
637                0
638            ].delegates[0],
639            expected_id=QnnBackend.__name__,
640            expected_processed=b"imqnncompiled",
641        )
642
643        # Check add not in the program with delegates
644        self.assertEqual(
645            0,
646            len(
647                [
648                    op
649                    for op in program_with_delegates._emitter_output.program.execution_plan[
650                        0
651                    ].operators
652                    if op.name == "aten::sub"
653                ]
654            ),
655        )
656
657        # Check convolution not in the program with delegates
658        self.assertEqual(
659            0,
660            len(
661                [
662                    op
663                    for op in program_with_delegates._emitter_output.program.execution_plan[
664                        0
665                    ].operators
666                    if op.name == "aten::convolution"
667                ]
668            ),
669        )
670
671        # Check convolution in the program without delegates
672        self.assertEqual(
673            1,
674            len(
675                [
676                    op
677                    for op in program_without_delegates._emitter_output.program.execution_plan[
678                        0
679                    ].operators
680                    if op.name == "aten::convolution"
681                ]
682            ),
683        )
684
685    @vary_segments
686    def test_partition_delegate_graph_with_one_patterns(
687        self, extract_delegate_segments: bool
688    ):
689        class CompositeModel(torch.nn.Module):
690            def __init__(self, _weight):
691                super().__init__()
692                self.weight = _weight
693                self.lstm = torch.nn.LSTM(
694                    input_size=32,
695                    hidden_size=32,
696                    num_layers=1,
697                )
698                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)
699
700            def forward(self, x_raw, h, c):
701                output, (hn, cn) = self.lstm(x_raw, (h, c))
702                k = self.conv(output)
703                x = output
704                y = cn
705                a = torch.sub(x, y)
706                b = torch.sub(x, a)
707                c = torch.sub(x, b)
708                d = torch.add(x, self.weight)
709                e = torch.mul(c, d)
710                return e, hn, k
711
712        # Prepare input and trace it
713        input_x = torch.ones([1, 32])
714        input_h = torch.ones([1, 32])
715        input_c = torch.ones([1, 32])
716        inputs = (input_x, input_h, input_c)
717
718        composite_m = CompositeModel(3)
719        orig_res = composite_m(*inputs)
720
721        traced = to_edge(
722            export(composite_m, inputs),
723            compile_config=exir.EdgeCompileConfig(
724                _check_ir_validity=False, _use_edge_ops=True
725            ),
726        )
727
728        program_without_delegates = to_edge(
729            export(
730                CompositeModel(3),
731                (input_x, input_h, input_c),
732            ),
733            compile_config=exir.EdgeCompileConfig(
734                _check_ir_validity=False,
735            ),
736        ).to_executorch(
737            config=exir.ExecutorchBackendConfig(
738                extract_delegate_segments=extract_delegate_segments
739            ),
740        )
741        # after this step, part of the graph will be lowered to backend, depending on
742        # HTAPartitionerDemo's rule.
743        traced_with_delegate = traced
744        traced_with_delegate = traced_with_delegate.to_backend(
745            HTAPartitionerOnePatternDemo()
746        )
747
748        new_res = traced_with_delegate.exported_program().module()(*inputs)
749        for t1, t2 in zip(new_res, orig_res, strict=True):
750            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
751
752        program_with_delegates = traced_with_delegate.to_executorch(
753            config=exir.ExecutorchBackendConfig(
754                extract_delegate_segments=extract_delegate_segments
755            ),
756        )
757
758        # TODO(T143084047): Currently not retraceable
759        # Retracing is not needed, but keeping this here to make sure the result
760        # of to_backend is retraceable
761        # graph_module_with_delegate = to_edge(export(
762        #     traced_with_delegate,
763        #     (input_x, input_h, input_c),
764        #
765        # ))
766
767        # program_with_delegates = graph_module_with_delegate.to_executorch(
768        #     config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments),
769        # )
770
771        new_res = program_with_delegates.exported_program().module()(*inputs)
772        for t1, t2 in zip(new_res, orig_res, strict=True):
773            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
774
775        # Check the backend delegate
776        self.check_backend_delegate(
777            program=program_with_delegates._emitter_output.program,
778            delegate=program_with_delegates._emitter_output.program.execution_plan[
779                0
780            ].delegates[0],
781            expected_id=QnnBackend.__name__,
782            expected_processed=b"imqnncompiled",
783        )
784
785        # Check add is in the program with delegates
786        self.assertEqual(
787            1,
788            len(
789                [
790                    op
791                    for op in program_with_delegates._emitter_output.program.execution_plan[
792                        0
793                    ].operators
794                    if op.name == "aten::sub"
795                ]
796            ),
797        )
798
799        # Check convolution not in the program with delegates
800        self.assertEqual(
801            0,
802            len(
803                [
804                    op
805                    for op in program_with_delegates._emitter_output.program.execution_plan[
806                        0
807                    ].operators
808                    if op.name == "aten::convolution"
809                ]
810            ),
811        )
812
813        # Check convolution in the program without delegates
814        self.assertEqual(
815            1,
816            len(
817                [
818                    op
819                    for op in program_without_delegates._emitter_output.program.execution_plan[
820                        0
821                    ].operators
822                    if op.name == "aten::convolution"
823                ]
824            ),
825        )
826
827    @vary_segments
828    def test_add_mul_partitioner(self, extract_delegate_segments: bool):
829        class Model(torch.nn.Module):
830            def __init__(self):
831                super().__init__()
832
833            def forward(self, a, x, b):
834                y = torch.mm(a, x)
835                z = y + b
836                a = z - a
837                y = torch.mm(a, x)
838                z = y + b
839                return z
840
841        m = Model()
842        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
843        orig_res = m(*inputs)
844
845        ep = to_edge(export(m, inputs))
846        executorch_prog = ep
847        executorch_prog = executorch_prog.to_backend(AddMulPartitionerDemo())
848        executorch_prog = executorch_prog.to_executorch(
849            config=exir.ExecutorchBackendConfig(
850                extract_delegate_segments=extract_delegate_segments
851            ),
852        )
853
854        new_res = executorch_prog.exported_program().graph_module(*inputs)
855        self.assertTrue(torch.allclose(new_res[0], orig_res))
856
857        counter = 0
858        for node in executorch_prog.exported_program().graph_module.graph.nodes:
859            if node.op == "get_attr":
860                self.assertEqual(node.target, f"lowered_module_{counter}")
861                counter += 1
862        # There should be 2 delegated modules
863        self.assertEqual(counter, 2)
864
865        executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
866        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
867        inputs_flattened, _ = tree_flatten(inputs)
868        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
869        ref_output = m(*inputs)
870
871        self.assertTrue(
872            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03),
873        )
874
875    @vary_segments
876    def test_partitioner_with_attributes(self, extract_delegate_segments: bool):
877        """
878        check that parameters that are lowered are correctly moved into the sub
879        program, rather than being retained and passed as inputs.
880        """
881
882        class AddOne(torch.nn.Module):
883            def __init__(self):
884                super().__init__()
885                self.register_buffer("one", torch.ones(1, 3))
886
887            def forward(self, x):
888                return x + self.one
889
890        class Model(torch.nn.Module):
891            def __init__(self):
892                super().__init__()
893                self.add_one = AddOne()
894                self.add_one_2 = AddOne()
895
896            def forward(self, x, y):
897                x = self.add_one(x) * y
898                return self.add_one_2(x)
899
900        inputs = (torch.randn(1, 3), torch.randn(1, 3))
901        orig_res = Model()(*inputs)
902        ep = to_edge(export(Model(), inputs))
903        executorch_prog = ep
904        executorch_prog = executorch_prog.to_backend(AddAttributePartitionerDemo())
905        executorch_prog = executorch_prog.to_executorch(
906            config=exir.ExecutorchBackendConfig(
907                extract_delegate_segments=extract_delegate_segments
908            ),
909        )
910
911        # Check the delegated submodules
912        lowered_backends = get_lowered_backend_modules(
913            executorch_prog.exported_program().graph_module
914        )
915        self.assertEqual(len(lowered_backends), 2)
916        for backend in lowered_backends:
917            original_program = backend.original_module
918            # check that program has the lowered attributes
919            self.assertEqual(len(original_program.state_dict), 1)
920            # check backend has one placeholder input one placeholder parameter
921            self.check_delegate_input(backend, 2)
922
923        executorch_prog.buffer
924
925        new_res = executorch_prog.exported_program().graph_module(*inputs)
926        self.assertTrue(torch.allclose(orig_res, new_res[0]))
927
928    def test_bad_partitioner(self):
929        """
930        Checks that we throw an error if user provided partitioner modifies the
931        graph module
932        """
933        inputs = (torch.randn(1, 3), torch.randn(1, 3))
934
935        class Model(torch.nn.Module):
936            def __init__(self):
937                super().__init__()
938
939            def forward(self, x, y):
940                x = x + y
941                x = x * y
942                x = x - y
943                x = x / y
944                x = x * y
945                x = x + y
946                return x
947
948        class BadPartitioner(Partitioner):
949            partition_tags = {"tag1": DelegationSpec("BackendWithCompilerDemo", [])}
950
951            def partition(self, exported_program: ExportedProgram) -> PartitionResult:
952                # Partitioner should not modify the given graph module
953                partition_tags: Dict[str, DelegationSpec] = {}
954                for node in exported_program.graph.nodes:
955                    if (
956                        node.op == "call_function"
957                        and node.target == exir_ops.edge.aten.add.Tensor
958                    ):
959                        node.target = exir_ops.edge.aten.mul.Tensor
960                return PartitionResult(
961                    tagged_exported_program=exported_program,
962                    partition_tags=partition_tags,
963                )
964
965        ep = to_edge(export(Model(), inputs))
966        with self.assertRaises(AssertionError):
967            _ = ep.to_backend(BadPartitioner())
968
969    def test_quantized_with_delegate(self) -> None:
970        torch.ops.load_library(
971            "//executorch/kernels/quantized:custom_ops_generated_lib"
972        )
973        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
974        in_size = 2
975        input_size = 3
976        output_size = 4
977        linear = torch.nn.Linear(input_size, output_size).eval()
978        example_inputs = (torch.ones(in_size, input_size),)
979        prepared_linear = prepare_fx(
980            linear,
981            qconfig_mapping,
982            example_inputs,
983            backend_config=get_executorch_backend_config(),
984        )
985        converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx(
986            prepared_linear,
987        )
988
989        # fails to trace here
990        converted_linear_gm = to_edge(
991            export(
992                converted_linear,
993                example_inputs,
994            ),
995            compile_config=exir.EdgeCompileConfig(
996                _check_ir_validity=False,
997            ),
998        )
999        FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run(
1000            converted_linear_gm.exported_program().graph_module.code
1001        )
1002
1003    def test_partition_with_control_flow(self) -> None:
1004        def true_fn(x, y):
1005            x = x - y
1006            x = x + y
1007            x = x - y
1008            return x
1009
1010        def false_fn(x, y):
1011            x = x - y
1012            x = torch.mm(x, y)
1013            x = x - y
1014            return x
1015
1016        class Module(torch.nn.Module):
1017            def forward(self, x, y):
1018                x = x + y
1019                x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1020                x = x - y
1021                return x
1022
1023        f = Module()
1024        inputs = (torch.ones(2, 2), torch.ones(2, 2))
1025        orig_res = f(*inputs)
1026        orig = to_edge(
1027            export(
1028                f,
1029                inputs,
1030            )
1031        )
1032        partitioned = orig
1033        partitioned = partitioned.to_backend(AddMulPartitionerDemo())
1034
1035        new_res = partitioned.exported_program().module()(*inputs)
1036        self.assertTrue(torch.allclose(orig_res, new_res[0]))
1037
1038        toplevel_lowered = get_lowered_submodules(
1039            partitioned.exported_program().graph_module
1040        )
1041        self.assertEqual(len(toplevel_lowered), 1)
1042        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1043            toplevel_lowered[0][1].original_module.graph_module.code
1044        )
1045
1046        # Toplevel module only has the cond submodules
1047        partitioned_submodules = get_control_flow_submodules(
1048            partitioned.exported_program().graph_module
1049        )
1050        self.assertEqual(len(partitioned_submodules), 2)
1051
1052        true_gm = partitioned_submodules[0][1]
1053        true_lowered = get_lowered_submodules(true_gm)
1054        self.assertEqual(len(true_lowered), 1)
1055        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1056            true_lowered[0][1].original_module.graph_module.code
1057        )
1058
1059        false_gm = partitioned_submodules[1][1]
1060        false_lowered = get_lowered_submodules(false_gm)
1061        self.assertEqual(len(true_lowered), 1)
1062        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1063            false_lowered[0][1].original_module.graph_module.code
1064        )
1065
1066    def test_partition_with_map(self) -> None:
1067        def map_fn(x, y):
1068            x = x - y
1069            x = x + y
1070            return x
1071
1072        class Module(torch.nn.Module):
1073            def forward(self, xs, y):
1074                y = torch.mm(y, y)
1075                return control_flow.map(map_fn, xs, y)
1076
1077        f = Module()
1078        inputs = (torch.ones(2, 2), torch.ones(2, 2))
1079        orig_res = f(*inputs)
1080        orig = to_edge(
1081            export(
1082                f,
1083                inputs,
1084            )
1085        )
1086        partitioned = orig
1087        partitioned = partitioned.to_backend(AddMulPartitionerDemo())
1088
1089        toplevel_lowered = get_lowered_submodules(
1090            partitioned.exported_program().graph_module
1091        )
1092        self.assertEqual(len(toplevel_lowered), 1)
1093        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1094            toplevel_lowered[0][1].original_module.graph_module.code
1095        )
1096
1097        # Toplevel module only has the map submodule
1098        partitioned_submodules = get_control_flow_submodules(
1099            partitioned.exported_program().graph_module
1100        )
1101        self.assertEqual(len(partitioned_submodules), 1)
1102
1103        map_fn_gm = partitioned_submodules[0][1]
1104        map_fn_lowered = get_lowered_submodules(map_fn_gm)
1105        self.assertEqual(len(map_fn_lowered), 1)
1106        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1107            map_fn_lowered[0][1].original_module.graph_module.code
1108        )
1109
1110        new_res = partitioned.exported_program().module()(*inputs)
1111
1112        self.assertTrue(torch.allclose(orig_res, new_res[0]))
1113
1114    def test_partition_with_nested_control_flow(self) -> None:
1115        """
1116        Partitions the add and mul ops, including the ones inside the submodules
1117        """
1118
1119        def true_nested(y):
1120            y = y + y
1121            y = torch.mm(y, y)
1122            return y
1123
1124        def false_nested(y):
1125            return torch.mm(y, y)
1126
1127        def true_fn(x, pred2):
1128            z = control_flow.cond(pred2, true_nested, false_nested, [x])
1129            return x + z
1130
1131        def false_fn(x, _):
1132            return x.cos()
1133
1134        def map_fn(x, pred1, pred2, y):
1135            x = x.cos()
1136            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
1137            x = x + y
1138            return x.sin()
1139
1140        class Module(torch.nn.Module):
1141            def forward(self, xs, pred1, pred2, y):
1142                y = torch.mm(y, y)
1143                return control_flow.map(map_fn, xs, pred1, pred2, y)
1144
1145        inputs = (
1146            torch.ones(2, 2),
1147            torch.tensor([False]),
1148            torch.Tensor([False]),
1149            torch.ones(2, 2),
1150        )
1151
1152        f = Module()
1153        orig_res = f(*inputs)
1154        orig = to_edge(
1155            export(
1156                f,
1157                inputs,
1158            )
1159        )
1160        partitioned = orig
1161        partitioned = partitioned.to_backend(AddMulPartitionerDemo())
1162
1163        new_res = partitioned.exported_program().module()(*inputs)
1164        self.assertTrue(torch.allclose(orig_res, new_res[0]))
1165
1166        toplevel_lowered = get_lowered_submodules(
1167            partitioned.exported_program().graph_module
1168        )
1169        self.assertEqual(len(toplevel_lowered), 1)
1170        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1171            toplevel_lowered[0][1].original_module.graph_module.code
1172        )
1173
1174        # Toplevel module only has the map submodule
1175        partitioned_submodules = get_control_flow_submodules(
1176            partitioned.exported_program().graph_module
1177        )
1178        self.assertEqual(len(partitioned_submodules), 1)
1179
1180        # Map module has the cond submodules
1181        map_submodules = get_control_flow_submodules(partitioned_submodules[0][1])
1182        self.assertEqual(len(map_submodules), 2)
1183
1184        # True module
1185        true_module = map_submodules[0][1]
1186        true_lowered = get_lowered_submodules(true_module)
1187        self.assertEqual(len(true_lowered), 1)
1188        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1189            true_lowered[0][1].original_module.graph_module.code
1190        )
1191
1192        # False module
1193        false_lowered = get_lowered_submodules(map_submodules[1][1])
1194        self.assertEqual(len(false_lowered), 0)
1195
1196        # True module has the nested cond submodules
1197        true_submodules = get_control_flow_submodules(true_module)
1198        self.assertEqual(len(true_submodules), 2)
1199
1200        # Nested True module
1201        true_true_lowered = get_lowered_submodules(true_submodules[0][1])
1202        self.assertEqual(len(true_true_lowered), 1)
1203        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check(
1204            "executorch_exir_dialects_edge__ops_aten_mm_default"
1205        ).run(true_true_lowered[0][1].original_module.graph_module.code)
1206
1207        # Nested False module
1208        true_false_lowered = get_lowered_submodules(true_submodules[1][1])
1209        self.assertEqual(len(true_false_lowered), 1)
1210        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1211            true_false_lowered[0][1].original_module.graph_module.code
1212        )
1213
1214    def test_list_input(self):
1215        class Module(torch.nn.Module):
1216            def forward(self, x: List[torch.Tensor]):
1217                y = x[0] + x[1]
1218                return y
1219
1220        f = Module()
1221        inputs = ([torch.randn(2, 2), torch.randn(2, 2)],)
1222        edge_prog = to_edge(export(f, inputs))
1223        lowered_gm = to_backend(
1224            BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
1225        )
1226
1227        class ComposedM(torch.nn.Module):
1228            def __init__(self):
1229                super().__init__()
1230                self.lowered = lowered_gm
1231
1232            def forward(self, x: List[torch.Tensor]):
1233                return self.lowered(x)
1234
1235        gm = to_edge(export(ComposedM(), inputs))
1236        gm.exported_program().module()(*inputs)
1237
1238    def test_dict_input(self):
1239        class Module(torch.nn.Module):
1240            def forward(self, x: Dict[str, torch.Tensor]):
1241                y = x["a"] + x["b"]
1242                return y
1243
1244        f = Module()
1245        inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
1246        edge_prog = to_edge(export(f, inputs))
1247        lowered_gm = to_backend(
1248            BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
1249        )
1250
1251        class ComposedM(torch.nn.Module):
1252            def __init__(self):
1253                super().__init__()
1254                self.lowered = lowered_gm
1255
1256            def forward(self, x: List[torch.Tensor]):
1257                return self.lowered(x)
1258
1259        gm = to_edge(export(ComposedM(), inputs))
1260        gm.exported_program().module()(*inputs)
1261