1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6# Example script for exporting simple models to flatbuffer 7 8import argparse 9import copy 10import logging 11 12import torch 13from examples.apple.mps.scripts.bench_utils import bench_torch, compare_outputs 14from executorch import exir 15from executorch.backends.apple.mps import MPSBackend 16from executorch.backends.apple.mps.partition import MPSPartitioner 17from executorch.devtools import BundledProgram, generate_etrecord 18from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite 19from executorch.devtools.bundled_program.serialize import ( 20 serialize_from_bundled_program_to_flatbuffer, 21) 22 23from executorch.exir import ( 24 EdgeCompileConfig, 25 EdgeProgramManager, 26 ExecutorchProgramManager, 27) 28from executorch.exir.backend.backend_api import to_backend 29from executorch.exir.backend.backend_details import CompileSpec 30from executorch.exir.capture._config import ExecutorchBackendConfig 31from executorch.extension.export_util.utils import export_to_edge, save_pte_program 32 33from ....models import MODEL_NAME_TO_MODEL 34from ....models.model_factory import EagerModelFactory 35 36FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 37logging.basicConfig(level=logging.INFO, format=FORMAT) 38 39 40def get_bundled_program(executorch_program, example_inputs, expected_output): 41 method_test_suites = [ 42 MethodTestSuite( 43 method_name="forward", 44 test_cases=[ 45 MethodTestCase( 46 inputs=example_inputs, expected_outputs=[expected_output] 47 ) 48 ], 49 ) 50 ] 51 logging.info(f"Expected output: {expected_output}") 52 53 bundled_program = BundledProgram(executorch_program, method_test_suites) 54 bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( 55 bundled_program 56 ) 57 return bundled_program_buffer 58 59 60def parse_args(): 61 parser = argparse.ArgumentParser() 62 parser.add_argument( 63 "-m", 64 "--model_name", 65 required=True, 66 help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", 67 ) 68 69 parser.add_argument( 70 "--use_fp16", 71 default=True, 72 action=argparse.BooleanOptionalAction, 73 help="Whether to automatically convert float32 operations to float16 operations.", 74 ) 75 76 parser.add_argument( 77 "--use_partitioner", 78 default=True, 79 action=argparse.BooleanOptionalAction, 80 help="Use MPS partitioner to run the model instead of using whole graph lowering.", 81 ) 82 83 parser.add_argument( 84 "--bench_pytorch", 85 default=False, 86 action=argparse.BooleanOptionalAction, 87 help="Bench ExecuTorch MPS foward pass with PyTorch MPS forward pass.", 88 ) 89 90 parser.add_argument( 91 "-b", 92 "--bundled", 93 action="store_true", 94 required=False, 95 default=False, 96 help="Flag for bundling inputs and outputs in the final flatbuffer program", 97 ) 98 99 parser.add_argument( 100 "-c", 101 "--check_correctness", 102 action="store_true", 103 required=False, 104 default=False, 105 help="Whether to compare the ExecuTorch MPS results with the PyTorch forward pass", 106 ) 107 108 parser.add_argument( 109 "--generate_etrecord", 110 action="store_true", 111 required=False, 112 default=False, 113 help="Generate ETRecord metadata to link with runtime results (used for profiling)", 114 ) 115 116 parser.add_argument( 117 "--checkpoint", 118 required=False, 119 default=None, 120 help="checkpoing for llama model", 121 ) 122 123 parser.add_argument( 124 "--params", 125 required=False, 126 default=None, 127 help="params for llama model", 128 ) 129 130 args = parser.parse_args() 131 return args 132 133 134def get_model_config(args): 135 model_config = {} 136 model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0] 137 model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1] 138 139 if args.model_name == "llama2": 140 if args.checkpoint: 141 model_config["checkpoint"] = args.checkpoint 142 if args.params: 143 model_config["params"] = args.params 144 model_config["use_kv_cache"] = True 145 return model_config 146 147 148if __name__ == "__main__": 149 args = parse_args() 150 151 if args.model_name not in MODEL_NAME_TO_MODEL: 152 raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") 153 154 model_config = get_model_config(args) 155 model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config) 156 157 model = model.eval() 158 159 # Deep copy the model inputs to check against PyTorch forward pass 160 if args.check_correctness or args.bench_pytorch: 161 model_copy = copy.deepcopy(model) 162 inputs_copy = [] 163 for t in example_inputs: 164 inputs_copy.append(t.detach().clone()) 165 inputs_copy = tuple(inputs_copy) 166 167 # pre-autograd export. eventually this will become torch.export 168 with torch.no_grad(): 169 model = torch.export.export_for_training(model, example_inputs).module() 170 edge: EdgeProgramManager = export_to_edge( 171 model, 172 example_inputs, 173 edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), 174 ) 175 176 edge_program_manager_copy = copy.deepcopy(edge) 177 178 compile_specs = [CompileSpec("use_fp16", bytes([args.use_fp16]))] 179 180 logging.info(f"Edge IR graph:\n{edge.exported_program().graph}") 181 if args.use_partitioner: 182 edge = edge.to_backend(MPSPartitioner(compile_specs=compile_specs)) 183 logging.info(f"Lowered graph:\n{edge.exported_program().graph}") 184 185 executorch_program = edge.to_executorch( 186 config=ExecutorchBackendConfig(extract_delegate_segments=False) 187 ) 188 else: 189 lowered_module = to_backend( 190 MPSBackend.__name__, edge.exported_program(), compile_specs 191 ) 192 executorch_program: ExecutorchProgramManager = export_to_edge( 193 lowered_module, 194 example_inputs, 195 edge_compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 196 ).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False)) 197 198 dtype = "float16" if args.use_fp16 else "float32" 199 model_name = f"{args.model_name}_mps_{dtype}" 200 201 if args.bundled: 202 expected_output = model(*example_inputs) 203 bundled_program_buffer = get_bundled_program( 204 executorch_program, example_inputs, expected_output 205 ) 206 model_name = f"{model_name}_bundled.pte" 207 208 if args.generate_etrecord: 209 etrecord_path = "etrecord.bin" 210 logging.info("generating etrecord.bin") 211 generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program) 212 213 if args.bundled: 214 with open(model_name, "wb") as file: 215 file.write(bundled_program_buffer) 216 logging.info(f"Saved bundled program to {model_name}") 217 else: 218 save_pte_program(executorch_program, model_name) 219 220 if args.bench_pytorch: 221 bench_torch(executorch_program, model_copy, example_inputs, model_name) 222 223 if args.check_correctness: 224 compare_outputs( 225 executorch_program, model_copy, inputs_copy, model_name, args.use_fp16 226 ) 227