1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ 18 19 #include "tensorflow/core/lib/core/threadpool.h" 20 #include "tensorflow/core/lib/histogram/histogram.h" 21 #include "tensorflow/core/platform/context.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/platform/thread_annotations.h" 24 #include "tensorflow/core/protobuf/config.pb.h" 25 26 namespace Eigen { 27 struct ThreadPoolDevice; 28 } 29 30 namespace tensorflow { 31 32 class RunHandler; 33 34 // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers 35 // that can be used for tracking inter-op work for a given Session::Run(). 36 // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes 37 // 'active' when its unique_ptr is returned by Get() and is being used by a 38 // client. It becomes 'inactive' once more when its unique_ptr gets destroyed. 39 // 40 // Expected usage: 41 // 42 // * Create a single RunHandlerPool (say run_handler_pool_). 43 // 44 // * When a Session::Run() is invoked, obtain a handler by: 45 // auto handler = run_handler_pool_->Get(); 46 // 47 // * Use handler for scheduling all inter-op work by: 48 // handler->ScheduleInterOpClosure(closure); 49 // 50 // This class is thread safe. 51 class RunHandlerPool { 52 public: 53 explicit RunHandlerPool(int num_inter_op_threads); 54 55 RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads); 56 ~RunHandlerPool(); 57 58 // Returns an inactive RunHandler from the pool. 59 // 60 // RunHandlers in RunHandlerPool are initially 'inactive'. 61 // A RunHandler becomes 'active' when its unique_ptr its returned by Get() 62 // and is being used by a client. It becomes 'inactive' once more when the 63 // unique_ptr is destroyed. 64 // 65 // Will block unless there is an inactive handler. 66 std::unique_ptr<RunHandler> Get( 67 int64_t step_id = 0, int64_t timeout_in_ms = 0, 68 const RunOptions::Experimental::RunHandlerPoolOptions& options = 69 RunOptions::Experimental::RunHandlerPoolOptions()); 70 71 // Get the priorities for active handlers. The return result is with the same 72 // order of the active handler list. 73 std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const; 74 75 private: 76 class Impl; 77 friend class RunHandler; 78 79 std::unique_ptr<Impl> impl_; 80 }; 81 82 // RunHandler can be used to schedule inter/intra-op closures to run on a global 83 // pool shared across all Session::Run(s). The closures are enqueued to a 84 // handler specific queue, from which the work is stolen in a priority order 85 // (time of the Get() call). 86 // 87 // It can only be created via RunHandlerPool::Get(). 88 // 89 // This class can be used instead of directly scheduling closures on a global 90 // pool since it maintains a global view across all sessions and optimizes pool 91 // scheduling to improve (median and tail) latency. 92 // 93 // This class is thread safe. 94 class RunHandler { 95 public: 96 void ScheduleInterOpClosure(std::function<void()> fn); 97 thread::ThreadPoolInterface* AsIntraThreadPoolInterface(); 98 99 ~RunHandler(); 100 101 private: 102 class Impl; 103 friend class RunHandlerPool::Impl; 104 105 explicit RunHandler(Impl* impl); 106 107 Impl* impl_; // NOT OWNED. 108 }; 109 110 namespace internal { 111 112 // TODO(azaks): Refactor with thread:ThreadPool 113 class RunHandlerEnvironment { 114 typedef Thread EnvThread; 115 struct TaskImpl { 116 std::function<void()> f; 117 Context context; 118 uint64 trace_id; 119 }; 120 Env* const env_; 121 const ThreadOptions thread_options_; 122 const string name_; 123 124 public: 125 struct Task { 126 std::unique_ptr<TaskImpl> f; 127 }; 128 129 RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, 130 const string& name); 131 132 EnvThread* CreateThread(std::function<void()> f, 133 const std::string& thread_name); 134 135 Task CreateTask(std::function<void()> f); 136 137 void ExecuteTask(const Task& t); 138 }; 139 140 typedef typename RunHandlerEnvironment::Task Task; 141 typedef Eigen::RunQueue<Task, 1024> Queue; 142 143 // To reduce cache misses, we use a doubly-linked list of Waiter structs and 144 // queue them in LIFO order rather than the FIFO order used by a single 145 // condition variable. 146 struct Waiter { WaiterWaiter147 Waiter() { 148 next = this; 149 prev = this; 150 } 151 condition_variable cv; 152 mutex mu; 153 Waiter* next; 154 Waiter* prev; 155 }; 156 157 class ThreadWorkSource { 158 public: 159 ThreadWorkSource(); 160 161 ~ThreadWorkSource(); 162 163 Task EnqueueTask(Task t, bool is_blocking); 164 165 Task PopBlockingTask(); 166 167 Task PopNonBlockingTask(int start_index, bool search_from_all_queue); 168 169 void WaitForWork(int max_sleep_micros); 170 171 int TaskQueueSize(bool is_blocking); 172 173 int64_t GetTracemeId(); 174 175 void SetTracemeId(int64_t value); 176 177 void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex); 178 179 int64_t GetInflightTaskCount(bool is_blocking); 180 181 void IncrementInflightTaskCount(bool is_blocking); 182 183 void DecrementInflightTaskCount(bool is_blocking); 184 185 unsigned NonBlockingWorkShardingFactor(); 186 187 std::string ToString(); 188 189 private: 190 struct NonBlockingQueue { 191 mutex queue_op_mu; 192 char pad[128]; 193 Queue queue; 194 }; 195 196 int32 non_blocking_work_sharding_factor_; 197 Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_; 198 199 std::atomic<int64_t> blocking_inflight_; 200 std::atomic<int64_t> non_blocking_inflight_; 201 202 Queue blocking_work_queue_; 203 mutex blocking_queue_op_mu_; 204 char pad_[128]; 205 mutex waiters_mu_; 206 Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); 207 std::atomic<int64_t> traceme_id_; 208 209 mutex run_handler_waiter_mu_; 210 uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_); 211 mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_); 212 Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); 213 }; 214 215 class RunHandlerThreadPool { 216 public: 217 struct PerThread { PerThreadPerThread218 constexpr PerThread() : pool(nullptr), thread_id(-1) {} 219 RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. 220 int thread_id; // Worker thread index in pool. 221 }; 222 223 RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, 224 Env* env, const ThreadOptions& thread_options, 225 const string& name, 226 Eigen::MaxSizeVector<mutex>* waiters_mu, 227 Eigen::MaxSizeVector<Waiter>* queue_waiters); 228 229 ~RunHandlerThreadPool(); 230 231 void Start(); 232 233 void StartOneThreadForTesting(); 234 235 void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, 236 std::function<void()> fn); 237 238 // Set work queues from which the thread 'tid' can steal its work. 239 // The request with start_request_idx will be attempted first. Other requests 240 // will be attempted in FIFO order based on their arrival time. 241 void SetThreadWorkSources( 242 int tid, int start_request_idx, uint64 version, 243 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources); 244 245 PerThread* GetPerThread(); 246 247 int CurrentThreadId() const; 248 249 int NumThreads() const; 250 251 int NumBlockingThreads() const; 252 253 int NumNonBlockingThreads() const; 254 255 void WorkerLoop(int thread_id, bool may_steal_blocking_work); 256 257 // Search tasks from Requets range searching_range_start to 258 // searching_range_end. If there is no tasks in the search range and 259 // may_steal_blocking_work is true, then search from all requests. 260 Task FindTask( 261 int searching_range_start, int searching_range_end, int thread_id, 262 int sub_thread_pool_id, int max_blocking_inflight, 263 bool may_steal_blocking_work, 264 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources, 265 bool* task_from_blocking_queue, ThreadWorkSource** tws); 266 267 void WaitForWork(bool is_blocking, int thread_id, 268 int32_t max_blocking_inflight); 269 270 void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); 271 272 private: 273 struct ThreadData { 274 ThreadData(); 275 mutex mu; 276 uint64 new_version; 277 condition_variable sources_not_empty; 278 std::unique_ptr<Thread> thread; 279 int current_index; 280 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> 281 new_thread_work_sources TF_GUARDED_BY(mu); 282 283 uint64 current_version; 284 // Should only be accessed by one thread. 285 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> 286 current_thread_work_sources; 287 288 int sub_thread_pool_id; 289 }; 290 291 const int num_threads_; 292 const int num_blocking_threads_; 293 const int num_non_blocking_threads_; 294 Eigen::MaxSizeVector<ThreadData> thread_data_; 295 internal::RunHandlerEnvironment env_; 296 std::atomic<bool> cancelled_; 297 string name_; 298 Eigen::MaxSizeVector<mutex>* waiters_mu_; 299 Eigen::MaxSizeVector<Waiter>* queue_waiters_; 300 301 bool use_sub_thread_pool_; 302 std::vector<int> num_threads_in_sub_thread_pool_; 303 304 // Threads in each sub thread pool will search tasks from the given 305 // start_request_percentage to end_request_percentage in a round robin 306 // fashion. 307 std::vector<double> sub_thread_pool_start_request_percentage_; 308 std::vector<double> sub_thread_pool_end_request_percentage_; 309 }; 310 311 } // namespace internal 312 313 } // end namespace tensorflow. 314 315 #endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ 316