xref: /aosp_15_r20/external/executorch/examples/apple/mps/executor_runner/mps_executor_runner.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6/**
7 * @file
8 *
9 * This tool can run Executorch model files that use operators that
10 * are covered by the MPSDelegate or the portable kernels.
11 *
12 * It uses the original bundled input data from the flatbuffer file.
13 */
14
15#include <memory>
16#include <numeric>
17#include <iomanip>
18#include <iostream>
19
20#include <gflags/gflags.h>
21
22#include <executorch/extension/data_loader/buffer_data_loader.h>
23#include <executorch/extension/data_loader/file_data_loader.h>
24#include <executorch/extension/evalue_util/print_evalue.h>
25#include <executorch/extension/runner_util/inputs.h>
26#include <executorch/runtime/core/result.h>
27#include <executorch/runtime/executor/method.h>
28#include <executorch/runtime/executor/program.h>
29#include <executorch/runtime/platform/log.h>
30#include <executorch/runtime/platform/profiler.h>
31#include <executorch/runtime/platform/runtime.h>
32#include <executorch/runtime/platform/runtime.h>
33#include <executorch/devtools/bundled_program/bundled_program.h>
34#include <executorch/devtools/etdump/etdump_flatcc.h>
35
36#include <chrono>
37using namespace std::chrono;
38
39static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB
40
41DEFINE_string(model_path, "model.ff", "Model serialized in flatbuffer format.");
42DEFINE_string(
43    prof_result_path,
44    "prof_result.bin",
45    "Executorch profiler output path.");
46
47DEFINE_bool(
48    bundled_program,
49    false,
50    "True for running bundled program, false for executorch_flatbuffer::program");
51
52DEFINE_int32(
53    testset_idx,
54    0,
55    "Index of bundled verification set to be run "
56    "by bundled model for verification");
57
58DEFINE_int32(
59    num_runs,
60    1,
61    "Number of total runs");
62
63DEFINE_bool(
64    profile,
65    false,
66    "True for showing profile data (e.g execution time)");
67
68DEFINE_bool(
69    skip_warmup,
70    false,
71    "If true, a warmup iteration won't be executed.");
72
73DEFINE_string(
74    etdump_path,
75    "etdump.etdp",
76    "If etdump generation is enabled an etdump will be written out to this path");
77
78DEFINE_bool(
79    print_output,
80    false,
81    "Print the output of the ET model to stdout, if needs.");
82
83DEFINE_bool(dump_outputs, false, "Dump outputs to etdump file");
84
85DEFINE_bool(
86    dump_intermediate_outputs,
87    false,
88    "Dump intermediate outputs to etdump file.");
89
90DEFINE_string(
91    debug_output_path,
92    "debug_output.bin",
93    "Path to dump debug outputs to.");
94
95DEFINE_int32(
96    debug_buffer_size,
97    262144, // 256 KB
98    "Size of the debug buffer in bytes to allocate for intermediate outputs and program outputs logging.");
99
100using executorch::etdump::ETDumpGen;
101using executorch::etdump::ETDumpResult;
102using executorch::extension::BufferCleanup;
103using executorch::extension::BufferDataLoader;
104using executorch::extension::FileDataLoader;
105using executorch::runtime::DataLoader;
106using executorch::runtime::EValue;
107using executorch::runtime::Error;
108using executorch::runtime::EventTracerDebugLogLevel;
109using executorch::runtime::FreeableBuffer;
110using executorch::runtime::HierarchicalAllocator;
111using executorch::runtime::MemoryAllocator;
112using executorch::runtime::MemoryManager;
113using executorch::runtime::Method;
114using executorch::runtime::MethodMeta;
115using executorch::runtime::Program;
116using executorch::runtime::Result;
117using executorch::runtime::Span;
118
119namespace bundled_program = executorch::bundled_program;
120
121int main(int argc, char** argv) {
122  {
123    const char* usage = R"(MPS Executor Runner. Sample usage:
124  mps_executor_runner --model_path model.pte)";
125    gflags::SetUsageMessage(usage);
126  }
127
128  if (argc == 1) {
129    ET_LOG(Error, "No options provided.");
130    gflags::ShowUsageWithFlags(argv[0]);
131    return 1;
132  }
133
134  executorch::runtime::runtime_init();
135
136  gflags::ParseCommandLineFlags(&argc, &argv, true);
137  if (argc != 1) {
138    std::string msg = "Extra commandline args:";
139    for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
140      msg += std::string(" ") + argv[i];
141    }
142    ET_LOG(Error, "%s", msg.c_str());
143    return 1;
144  }
145
146  // Create a loader to get the data of the program file. There are other
147  // DataLoaders that use mmap() or point to data that's already in memory, and
148  // users can create their own DataLoaders to load from arbitrary sources.
149  const char* model_path = FLAGS_model_path.c_str();
150  Result<FileDataLoader> loader = FileDataLoader::from(model_path);
151  ET_CHECK_MSG(
152      loader.ok(), "FileDataLoader::from() failed: 0x%" PRIx32, loader.error());
153
154  // Read in the entire file.
155  Result<FreeableBuffer> file_data = loader->load(0, loader->size().get(), DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
156  ET_CHECK_MSG(
157      file_data.ok(),
158      "Could not load contents of file '%s': 0x%x",
159      model_path,
160      (unsigned int)file_data.error());
161
162  // Find the offset to the embedded Program.
163  const void* program_data;
164  size_t program_data_len;
165  Error status = bundled_program::get_program_data(
166      const_cast<void*>(file_data->data()),
167      file_data->size(),
168      &program_data,
169      &program_data_len);
170  ET_CHECK_MSG(
171      status == Error::Ok,
172      "get_program_data() failed on file '%s': 0x%x",
173      model_path,
174      (unsigned int)status);
175
176  // Wrap the buffer in a DataLoader.
177  auto buffer_data_loader =
178      BufferDataLoader(program_data, program_data_len);
179
180  // Parse the program file. This is immutable, and can also be reused between
181  // multiple execution invocations across multiple threads.
182  Result<Program> program = Program::load(&buffer_data_loader);
183  if (!program.ok()) {
184    ET_LOG(Error, "Failed to parse model file %s", model_path);
185    return 1;
186  }
187  ET_LOG(Info, "Model file %s is loaded.", model_path);
188
189  // Use the first method in the program.
190  const char* method_name = nullptr;
191  {
192    const auto method_name_result = program->get_method_name(0);
193    ET_CHECK_MSG(method_name_result.ok(), "Program has no methods");
194    method_name = *method_name_result;
195  }
196  ET_LOG(Info, "Using method %s", method_name);
197
198  // MethodMeta describes the memory requirements of the method.
199  Result<MethodMeta> method_meta = program->method_meta(method_name);
200  ET_CHECK_MSG(
201      method_meta.ok(),
202      "Failed to get method_meta for %s: 0x%" PRIx32,
203      method_name,
204      (uint32_t)method_meta.error());
205
206  //
207  // The runtime does not use malloc/new; it allocates all memory using the
208  // MemoryManger provided by the client. Clients are responsible for allocating
209  // the memory ahead of time, or providing MemoryAllocator subclasses that can
210  // do it dynamically.
211  //
212
213  // The method allocator is used to allocate all dynamic C++ metadata/objects
214  // used to represent the loaded method. This allocator is only used during
215  // loading a method of the program, which will return an error if there was
216  // not enough memory.
217  //
218  // The amount of memory required depends on the loaded method and the runtime
219  // code itself. The amount of memory here is usually determined by running the
220  // method and seeing how much memory is actually used, though it's possible to
221  // subclass MemoryAllocator so that it calls malloc() under the hood (see
222  // MallocMemoryAllocator).
223  //
224  // In this example we use a statically allocated memory pool.
225  MemoryAllocator method_allocator{
226      MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)};
227
228  // The memory-planned buffers will back the mutable tensors used by the
229  // method. The sizes of these buffers were determined ahead of time during the
230  // memory-planning pasees.
231  //
232  // Each buffer typically corresponds to a different hardware memory bank. Most
233  // mobile environments will only have a single buffer. Some embedded
234  // environments may have more than one for, e.g., slow/large DRAM and
235  // fast/small SRAM, or for memory associated with particular cores.
236  std::vector<std::unique_ptr<uint8_t[]>> planned_buffers; // Owns the memory
237  std::vector<Span<uint8_t>> planned_spans; // Passed to the allocator
238  size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers();
239  for (size_t id = 0; id < num_memory_planned_buffers; ++id) {
240    // .get() will always succeed because id < num_memory_planned_buffers.
241    size_t buffer_size =
242        static_cast<size_t>(method_meta->memory_planned_buffer_size(id).get());
243    ET_LOG(Info, "Setting up planned buffer %zu, size %zu.", id, buffer_size);
244    planned_buffers.push_back(std::make_unique<uint8_t[]>(buffer_size));
245    planned_spans.push_back({planned_buffers.back().get(), buffer_size});
246  }
247  HierarchicalAllocator planned_memory(
248      {planned_spans.data(), planned_spans.size()});
249
250  // Assemble all of the allocators into the MemoryManager that the Executor
251  // will use.
252  MemoryManager memory_manager(&method_allocator, &planned_memory);
253
254  //
255  // Load the method from the program, using the provided allocators. Running
256  // the method can mutate the memory-planned buffers, so the method should only
257  // be used by a single thread at at time, but it can be reused.
258  //
259
260  ETDumpGen etdump_gen;
261  Result<Method> method =
262      program->load_method(method_name, &memory_manager, &etdump_gen);
263  ET_CHECK_MSG(
264      method.ok(),
265      "Loading of method %s failed with status 0x%" PRIx32,
266      method_name,
267      (uint32_t)method.error());
268  ET_LOG(Info, "Method loaded.");
269
270  void* debug_buffer = malloc(FLAGS_debug_buffer_size);
271  if (FLAGS_dump_intermediate_outputs) {
272    Span<uint8_t> buffer((uint8_t*)debug_buffer, FLAGS_debug_buffer_size);
273    etdump_gen.set_debug_buffer(buffer);
274    etdump_gen.set_event_tracer_debug_level(
275        EventTracerDebugLogLevel::kIntermediateOutputs);
276  } else if (FLAGS_dump_outputs) {
277    Span<uint8_t> buffer((uint8_t*)debug_buffer, FLAGS_debug_buffer_size);
278    etdump_gen.set_debug_buffer(buffer);
279    etdump_gen.set_event_tracer_debug_level(
280        EventTracerDebugLogLevel::kProgramOutputs);
281  }
282
283  // Prepare the inputs.
284  std::unique_ptr<BufferCleanup> inputs;
285  if (FLAGS_bundled_program) {
286    ET_LOG(Info, "Loading bundled program...");
287    // Use the inputs embedded in the bundled program.
288    status = bundled_program::load_bundled_input(
289        *method,
290        file_data->data(),
291        FLAGS_testset_idx);
292    ET_CHECK_MSG(
293        status == Error::Ok,
294        "LoadBundledInput failed with status 0x%" PRIx32,
295        status);
296  } else {
297    ET_LOG(Info, "Loading non-bundled program...\n");
298    // Use ones-initialized inputs.
299    auto inputs_result = executorch::extension::prepare_input_tensors(*method);
300    if (inputs_result.ok()) {
301      // Will free the inputs when destroyed.
302      inputs =
303          std::make_unique<BufferCleanup>(std::move(inputs_result.get()));
304    }
305  }
306  ET_LOG(Info, "Inputs prepared.");
307
308  int num_iterations = FLAGS_num_runs + (FLAGS_skip_warmup ? 0 : 1);
309  std::vector<float> exec_times;
310  exec_times.reserve(FLAGS_num_runs);
311  for (int i = 0; i < num_iterations; i++) {
312    auto start_exec_time = high_resolution_clock::now();
313    // Run the model.
314    Error status = method->execute();
315    auto end_exec_time = high_resolution_clock::now();
316    auto duration = duration_cast<microseconds>(end_exec_time - start_exec_time);
317    exec_times.push_back(duration.count());
318    if (FLAGS_profile) {
319      const float miliseconds = static_cast<float>(duration.count()) / 1000.f;
320      ET_LOG(Info, "[Run %d] Inference time: %.3f miliseconds", i, miliseconds);
321    }
322    ET_CHECK_MSG(
323        status == Error::Ok,
324        "Execution of method %s failed with status 0x%" PRIx32,
325        method_name,
326        status);
327  }
328  if (FLAGS_profile && FLAGS_num_runs) {
329    auto itr = exec_times.begin();
330    if (!FLAGS_skip_warmup)
331      itr++;
332
333    const float avg_time = (std::reduce(itr, exec_times.end()) / static_cast<float>(FLAGS_num_runs)) / 1000.f;
334    std::cout << "Average inference time: " << std::setprecision(2) << std::fixed << avg_time << " miliseconds\n";
335  }
336  ET_LOG(Info, "Model executed successfully.");
337
338  // Print the outputs.
339  std::vector<EValue> outputs(method->outputs_size());
340  status = method->get_outputs(outputs.data(), outputs.size());
341  ET_CHECK(status == Error::Ok);
342  // Print the first and last 100 elements of long lists of scalars.
343  std::cout << executorch::extension::evalue_edge_items(100);
344  for (int i = 0; i < outputs.size(); ++i) {
345    std::cout << "Output " << i << ": " << outputs[i] << std::endl;
346  }
347
348  // Dump the etdump data containing profiling/debugging data to the specified
349  // file.
350  ETDumpResult result = etdump_gen.get_etdump_data();
351  if (result.buf != nullptr && result.size > 0) {
352    FILE* f = fopen(FLAGS_etdump_path.c_str(), "w+");
353    fwrite((uint8_t*)result.buf, 1, result.size, f);
354    fclose(f);
355    free(result.buf);
356  }
357
358  // Handle the outputs.
359  if (FLAGS_bundled_program) {
360    double rtol = 1e-05;
361    double atol = 1e-08;
362    if (strstr(model_path, "fp16")) {
363      rtol = 1e-01;
364      atol = 1e-01;
365    } else if (strstr(model_path, "mv3")           ||
366        strstr(model_path, "mv2")                  ||
367        strstr(model_path, "conv")                 ||
368        strstr(model_path, "vit")                  ||
369        strstr(model_path, "resnet18")             ||
370        strstr(model_path, "resnet50")             ||
371        strstr(model_path, "emformer")             ||
372        strstr(model_path, "emformer_transcribe")  ||
373        strstr(model_path, "emformer_join")        ||
374        strstr(model_path, "edsr")                 ||
375        strstr(model_path, "llama")                ||
376        strstr(model_path, "ic3")                  ||
377        strstr(model_path, "ic4")) {
378      atol = 1e-04;
379    } else if (strstr(model_path, "mobilebert")) {
380      atol = 1e-01;
381      rtol = 1e-01;
382    }
383    status = bundled_program::verify_method_outputs(
384        *method,
385        file_data->data(),
386        FLAGS_testset_idx,
387        rtol,
388        atol
389    );
390    ET_CHECK_MSG(
391        status == Error::Ok,
392        "Bundle verification failed with status 0x%" PRIx32,
393        status);
394    ET_LOG(Info, "Model verified successfully.");
395  }
396
397  if (FLAGS_dump_outputs || FLAGS_dump_intermediate_outputs) {
398    FILE* f = fopen(FLAGS_debug_output_path.c_str(), "w+");
399    fwrite((uint8_t*)debug_buffer, 1, FLAGS_debug_buffer_size, f);
400    fclose(f);
401  }
402  free(debug_buffer);
403
404  return 0;
405}
406