xref: /aosp_15_r20/external/executorch/test/models/export_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import argparse
10import inspect
11import os
12import sys
13from typing import Any, Dict, List, Type
14
15import torch
16from executorch.exir import CaptureConfig
17from executorch.exir.passes import MemoryPlanningPass
18from torch import nn
19from torch.export import Dim
20
21from ..end2end.exported_module import ExportedModule
22
23"""Traces and exports nn.Modules to ExecuTorch .pte program files.
24
25This tool mainly exists to export programs for C++ tests, but can also
26be used to export models manually.
27"""
28
29#
30# Module definitions.
31#
32# If we ever have more than a handful, consider splitting into multiple files.
33#
34
35
36class ModuleBasic(nn.Module):
37    def __init__(self):
38        super(ModuleBasic, self).__init__()
39
40    def forward(self, x):
41        return torch.sin(x).max()
42
43    def get_random_inputs(self):
44        return (torch.randn(100),)
45
46    @staticmethod
47    def get_export_kwargs() -> Dict[str, Any]:
48        """Returns custom trace params for ExportedModule."""
49        return {
50            # aten::max.default does not have an out variant.
51            "ignore_to_out_var_failure": True,
52        }
53
54
55class ModuleIndex(nn.Module):
56    def __init__(self):
57        super(ModuleIndex, self).__init__()
58
59    def forward(self, x):
60        # Weird index that happens to generate a None in torch.index.Tensor_out
61        # which is desirable for deserialization testing. A modified form of
62        # an example index from https://pytorch.org/cppdocs/notes/tensor_indexing.html.
63        return x[1::2, torch.tensor([1, 2])]
64
65    def get_random_inputs(self):
66        return (torch.randn(10, 10, 10),)
67
68
69class ModuleNoOp(nn.Module):
70    def __init__(self):
71        super(ModuleNoOp, self).__init__()
72
73    def forward(self, x, y):
74        return (x, y)
75
76    def get_random_inputs(self):
77        return (torch.randn(2, 2), torch.randn(2, 2))
78
79
80class ModuleAdd(nn.Module):
81    def __init__(self):
82        super(ModuleAdd, self).__init__()
83
84    def forward(self, x, y, alpha):
85        return torch.add(x, y, alpha=alpha)
86
87    def get_random_inputs(self):
88        return (torch.randn(2, 2), torch.randn(2, 2), 1.0)
89
90
91class ModuleAddHalf(nn.Module):
92    def __init__(self):
93        super().__init__()
94
95    def forward(self, x, y, alpha):
96        return torch.add(x, y, alpha=alpha)
97
98    def get_random_inputs(self):
99        return (
100            torch.randn(2, 2).half(),
101            torch.randn(2, 2).half(),
102            1.0,
103        )
104
105
106class ModuleDynamicCatUnallocatedIO(nn.Module):
107    def __init__(self):
108        super(ModuleDynamicCatUnallocatedIO, self).__init__()
109        # TODO(T163238401)
110        self._inputs = (torch.randn(3, 4),)
111
112    def forward(self, k):
113        k = torch.cat((k, torch.ones(1, 4)))
114        return k
115
116    def get_random_inputs(self):
117        return self._inputs
118
119    def get_dynamic_shapes(self):
120        return ({0: Dim("dim0_k", max=3)},)
121
122    def get_memory_planning_pass(self):
123        return MemoryPlanningPass(
124            alloc_graph_input=False,
125            alloc_graph_output=False,
126        )
127
128    @staticmethod
129    def get_export_kwargs():
130        return {"capture_config": CaptureConfig(pt2_mode=True, enable_aot=True)}
131
132
133class ModuleLinear(torch.nn.Module):
134    def __init__(self):
135        super().__init__()
136        self.a = 3 * torch.ones(2, 2, dtype=torch.float)
137        self.b = 2 * torch.ones(2, 2, dtype=torch.float)
138
139    def forward(self, x: torch.Tensor):
140        out_1 = torch.mul(self.a, x)
141        out_2 = torch.add(out_1, self.b)
142        return out_2
143
144    def get_random_inputs(self):
145        return (torch.ones(2, 2, dtype=torch.float),)
146
147
148class ModuleMultipleEntry(torch.nn.Module):
149    def __init__(self):
150        super().__init__()
151        self.a = 3 * torch.ones(2, 2, dtype=torch.float)
152        self.b = 2 * torch.ones(2, 2, dtype=torch.float)
153
154    def forward(self, x: torch.Tensor):
155        return x + self.a
156
157    def forward2(self, x: torch.Tensor):
158        return x + self.a + self.b
159
160    def get_random_inputs(self):
161        return (torch.ones(2, 2, dtype=torch.float),)
162
163    @staticmethod
164    def get_method_names_to_export() -> List[str]:
165        return ["forward", "forward2"]
166
167
168class ModuleSimpleTrain(torch.nn.Module):
169    def __init__(self):
170        super().__init__()
171        self.linear = torch.nn.Linear(3, 3)
172        self.loss = torch.nn.CrossEntropyLoss()
173
174    def forward(self, x, y):
175        return self.loss(self.linear(x).softmax(dim=0), y)
176
177    def get_random_inputs(self):
178        return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0]))
179
180    @staticmethod
181    def export_joint():
182        return True
183
184
185#
186# Main logic.
187#
188
189
190def export_module_to_program(
191    module_class: Type[nn.Module],
192    skip_type_promotion: bool,
193):
194    """Exports the module and returns the serialized program data."""
195    torch.manual_seed(0)
196    # Look for an optional @staticmethod that defines custom trace params.
197    export_kwargs: Dict[str, Any] = {}
198    if hasattr(module_class, "get_export_kwargs"):
199        # pyre-ignore[16]: pyre doesn't know about get_export_kwargs.
200        export_kwargs = module_class.get_export_kwargs()
201    export_joint = False
202    if hasattr(module_class, "export_joint"):
203        export_joint = module_class.export_joint()  # pyre-ignore
204    if hasattr(module_class, "get_method_names_to_export"):
205        # pyre-ignore[16]: pyre doesn't know about get_export_kwargs.
206        methods = module_class.get_method_names_to_export()
207    else:
208        methods = ["forward"]
209    module = ExportedModule.export(
210        module_class,
211        methods,
212        skip_type_promotion=skip_type_promotion,
213        export_joint_graph=export_joint,
214        **export_kwargs,
215    )
216    return module.executorch_program.buffer
217
218
219def main() -> None:
220    # These args are optimized for genrule usage. There's a lot of startup
221    # overhead for this tool, so it's faster to export multiple models at once
222    # when possible.
223    torch.manual_seed(0)
224    parser = argparse.ArgumentParser(
225        prog="export_program",
226        description="Exports nn.Module models to ExecuTorch .pte files",
227    )
228    parser.add_argument(
229        "--modules",
230        help="Comma-separated list of model class names to export; "
231        + "e.g., '--modules=ModuleBasic,ModuleAdd'",
232        type=lambda s: [item.strip() for item in s.split(",")],
233    )
234    parser.add_argument(
235        "--outdir",
236        type=str,
237        required=True,
238        help="Path to the directory to write <classname>.pte files to",
239    )
240    args = parser.parse_args()
241
242    # Find the classes to export. Only looks in this module for now, but could
243    # be extended to look in other modules if helpful.
244    module_names_to_classes: Dict[str, Type[nn.Module]] = {}
245    for module in args.modules:
246        module_class = getattr(sys.modules[__name__], module, None)
247        if not (inspect.isclass(module_class) and issubclass(module_class, nn.Module)):
248            raise NameError(f"Could not find nn.Module class named '{module}'")
249        module_names_to_classes[module] = module_class
250
251    # Export and write to the output files.
252    os.makedirs(args.outdir, exist_ok=True)
253    for module_name, module_class in module_names_to_classes.items():
254        skip_type_promotion = False
255        if module_name == "ModuleAddHalf":
256            # Skip type promotion to keep the model in fp16.
257            # Type promotion will convert to fp32.
258            skip_type_promotion = True
259        outfile = os.path.join(args.outdir, f"{module_name}.pte")
260        with open(outfile, "wb") as fp:
261            fp.write(
262                export_module_to_program(
263                    module_class,
264                    skip_type_promotion=skip_type_promotion,
265                )
266            )
267        print(f"Exported {module_name} and wrote program data to {outfile}")
268
269
270if __name__ == "__main__":
271    main()
272