1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOLL_RUN_HANDLER_H_ 17 #define TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOLL_RUN_HANDLER_H_ 18 19 #include <cstddef> 20 21 #include "tensorflow/core/lib/core/threadpool.h" 22 #include "tensorflow/core/lib/histogram/histogram.h" 23 #include "tensorflow/core/platform/context.h" 24 #include "tensorflow/core/platform/mutex.h" 25 #include "tensorflow/core/platform/thread_annotations.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" 28 #include "tfrt/host_context/task_function.h" // from @tf_runtime 29 namespace Eigen { 30 struct ThreadPoolDevice; 31 } 32 33 namespace tfrt { 34 namespace tf { 35 36 class RunHandler; 37 38 // Options for RunHanler. 39 struct RunHandlerOptions { RunHandlerOptionsRunHandlerOptions40 RunHandlerOptions() : priority(0) {} 41 42 // Request priority. 43 int priority; 44 }; 45 46 // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers 47 // that can be used for tracking op work for a given inference request. 48 // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes 49 // 'active' when its unique_ptr is returned by Get() and is being used by a 50 // client. It becomes 'inactive' once more when its unique_ptr gets destroyed. 51 // 52 // Expected usage: 53 // 54 // * Create a single RunHandlerPool (say run_handler_pool_). 55 // 56 // * When an inference request is invoked, obtain a handler by: 57 // auto handler = run_handler_pool_->Get(); 58 // 59 // * Use handler for scheduling all inter-op work by: 60 // handler->ScheduleInterOpClosure(closure); 61 // 62 // This class is thread safe. 63 class RunHandlerPool { 64 public: 65 struct Options { 66 // The number of main threads. 67 int num_inter_op_threads = 1; 68 69 // The number of complimentary threads. 70 int num_intra_op_threads = 1; 71 72 // The number of max concurrent handlers. 73 int max_concurrent_handler = 128; 74 75 // The number of sub thread pool configed. 76 int num_sub_thread_pool = 1; 77 78 // The number of threads in each sub thread pool. The length of the vector 79 // should equal to num_sub_thread_pool. 80 std::vector<int> num_threads_in_sub_thread_pool = {1}; 81 82 // The percentage of requests the first N sub thread pool handles. The 83 // length of the vector should equal to num_sub_thread_pool. For example, 84 // {0.5, 1} means the first sub thread pool will handle the first 50% 85 // requests based on priority and the second thread pool will handle the 86 // second 50% requests based on priority. 87 std::vector<double> sub_thread_request_percentage = {1.0}; 88 89 // Sleep time for non blocking threads if there is no pending task. 90 int non_blocking_threads_sleep_time_micro_sec = 1000; 91 92 // Max sleep time for blocking threads if there is no pending task and no 93 // new task wakes up the thread. 94 int blocking_threads_max_sleep_time_micro_sec = 1000; 95 96 // If true, use adaptive waiting time. 97 bool use_adaptive_waiting_time = true; 98 99 // If true, threads won't wake itself up if there is no active requests. 100 bool wait_if_no_active_request = true; 101 102 // If true, threads will be waken up by new tasks. 103 bool enable_wake_up = true; 104 }; 105 explicit RunHandlerPool(Options options); 106 ~RunHandlerPool(); 107 108 // Returns an inactive RunHandler from the pool. 109 // 110 // RunHandlers in RunHandlerPool are initially 'inactive'. 111 // A RunHandler becomes 'active' when its unique_ptr its returned by Get() 112 // and is being used by a client. It becomes 'inactive' once more when the 113 // unique_ptr is destroyed. 114 // 115 // Will block unless there is an inactive handler. 116 std::unique_ptr<RunHandler> Get( 117 int64_t step_id = 0, int64_t timeout_in_ms = 0, 118 const RunHandlerOptions& options = RunHandlerOptions()); 119 120 // Get the priorities for active handlers. The return result is with the same 121 // order of the active handler list. 122 std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const; 123 124 // Block until the system is quiescent (no pending work and no inflight work). 125 void Quiesce() const; 126 127 private: 128 class Impl; 129 friend class RunHandler; 130 131 std::unique_ptr<Impl> impl_; 132 }; 133 134 // RunHandler can be used to schedule inter/intra-op closures to run on a global 135 // pool shared across all Session::Run(s). The closures are enqueued to a 136 // handler specific queue, from which the work is stolen in a priority order 137 // (time of the Get() call). 138 // 139 // It can only be created via RunHandlerPool::Get(). 140 // 141 // This class can be used instead of directly scheduling closures on a global 142 // pool since it maintains a global view across all sessions and optimizes pool 143 // scheduling to improve (median and tail) latency. 144 // 145 // This class is thread safe. 146 class RunHandler { 147 public: 148 void ScheduleInterOpClosure(TaskFunction fn); 149 void ScheduleIntraOpClosure(TaskFunction fn); 150 151 tensorflow::thread::ThreadPoolInterface* AsIntraThreadPoolInterface() const; 152 153 int NumThreads() const; 154 155 int64_t step_id() const; 156 157 ~RunHandler(); 158 159 private: 160 class Impl; 161 friend class RunHandlerPool::Impl; 162 163 explicit RunHandler(Impl* impl); 164 165 Impl* impl_; // NOT OWNED. 166 }; 167 168 namespace internal { 169 170 // TODO(azaks): Refactor with thread:ThreadPool 171 class RunHandlerEnvironment { 172 public: 173 typedef tensorflow::Thread EnvThread; 174 struct TaskImpl { 175 TaskFunction f; 176 tensorflow::Context context; 177 uint64_t trace_id; 178 }; 179 tensorflow::Env* const env_; 180 const tensorflow::ThreadOptions thread_options_; 181 const std::string name_; 182 183 public: 184 struct Task { 185 std::unique_ptr<TaskImpl> f; 186 }; 187 188 RunHandlerEnvironment(tensorflow::Env* env, 189 const tensorflow::ThreadOptions& thread_options, 190 const std::string& name); 191 192 EnvThread* CreateThread(std::function<void()> f); 193 194 Task CreateTask(TaskFunction f); 195 196 void ExecuteTask(const Task& t); 197 }; 198 199 typedef typename RunHandlerEnvironment::Task Task; 200 typedef Eigen::RunQueue<Task, 1024> Queue; 201 202 // To reduce cache misses, we use a doubly-linked list of Waiter structs and 203 // queue them in LIFO order rather than the FIFO order used by a single 204 // condition variable. 205 struct Waiter { WaiterWaiter206 Waiter() { 207 next = this; 208 prev = this; 209 } 210 tensorflow::condition_variable cv; 211 int num_waiting_threads = 0; 212 tensorflow::mutex mu; 213 Waiter* next; 214 Waiter* prev; 215 }; 216 217 class ThreadWorkSource { 218 public: 219 ThreadWorkSource(); 220 221 ~ThreadWorkSource(); 222 223 Task EnqueueTask(Task t, bool is_blocking, bool enable_wake_up); 224 225 Task PopBlockingTask(); 226 227 Task PopNonBlockingTask(int start_index, bool search_from_all_queue); 228 229 int TaskQueueSize(bool is_blocking); 230 231 int64_t GetTracemeId(); 232 233 void SetTracemeId(int64_t value); 234 235 void SetWaiter(uint64_t version, Waiter* waiter, tensorflow::mutex* mutex); 236 237 int64_t GetInflightTaskCount(bool is_blocking); 238 239 void IncrementInflightTaskCount(bool is_blocking); 240 241 void DecrementInflightTaskCount(bool is_blocking); 242 243 int64_t GetPendingTaskCount(); 244 245 void IncrementPendingTaskCount(); 246 247 void DecrementPendingTaskCount(); 248 249 unsigned NonBlockingWorkShardingFactor(); 250 251 std::string ToString(); 252 253 private: 254 struct NonBlockingQueue { 255 tensorflow::mutex queue_op_mu; 256 char pad[128]; 257 Queue queue; 258 }; 259 260 int32_t non_blocking_work_sharding_factor_; 261 Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_; 262 263 // The number of tasks that are executing now. 264 std::atomic<int64_t> blocking_inflight_; 265 std::atomic<int64_t> non_blocking_inflight_; 266 267 // The number of tasks that are enqueued and not finished. 268 std::atomic<int64_t> pending_tasks_; 269 270 Queue blocking_work_queue_; 271 tensorflow::mutex blocking_queue_op_mu_; 272 char pad_[128]; 273 tensorflow::mutex waiters_mu_; 274 Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); 275 std::atomic<int64_t> traceme_id_; 276 277 tensorflow::mutex run_handler_waiter_mu_; 278 uint64_t version_ TF_GUARDED_BY(run_handler_waiter_mu_); 279 tensorflow::mutex* sub_thread_pool_waiter_mu_ 280 TF_GUARDED_BY(run_handler_waiter_mu_); 281 Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); 282 }; 283 284 class RunHandlerThreadPool { 285 public: 286 struct Options { 287 int num_blocking_threads; 288 int num_non_blocking_threads; 289 bool wait_if_no_active_request; 290 int non_blocking_threads_sleep_time_micro_sec; 291 int blocking_threads_max_sleep_time_micro_sec; 292 bool use_adaptive_waiting_time; 293 bool enable_wake_up; 294 int max_concurrent_handler; 295 std::vector<int> num_threads_in_sub_thread_pool; 296 std::vector<double> sub_thread_request_percentage; OptionsOptions297 Options(int num_blocking_threads, int num_non_blocking_threads, 298 bool wait_if_no_active_request, 299 int non_blocking_threads_sleep_time_micro_sec, 300 int blocking_threads_max_sleep_time_micro_sec, 301 bool use_adaptive_waiting_time, bool enable_wake_up, 302 int max_concurrent_handler, 303 const std::vector<int>& num_threads_in_sub_thread_pool, 304 const std::vector<double>& sub_thread_request_percentage) 305 : num_blocking_threads(num_blocking_threads), 306 num_non_blocking_threads(num_non_blocking_threads), 307 wait_if_no_active_request(wait_if_no_active_request), 308 non_blocking_threads_sleep_time_micro_sec( 309 non_blocking_threads_sleep_time_micro_sec), 310 blocking_threads_max_sleep_time_micro_sec( 311 blocking_threads_max_sleep_time_micro_sec), 312 use_adaptive_waiting_time(use_adaptive_waiting_time), 313 enable_wake_up(enable_wake_up), 314 max_concurrent_handler(max_concurrent_handler), 315 num_threads_in_sub_thread_pool(num_threads_in_sub_thread_pool), 316 sub_thread_request_percentage(sub_thread_request_percentage) {} 317 }; 318 struct PerThread { PerThreadPerThread319 constexpr PerThread() : pool(nullptr), thread_id(-1) {} 320 RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. 321 int thread_id; // Worker thread index in pool. 322 }; 323 324 RunHandlerThreadPool(Options options, tensorflow::Env* env, 325 const tensorflow::ThreadOptions& thread_options, 326 const std::string& name, 327 Eigen::MaxSizeVector<tensorflow::mutex>* waiters_mu, 328 Eigen::MaxSizeVector<Waiter>* queue_waiters); 329 330 ~RunHandlerThreadPool(); 331 332 void Start(); 333 334 void StartOneThreadForTesting(); 335 336 void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, TaskFunction fn); 337 338 // Set work queues from which the thread 'tid' can steal its work. 339 void SetThreadWorkSources( 340 int tid, uint64_t version, 341 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources); 342 343 PerThread* GetPerThread(); 344 345 int CurrentThreadId() const; 346 347 int NumThreads() const; 348 349 int NumBlockingThreads() const; 350 351 int NumNonBlockingThreads() const; 352 353 void WorkerLoop(int thread_id, bool may_steal_blocking_work); 354 355 // Search tasks from Requets range searching_range_start to 356 // searching_range_end. If there is no tasks in the search range and 357 // may_steal_blocking_work is true, then search from all requests. 358 Task FindTask( 359 int searching_range_start, int searching_range_end, int thread_id, 360 int sub_thread_pool_id, int max_blocking_inflight, 361 bool may_steal_blocking_work, 362 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources, 363 bool* task_from_blocking_queue, ThreadWorkSource** tws); 364 365 void WaitForWorkInSubThreadPool(int thread_id, bool is_blocking, 366 int sub_thread_pool_id); 367 368 private: 369 struct ThreadData { 370 ThreadData(); 371 tensorflow::mutex mu; 372 uint64_t new_version; 373 tensorflow::condition_variable sources_not_empty; 374 std::unique_ptr<tensorflow::Thread> thread; 375 int current_index; 376 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> 377 new_thread_work_sources TF_GUARDED_BY(mu); 378 379 uint64_t current_version; 380 // Should only be accessed by one thread. 381 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> 382 current_thread_work_sources; 383 384 int sub_thread_pool_id; 385 }; 386 387 const int num_threads_; 388 const int num_blocking_threads_; 389 const int num_non_blocking_threads_; 390 const bool adaptive_sleep_time_; 391 const bool wait_if_no_active_request_; 392 const int non_blocking_thread_sleep_time_; 393 const int blocking_thread_max_waiting_time_; 394 const bool enable_wake_up_; 395 Eigen::MaxSizeVector<ThreadData> thread_data_; 396 internal::RunHandlerEnvironment env_; 397 std::atomic<bool> cancelled_; 398 std::string name_; 399 Eigen::MaxSizeVector<tensorflow::mutex>* waiters_mu_; 400 Eigen::MaxSizeVector<Waiter>* queue_waiters_; 401 402 std::vector<int> num_threads_in_sub_thread_pool_; 403 404 // Threads in each sub thread pool will search tasks from 405 // the end_request_percentage of previous sub thread pool to its own 406 // end_request_percentage in a round robin fashion. 407 std::vector<double> sub_thread_pool_end_request_percentage_; 408 }; 409 410 } // namespace internal 411 412 class RunHandlerWorkQueue : public tensorflow::tfrt_stub::WorkQueueInterface { 413 public: RunHandlerWorkQueue(std::unique_ptr<RunHandler> run_handler)414 explicit RunHandlerWorkQueue(std::unique_ptr<RunHandler> run_handler) 415 : run_handler_(std::move(run_handler)) { 416 DCHECK(run_handler_); 417 } 418 ~RunHandlerWorkQueue() override = default; 419 name()420 std::string name() const override { return "run_handler"; } 421 422 int GetParallelismLevel() const override; 423 424 void AddTask(TaskFunction work) override; 425 426 Optional<TaskFunction> AddBlockingTask(TaskFunction work, 427 bool allow_queuing) override; 428 429 void Await( 430 llvm::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values) override; 431 432 bool IsInWorkerThread() const override; 433 Quiesce()434 void Quiesce() override { 435 LOG(FATAL) << "RunHandlerWorkQueue::Quiesce() is not " // Crash OK 436 "implemented, and supposed to be removed."; 437 } 438 439 private: 440 std::unique_ptr<RunHandler> run_handler_; 441 }; 442 443 } // end namespace tf 444 } // end namespace tfrt 445 446 #endif // TENSORFLOW_CORE_TFRT_EXPERIMENTAL_RUN_HANDLER_THREAD_POOLL_RUN_HANDLER_H_ 447