1 // Copyright (C) 2019 The Android Open Source Project
2 // Copyright (C) 2019 Google Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "aemu/base/threads/AndroidWorkPool.h"
16 
17 #include "aemu/base/threads/AndroidFunctorThread.h"
18 #include "aemu/base/synchronization/AndroidLock.h"
19 #include "aemu/base/synchronization/AndroidConditionVariable.h"
20 #include "aemu/base/synchronization/AndroidMessageChannel.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <unordered_map>
25 #include <sys/time.h>
26 
27 using gfxstream::guest::AutoLock;
28 using gfxstream::guest::ConditionVariable;
29 using gfxstream::guest::FunctorThread;
30 using gfxstream::guest::Lock;
31 using gfxstream::guest::MessageChannel;
32 
33 namespace gfxstream {
34 namespace guest {
35 
36 static constexpr const uint64_t kMicrosecondsPerSecond = 1000000;
37 static constexpr const uint64_t kNanosecondsPerMicrosecond = 1000;
38 
39 class WaitGroup { // intrusive refcounted
40 public:
41 
WaitGroup(int numTasksRemaining)42     WaitGroup(int numTasksRemaining) :
43         mNumTasksInitial(numTasksRemaining),
44         mNumTasksRemaining(numTasksRemaining) { }
45 
46     ~WaitGroup() = default;
47 
getLock()48     gfxstream::guest::Lock& getLock() { return mLock; }
49 
acquire()50     void acquire() {
51         if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
52             ALOGE("%s: goofed, refcount0 acquire\n", __func__);
53             abort();
54         }
55     }
56 
release()57     bool release() {
58         if (0 == mRefCount) {
59             ALOGE("%s: goofed, refcount0 release\n", __func__);
60             abort();
61         }
62         if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
63             std::atomic_thread_fence(std::memory_order_acquire);
64             delete this;
65             return true;
66         }
67         return false;
68     }
69 
70     // wait on all of or any of the associated tasks to complete.
waitAllLocked(WorkPool::TimeoutUs timeout)71     bool waitAllLocked(WorkPool::TimeoutUs timeout) {
72         return conditionalTimeoutLocked(
73             [this] { return mNumTasksRemaining > 0; },
74             timeout);
75     }
76 
waitAnyLocked(WorkPool::TimeoutUs timeout)77     bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
78         return conditionalTimeoutLocked(
79             [this] { return mNumTasksRemaining == mNumTasksInitial; },
80             timeout);
81     }
82 
83     // broadcasts to all waiters that there has been a new job that has completed
decrementBroadcast()84     bool decrementBroadcast() {
85         AutoLock<Lock> lock(mLock);
86         bool done =
87             (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
88         std::atomic_thread_fence(std::memory_order_acquire);
89         mCv.broadcast();
90         return done;
91     }
92 
93 private:
94 
doWait(WorkPool::TimeoutUs timeout)95     bool doWait(WorkPool::TimeoutUs timeout) {
96         if (timeout == ~0ULL) {
97             ALOGV("%s: uncond wait\n", __func__);
98             mCv.wait(&mLock);
99             return true;
100         } else {
101             return mCv.timedWait(&mLock, getDeadline(timeout));
102         }
103     }
104 
getDeadline(WorkPool::TimeoutUs relative)105     struct timespec getDeadline(WorkPool::TimeoutUs relative) {
106         struct timeval deadlineUs;
107         struct timespec deadlineNs;
108         gettimeofday(&deadlineUs, 0);
109 
110         deadlineUs.tv_sec += (relative / kMicrosecondsPerSecond);
111         deadlineUs.tv_usec += (relative % kMicrosecondsPerSecond);
112 
113         if (deadlineUs.tv_usec > kMicrosecondsPerSecond) {
114             deadlineUs.tv_sec += (deadlineUs.tv_usec / kMicrosecondsPerSecond);
115             deadlineUs.tv_usec = (deadlineUs.tv_usec % kMicrosecondsPerSecond);
116         }
117 
118         deadlineNs.tv_sec = deadlineUs.tv_sec;
119         deadlineNs.tv_nsec = deadlineUs.tv_usec * kNanosecondsPerMicrosecond;
120         return deadlineNs;
121     }
122 
currTimeUs()123     uint64_t currTimeUs() {
124         struct timeval tv;
125         gettimeofday(&tv, 0);
126         return (uint64_t)(tv.tv_sec * kMicrosecondsPerSecond + tv.tv_usec);
127     }
128 
conditionalTimeoutLocked(std::function<bool ()> conditionFunc,WorkPool::TimeoutUs timeout)129     bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
130         uint64_t currTime = currTimeUs();
131         WorkPool::TimeoutUs currTimeout = timeout;
132 
133         while (conditionFunc()) {
134             doWait(currTimeout);
135             if (conditionFunc()) {
136                 // Decrement timeout for wakeups
137                 uint64_t nextTime = currTimeUs();
138                 WorkPool::TimeoutUs waited =
139                     nextTime - currTime;
140                 currTime = nextTime;
141 
142                 if (currTimeout > waited) {
143                     currTimeout -= waited;
144                 } else {
145                     return conditionFunc();
146                 }
147             }
148         }
149 
150         return true;
151     }
152 
153     std::atomic<int> mRefCount = { 1 };
154     int mNumTasksInitial;
155     std::atomic<int> mNumTasksRemaining;
156 
157     Lock mLock;
158     ConditionVariable mCv;
159 };
160 
161 class WorkPoolThread {
162 public:
163     // State diagram for each work pool thread
164     //
165     // Unacquired: (Start state) When no one else has claimed the thread.
166     // Acquired: When the thread has been claimed for work,
167     // but work has not been issued to it yet.
168     // Scheduled: When the thread is running tasks from the acquirer.
169     // Exiting: cleanup
170     //
171     // Messages:
172     //
173     // Acquire
174     // Run
175     // Exit
176     //
177     // Transitions:
178     //
179     // Note: While task is being run, messages will come back with a failure value.
180     //
181     // Unacquired:
182     //     message Acquire -> Acquired. effect: return success value
183     //     message Run -> Unacquired. effect: return failure value
184     //     message Exit -> Exiting. effect: return success value
185     //
186     // Acquired:
187     //     message Acquire -> Acquired. effect: return failure value
188     //     message Run -> Scheduled. effect: run the task, return success
189     //     message Exit -> Exiting. effect: return success value
190     //
191     // Scheduled:
192     //     implicit effect: after task is run, transition back to Unacquired.
193     //     message Acquire -> Scheduled. effect: return failure value
194     //     message Run -> Scheduled. effect: return failure value
195     //     message Exit -> queue up exit message, then transition to Exiting after that is done.
196     //         effect: return success value
197     //
198     enum State {
199         Unacquired = 0,
200         Acquired = 1,
201         Scheduled = 2,
202         Exiting = 3,
203     };
204 
__anonf4013c230302null205     WorkPoolThread() : mThread([this] { threadFunc(); }) {
206         mThread.start();
207     }
208 
~WorkPoolThread()209     ~WorkPoolThread() {
210         exit();
211         mThread.wait();
212     }
213 
acquire()214     bool acquire() {
215         AutoLock<Lock> lock(mLock);
216         switch (mState) {
217             case State::Unacquired:
218                 mState = State::Acquired;
219                 return true;
220             case State::Acquired:
221             case State::Scheduled:
222             case State::Exiting:
223                 return false;
224             default:
225                 return false;
226         }
227     }
228 
run(WorkPool::WaitGroupHandle waitGroupHandle,WaitGroup * waitGroup,WorkPool::Task task)229     bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
230         AutoLock<Lock> lock(mLock);
231         switch (mState) {
232             case State::Unacquired:
233                 return false;
234             case State::Acquired: {
235                 mState = State::Scheduled;
236                 mToCleanupWaitGroupHandle = waitGroupHandle;
237                 waitGroup->acquire();
238                 mToCleanupWaitGroup = waitGroup;
239                 mShouldCleanupWaitGroup = false;
240                 TaskInfo msg = {
241                     Command::Run,
242                     waitGroup, task,
243                 };
244                 mRunMessages.send(msg);
245                 return true;
246             }
247             case State::Scheduled:
248             case State::Exiting:
249                 return false;
250             default:
251                 return false;
252         }
253     }
254 
shouldCleanupWaitGroup(WorkPool::WaitGroupHandle * waitGroupHandle,WaitGroup ** waitGroup)255     bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
256         AutoLock<Lock> lock(mLock);
257         bool res = mShouldCleanupWaitGroup;
258         *waitGroupHandle = mToCleanupWaitGroupHandle;
259         *waitGroup = mToCleanupWaitGroup;
260         mShouldCleanupWaitGroup = false;
261         return res;
262     }
263 
264 private:
265     enum Command {
266         Run = 0,
267         Exit = 1,
268     };
269 
270     struct TaskInfo {
271         Command cmd;
272         WaitGroup* waitGroup = nullptr;
273         WorkPool::Task task = {};
274     };
275 
exit()276     bool exit() {
277         AutoLock<Lock> lock(mLock);
278         TaskInfo msg { Command::Exit, };
279         mRunMessages.send(msg);
280         return true;
281     }
282 
threadFunc()283     void threadFunc() {
284         TaskInfo taskInfo;
285         bool done = false;
286 
287         while (!done) {
288             mRunMessages.receive(&taskInfo);
289             switch (taskInfo.cmd) {
290                 case Command::Run:
291                     doRun(taskInfo);
292                     break;
293                 case Command::Exit: {
294                     AutoLock<Lock> lock(mLock);
295                     mState = State::Exiting;
296                     break;
297                 }
298             }
299             AutoLock<Lock> lock(mLock);
300             done = mState == State::Exiting;
301         }
302     }
303 
304     // Assumption: the wait group refcount is >= 1 when entering
305     // this function (before decrement)..
306     // at least it doesn't get to 0
doRun(TaskInfo & msg)307     void doRun(TaskInfo& msg) {
308         WaitGroup* waitGroup = msg.waitGroup;
309 
310         if (msg.task) msg.task();
311 
312         bool lastTask =
313             waitGroup->decrementBroadcast();
314 
315         AutoLock<Lock> lock(mLock);
316         mState = State::Unacquired;
317 
318         if (lastTask) {
319             mShouldCleanupWaitGroup = true;
320         }
321 
322         waitGroup->release();
323     }
324 
325     FunctorThread mThread;
326     Lock mLock;
327     State mState = State::Unacquired;
328     MessageChannel<TaskInfo, 4> mRunMessages;
329     WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
330     WaitGroup* mToCleanupWaitGroup = nullptr;
331     bool mShouldCleanupWaitGroup = false;
332 };
333 
334 class WorkPool::Impl {
335 public:
Impl(int numInitialThreads)336     Impl(int numInitialThreads) : mThreads(numInitialThreads) {
337         for (size_t i = 0; i < mThreads.size(); ++i) {
338             mThreads[i].reset(new WorkPoolThread);
339         }
340     }
341 
342     ~Impl() = default;
343 
schedule(const std::vector<WorkPool::Task> & tasks)344     WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
345 
346         if (tasks.empty()) abort();
347 
348         AutoLock<Lock> lock(mLock);
349 
350         // Sweep old wait groups
351         for (size_t i = 0; i < mThreads.size(); ++i) {
352             WaitGroupHandle handle;
353             WaitGroup* waitGroup;
354             bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
355             if (cleanup) {
356                 mWaitGroups.erase(handle);
357                 waitGroup->release();
358             }
359         }
360 
361         WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
362         WaitGroup* waitGroup =
363             new WaitGroup(tasks.size());
364 
365         mWaitGroups[resHandle] = waitGroup;
366 
367         std::vector<size_t> threadIndices;
368 
369         while (threadIndices.size() < tasks.size()) {
370             for (size_t i = 0; i < mThreads.size(); ++i) {
371                 if (!mThreads[i]->acquire()) continue;
372                 threadIndices.push_back(i);
373                 if (threadIndices.size() == tasks.size()) break;
374             }
375             if (threadIndices.size() < tasks.size()) {
376                 mThreads.resize(mThreads.size() + 1);
377                 mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
378             }
379         }
380 
381         // every thread here is acquired
382         for (size_t i = 0; i < threadIndices.size(); ++i) {
383             mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
384         }
385 
386         return resHandle;
387     }
388 
waitAny(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)389     bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
390         AutoLock<Lock> lock(mLock);
391         auto it = mWaitGroups.find(waitGroupHandle);
392         if (it == mWaitGroups.end()) return true;
393 
394         auto waitGroup = it->second;
395         waitGroup->acquire();
396         lock.unlock();
397 
398         bool waitRes = false;
399 
400         {
401             AutoLock<Lock> waitGroupLock(waitGroup->getLock());
402             waitRes = waitGroup->waitAnyLocked(timeout);
403         }
404 
405         waitGroup->release();
406 
407         return waitRes;
408     }
409 
waitAll(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)410     bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
411         auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
412         if (!waitGroup) return true;
413 
414         bool waitRes = false;
415 
416         {
417             AutoLock<Lock> waitGroupLock(waitGroup->getLock());
418             waitRes = waitGroup->waitAllLocked(timeout);
419         }
420 
421         waitGroup->release();
422 
423         return waitRes;
424     }
425 
426 private:
427     // Increments wait group refcount by 1.
acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle)428     WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
429         AutoLock<Lock> lock(mLock);
430         auto it = mWaitGroups.find(waitGroupHandle);
431         if (it == mWaitGroups.end()) return nullptr;
432 
433         auto waitGroup = it->second;
434         waitGroup->acquire();
435 
436         return waitGroup;
437     }
438 
439     using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
440 
genWaitGroupHandleLocked()441     WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
442         WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
443         ++mNextWaitGroupHandle;
444         return res;
445     }
446 
447     Lock mLock;
448     uint64_t mNextWaitGroupHandle = 0;
449     WaitGroupStore mWaitGroups;
450     std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
451 };
452 
WorkPool(int numInitialThreads)453 WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
454 WorkPool::~WorkPool() = default;
455 
schedule(const std::vector<WorkPool::Task> & tasks)456 WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
457     return mImpl->schedule(tasks);
458 }
459 
waitAny(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)460 bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
461     return mImpl->waitAny(waitGroup, timeout);
462 }
463 
waitAll(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)464 bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
465     return mImpl->waitAll(waitGroup, timeout);
466 }
467 
468 } // namespace guest
469 } // namespace gfxstream
470