xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/profiling/memory_usage_monitor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_
17 #define TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_
18 
19 #include <memory>
20 #include <thread>  // NOLINT(build/c++11)
21 
22 #include "absl/synchronization/notification.h"
23 #include "absl/time/clock.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/lite/profiling/memory_info.h"
26 
27 namespace tflite {
28 namespace profiling {
29 namespace memory {
30 
31 // This class could help to tell the peak memory footprint of a running program.
32 // It achieves this by spawning a thread to check the memory usage periodically
33 // at a pre-defined frequency.
34 class MemoryUsageMonitor {
35  public:
36   // A helper class that does memory usage sampling. This allows injecting an
37   // external dependency for the sake of testing or providing platform-specific
38   // implementations.
39   class Sampler {
40    public:
~Sampler()41     virtual ~Sampler() {}
IsSupported()42     virtual bool IsSupported() { return MemoryUsage::IsSupported(); }
GetMemoryUsage()43     virtual MemoryUsage GetMemoryUsage() {
44       return tflite::profiling::memory::GetMemoryUsage();
45     }
SleepFor(const absl::Duration & duration)46     virtual void SleepFor(const absl::Duration& duration) {
47       absl::SleepFor(duration);
48     }
49   };
50 
51   static constexpr float kInvalidMemUsageMB = -1.0f;
52 
53   explicit MemoryUsageMonitor(int sampling_interval_ms = 50)
MemoryUsageMonitor(sampling_interval_ms,std::make_unique<Sampler> ())54       : MemoryUsageMonitor(sampling_interval_ms, std::make_unique<Sampler>()) {}
55   MemoryUsageMonitor(int sampling_interval_ms,
56                      std::unique_ptr<Sampler> sampler);
~MemoryUsageMonitor()57   ~MemoryUsageMonitor() { StopInternal(); }
58 
59   void Start();
60   void Stop();
61 
62   // For simplicity, we will return kInvalidMemUsageMB for the either following
63   // conditions:
64   // 1. getting memory usage isn't supported on the platform.
65   // 2. the memory usage is being monitored (i.e. we've created the
66   // 'check_memory_thd_'.
GetPeakMemUsageInMB()67   float GetPeakMemUsageInMB() const {
68     if (!is_supported_ || check_memory_thd_ != nullptr) {
69       return kInvalidMemUsageMB;
70     }
71     return peak_max_rss_kb_ / 1024.0;
72   }
73 
74   MemoryUsageMonitor(MemoryUsageMonitor&) = delete;
75   MemoryUsageMonitor& operator=(const MemoryUsageMonitor&) = delete;
76   MemoryUsageMonitor(MemoryUsageMonitor&&) = delete;
77   MemoryUsageMonitor& operator=(const MemoryUsageMonitor&&) = delete;
78 
79  private:
80   void StopInternal();
81 
82   std::unique_ptr<Sampler> sampler_ = nullptr;
83   bool is_supported_ = false;
84   std::unique_ptr<absl::Notification> stop_signal_ = nullptr;
85   absl::Duration sampling_interval_;
86   std::unique_ptr<std::thread> check_memory_thd_ = nullptr;
87   int64_t peak_max_rss_kb_ = static_cast<int64_t>(kInvalidMemUsageMB * 1024);
88 };
89 
90 }  // namespace memory
91 }  // namespace profiling
92 }  // namespace tflite
93 
94 #endif  // TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_
95