1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <atomic>
17 #define EIGEN_USE_THREADS
18
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/lib/core/threadpool_interface.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/context.h"
23 #include "tensorflow/core/platform/denormal.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/setround.h"
26 #include "tensorflow/core/platform/tracing.h"
27 #include "tensorflow/core/profiler/lib/connected_traceme.h"
28 #include "tensorflow/core/profiler/lib/traceme.h"
29 #include "tensorflow/core/profiler/lib/traceme_encode.h"
30 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h"
31 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.h"
32 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
33 #include "tensorflow/core/util/ptr_util.h"
34 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
35
36 namespace tfrt {
37 namespace tf {
38 namespace {
39
40 typedef typename internal::RunHandlerEnvironment::Task Task;
41 typedef Eigen::RunQueue<Task, 1024> Queue;
42
43 } // namespace
44
45 namespace internal {
RunHandlerEnvironment(tensorflow::Env * env,const tensorflow::ThreadOptions & thread_options,const std::string & name)46 RunHandlerEnvironment::RunHandlerEnvironment(
47 tensorflow::Env* env, const tensorflow::ThreadOptions& thread_options,
48 const std::string& name)
49 : env_(env), thread_options_(thread_options), name_(name) {}
50
CreateThread(std::function<void ()> f)51 RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
52 std::function<void()> f) {
53 return env_->StartThread(thread_options_, name_, [=]() {
54 // Set the processor flag to flush denormals to zero.
55 tensorflow::port::ScopedFlushDenormal flush;
56 // Set the processor rounding mode to ROUND TO NEAREST.
57 tensorflow::port::ScopedSetRound round(FE_TONEAREST);
58 if (thread_options_.numa_node != tensorflow::port::kNUMANoAffinity) {
59 tensorflow::port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
60 }
61 f();
62 });
63 }
64
CreateTask(TaskFunction f)65 RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(TaskFunction f) {
66 uint64_t id = 0;
67 if (tensorflow::tracing::EventCollector::IsEnabled()) {
68 id = tensorflow::tracing::GetUniqueArg();
69 tensorflow::tracing::RecordEvent(
70 tensorflow::tracing::EventCategory::kScheduleClosure, id);
71 }
72 return Task{
73 std::unique_ptr<TaskImpl>(new TaskImpl{
74 std::move(f),
75 tensorflow::Context(tensorflow::ContextKind::kThread),
76 id,
77 }),
78 };
79 }
80
ExecuteTask(const Task & t)81 void RunHandlerEnvironment::ExecuteTask(const Task& t) {
82 tensorflow::WithContext wc(t.f->context);
83 tensorflow::tracing::ScopedRegion region(
84 tensorflow::tracing::EventCategory::kRunClosure, t.f->trace_id);
85 t.f->f();
86 }
87
WaitOnWaiter(Waiter * waiter,Waiter * queue_head,tensorflow::mutex * mutex,int sleep_micros,bool adaptive_sleep_time)88 void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, tensorflow::mutex* mutex,
89 int sleep_micros, bool adaptive_sleep_time) {
90 {
91 tensorflow::mutex_lock l(*mutex);
92 CHECK_EQ(waiter->next, waiter); // Crash OK.
93 CHECK_EQ(waiter->prev, waiter); // Crash OK.
94
95 // Add waiter to the LIFO queue
96 waiter->prev = queue_head;
97 waiter->next = queue_head->next;
98 waiter->next->prev = waiter;
99 waiter->prev->next = waiter;
100 }
101 {
102 tensorflow::mutex_lock l(waiter->mu);
103 waiter->num_waiting_threads++;
104 int max_sleep_micros = adaptive_sleep_time
105 ? sleep_micros * waiter->num_waiting_threads
106 : sleep_micros;
107 waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
108 waiter->num_waiting_threads--;
109 }
110
111 tensorflow::mutex_lock l(*mutex);
112 // Remove waiter from the LIFO queue. Note even when a waiter wakes up due
113 // to a notification we cannot conclude the waiter is not in the queue.
114 // This is due to the fact that a thread preempted right before notifying
115 // may resume after a waiter got re-added.
116 if (waiter->next != waiter) {
117 CHECK(waiter->prev != waiter); // Crash OK.
118 waiter->next->prev = waiter->prev;
119 waiter->prev->next = waiter->next;
120 waiter->next = waiter;
121 waiter->prev = waiter;
122 } else {
123 CHECK_EQ(waiter->prev, waiter); // Crash OK.
124 }
125 }
126
ThreadWorkSource()127 ThreadWorkSource::ThreadWorkSource()
128 : non_blocking_work_sharding_factor_(
129 static_cast<int32_t>(ParamFromEnvWithDefault(
130 "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
131 non_blocking_work_queues_(non_blocking_work_sharding_factor_),
132 blocking_inflight_(0),
133 non_blocking_inflight_(0),
134 pending_tasks_(0),
135 traceme_id_(0),
136 version_(0),
137 sub_thread_pool_waiter_(nullptr) {
138 queue_waiters_.next = &queue_waiters_;
139 queue_waiters_.prev = &queue_waiters_;
140 for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
141 non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
142 }
143 }
144
~ThreadWorkSource()145 ThreadWorkSource::~ThreadWorkSource() {
146 for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
147 delete non_blocking_work_queues_[i];
148 }
149 }
150
EnqueueTask(Task t,bool is_blocking,bool enable_wake_up)151 Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking,
152 bool enable_wake_up) {
153 uint64_t id = t.f->trace_id;
154 tensorflow::profiler::TraceMe activity(
155 [id, is_blocking] {
156 return tensorflow::profiler::TraceMeEncode(
157 "Enqueue", {{"id", id}, {"is_blocking", is_blocking}});
158 },
159 tensorflow::profiler::TraceMeLevel::kInfo);
160 tensorflow::mutex* mu = nullptr;
161 Queue* task_queue = nullptr;
162 thread_local int64_t closure_counter = 0;
163
164 if (!is_blocking) {
165 int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
166 task_queue = &(non_blocking_work_queues_[queue_index]->queue);
167 mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
168 } else {
169 task_queue = &blocking_work_queue_;
170 mu = &blocking_queue_op_mu_;
171 }
172
173 {
174 tensorflow::mutex_lock l(*mu);
175 // For a given queue, only one thread can call PushFront.
176 t = task_queue->PushFront(std::move(t));
177 }
178 IncrementPendingTaskCount();
179
180 if (enable_wake_up) {
181 // Try to wake up waiting thread if there is any for the given sub thread
182 // pool. We could potentially wake up threads in other sub thread pool if we
183 // cannot find any waiting threads in the given sub thread pool.
184 // The wake up logic is best effort as the thread may be right before being
185 // added to the waiting queue or start waiting on the condition variable.
186 // However a thread will wake in short period of time in case a notification
187 // is missed.
188 Waiter* w = nullptr;
189
190 Waiter* waiter_queue;
191 tensorflow::mutex* waiter_queue_mu;
192 {
193 // When we use multiple sub thread pools, free threads wait on sub
194 // thread pool waiting queues. Wake up threads from sub thread waiting
195 // queues.
196 // The waiting queues are defined at RunHandlerPool.
197 // Get the waiter_queue and corresponding mutex. Note, the thread work
198 // source may change afterwards if a new request comes or an old request
199 // finishes.
200 tensorflow::tf_shared_lock lock(run_handler_waiter_mu_);
201 waiter_queue = sub_thread_pool_waiter_;
202 waiter_queue_mu = sub_thread_pool_waiter_mu_;
203 }
204 {
205 tensorflow::mutex_lock l(*waiter_queue_mu);
206 if (waiter_queue->next != waiter_queue) {
207 // Remove waiter from the LIFO queue
208 w = waiter_queue->next;
209
210 CHECK(w->prev != w); // Crash OK.
211 CHECK(w->next != w); // Crash OK.
212
213 w->next->prev = w->prev;
214 w->prev->next = w->next;
215
216 // Use `w->next == &w` to indicate that the waiter has been removed
217 // from the queue.
218 w->next = w;
219 w->prev = w;
220 }
221 }
222 if (w != nullptr) {
223 w->cv.notify_one();
224 }
225 }
226 VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
227 << traceme_id_.load(std::memory_order_relaxed);
228 return t;
229 }
230
PopBlockingTask()231 Task ThreadWorkSource::PopBlockingTask() {
232 return blocking_work_queue_.PopBack();
233 }
234
PopNonBlockingTask(int start_index,bool search_from_all_queue)235 Task ThreadWorkSource::PopNonBlockingTask(int start_index,
236 bool search_from_all_queue) {
237 Task t;
238 unsigned sharding_factor = NonBlockingWorkShardingFactor();
239 for (unsigned j = 0; j < sharding_factor; ++j) {
240 t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
241 ->queue.PopBack();
242 if (t.f) {
243 return t;
244 }
245 if (!search_from_all_queue) {
246 break;
247 }
248 }
249 return t;
250 }
251
TaskQueueSize(bool is_blocking)252 int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
253 if (is_blocking) {
254 return blocking_work_queue_.Size();
255 } else {
256 unsigned total_size = 0;
257 for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
258 total_size += non_blocking_work_queues_[i]->queue.Size();
259 }
260 return total_size;
261 }
262 }
263
GetTracemeId()264 int64_t ThreadWorkSource::GetTracemeId() {
265 return traceme_id_.load(std::memory_order_relaxed);
266 }
267
SetTracemeId(int64_t value)268 void ThreadWorkSource::SetTracemeId(int64_t value) { traceme_id_ = value; }
269
SetWaiter(uint64_t version,Waiter * waiter,tensorflow::mutex * mutex)270 void ThreadWorkSource::SetWaiter(uint64_t version, Waiter* waiter,
271 tensorflow::mutex* mutex) {
272 {
273 tensorflow::tf_shared_lock lock(run_handler_waiter_mu_);
274 // Most of the request won't change sub pool for recomputation.
275 // Optimization for avoiding holding exclusive lock to reduce contention.
276 if (sub_thread_pool_waiter_ == waiter) {
277 return;
278 }
279 // If the current version is a newer version, no need to update.
280 if (version_ > version) {
281 return;
282 }
283 }
284
285 tensorflow::mutex_lock l(run_handler_waiter_mu_);
286 sub_thread_pool_waiter_ = waiter;
287 sub_thread_pool_waiter_mu_ = mutex;
288 version_ = version;
289 }
290
GetInflightTaskCount(bool is_blocking)291 int64_t ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
292 std::atomic<int64_t>* counter =
293 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
294 return counter->load(std::memory_order_relaxed);
295 }
296
IncrementInflightTaskCount(bool is_blocking)297 void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
298 std::atomic<int64_t>* counter =
299 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
300 counter->fetch_add(1, std::memory_order_relaxed);
301 }
302
DecrementInflightTaskCount(bool is_blocking)303 void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
304 std::atomic<int64_t>* counter =
305 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
306 counter->fetch_sub(1, std::memory_order_relaxed);
307 }
308
GetPendingTaskCount()309 int64_t ThreadWorkSource::GetPendingTaskCount() {
310 return pending_tasks_.load(std::memory_order_acquire);
311 }
312
IncrementPendingTaskCount()313 void ThreadWorkSource::IncrementPendingTaskCount() {
314 pending_tasks_.fetch_add(1, std::memory_order_relaxed);
315 }
316
DecrementPendingTaskCount()317 void ThreadWorkSource::DecrementPendingTaskCount() {
318 // std::memory_order_release prevents reorder with op execution.
319 pending_tasks_.fetch_sub(1, std::memory_order_release);
320 }
321
NonBlockingWorkShardingFactor()322 unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
323 return non_blocking_work_sharding_factor_;
324 }
325
ToString()326 std::string ThreadWorkSource::ToString() {
327 return tensorflow::strings::StrCat(
328 "traceme_id = ", GetTracemeId(),
329 ", inter queue size = ", TaskQueueSize(true),
330 ", inter inflight = ", GetInflightTaskCount(true),
331 ", intra queue size = ", TaskQueueSize(false),
332 ", intra inflight = ", GetInflightTaskCount(false));
333 }
334
RunHandlerThreadPool(Options options,tensorflow::Env * env,const tensorflow::ThreadOptions & thread_options,const std::string & name,Eigen::MaxSizeVector<tensorflow::mutex> * waiters_mu,Eigen::MaxSizeVector<Waiter> * queue_waiters)335 RunHandlerThreadPool::RunHandlerThreadPool(
336 Options options, tensorflow::Env* env,
337 const tensorflow::ThreadOptions& thread_options, const std::string& name,
338 Eigen::MaxSizeVector<tensorflow::mutex>* waiters_mu,
339 Eigen::MaxSizeVector<Waiter>* queue_waiters)
340 : num_threads_(options.num_blocking_threads +
341 options.num_non_blocking_threads),
342 num_blocking_threads_(options.num_blocking_threads),
343 num_non_blocking_threads_(options.num_non_blocking_threads),
344 adaptive_sleep_time_(options.use_adaptive_waiting_time),
345 wait_if_no_active_request_(options.wait_if_no_active_request),
346 non_blocking_thread_sleep_time_(
347 options.non_blocking_threads_sleep_time_micro_sec),
348 blocking_thread_max_waiting_time_(
349 options.blocking_threads_max_sleep_time_micro_sec),
350 enable_wake_up_(options.enable_wake_up),
351 thread_data_(num_threads_),
352 env_(env, thread_options, name),
353 name_(name),
354 waiters_mu_(waiters_mu),
355 queue_waiters_(queue_waiters),
356 num_threads_in_sub_thread_pool_(options.num_threads_in_sub_thread_pool),
357 sub_thread_pool_end_request_percentage_(
358 options.sub_thread_request_percentage) {
359 thread_data_.resize(num_threads_);
360 for (int i = 0; i < num_threads_; ++i) {
361 thread_data_[i].new_thread_work_sources =
362 std::make_unique<Eigen::MaxSizeVector<ThreadWorkSource*>>(
363 options.max_concurrent_handler);
364 thread_data_[i].current_thread_work_sources =
365 std::make_unique<Eigen::MaxSizeVector<ThreadWorkSource*>>(
366 options.max_concurrent_handler);
367 }
368 VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
369 << num_blocking_threads_ << " blocking threads and "
370 << num_non_blocking_threads_ << " non-blocking threads.";
371 }
372
~RunHandlerThreadPool()373 RunHandlerThreadPool::~RunHandlerThreadPool() {
374 VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
375
376 cancelled_ = true;
377 for (size_t i = 0; i < thread_data_.size(); ++i) {
378 {
379 tensorflow::mutex_lock l(thread_data_[i].mu);
380 thread_data_[i].sources_not_empty.notify_all();
381 }
382 thread_data_[i].thread.reset();
383 }
384 }
385
Start()386 void RunHandlerThreadPool::Start() {
387 cancelled_ = false;
388 int num_blocking_threads = num_blocking_threads_;
389 for (int i = 0; i < num_threads_; i++) {
390 int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
391 for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
392 if (i < num_threads_in_sub_thread_pool_[j]) {
393 sub_thread_pool_id = j;
394 break;
395 }
396 }
397 thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
398 thread_data_[i].thread.reset(
399 env_.CreateThread([this, i, num_blocking_threads]() {
400 WorkerLoop(i, i < num_blocking_threads);
401 }));
402 }
403 }
404
StartOneThreadForTesting()405 void RunHandlerThreadPool::StartOneThreadForTesting() {
406 cancelled_ = false;
407 thread_data_[0].sub_thread_pool_id = 0;
408 thread_data_[0].thread.reset(
409 env_.CreateThread([this]() { WorkerLoop(0, true); }));
410 }
411
AddWorkToQueue(ThreadWorkSource * tws,bool is_blocking,TaskFunction fn)412 void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
413 bool is_blocking, TaskFunction fn) {
414 Task t = env_.CreateTask(std::move(fn));
415 t = tws->EnqueueTask(std::move(t), is_blocking, enable_wake_up_);
416 if (t.f) {
417 VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
418 << tws->GetTracemeId();
419 env_.ExecuteTask(t);
420 }
421 }
422
SetThreadWorkSources(int tid,uint64_t version,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources)423 void RunHandlerThreadPool::SetThreadWorkSources(
424 int tid, uint64_t version,
425 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
426 tensorflow::mutex_lock l(thread_data_[tid].mu);
427 if (version > thread_data_[tid].new_version) {
428 thread_data_[tid].new_version = version;
429 } else {
430 // A newer version is already updated. No need to update.
431 return;
432 }
433 auto original_size = thread_data_[tid].current_thread_work_sources->size();
434 thread_data_[tid].new_thread_work_sources->resize(0);
435 for (int i = 0; i < thread_work_sources.size(); ++i) {
436 thread_data_[tid].new_thread_work_sources->emplace_back(
437 thread_work_sources[i]);
438 }
439 if (original_size == 0) {
440 thread_data_[tid].sources_not_empty.notify_all();
441 }
442 }
443
GetPerThread()444 RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
445 thread_local RunHandlerThreadPool::PerThread per_thread_;
446 RunHandlerThreadPool::PerThread* pt = &per_thread_;
447 return pt;
448 }
449
CurrentThreadId() const450 int RunHandlerThreadPool::CurrentThreadId() const {
451 const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
452 if (pt->pool == this) {
453 return pt->thread_id;
454 } else {
455 return -1;
456 }
457 }
458
NumThreads() const459 int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
460
NumBlockingThreads() const461 int RunHandlerThreadPool::NumBlockingThreads() const {
462 return num_blocking_threads_;
463 }
464
NumNonBlockingThreads() const465 int RunHandlerThreadPool::NumNonBlockingThreads() const {
466 return num_non_blocking_threads_;
467 }
468
ThreadData()469 RunHandlerThreadPool::ThreadData::ThreadData()
470 : new_version(0), current_index(0), current_version(0) {}
471
FindTask(int searching_range_start,int searching_range_end,int thread_id,int sub_thread_pool_id,int max_blocking_inflight,bool may_steal_blocking_work,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources,bool * task_from_blocking_queue,ThreadWorkSource ** tws)472 Task RunHandlerThreadPool::FindTask(
473 int searching_range_start, int searching_range_end, int thread_id,
474 int sub_thread_pool_id, int max_blocking_inflight,
475 bool may_steal_blocking_work,
476 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
477 bool* task_from_blocking_queue, ThreadWorkSource** tws) {
478 Task t;
479 int current_index = thread_data_[thread_id].current_index;
480 *task_from_blocking_queue = false;
481
482 for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
483 if (current_index >= searching_range_end ||
484 current_index < searching_range_start) {
485 current_index = searching_range_start;
486 }
487 *tws = thread_work_sources[current_index];
488 ++current_index;
489
490 // For blocking thread, search for blocking tasks first.
491 if (may_steal_blocking_work &&
492 (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
493 t = (*tws)->PopBlockingTask();
494 if (t.f) {
495 *task_from_blocking_queue = true;
496 break;
497 }
498 }
499
500 // Search for non-blocking tasks.
501 t = (*tws)->PopNonBlockingTask(thread_id, true);
502 if (t.f) {
503 break;
504 }
505 }
506 thread_data_[thread_id].current_index = current_index;
507 return t;
508 }
509
510 // Main worker thread loop.
WorkerLoop(int thread_id,bool may_steal_blocking_work)511 void RunHandlerThreadPool::WorkerLoop(int thread_id,
512 bool may_steal_blocking_work) {
513 PerThread* pt = GetPerThread();
514 pt->pool = this;
515 pt->thread_id = thread_id;
516 static constexpr int32_t kMaxBlockingInflight = 10;
517
518 while (!cancelled_) {
519 Task t;
520 ThreadWorkSource* tws = nullptr;
521 bool task_from_blocking_queue = true;
522 int sub_thread_pool_id;
523 // Get the current thread work sources.
524 {
525 tensorflow::mutex_lock l(thread_data_[thread_id].mu);
526 if (thread_data_[thread_id].current_version <
527 thread_data_[thread_id].new_version) {
528 thread_data_[thread_id].current_version =
529 thread_data_[thread_id].new_version;
530 thread_data_[thread_id].current_thread_work_sources.swap(
531 thread_data_[thread_id].new_thread_work_sources);
532 }
533 }
534 Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
535 thread_data_[thread_id].current_thread_work_sources.get();
536 sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
537 int active_requests = thread_work_sources->size();
538 if (may_steal_blocking_work) {
539 // Each thread will first look for tasks from requests that belongs to
540 // its sub thread pool.
541 int search_range_start =
542 sub_thread_pool_id == 0
543 ? 0
544 : active_requests *
545 sub_thread_pool_end_request_percentage_[sub_thread_pool_id -
546 1];
547 int search_range_end =
548 active_requests *
549 sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
550 search_range_end = std::min(
551 active_requests, std::max(search_range_end, search_range_start + 1));
552
553 t = FindTask(search_range_start, search_range_end, thread_id,
554 sub_thread_pool_id, kMaxBlockingInflight,
555 /*may_steal_blocking_work=*/true, *thread_work_sources,
556 &task_from_blocking_queue, &tws);
557 if (!t.f) {
558 // Search from all requests if the thread cannot find tasks from
559 // requests that belong to its own sub thread pool.
560 t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
561 kMaxBlockingInflight,
562 /*may_steal_blocking_work=*/true, *thread_work_sources,
563 &task_from_blocking_queue, &tws);
564 }
565 } else {
566 // For non-blocking threads, it will always search from all pending
567 // requests.
568 t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
569 kMaxBlockingInflight,
570 /*may_steal_blocking_work=*/false, *thread_work_sources,
571 &task_from_blocking_queue, &tws);
572 }
573 if (t.f) {
574 VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
575 << " work from " << tws->GetTracemeId();
576 tws->IncrementInflightTaskCount(task_from_blocking_queue);
577 env_.ExecuteTask(t);
578 tws->DecrementInflightTaskCount(task_from_blocking_queue);
579 tws->DecrementPendingTaskCount();
580 } else {
581 tensorflow::profiler::TraceMe activity(
582 [thread_id] {
583 return tensorflow::profiler::TraceMeEncode(
584 "Sleeping", {{"thread_id", thread_id}});
585 },
586 tensorflow::profiler::TraceMeLevel::kInfo);
587 if (VLOG_IS_ON(4)) {
588 for (int i = 0; i < thread_work_sources->size(); ++i) {
589 VLOG(4) << "source id " << i << " "
590 << (*thread_work_sources)[i]->ToString();
591 }
592 }
593 WaitForWorkInSubThreadPool(thread_id, may_steal_blocking_work,
594 sub_thread_pool_id);
595 }
596 }
597 }
598
WaitForWorkInSubThreadPool(int thread_id,bool is_blocking,int sub_thread_pool_id)599 void RunHandlerThreadPool::WaitForWorkInSubThreadPool(int thread_id,
600 bool is_blocking,
601 int sub_thread_pool_id) {
602 if (wait_if_no_active_request_) {
603 tensorflow::mutex_lock l(thread_data_[thread_id].mu);
604 if (thread_data_[thread_id].new_version >
605 thread_data_[thread_id].current_version) {
606 thread_data_[thread_id].current_thread_work_sources.swap(
607 thread_data_[thread_id].new_thread_work_sources);
608 thread_data_[thread_id].current_version =
609 thread_data_[thread_id].new_version;
610 }
611 Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
612 thread_data_[thread_id].current_thread_work_sources.get();
613 bool already_wait = false;
614 while (!cancelled_ && thread_work_sources->empty()) {
615 // Wait until there is new request.
616 thread_data_[thread_id].sources_not_empty.wait(l);
617 if (thread_data_[thread_id].new_version >
618 thread_data_[thread_id].current_version) {
619 thread_data_[thread_id].current_thread_work_sources.swap(
620 thread_data_[thread_id].new_thread_work_sources);
621 thread_data_[thread_id].current_version =
622 thread_data_[thread_id].new_version;
623 thread_work_sources =
624 thread_data_[thread_id].current_thread_work_sources.get();
625 }
626 already_wait = true;
627 }
628 if (already_wait || cancelled_) {
629 return;
630 }
631 }
632
633 // The non-blocking thread will just sleep.
634 if (!is_blocking) {
635 tensorflow::Env::Default()->SleepForMicroseconds(
636 non_blocking_thread_sleep_time_);
637 return;
638 }
639
640 if (enable_wake_up_) {
641 thread_local Waiter waiter;
642 WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
643 &(*waiters_mu_)[sub_thread_pool_id],
644 blocking_thread_max_waiting_time_, adaptive_sleep_time_);
645 } else {
646 tensorflow::Env::Default()->SleepForMicroseconds(
647 blocking_thread_max_waiting_time_);
648 }
649 }
650
651 } // namespace internal
652
653 // Contains the concrete implementation of the RunHandler.
654 // Externally visible RunHandler class simply forwards the work to this one.
655 class RunHandler::Impl {
656 public:
657 explicit Impl(RunHandlerPool::Impl* pool_impl);
658
~Impl()659 ~Impl() {}
660
661 // Stores now time (in microseconds) since unix epoch when the handler is
662 // requested via RunHandlerPool::Get().
start_time_us() const663 uint64_t start_time_us() const { return start_time_us_; }
step_id() const664 int64_t step_id() const { return step_id_; }
665 void ScheduleInterOpClosure(TaskFunction fn);
666 void ScheduleIntraOpClosure(TaskFunction fn);
667
668 void Reset(int64_t step_id, const RunHandlerOptions& options);
669
pool_impl()670 RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
671
thread_pool_interface()672 tensorflow::thread::ThreadPoolInterface* thread_pool_interface() {
673 return &eigen_thread_pool_;
674 }
675
tws()676 internal::ThreadWorkSource* tws() { return &tws_; }
677
priority() const678 int64_t priority() const { return options_.priority; }
679
680 private:
681 class RunHandlerEigenThreadPool
682 : public tensorflow::thread::ThreadPoolInterface {
683 public:
RunHandlerEigenThreadPool(RunHandler::Impl * run_handler)684 explicit RunHandlerEigenThreadPool(RunHandler::Impl* run_handler)
685 : run_handler_(run_handler) {
686 DCHECK(run_handler);
687 }
688
Schedule(std::function<void ()> fn)689 void Schedule(std::function<void()> fn) override {
690 run_handler_->ScheduleIntraOpClosure(tensorflow::tfrt_stub::WrapWork(
691 run_handler_->tws()->GetTracemeId(), "intra", std::move(fn)));
692 }
693
694 int NumThreads() const override;
695 int CurrentThreadId() const override;
696
697 private:
698 RunHandler::Impl* run_handler_;
699 };
700
701 RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
702 RunHandlerEigenThreadPool eigen_thread_pool_;
703 uint64_t start_time_us_;
704 int64_t step_id_;
705 internal::ThreadWorkSource tws_;
706 RunHandlerOptions options_;
707 };
708
709 // Contains shared state across all run handlers present in the pool. Also
710 // responsible for pool management decisions.
711 // This class is thread safe.
712 class RunHandlerPool::Impl {
713 public:
Impl(Options options)714 explicit Impl(Options options)
715 : max_handlers_(options.max_concurrent_handler),
716 waiters_mu_(options.num_sub_thread_pool),
717 queue_waiters_(options.num_sub_thread_pool),
718 run_handler_thread_pool_(new internal::RunHandlerThreadPool(
719 internal::RunHandlerThreadPool::Options(
720 options.num_inter_op_threads, options.num_intra_op_threads,
721 options.wait_if_no_active_request,
722 options.non_blocking_threads_sleep_time_micro_sec,
723 options.blocking_threads_max_sleep_time_micro_sec,
724 options.use_adaptive_waiting_time, options.enable_wake_up,
725 options.max_concurrent_handler,
726 options.num_threads_in_sub_thread_pool,
727 options.sub_thread_request_percentage),
728 tensorflow::Env::Default(), tensorflow::ThreadOptions(),
729 "tf_run_handler_pool", &waiters_mu_, &queue_waiters_)),
730 iterations_(0),
731 version_(0),
732 wait_if_no_active_request_(options.wait_if_no_active_request),
733 sub_thread_pool_end_request_percentage_(
734 options.sub_thread_request_percentage) {
735 VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
736 free_handlers_.reserve(max_handlers_);
737 handlers_.reserve(max_handlers_);
738 for (int i = 0; i < max_handlers_; ++i) {
739 handlers_.emplace_back(new RunHandler::Impl(this));
740 free_handlers_.push_back(handlers_.back().get());
741 }
742 queue_waiters_.resize(options.num_sub_thread_pool);
743 waiters_mu_.resize(options.num_sub_thread_pool);
744 for (auto& queue_waiter : queue_waiters_) {
745 queue_waiter.next = &queue_waiter;
746 queue_waiter.prev = &queue_waiter;
747 }
748 run_handler_thread_pool_->Start();
749 }
750
~Impl()751 ~Impl() {
752 // Sanity check that all handlers have been returned back to the pool before
753 // destruction.
754 DCHECK_EQ(handlers_.size(), max_handlers_);
755 DCHECK_EQ(free_handlers_.size(), handlers_.size());
756 DCHECK_EQ(sorted_active_handlers_.size(), 0);
757 // Stop the threads in run_handler_thread_pool_ before freeing other
758 // pointers. Otherwise a thread may try to access a pointer after the
759 // pointer has been freed.
760 run_handler_thread_pool_.reset();
761 }
762
run_handler_thread_pool()763 internal::RunHandlerThreadPool* run_handler_thread_pool() {
764 return run_handler_thread_pool_.get();
765 }
766
has_free_handler()767 bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
768 return !free_handlers_.empty();
769 }
770
Get(int64_t step_id,int64_t timeout_in_ms,const RunHandlerOptions & options)771 std::unique_ptr<RunHandler> Get(int64_t step_id, int64_t timeout_in_ms,
772 const RunHandlerOptions& options)
773 TF_LOCKS_EXCLUDED(mu_) {
774 thread_local std::unique_ptr<
775 Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
776 thread_work_sources =
777 std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
778 new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
779 max_handlers_));
780 uint64_t version;
781 int num_active_requests;
782 RunHandler::Impl* handler_impl;
783 {
784 tensorflow::mutex_lock l(mu_);
785 if (!has_free_handler()) {
786 tensorflow::profiler::TraceMe activity(
787 [step_id] {
788 return tensorflow::profiler::TraceMeEncode(
789 "WaitingForHandler", {{"step_id", step_id}});
790 },
791 tensorflow::profiler::TraceMeLevel::kInfo);
792 if (timeout_in_ms == 0) {
793 mu_.Await(tensorflow::Condition(this, &Impl::has_free_handler));
794 } else if (!mu_.AwaitWithDeadline(
795 tensorflow::Condition(this, &Impl::has_free_handler),
796 tensorflow::EnvTime::NowNanos() +
797 timeout_in_ms * 1000 * 1000)) {
798 return nullptr;
799 }
800 }
801 // Remove the last entry from free_handlers_ and add to the end of
802 // sorted_active_handlers_.
803 handler_impl = free_handlers_.back();
804 handler_impl->Reset(step_id, options);
805 free_handlers_.pop_back();
806
807 num_active_requests = sorted_active_handlers_.size() + 1;
808 thread_work_sources->resize(num_active_requests);
809 int priority = options.priority;
810 auto it = sorted_active_handlers_.cbegin();
811 bool new_handler_inserted = false;
812 for (int i = 0; i < num_active_requests; ++i) {
813 if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
814 priority > (*it)->priority())) {
815 sorted_active_handlers_.insert(it, handler_impl);
816 new_handler_inserted = true;
817 // Point to the newly added handler.
818 --it;
819 }
820 (*thread_work_sources)[i] = (*it)->tws();
821 ++it;
822 }
823 version = ++version_;
824 }
825 RecomputePoolStats(num_active_requests, version, *thread_work_sources);
826 return tensorflow::WrapUnique<RunHandler>(new RunHandler(handler_impl));
827 }
828
ReleaseHandler(RunHandler::Impl * handler)829 void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
830 tensorflow::mutex_lock l(mu_);
831 DCHECK_GT(sorted_active_handlers_.size(), 0);
832
833 CHECK_EQ(handler->tws()->TaskQueueSize(true), 0); // Crash OK.
834 CHECK_EQ(handler->tws()->TaskQueueSize(false), 0); // Crash OK.
835
836 uint64_t now = tensorflow::EnvTime::NowMicros();
837 double elapsed = (now - handler->start_time_us()) / 1000.0;
838 time_hist_.Add(elapsed);
839
840 // Erase from and update sorted_active_handlers_. Add it to the end of
841 // free_handlers_.
842 auto iter = std::find(sorted_active_handlers_.begin(),
843 sorted_active_handlers_.end(), handler);
844 DCHECK(iter != sorted_active_handlers_.end())
845 << "Unexpected handler: " << handler
846 << " is being requested for release";
847
848 // Remove this handler from this list and add it to the list of free
849 // handlers.
850 sorted_active_handlers_.erase(iter);
851 free_handlers_.push_back(handler);
852 DCHECK_LE(free_handlers_.size(), max_handlers_);
853 LogInfo();
854
855 // We do not recompute pool stats all the time. The side effect is that
856 // there may be empty thread work sources in the queue. However, any new
857 // requests will trigger recomputation.
858 if (wait_if_no_active_request_ && sorted_active_handlers_.empty()) {
859 thread_local auto thread_work_sources =
860 std::make_unique<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
861 max_handlers_);
862 thread_work_sources->resize(0);
863 auto version = ++version_;
864 RecomputePoolStats(0, version, *thread_work_sources);
865 }
866 }
867
GetActiveHandlerPrioritiesForTesting()868 std::vector<int64_t> GetActiveHandlerPrioritiesForTesting()
869 TF_LOCKS_EXCLUDED(mu_) {
870 tensorflow::mutex_lock l(mu_);
871 std::vector<int64_t> ret;
872 for (const auto& handler_impl : sorted_active_handlers_) {
873 ret.push_back(handler_impl->priority());
874 }
875 return ret;
876 }
877
Quiesce()878 void Quiesce() TF_LOCKS_EXCLUDED(mu_) {
879 while (true) {
880 {
881 tensorflow::tf_shared_lock l(mu_);
882 bool all_empty = true;
883 for (const auto& handler : sorted_active_handlers_) {
884 if (handler->tws()->GetPendingTaskCount() != 0) {
885 all_empty = false;
886 break;
887 }
888 }
889 if (all_empty) {
890 break;
891 }
892 }
893 // Sleep
894 const int sleep_time = 50000;
895 tensorflow::Env::Default()->SleepForMicroseconds(sleep_time);
896 }
897 }
898
899 private:
900 void RecomputePoolStats(
901 int num_active_requests, uint64_t version,
902 const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
903 thread_work_sources);
904
905 void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
906
907 // Maximum number of handlers pre-created during pool construction time. The
908 // number has been chosen expecting each handler might at least want 1
909 // inter-op thread for execution (during compute intensive workloads like
910 // inference).
911 const int max_handlers_;
912
913 Eigen::MaxSizeVector<tensorflow::mutex> waiters_mu_;
914 Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
915
916 std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
917 // Thread compatible part used only by lock under RunHandlerPool.
918 // Handlers are sorted by start time.
919 // TODO(azaks): sort by the remaining latency budget.
920 // TODO(chaox): Consider other data structure for maintaining the sorted
921 // active handlers if the searching overhead(currently O(n)) becomes the
922 // bottleneck.
923 std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
924 std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
925 std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
926
927 // Histogram of elapsed runtime of every handler (in ms).
928 tensorflow::histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
929
930 int64_t iterations_ TF_GUARDED_BY(mu_);
931 tensorflow::mutex mu_;
932 int64_t version_ TF_GUARDED_BY(mu_);
933 bool wait_if_no_active_request_;
934 const std::vector<double> sub_thread_pool_end_request_percentage_;
935 };
936
RecomputePoolStats(int num_active_requests,uint64_t version,const Eigen::MaxSizeVector<internal::ThreadWorkSource * > & thread_work_sources)937 void RunHandlerPool::Impl::RecomputePoolStats(
938 int num_active_requests, uint64_t version,
939 const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
940 thread_work_sources) {
941 int sub_thread_pool_id = 0;
942 for (int i = 0; i < num_active_requests; ++i) {
943 while (
944 sub_thread_pool_id <
945 sub_thread_pool_end_request_percentage_.size() - 1 &&
946 i >= num_active_requests *
947 sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
948 sub_thread_pool_id++;
949 }
950 thread_work_sources[i]->SetWaiter(version,
951 &queue_waiters_[sub_thread_pool_id],
952 &waiters_mu_[sub_thread_pool_id]);
953 }
954
955 int num_threads = run_handler_thread_pool()->NumThreads();
956 int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
957 int num_non_blocking_threads = num_threads - num_blocking_threads;
958
959 for (int i = 0; i < num_blocking_threads + num_non_blocking_threads; ++i) {
960 run_handler_thread_pool()->SetThreadWorkSources(i, version,
961 thread_work_sources);
962 }
963 }
964
LogInfo()965 void RunHandlerPool::Impl::LogInfo() {
966 if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
967 int num_active_requests = sorted_active_handlers_.size();
968 VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
969 VLOG(1) << "Active session runs: " << num_active_requests;
970 uint64_t now = tensorflow::Env::Default()->NowMicros();
971 std::string times_str = "";
972 std::string ids_str = "";
973 auto it = sorted_active_handlers_.cbegin();
974 for (int i = 0; i < num_active_requests; ++i) {
975 if (i > 0) {
976 times_str += " ";
977 ids_str += " ";
978 }
979
980 times_str += tensorflow::strings::StrCat(
981 (now - (*it)->start_time_us()) / 1000.0, " ms.");
982 ids_str += tensorflow::strings::StrCat((*it)->tws()->GetTracemeId());
983 ++it;
984 }
985 VLOG(1) << "Elapsed times are: " << times_str;
986 VLOG(1) << "Step ids are: " << ids_str;
987 }
988 }
989
Impl(RunHandlerPool::Impl * pool_impl)990 RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
991 : pool_impl_(pool_impl), eigen_thread_pool_(this) {
992 Reset(0, RunHandlerOptions());
993 }
994
ScheduleInterOpClosure(TaskFunction fn)995 void RunHandler::Impl::ScheduleInterOpClosure(TaskFunction fn) {
996 VLOG(3) << "Scheduling inter work for " << tws()->GetTracemeId();
997 pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
998 std::move(fn));
999 }
1000
ScheduleIntraOpClosure(TaskFunction fn)1001 void RunHandler::Impl::ScheduleIntraOpClosure(TaskFunction fn) {
1002 VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
1003 pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
1004 std::move(fn));
1005 }
1006
Reset(int64_t step_id,const RunHandlerOptions & options)1007 void RunHandler::Impl::Reset(int64_t step_id,
1008 const RunHandlerOptions& options) {
1009 start_time_us_ = tensorflow::Env::Default()->NowMicros();
1010 step_id_ = step_id;
1011 options_ = options;
1012 tws_.SetTracemeId(step_id);
1013 }
1014
NumThreads() const1015 int RunHandler::Impl::RunHandlerEigenThreadPool::NumThreads() const {
1016 return run_handler_->pool_impl_->run_handler_thread_pool()->NumThreads();
1017 }
1018
CurrentThreadId() const1019 int RunHandler::Impl::RunHandlerEigenThreadPool::CurrentThreadId() const {
1020 return run_handler_->pool_impl_->run_handler_thread_pool()->CurrentThreadId();
1021 }
1022
RunHandlerPool(Options options)1023 RunHandlerPool::RunHandlerPool(Options options) : impl_(new Impl(options)) {}
1024
~RunHandlerPool()1025 RunHandlerPool::~RunHandlerPool() {}
1026
Get(int64_t step_id,int64_t timeout_in_ms,const RunHandlerOptions & options)1027 std::unique_ptr<RunHandler> RunHandlerPool::Get(
1028 int64_t step_id, int64_t timeout_in_ms, const RunHandlerOptions& options) {
1029 return impl_->Get(step_id, timeout_in_ms, options);
1030 }
1031
GetActiveHandlerPrioritiesForTesting() const1032 std::vector<int64_t> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
1033 const {
1034 return impl_->GetActiveHandlerPrioritiesForTesting();
1035 }
1036
Quiesce() const1037 void RunHandlerPool::Quiesce() const { impl_->Quiesce(); }
1038
RunHandler(Impl * impl)1039 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
1040
ScheduleInterOpClosure(TaskFunction fn)1041 void RunHandler::ScheduleInterOpClosure(TaskFunction fn) {
1042 impl_->ScheduleInterOpClosure(std::move(fn));
1043 }
1044
ScheduleIntraOpClosure(TaskFunction fn)1045 void RunHandler::ScheduleIntraOpClosure(TaskFunction fn) {
1046 impl_->ScheduleInterOpClosure(std::move(fn));
1047 }
1048
NumThreads() const1049 int RunHandler::NumThreads() const {
1050 return impl_->pool_impl()->run_handler_thread_pool()->NumThreads();
1051 }
1052
step_id() const1053 int64_t RunHandler::step_id() const { return impl_->step_id(); }
1054
1055 tensorflow::thread::ThreadPoolInterface*
AsIntraThreadPoolInterface() const1056 RunHandler::AsIntraThreadPoolInterface() const {
1057 return impl_->thread_pool_interface();
1058 }
1059
~RunHandler()1060 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
1061
GetParallelismLevel() const1062 int RunHandlerWorkQueue::GetParallelismLevel() const {
1063 return run_handler_->NumThreads();
1064 }
1065
AddTask(TaskFunction work)1066 void RunHandlerWorkQueue::AddTask(TaskFunction work) {
1067 run_handler_->ScheduleInterOpClosure(tensorflow::tfrt_stub::WrapWork(
1068 run_handler_->step_id(), "inter", std::move(work)));
1069 }
1070
AddBlockingTask(TaskFunction work,bool allow_queuing)1071 Optional<TaskFunction> RunHandlerWorkQueue::AddBlockingTask(
1072 TaskFunction work, bool allow_queuing) {
1073 LOG_EVERY_N_SEC(ERROR, 10)
1074 << "RunHandlerWorkQueue::AddBlockingTask() is not supposed to be called.";
1075 return work;
1076 }
1077
Await(ArrayRef<RCReference<AsyncValue>> values)1078 void RunHandlerWorkQueue::Await(ArrayRef<RCReference<AsyncValue>> values) {
1079 tfrt::Await(values);
1080 }
1081
IsInWorkerThread() const1082 bool RunHandlerWorkQueue::IsInWorkerThread() const {
1083 // Simply return true here as this method is not used in savedmodel workflow
1084 // and soon deprecated.
1085 //
1086 // TODO(b/198671794): Remove this method once it is removed from base.
1087 return true;
1088 }
1089
1090 } // namespace tf
1091 } // namespace tfrt
1092