1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "absl/synchronization/notification.h"
16
17 #include <thread> // NOLINT(build/c++11)
18 #include <tuple>
19 #include <vector>
20
21 #include "gtest/gtest.h"
22 #include "absl/base/attributes.h"
23 #include "absl/base/config.h"
24 #include "absl/base/internal/tracing.h"
25 #include "absl/synchronization/mutex.h"
26 #include "absl/time/time.h"
27
28 namespace absl {
29 ABSL_NAMESPACE_BEGIN
30
31 // A thread-safe class that holds a counter.
32 class ThreadSafeCounter {
33 public:
ThreadSafeCounter()34 ThreadSafeCounter() : count_(0) {}
35
Increment()36 void Increment() {
37 MutexLock lock(&mutex_);
38 ++count_;
39 }
40
Get() const41 int Get() const {
42 MutexLock lock(&mutex_);
43 return count_;
44 }
45
WaitUntilGreaterOrEqual(int n)46 void WaitUntilGreaterOrEqual(int n) {
47 MutexLock lock(&mutex_);
48 auto cond = [this, n]() { return count_ >= n; };
49 mutex_.Await(Condition(&cond));
50 }
51
52 private:
53 mutable Mutex mutex_;
54 int count_;
55 };
56
57 // Runs the |i|'th worker thread for the tests in BasicTests(). Increments the
58 // |ready_counter|, waits on the |notification|, and then increments the
59 // |done_counter|.
RunWorker(int i,ThreadSafeCounter * ready_counter,Notification * notification,ThreadSafeCounter * done_counter)60 static void RunWorker(int i, ThreadSafeCounter* ready_counter,
61 Notification* notification,
62 ThreadSafeCounter* done_counter) {
63 ready_counter->Increment();
64 notification->WaitForNotification();
65 done_counter->Increment();
66 }
67
68 // Tests that the |notification| properly blocks and awakens threads. Assumes
69 // that the |notification| is not yet triggered. If |notify_before_waiting| is
70 // true, the |notification| is triggered before any threads are created, so the
71 // threads never block in WaitForNotification(). Otherwise, the |notification|
72 // is triggered at a later point when most threads are likely to be blocking in
73 // WaitForNotification().
BasicTests(bool notify_before_waiting,Notification * notification)74 static void BasicTests(bool notify_before_waiting, Notification* notification) {
75 EXPECT_FALSE(notification->HasBeenNotified());
76 EXPECT_FALSE(
77 notification->WaitForNotificationWithTimeout(absl::Milliseconds(0)));
78 EXPECT_FALSE(notification->WaitForNotificationWithDeadline(absl::Now()));
79
80 const absl::Duration delay = absl::Milliseconds(50);
81 const absl::Time start = absl::Now();
82 EXPECT_FALSE(notification->WaitForNotificationWithTimeout(delay));
83 const absl::Duration elapsed = absl::Now() - start;
84
85 // Allow for a slight early return, to account for quality of implementation
86 // issues on various platforms.
87 const absl::Duration slop = absl::Milliseconds(5);
88 EXPECT_LE(delay - slop, elapsed)
89 << "WaitForNotificationWithTimeout returned " << delay - elapsed
90 << " early (with " << slop << " slop), start time was " << start;
91
92 ThreadSafeCounter ready_counter;
93 ThreadSafeCounter done_counter;
94
95 if (notify_before_waiting) {
96 notification->Notify();
97 }
98
99 // Create a bunch of threads that increment the |done_counter| after being
100 // notified.
101 const int kNumThreads = 10;
102 std::vector<std::thread> workers;
103 for (int i = 0; i < kNumThreads; ++i) {
104 workers.push_back(std::thread(&RunWorker, i, &ready_counter, notification,
105 &done_counter));
106 }
107
108 if (!notify_before_waiting) {
109 ready_counter.WaitUntilGreaterOrEqual(kNumThreads);
110
111 // Workers have not been notified yet, so the |done_counter| should be
112 // unmodified.
113 EXPECT_EQ(0, done_counter.Get());
114
115 notification->Notify();
116 }
117
118 // After notifying and then joining the workers, both counters should be
119 // fully incremented.
120 notification->WaitForNotification(); // should exit immediately
121 EXPECT_TRUE(notification->HasBeenNotified());
122 EXPECT_TRUE(notification->WaitForNotificationWithTimeout(absl::Seconds(0)));
123 EXPECT_TRUE(notification->WaitForNotificationWithDeadline(absl::Now()));
124 for (std::thread& worker : workers) {
125 worker.join();
126 }
127 EXPECT_EQ(kNumThreads, ready_counter.Get());
128 EXPECT_EQ(kNumThreads, done_counter.Get());
129 }
130
TEST(NotificationTest,SanityTest)131 TEST(NotificationTest, SanityTest) {
132 Notification local_notification1, local_notification2;
133 BasicTests(false, &local_notification1);
134 BasicTests(true, &local_notification2);
135 }
136
137 #if ABSL_HAVE_ATTRIBUTE_WEAK
138
139 namespace base_internal {
140
141 namespace {
142
143 using TraceRecord = std::tuple<const void*, ObjectKind>;
144
145 thread_local TraceRecord tls_signal;
146 thread_local TraceRecord tls_wait;
147 thread_local TraceRecord tls_continue;
148 thread_local TraceRecord tls_observed;
149
150 } // namespace
151
152 // Strong extern "C" implementation.
153 extern "C" {
154
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)155 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceWait)(const void* object,
156 ObjectKind kind) {
157 tls_wait = {object, kind};
158 }
159
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)160 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceContinue)(const void* object,
161 ObjectKind kind) {
162 tls_continue = {object, kind};
163 }
164
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)165 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceSignal)(const void* object,
166 ObjectKind kind) {
167 tls_signal = {object, kind};
168 }
169
ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceObserved)170 void ABSL_INTERNAL_C_SYMBOL(AbslInternalTraceObserved)(const void* object,
171 ObjectKind kind) {
172 tls_observed = {object, kind};
173 }
174
175 } // extern "C"
176
TEST(NotificationTest,TracesNotify)177 TEST(NotificationTest, TracesNotify) {
178 Notification n;
179 tls_signal = {};
180 n.Notify();
181 EXPECT_EQ(tls_signal, TraceRecord(&n, ObjectKind::kNotification));
182 }
183
TEST(NotificationTest,TracesWaitForNotification)184 TEST(NotificationTest, TracesWaitForNotification) {
185 Notification n;
186 n.Notify();
187 tls_wait = tls_continue = {};
188 n.WaitForNotification();
189 EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
190 EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification));
191 }
192
TEST(NotificationTest,TracesWaitForNotificationWithTimeout)193 TEST(NotificationTest, TracesWaitForNotificationWithTimeout) {
194 Notification n;
195
196 tls_wait = tls_continue = {};
197 n.WaitForNotificationWithTimeout(absl::Milliseconds(1));
198 EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
199 EXPECT_EQ(tls_continue, TraceRecord(nullptr, ObjectKind::kNotification));
200
201 n.Notify();
202 tls_wait = tls_continue = {};
203 n.WaitForNotificationWithTimeout(absl::Milliseconds(1));
204 EXPECT_EQ(tls_wait, TraceRecord(&n, ObjectKind::kNotification));
205 EXPECT_EQ(tls_continue, TraceRecord(&n, ObjectKind::kNotification));
206 }
207
TEST(NotificationTest,TracesHasBeenNotified)208 TEST(NotificationTest, TracesHasBeenNotified) {
209 Notification n;
210
211 tls_observed = {};
212 ASSERT_FALSE(n.HasBeenNotified());
213 EXPECT_EQ(tls_observed, TraceRecord(nullptr, ObjectKind::kUnknown));
214
215 n.Notify();
216 tls_observed = {};
217 ASSERT_TRUE(n.HasBeenNotified());
218 EXPECT_EQ(tls_observed, TraceRecord(&n, ObjectKind::kNotification));
219 }
220
221 } // namespace base_internal
222
223 #endif // ABSL_HAVE_ATTRIBUTE_WEAK
224
225 ABSL_NAMESPACE_END
226 } // namespace absl
227