xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOLL_RUN_HANDLER_H_
17 #define TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOLL_RUN_HANDLER_H_
18 
19 #include <cstddef>
20 
21 #include "tensorflow/core/lib/core/threadpool.h"
22 #include "tensorflow/core/lib/histogram/histogram.h"
23 #include "tensorflow/core/platform/context.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
28 #include "tfrt/host_context/task_function.h"  // from @tf_runtime
29 namespace Eigen {
30 struct ThreadPoolDevice;
31 }
32 
33 namespace tfrt {
34 namespace tf {
35 
36 class RunHandler;
37 
38 // Options for RunHanler.
39 struct RunHandlerOptions {
RunHandlerOptionsRunHandlerOptions40   RunHandlerOptions() : priority(0) {}
41 
42   // Request priority.
43   int priority;
44 };
45 
46 // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
47 // that can be used for tracking op work for a given inference request.
48 // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
49 // 'active' when its unique_ptr is returned by Get() and is being used by a
50 // client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
51 //
52 // Expected usage:
53 //
54 // * Create a single RunHandlerPool (say run_handler_pool_).
55 //
56 // * When an inference request is invoked, obtain a handler by:
57 // auto handler = run_handler_pool_->Get();
58 //
59 // * Use handler for scheduling all inter-op work by:
60 // handler->ScheduleInterOpClosure(closure);
61 //
62 // This class is thread safe.
63 class RunHandlerPool {
64  public:
65   struct Options {
66     // The number of main threads.
67     int num_inter_op_threads = 1;
68 
69     // The number of complimentary threads.
70     int num_intra_op_threads = 1;
71 
72     // The number of max concurrent handlers.
73     int max_concurrent_handler = 128;
74 
75     // The number of sub thread pool configed.
76     int num_sub_thread_pool = 1;
77 
78     // The number of threads in each sub thread pool. The length of the vector
79     // should equal to num_sub_thread_pool.
80     std::vector<int> num_threads_in_sub_thread_pool = {1};
81 
82     // The percentage of requests the first N sub thread pool handles. The
83     // length of the vector should equal to num_sub_thread_pool. For example,
84     // {0.5, 1} means the first sub thread pool will handle the first 50%
85     // requests based on priority and the second thread pool will handle the
86     // second 50% requests based on priority.
87     std::vector<double> sub_thread_request_percentage = {1.0};
88 
89     // Sleep time for non blocking threads if there is no pending task.
90     int non_blocking_threads_sleep_time_micro_sec = 1000;
91 
92     // Max sleep time for blocking threads if there is no pending task and no
93     // new task wakes up the thread.
94     int blocking_threads_max_sleep_time_micro_sec = 1000;
95 
96     // If true, use adaptive waiting time.
97     bool use_adaptive_waiting_time = true;
98 
99     // If true, threads won't wake itself up if there is no active requests.
100     bool wait_if_no_active_request = true;
101 
102     // If true, threads will be waken up by new tasks.
103     bool enable_wake_up = true;
104   };
105   explicit RunHandlerPool(Options options);
106   ~RunHandlerPool();
107 
108   // Returns an inactive RunHandler from the pool.
109   //
110   // RunHandlers in RunHandlerPool are initially 'inactive'.
111   // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
112   // and is being used by a client.  It becomes 'inactive' once more when the
113   // unique_ptr is destroyed.
114   //
115   // Will block unless there is an inactive handler.
116   std::unique_ptr<RunHandler> Get(
117       int64_t step_id = 0, int64_t timeout_in_ms = 0,
118       const RunHandlerOptions& options = RunHandlerOptions());
119 
120   // Get the priorities for active handlers. The return result is with the same
121   // order of the active handler list.
122   std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const;
123 
124   // Block until the system is quiescent (no pending work and no inflight work).
125   void Quiesce() const;
126 
127  private:
128   class Impl;
129   friend class RunHandler;
130 
131   std::unique_ptr<Impl> impl_;
132 };
133 
134 // RunHandler can be used to schedule inter/intra-op closures to run on a global
135 // pool shared across all Session::Run(s). The closures are enqueued to a
136 // handler specific queue, from which the work is stolen in a priority order
137 // (time of the Get() call).
138 //
139 // It can only be created via RunHandlerPool::Get().
140 //
141 // This class can be used instead of directly scheduling closures on a global
142 // pool since it maintains a global view across all sessions and optimizes pool
143 // scheduling to improve (median and tail) latency.
144 //
145 // This class is thread safe.
146 class RunHandler {
147  public:
148   void ScheduleInterOpClosure(TaskFunction fn);
149   void ScheduleIntraOpClosure(TaskFunction fn);
150 
151   tensorflow::thread::ThreadPoolInterface* AsIntraThreadPoolInterface() const;
152 
153   int NumThreads() const;
154 
155   int64_t step_id() const;
156 
157   ~RunHandler();
158 
159  private:
160   class Impl;
161   friend class RunHandlerPool::Impl;
162 
163   explicit RunHandler(Impl* impl);
164 
165   Impl* impl_;  // NOT OWNED.
166 };
167 
168 namespace internal {
169 
170 // TODO(azaks): Refactor with thread:ThreadPool
171 class RunHandlerEnvironment {
172  public:
173   typedef tensorflow::Thread EnvThread;
174   struct TaskImpl {
175     TaskFunction f;
176     tensorflow::Context context;
177     uint64_t trace_id;
178   };
179   tensorflow::Env* const env_;
180   const tensorflow::ThreadOptions thread_options_;
181   const std::string name_;
182 
183  public:
184   struct Task {
185     std::unique_ptr<TaskImpl> f;
186   };
187 
188   RunHandlerEnvironment(tensorflow::Env* env,
189                         const tensorflow::ThreadOptions& thread_options,
190                         const std::string& name);
191 
192   EnvThread* CreateThread(std::function<void()> f);
193 
194   Task CreateTask(TaskFunction f);
195 
196   void ExecuteTask(const Task& t);
197 };
198 
199 typedef typename RunHandlerEnvironment::Task Task;
200 typedef Eigen::RunQueue<Task, 1024> Queue;
201 
202 // To reduce cache misses, we use a doubly-linked list of Waiter structs and
203 // queue them in LIFO order rather than the FIFO order used by a single
204 // condition variable.
205 struct Waiter {
WaiterWaiter206   Waiter() {
207     next = this;
208     prev = this;
209   }
210   tensorflow::condition_variable cv;
211   int num_waiting_threads = 0;
212   tensorflow::mutex mu;
213   Waiter* next;
214   Waiter* prev;
215 };
216 
217 class ThreadWorkSource {
218  public:
219   ThreadWorkSource();
220 
221   ~ThreadWorkSource();
222 
223   Task EnqueueTask(Task t, bool is_blocking, bool enable_wake_up);
224 
225   Task PopBlockingTask();
226 
227   Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
228 
229   int TaskQueueSize(bool is_blocking);
230 
231   int64_t GetTracemeId();
232 
233   void SetTracemeId(int64_t value);
234 
235   void SetWaiter(uint64_t version, Waiter* waiter, tensorflow::mutex* mutex);
236 
237   int64_t GetInflightTaskCount(bool is_blocking);
238 
239   void IncrementInflightTaskCount(bool is_blocking);
240 
241   void DecrementInflightTaskCount(bool is_blocking);
242 
243   int64_t GetPendingTaskCount();
244 
245   void IncrementPendingTaskCount();
246 
247   void DecrementPendingTaskCount();
248 
249   unsigned NonBlockingWorkShardingFactor();
250 
251   std::string ToString();
252 
253  private:
254   struct NonBlockingQueue {
255     tensorflow::mutex queue_op_mu;
256     char pad[128];
257     Queue queue;
258   };
259 
260   int32_t non_blocking_work_sharding_factor_;
261   Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
262 
263   // The number of tasks that are executing now.
264   std::atomic<int64_t> blocking_inflight_;
265   std::atomic<int64_t> non_blocking_inflight_;
266 
267   // The number of tasks that are enqueued and not finished.
268   std::atomic<int64_t> pending_tasks_;
269 
270   Queue blocking_work_queue_;
271   tensorflow::mutex blocking_queue_op_mu_;
272   char pad_[128];
273   tensorflow::mutex waiters_mu_;
274   Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
275   std::atomic<int64_t> traceme_id_;
276 
277   tensorflow::mutex run_handler_waiter_mu_;
278   uint64_t version_ TF_GUARDED_BY(run_handler_waiter_mu_);
279   tensorflow::mutex* sub_thread_pool_waiter_mu_
280       TF_GUARDED_BY(run_handler_waiter_mu_);
281   Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
282 };
283 
284 class RunHandlerThreadPool {
285  public:
286   struct Options {
287     int num_blocking_threads;
288     int num_non_blocking_threads;
289     bool wait_if_no_active_request;
290     int non_blocking_threads_sleep_time_micro_sec;
291     int blocking_threads_max_sleep_time_micro_sec;
292     bool use_adaptive_waiting_time;
293     bool enable_wake_up;
294     int max_concurrent_handler;
295     std::vector<int> num_threads_in_sub_thread_pool;
296     std::vector<double> sub_thread_request_percentage;
OptionsOptions297     Options(int num_blocking_threads, int num_non_blocking_threads,
298             bool wait_if_no_active_request,
299             int non_blocking_threads_sleep_time_micro_sec,
300             int blocking_threads_max_sleep_time_micro_sec,
301             bool use_adaptive_waiting_time, bool enable_wake_up,
302             int max_concurrent_handler,
303             const std::vector<int>& num_threads_in_sub_thread_pool,
304             const std::vector<double>& sub_thread_request_percentage)
305         : num_blocking_threads(num_blocking_threads),
306           num_non_blocking_threads(num_non_blocking_threads),
307           wait_if_no_active_request(wait_if_no_active_request),
308           non_blocking_threads_sleep_time_micro_sec(
309               non_blocking_threads_sleep_time_micro_sec),
310           blocking_threads_max_sleep_time_micro_sec(
311               blocking_threads_max_sleep_time_micro_sec),
312           use_adaptive_waiting_time(use_adaptive_waiting_time),
313           enable_wake_up(enable_wake_up),
314           max_concurrent_handler(max_concurrent_handler),
315           num_threads_in_sub_thread_pool(num_threads_in_sub_thread_pool),
316           sub_thread_request_percentage(sub_thread_request_percentage) {}
317   };
318   struct PerThread {
PerThreadPerThread319     constexpr PerThread() : pool(nullptr), thread_id(-1) {}
320     RunHandlerThreadPool* pool;  // Parent pool, or null for normal threads.
321     int thread_id;               // Worker thread index in pool.
322   };
323 
324   RunHandlerThreadPool(Options options, tensorflow::Env* env,
325                        const tensorflow::ThreadOptions& thread_options,
326                        const std::string& name,
327                        Eigen::MaxSizeVector<tensorflow::mutex>* waiters_mu,
328                        Eigen::MaxSizeVector<Waiter>* queue_waiters);
329 
330   ~RunHandlerThreadPool();
331 
332   void Start();
333 
334   void StartOneThreadForTesting();
335 
336   void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, TaskFunction fn);
337 
338   // Set work queues from which the thread 'tid' can steal its work.
339   void SetThreadWorkSources(
340       int tid, uint64_t version,
341       const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
342 
343   PerThread* GetPerThread();
344 
345   int CurrentThreadId() const;
346 
347   int NumThreads() const;
348 
349   int NumBlockingThreads() const;
350 
351   int NumNonBlockingThreads() const;
352 
353   void WorkerLoop(int thread_id, bool may_steal_blocking_work);
354 
355   // Search tasks from Requets range searching_range_start to
356   // searching_range_end. If there is no tasks in the search range and
357   // may_steal_blocking_work is true, then search from all requests.
358   Task FindTask(
359       int searching_range_start, int searching_range_end, int thread_id,
360       int sub_thread_pool_id, int max_blocking_inflight,
361       bool may_steal_blocking_work,
362       const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
363       bool* task_from_blocking_queue, ThreadWorkSource** tws);
364 
365   void WaitForWorkInSubThreadPool(int thread_id, bool is_blocking,
366                                   int sub_thread_pool_id);
367 
368  private:
369   struct ThreadData {
370     ThreadData();
371     tensorflow::mutex mu;
372     uint64_t new_version;
373     tensorflow::condition_variable sources_not_empty;
374     std::unique_ptr<tensorflow::Thread> thread;
375     int current_index;
376     std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
377         new_thread_work_sources TF_GUARDED_BY(mu);
378 
379     uint64_t current_version;
380     // Should only be accessed by one thread.
381     std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
382         current_thread_work_sources;
383 
384     int sub_thread_pool_id;
385   };
386 
387   const int num_threads_;
388   const int num_blocking_threads_;
389   const int num_non_blocking_threads_;
390   const bool adaptive_sleep_time_;
391   const bool wait_if_no_active_request_;
392   const int non_blocking_thread_sleep_time_;
393   const int blocking_thread_max_waiting_time_;
394   const bool enable_wake_up_;
395   Eigen::MaxSizeVector<ThreadData> thread_data_;
396   internal::RunHandlerEnvironment env_;
397   std::atomic<bool> cancelled_;
398   std::string name_;
399   Eigen::MaxSizeVector<tensorflow::mutex>* waiters_mu_;
400   Eigen::MaxSizeVector<Waiter>* queue_waiters_;
401 
402   std::vector<int> num_threads_in_sub_thread_pool_;
403 
404   // Threads in each sub thread pool will search tasks from
405   // the end_request_percentage of previous sub thread pool to its own
406   // end_request_percentage in a round robin fashion.
407   std::vector<double> sub_thread_pool_end_request_percentage_;
408 };
409 
410 }  // namespace internal
411 
412 class RunHandlerWorkQueue : public tensorflow::tfrt_stub::WorkQueueInterface {
413  public:
RunHandlerWorkQueue(std::unique_ptr<RunHandler> run_handler)414   explicit RunHandlerWorkQueue(std::unique_ptr<RunHandler> run_handler)
415       : run_handler_(std::move(run_handler)) {
416     DCHECK(run_handler_);
417   }
418   ~RunHandlerWorkQueue() override = default;
419 
name()420   std::string name() const override { return "run_handler"; }
421 
422   int GetParallelismLevel() const override;
423 
424   void AddTask(TaskFunction work) override;
425 
426   Optional<TaskFunction> AddBlockingTask(TaskFunction work,
427                                          bool allow_queuing) override;
428 
429   void Await(
430       llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values) override;
431 
432   bool IsInWorkerThread() const override;
433 
Quiesce()434   void Quiesce() override {
435     LOG(FATAL) << "RunHandlerWorkQueue::Quiesce() is not "  // Crash OK
436                   "implemented, and supposed to be removed.";
437   }
438 
439  private:
440   std::unique_ptr<RunHandler> run_handler_;
441 };
442 
443 }  // end namespace tf
444 }  // end namespace tfrt
445 
446 #endif  // TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOLL_RUN_HANDLER_H_
447