xref: /aosp_15_r20/external/executorch/examples/portable/scripts/export_and_delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Example script for exporting simple models to flatbuffer
8
9import argparse
10import logging
11
12import torch
13from executorch.exir.backend.backend_api import to_backend
14from executorch.exir.backend.test.backend_with_compiler_demo import (
15    BackendWithCompilerDemo,
16)
17from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
18from executorch.extension.export_util import export_to_edge
19
20from ...models import MODEL_NAME_TO_MODEL
21from ...models.model_factory import EagerModelFactory
22
23
24FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
25logging.basicConfig(level=logging.INFO, format=FORMAT)
26
27
28"""
29BackendWithCompilerDemo is a test demo backend, only supports torch.mm and torch.add, here are some examples
30to show how to lower torch.mm and torch.add into this backend via to_backend API.
31
32We support three ways:
331. Lower the whole graph
342. Lower part of the graph via graph partitioner
353. Composite a model with lowered module
36"""
37
38
39def export_composite_module_with_lower_graph():
40    """
41
42    AddMulModule:
43
44        input -> torch.mm -> torch.add -> output
45
46    this module can be lowered to the demo backend as a delegate
47
48        input -> [lowered module (delegate)] -> output
49
50    the lowered module can be used to composite with other modules
51
52        input -> [lowered module (delegate)] -> sub  -> output
53               |--------  composite module    -------|
54
55    """
56    logging.info(
57        "Running the example to export a composite module with lowered graph..."
58    )
59
60    m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
61    m_compile_spec = m.get_compile_spec()
62
63    # pre-autograd export. eventually this will become torch.export
64    m = torch.export.export_for_training(m, m_inputs).module()
65    edge = export_to_edge(m, m_inputs)
66    logging.info(f"Exported graph:\n{edge.exported_program().graph}")
67
68    # Lower AddMulModule to the demo backend
69    logging.info("Lowering to the demo backend...")
70    lowered_graph = to_backend(
71        BackendWithCompilerDemo.__name__, edge.exported_program(), m_compile_spec
72    )
73
74    # Composite the lower graph with other module
75    class CompositeModule(torch.nn.Module):
76        def __init__(self):
77            super().__init__()
78            self.lowered_graph = lowered_graph
79
80        def forward(self, *args):
81            return torch.sub(self.lowered_graph(*args), torch.ones(1))
82
83    # Get the graph for the composite module, which includes lowered graph
84    m = CompositeModule()
85    m = m.eval()
86    # pre-autograd export. eventually this will become torch.export
87    m = torch.export.export_for_training(m, m_inputs).module()
88    composited_edge = export_to_edge(m, m_inputs)
89
90    # The graph module is still runnerable
91    composited_edge.exported_program().graph_module(*m_inputs)
92
93    logging.info(f"Lowered graph:\n{composited_edge.exported_program().graph}")
94
95    exec_prog = composited_edge.to_executorch()
96    buffer = exec_prog.buffer
97
98    model_name = "composite_model"
99    filename = f"{model_name}.pte"
100    logging.info(f"Saving exported program to {filename}")
101    with open(filename, "wb") as file:
102        file.write(buffer)
103
104
105def export_and_lower_partitioned_graph():
106    """
107
108    Model:
109        input -> torch.mm -> torch.add -> torch.sub -> torch.mm -> torch.add -> output
110
111    AddMulPartitionerDemo is a graph partitioner that tag the lowered nodes, in this case, it will tag
112    torch.mm and torch.add nodes. After to_backend, the graph will becomes:
113
114        input -> [lowered module (delegate)] -> torch.sub -> [lowered module (delegate)] -> output
115    """
116
117    logging.info("Running the example to export and lower the whole graph...")
118
119    class Model(torch.nn.Module):
120        def __init__(self):
121            super().__init__()
122
123        def forward(self, a, x, b):
124            y = torch.mm(a, x)
125            z = y + b
126            a = z - a
127            y = torch.mm(a, x)
128            z = y + b
129            return z
130
131        def get_example_inputs(self):
132            return (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
133
134    m = Model()
135    m_inputs = m.get_example_inputs()
136    # pre-autograd export. eventually this will become torch.export
137    m = torch.export.export_for_training(m, m_inputs).module()
138    edge = export_to_edge(m, m_inputs)
139    logging.info(f"Exported graph:\n{edge.exported_program().graph}")
140
141    # Lower to backend_with_compiler_demo
142    logging.info("Lowering to the demo backend...")
143    edge = edge.to_backend(AddMulPartitionerDemo())
144    logging.info(f"Lowered graph:\n{edge.exported_program().graph}")
145
146    exec_prog = edge.to_executorch()
147    buffer = exec_prog.buffer
148
149    model_name = "partition_lowered_model"
150    filename = f"{model_name}.pte"
151    logging.info(f"Saving exported program to {filename}")
152    with open(filename, "wb") as file:
153        file.write(buffer)
154
155
156def export_and_lower_the_whole_graph():
157    """
158
159    AddMulModule:
160
161        input -> torch.mm -> torch.add -> output
162
163    this module can be lowered to the demo backend as a delegate
164
165        input -> [lowered module (delegate)] -> output
166    """
167    logging.info("Running the example to export and lower the whole graph...")
168
169    m, m_inputs, _, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL["add_mul"])
170    m_compile_spec = m.get_compile_spec()
171
172    m_inputs = m.get_example_inputs()
173    # pre-autograd export. eventually this will become torch.export
174    m = torch.export.export_for_training(m, m_inputs).module()
175    edge = export_to_edge(m, m_inputs)
176    logging.info(f"Exported graph:\n{edge.exported_program().graph}")
177
178    # Lower AddMulModule to the demo backend
179    logging.info("Lowering to the demo backend...")
180    lowered_module = to_backend(
181        BackendWithCompilerDemo.__name__, edge.exported_program(), m_compile_spec
182    )
183
184    buffer = lowered_module.buffer()
185
186    model_name = "whole"
187    filename = f"{model_name}.pte"
188    logging.info(f"Saving exported program to {filename}")
189    with open(filename, "wb") as file:
190        file.write(buffer)
191
192
193OPTIONS_TO_LOWER = {
194    "composite": export_composite_module_with_lower_graph,
195    "partition": export_and_lower_partitioned_graph,
196    "whole": export_and_lower_the_whole_graph,
197}
198
199if __name__ == "__main__":
200    parser = argparse.ArgumentParser()
201    parser.add_argument(
202        "--option",
203        required=True,
204        choices=list(OPTIONS_TO_LOWER.keys()),
205        help=f"Provide the flow name. Valid ones: {list(OPTIONS_TO_LOWER.keys())}",
206    )
207
208    args = parser.parse_args()
209
210    # Choose one option
211    option = OPTIONS_TO_LOWER[args.option]
212
213    # Run the example flow
214    option()
215