xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/run_handler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
18 
19 #include "tensorflow/core/lib/core/threadpool.h"
20 #include "tensorflow/core/lib/histogram/histogram.h"
21 #include "tensorflow/core/platform/context.h"
22 #include "tensorflow/core/platform/mutex.h"
23 #include "tensorflow/core/platform/thread_annotations.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 
26 namespace Eigen {
27 struct ThreadPoolDevice;
28 }
29 
30 namespace tensorflow {
31 
32 class RunHandler;
33 
34 // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
35 // that can be used for tracking inter-op work for a given Session::Run().
36 // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
37 // 'active' when its unique_ptr is returned by Get() and is being used by a
38 // client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
39 //
40 // Expected usage:
41 //
42 // * Create a single RunHandlerPool (say run_handler_pool_).
43 //
44 // * When a Session::Run() is invoked, obtain a handler by:
45 // auto handler = run_handler_pool_->Get();
46 //
47 // * Use handler for scheduling all inter-op work by:
48 // handler->ScheduleInterOpClosure(closure);
49 //
50 // This class is thread safe.
51 class RunHandlerPool {
52  public:
53   explicit RunHandlerPool(int num_inter_op_threads);
54 
55   RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads);
56   ~RunHandlerPool();
57 
58   // Returns an inactive RunHandler from the pool.
59   //
60   // RunHandlers in RunHandlerPool are initially 'inactive'.
61   // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
62   // and is being used by a client.  It becomes 'inactive' once more when the
63   // unique_ptr is destroyed.
64   //
65   // Will block unless there is an inactive handler.
66   std::unique_ptr<RunHandler> Get(
67       int64_t step_id = 0, int64_t timeout_in_ms = 0,
68       const RunOptions::Experimental::RunHandlerPoolOptions& options =
69           RunOptions::Experimental::RunHandlerPoolOptions());
70 
71   // Get the priorities for active handlers. The return result is with the same
72   // order of the active handler list.
73   std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const;
74 
75  private:
76   class Impl;
77   friend class RunHandler;
78 
79   std::unique_ptr<Impl> impl_;
80 };
81 
82 // RunHandler can be used to schedule inter/intra-op closures to run on a global
83 // pool shared across all Session::Run(s). The closures are enqueued to a
84 // handler specific queue, from which the work is stolen in a priority order
85 // (time of the Get() call).
86 //
87 // It can only be created via RunHandlerPool::Get().
88 //
89 // This class can be used instead of directly scheduling closures on a global
90 // pool since it maintains a global view across all sessions and optimizes pool
91 // scheduling to improve (median and tail) latency.
92 //
93 // This class is thread safe.
94 class RunHandler {
95  public:
96   void ScheduleInterOpClosure(std::function<void()> fn);
97   thread::ThreadPoolInterface* AsIntraThreadPoolInterface();
98 
99   ~RunHandler();
100 
101  private:
102   class Impl;
103   friend class RunHandlerPool::Impl;
104 
105   explicit RunHandler(Impl* impl);
106 
107   Impl* impl_;  // NOT OWNED.
108 };
109 
110 namespace internal {
111 
112 // TODO(azaks): Refactor with thread:ThreadPool
113 class RunHandlerEnvironment {
114   typedef Thread EnvThread;
115   struct TaskImpl {
116     std::function<void()> f;
117     Context context;
118     uint64 trace_id;
119   };
120   Env* const env_;
121   const ThreadOptions thread_options_;
122   const string name_;
123 
124  public:
125   struct Task {
126     std::unique_ptr<TaskImpl> f;
127   };
128 
129   RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
130                         const string& name);
131 
132   EnvThread* CreateThread(std::function<void()> f,
133                           const std::string& thread_name);
134 
135   Task CreateTask(std::function<void()> f);
136 
137   void ExecuteTask(const Task& t);
138 };
139 
140 typedef typename RunHandlerEnvironment::Task Task;
141 typedef Eigen::RunQueue<Task, 1024> Queue;
142 
143 // To reduce cache misses, we use a doubly-linked list of Waiter structs and
144 // queue them in LIFO order rather than the FIFO order used by a single
145 // condition variable.
146 struct Waiter {
WaiterWaiter147   Waiter() {
148     next = this;
149     prev = this;
150   }
151   condition_variable cv;
152   mutex mu;
153   Waiter* next;
154   Waiter* prev;
155 };
156 
157 class ThreadWorkSource {
158  public:
159   ThreadWorkSource();
160 
161   ~ThreadWorkSource();
162 
163   Task EnqueueTask(Task t, bool is_blocking);
164 
165   Task PopBlockingTask();
166 
167   Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
168 
169   void WaitForWork(int max_sleep_micros);
170 
171   int TaskQueueSize(bool is_blocking);
172 
173   int64_t GetTracemeId();
174 
175   void SetTracemeId(int64_t value);
176 
177   void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex);
178 
179   int64_t GetInflightTaskCount(bool is_blocking);
180 
181   void IncrementInflightTaskCount(bool is_blocking);
182 
183   void DecrementInflightTaskCount(bool is_blocking);
184 
185   unsigned NonBlockingWorkShardingFactor();
186 
187   std::string ToString();
188 
189  private:
190   struct NonBlockingQueue {
191     mutex queue_op_mu;
192     char pad[128];
193     Queue queue;
194   };
195 
196   int32 non_blocking_work_sharding_factor_;
197   Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
198 
199   std::atomic<int64_t> blocking_inflight_;
200   std::atomic<int64_t> non_blocking_inflight_;
201 
202   Queue blocking_work_queue_;
203   mutex blocking_queue_op_mu_;
204   char pad_[128];
205   mutex waiters_mu_;
206   Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
207   std::atomic<int64_t> traceme_id_;
208 
209   mutex run_handler_waiter_mu_;
210   uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
211   mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
212   Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
213 };
214 
215 class RunHandlerThreadPool {
216  public:
217   struct PerThread {
PerThreadPerThread218     constexpr PerThread() : pool(nullptr), thread_id(-1) {}
219     RunHandlerThreadPool* pool;  // Parent pool, or null for normal threads.
220     int thread_id;               // Worker thread index in pool.
221   };
222 
223   RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
224                        Env* env, const ThreadOptions& thread_options,
225                        const string& name,
226                        Eigen::MaxSizeVector<mutex>* waiters_mu,
227                        Eigen::MaxSizeVector<Waiter>* queue_waiters);
228 
229   ~RunHandlerThreadPool();
230 
231   void Start();
232 
233   void StartOneThreadForTesting();
234 
235   void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
236                       std::function<void()> fn);
237 
238   // Set work queues from which the thread 'tid' can steal its work.
239   // The request with start_request_idx will be attempted first. Other requests
240   // will be attempted in FIFO order based on their arrival time.
241   void SetThreadWorkSources(
242       int tid, int start_request_idx, uint64 version,
243       const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
244 
245   PerThread* GetPerThread();
246 
247   int CurrentThreadId() const;
248 
249   int NumThreads() const;
250 
251   int NumBlockingThreads() const;
252 
253   int NumNonBlockingThreads() const;
254 
255   void WorkerLoop(int thread_id, bool may_steal_blocking_work);
256 
257   // Search tasks from Requets range searching_range_start to
258   // searching_range_end. If there is no tasks in the search range and
259   // may_steal_blocking_work is true, then search from all requests.
260   Task FindTask(
261       int searching_range_start, int searching_range_end, int thread_id,
262       int sub_thread_pool_id, int max_blocking_inflight,
263       bool may_steal_blocking_work,
264       const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
265       bool* task_from_blocking_queue, ThreadWorkSource** tws);
266 
267   void WaitForWork(bool is_blocking, int thread_id,
268                    int32_t max_blocking_inflight);
269 
270   void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
271 
272  private:
273   struct ThreadData {
274     ThreadData();
275     mutex mu;
276     uint64 new_version;
277     condition_variable sources_not_empty;
278     std::unique_ptr<Thread> thread;
279     int current_index;
280     std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
281         new_thread_work_sources TF_GUARDED_BY(mu);
282 
283     uint64 current_version;
284     // Should only be accessed by one thread.
285     std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
286         current_thread_work_sources;
287 
288     int sub_thread_pool_id;
289   };
290 
291   const int num_threads_;
292   const int num_blocking_threads_;
293   const int num_non_blocking_threads_;
294   Eigen::MaxSizeVector<ThreadData> thread_data_;
295   internal::RunHandlerEnvironment env_;
296   std::atomic<bool> cancelled_;
297   string name_;
298   Eigen::MaxSizeVector<mutex>* waiters_mu_;
299   Eigen::MaxSizeVector<Waiter>* queue_waiters_;
300 
301   bool use_sub_thread_pool_;
302   std::vector<int> num_threads_in_sub_thread_pool_;
303 
304   // Threads in each sub thread pool will search tasks from the given
305   // start_request_percentage to end_request_percentage in a round robin
306   // fashion.
307   std::vector<double> sub_thread_pool_start_request_percentage_;
308   std::vector<double> sub_thread_pool_end_request_percentage_;
309 };
310 
311 }  // namespace internal
312 
313 }  // end namespace tensorflow.
314 
315 #endif  // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
316