1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/core/error.h> 12 #include <executorch/runtime/core/event_tracer_hooks_delegate.h> 13 14 #include <xnnpack.h> 15 #include <vector> 16 17 namespace executorch { 18 namespace backends { 19 namespace xnnpack { 20 namespace delegate { 21 namespace profiling { 22 23 enum class XNNProfilerState { Uninitialized, Ready, Running }; 24 25 class XNNProfiler { 26 public: 27 XNNProfiler(); 28 29 /** 30 * Initialize the profiler. This must be called after model is 31 * compiled and before calling begin_execution. 32 */ 33 executorch::runtime::Error initialize(xnn_runtime_t runtime); 34 35 /** 36 * Start a new profiling session. This is typically invoked 37 * immediately before invoking the XNNPACK runtime as part 38 * of a forward pass. 39 */ 40 executorch::runtime::Error start( 41 executorch::runtime::EventTracer* event_tracer); 42 43 /** 44 * End a profiling session. This is typically invoked immediately 45 * after the XNNPACK runtime invocation completes. 46 */ 47 executorch::runtime::Error end(); 48 49 private: 50 #if defined(ET_EVENT_TRACER_ENABLED) || defined(ENABLE_XNNPACK_PROFILING) 51 executorch::runtime::EventTracer* event_tracer_; 52 xnn_runtime_t runtime_; 53 XNNProfilerState state_; 54 55 size_t op_count_; 56 std::vector<char> op_names_; 57 std::vector<uint64_t> op_timings_; 58 uint64_t run_count_; 59 et_timestamp_t start_time_; 60 61 #ifdef ENABLE_XNNPACK_PROFILING 62 // State needed to track average timing. Track the running sum of 63 // timing for each op, as well as the number of invocations. The 64 // running average can be found as sum / run_count. 65 std::vector<uint64_t> op_timings_sum_; 66 #endif 67 68 executorch::runtime::Error get_runtime_operator_names(); 69 executorch::runtime::Error get_runtime_num_operators(); 70 executorch::runtime::Error get_runtime_operator_timings(); 71 72 void log_operator_timings(); 73 74 /** 75 * Submit the trace to the ET event tracer. 76 */ 77 void submit_trace(); 78 #endif 79 }; 80 81 } // namespace profiling 82 } // namespace delegate 83 } // namespace xnnpack 84 } // namespace backends 85 } // namespace executorch 86