1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6/** 7 * @file 8 * 9 * This tool can run Executorch model files that use operators that 10 * are covered by the MPSDelegate or the portable kernels. 11 * 12 * It uses the original bundled input data from the flatbuffer file. 13 */ 14 15#include <memory> 16#include <numeric> 17#include <iomanip> 18#include <iostream> 19 20#include <gflags/gflags.h> 21 22#include <executorch/extension/data_loader/buffer_data_loader.h> 23#include <executorch/extension/data_loader/file_data_loader.h> 24#include <executorch/extension/evalue_util/print_evalue.h> 25#include <executorch/extension/runner_util/inputs.h> 26#include <executorch/runtime/core/result.h> 27#include <executorch/runtime/executor/method.h> 28#include <executorch/runtime/executor/program.h> 29#include <executorch/runtime/platform/log.h> 30#include <executorch/runtime/platform/profiler.h> 31#include <executorch/runtime/platform/runtime.h> 32#include <executorch/runtime/platform/runtime.h> 33#include <executorch/devtools/bundled_program/bundled_program.h> 34#include <executorch/devtools/etdump/etdump_flatcc.h> 35 36#include <chrono> 37using namespace std::chrono; 38 39static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB 40 41DEFINE_string(model_path, "model.ff", "Model serialized in flatbuffer format."); 42DEFINE_string( 43 prof_result_path, 44 "prof_result.bin", 45 "Executorch profiler output path."); 46 47DEFINE_bool( 48 bundled_program, 49 false, 50 "True for running bundled program, false for executorch_flatbuffer::program"); 51 52DEFINE_int32( 53 testset_idx, 54 0, 55 "Index of bundled verification set to be run " 56 "by bundled model for verification"); 57 58DEFINE_int32( 59 num_runs, 60 1, 61 "Number of total runs"); 62 63DEFINE_bool( 64 profile, 65 false, 66 "True for showing profile data (e.g execution time)"); 67 68DEFINE_bool( 69 skip_warmup, 70 false, 71 "If true, a warmup iteration won't be executed."); 72 73DEFINE_string( 74 etdump_path, 75 "etdump.etdp", 76 "If etdump generation is enabled an etdump will be written out to this path"); 77 78DEFINE_bool( 79 print_output, 80 false, 81 "Print the output of the ET model to stdout, if needs."); 82 83DEFINE_bool(dump_outputs, false, "Dump outputs to etdump file"); 84 85DEFINE_bool( 86 dump_intermediate_outputs, 87 false, 88 "Dump intermediate outputs to etdump file."); 89 90DEFINE_string( 91 debug_output_path, 92 "debug_output.bin", 93 "Path to dump debug outputs to."); 94 95DEFINE_int32( 96 debug_buffer_size, 97 262144, // 256 KB 98 "Size of the debug buffer in bytes to allocate for intermediate outputs and program outputs logging."); 99 100using executorch::etdump::ETDumpGen; 101using executorch::etdump::ETDumpResult; 102using executorch::extension::BufferCleanup; 103using executorch::extension::BufferDataLoader; 104using executorch::extension::FileDataLoader; 105using executorch::runtime::DataLoader; 106using executorch::runtime::EValue; 107using executorch::runtime::Error; 108using executorch::runtime::EventTracerDebugLogLevel; 109using executorch::runtime::FreeableBuffer; 110using executorch::runtime::HierarchicalAllocator; 111using executorch::runtime::MemoryAllocator; 112using executorch::runtime::MemoryManager; 113using executorch::runtime::Method; 114using executorch::runtime::MethodMeta; 115using executorch::runtime::Program; 116using executorch::runtime::Result; 117using executorch::runtime::Span; 118 119namespace bundled_program = executorch::bundled_program; 120 121int main(int argc, char** argv) { 122 { 123 const char* usage = R"(MPS Executor Runner. Sample usage: 124 mps_executor_runner --model_path model.pte)"; 125 gflags::SetUsageMessage(usage); 126 } 127 128 if (argc == 1) { 129 ET_LOG(Error, "No options provided."); 130 gflags::ShowUsageWithFlags(argv[0]); 131 return 1; 132 } 133 134 executorch::runtime::runtime_init(); 135 136 gflags::ParseCommandLineFlags(&argc, &argv, true); 137 if (argc != 1) { 138 std::string msg = "Extra commandline args:"; 139 for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) { 140 msg += std::string(" ") + argv[i]; 141 } 142 ET_LOG(Error, "%s", msg.c_str()); 143 return 1; 144 } 145 146 // Create a loader to get the data of the program file. There are other 147 // DataLoaders that use mmap() or point to data that's already in memory, and 148 // users can create their own DataLoaders to load from arbitrary sources. 149 const char* model_path = FLAGS_model_path.c_str(); 150 Result<FileDataLoader> loader = FileDataLoader::from(model_path); 151 ET_CHECK_MSG( 152 loader.ok(), "FileDataLoader::from() failed: 0x%" PRIx32, loader.error()); 153 154 // Read in the entire file. 155 Result<FreeableBuffer> file_data = loader->load(0, loader->size().get(), DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program)); 156 ET_CHECK_MSG( 157 file_data.ok(), 158 "Could not load contents of file '%s': 0x%x", 159 model_path, 160 (unsigned int)file_data.error()); 161 162 // Find the offset to the embedded Program. 163 const void* program_data; 164 size_t program_data_len; 165 Error status = bundled_program::get_program_data( 166 const_cast<void*>(file_data->data()), 167 file_data->size(), 168 &program_data, 169 &program_data_len); 170 ET_CHECK_MSG( 171 status == Error::Ok, 172 "get_program_data() failed on file '%s': 0x%x", 173 model_path, 174 (unsigned int)status); 175 176 // Wrap the buffer in a DataLoader. 177 auto buffer_data_loader = 178 BufferDataLoader(program_data, program_data_len); 179 180 // Parse the program file. This is immutable, and can also be reused between 181 // multiple execution invocations across multiple threads. 182 Result<Program> program = Program::load(&buffer_data_loader); 183 if (!program.ok()) { 184 ET_LOG(Error, "Failed to parse model file %s", model_path); 185 return 1; 186 } 187 ET_LOG(Info, "Model file %s is loaded.", model_path); 188 189 // Use the first method in the program. 190 const char* method_name = nullptr; 191 { 192 const auto method_name_result = program->get_method_name(0); 193 ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); 194 method_name = *method_name_result; 195 } 196 ET_LOG(Info, "Using method %s", method_name); 197 198 // MethodMeta describes the memory requirements of the method. 199 Result<MethodMeta> method_meta = program->method_meta(method_name); 200 ET_CHECK_MSG( 201 method_meta.ok(), 202 "Failed to get method_meta for %s: 0x%" PRIx32, 203 method_name, 204 (uint32_t)method_meta.error()); 205 206 // 207 // The runtime does not use malloc/new; it allocates all memory using the 208 // MemoryManger provided by the client. Clients are responsible for allocating 209 // the memory ahead of time, or providing MemoryAllocator subclasses that can 210 // do it dynamically. 211 // 212 213 // The method allocator is used to allocate all dynamic C++ metadata/objects 214 // used to represent the loaded method. This allocator is only used during 215 // loading a method of the program, which will return an error if there was 216 // not enough memory. 217 // 218 // The amount of memory required depends on the loaded method and the runtime 219 // code itself. The amount of memory here is usually determined by running the 220 // method and seeing how much memory is actually used, though it's possible to 221 // subclass MemoryAllocator so that it calls malloc() under the hood (see 222 // MallocMemoryAllocator). 223 // 224 // In this example we use a statically allocated memory pool. 225 MemoryAllocator method_allocator{ 226 MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)}; 227 228 // The memory-planned buffers will back the mutable tensors used by the 229 // method. The sizes of these buffers were determined ahead of time during the 230 // memory-planning pasees. 231 // 232 // Each buffer typically corresponds to a different hardware memory bank. Most 233 // mobile environments will only have a single buffer. Some embedded 234 // environments may have more than one for, e.g., slow/large DRAM and 235 // fast/small SRAM, or for memory associated with particular cores. 236 std::vector<std::unique_ptr<uint8_t[]>> planned_buffers; // Owns the memory 237 std::vector<Span<uint8_t>> planned_spans; // Passed to the allocator 238 size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers(); 239 for (size_t id = 0; id < num_memory_planned_buffers; ++id) { 240 // .get() will always succeed because id < num_memory_planned_buffers. 241 size_t buffer_size = 242 static_cast<size_t>(method_meta->memory_planned_buffer_size(id).get()); 243 ET_LOG(Info, "Setting up planned buffer %zu, size %zu.", id, buffer_size); 244 planned_buffers.push_back(std::make_unique<uint8_t[]>(buffer_size)); 245 planned_spans.push_back({planned_buffers.back().get(), buffer_size}); 246 } 247 HierarchicalAllocator planned_memory( 248 {planned_spans.data(), planned_spans.size()}); 249 250 // Assemble all of the allocators into the MemoryManager that the Executor 251 // will use. 252 MemoryManager memory_manager(&method_allocator, &planned_memory); 253 254 // 255 // Load the method from the program, using the provided allocators. Running 256 // the method can mutate the memory-planned buffers, so the method should only 257 // be used by a single thread at at time, but it can be reused. 258 // 259 260 ETDumpGen etdump_gen; 261 Result<Method> method = 262 program->load_method(method_name, &memory_manager, &etdump_gen); 263 ET_CHECK_MSG( 264 method.ok(), 265 "Loading of method %s failed with status 0x%" PRIx32, 266 method_name, 267 (uint32_t)method.error()); 268 ET_LOG(Info, "Method loaded."); 269 270 void* debug_buffer = malloc(FLAGS_debug_buffer_size); 271 if (FLAGS_dump_intermediate_outputs) { 272 Span<uint8_t> buffer((uint8_t*)debug_buffer, FLAGS_debug_buffer_size); 273 etdump_gen.set_debug_buffer(buffer); 274 etdump_gen.set_event_tracer_debug_level( 275 EventTracerDebugLogLevel::kIntermediateOutputs); 276 } else if (FLAGS_dump_outputs) { 277 Span<uint8_t> buffer((uint8_t*)debug_buffer, FLAGS_debug_buffer_size); 278 etdump_gen.set_debug_buffer(buffer); 279 etdump_gen.set_event_tracer_debug_level( 280 EventTracerDebugLogLevel::kProgramOutputs); 281 } 282 283 // Prepare the inputs. 284 std::unique_ptr<BufferCleanup> inputs; 285 if (FLAGS_bundled_program) { 286 ET_LOG(Info, "Loading bundled program..."); 287 // Use the inputs embedded in the bundled program. 288 status = bundled_program::load_bundled_input( 289 *method, 290 file_data->data(), 291 FLAGS_testset_idx); 292 ET_CHECK_MSG( 293 status == Error::Ok, 294 "LoadBundledInput failed with status 0x%" PRIx32, 295 status); 296 } else { 297 ET_LOG(Info, "Loading non-bundled program...\n"); 298 // Use ones-initialized inputs. 299 auto inputs_result = executorch::extension::prepare_input_tensors(*method); 300 if (inputs_result.ok()) { 301 // Will free the inputs when destroyed. 302 inputs = 303 std::make_unique<BufferCleanup>(std::move(inputs_result.get())); 304 } 305 } 306 ET_LOG(Info, "Inputs prepared."); 307 308 int num_iterations = FLAGS_num_runs + (FLAGS_skip_warmup ? 0 : 1); 309 std::vector<float> exec_times; 310 exec_times.reserve(FLAGS_num_runs); 311 for (int i = 0; i < num_iterations; i++) { 312 auto start_exec_time = high_resolution_clock::now(); 313 // Run the model. 314 Error status = method->execute(); 315 auto end_exec_time = high_resolution_clock::now(); 316 auto duration = duration_cast<microseconds>(end_exec_time - start_exec_time); 317 exec_times.push_back(duration.count()); 318 if (FLAGS_profile) { 319 const float miliseconds = static_cast<float>(duration.count()) / 1000.f; 320 ET_LOG(Info, "[Run %d] Inference time: %.3f miliseconds", i, miliseconds); 321 } 322 ET_CHECK_MSG( 323 status == Error::Ok, 324 "Execution of method %s failed with status 0x%" PRIx32, 325 method_name, 326 status); 327 } 328 if (FLAGS_profile && FLAGS_num_runs) { 329 auto itr = exec_times.begin(); 330 if (!FLAGS_skip_warmup) 331 itr++; 332 333 const float avg_time = (std::reduce(itr, exec_times.end()) / static_cast<float>(FLAGS_num_runs)) / 1000.f; 334 std::cout << "Average inference time: " << std::setprecision(2) << std::fixed << avg_time << " miliseconds\n"; 335 } 336 ET_LOG(Info, "Model executed successfully."); 337 338 // Print the outputs. 339 std::vector<EValue> outputs(method->outputs_size()); 340 status = method->get_outputs(outputs.data(), outputs.size()); 341 ET_CHECK(status == Error::Ok); 342 // Print the first and last 100 elements of long lists of scalars. 343 std::cout << executorch::extension::evalue_edge_items(100); 344 for (int i = 0; i < outputs.size(); ++i) { 345 std::cout << "Output " << i << ": " << outputs[i] << std::endl; 346 } 347 348 // Dump the etdump data containing profiling/debugging data to the specified 349 // file. 350 ETDumpResult result = etdump_gen.get_etdump_data(); 351 if (result.buf != nullptr && result.size > 0) { 352 FILE* f = fopen(FLAGS_etdump_path.c_str(), "w+"); 353 fwrite((uint8_t*)result.buf, 1, result.size, f); 354 fclose(f); 355 free(result.buf); 356 } 357 358 // Handle the outputs. 359 if (FLAGS_bundled_program) { 360 double rtol = 1e-05; 361 double atol = 1e-08; 362 if (strstr(model_path, "fp16")) { 363 rtol = 1e-01; 364 atol = 1e-01; 365 } else if (strstr(model_path, "mv3") || 366 strstr(model_path, "mv2") || 367 strstr(model_path, "conv") || 368 strstr(model_path, "vit") || 369 strstr(model_path, "resnet18") || 370 strstr(model_path, "resnet50") || 371 strstr(model_path, "emformer") || 372 strstr(model_path, "emformer_transcribe") || 373 strstr(model_path, "emformer_join") || 374 strstr(model_path, "edsr") || 375 strstr(model_path, "llama") || 376 strstr(model_path, "ic3") || 377 strstr(model_path, "ic4")) { 378 atol = 1e-04; 379 } else if (strstr(model_path, "mobilebert")) { 380 atol = 1e-01; 381 rtol = 1e-01; 382 } 383 status = bundled_program::verify_method_outputs( 384 *method, 385 file_data->data(), 386 FLAGS_testset_idx, 387 rtol, 388 atol 389 ); 390 ET_CHECK_MSG( 391 status == Error::Ok, 392 "Bundle verification failed with status 0x%" PRIx32, 393 status); 394 ET_LOG(Info, "Model verified successfully."); 395 } 396 397 if (FLAGS_dump_outputs || FLAGS_dump_intermediate_outputs) { 398 FILE* f = fopen(FLAGS_debug_output_path.c_str(), "w+"); 399 fwrite((uint8_t*)debug_buffer, 1, FLAGS_debug_buffer_size, f); 400 fclose(f); 401 } 402 free(debug_buffer); 403 404 return 0; 405} 406