1 // Copyright (C) 2019 The Android Open Source Project
2 // Copyright (C) 2019 Google Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "aemu/base/threads/AndroidWorkPool.h"
16
17 #include "aemu/base/threads/AndroidFunctorThread.h"
18 #include "aemu/base/synchronization/AndroidLock.h"
19 #include "aemu/base/synchronization/AndroidConditionVariable.h"
20 #include "aemu/base/synchronization/AndroidMessageChannel.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <unordered_map>
25 #include <sys/time.h>
26
27 using gfxstream::guest::AutoLock;
28 using gfxstream::guest::ConditionVariable;
29 using gfxstream::guest::FunctorThread;
30 using gfxstream::guest::Lock;
31 using gfxstream::guest::MessageChannel;
32
33 namespace gfxstream {
34 namespace guest {
35
36 static constexpr const uint64_t kMicrosecondsPerSecond = 1000000;
37 static constexpr const uint64_t kNanosecondsPerMicrosecond = 1000;
38
39 class WaitGroup { // intrusive refcounted
40 public:
41
WaitGroup(int numTasksRemaining)42 WaitGroup(int numTasksRemaining) :
43 mNumTasksInitial(numTasksRemaining),
44 mNumTasksRemaining(numTasksRemaining) { }
45
46 ~WaitGroup() = default;
47
getLock()48 gfxstream::guest::Lock& getLock() { return mLock; }
49
acquire()50 void acquire() {
51 if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
52 ALOGE("%s: goofed, refcount0 acquire\n", __func__);
53 abort();
54 }
55 }
56
release()57 bool release() {
58 if (0 == mRefCount) {
59 ALOGE("%s: goofed, refcount0 release\n", __func__);
60 abort();
61 }
62 if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
63 std::atomic_thread_fence(std::memory_order_acquire);
64 delete this;
65 return true;
66 }
67 return false;
68 }
69
70 // wait on all of or any of the associated tasks to complete.
waitAllLocked(WorkPool::TimeoutUs timeout)71 bool waitAllLocked(WorkPool::TimeoutUs timeout) {
72 return conditionalTimeoutLocked(
73 [this] { return mNumTasksRemaining > 0; },
74 timeout);
75 }
76
waitAnyLocked(WorkPool::TimeoutUs timeout)77 bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
78 return conditionalTimeoutLocked(
79 [this] { return mNumTasksRemaining == mNumTasksInitial; },
80 timeout);
81 }
82
83 // broadcasts to all waiters that there has been a new job that has completed
decrementBroadcast()84 bool decrementBroadcast() {
85 AutoLock<Lock> lock(mLock);
86 bool done =
87 (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
88 std::atomic_thread_fence(std::memory_order_acquire);
89 mCv.broadcast();
90 return done;
91 }
92
93 private:
94
doWait(WorkPool::TimeoutUs timeout)95 bool doWait(WorkPool::TimeoutUs timeout) {
96 if (timeout == ~0ULL) {
97 ALOGV("%s: uncond wait\n", __func__);
98 mCv.wait(&mLock);
99 return true;
100 } else {
101 return mCv.timedWait(&mLock, getDeadline(timeout));
102 }
103 }
104
getDeadline(WorkPool::TimeoutUs relative)105 struct timespec getDeadline(WorkPool::TimeoutUs relative) {
106 struct timeval deadlineUs;
107 struct timespec deadlineNs;
108 gettimeofday(&deadlineUs, 0);
109
110 deadlineUs.tv_sec += (relative / kMicrosecondsPerSecond);
111 deadlineUs.tv_usec += (relative % kMicrosecondsPerSecond);
112
113 if (deadlineUs.tv_usec > kMicrosecondsPerSecond) {
114 deadlineUs.tv_sec += (deadlineUs.tv_usec / kMicrosecondsPerSecond);
115 deadlineUs.tv_usec = (deadlineUs.tv_usec % kMicrosecondsPerSecond);
116 }
117
118 deadlineNs.tv_sec = deadlineUs.tv_sec;
119 deadlineNs.tv_nsec = deadlineUs.tv_usec * kNanosecondsPerMicrosecond;
120 return deadlineNs;
121 }
122
currTimeUs()123 uint64_t currTimeUs() {
124 struct timeval tv;
125 gettimeofday(&tv, 0);
126 return (uint64_t)(tv.tv_sec * kMicrosecondsPerSecond + tv.tv_usec);
127 }
128
conditionalTimeoutLocked(std::function<bool ()> conditionFunc,WorkPool::TimeoutUs timeout)129 bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
130 uint64_t currTime = currTimeUs();
131 WorkPool::TimeoutUs currTimeout = timeout;
132
133 while (conditionFunc()) {
134 doWait(currTimeout);
135 if (conditionFunc()) {
136 // Decrement timeout for wakeups
137 uint64_t nextTime = currTimeUs();
138 WorkPool::TimeoutUs waited =
139 nextTime - currTime;
140 currTime = nextTime;
141
142 if (currTimeout > waited) {
143 currTimeout -= waited;
144 } else {
145 return conditionFunc();
146 }
147 }
148 }
149
150 return true;
151 }
152
153 std::atomic<int> mRefCount = { 1 };
154 int mNumTasksInitial;
155 std::atomic<int> mNumTasksRemaining;
156
157 Lock mLock;
158 ConditionVariable mCv;
159 };
160
161 class WorkPoolThread {
162 public:
163 // State diagram for each work pool thread
164 //
165 // Unacquired: (Start state) When no one else has claimed the thread.
166 // Acquired: When the thread has been claimed for work,
167 // but work has not been issued to it yet.
168 // Scheduled: When the thread is running tasks from the acquirer.
169 // Exiting: cleanup
170 //
171 // Messages:
172 //
173 // Acquire
174 // Run
175 // Exit
176 //
177 // Transitions:
178 //
179 // Note: While task is being run, messages will come back with a failure value.
180 //
181 // Unacquired:
182 // message Acquire -> Acquired. effect: return success value
183 // message Run -> Unacquired. effect: return failure value
184 // message Exit -> Exiting. effect: return success value
185 //
186 // Acquired:
187 // message Acquire -> Acquired. effect: return failure value
188 // message Run -> Scheduled. effect: run the task, return success
189 // message Exit -> Exiting. effect: return success value
190 //
191 // Scheduled:
192 // implicit effect: after task is run, transition back to Unacquired.
193 // message Acquire -> Scheduled. effect: return failure value
194 // message Run -> Scheduled. effect: return failure value
195 // message Exit -> queue up exit message, then transition to Exiting after that is done.
196 // effect: return success value
197 //
198 enum State {
199 Unacquired = 0,
200 Acquired = 1,
201 Scheduled = 2,
202 Exiting = 3,
203 };
204
__anonf4013c230302null205 WorkPoolThread() : mThread([this] { threadFunc(); }) {
206 mThread.start();
207 }
208
~WorkPoolThread()209 ~WorkPoolThread() {
210 exit();
211 mThread.wait();
212 }
213
acquire()214 bool acquire() {
215 AutoLock<Lock> lock(mLock);
216 switch (mState) {
217 case State::Unacquired:
218 mState = State::Acquired;
219 return true;
220 case State::Acquired:
221 case State::Scheduled:
222 case State::Exiting:
223 return false;
224 default:
225 return false;
226 }
227 }
228
run(WorkPool::WaitGroupHandle waitGroupHandle,WaitGroup * waitGroup,WorkPool::Task task)229 bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
230 AutoLock<Lock> lock(mLock);
231 switch (mState) {
232 case State::Unacquired:
233 return false;
234 case State::Acquired: {
235 mState = State::Scheduled;
236 mToCleanupWaitGroupHandle = waitGroupHandle;
237 waitGroup->acquire();
238 mToCleanupWaitGroup = waitGroup;
239 mShouldCleanupWaitGroup = false;
240 TaskInfo msg = {
241 Command::Run,
242 waitGroup, task,
243 };
244 mRunMessages.send(msg);
245 return true;
246 }
247 case State::Scheduled:
248 case State::Exiting:
249 return false;
250 default:
251 return false;
252 }
253 }
254
shouldCleanupWaitGroup(WorkPool::WaitGroupHandle * waitGroupHandle,WaitGroup ** waitGroup)255 bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
256 AutoLock<Lock> lock(mLock);
257 bool res = mShouldCleanupWaitGroup;
258 *waitGroupHandle = mToCleanupWaitGroupHandle;
259 *waitGroup = mToCleanupWaitGroup;
260 mShouldCleanupWaitGroup = false;
261 return res;
262 }
263
264 private:
265 enum Command {
266 Run = 0,
267 Exit = 1,
268 };
269
270 struct TaskInfo {
271 Command cmd;
272 WaitGroup* waitGroup = nullptr;
273 WorkPool::Task task = {};
274 };
275
exit()276 bool exit() {
277 AutoLock<Lock> lock(mLock);
278 TaskInfo msg { Command::Exit, };
279 mRunMessages.send(msg);
280 return true;
281 }
282
threadFunc()283 void threadFunc() {
284 TaskInfo taskInfo;
285 bool done = false;
286
287 while (!done) {
288 mRunMessages.receive(&taskInfo);
289 switch (taskInfo.cmd) {
290 case Command::Run:
291 doRun(taskInfo);
292 break;
293 case Command::Exit: {
294 AutoLock<Lock> lock(mLock);
295 mState = State::Exiting;
296 break;
297 }
298 }
299 AutoLock<Lock> lock(mLock);
300 done = mState == State::Exiting;
301 }
302 }
303
304 // Assumption: the wait group refcount is >= 1 when entering
305 // this function (before decrement)..
306 // at least it doesn't get to 0
doRun(TaskInfo & msg)307 void doRun(TaskInfo& msg) {
308 WaitGroup* waitGroup = msg.waitGroup;
309
310 if (msg.task) msg.task();
311
312 bool lastTask =
313 waitGroup->decrementBroadcast();
314
315 AutoLock<Lock> lock(mLock);
316 mState = State::Unacquired;
317
318 if (lastTask) {
319 mShouldCleanupWaitGroup = true;
320 }
321
322 waitGroup->release();
323 }
324
325 FunctorThread mThread;
326 Lock mLock;
327 State mState = State::Unacquired;
328 MessageChannel<TaskInfo, 4> mRunMessages;
329 WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
330 WaitGroup* mToCleanupWaitGroup = nullptr;
331 bool mShouldCleanupWaitGroup = false;
332 };
333
334 class WorkPool::Impl {
335 public:
Impl(int numInitialThreads)336 Impl(int numInitialThreads) : mThreads(numInitialThreads) {
337 for (size_t i = 0; i < mThreads.size(); ++i) {
338 mThreads[i].reset(new WorkPoolThread);
339 }
340 }
341
342 ~Impl() = default;
343
schedule(const std::vector<WorkPool::Task> & tasks)344 WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
345
346 if (tasks.empty()) abort();
347
348 AutoLock<Lock> lock(mLock);
349
350 // Sweep old wait groups
351 for (size_t i = 0; i < mThreads.size(); ++i) {
352 WaitGroupHandle handle;
353 WaitGroup* waitGroup;
354 bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
355 if (cleanup) {
356 mWaitGroups.erase(handle);
357 waitGroup->release();
358 }
359 }
360
361 WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
362 WaitGroup* waitGroup =
363 new WaitGroup(tasks.size());
364
365 mWaitGroups[resHandle] = waitGroup;
366
367 std::vector<size_t> threadIndices;
368
369 while (threadIndices.size() < tasks.size()) {
370 for (size_t i = 0; i < mThreads.size(); ++i) {
371 if (!mThreads[i]->acquire()) continue;
372 threadIndices.push_back(i);
373 if (threadIndices.size() == tasks.size()) break;
374 }
375 if (threadIndices.size() < tasks.size()) {
376 mThreads.resize(mThreads.size() + 1);
377 mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
378 }
379 }
380
381 // every thread here is acquired
382 for (size_t i = 0; i < threadIndices.size(); ++i) {
383 mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
384 }
385
386 return resHandle;
387 }
388
waitAny(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)389 bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
390 AutoLock<Lock> lock(mLock);
391 auto it = mWaitGroups.find(waitGroupHandle);
392 if (it == mWaitGroups.end()) return true;
393
394 auto waitGroup = it->second;
395 waitGroup->acquire();
396 lock.unlock();
397
398 bool waitRes = false;
399
400 {
401 AutoLock<Lock> waitGroupLock(waitGroup->getLock());
402 waitRes = waitGroup->waitAnyLocked(timeout);
403 }
404
405 waitGroup->release();
406
407 return waitRes;
408 }
409
waitAll(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)410 bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
411 auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
412 if (!waitGroup) return true;
413
414 bool waitRes = false;
415
416 {
417 AutoLock<Lock> waitGroupLock(waitGroup->getLock());
418 waitRes = waitGroup->waitAllLocked(timeout);
419 }
420
421 waitGroup->release();
422
423 return waitRes;
424 }
425
426 private:
427 // Increments wait group refcount by 1.
acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle)428 WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
429 AutoLock<Lock> lock(mLock);
430 auto it = mWaitGroups.find(waitGroupHandle);
431 if (it == mWaitGroups.end()) return nullptr;
432
433 auto waitGroup = it->second;
434 waitGroup->acquire();
435
436 return waitGroup;
437 }
438
439 using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
440
genWaitGroupHandleLocked()441 WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
442 WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
443 ++mNextWaitGroupHandle;
444 return res;
445 }
446
447 Lock mLock;
448 uint64_t mNextWaitGroupHandle = 0;
449 WaitGroupStore mWaitGroups;
450 std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
451 };
452
WorkPool(int numInitialThreads)453 WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
454 WorkPool::~WorkPool() = default;
455
schedule(const std::vector<WorkPool::Task> & tasks)456 WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
457 return mImpl->schedule(tasks);
458 }
459
waitAny(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)460 bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
461 return mImpl->waitAny(waitGroup, timeout);
462 }
463
waitAll(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)464 bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
465 return mImpl->waitAll(waitGroup, timeout);
466 }
467
468 } // namespace guest
469 } // namespace gfxstream
470