xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/run_handler_thread_pool/run_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #define EIGEN_USE_THREADS
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/lib/core/threadpool_interface.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/context.h"
23 #include "tensorflow/core/platform/denormal.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/setround.h"
26 #include "tensorflow/core/platform/tracing.h"
27 #include "tensorflow/core/profiler/lib/connected_traceme.h"
28 #include "tensorflow/core/profiler/lib/traceme.h"
29 #include "tensorflow/core/profiler/lib/traceme_encode.h"
30 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h"
31 #include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.h"
32 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
33 #include "tensorflow/core/util/ptr_util.h"
34 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
35 
36 namespace tfrt {
37 namespace tf {
38 namespace {
39 
40 typedef typename internal::RunHandlerEnvironment::Task Task;
41 typedef Eigen::RunQueue<Task, 1024> Queue;
42 
43 }  // namespace
44 
45 namespace internal {
RunHandlerEnvironment(tensorflow::Env * env,const tensorflow::ThreadOptions & thread_options,const std::string & name)46 RunHandlerEnvironment::RunHandlerEnvironment(
47     tensorflow::Env* env, const tensorflow::ThreadOptions& thread_options,
48     const std::string& name)
49     : env_(env), thread_options_(thread_options), name_(name) {}
50 
CreateThread(std::function<void ()> f)51 RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
52     std::function<void()> f) {
53   return env_->StartThread(thread_options_, name_, [=]() {
54     // Set the processor flag to flush denormals to zero.
55     tensorflow::port::ScopedFlushDenormal flush;
56     // Set the processor rounding mode to ROUND TO NEAREST.
57     tensorflow::port::ScopedSetRound round(FE_TONEAREST);
58     if (thread_options_.numa_node != tensorflow::port::kNUMANoAffinity) {
59       tensorflow::port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
60     }
61     f();
62   });
63 }
64 
CreateTask(TaskFunction f)65 RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(TaskFunction f) {
66   uint64_t id = 0;
67   if (tensorflow::tracing::EventCollector::IsEnabled()) {
68     id = tensorflow::tracing::GetUniqueArg();
69     tensorflow::tracing::RecordEvent(
70         tensorflow::tracing::EventCategory::kScheduleClosure, id);
71   }
72   return Task{
73       std::unique_ptr<TaskImpl>(new TaskImpl{
74           std::move(f),
75           tensorflow::Context(tensorflow::ContextKind::kThread),
76           id,
77       }),
78   };
79 }
80 
ExecuteTask(const Task & t)81 void RunHandlerEnvironment::ExecuteTask(const Task& t) {
82   tensorflow::WithContext wc(t.f->context);
83   tensorflow::tracing::ScopedRegion region(
84       tensorflow::tracing::EventCategory::kRunClosure, t.f->trace_id);
85   t.f->f();
86 }
87 
WaitOnWaiter(Waiter * waiter,Waiter * queue_head,tensorflow::mutex * mutex,int sleep_micros,bool adaptive_sleep_time)88 void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, tensorflow::mutex* mutex,
89                   int sleep_micros, bool adaptive_sleep_time) {
90   {
91     tensorflow::mutex_lock l(*mutex);
92     CHECK_EQ(waiter->next, waiter);  // Crash OK.
93     CHECK_EQ(waiter->prev, waiter);  // Crash OK.
94 
95     // Add waiter to the LIFO queue
96     waiter->prev = queue_head;
97     waiter->next = queue_head->next;
98     waiter->next->prev = waiter;
99     waiter->prev->next = waiter;
100   }
101   {
102     tensorflow::mutex_lock l(waiter->mu);
103     waiter->num_waiting_threads++;
104     int max_sleep_micros = adaptive_sleep_time
105                                ? sleep_micros * waiter->num_waiting_threads
106                                : sleep_micros;
107     waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
108     waiter->num_waiting_threads--;
109   }
110 
111   tensorflow::mutex_lock l(*mutex);
112   // Remove waiter from the LIFO queue. Note even when a waiter wakes up due
113   // to a notification we cannot conclude the waiter is not in the queue.
114   // This is due to the fact that a thread preempted right before notifying
115   // may resume after a waiter got re-added.
116   if (waiter->next != waiter) {
117     CHECK(waiter->prev != waiter);  // Crash OK.
118     waiter->next->prev = waiter->prev;
119     waiter->prev->next = waiter->next;
120     waiter->next = waiter;
121     waiter->prev = waiter;
122   } else {
123     CHECK_EQ(waiter->prev, waiter);  // Crash OK.
124   }
125 }
126 
ThreadWorkSource()127 ThreadWorkSource::ThreadWorkSource()
128     : non_blocking_work_sharding_factor_(
129           static_cast<int32_t>(ParamFromEnvWithDefault(
130               "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
131       non_blocking_work_queues_(non_blocking_work_sharding_factor_),
132       blocking_inflight_(0),
133       non_blocking_inflight_(0),
134       pending_tasks_(0),
135       traceme_id_(0),
136       version_(0),
137       sub_thread_pool_waiter_(nullptr) {
138   queue_waiters_.next = &queue_waiters_;
139   queue_waiters_.prev = &queue_waiters_;
140   for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
141     non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
142   }
143 }
144 
~ThreadWorkSource()145 ThreadWorkSource::~ThreadWorkSource() {
146   for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
147     delete non_blocking_work_queues_[i];
148   }
149 }
150 
EnqueueTask(Task t,bool is_blocking,bool enable_wake_up)151 Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking,
152                                    bool enable_wake_up) {
153   uint64_t id = t.f->trace_id;
154   tensorflow::profiler::TraceMe activity(
155       [id, is_blocking] {
156         return tensorflow::profiler::TraceMeEncode(
157             "Enqueue", {{"id", id}, {"is_blocking", is_blocking}});
158       },
159       tensorflow::profiler::TraceMeLevel::kInfo);
160   tensorflow::mutex* mu = nullptr;
161   Queue* task_queue = nullptr;
162   thread_local int64_t closure_counter = 0;
163 
164   if (!is_blocking) {
165     int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
166     task_queue = &(non_blocking_work_queues_[queue_index]->queue);
167     mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
168   } else {
169     task_queue = &blocking_work_queue_;
170     mu = &blocking_queue_op_mu_;
171   }
172 
173   {
174     tensorflow::mutex_lock l(*mu);
175     // For a given queue, only one thread can call PushFront.
176     t = task_queue->PushFront(std::move(t));
177   }
178   IncrementPendingTaskCount();
179 
180   if (enable_wake_up) {
181     // Try to wake up waiting thread if there is any for the given sub thread
182     // pool. We could potentially wake up threads in other sub thread pool if we
183     // cannot find any waiting threads in the given sub thread pool.
184     // The wake up logic is best effort as the thread may be right before being
185     // added to the waiting queue or start waiting on the condition variable.
186     // However a thread will wake in short period of time in case a notification
187     // is missed.
188     Waiter* w = nullptr;
189 
190     Waiter* waiter_queue;
191     tensorflow::mutex* waiter_queue_mu;
192     {
193       // When we use multiple sub thread pools, free threads wait on sub
194       // thread pool waiting queues. Wake up threads from sub thread waiting
195       // queues.
196       // The waiting queues are defined at RunHandlerPool.
197       // Get the waiter_queue and corresponding mutex. Note, the thread work
198       // source may change afterwards if a new request comes or an old request
199       // finishes.
200       tensorflow::tf_shared_lock lock(run_handler_waiter_mu_);
201       waiter_queue = sub_thread_pool_waiter_;
202       waiter_queue_mu = sub_thread_pool_waiter_mu_;
203     }
204     {
205       tensorflow::mutex_lock l(*waiter_queue_mu);
206       if (waiter_queue->next != waiter_queue) {
207         // Remove waiter from the LIFO queue
208         w = waiter_queue->next;
209 
210         CHECK(w->prev != w);  // Crash OK.
211         CHECK(w->next != w);  // Crash OK.
212 
213         w->next->prev = w->prev;
214         w->prev->next = w->next;
215 
216         // Use `w->next == &w` to indicate that the waiter has been removed
217         // from the queue.
218         w->next = w;
219         w->prev = w;
220       }
221     }
222     if (w != nullptr) {
223       w->cv.notify_one();
224     }
225   }
226   VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
227           << traceme_id_.load(std::memory_order_relaxed);
228   return t;
229 }
230 
PopBlockingTask()231 Task ThreadWorkSource::PopBlockingTask() {
232   return blocking_work_queue_.PopBack();
233 }
234 
PopNonBlockingTask(int start_index,bool search_from_all_queue)235 Task ThreadWorkSource::PopNonBlockingTask(int start_index,
236                                           bool search_from_all_queue) {
237   Task t;
238   unsigned sharding_factor = NonBlockingWorkShardingFactor();
239   for (unsigned j = 0; j < sharding_factor; ++j) {
240     t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
241             ->queue.PopBack();
242     if (t.f) {
243       return t;
244     }
245     if (!search_from_all_queue) {
246       break;
247     }
248   }
249   return t;
250 }
251 
TaskQueueSize(bool is_blocking)252 int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
253   if (is_blocking) {
254     return blocking_work_queue_.Size();
255   } else {
256     unsigned total_size = 0;
257     for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
258       total_size += non_blocking_work_queues_[i]->queue.Size();
259     }
260     return total_size;
261   }
262 }
263 
GetTracemeId()264 int64_t ThreadWorkSource::GetTracemeId() {
265   return traceme_id_.load(std::memory_order_relaxed);
266 }
267 
SetTracemeId(int64_t value)268 void ThreadWorkSource::SetTracemeId(int64_t value) { traceme_id_ = value; }
269 
SetWaiter(uint64_t version,Waiter * waiter,tensorflow::mutex * mutex)270 void ThreadWorkSource::SetWaiter(uint64_t version, Waiter* waiter,
271                                  tensorflow::mutex* mutex) {
272   {
273     tensorflow::tf_shared_lock lock(run_handler_waiter_mu_);
274     // Most of the request won't change sub pool for recomputation.
275     // Optimization for avoiding holding exclusive lock to reduce contention.
276     if (sub_thread_pool_waiter_ == waiter) {
277       return;
278     }
279     // If the current version is a newer version, no need to update.
280     if (version_ > version) {
281       return;
282     }
283   }
284 
285   tensorflow::mutex_lock l(run_handler_waiter_mu_);
286   sub_thread_pool_waiter_ = waiter;
287   sub_thread_pool_waiter_mu_ = mutex;
288   version_ = version;
289 }
290 
GetInflightTaskCount(bool is_blocking)291 int64_t ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
292   std::atomic<int64_t>* counter =
293       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
294   return counter->load(std::memory_order_relaxed);
295 }
296 
IncrementInflightTaskCount(bool is_blocking)297 void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
298   std::atomic<int64_t>* counter =
299       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
300   counter->fetch_add(1, std::memory_order_relaxed);
301 }
302 
DecrementInflightTaskCount(bool is_blocking)303 void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
304   std::atomic<int64_t>* counter =
305       is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
306   counter->fetch_sub(1, std::memory_order_relaxed);
307 }
308 
GetPendingTaskCount()309 int64_t ThreadWorkSource::GetPendingTaskCount() {
310   return pending_tasks_.load(std::memory_order_acquire);
311 }
312 
IncrementPendingTaskCount()313 void ThreadWorkSource::IncrementPendingTaskCount() {
314   pending_tasks_.fetch_add(1, std::memory_order_relaxed);
315 }
316 
DecrementPendingTaskCount()317 void ThreadWorkSource::DecrementPendingTaskCount() {
318   // std::memory_order_release prevents reorder with op execution.
319   pending_tasks_.fetch_sub(1, std::memory_order_release);
320 }
321 
NonBlockingWorkShardingFactor()322 unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
323   return non_blocking_work_sharding_factor_;
324 }
325 
ToString()326 std::string ThreadWorkSource::ToString() {
327   return tensorflow::strings::StrCat(
328       "traceme_id = ", GetTracemeId(),
329       ", inter queue size = ", TaskQueueSize(true),
330       ", inter inflight = ", GetInflightTaskCount(true),
331       ", intra queue size = ", TaskQueueSize(false),
332       ", intra inflight = ", GetInflightTaskCount(false));
333 }
334 
RunHandlerThreadPool(Options options,tensorflow::Env * env,const tensorflow::ThreadOptions & thread_options,const std::string & name,Eigen::MaxSizeVector<tensorflow::mutex> * waiters_mu,Eigen::MaxSizeVector<Waiter> * queue_waiters)335 RunHandlerThreadPool::RunHandlerThreadPool(
336     Options options, tensorflow::Env* env,
337     const tensorflow::ThreadOptions& thread_options, const std::string& name,
338     Eigen::MaxSizeVector<tensorflow::mutex>* waiters_mu,
339     Eigen::MaxSizeVector<Waiter>* queue_waiters)
340     : num_threads_(options.num_blocking_threads +
341                    options.num_non_blocking_threads),
342       num_blocking_threads_(options.num_blocking_threads),
343       num_non_blocking_threads_(options.num_non_blocking_threads),
344       adaptive_sleep_time_(options.use_adaptive_waiting_time),
345       wait_if_no_active_request_(options.wait_if_no_active_request),
346       non_blocking_thread_sleep_time_(
347           options.non_blocking_threads_sleep_time_micro_sec),
348       blocking_thread_max_waiting_time_(
349           options.blocking_threads_max_sleep_time_micro_sec),
350       enable_wake_up_(options.enable_wake_up),
351       thread_data_(num_threads_),
352       env_(env, thread_options, name),
353       name_(name),
354       waiters_mu_(waiters_mu),
355       queue_waiters_(queue_waiters),
356       num_threads_in_sub_thread_pool_(options.num_threads_in_sub_thread_pool),
357       sub_thread_pool_end_request_percentage_(
358           options.sub_thread_request_percentage) {
359   thread_data_.resize(num_threads_);
360   for (int i = 0; i < num_threads_; ++i) {
361     thread_data_[i].new_thread_work_sources =
362         std::make_unique<Eigen::MaxSizeVector<ThreadWorkSource*>>(
363             options.max_concurrent_handler);
364     thread_data_[i].current_thread_work_sources =
365         std::make_unique<Eigen::MaxSizeVector<ThreadWorkSource*>>(
366             options.max_concurrent_handler);
367   }
368   VLOG(1) << "Creating RunHandlerThreadPool " << name << " with  "
369           << num_blocking_threads_ << " blocking threads and "
370           << num_non_blocking_threads_ << " non-blocking threads.";
371 }
372 
~RunHandlerThreadPool()373 RunHandlerThreadPool::~RunHandlerThreadPool() {
374   VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
375 
376   cancelled_ = true;
377   for (size_t i = 0; i < thread_data_.size(); ++i) {
378     {
379       tensorflow::mutex_lock l(thread_data_[i].mu);
380       thread_data_[i].sources_not_empty.notify_all();
381     }
382     thread_data_[i].thread.reset();
383   }
384 }
385 
Start()386 void RunHandlerThreadPool::Start() {
387   cancelled_ = false;
388   int num_blocking_threads = num_blocking_threads_;
389   for (int i = 0; i < num_threads_; i++) {
390     int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
391     for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
392       if (i < num_threads_in_sub_thread_pool_[j]) {
393         sub_thread_pool_id = j;
394         break;
395       }
396     }
397     thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
398     thread_data_[i].thread.reset(
399         env_.CreateThread([this, i, num_blocking_threads]() {
400           WorkerLoop(i, i < num_blocking_threads);
401         }));
402   }
403 }
404 
StartOneThreadForTesting()405 void RunHandlerThreadPool::StartOneThreadForTesting() {
406   cancelled_ = false;
407   thread_data_[0].sub_thread_pool_id = 0;
408   thread_data_[0].thread.reset(
409       env_.CreateThread([this]() { WorkerLoop(0, true); }));
410 }
411 
AddWorkToQueue(ThreadWorkSource * tws,bool is_blocking,TaskFunction fn)412 void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
413                                           bool is_blocking, TaskFunction fn) {
414   Task t = env_.CreateTask(std::move(fn));
415   t = tws->EnqueueTask(std::move(t), is_blocking, enable_wake_up_);
416   if (t.f) {
417     VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
418             << tws->GetTracemeId();
419     env_.ExecuteTask(t);
420   }
421 }
422 
SetThreadWorkSources(int tid,uint64_t version,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources)423 void RunHandlerThreadPool::SetThreadWorkSources(
424     int tid, uint64_t version,
425     const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
426   tensorflow::mutex_lock l(thread_data_[tid].mu);
427   if (version > thread_data_[tid].new_version) {
428     thread_data_[tid].new_version = version;
429   } else {
430     // A newer version is already updated. No need to update.
431     return;
432   }
433   auto original_size = thread_data_[tid].current_thread_work_sources->size();
434   thread_data_[tid].new_thread_work_sources->resize(0);
435   for (int i = 0; i < thread_work_sources.size(); ++i) {
436     thread_data_[tid].new_thread_work_sources->emplace_back(
437         thread_work_sources[i]);
438   }
439   if (original_size == 0) {
440     thread_data_[tid].sources_not_empty.notify_all();
441   }
442 }
443 
GetPerThread()444 RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
445   thread_local RunHandlerThreadPool::PerThread per_thread_;
446   RunHandlerThreadPool::PerThread* pt = &per_thread_;
447   return pt;
448 }
449 
CurrentThreadId() const450 int RunHandlerThreadPool::CurrentThreadId() const {
451   const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
452   if (pt->pool == this) {
453     return pt->thread_id;
454   } else {
455     return -1;
456   }
457 }
458 
NumThreads() const459 int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
460 
NumBlockingThreads() const461 int RunHandlerThreadPool::NumBlockingThreads() const {
462   return num_blocking_threads_;
463 }
464 
NumNonBlockingThreads() const465 int RunHandlerThreadPool::NumNonBlockingThreads() const {
466   return num_non_blocking_threads_;
467 }
468 
ThreadData()469 RunHandlerThreadPool::ThreadData::ThreadData()
470     : new_version(0), current_index(0), current_version(0) {}
471 
FindTask(int searching_range_start,int searching_range_end,int thread_id,int sub_thread_pool_id,int max_blocking_inflight,bool may_steal_blocking_work,const Eigen::MaxSizeVector<ThreadWorkSource * > & thread_work_sources,bool * task_from_blocking_queue,ThreadWorkSource ** tws)472 Task RunHandlerThreadPool::FindTask(
473     int searching_range_start, int searching_range_end, int thread_id,
474     int sub_thread_pool_id, int max_blocking_inflight,
475     bool may_steal_blocking_work,
476     const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
477     bool* task_from_blocking_queue, ThreadWorkSource** tws) {
478   Task t;
479   int current_index = thread_data_[thread_id].current_index;
480   *task_from_blocking_queue = false;
481 
482   for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
483     if (current_index >= searching_range_end ||
484         current_index < searching_range_start) {
485       current_index = searching_range_start;
486     }
487     *tws = thread_work_sources[current_index];
488     ++current_index;
489 
490     // For blocking thread, search for blocking tasks first.
491     if (may_steal_blocking_work &&
492         (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
493       t = (*tws)->PopBlockingTask();
494       if (t.f) {
495         *task_from_blocking_queue = true;
496         break;
497       }
498     }
499 
500     // Search for non-blocking tasks.
501     t = (*tws)->PopNonBlockingTask(thread_id, true);
502     if (t.f) {
503       break;
504     }
505   }
506   thread_data_[thread_id].current_index = current_index;
507   return t;
508 }
509 
510 // Main worker thread loop.
WorkerLoop(int thread_id,bool may_steal_blocking_work)511 void RunHandlerThreadPool::WorkerLoop(int thread_id,
512                                       bool may_steal_blocking_work) {
513   PerThread* pt = GetPerThread();
514   pt->pool = this;
515   pt->thread_id = thread_id;
516   static constexpr int32_t kMaxBlockingInflight = 10;
517 
518   while (!cancelled_) {
519     Task t;
520     ThreadWorkSource* tws = nullptr;
521     bool task_from_blocking_queue = true;
522     int sub_thread_pool_id;
523     // Get the current thread work sources.
524     {
525       tensorflow::mutex_lock l(thread_data_[thread_id].mu);
526       if (thread_data_[thread_id].current_version <
527           thread_data_[thread_id].new_version) {
528         thread_data_[thread_id].current_version =
529             thread_data_[thread_id].new_version;
530         thread_data_[thread_id].current_thread_work_sources.swap(
531             thread_data_[thread_id].new_thread_work_sources);
532       }
533     }
534     Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
535         thread_data_[thread_id].current_thread_work_sources.get();
536     sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
537     int active_requests = thread_work_sources->size();
538     if (may_steal_blocking_work) {
539       // Each thread will first look for tasks from requests that belongs to
540       // its sub thread pool.
541       int search_range_start =
542           sub_thread_pool_id == 0
543               ? 0
544               : active_requests *
545                     sub_thread_pool_end_request_percentage_[sub_thread_pool_id -
546                                                             1];
547       int search_range_end =
548           active_requests *
549           sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
550       search_range_end = std::min(
551           active_requests, std::max(search_range_end, search_range_start + 1));
552 
553       t = FindTask(search_range_start, search_range_end, thread_id,
554                    sub_thread_pool_id, kMaxBlockingInflight,
555                    /*may_steal_blocking_work=*/true, *thread_work_sources,
556                    &task_from_blocking_queue, &tws);
557       if (!t.f) {
558         // Search from all requests if the thread cannot find tasks from
559         // requests that belong to its own sub thread pool.
560         t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
561                      kMaxBlockingInflight,
562                      /*may_steal_blocking_work=*/true, *thread_work_sources,
563                      &task_from_blocking_queue, &tws);
564       }
565     } else {
566       // For non-blocking threads, it will always search from all pending
567       // requests.
568       t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
569                    kMaxBlockingInflight,
570                    /*may_steal_blocking_work=*/false, *thread_work_sources,
571                    &task_from_blocking_queue, &tws);
572     }
573     if (t.f) {
574       VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
575               << " work from " << tws->GetTracemeId();
576       tws->IncrementInflightTaskCount(task_from_blocking_queue);
577       env_.ExecuteTask(t);
578       tws->DecrementInflightTaskCount(task_from_blocking_queue);
579       tws->DecrementPendingTaskCount();
580     } else {
581       tensorflow::profiler::TraceMe activity(
582           [thread_id] {
583             return tensorflow::profiler::TraceMeEncode(
584                 "Sleeping", {{"thread_id", thread_id}});
585           },
586           tensorflow::profiler::TraceMeLevel::kInfo);
587       if (VLOG_IS_ON(4)) {
588         for (int i = 0; i < thread_work_sources->size(); ++i) {
589           VLOG(4) << "source id " << i << " "
590                   << (*thread_work_sources)[i]->ToString();
591         }
592       }
593       WaitForWorkInSubThreadPool(thread_id, may_steal_blocking_work,
594                                  sub_thread_pool_id);
595     }
596   }
597 }
598 
WaitForWorkInSubThreadPool(int thread_id,bool is_blocking,int sub_thread_pool_id)599 void RunHandlerThreadPool::WaitForWorkInSubThreadPool(int thread_id,
600                                                       bool is_blocking,
601                                                       int sub_thread_pool_id) {
602   if (wait_if_no_active_request_) {
603     tensorflow::mutex_lock l(thread_data_[thread_id].mu);
604     if (thread_data_[thread_id].new_version >
605         thread_data_[thread_id].current_version) {
606       thread_data_[thread_id].current_thread_work_sources.swap(
607           thread_data_[thread_id].new_thread_work_sources);
608       thread_data_[thread_id].current_version =
609           thread_data_[thread_id].new_version;
610     }
611     Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
612         thread_data_[thread_id].current_thread_work_sources.get();
613     bool already_wait = false;
614     while (!cancelled_ && thread_work_sources->empty()) {
615       // Wait until there is new request.
616       thread_data_[thread_id].sources_not_empty.wait(l);
617       if (thread_data_[thread_id].new_version >
618           thread_data_[thread_id].current_version) {
619         thread_data_[thread_id].current_thread_work_sources.swap(
620             thread_data_[thread_id].new_thread_work_sources);
621         thread_data_[thread_id].current_version =
622             thread_data_[thread_id].new_version;
623         thread_work_sources =
624             thread_data_[thread_id].current_thread_work_sources.get();
625       }
626       already_wait = true;
627     }
628     if (already_wait || cancelled_) {
629       return;
630     }
631   }
632 
633   // The non-blocking thread will just sleep.
634   if (!is_blocking) {
635     tensorflow::Env::Default()->SleepForMicroseconds(
636         non_blocking_thread_sleep_time_);
637     return;
638   }
639 
640   if (enable_wake_up_) {
641     thread_local Waiter waiter;
642     WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
643                  &(*waiters_mu_)[sub_thread_pool_id],
644                  blocking_thread_max_waiting_time_, adaptive_sleep_time_);
645   } else {
646     tensorflow::Env::Default()->SleepForMicroseconds(
647         blocking_thread_max_waiting_time_);
648   }
649 }
650 
651 }  // namespace internal
652 
653 // Contains the concrete implementation of the RunHandler.
654 // Externally visible RunHandler class simply forwards the work to this one.
655 class RunHandler::Impl {
656  public:
657   explicit Impl(RunHandlerPool::Impl* pool_impl);
658 
~Impl()659   ~Impl() {}
660 
661   // Stores now time (in microseconds) since unix epoch when the handler is
662   // requested via RunHandlerPool::Get().
start_time_us() const663   uint64_t start_time_us() const { return start_time_us_; }
step_id() const664   int64_t step_id() const { return step_id_; }
665   void ScheduleInterOpClosure(TaskFunction fn);
666   void ScheduleIntraOpClosure(TaskFunction fn);
667 
668   void Reset(int64_t step_id, const RunHandlerOptions& options);
669 
pool_impl()670   RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
671 
thread_pool_interface()672   tensorflow::thread::ThreadPoolInterface* thread_pool_interface() {
673     return &eigen_thread_pool_;
674   }
675 
tws()676   internal::ThreadWorkSource* tws() { return &tws_; }
677 
priority() const678   int64_t priority() const { return options_.priority; }
679 
680  private:
681   class RunHandlerEigenThreadPool
682       : public tensorflow::thread::ThreadPoolInterface {
683    public:
RunHandlerEigenThreadPool(RunHandler::Impl * run_handler)684     explicit RunHandlerEigenThreadPool(RunHandler::Impl* run_handler)
685         : run_handler_(run_handler) {
686       DCHECK(run_handler);
687     }
688 
Schedule(std::function<void ()> fn)689     void Schedule(std::function<void()> fn) override {
690       run_handler_->ScheduleIntraOpClosure(tensorflow::tfrt_stub::WrapWork(
691           run_handler_->tws()->GetTracemeId(), "intra", std::move(fn)));
692     }
693 
694     int NumThreads() const override;
695     int CurrentThreadId() const override;
696 
697    private:
698     RunHandler::Impl* run_handler_;
699   };
700 
701   RunHandlerPool::Impl* pool_impl_;  // NOT OWNED.
702   RunHandlerEigenThreadPool eigen_thread_pool_;
703   uint64_t start_time_us_;
704   int64_t step_id_;
705   internal::ThreadWorkSource tws_;
706   RunHandlerOptions options_;
707 };
708 
709 // Contains shared state across all run handlers present in the pool. Also
710 // responsible for pool management decisions.
711 // This class is thread safe.
712 class RunHandlerPool::Impl {
713  public:
Impl(Options options)714   explicit Impl(Options options)
715       : max_handlers_(options.max_concurrent_handler),
716         waiters_mu_(options.num_sub_thread_pool),
717         queue_waiters_(options.num_sub_thread_pool),
718         run_handler_thread_pool_(new internal::RunHandlerThreadPool(
719             internal::RunHandlerThreadPool::Options(
720                 options.num_inter_op_threads, options.num_intra_op_threads,
721                 options.wait_if_no_active_request,
722                 options.non_blocking_threads_sleep_time_micro_sec,
723                 options.blocking_threads_max_sleep_time_micro_sec,
724                 options.use_adaptive_waiting_time, options.enable_wake_up,
725                 options.max_concurrent_handler,
726                 options.num_threads_in_sub_thread_pool,
727                 options.sub_thread_request_percentage),
728             tensorflow::Env::Default(), tensorflow::ThreadOptions(),
729             "tf_run_handler_pool", &waiters_mu_, &queue_waiters_)),
730         iterations_(0),
731         version_(0),
732         wait_if_no_active_request_(options.wait_if_no_active_request),
733         sub_thread_pool_end_request_percentage_(
734             options.sub_thread_request_percentage) {
735     VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
736     free_handlers_.reserve(max_handlers_);
737     handlers_.reserve(max_handlers_);
738     for (int i = 0; i < max_handlers_; ++i) {
739       handlers_.emplace_back(new RunHandler::Impl(this));
740       free_handlers_.push_back(handlers_.back().get());
741     }
742     queue_waiters_.resize(options.num_sub_thread_pool);
743     waiters_mu_.resize(options.num_sub_thread_pool);
744     for (auto& queue_waiter : queue_waiters_) {
745       queue_waiter.next = &queue_waiter;
746       queue_waiter.prev = &queue_waiter;
747     }
748     run_handler_thread_pool_->Start();
749   }
750 
~Impl()751   ~Impl() {
752     // Sanity check that all handlers have been returned back to the pool before
753     // destruction.
754     DCHECK_EQ(handlers_.size(), max_handlers_);
755     DCHECK_EQ(free_handlers_.size(), handlers_.size());
756     DCHECK_EQ(sorted_active_handlers_.size(), 0);
757     // Stop the threads in run_handler_thread_pool_ before freeing other
758     // pointers. Otherwise a thread may try to access a pointer after the
759     // pointer has been freed.
760     run_handler_thread_pool_.reset();
761   }
762 
run_handler_thread_pool()763   internal::RunHandlerThreadPool* run_handler_thread_pool() {
764     return run_handler_thread_pool_.get();
765   }
766 
has_free_handler()767   bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
768     return !free_handlers_.empty();
769   }
770 
Get(int64_t step_id,int64_t timeout_in_ms,const RunHandlerOptions & options)771   std::unique_ptr<RunHandler> Get(int64_t step_id, int64_t timeout_in_ms,
772                                   const RunHandlerOptions& options)
773       TF_LOCKS_EXCLUDED(mu_) {
774     thread_local std::unique_ptr<
775         Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
776         thread_work_sources =
777             std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
778                 new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
779                     max_handlers_));
780     uint64_t version;
781     int num_active_requests;
782     RunHandler::Impl* handler_impl;
783     {
784       tensorflow::mutex_lock l(mu_);
785       if (!has_free_handler()) {
786         tensorflow::profiler::TraceMe activity(
787             [step_id] {
788               return tensorflow::profiler::TraceMeEncode(
789                   "WaitingForHandler", {{"step_id", step_id}});
790             },
791             tensorflow::profiler::TraceMeLevel::kInfo);
792         if (timeout_in_ms == 0) {
793           mu_.Await(tensorflow::Condition(this, &Impl::has_free_handler));
794         } else if (!mu_.AwaitWithDeadline(
795                        tensorflow::Condition(this, &Impl::has_free_handler),
796                        tensorflow::EnvTime::NowNanos() +
797                            timeout_in_ms * 1000 * 1000)) {
798           return nullptr;
799         }
800       }
801       // Remove the last entry from free_handlers_ and add to the end of
802       // sorted_active_handlers_.
803       handler_impl = free_handlers_.back();
804       handler_impl->Reset(step_id, options);
805       free_handlers_.pop_back();
806 
807       num_active_requests = sorted_active_handlers_.size() + 1;
808       thread_work_sources->resize(num_active_requests);
809       int priority = options.priority;
810       auto it = sorted_active_handlers_.cbegin();
811       bool new_handler_inserted = false;
812       for (int i = 0; i < num_active_requests; ++i) {
813         if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
814                                       priority > (*it)->priority())) {
815           sorted_active_handlers_.insert(it, handler_impl);
816           new_handler_inserted = true;
817           // Point to the newly added handler.
818           --it;
819         }
820         (*thread_work_sources)[i] = (*it)->tws();
821         ++it;
822       }
823       version = ++version_;
824     }
825     RecomputePoolStats(num_active_requests, version, *thread_work_sources);
826     return tensorflow::WrapUnique<RunHandler>(new RunHandler(handler_impl));
827   }
828 
ReleaseHandler(RunHandler::Impl * handler)829   void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
830     tensorflow::mutex_lock l(mu_);
831     DCHECK_GT(sorted_active_handlers_.size(), 0);
832 
833     CHECK_EQ(handler->tws()->TaskQueueSize(true), 0);   // Crash OK.
834     CHECK_EQ(handler->tws()->TaskQueueSize(false), 0);  // Crash OK.
835 
836     uint64_t now = tensorflow::EnvTime::NowMicros();
837     double elapsed = (now - handler->start_time_us()) / 1000.0;
838     time_hist_.Add(elapsed);
839 
840     // Erase from and update sorted_active_handlers_. Add it to the end of
841     // free_handlers_.
842     auto iter = std::find(sorted_active_handlers_.begin(),
843                           sorted_active_handlers_.end(), handler);
844     DCHECK(iter != sorted_active_handlers_.end())
845         << "Unexpected handler: " << handler
846         << " is being requested for release";
847 
848     // Remove this handler from this list and add it to the list of free
849     // handlers.
850     sorted_active_handlers_.erase(iter);
851     free_handlers_.push_back(handler);
852     DCHECK_LE(free_handlers_.size(), max_handlers_);
853     LogInfo();
854 
855     // We do not recompute pool stats all the time. The side effect is that
856     // there may be empty thread work sources in the queue. However, any new
857     // requests will trigger recomputation.
858     if (wait_if_no_active_request_ && sorted_active_handlers_.empty()) {
859       thread_local auto thread_work_sources =
860           std::make_unique<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
861               max_handlers_);
862       thread_work_sources->resize(0);
863       auto version = ++version_;
864       RecomputePoolStats(0, version, *thread_work_sources);
865     }
866   }
867 
GetActiveHandlerPrioritiesForTesting()868   std::vector<int64_t> GetActiveHandlerPrioritiesForTesting()
869       TF_LOCKS_EXCLUDED(mu_) {
870     tensorflow::mutex_lock l(mu_);
871     std::vector<int64_t> ret;
872     for (const auto& handler_impl : sorted_active_handlers_) {
873       ret.push_back(handler_impl->priority());
874     }
875     return ret;
876   }
877 
Quiesce()878   void Quiesce() TF_LOCKS_EXCLUDED(mu_) {
879     while (true) {
880       {
881         tensorflow::tf_shared_lock l(mu_);
882         bool all_empty = true;
883         for (const auto& handler : sorted_active_handlers_) {
884           if (handler->tws()->GetPendingTaskCount() != 0) {
885             all_empty = false;
886             break;
887           }
888         }
889         if (all_empty) {
890           break;
891         }
892       }
893       // Sleep
894       const int sleep_time = 50000;
895       tensorflow::Env::Default()->SleepForMicroseconds(sleep_time);
896     }
897   }
898 
899  private:
900   void RecomputePoolStats(
901       int num_active_requests, uint64_t version,
902       const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
903           thread_work_sources);
904 
905   void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
906 
907   // Maximum number of handlers pre-created during pool construction time. The
908   // number has been chosen expecting each handler might at least want 1
909   // inter-op thread for execution (during compute intensive workloads like
910   // inference).
911   const int max_handlers_;
912 
913   Eigen::MaxSizeVector<tensorflow::mutex> waiters_mu_;
914   Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
915 
916   std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
917   // Thread compatible part used only by lock under RunHandlerPool.
918   // Handlers are sorted by start time.
919   // TODO(azaks): sort by the remaining latency budget.
920   // TODO(chaox): Consider other data structure for maintaining the sorted
921   // active handlers if the searching overhead(currently O(n)) becomes the
922   // bottleneck.
923   std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
924   std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
925   std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
926 
927   // Histogram of elapsed runtime of every handler (in ms).
928   tensorflow::histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
929 
930   int64_t iterations_ TF_GUARDED_BY(mu_);
931   tensorflow::mutex mu_;
932   int64_t version_ TF_GUARDED_BY(mu_);
933   bool wait_if_no_active_request_;
934   const std::vector<double> sub_thread_pool_end_request_percentage_;
935 };
936 
RecomputePoolStats(int num_active_requests,uint64_t version,const Eigen::MaxSizeVector<internal::ThreadWorkSource * > & thread_work_sources)937 void RunHandlerPool::Impl::RecomputePoolStats(
938     int num_active_requests, uint64_t version,
939     const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
940         thread_work_sources) {
941   int sub_thread_pool_id = 0;
942   for (int i = 0; i < num_active_requests; ++i) {
943     while (
944         sub_thread_pool_id <
945             sub_thread_pool_end_request_percentage_.size() - 1 &&
946         i >= num_active_requests *
947                  sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
948       sub_thread_pool_id++;
949     }
950     thread_work_sources[i]->SetWaiter(version,
951                                       &queue_waiters_[sub_thread_pool_id],
952                                       &waiters_mu_[sub_thread_pool_id]);
953   }
954 
955   int num_threads = run_handler_thread_pool()->NumThreads();
956   int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
957   int num_non_blocking_threads = num_threads - num_blocking_threads;
958 
959   for (int i = 0; i < num_blocking_threads + num_non_blocking_threads; ++i) {
960     run_handler_thread_pool()->SetThreadWorkSources(i, version,
961                                                     thread_work_sources);
962   }
963 }
964 
LogInfo()965 void RunHandlerPool::Impl::LogInfo() {
966   if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
967     int num_active_requests = sorted_active_handlers_.size();
968     VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
969     VLOG(1) << "Active session runs: " << num_active_requests;
970     uint64_t now = tensorflow::Env::Default()->NowMicros();
971     std::string times_str = "";
972     std::string ids_str = "";
973     auto it = sorted_active_handlers_.cbegin();
974     for (int i = 0; i < num_active_requests; ++i) {
975       if (i > 0) {
976         times_str += " ";
977         ids_str += " ";
978       }
979 
980       times_str += tensorflow::strings::StrCat(
981           (now - (*it)->start_time_us()) / 1000.0, " ms.");
982       ids_str += tensorflow::strings::StrCat((*it)->tws()->GetTracemeId());
983       ++it;
984     }
985     VLOG(1) << "Elapsed times are: " << times_str;
986     VLOG(1) << "Step ids are: " << ids_str;
987   }
988 }
989 
Impl(RunHandlerPool::Impl * pool_impl)990 RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
991     : pool_impl_(pool_impl), eigen_thread_pool_(this) {
992   Reset(0, RunHandlerOptions());
993 }
994 
ScheduleInterOpClosure(TaskFunction fn)995 void RunHandler::Impl::ScheduleInterOpClosure(TaskFunction fn) {
996   VLOG(3) << "Scheduling inter work for  " << tws()->GetTracemeId();
997   pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
998                                                         std::move(fn));
999 }
1000 
ScheduleIntraOpClosure(TaskFunction fn)1001 void RunHandler::Impl::ScheduleIntraOpClosure(TaskFunction fn) {
1002   VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
1003   pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
1004                                                         std::move(fn));
1005 }
1006 
Reset(int64_t step_id,const RunHandlerOptions & options)1007 void RunHandler::Impl::Reset(int64_t step_id,
1008                              const RunHandlerOptions& options) {
1009   start_time_us_ = tensorflow::Env::Default()->NowMicros();
1010   step_id_ = step_id;
1011   options_ = options;
1012   tws_.SetTracemeId(step_id);
1013 }
1014 
NumThreads() const1015 int RunHandler::Impl::RunHandlerEigenThreadPool::NumThreads() const {
1016   return run_handler_->pool_impl_->run_handler_thread_pool()->NumThreads();
1017 }
1018 
CurrentThreadId() const1019 int RunHandler::Impl::RunHandlerEigenThreadPool::CurrentThreadId() const {
1020   return run_handler_->pool_impl_->run_handler_thread_pool()->CurrentThreadId();
1021 }
1022 
RunHandlerPool(Options options)1023 RunHandlerPool::RunHandlerPool(Options options) : impl_(new Impl(options)) {}
1024 
~RunHandlerPool()1025 RunHandlerPool::~RunHandlerPool() {}
1026 
Get(int64_t step_id,int64_t timeout_in_ms,const RunHandlerOptions & options)1027 std::unique_ptr<RunHandler> RunHandlerPool::Get(
1028     int64_t step_id, int64_t timeout_in_ms, const RunHandlerOptions& options) {
1029   return impl_->Get(step_id, timeout_in_ms, options);
1030 }
1031 
GetActiveHandlerPrioritiesForTesting() const1032 std::vector<int64_t> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
1033     const {
1034   return impl_->GetActiveHandlerPrioritiesForTesting();
1035 }
1036 
Quiesce() const1037 void RunHandlerPool::Quiesce() const { impl_->Quiesce(); }
1038 
RunHandler(Impl * impl)1039 RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
1040 
ScheduleInterOpClosure(TaskFunction fn)1041 void RunHandler::ScheduleInterOpClosure(TaskFunction fn) {
1042   impl_->ScheduleInterOpClosure(std::move(fn));
1043 }
1044 
ScheduleIntraOpClosure(TaskFunction fn)1045 void RunHandler::ScheduleIntraOpClosure(TaskFunction fn) {
1046   impl_->ScheduleInterOpClosure(std::move(fn));
1047 }
1048 
NumThreads() const1049 int RunHandler::NumThreads() const {
1050   return impl_->pool_impl()->run_handler_thread_pool()->NumThreads();
1051 }
1052 
step_id() const1053 int64_t RunHandler::step_id() const { return impl_->step_id(); }
1054 
1055 tensorflow::thread::ThreadPoolInterface*
AsIntraThreadPoolInterface() const1056 RunHandler::AsIntraThreadPoolInterface() const {
1057   return impl_->thread_pool_interface();
1058 }
1059 
~RunHandler()1060 RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
1061 
GetParallelismLevel() const1062 int RunHandlerWorkQueue::GetParallelismLevel() const {
1063   return run_handler_->NumThreads();
1064 }
1065 
AddTask(TaskFunction work)1066 void RunHandlerWorkQueue::AddTask(TaskFunction work) {
1067   run_handler_->ScheduleInterOpClosure(tensorflow::tfrt_stub::WrapWork(
1068       run_handler_->step_id(), "inter", std::move(work)));
1069 }
1070 
AddBlockingTask(TaskFunction work,bool allow_queuing)1071 Optional<TaskFunction> RunHandlerWorkQueue::AddBlockingTask(
1072     TaskFunction work, bool allow_queuing) {
1073   LOG_EVERY_N_SEC(ERROR, 10)
1074       << "RunHandlerWorkQueue::AddBlockingTask() is not supposed to be called.";
1075   return work;
1076 }
1077 
Await(ArrayRef<RCReference<AsyncValue>> values)1078 void RunHandlerWorkQueue::Await(ArrayRef<RCReference<AsyncValue>> values) {
1079   tfrt::Await(values);
1080 }
1081 
IsInWorkerThread() const1082 bool RunHandlerWorkQueue::IsInWorkerThread() const {
1083   // Simply return true here as this method is not used in savedmodel workflow
1084   // and soon deprecated.
1085   //
1086   // TODO(b/198671794): Remove this method once it is removed from base.
1087   return true;
1088 }
1089 
1090 }  // namespace tf
1091 }  // namespace tfrt
1092