1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_
16 #define TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_
17 
18 #include <memory>
19 
20 #include "tensorflow/core/platform/strcat.h"
21 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h"
22 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
23 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
24 #include "tfrt/support/thread_environment.h"  // from @tf_runtime
25 #include "third_party/concurrent_work_queue/lib/blocking_work_queue.h"
26 #include "third_party/concurrent_work_queue/lib/non_blocking_work_queue.h"
27 
28 namespace tfrt {
29 namespace tf {
30 
31 // Concurrent Work Queue based on Run Handler thread Pool. All tasks are queued
32 // based on requests.
33 class RunHandlerThreadWorkQueue
34     : public tensorflow::tfrt_stub::WorkQueueInterface {
35  public:
36   struct Options {
37     // The number of threads used for the main thread pool.
38     int num_main_threads;
39 
40     // The number of threads used for complementary thread pool.
41     int num_complementary_threads;
42 
43     // Timeout for InitRequest().
44     // The timeout may trigger as the work queue limits the number of concurrent
45     // in-flight requests for better latency.
46     int64_t init_timeout_ms;
47 
48     // The number of max concurrent handlers.
49     int max_concurrent_handler = 128;
50 
51     // The number of sub thread pool configed.
52     int num_sub_thread_pool = 1;
53 
54     // The number of threads in each sub thread pool. The length of the vector
55     // should equal to num_sub_thread_pool.
56     std::vector<int> num_threads_in_sub_thread_pool = {1};
57 
58     // The percentage of requests the first N sub thread pool handles. The
59     // length of the vector should equal to num_sub_thread_pool.
60     std::vector<double> sub_thread_request_percentage = {1.0};
61 
62     // Sleep time for non blocking threads if there is no pending task.
63     int non_blocking_threads_sleep_time_micro_sec = 1000;
64 
65     // Max sleep time for blocking threads if there is no pending task and no
66     // new task wakes up the thread.
67     int blocking_threads_max_sleep_time_micro_sec = 1000;
68 
69     // If true, use adaptive waiting time.
70     bool use_adaptive_waiting_time = true;
71 
72     // If true, threads won't wake itself up if there is no active requests.
73     bool wait_if_no_active_request = true;
74 
75     // If true, threads will be waken up by new tasks.
76     bool enable_wake_up = true;
77   };
78 
79   explicit RunHandlerThreadWorkQueue(const Options& options);
~RunHandlerThreadWorkQueue()80   ~RunHandlerThreadWorkQueue() override {}
81 
name()82   std::string name() const override {
83     return tensorflow::strings::StrCat(
84         "RunHandlerThreadWorkQueue C++ work queue (", options_.num_main_threads,
85         " main threads, ", options_.num_complementary_threads,
86         " complementary threads)");
87   }
88 
89   tensorflow::StatusOr<
90       std::unique_ptr<tensorflow::tfrt_stub::WorkQueueInterface>>
91   InitializeRequest(tfrt::RequestContextBuilder* request_context_builder,
92                     tensorflow::thread::ThreadPoolInterface**
93                         intra_op_threadpool) const override;
94 
GetParallelismLevel()95   int GetParallelismLevel() const override {
96     return options_.num_main_threads + options_.num_complementary_threads;
97   }
98 
99   void AddTask(TaskFunction work) override;
100 
101   Optional<TaskFunction> AddBlockingTask(TaskFunction work,
102                                          bool allow_queuing) override;
103 
104   void Quiesce() override;
105 
106   void Await(ArrayRef<RCReference<AsyncValue>> values) override;
107 
108   bool IsInWorkerThread() const override;
109 
110  private:
111   Options options_;
112 
113   // Handler Pool.
114   // Each request will require a handler from the pool, and release the handler
115   // back to the pool once it is done.
116   std::unique_ptr<RunHandlerPool> handler_pool_;
117 
118   // An id assigned to each request for tracing purpose.
119   static std::atomic_int_fast64_t step_id_counter_;
120 
121   // QuiescingState for non_blocking_work_queue_ and blocking_work_queue_.
122   std::unique_ptr<::tfrt::internal::QuiescingState> quiescing_state_;
123 
124   // Nonblocking queue used for cases without execution context.
125   ::tfrt::internal::NonBlockingWorkQueue<ThreadingEnvironment>
126       non_blocking_work_queue_;
127 
128   // Blocking queue used for cases without execution context.
129   ::tfrt::internal::BlockingWorkQueue<ThreadingEnvironment>
130       blocking_work_queue_;
131 };
132 
133 }  // namespace tf
134 }  // namespace tfrt
135 
136 #endif  // TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_
137