xref: /aosp_15_r20/external/android-nn-driver/RequestThread_1_3.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <queue>
9 #include <thread>
10 #include <mutex>
11 #include <condition_variable>
12 
13 #include "ArmnnDriver.hpp"
14 #include "ArmnnDriverImpl.hpp"
15 
16 #include <CpuExecutor.h>
17 #include <armnn/ArmNN.hpp>
18 
19 namespace armnn_driver
20 {
21 using TimePoint = std::chrono::steady_clock::time_point;
22 
23 template<template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
24 class RequestThread_1_3
25 {
26 public:
27     /// Constructor creates the thread
28     RequestThread_1_3();
29 
30     /// Destructor terminates the thread
31     ~RequestThread_1_3();
32 
33     /// Add a message to the thread queue.
34     /// @param[in] model pointer to the prepared model handling the request
35     /// @param[in] memPools pointer to the memory pools vector for the tensors
36     /// @param[in] inputTensors pointer to the input tensors for the request
37     /// @param[in] outputTensors pointer to the output tensors for the request
38     /// @param[in] callback the android notification callback
39     void PostMsg(PreparedModel<HalVersion>* model,
40                  std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
41                  std::shared_ptr<armnn::InputTensors>& inputTensors,
42                  std::shared_ptr<armnn::OutputTensors>& outputTensors,
43                  CallbackContext callbackContext);
44 
45 private:
46     RequestThread_1_3(const RequestThread_1_3&) = delete;
47     RequestThread_1_3& operator=(const RequestThread_1_3&) = delete;
48 
49     /// storage for a prepared model and args for the asyncExecute call
50     struct AsyncExecuteData
51     {
AsyncExecuteDataarmnn_driver::RequestThread_1_3::AsyncExecuteData52         AsyncExecuteData(PreparedModel<HalVersion>* model,
53                          std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
54                          std::shared_ptr<armnn::InputTensors>& inputTensors,
55                          std::shared_ptr<armnn::OutputTensors>& outputTensors,
56                          CallbackContext callbackContext)
57             : m_Model(model)
58             , m_MemPools(memPools)
59             , m_InputTensors(inputTensors)
60             , m_OutputTensors(outputTensors)
61             , m_CallbackContext(callbackContext)
62         {
63         }
64 
65         PreparedModel<HalVersion>* m_Model;
66         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
67         std::shared_ptr<armnn::InputTensors> m_InputTensors;
68         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
69         CallbackContext m_CallbackContext;
70     };
71     enum class ThreadMsgType
72     {
73         EXIT,                   // exit the thread
74         REQUEST                 // user request to process
75     };
76 
77     /// storage for the thread message type and data
78     struct ThreadMsg
79     {
ThreadMsgarmnn_driver::RequestThread_1_3::ThreadMsg80         ThreadMsg(ThreadMsgType msgType,
81                   std::shared_ptr<AsyncExecuteData>& msgData)
82             : type(msgType)
83             , data(msgData)
84         {
85         }
86 
87         ThreadMsgType type;
88         std::shared_ptr<AsyncExecuteData> data;
89     };
90 
91     /// Add a prepared thread message to the thread queue.
92     /// @param[in] threadMsg the message to add to the queue
93     void PostMsg(std::shared_ptr<ThreadMsg>& pThreadMsg, V1_3::Priority priority = V1_3::Priority::MEDIUM);
94 
95     /// Entry point for the request thread
96     void Process();
97 
98     std::unique_ptr<std::thread> m_Thread;
99     std::queue<std::shared_ptr<ThreadMsg>> m_HighPriorityQueue;
100     std::queue<std::shared_ptr<ThreadMsg>> m_MediumPriorityQueue;
101     std::queue<std::shared_ptr<ThreadMsg>> m_LowPriorityQueue;
102     std::mutex m_Mutex;
103     std::condition_variable m_Cv;
104 };
105 
106 } // namespace armnn_driver
107