xref: /aosp_15_r20/external/executorch/exir/backend/test/backend_with_delegate_mapping_demo.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import List, Union
8
9import torch
10from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
11from executorch.exir.backend.compile_spec_schema import CompileSpec
12from executorch.exir.backend.utils import DelegateMappingBuilder
13from executorch.exir.dialects._ops import ops as exir_ops
14from torch import nn
15from torch.export.exported_program import ExportedProgram
16
17
18# A simple way to represent an op along with its delegate debug identifier.
19class DummyOp:
20    def __init__(
21        self,
22        op: str,
23        delegate_debug_identifier: Union[int, str],
24    ):
25        self.op = op
26        self.delegate_debug_identifier = delegate_debug_identifier
27        self.__name__ = self.__repr__()
28
29    def __repr__(self):
30        return f"{self.op}"
31
32    def serialize(self):
33        return f"{self.op},{self.delegate_debug_identifier},"
34
35
36"""
37This demo implementation is mainly intended to show how the DelegateMappingBuilder should be used
38in backends. There are two use cases represented here:
391. A list of decomposed ops are fused into a single backend op and how the delegate debug identifier
40mapping is generated for that.
412. A single op is decomposed into two backend ops and we show how the delegate debug identifier mapping
42is generated for that.
43
44Here is what the graph looks like for the demo model ConvReLUAddModel implemented in this class:
45input
4647conv (debug_handle : 1)
4849relu (debug_handle : 2)
5051tan (debug_handle : 3)
5253output
54
55Here is what the graph that runs in the backend looks like:
56input
5758fused_conv_relu (delegate_debug_identifier : a)
5960sin (delegate_debug_identifier : b)
6162cos (delegate_debug_identifier : c)
6364div (delegate_debug_identifier : d)
65output
66
67Here is what the delegate mapping looks like. The key is the delegate_debug_identifier and the value
68is the debug handles.
69{ a : (1,2), b : (3), c: (3), d: (3)}
70(NOTE: Here a,b,c can be integers or strings, the decision is left to the user, but whatever is
71used during the AOT process to generate the mapping should be the same int/string logged in the
72runtime.)
73
74NOTE: these two graphs are not necessarily functionally equivalent but rather representative
75examples on how to generated delegate debug identifieres for various use cases such as fusion of ops
76in the backend, decomposition of ops in the backend etc.
77"""
78
79
80class BackendWithDelegateMappingDemo(BackendDetails):
81    @staticmethod
82    def preprocess(
83        edge_program: ExportedProgram,
84        compile_specs: List[CompileSpec],
85    ) -> PreprocessResult:
86        processed_bytes = ""
87        number_of_instruction = 0
88        delegate_builder = DelegateMappingBuilder(generated_identifiers=True)
89
90        for node in edge_program.graph.nodes:
91            if node.op == "call_function":
92                # Here we demonstrate case 1 where a couple of ops are fused into a single op.
93                # In this case the pattern of conv + relu is detected and fused into a single
94                # delegated op and the corresponding delegate debug identifier for that is generated
95                # and stored in the serialized blob.
96                if node.target == exir_ops.edge.aten.relu.default:
97                    input_node = node.args[0]
98                    if input_node.target == exir_ops.edge.aten.convolution.default:
99                        delegate_debug_identifier = (
100                            delegate_builder.insert_delegate_mapping_entry(
101                                [input_node, node]
102                            )
103                        )
104                        conv_relu_op = DummyOp(
105                            "conv_relu_op",
106                            delegate_debug_identifier,
107                        )
108                        number_of_instruction += 1
109                        processed_bytes += conv_relu_op.serialize()
110
111                # Here we demonstrate case 2 where a single op is decomposed into three backend ops.
112                # In this case the tan op is detected and decomposed into sin, cos and div ops. The
113                # corresponding delegate debug identifieres are generated for the three delegatged ops which
114                # map to the original tan op. These delegate debug identifieres are then serialized into the
115                # blob.
116                elif node.target == exir_ops.edge.aten.tan.default:
117                    delegate_debug_identifier = (
118                        delegate_builder.insert_delegate_mapping_entry(node)
119                    )
120                    sin_decomp_from_addmm = DummyOp(
121                        "sin_decomp_from_tan",
122                        delegate_debug_identifier,
123                    )
124                    number_of_instruction += 1
125                    processed_bytes += sin_decomp_from_addmm.serialize()
126
127                    delegate_debug_identifier = (
128                        delegate_builder.insert_delegate_mapping_entry(node)
129                    )
130                    cos_decomp_from_addmm = DummyOp(
131                        "cos_decomp_from_tan",
132                        delegate_debug_identifier,
133                    )
134                    number_of_instruction += 1
135                    processed_bytes += cos_decomp_from_addmm.serialize()
136
137                    delegate_debug_identifier = (
138                        delegate_builder.insert_delegate_mapping_entry(node)
139                    )
140                    div_decomp_from_addmm = DummyOp(
141                        "div_decomp_from_tan",
142                        delegate_debug_identifier,
143                    )
144                    number_of_instruction += 1
145                    processed_bytes += div_decomp_from_addmm.serialize()
146            elif node.op in ["placeholder", "output", "get_attr"]:
147                continue
148            else:
149                raise RuntimeError(
150                    f"{node.op} is not supported in backend BackendWithCompilerDemo"
151                )
152
153        return PreprocessResult(
154            processed_bytes=bytes(
155                str(number_of_instruction) + "#" + processed_bytes, encoding="utf8"
156            ),
157            debug_handle_map=delegate_builder.get_delegate_mapping(),
158        )
159
160    @staticmethod
161    # The sample model that will work with BackendWithDelegateMapping show above.
162    def get_test_model_and_inputs():
163        class SimpleConvNet(nn.Module):
164            def __init__(self):
165                super(SimpleConvNet, self).__init__()
166
167                # First convolutional layer
168                self.conv1 = nn.Conv2d(
169                    in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
170                )
171                self.relu1 = nn.ReLU()
172
173                # Second convolutional layer
174                self.conv2 = nn.Conv2d(
175                    in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1
176                )
177                self.relu2 = nn.ReLU()
178
179            def forward(self, x):
180                # Forward pass through the first convolutional layer
181                x = self.conv1(x)
182                x = self.relu1(x)
183
184                # Forward pass through the second convolutional layer
185                x = self.conv2(x)
186                x = self.relu2(x)
187
188                return x
189
190        class ConvReLUTanModel(nn.Module):
191            def __init__(self):
192                super(ConvReLUTanModel, self).__init__()
193
194                # Define a convolutional layer
195                self.conv_layer = SimpleConvNet()
196
197            def forward(self, x):
198                # Forward pass through convolutional layer
199                conv_output = self.conv_layer(x)
200
201                # Perform tan on conv_output
202                tan_output = torch.tan(conv_output)
203
204                return tan_output
205
206        batch_size = 4
207        channels = 3
208        height = 64
209        width = 64
210        return (ConvReLUTanModel(), (torch.randn(batch_size, channels, height, width),))
211