1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# Copyright 2023-2024 Arm Limited and/or its affiliates. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# Example script for exporting simple models to flatbuffer 9 10import argparse 11import json 12import logging 13import os 14 15from pathlib import Path 16from typing import Any, Dict, Optional, Tuple 17 18import torch 19from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder 20from executorch.backends.arm.arm_partitioner import ArmPartitioner 21from executorch.backends.arm.quantizer.arm_quantizer import ( 22 ArmQuantizer, 23 get_symmetric_quantization_config, 24) 25 26from executorch.backends.arm.util.arm_model_evaluator import ( 27 GenericModelEvaluator, 28 MobileNetV2Evaluator, 29) 30from executorch.devtools.backend_debug import get_delegation_info 31from executorch.exir import ( 32 EdgeCompileConfig, 33 ExecutorchBackendConfig, 34 to_edge_transform_and_lower, 35) 36from executorch.extension.export_util.utils import save_pte_program 37from tabulate import tabulate 38 39# Quantize model if required using the standard export quantizaion flow. 40from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 41from torch.utils.data import DataLoader 42 43from ..models import MODEL_NAME_TO_MODEL 44from ..models.model_factory import EagerModelFactory 45 46FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 47logging.basicConfig(level=logging.WARNING, format=FORMAT) 48 49 50def get_model_and_inputs_from_name(model_name: str) -> Tuple[torch.nn.Module, Any]: 51 """Given the name of an example pytorch model, return it and example inputs. 52 53 Raises RuntimeError if there is no example model corresponding to the given name. 54 """ 55 # Case 1: Model is defined in this file 56 if model_name in models.keys(): 57 model = models[model_name]() 58 example_inputs = models[model_name].example_input 59 # Case 2: Model is defined in examples/models/ 60 elif model_name in MODEL_NAME_TO_MODEL.keys(): 61 logging.warning( 62 "Using a model from examples/models not all of these are currently supported" 63 ) 64 model, example_inputs, _, _ = EagerModelFactory.create_model( 65 *MODEL_NAME_TO_MODEL[model_name] 66 ) 67 # Case 3: Model is in an external python file loaded as a module. 68 # ModelUnderTest should be a torch.nn.module instance 69 # ModelInputs should be a tuple of inputs to the forward function 70 elif model_name.endswith(".py"): 71 import importlib.util 72 73 # load model's module and add it 74 spec = importlib.util.spec_from_file_location("tmp_model", model_name) 75 module = importlib.util.module_from_spec(spec) 76 spec.loader.exec_module(module) 77 model = module.ModelUnderTest 78 example_inputs = module.ModelInputs 79 80 else: 81 raise RuntimeError( 82 f"Model '{model_name}' is not a valid name. Use --help for a list of available models." 83 ) 84 85 return model, example_inputs 86 87 88def quantize( 89 model: torch.nn.Module, 90 model_name: str, 91 example_inputs: Tuple[torch.Tensor], 92 evaluator_name: str | None, 93 evaluator_config: Dict[str, Any] | None, 94) -> torch.nn.Module: 95 """This is the official recommended flow for quantization in pytorch 2.0 export""" 96 logging.info("Quantizing Model...") 97 logging.debug(f"Original model: {model}") 98 quantizer = ArmQuantizer() 99 100 # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel 101 operator_config = get_symmetric_quantization_config(is_per_channel=False) 102 quantizer.set_global(operator_config) 103 m = prepare_pt2e(model, quantizer) 104 105 dataset = get_calibration_data( 106 model_name, example_inputs, evaluator_name, evaluator_config 107 ) 108 109 # The dataset could be a tuple of tensors or a DataLoader 110 # These two cases need to be accounted for 111 if isinstance(dataset, DataLoader): 112 for sample, _ in dataset: 113 m(sample) 114 else: 115 m(*dataset) 116 117 m = convert_pt2e(m) 118 logging.debug(f"Quantized model: {m}") 119 return m 120 121 122# Simple example models 123class AddModule(torch.nn.Module): 124 def __init__(self): 125 super().__init__() 126 127 def forward(self, x): 128 return x + x 129 130 example_input = (torch.ones(5, dtype=torch.int32),) 131 can_delegate = True 132 133 134class AddModule2(torch.nn.Module): 135 def __init__(self): 136 super().__init__() 137 138 def forward(self, x, y): 139 return x + y 140 141 example_input = ( 142 torch.ones(5, dtype=torch.int32), 143 torch.ones(5, dtype=torch.int32), 144 ) 145 can_delegate = True 146 147 148class AddModule3(torch.nn.Module): 149 def __init__(self): 150 super().__init__() 151 152 def forward(self, x, y): 153 return (x + y, x + x) 154 155 example_input = ( 156 torch.ones(5, dtype=torch.int32), 157 torch.ones(5, dtype=torch.int32), 158 ) 159 can_delegate = True 160 161 162class SoftmaxModule(torch.nn.Module): 163 def __init__(self): 164 super().__init__() 165 self.softmax = torch.nn.Softmax(dim=0) 166 167 def forward(self, x): 168 z = self.softmax(x) 169 return z 170 171 example_input = (torch.ones(2, 2),) 172 can_delegate = False 173 174 175models = { 176 "add": AddModule, 177 "add2": AddModule2, 178 "add3": AddModule3, 179 "softmax": SoftmaxModule, 180} 181 182calibration_data = { 183 "add": (torch.randn(1, 5),), 184 "add2": ( 185 torch.randn(1, 5), 186 torch.randn(1, 5), 187 ), 188 "add3": ( 189 torch.randn(32, 5), 190 torch.randn(32, 5), 191 ), 192 "softmax": (torch.randn(32, 2, 2),), 193} 194 195evaluators = { 196 "generic": GenericModelEvaluator, 197 "mv2": MobileNetV2Evaluator, 198} 199 200targets = [ 201 "ethos-u55-32", 202 "ethos-u55-64", 203 "ethos-u55-128", 204 "ethos-u55-256", 205 "ethos-u85-128", 206 "ethos-u85-256", 207 "ethos-u85-512", 208 "ethos-u85-1024", 209 "ethos-u85-2048", 210 "TOSA", 211] 212 213 214def get_calibration_data( 215 model_name: str, 216 example_inputs: Tuple[torch.Tensor], 217 evaluator_name: str | None, 218 evaluator_config: str | None, 219): 220 # Firstly, if the model is being evaluated, take the evaluators calibration function if it has one 221 if evaluator_name is not None: 222 evaluator = evaluators[evaluator_name] 223 224 if hasattr(evaluator, "get_calibrator"): 225 assert evaluator_config is not None 226 227 config_path = Path(evaluator_config) 228 with config_path.open() as f: 229 config = json.load(f) 230 231 if evaluator_name == "mv2": 232 return evaluator.get_calibrator( 233 training_dataset_path=config["training_dataset_path"] 234 ) 235 else: 236 raise RuntimeError(f"Unknown evaluator: {evaluator_name}") 237 238 # If the model is in the calibration_data dictionary, get the data from there 239 # This is used for the simple model examples provided 240 if model_name in calibration_data: 241 return calibration_data[model_name] 242 243 # As a last resort, fallback to the scripts previous behavior and return the example inputs 244 return example_inputs 245 246 247def get_compile_spec( 248 target: str, intermediates: Optional[str] = None 249) -> ArmCompileSpecBuilder: 250 spec_builder = None 251 if target == "TOSA": 252 spec_builder = ( 253 ArmCompileSpecBuilder() 254 .tosa_compile_spec("TOSA-0.80.0+BI") 255 .set_permute_memory_format(True) 256 ) 257 elif "ethos-u55" in target: 258 spec_builder = ( 259 ArmCompileSpecBuilder() 260 .ethosu_compile_spec( 261 target, 262 system_config="Ethos_U55_High_End_Embedded", 263 memory_mode="Shared_Sram", 264 extra_flags="--debug-force-regor --output-format=raw", 265 ) 266 .set_permute_memory_format(True) 267 .set_quantize_io(True) 268 ) 269 elif "ethos-u85" in target: 270 spec_builder = ( 271 ArmCompileSpecBuilder() 272 .ethosu_compile_spec( 273 target, 274 system_config="Ethos_U85_SYS_DRAM_Mid", 275 memory_mode="Shared_Sram", 276 extra_flags="--output-format=raw", 277 ) 278 .set_permute_memory_format(True) 279 .set_quantize_io(True) 280 ) 281 282 if intermediates is not None: 283 spec_builder.dump_intermediate_artifacts_to(intermediates) 284 285 return spec_builder.build() 286 287 288def evaluate_model( 289 model_name: str, 290 intermediates: str, 291 model_fp32: torch.nn.Module, 292 model_int8: torch.nn.Module, 293 example_inputs: Tuple[torch.Tensor], 294 evaluator_name: str, 295 evaluator_config: str | None, 296) -> None: 297 evaluator = evaluators[evaluator_name] 298 299 # Get the path of the TOSA flatbuffer that is dumped 300 intermediates_path = Path(intermediates) 301 tosa_paths = list(intermediates_path.glob("*.tosa")) 302 303 if evaluator.REQUIRES_CONFIG: 304 assert evaluator_config is not None 305 306 config_path = Path(evaluator_config) 307 with config_path.open() as f: 308 config = json.load(f) 309 310 if evaluator_name == "mv2": 311 init_evaluator = evaluator( 312 model_name, 313 model_fp32, 314 model_int8, 315 example_inputs, 316 str(tosa_paths[0]), 317 config["batch_size"], 318 config["validation_dataset_path"], 319 ) 320 else: 321 raise RuntimeError(f"Unknown evaluator {evaluator_name}") 322 else: 323 init_evaluator = evaluator( 324 model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0]) 325 ) 326 327 quant_metrics = init_evaluator.evaluate() 328 output_json_path = intermediates_path / "quant_metrics.json" 329 330 with output_json_path.open("w") as json_file: 331 json.dump(quant_metrics, json_file) 332 333 334def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None): 335 graph_module = edge.exported_program().graph_module 336 delegation_info = get_delegation_info(graph_module) 337 df = delegation_info.get_operator_delegation_dataframe() 338 table = tabulate(df, headers="keys", tablefmt="fancy_grid") 339 delegation_info_string = f"Delegation info:\n{delegation_info.get_summary()}\nDelegation table:\n{table}\n" 340 logging.info(delegation_info_string) 341 if intermediate_files_folder is not None: 342 delegation_file_path = os.path.join( 343 intermediate_files_folder, "delegation_info.txt" 344 ) 345 with open(delegation_file_path, "w") as file: 346 file.write(delegation_info_string) 347 348 349def get_args(): 350 parser = argparse.ArgumentParser() 351 parser.add_argument( 352 "-m", 353 "--model_name", 354 required=True, 355 help=f"Provide model name. Valid ones: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}", 356 ) 357 parser.add_argument( 358 "-d", 359 "--delegate", 360 action="store_true", 361 required=False, 362 default=False, 363 help="Flag for producing ArmBackend delegated model", 364 ) 365 parser.add_argument( 366 "-t", 367 "--target", 368 action="store", 369 required=False, 370 default="ethos-u55-128", 371 choices=targets, 372 help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}", 373 ) 374 parser.add_argument( 375 "-e", 376 "--evaluate", 377 required=False, 378 nargs="?", 379 const="generic", 380 choices=["generic", "mv2"], 381 help="Flag for running evaluation of the model.", 382 ) 383 parser.add_argument( 384 "-c", 385 "--evaluate_config", 386 required=False, 387 default=None, 388 help="Provide path to evaluator config, if it is required.", 389 ) 390 parser.add_argument( 391 "-q", 392 "--quantize", 393 action="store_true", 394 required=False, 395 default=False, 396 help="Produce a quantized model", 397 ) 398 parser.add_argument( 399 "-s", 400 "--so_library", 401 required=False, 402 default=None, 403 help="Provide path to so library. E.g., cmake-out/examples/portable/custom_ops/libcustom_ops_aot_lib.so", 404 ) 405 parser.add_argument( 406 "--debug", action="store_true", help="Set the logging level to debug." 407 ) 408 parser.add_argument( 409 "-i", 410 "--intermediates", 411 action="store", 412 required=False, 413 help="Store intermediate output (like TOSA artefacts) somewhere.", 414 ) 415 parser.add_argument( 416 "-o", 417 "--output", 418 action="store", 419 required=False, 420 help="Location for outputs, if not the default of cwd.", 421 ) 422 args = parser.parse_args() 423 424 if args.evaluate and ( 425 args.quantize is None or args.intermediates is None or (not args.delegate) 426 ): 427 raise RuntimeError( 428 "--evaluate requires --quantize, --intermediates and --delegate to be enabled." 429 ) 430 431 if args.debug: 432 logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True) 433 434 if args.quantize and not args.so_library: 435 logging.warning( 436 "Quantization enabled without supplying path to libcustom_ops_aot_lib using -s flag." 437 + "This is required for running quantized models with unquantized input." 438 ) 439 440 # if we have custom ops, register them before processing the model 441 if args.so_library is not None: 442 logging.info(f"Loading custom ops from {args.so_library}") 443 torch.ops.load_library(args.so_library) 444 445 if ( 446 args.model_name in models.keys() 447 and args.delegate is True 448 and models[args.model_name].can_delegate is False 449 ): 450 raise RuntimeError(f"Model {args.model_name} cannot be delegated.") 451 452 return args 453 454 455if __name__ == "__main__": 456 args = get_args() 457 458 # Pick model from one of the supported lists 459 model, example_inputs = get_model_and_inputs_from_name(args.model_name) 460 model = model.eval() 461 462 # export_for_training under the assumption we quantize, the exported form also works 463 # in to_edge if we don't quantize 464 exported_program = torch.export.export_for_training(model, example_inputs) 465 model = exported_program.module() 466 model_fp32 = model 467 468 # Quantize if required 469 model_int8 = None 470 if args.quantize: 471 model = quantize( 472 model, args.model_name, example_inputs, args.evaluate, args.evaluate_config 473 ) 474 model_int8 = model 475 # Wrap quantized model back into an exported_program 476 exported_program = torch.export.export_for_training(model, example_inputs) 477 478 if args.intermediates: 479 os.makedirs(args.intermediates, exist_ok=True) 480 481 if args.delegate: 482 # As we can target multiple output encodings from ArmBackend, one must 483 # be specified. 484 compile_spec = get_compile_spec(args.target, args.intermediates) 485 edge = to_edge_transform_and_lower( 486 exported_program, 487 partitioner=[ArmPartitioner(compile_spec)], 488 compile_config=EdgeCompileConfig( 489 _check_ir_validity=False, 490 _skip_dim_order=True, 491 ), 492 ) 493 else: 494 edge = to_edge_transform_and_lower( 495 exported_program, 496 compile_config=EdgeCompileConfig( 497 _check_ir_validity=False, 498 _skip_dim_order=True, 499 ), 500 ) 501 502 dump_delegation_info(edge, args.intermediates) 503 504 try: 505 exec_prog = edge.to_executorch( 506 config=ExecutorchBackendConfig(extract_delegate_segments=False) 507 ) 508 except RuntimeError as e: 509 if "Missing out variants" in str(e.args[0]): 510 raise RuntimeError( 511 e.args[0] 512 + ".\nThis likely due to an external so library not being loaded. Supply a path to it with the -s flag." 513 ).with_traceback(e.__traceback__) from None 514 else: 515 raise e 516 517 model_name = os.path.basename(os.path.splitext(args.model_name)[0]) 518 output_name = f"{model_name}" + ( 519 f"_arm_delegate_{args.target}" 520 if args.delegate is True 521 else f"_arm_{args.target}" 522 ) 523 524 if args.output is not None: 525 output_name = os.path.join(args.output, output_name) 526 527 save_pte_program(exec_prog, output_name) 528 529 if args.evaluate: 530 evaluate_model( 531 args.model_name, 532 args.intermediates, 533 model_fp32, 534 model_int8, 535 example_inputs, 536 args.evaluate, 537 args.evaluate_config, 538 ) 539