1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_ 16 #define TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_ 17 18 #include <memory> 19 20 #include "tensorflow/core/platform/strcat.h" 21 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h" 22 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" 23 #include "tfrt/host_context/execution_context.h" // from @tf_runtime 24 #include "tfrt/support/thread_environment.h" // from @tf_runtime 25 #include "third_party/concurrent_work_queue/lib/blocking_work_queue.h" 26 #include "third_party/concurrent_work_queue/lib/non_blocking_work_queue.h" 27 28 namespace tfrt { 29 namespace tf { 30 31 // Concurrent Work Queue based on Run Handler thread Pool. All tasks are queued 32 // based on requests. 33 class RunHandlerThreadWorkQueue 34 : public tensorflow::tfrt_stub::WorkQueueInterface { 35 public: 36 struct Options { 37 // The number of threads used for the main thread pool. 38 int num_main_threads; 39 40 // The number of threads used for complementary thread pool. 41 int num_complementary_threads; 42 43 // Timeout for InitRequest(). 44 // The timeout may trigger as the work queue limits the number of concurrent 45 // in-flight requests for better latency. 46 int64_t init_timeout_ms; 47 48 // The number of max concurrent handlers. 49 int max_concurrent_handler = 128; 50 51 // The number of sub thread pool configed. 52 int num_sub_thread_pool = 1; 53 54 // The number of threads in each sub thread pool. The length of the vector 55 // should equal to num_sub_thread_pool. 56 std::vector<int> num_threads_in_sub_thread_pool = {1}; 57 58 // The percentage of requests the first N sub thread pool handles. The 59 // length of the vector should equal to num_sub_thread_pool. 60 std::vector<double> sub_thread_request_percentage = {1.0}; 61 62 // Sleep time for non blocking threads if there is no pending task. 63 int non_blocking_threads_sleep_time_micro_sec = 1000; 64 65 // Max sleep time for blocking threads if there is no pending task and no 66 // new task wakes up the thread. 67 int blocking_threads_max_sleep_time_micro_sec = 1000; 68 69 // If true, use adaptive waiting time. 70 bool use_adaptive_waiting_time = true; 71 72 // If true, threads won't wake itself up if there is no active requests. 73 bool wait_if_no_active_request = true; 74 75 // If true, threads will be waken up by new tasks. 76 bool enable_wake_up = true; 77 }; 78 79 explicit RunHandlerThreadWorkQueue(const Options& options); ~RunHandlerThreadWorkQueue()80 ~RunHandlerThreadWorkQueue() override {} 81 name()82 std::string name() const override { 83 return tensorflow::strings::StrCat( 84 "RunHandlerThreadWorkQueue C++ work queue (", options_.num_main_threads, 85 " main threads, ", options_.num_complementary_threads, 86 " complementary threads)"); 87 } 88 89 tensorflow::StatusOr< 90 std::unique_ptr<tensorflow::tfrt_stub::WorkQueueInterface>> 91 InitializeRequest(tfrt::RequestContextBuilder* request_context_builder, 92 tensorflow::thread::ThreadPoolInterface** 93 intra_op_threadpool) const override; 94 GetParallelismLevel()95 int GetParallelismLevel() const override { 96 return options_.num_main_threads + options_.num_complementary_threads; 97 } 98 99 void AddTask(TaskFunction work) override; 100 101 Optional<TaskFunction> AddBlockingTask(TaskFunction work, 102 bool allow_queuing) override; 103 104 void Quiesce() override; 105 106 void Await(ArrayRef<RCReference<AsyncValue>> values) override; 107 108 bool IsInWorkerThread() const override; 109 110 private: 111 Options options_; 112 113 // Handler Pool. 114 // Each request will require a handler from the pool, and release the handler 115 // back to the pool once it is done. 116 std::unique_ptr<RunHandlerPool> handler_pool_; 117 118 // An id assigned to each request for tracing purpose. 119 static std::atomic_int_fast64_t step_id_counter_; 120 121 // QuiescingState for non_blocking_work_queue_ and blocking_work_queue_. 122 std::unique_ptr<::tfrt::internal::QuiescingState> quiescing_state_; 123 124 // Nonblocking queue used for cases without execution context. 125 ::tfrt::internal::NonBlockingWorkQueue<ThreadingEnvironment> 126 non_blocking_work_queue_; 127 128 // Blocking queue used for cases without execution context. 129 ::tfrt::internal::BlockingWorkQueue<ThreadingEnvironment> 130 blocking_work_queue_; 131 }; 132 133 } // namespace tf 134 } // namespace tfrt 135 136 #endif // TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_ 137