1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker# Example script for exporting simple models to flatbuffer 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport argparse 12*523fa7a6SAndroid Build Coastguard Workerimport copy 13*523fa7a6SAndroid Build Coastguard Workerimport logging 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.devtools import generate_etrecord 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ( 19*523fa7a6SAndroid Build Coastguard Worker EdgeCompileConfig, 20*523fa7a6SAndroid Build Coastguard Worker ExecutorchBackendConfig, 21*523fa7a6SAndroid Build Coastguard Worker to_edge_transform_and_lower, 22*523fa7a6SAndroid Build Coastguard Worker) 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.export_util.utils import save_pte_program 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerfrom ..models import MODEL_NAME_TO_MODEL 26*523fa7a6SAndroid Build Coastguard Workerfrom ..models.model_factory import EagerModelFactory 27*523fa7a6SAndroid Build Coastguard Workerfrom . import MODEL_NAME_TO_OPTIONS 28*523fa7a6SAndroid Build Coastguard Workerfrom .quantization.utils import quantize 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard WorkerFORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 32*523fa7a6SAndroid Build Coastguard Workerlogging.basicConfig(level=logging.INFO, format=FORMAT) 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__": 36*523fa7a6SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 37*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 38*523fa7a6SAndroid Build Coastguard Worker "-m", 39*523fa7a6SAndroid Build Coastguard Worker "--model_name", 40*523fa7a6SAndroid Build Coastguard Worker required=True, 41*523fa7a6SAndroid Build Coastguard Worker help=f"Model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}", 42*523fa7a6SAndroid Build Coastguard Worker ) 43*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 44*523fa7a6SAndroid Build Coastguard Worker "-q", 45*523fa7a6SAndroid Build Coastguard Worker "--quantize", 46*523fa7a6SAndroid Build Coastguard Worker action="store_true", 47*523fa7a6SAndroid Build Coastguard Worker required=False, 48*523fa7a6SAndroid Build Coastguard Worker default=False, 49*523fa7a6SAndroid Build Coastguard Worker help="Produce an 8-bit quantized model", 50*523fa7a6SAndroid Build Coastguard Worker ) 51*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 52*523fa7a6SAndroid Build Coastguard Worker "-d", 53*523fa7a6SAndroid Build Coastguard Worker "--delegate", 54*523fa7a6SAndroid Build Coastguard Worker action="store_true", 55*523fa7a6SAndroid Build Coastguard Worker required=False, 56*523fa7a6SAndroid Build Coastguard Worker default=True, 57*523fa7a6SAndroid Build Coastguard Worker help="Produce an XNNPACK delegated model", 58*523fa7a6SAndroid Build Coastguard Worker ) 59*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 60*523fa7a6SAndroid Build Coastguard Worker "-r", 61*523fa7a6SAndroid Build Coastguard Worker "--etrecord", 62*523fa7a6SAndroid Build Coastguard Worker required=False, 63*523fa7a6SAndroid Build Coastguard Worker help="Generate and save an ETRecord to the given file location", 64*523fa7a6SAndroid Build Coastguard Worker ) 65*523fa7a6SAndroid Build Coastguard Worker parser.add_argument("-o", "--output_dir", default=".", help="output directory") 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker args = parser.parse_args() 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker if not args.delegate: 70*523fa7a6SAndroid Build Coastguard Worker raise NotImplementedError( 71*523fa7a6SAndroid Build Coastguard Worker "T161880157: Quantization-only without delegation is not supported yet" 72*523fa7a6SAndroid Build Coastguard Worker ) 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker if args.model_name not in MODEL_NAME_TO_OPTIONS and args.quantize: 75*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 76*523fa7a6SAndroid Build Coastguard Worker f"Model {args.model_name} is not a valid name. or not quantizable right now, " 77*523fa7a6SAndroid Build Coastguard Worker "please contact executorch team if you want to learn why or how to support " 78*523fa7a6SAndroid Build Coastguard Worker "quantization for the requested model" 79*523fa7a6SAndroid Build Coastguard Worker f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}." 80*523fa7a6SAndroid Build Coastguard Worker ) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker model, example_inputs, _, _ = EagerModelFactory.create_model( 83*523fa7a6SAndroid Build Coastguard Worker *MODEL_NAME_TO_MODEL[args.model_name] 84*523fa7a6SAndroid Build Coastguard Worker ) 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker model = model.eval() 87*523fa7a6SAndroid Build Coastguard Worker # pre-autograd export. eventually this will become torch.export 88*523fa7a6SAndroid Build Coastguard Worker ep = torch.export.export_for_training(model, example_inputs) 89*523fa7a6SAndroid Build Coastguard Worker model = ep.module() 90*523fa7a6SAndroid Build Coastguard Worker 91*523fa7a6SAndroid Build Coastguard Worker if args.quantize: 92*523fa7a6SAndroid Build Coastguard Worker logging.info("Quantizing Model...") 93*523fa7a6SAndroid Build Coastguard Worker # TODO(T165162973): This pass shall eventually be folded into quantizer 94*523fa7a6SAndroid Build Coastguard Worker model = quantize(model, example_inputs) 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker edge = to_edge_transform_and_lower( 97*523fa7a6SAndroid Build Coastguard Worker ep, 98*523fa7a6SAndroid Build Coastguard Worker partitioner=[XnnpackPartitioner()], 99*523fa7a6SAndroid Build Coastguard Worker compile_config=EdgeCompileConfig( 100*523fa7a6SAndroid Build Coastguard Worker _check_ir_validity=False if args.quantize else True, 101*523fa7a6SAndroid Build Coastguard Worker _skip_dim_order=True, # TODO(T182187531): enable dim order in xnnpack 102*523fa7a6SAndroid Build Coastguard Worker ), 103*523fa7a6SAndroid Build Coastguard Worker ) 104*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}") 105*523fa7a6SAndroid Build Coastguard Worker 106*523fa7a6SAndroid Build Coastguard Worker # this is needed for the ETRecord as lowering modifies the graph in-place 107*523fa7a6SAndroid Build Coastguard Worker edge_copy = copy.deepcopy(edge) 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker exec_prog = edge.to_executorch( 110*523fa7a6SAndroid Build Coastguard Worker config=ExecutorchBackendConfig(extract_delegate_segments=False) 111*523fa7a6SAndroid Build Coastguard Worker ) 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker if args.etrecord is not None: 114*523fa7a6SAndroid Build Coastguard Worker generate_etrecord(args.etrecord, edge_copy, exec_prog) 115*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Saved ETRecord to {args.etrecord}") 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Worker quant_tag = "q8" if args.quantize else "fp32" 118*523fa7a6SAndroid Build Coastguard Worker model_name = f"{args.model_name}_xnnpack_{quant_tag}" 119*523fa7a6SAndroid Build Coastguard Worker save_pte_program(exec_prog, model_name, args.output_dir) 120