1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_ 17 #define TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_ 18 19 #include <memory> 20 #include <thread> // NOLINT(build/c++11) 21 22 #include "absl/synchronization/notification.h" 23 #include "absl/time/clock.h" 24 #include "absl/time/time.h" 25 #include "tensorflow/lite/profiling/memory_info.h" 26 27 namespace tflite { 28 namespace profiling { 29 namespace memory { 30 31 // This class could help to tell the peak memory footprint of a running program. 32 // It achieves this by spawning a thread to check the memory usage periodically 33 // at a pre-defined frequency. 34 class MemoryUsageMonitor { 35 public: 36 // A helper class that does memory usage sampling. This allows injecting an 37 // external dependency for the sake of testing or providing platform-specific 38 // implementations. 39 class Sampler { 40 public: ~Sampler()41 virtual ~Sampler() {} IsSupported()42 virtual bool IsSupported() { return MemoryUsage::IsSupported(); } GetMemoryUsage()43 virtual MemoryUsage GetMemoryUsage() { 44 return tflite::profiling::memory::GetMemoryUsage(); 45 } SleepFor(const absl::Duration & duration)46 virtual void SleepFor(const absl::Duration& duration) { 47 absl::SleepFor(duration); 48 } 49 }; 50 51 static constexpr float kInvalidMemUsageMB = -1.0f; 52 53 explicit MemoryUsageMonitor(int sampling_interval_ms = 50) MemoryUsageMonitor(sampling_interval_ms,std::make_unique<Sampler> ())54 : MemoryUsageMonitor(sampling_interval_ms, std::make_unique<Sampler>()) {} 55 MemoryUsageMonitor(int sampling_interval_ms, 56 std::unique_ptr<Sampler> sampler); ~MemoryUsageMonitor()57 ~MemoryUsageMonitor() { StopInternal(); } 58 59 void Start(); 60 void Stop(); 61 62 // For simplicity, we will return kInvalidMemUsageMB for the either following 63 // conditions: 64 // 1. getting memory usage isn't supported on the platform. 65 // 2. the memory usage is being monitored (i.e. we've created the 66 // 'check_memory_thd_'. GetPeakMemUsageInMB()67 float GetPeakMemUsageInMB() const { 68 if (!is_supported_ || check_memory_thd_ != nullptr) { 69 return kInvalidMemUsageMB; 70 } 71 return peak_max_rss_kb_ / 1024.0; 72 } 73 74 MemoryUsageMonitor(MemoryUsageMonitor&) = delete; 75 MemoryUsageMonitor& operator=(const MemoryUsageMonitor&) = delete; 76 MemoryUsageMonitor(MemoryUsageMonitor&&) = delete; 77 MemoryUsageMonitor& operator=(const MemoryUsageMonitor&&) = delete; 78 79 private: 80 void StopInternal(); 81 82 std::unique_ptr<Sampler> sampler_ = nullptr; 83 bool is_supported_ = false; 84 std::unique_ptr<absl::Notification> stop_signal_ = nullptr; 85 absl::Duration sampling_interval_; 86 std::unique_ptr<std::thread> check_memory_thd_ = nullptr; 87 int64_t peak_max_rss_kb_ = static_cast<int64_t>(kInvalidMemUsageMB * 1024); 88 }; 89 90 } // namespace memory 91 } // namespace profiling 92 } // namespace tflite 93 94 #endif // TENSORFLOW_LITE_PROFILING_MEMORY_USAGE_MONITOR_H_ 95