1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_PLATFORM_MUTEX_H_
17 #define TENSORFLOW_CORE_PLATFORM_MUTEX_H_
18
19 #include <chrono> // NOLINT
20 // for std::try_to_lock_t and std::cv_status
21 #include <condition_variable> // NOLINT
22 #include <mutex> // NOLINT
23
24 #include "tensorflow/core/platform/platform.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 #include "tensorflow/core/platform/types.h"
27
28 // Include appropriate platform-dependent implementation details of mutex etc.
29 #if defined(PLATFORM_GOOGLE)
30 #include "tensorflow/tsl/platform/google/mutex_data.h"
31 #elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
32 defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
33 defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
34 #include "tensorflow/tsl/platform/default/mutex_data.h"
35 #else
36 #error Define the appropriate PLATFORM_<foo> macro for this platform
37 #endif
38
39 namespace tensorflow {
40
41 enum ConditionResult { kCond_Timeout, kCond_MaybeNotified };
42 enum LinkerInitialized { LINKER_INITIALIZED };
43
44 class condition_variable;
45 class Condition;
46
47 // Mimic std::mutex + C++17's shared_mutex, adding a LinkerInitialized
48 // constructor interface. This type is as fast as mutex, but is also a shared
49 // lock, and provides conditional critical sections (via Await()), as an
50 // alternative to condition variables.
51 class TF_LOCKABLE mutex {
52 public:
53 mutex();
54 // The default implementation of the underlying mutex is safe to use after
55 // the linker initialization to zero.
mutex(LinkerInitialized x)56 explicit constexpr mutex(LinkerInitialized x):mu_(absl::kConstInit) {}
57
58 void lock() TF_EXCLUSIVE_LOCK_FUNCTION();
59 bool try_lock() TF_EXCLUSIVE_TRYLOCK_FUNCTION(true);
60 void unlock() TF_UNLOCK_FUNCTION();
61
62 void lock_shared() TF_SHARED_LOCK_FUNCTION();
63 bool try_lock_shared() TF_SHARED_TRYLOCK_FUNCTION(true);
64 void unlock_shared() TF_UNLOCK_FUNCTION();
65
66 // -------
67 // Conditional critical sections.
68 // These represent an alternative to condition variables that is easier to
69 // use. The predicate must be encapsulated in a function (via Condition),
70 // but there is no need to use a while-loop, and no need to signal the
71 // condition. Example: suppose "mu" protects "counter"; we wish one thread
72 // to wait until counter is decremented to zero by another thread.
73 // // Predicate expressed as a function:
74 // static bool IntIsZero(int* pi) { return *pi == 0; }
75 //
76 // // Waiter:
77 // mu.lock();
78 // mu.Await(Condition(&IntIsZero, &counter)); // no loop needed
79 // // lock is held and counter==0...
80 // mu.unlock();
81 //
82 // // Decrementer:
83 // mu.lock();
84 // counter--;
85 // mu.unlock(); // no need to signal; mutex will check condition
86 //
87 // A mutex may be used with condition variables and conditional critical
88 // sections at the same time. Conditional critical sections are easier to
89 // use, but if there are multiple conditions that are simultaneously false,
90 // condition variables may be faster.
91
92 // Unlock *this and wait until cond.Eval() is true, then atomically reacquire
93 // *this in the same mode in which it was previously held and return.
94 void Await(const Condition& cond);
95
96 // Unlock *this and wait until either cond.Eval is true, or abs_deadline_ns
97 // has been reached, then atomically reacquire *this in the same mode in
98 // which it was previously held, and return whether cond.Eval() is true.
99 // See tensorflow/core/platform/env_time.h for the time interface.
100 bool AwaitWithDeadline(const Condition& cond, uint64 abs_deadline_ns);
101 // -------
102
103 private:
104 friend class condition_variable;
105 internal::MuData mu_;
106 };
107
108 // A Condition represents a predicate on state protected by a mutex. The
109 // function must have no side-effects on that state. When passed to
110 // mutex::Await(), the function will be called with the mutex held. It may be
111 // called:
112 // - any number of times;
113 // - by any thread using the mutex; and/or
114 // - with the mutex held in any mode (read or write).
115 // If you must use a lambda, prefix the lambda with +, and capture no variables.
116 // For example: Condition(+[](int *pi)->bool { return *pi == 0; }, &i)
117 class Condition {
118 public:
119 template <typename T>
120 Condition(bool (*func)(T* arg), T* arg); // Value is (*func)(arg)
121 template <typename T>
122 Condition(T* obj, bool (T::*method)()); // Value is obj->*method()
123 template <typename T>
124 Condition(T* obj, bool (T::*method)() const); // Value is obj->*method()
125 explicit Condition(const bool* flag); // Value is *flag
126
127 // Return the value of the predicate represented by this Condition.
Eval()128 bool Eval() const { return (*this->eval_)(this); }
129
130 private:
131 bool (*eval_)(const Condition*); // CallFunction, CallMethod, or, ReturnBool
132 bool (*function_)(void*); // predicate of form (*function_)(arg_)
133 bool (Condition::*method_)(); // predicate of form arg_->method_()
134 void* arg_;
135 Condition();
136 // The following functions can be pointed to by the eval_ field.
137 template <typename T>
138 static bool CallFunction(const Condition* cond); // call function_
139 template <typename T>
140 static bool CallMethod(const Condition* cond); // call method_
141 static bool ReturnBool(const Condition* cond); // access *(bool *)arg_
142 };
143
144 // Mimic a subset of the std::unique_lock<tensorflow::mutex> functionality.
145 class TF_SCOPED_LOCKABLE mutex_lock {
146 public:
147 typedef ::tensorflow::mutex mutex_type;
148
mutex_lock(mutex_type & mu)149 explicit mutex_lock(mutex_type& mu) TF_EXCLUSIVE_LOCK_FUNCTION(mu)
150 : mu_(&mu) {
151 mu_->lock();
152 }
153
mutex_lock(mutex_type & mu,std::try_to_lock_t)154 mutex_lock(mutex_type& mu, std::try_to_lock_t) TF_EXCLUSIVE_LOCK_FUNCTION(mu)
155 : mu_(&mu) {
156 if (!mu.try_lock()) {
157 mu_ = nullptr;
158 }
159 }
160
161 // Manually nulls out the source to prevent double-free.
162 // (std::move does not null the source pointer by default.)
mutex_lock(mutex_lock && ml)163 mutex_lock(mutex_lock&& ml) noexcept TF_EXCLUSIVE_LOCK_FUNCTION(ml.mu_)
164 : mu_(ml.mu_) {
165 ml.mu_ = nullptr;
166 }
TF_UNLOCK_FUNCTION()167 ~mutex_lock() TF_UNLOCK_FUNCTION() {
168 if (mu_ != nullptr) {
169 mu_->unlock();
170 }
171 }
mutex()172 mutex_type* mutex() { return mu_; }
173
174 explicit operator bool() const { return mu_ != nullptr; }
175
176 private:
177 mutex_type* mu_;
178 };
179
180 // Catch bug where variable name is omitted, e.g. mutex_lock (mu);
181 #define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
182
183 // Mimic a subset of the std::shared_lock<tensorflow::mutex> functionality.
184 // Name chosen to minimize conflicts with the tf_shared_lock macro, below.
185 class TF_SCOPED_LOCKABLE tf_shared_lock {
186 public:
187 typedef ::tensorflow::mutex mutex_type;
188
tf_shared_lock(mutex_type & mu)189 explicit tf_shared_lock(mutex_type& mu) TF_SHARED_LOCK_FUNCTION(mu)
190 : mu_(&mu) {
191 mu_->lock_shared();
192 }
193
tf_shared_lock(mutex_type & mu,std::try_to_lock_t)194 tf_shared_lock(mutex_type& mu, std::try_to_lock_t) TF_SHARED_LOCK_FUNCTION(mu)
195 : mu_(&mu) {
196 if (!mu.try_lock_shared()) {
197 mu_ = nullptr;
198 }
199 }
200
201 // Manually nulls out the source to prevent double-free.
202 // (std::move does not null the source pointer by default.)
tf_shared_lock(tf_shared_lock && ml)203 tf_shared_lock(tf_shared_lock&& ml) noexcept TF_SHARED_LOCK_FUNCTION(ml.mu_)
204 : mu_(ml.mu_) {
205 ml.mu_ = nullptr;
206 }
TF_UNLOCK_FUNCTION()207 ~tf_shared_lock() TF_UNLOCK_FUNCTION() {
208 if (mu_ != nullptr) {
209 mu_->unlock_shared();
210 }
211 }
mutex()212 mutex_type* mutex() { return mu_; }
213
214 explicit operator bool() const { return mu_ != nullptr; }
215
216 private:
217 mutex_type* mu_;
218 };
219
220 // Catch bug where variable name is omitted, e.g. tf_shared_lock (mu);
221 #define tf_shared_lock(x) \
222 static_assert(0, "tf_shared_lock_decl_missing_var_name");
223
224 // Mimic std::condition_variable.
225 class condition_variable {
226 public:
227 condition_variable();
228
229 void wait(mutex_lock& lock);
230 template <class Rep, class Period>
231 std::cv_status wait_for(mutex_lock& lock,
232 std::chrono::duration<Rep, Period> dur);
233 void notify_one();
234 void notify_all();
235
236 private:
237 friend ConditionResult WaitForMilliseconds(mutex_lock* mu,
238 condition_variable* cv,
239 int64_t ms);
240 internal::CVData cv_;
241 };
242
243 // Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds.
244 //
245 // Returns kCond_Timeout if the timeout expired without this
246 // thread noticing a signal on the condition variable. Otherwise may
247 // return either kCond_Timeout or kCond_MaybeNotified
WaitForMilliseconds(mutex_lock * mu,condition_variable * cv,int64_t ms)248 inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
249 condition_variable* cv, int64_t ms) {
250 std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms));
251 return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified;
252 }
253
254 // ------------------------------------------------------------
255 // Implementation details follow. Clients should ignore them.
256
257 // private static
258 template <typename T>
CallFunction(const Condition * cond)259 inline bool Condition::CallFunction(const Condition* cond) {
260 bool (*fn)(T*) = reinterpret_cast<bool (*)(T*)>(cond->function_);
261 return (*fn)(static_cast<T*>(cond->arg_));
262 }
263
264 template <typename T>
Condition(bool (* func)(T *),T * arg)265 inline Condition::Condition(bool (*func)(T*), T* arg)
266 : eval_(&CallFunction<T>),
267 function_(reinterpret_cast<bool (*)(void*)>(func)),
268 method_(nullptr),
269 arg_(const_cast<void*>(static_cast<const void*>(arg))) {}
270
271 // private static
272 template <typename T>
CallMethod(const Condition * cond)273 inline bool Condition::CallMethod(const Condition* cond) {
274 bool (T::*m)() = reinterpret_cast<bool (T::*)()>(cond->method_);
275 return (static_cast<T*>(cond->arg_)->*m)();
276 }
277
278 template <typename T>
Condition(T * obj,bool (T::* method)())279 inline Condition::Condition(T* obj, bool (T::*method)())
280 : eval_(&CallMethod<T>),
281 function_(nullptr),
282 method_(reinterpret_cast<bool (Condition::*)()>(method)),
283 arg_(const_cast<void*>(static_cast<const void*>(obj))) {}
284
285 template <typename T>
Condition(T * obj,bool (T::* method)()const)286 inline Condition::Condition(T* obj, bool (T::*method)() const)
287 : eval_(&CallMethod<T>),
288 function_(nullptr),
289 method_(reinterpret_cast<bool (Condition::*)()>(method)),
290 arg_(const_cast<void*>(static_cast<const void*>(obj))) {}
291
292 // private static
ReturnBool(const Condition * cond)293 inline bool Condition::ReturnBool(const Condition* cond) {
294 return *static_cast<bool*>(cond->arg_);
295 }
296
Condition(const bool * flag)297 inline Condition::Condition(const bool* flag)
298 : eval_(&ReturnBool),
299 function_(nullptr),
300 method_(nullptr),
301 arg_(const_cast<void*>(static_cast<const void*>(flag))) {}
302
303 } // namespace tensorflow
304
305 // Include appropriate platform-dependent implementation details of mutex etc.
306 #if defined(PLATFORM_GOOGLE)
307 #include "tensorflow/tsl/platform/google/mutex.h"
308 #elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
309 defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
310 defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
311 #include "tensorflow/tsl/platform/default/mutex.h"
312 #else
313 #error Define the appropriate PLATFORM_<foo> macro for this platform
314 #endif
315
316 #endif // TENSORFLOW_CORE_PLATFORM_MUTEX_H_
317