xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/mutex.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PLATFORM_MUTEX_H_
17 #define TENSORFLOW_CORE_PLATFORM_MUTEX_H_
18 
19 #include <chrono>  // NOLINT
20 // for std::try_to_lock_t and std::cv_status
21 #include <condition_variable>  // NOLINT
22 #include <mutex>               // NOLINT
23 
24 #include "tensorflow/core/platform/platform.h"
25 #include "tensorflow/core/platform/thread_annotations.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 // Include appropriate platform-dependent implementation details of mutex etc.
29 #if defined(PLATFORM_GOOGLE)
30 #include "tensorflow/tsl/platform/google/mutex_data.h"
31 #elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) ||    \
32     defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
33     defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
34 #include "tensorflow/tsl/platform/default/mutex_data.h"
35 #else
36 #error Define the appropriate PLATFORM_<foo> macro for this platform
37 #endif
38 
39 namespace tensorflow {
40 
41 enum ConditionResult { kCond_Timeout, kCond_MaybeNotified };
42 enum LinkerInitialized { LINKER_INITIALIZED };
43 
44 class condition_variable;
45 class Condition;
46 
47 // Mimic std::mutex + C++17's shared_mutex, adding a LinkerInitialized
48 // constructor interface.  This type is as fast as mutex, but is also a shared
49 // lock, and provides conditional critical sections (via Await()), as an
50 // alternative to condition variables.
51 class TF_LOCKABLE mutex {
52  public:
53   mutex();
54   // The default implementation of the underlying mutex is safe to use after
55   // the linker initialization to zero.
mutex(LinkerInitialized x)56   explicit constexpr mutex(LinkerInitialized x):mu_(absl::kConstInit) {}
57 
58   void lock() TF_EXCLUSIVE_LOCK_FUNCTION();
59   bool try_lock() TF_EXCLUSIVE_TRYLOCK_FUNCTION(true);
60   void unlock() TF_UNLOCK_FUNCTION();
61 
62   void lock_shared() TF_SHARED_LOCK_FUNCTION();
63   bool try_lock_shared() TF_SHARED_TRYLOCK_FUNCTION(true);
64   void unlock_shared() TF_UNLOCK_FUNCTION();
65 
66   // -------
67   // Conditional critical sections.
68   // These represent an alternative to condition variables that is easier to
69   // use.  The predicate must be encapsulated in a function (via Condition),
70   // but there is no need to use a while-loop, and no need to signal the
71   // condition.  Example:  suppose "mu" protects "counter"; we wish one thread
72   // to wait until counter is decremented to zero by another thread.
73   //   // Predicate expressed as a function:
74   //   static bool IntIsZero(int* pi) { return *pi == 0; }
75   //
76   //   // Waiter:
77   //   mu.lock();
78   //   mu.Await(Condition(&IntIsZero, &counter));   // no loop needed
79   //   // lock is held and counter==0...
80   //   mu.unlock();
81   //
82   //   // Decrementer:
83   //   mu.lock();
84   //   counter--;
85   //   mu.unlock();    // no need to signal; mutex will check condition
86   //
87   // A mutex may be used with condition variables and conditional critical
88   // sections at the same time.  Conditional critical sections are easier to
89   // use, but if there are multiple conditions that are simultaneously false,
90   // condition variables may be faster.
91 
92   // Unlock *this and wait until cond.Eval() is true, then atomically reacquire
93   // *this in the same mode in which it was previously held and return.
94   void Await(const Condition& cond);
95 
96   // Unlock *this and wait until either cond.Eval is true, or abs_deadline_ns
97   // has been reached, then atomically reacquire *this in the same mode in
98   // which it was previously held, and return whether cond.Eval() is true.
99   // See tensorflow/core/platform/env_time.h for the time interface.
100   bool AwaitWithDeadline(const Condition& cond, uint64 abs_deadline_ns);
101   // -------
102 
103  private:
104   friend class condition_variable;
105   internal::MuData mu_;
106 };
107 
108 // A Condition represents a predicate on state protected by a mutex.  The
109 // function must have no side-effects on that state.  When passed to
110 // mutex::Await(), the function will be called with the mutex held.  It may be
111 // called:
112 // - any number of times;
113 // - by any thread using the mutex; and/or
114 // - with the mutex held in any mode (read or write).
115 // If you must use a lambda, prefix the lambda with +, and capture no variables.
116 // For example:  Condition(+[](int *pi)->bool { return *pi == 0; }, &i)
117 class Condition {
118  public:
119   template <typename T>
120   Condition(bool (*func)(T* arg), T* arg);  // Value is (*func)(arg)
121   template <typename T>
122   Condition(T* obj, bool (T::*method)());  // Value is obj->*method()
123   template <typename T>
124   Condition(T* obj, bool (T::*method)() const);  // Value is obj->*method()
125   explicit Condition(const bool* flag);          // Value is *flag
126 
127   // Return the value of the predicate represented by this Condition.
Eval()128   bool Eval() const { return (*this->eval_)(this); }
129 
130  private:
131   bool (*eval_)(const Condition*);  // CallFunction, CallMethod, or, ReturnBool
132   bool (*function_)(void*);         // predicate of form (*function_)(arg_)
133   bool (Condition::*method_)();     // predicate of form arg_->method_()
134   void* arg_;
135   Condition();
136   // The following functions can be pointed to by the eval_ field.
137   template <typename T>
138   static bool CallFunction(const Condition* cond);  // call function_
139   template <typename T>
140   static bool CallMethod(const Condition* cond);  // call method_
141   static bool ReturnBool(const Condition* cond);  // access *(bool *)arg_
142 };
143 
144 // Mimic a subset of the std::unique_lock<tensorflow::mutex> functionality.
145 class TF_SCOPED_LOCKABLE mutex_lock {
146  public:
147   typedef ::tensorflow::mutex mutex_type;
148 
mutex_lock(mutex_type & mu)149   explicit mutex_lock(mutex_type& mu) TF_EXCLUSIVE_LOCK_FUNCTION(mu)
150       : mu_(&mu) {
151     mu_->lock();
152   }
153 
mutex_lock(mutex_type & mu,std::try_to_lock_t)154   mutex_lock(mutex_type& mu, std::try_to_lock_t) TF_EXCLUSIVE_LOCK_FUNCTION(mu)
155       : mu_(&mu) {
156     if (!mu.try_lock()) {
157       mu_ = nullptr;
158     }
159   }
160 
161   // Manually nulls out the source to prevent double-free.
162   // (std::move does not null the source pointer by default.)
mutex_lock(mutex_lock && ml)163   mutex_lock(mutex_lock&& ml) noexcept TF_EXCLUSIVE_LOCK_FUNCTION(ml.mu_)
164       : mu_(ml.mu_) {
165     ml.mu_ = nullptr;
166   }
TF_UNLOCK_FUNCTION()167   ~mutex_lock() TF_UNLOCK_FUNCTION() {
168     if (mu_ != nullptr) {
169       mu_->unlock();
170     }
171   }
mutex()172   mutex_type* mutex() { return mu_; }
173 
174   explicit operator bool() const { return mu_ != nullptr; }
175 
176  private:
177   mutex_type* mu_;
178 };
179 
180 // Catch bug where variable name is omitted, e.g. mutex_lock (mu);
181 #define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
182 
183 // Mimic a subset of the std::shared_lock<tensorflow::mutex> functionality.
184 // Name chosen to minimize conflicts with the tf_shared_lock macro, below.
185 class TF_SCOPED_LOCKABLE tf_shared_lock {
186  public:
187   typedef ::tensorflow::mutex mutex_type;
188 
tf_shared_lock(mutex_type & mu)189   explicit tf_shared_lock(mutex_type& mu) TF_SHARED_LOCK_FUNCTION(mu)
190       : mu_(&mu) {
191     mu_->lock_shared();
192   }
193 
tf_shared_lock(mutex_type & mu,std::try_to_lock_t)194   tf_shared_lock(mutex_type& mu, std::try_to_lock_t) TF_SHARED_LOCK_FUNCTION(mu)
195       : mu_(&mu) {
196     if (!mu.try_lock_shared()) {
197       mu_ = nullptr;
198     }
199   }
200 
201   // Manually nulls out the source to prevent double-free.
202   // (std::move does not null the source pointer by default.)
tf_shared_lock(tf_shared_lock && ml)203   tf_shared_lock(tf_shared_lock&& ml) noexcept TF_SHARED_LOCK_FUNCTION(ml.mu_)
204       : mu_(ml.mu_) {
205     ml.mu_ = nullptr;
206   }
TF_UNLOCK_FUNCTION()207   ~tf_shared_lock() TF_UNLOCK_FUNCTION() {
208     if (mu_ != nullptr) {
209       mu_->unlock_shared();
210     }
211   }
mutex()212   mutex_type* mutex() { return mu_; }
213 
214   explicit operator bool() const { return mu_ != nullptr; }
215 
216  private:
217   mutex_type* mu_;
218 };
219 
220 // Catch bug where variable name is omitted, e.g. tf_shared_lock (mu);
221 #define tf_shared_lock(x) \
222   static_assert(0, "tf_shared_lock_decl_missing_var_name");
223 
224 // Mimic std::condition_variable.
225 class condition_variable {
226  public:
227   condition_variable();
228 
229   void wait(mutex_lock& lock);
230   template <class Rep, class Period>
231   std::cv_status wait_for(mutex_lock& lock,
232                           std::chrono::duration<Rep, Period> dur);
233   void notify_one();
234   void notify_all();
235 
236  private:
237   friend ConditionResult WaitForMilliseconds(mutex_lock* mu,
238                                              condition_variable* cv,
239                                              int64_t ms);
240   internal::CVData cv_;
241 };
242 
243 // Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds.
244 //
245 // Returns kCond_Timeout if the timeout expired without this
246 // thread noticing a signal on the condition variable.  Otherwise may
247 // return either kCond_Timeout or kCond_MaybeNotified
WaitForMilliseconds(mutex_lock * mu,condition_variable * cv,int64_t ms)248 inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
249                                            condition_variable* cv, int64_t ms) {
250   std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms));
251   return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified;
252 }
253 
254 // ------------------------------------------------------------
255 // Implementation details follow.   Clients should ignore them.
256 
257 // private static
258 template <typename T>
CallFunction(const Condition * cond)259 inline bool Condition::CallFunction(const Condition* cond) {
260   bool (*fn)(T*) = reinterpret_cast<bool (*)(T*)>(cond->function_);
261   return (*fn)(static_cast<T*>(cond->arg_));
262 }
263 
264 template <typename T>
Condition(bool (* func)(T *),T * arg)265 inline Condition::Condition(bool (*func)(T*), T* arg)
266     : eval_(&CallFunction<T>),
267       function_(reinterpret_cast<bool (*)(void*)>(func)),
268       method_(nullptr),
269       arg_(const_cast<void*>(static_cast<const void*>(arg))) {}
270 
271 // private static
272 template <typename T>
CallMethod(const Condition * cond)273 inline bool Condition::CallMethod(const Condition* cond) {
274   bool (T::*m)() = reinterpret_cast<bool (T::*)()>(cond->method_);
275   return (static_cast<T*>(cond->arg_)->*m)();
276 }
277 
278 template <typename T>
Condition(T * obj,bool (T::* method)())279 inline Condition::Condition(T* obj, bool (T::*method)())
280     : eval_(&CallMethod<T>),
281       function_(nullptr),
282       method_(reinterpret_cast<bool (Condition::*)()>(method)),
283       arg_(const_cast<void*>(static_cast<const void*>(obj))) {}
284 
285 template <typename T>
Condition(T * obj,bool (T::* method)()const)286 inline Condition::Condition(T* obj, bool (T::*method)() const)
287     : eval_(&CallMethod<T>),
288       function_(nullptr),
289       method_(reinterpret_cast<bool (Condition::*)()>(method)),
290       arg_(const_cast<void*>(static_cast<const void*>(obj))) {}
291 
292 // private static
ReturnBool(const Condition * cond)293 inline bool Condition::ReturnBool(const Condition* cond) {
294   return *static_cast<bool*>(cond->arg_);
295 }
296 
Condition(const bool * flag)297 inline Condition::Condition(const bool* flag)
298     : eval_(&ReturnBool),
299       function_(nullptr),
300       method_(nullptr),
301       arg_(const_cast<void*>(static_cast<const void*>(flag))) {}
302 
303 }  // namespace tensorflow
304 
305 // Include appropriate platform-dependent implementation details of mutex etc.
306 #if defined(PLATFORM_GOOGLE)
307 #include "tensorflow/tsl/platform/google/mutex.h"
308 #elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) ||    \
309     defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
310     defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
311 #include "tensorflow/tsl/platform/default/mutex.h"
312 #else
313 #error Define the appropriate PLATFORM_<foo> macro for this platform
314 #endif
315 
316 #endif  // TENSORFLOW_CORE_PLATFORM_MUTEX_H_
317