1 // 2 // Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/IAsyncExecutionCallback.hpp> 9 #include <armnn/IWorkingMemHandle.hpp> 10 #include <armnn/Types.hpp> 11 12 #include <condition_variable> 13 #include <mutex> 14 #include <thread> 15 #include <queue> 16 #include <unordered_map> 17 18 namespace armnn 19 { 20 21 namespace experimental 22 { 23 24 using InferenceId = uint64_t; 25 class AsyncExecutionCallback final : public IAsyncExecutionCallback 26 { 27 private: 28 static InferenceId nextID; 29 30 public: AsyncExecutionCallback(std::queue<InferenceId> & notificationQueue,std::mutex & mutex,std::condition_variable & condition)31 AsyncExecutionCallback(std::queue<InferenceId>& notificationQueue 32 #if !defined(ARMNN_DISABLE_THREADS) 33 , std::mutex& mutex 34 , std::condition_variable& condition 35 #endif 36 ) 37 : m_NotificationQueue(notificationQueue) 38 #if !defined(ARMNN_DISABLE_THREADS) 39 , m_Mutex(mutex) 40 , m_Condition(condition) 41 #endif 42 , m_InferenceId(++nextID) 43 {} 44 ~AsyncExecutionCallback()45 ~AsyncExecutionCallback() 46 {} 47 48 void Notify(armnn::Status status, InferenceTimingPair timeTaken); 49 GetInferenceId()50 InferenceId GetInferenceId() 51 { 52 return m_InferenceId; 53 } 54 55 armnn::Status GetStatus() const; 56 HighResolutionClock GetStartTime() const; 57 HighResolutionClock GetEndTime() const; 58 59 private: 60 std::queue<InferenceId>& m_NotificationQueue; 61 #if !defined(ARMNN_DISABLE_THREADS) 62 std::mutex& m_Mutex; 63 std::condition_variable& m_Condition; 64 #endif 65 66 HighResolutionClock m_StartTime; 67 HighResolutionClock m_EndTime; 68 armnn::Status m_Status = Status::Failure; 69 InferenceId m_InferenceId; 70 }; 71 72 // Manager to create and monitor AsyncExecutionCallbacks 73 // GetNewCallback will create a callback for use in Threadpool::Schedule 74 // GetNotifiedCallback will return the first callback to be notified (finished execution) 75 class AsyncCallbackManager 76 { 77 public: 78 std::shared_ptr<AsyncExecutionCallback> GetNewCallback(); 79 std::shared_ptr<AsyncExecutionCallback> GetNotifiedCallback(); 80 81 private: 82 #if !defined(ARMNN_DISABLE_THREADS) 83 std::mutex m_Mutex; 84 std::condition_variable m_Condition; 85 #endif 86 std::unordered_map<InferenceId, std::shared_ptr<AsyncExecutionCallback>> m_Callbacks; 87 std::queue<InferenceId> m_NotificationQueue; 88 }; 89 90 } // namespace experimental 91 92 } // namespace armnn 93