xref: /aosp_15_r20/external/android-nn-driver/RequestThread.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <queue>
9 #include <thread>
10 #include <mutex>
11 #include <condition_variable>
12 
13 #include "ArmnnDriver.hpp"
14 #include "ArmnnDriverImpl.hpp"
15 
16 #include <CpuExecutor.h>
17 #include <armnn/ArmNN.hpp>
18 
19 namespace armnn_driver
20 {
21 using TimePoint = std::chrono::steady_clock::time_point;
22 static const TimePoint g_Min = std::chrono::steady_clock::time_point::min();
23 
24 template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
25 class RequestThread
26 {
27 public:
28     /// Constructor creates the thread
29     RequestThread();
30 
31     /// Destructor terminates the thread
32     ~RequestThread();
33 
34     /// Add a message to the thread queue.
35     /// @param[in] model pointer to the prepared model handling the request
36     /// @param[in] memPools pointer to the memory pools vector for the tensors
37     /// @param[in] inputTensors pointer to the input tensors for the request
38     /// @param[in] outputTensors pointer to the output tensors for the request
39     /// @param[in] callback the android notification callback
40     void PostMsg(PreparedModel<HalVersion>* model,
41                  std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
42                  std::shared_ptr<armnn::InputTensors>& inputTensors,
43                  std::shared_ptr<armnn::OutputTensors>& outputTensors,
44                  CallbackContext callbackContext);
45 
46 private:
47     RequestThread(const RequestThread&) = delete;
48     RequestThread& operator=(const RequestThread&) = delete;
49 
50     /// storage for a prepared model and args for the asyncExecute call
51     struct AsyncExecuteData
52     {
AsyncExecuteDataarmnn_driver::RequestThread::AsyncExecuteData53         AsyncExecuteData(PreparedModel<HalVersion>* model,
54                          std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
55                          std::shared_ptr<armnn::InputTensors>& inputTensors,
56                          std::shared_ptr<armnn::OutputTensors>& outputTensors,
57                          CallbackContext callbackContext)
58             : m_Model(model)
59             , m_MemPools(memPools)
60             , m_InputTensors(inputTensors)
61             , m_OutputTensors(outputTensors)
62             , m_CallbackContext(callbackContext)
63         {
64         }
65 
66         PreparedModel<HalVersion>* m_Model;
67         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
68         std::shared_ptr<armnn::InputTensors> m_InputTensors;
69         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
70         CallbackContext m_CallbackContext;
71     };
72     enum class ThreadMsgType
73     {
74         EXIT,                   // exit the thread
75         REQUEST                 // user request to process
76     };
77 
78     /// storage for the thread message type and data
79     struct ThreadMsg
80     {
ThreadMsgarmnn_driver::RequestThread::ThreadMsg81         ThreadMsg(ThreadMsgType msgType,
82                   std::shared_ptr<AsyncExecuteData>& msgData)
83             : type(msgType)
84             , data(msgData)
85         {
86         }
87 
88         ThreadMsgType type;
89         std::shared_ptr<AsyncExecuteData> data;
90     };
91 
92     /// Add a prepared thread message to the thread queue.
93     /// @param[in] threadMsg the message to add to the queue
94     void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg);
95 
96     /// Entry point for the request thread
97     void Process();
98 
99     std::unique_ptr<std::thread> m_Thread;
100     std::queue<std::shared_ptr<ThreadMsg>> m_Queue;
101     std::mutex m_Mutex;
102     std::condition_variable m_Cv;
103 };
104 
105 } // namespace armnn_driver
106