xref: /aosp_15_r20/external/android-nn-driver/RequestThread.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include <queue>
9*3e777be0SXin Li #include <thread>
10*3e777be0SXin Li #include <mutex>
11*3e777be0SXin Li #include <condition_variable>
12*3e777be0SXin Li 
13*3e777be0SXin Li #include "ArmnnDriver.hpp"
14*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
15*3e777be0SXin Li 
16*3e777be0SXin Li #include <CpuExecutor.h>
17*3e777be0SXin Li #include <armnn/ArmNN.hpp>
18*3e777be0SXin Li 
19*3e777be0SXin Li namespace armnn_driver
20*3e777be0SXin Li {
21*3e777be0SXin Li using TimePoint = std::chrono::steady_clock::time_point;
22*3e777be0SXin Li static const TimePoint g_Min = std::chrono::steady_clock::time_point::min();
23*3e777be0SXin Li 
24*3e777be0SXin Li template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
25*3e777be0SXin Li class RequestThread
26*3e777be0SXin Li {
27*3e777be0SXin Li public:
28*3e777be0SXin Li     /// Constructor creates the thread
29*3e777be0SXin Li     RequestThread();
30*3e777be0SXin Li 
31*3e777be0SXin Li     /// Destructor terminates the thread
32*3e777be0SXin Li     ~RequestThread();
33*3e777be0SXin Li 
34*3e777be0SXin Li     /// Add a message to the thread queue.
35*3e777be0SXin Li     /// @param[in] model pointer to the prepared model handling the request
36*3e777be0SXin Li     /// @param[in] memPools pointer to the memory pools vector for the tensors
37*3e777be0SXin Li     /// @param[in] inputTensors pointer to the input tensors for the request
38*3e777be0SXin Li     /// @param[in] outputTensors pointer to the output tensors for the request
39*3e777be0SXin Li     /// @param[in] callback the android notification callback
40*3e777be0SXin Li     void PostMsg(PreparedModel<HalVersion>* model,
41*3e777be0SXin Li                  std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
42*3e777be0SXin Li                  std::shared_ptr<armnn::InputTensors>& inputTensors,
43*3e777be0SXin Li                  std::shared_ptr<armnn::OutputTensors>& outputTensors,
44*3e777be0SXin Li                  CallbackContext callbackContext);
45*3e777be0SXin Li 
46*3e777be0SXin Li private:
47*3e777be0SXin Li     RequestThread(const RequestThread&) = delete;
48*3e777be0SXin Li     RequestThread& operator=(const RequestThread&) = delete;
49*3e777be0SXin Li 
50*3e777be0SXin Li     /// storage for a prepared model and args for the asyncExecute call
51*3e777be0SXin Li     struct AsyncExecuteData
52*3e777be0SXin Li     {
AsyncExecuteDataarmnn_driver::RequestThread::AsyncExecuteData53*3e777be0SXin Li         AsyncExecuteData(PreparedModel<HalVersion>* model,
54*3e777be0SXin Li                          std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
55*3e777be0SXin Li                          std::shared_ptr<armnn::InputTensors>& inputTensors,
56*3e777be0SXin Li                          std::shared_ptr<armnn::OutputTensors>& outputTensors,
57*3e777be0SXin Li                          CallbackContext callbackContext)
58*3e777be0SXin Li             : m_Model(model)
59*3e777be0SXin Li             , m_MemPools(memPools)
60*3e777be0SXin Li             , m_InputTensors(inputTensors)
61*3e777be0SXin Li             , m_OutputTensors(outputTensors)
62*3e777be0SXin Li             , m_CallbackContext(callbackContext)
63*3e777be0SXin Li         {
64*3e777be0SXin Li         }
65*3e777be0SXin Li 
66*3e777be0SXin Li         PreparedModel<HalVersion>* m_Model;
67*3e777be0SXin Li         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
68*3e777be0SXin Li         std::shared_ptr<armnn::InputTensors> m_InputTensors;
69*3e777be0SXin Li         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
70*3e777be0SXin Li         CallbackContext m_CallbackContext;
71*3e777be0SXin Li     };
72*3e777be0SXin Li     enum class ThreadMsgType
73*3e777be0SXin Li     {
74*3e777be0SXin Li         EXIT,                   // exit the thread
75*3e777be0SXin Li         REQUEST                 // user request to process
76*3e777be0SXin Li     };
77*3e777be0SXin Li 
78*3e777be0SXin Li     /// storage for the thread message type and data
79*3e777be0SXin Li     struct ThreadMsg
80*3e777be0SXin Li     {
ThreadMsgarmnn_driver::RequestThread::ThreadMsg81*3e777be0SXin Li         ThreadMsg(ThreadMsgType msgType,
82*3e777be0SXin Li                   std::shared_ptr<AsyncExecuteData>& msgData)
83*3e777be0SXin Li             : type(msgType)
84*3e777be0SXin Li             , data(msgData)
85*3e777be0SXin Li         {
86*3e777be0SXin Li         }
87*3e777be0SXin Li 
88*3e777be0SXin Li         ThreadMsgType type;
89*3e777be0SXin Li         std::shared_ptr<AsyncExecuteData> data;
90*3e777be0SXin Li     };
91*3e777be0SXin Li 
92*3e777be0SXin Li     /// Add a prepared thread message to the thread queue.
93*3e777be0SXin Li     /// @param[in] threadMsg the message to add to the queue
94*3e777be0SXin Li     void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg);
95*3e777be0SXin Li 
96*3e777be0SXin Li     /// Entry point for the request thread
97*3e777be0SXin Li     void Process();
98*3e777be0SXin Li 
99*3e777be0SXin Li     std::unique_ptr<std::thread> m_Thread;
100*3e777be0SXin Li     std::queue<std::shared_ptr<ThreadMsg>> m_Queue;
101*3e777be0SXin Li     std::mutex m_Mutex;
102*3e777be0SXin Li     std::condition_variable m_Cv;
103*3e777be0SXin Li };
104*3e777be0SXin Li 
105*3e777be0SXin Li } // namespace armnn_driver
106