1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import argparse 10import inspect 11import os 12import sys 13from typing import Any, Dict, List, Type 14 15import torch 16from executorch.exir import CaptureConfig 17from executorch.exir.passes import MemoryPlanningPass 18from torch import nn 19from torch.export import Dim 20 21from ..end2end.exported_module import ExportedModule 22 23"""Traces and exports nn.Modules to ExecuTorch .pte program files. 24 25This tool mainly exists to export programs for C++ tests, but can also 26be used to export models manually. 27""" 28 29# 30# Module definitions. 31# 32# If we ever have more than a handful, consider splitting into multiple files. 33# 34 35 36class ModuleBasic(nn.Module): 37 def __init__(self): 38 super(ModuleBasic, self).__init__() 39 40 def forward(self, x): 41 return torch.sin(x).max() 42 43 def get_random_inputs(self): 44 return (torch.randn(100),) 45 46 @staticmethod 47 def get_export_kwargs() -> Dict[str, Any]: 48 """Returns custom trace params for ExportedModule.""" 49 return { 50 # aten::max.default does not have an out variant. 51 "ignore_to_out_var_failure": True, 52 } 53 54 55class ModuleIndex(nn.Module): 56 def __init__(self): 57 super(ModuleIndex, self).__init__() 58 59 def forward(self, x): 60 # Weird index that happens to generate a None in torch.index.Tensor_out 61 # which is desirable for deserialization testing. A modified form of 62 # an example index from https://pytorch.org/cppdocs/notes/tensor_indexing.html. 63 return x[1::2, torch.tensor([1, 2])] 64 65 def get_random_inputs(self): 66 return (torch.randn(10, 10, 10),) 67 68 69class ModuleNoOp(nn.Module): 70 def __init__(self): 71 super(ModuleNoOp, self).__init__() 72 73 def forward(self, x, y): 74 return (x, y) 75 76 def get_random_inputs(self): 77 return (torch.randn(2, 2), torch.randn(2, 2)) 78 79 80class ModuleAdd(nn.Module): 81 def __init__(self): 82 super(ModuleAdd, self).__init__() 83 84 def forward(self, x, y, alpha): 85 return torch.add(x, y, alpha=alpha) 86 87 def get_random_inputs(self): 88 return (torch.randn(2, 2), torch.randn(2, 2), 1.0) 89 90 91class ModuleAddHalf(nn.Module): 92 def __init__(self): 93 super().__init__() 94 95 def forward(self, x, y, alpha): 96 return torch.add(x, y, alpha=alpha) 97 98 def get_random_inputs(self): 99 return ( 100 torch.randn(2, 2).half(), 101 torch.randn(2, 2).half(), 102 1.0, 103 ) 104 105 106class ModuleDynamicCatUnallocatedIO(nn.Module): 107 def __init__(self): 108 super(ModuleDynamicCatUnallocatedIO, self).__init__() 109 # TODO(T163238401) 110 self._inputs = (torch.randn(3, 4),) 111 112 def forward(self, k): 113 k = torch.cat((k, torch.ones(1, 4))) 114 return k 115 116 def get_random_inputs(self): 117 return self._inputs 118 119 def get_dynamic_shapes(self): 120 return ({0: Dim("dim0_k", max=3)},) 121 122 def get_memory_planning_pass(self): 123 return MemoryPlanningPass( 124 alloc_graph_input=False, 125 alloc_graph_output=False, 126 ) 127 128 @staticmethod 129 def get_export_kwargs(): 130 return {"capture_config": CaptureConfig(pt2_mode=True, enable_aot=True)} 131 132 133class ModuleLinear(torch.nn.Module): 134 def __init__(self): 135 super().__init__() 136 self.a = 3 * torch.ones(2, 2, dtype=torch.float) 137 self.b = 2 * torch.ones(2, 2, dtype=torch.float) 138 139 def forward(self, x: torch.Tensor): 140 out_1 = torch.mul(self.a, x) 141 out_2 = torch.add(out_1, self.b) 142 return out_2 143 144 def get_random_inputs(self): 145 return (torch.ones(2, 2, dtype=torch.float),) 146 147 148class ModuleMultipleEntry(torch.nn.Module): 149 def __init__(self): 150 super().__init__() 151 self.a = 3 * torch.ones(2, 2, dtype=torch.float) 152 self.b = 2 * torch.ones(2, 2, dtype=torch.float) 153 154 def forward(self, x: torch.Tensor): 155 return x + self.a 156 157 def forward2(self, x: torch.Tensor): 158 return x + self.a + self.b 159 160 def get_random_inputs(self): 161 return (torch.ones(2, 2, dtype=torch.float),) 162 163 @staticmethod 164 def get_method_names_to_export() -> List[str]: 165 return ["forward", "forward2"] 166 167 168class ModuleSimpleTrain(torch.nn.Module): 169 def __init__(self): 170 super().__init__() 171 self.linear = torch.nn.Linear(3, 3) 172 self.loss = torch.nn.CrossEntropyLoss() 173 174 def forward(self, x, y): 175 return self.loss(self.linear(x).softmax(dim=0), y) 176 177 def get_random_inputs(self): 178 return (torch.randn(3), torch.tensor([1.0, 0.0, 0.0])) 179 180 @staticmethod 181 def export_joint(): 182 return True 183 184 185# 186# Main logic. 187# 188 189 190def export_module_to_program( 191 module_class: Type[nn.Module], 192 skip_type_promotion: bool, 193): 194 """Exports the module and returns the serialized program data.""" 195 torch.manual_seed(0) 196 # Look for an optional @staticmethod that defines custom trace params. 197 export_kwargs: Dict[str, Any] = {} 198 if hasattr(module_class, "get_export_kwargs"): 199 # pyre-ignore[16]: pyre doesn't know about get_export_kwargs. 200 export_kwargs = module_class.get_export_kwargs() 201 export_joint = False 202 if hasattr(module_class, "export_joint"): 203 export_joint = module_class.export_joint() # pyre-ignore 204 if hasattr(module_class, "get_method_names_to_export"): 205 # pyre-ignore[16]: pyre doesn't know about get_export_kwargs. 206 methods = module_class.get_method_names_to_export() 207 else: 208 methods = ["forward"] 209 module = ExportedModule.export( 210 module_class, 211 methods, 212 skip_type_promotion=skip_type_promotion, 213 export_joint_graph=export_joint, 214 **export_kwargs, 215 ) 216 return module.executorch_program.buffer 217 218 219def main() -> None: 220 # These args are optimized for genrule usage. There's a lot of startup 221 # overhead for this tool, so it's faster to export multiple models at once 222 # when possible. 223 torch.manual_seed(0) 224 parser = argparse.ArgumentParser( 225 prog="export_program", 226 description="Exports nn.Module models to ExecuTorch .pte files", 227 ) 228 parser.add_argument( 229 "--modules", 230 help="Comma-separated list of model class names to export; " 231 + "e.g., '--modules=ModuleBasic,ModuleAdd'", 232 type=lambda s: [item.strip() for item in s.split(",")], 233 ) 234 parser.add_argument( 235 "--outdir", 236 type=str, 237 required=True, 238 help="Path to the directory to write <classname>.pte files to", 239 ) 240 args = parser.parse_args() 241 242 # Find the classes to export. Only looks in this module for now, but could 243 # be extended to look in other modules if helpful. 244 module_names_to_classes: Dict[str, Type[nn.Module]] = {} 245 for module in args.modules: 246 module_class = getattr(sys.modules[__name__], module, None) 247 if not (inspect.isclass(module_class) and issubclass(module_class, nn.Module)): 248 raise NameError(f"Could not find nn.Module class named '{module}'") 249 module_names_to_classes[module] = module_class 250 251 # Export and write to the output files. 252 os.makedirs(args.outdir, exist_ok=True) 253 for module_name, module_class in module_names_to_classes.items(): 254 skip_type_promotion = False 255 if module_name == "ModuleAddHalf": 256 # Skip type promotion to keep the model in fp16. 257 # Type promotion will convert to fp32. 258 skip_type_promotion = True 259 outfile = os.path.join(args.outdir, f"{module_name}.pte") 260 with open(outfile, "wb") as fp: 261 fp.write( 262 export_module_to_program( 263 module_class, 264 skip_type_promotion=skip_type_promotion, 265 ) 266 ) 267 print(f"Exported {module_name} and wrote program data to {outfile}") 268 269 270if __name__ == "__main__": 271 main() 272