xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/profiling/buffered_profiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_PROFILING_BUFFERED_PROFILER_H_
16 #define TENSORFLOW_LITE_PROFILING_BUFFERED_PROFILER_H_
17 
18 #include <cstdint>
19 #include <vector>
20 
21 #include "tensorflow/lite/core/api/profiler.h"
22 #include "tensorflow/lite/profiling/profile_buffer.h"
23 
24 namespace tflite {
25 namespace profiling {
26 
27 // Controls whether profiling is enabled or disabled and collects profiles.
28 // TFLite is used on platforms that don't have posix threads, so the profiler is
29 // kept as simple as possible. It is designed to be used only on a single
30 // thread.
31 //
32 // Profiles are collected using Scoped*Profile objects that begin and end a
33 // profile event.
34 // An example usage is shown in the example below:
35 //
36 // Say Worker class has a DoWork method and we are interested in profiling
37 // the overall execution time for DoWork and time spent in Task1 and Task2
38 // functions.
39 //
40 // class Worker {
41 //  public:
42 //   void DoWork() {
43 //    ScopedProfile(&controller, "DoWork");
44 //    Task1();
45 //    Task2();
46 //    .....
47 //   }
48 //
49 //   void Task1() {
50 //    ScopedProfile(&controller, "Task1");
51 //    ....
52 //   }
53 //
54 //   void Task2() {
55 //    ScopedProfile(&controller, "Task2");
56 //   }
57 //
58 //    Profiler profiler;
59 // }
60 //
61 // We instrument the functions that need to be profiled.
62 //
63 // Profile can be collected by enable profiling and then getting profile
64 // events.
65 //
66 //  void ProfileWorker() {
67 //    Worker worker;
68 //    worker.profiler.EnableProfiling();
69 //    worker.DoWork();
70 //    worker.profiler.DisableProfiling();
71 //    // Profiling is complete, extract profiles.
72 //    auto profile_events = worker.profiler.GetProfiles();
73 //  }
74 //
75 //
76 class BufferedProfiler : public tflite::Profiler {
77  public:
BufferedProfiler(uint32_t max_num_initial_entries,bool allow_dynamic_buffer_increase)78   BufferedProfiler(uint32_t max_num_initial_entries,
79                    bool allow_dynamic_buffer_increase)
80       : buffer_(max_num_initial_entries, false /*enabled*/,
81                 allow_dynamic_buffer_increase),
82         supported_event_types_(~static_cast<uint64_t>(
83             EventType::GENERAL_RUNTIME_INSTRUMENTATION_EVENT)) {}
84 
BufferedProfiler(uint32_t max_num_entries)85   explicit BufferedProfiler(uint32_t max_num_entries)
86       : BufferedProfiler(max_num_entries,
87                          false /*allow_dynamic_buffer_increase*/) {}
88 
BeginEvent(const char * tag,EventType event_type,int64_t event_metadata1,int64_t event_metadata2)89   uint32_t BeginEvent(const char* tag, EventType event_type,
90                       int64_t event_metadata1,
91                       int64_t event_metadata2) override {
92     if (!ShouldAddEvent(event_type)) return kInvalidEventHandle;
93     return buffer_.BeginEvent(tag, event_type, event_metadata1,
94                               event_metadata2);
95   }
96 
EndEvent(uint32_t event_handle)97   void EndEvent(uint32_t event_handle) override {
98     buffer_.EndEvent(event_handle);
99   }
100 
EndEvent(uint32_t event_handle,int64_t event_metadata1,int64_t event_metadata2)101   void EndEvent(uint32_t event_handle, int64_t event_metadata1,
102                 int64_t event_metadata2) override {
103     buffer_.EndEvent(event_handle, &event_metadata1, &event_metadata2);
104   }
105 
AddEvent(const char * tag,EventType event_type,uint64_t elapsed_time,int64_t event_metadata1,int64_t event_metadata2)106   void AddEvent(const char* tag, EventType event_type, uint64_t elapsed_time,
107                 int64_t event_metadata1, int64_t event_metadata2) override {
108     if (!ShouldAddEvent(event_type)) return;
109     buffer_.AddEvent(tag, event_type, elapsed_time, event_metadata1,
110                      event_metadata2);
111   }
112 
StartProfiling()113   void StartProfiling() { buffer_.SetEnabled(true); }
StopProfiling()114   void StopProfiling() { buffer_.SetEnabled(false); }
Reset()115   void Reset() { buffer_.Reset(); }
GetProfileEvents()116   std::vector<const ProfileEvent*> GetProfileEvents() {
117     std::vector<const ProfileEvent*> profile_events;
118     profile_events.reserve(buffer_.Size());
119     for (size_t i = 0; i < buffer_.Size(); i++) {
120       profile_events.push_back(buffer_.At(i));
121     }
122     return profile_events;
123   }
124 
125  protected:
ShouldAddEvent(EventType event_type)126   bool ShouldAddEvent(EventType event_type) {
127     return (static_cast<uint64_t>(event_type) & supported_event_types_) != 0;
128   }
129 
130  private:
GetProfileBuffer()131   ProfileBuffer* GetProfileBuffer() { return &buffer_; }
132   ProfileBuffer buffer_;
133   const uint64_t supported_event_types_;
134 };
135 
136 }  // namespace profiling
137 }  // namespace tflite
138 
139 #endif  // TENSORFLOW_LITE_PROFILING_BUFFERED_PROFILER_H_
140