xref: /aosp_15_r20/external/armnn/src/armnn/AsyncExecutionCallback.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/IAsyncExecutionCallback.hpp>
9 #include <armnn/IWorkingMemHandle.hpp>
10 #include <armnn/Types.hpp>
11 
12 #include <condition_variable>
13 #include <mutex>
14 #include <thread>
15 #include <queue>
16 #include <unordered_map>
17 
18 namespace armnn
19 {
20 
21 namespace experimental
22 {
23 
24 using InferenceId = uint64_t;
25 class AsyncExecutionCallback final : public IAsyncExecutionCallback
26 {
27 private:
28     static InferenceId nextID;
29 
30 public:
AsyncExecutionCallback(std::queue<InferenceId> & notificationQueue,std::mutex & mutex,std::condition_variable & condition)31     AsyncExecutionCallback(std::queue<InferenceId>& notificationQueue
32 #if !defined(ARMNN_DISABLE_THREADS)
33                            , std::mutex& mutex
34                            , std::condition_variable& condition
35 #endif
36         )
37         : m_NotificationQueue(notificationQueue)
38 #if !defined(ARMNN_DISABLE_THREADS)
39         , m_Mutex(mutex)
40         , m_Condition(condition)
41 #endif
42         , m_InferenceId(++nextID)
43     {}
44 
~AsyncExecutionCallback()45     ~AsyncExecutionCallback()
46     {}
47 
48     void Notify(armnn::Status status, InferenceTimingPair timeTaken);
49 
GetInferenceId()50     InferenceId GetInferenceId()
51     {
52         return m_InferenceId;
53     }
54 
55     armnn::Status GetStatus() const;
56     HighResolutionClock GetStartTime() const;
57     HighResolutionClock GetEndTime() const;
58 
59 private:
60     std::queue<InferenceId>& m_NotificationQueue;
61 #if !defined(ARMNN_DISABLE_THREADS)
62     std::mutex&              m_Mutex;
63     std::condition_variable& m_Condition;
64 #endif
65 
66     HighResolutionClock m_StartTime;
67     HighResolutionClock m_EndTime;
68     armnn::Status       m_Status = Status::Failure;
69     InferenceId m_InferenceId;
70 };
71 
72 // Manager to create and monitor AsyncExecutionCallbacks
73 // GetNewCallback will create a callback for use in Threadpool::Schedule
74 // GetNotifiedCallback will return the first callback to be notified (finished execution)
75 class AsyncCallbackManager
76 {
77 public:
78     std::shared_ptr<AsyncExecutionCallback> GetNewCallback();
79     std::shared_ptr<AsyncExecutionCallback> GetNotifiedCallback();
80 
81 private:
82 #if !defined(ARMNN_DISABLE_THREADS)
83     std::mutex              m_Mutex;
84     std::condition_variable m_Condition;
85 #endif
86     std::unordered_map<InferenceId, std::shared_ptr<AsyncExecutionCallback>> m_Callbacks;
87     std::queue<InferenceId> m_NotificationQueue;
88 };
89 
90 } // namespace experimental
91 
92 } // namespace armnn
93