1*5f39d1b3SJooyung Han // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han // http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han
15*5f39d1b3SJooyung Han #include <atomic> // NOLINT
16*5f39d1b3SJooyung Han #include <vector>
17*5f39d1b3SJooyung Han #include <iostream>
18*5f39d1b3SJooyung Han #include <cstdlib>
19*5f39d1b3SJooyung Han
20*5f39d1b3SJooyung Han #include "../internal/multi_thread_gemm.h"
21*5f39d1b3SJooyung Han #include "../profiling/pthread_everywhere.h"
22*5f39d1b3SJooyung Han #include "test.h"
23*5f39d1b3SJooyung Han
24*5f39d1b3SJooyung Han namespace gemmlowp {
25*5f39d1b3SJooyung Han
26*5f39d1b3SJooyung Han class Thread {
27*5f39d1b3SJooyung Han public:
Thread(BlockingCounter * blocking_counter,int number_of_times_to_decrement)28*5f39d1b3SJooyung Han Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
29*5f39d1b3SJooyung Han : blocking_counter_(blocking_counter),
30*5f39d1b3SJooyung Han number_of_times_to_decrement_(number_of_times_to_decrement),
31*5f39d1b3SJooyung Han made_the_last_decrement_(false),
32*5f39d1b3SJooyung Han finished_(false) {
33*5f39d1b3SJooyung Han #if defined GEMMLOWP_USE_PTHREAD
34*5f39d1b3SJooyung Han // Limit the stack size so as not to deplete memory when creating
35*5f39d1b3SJooyung Han // many threads.
36*5f39d1b3SJooyung Han pthread_attr_t attr;
37*5f39d1b3SJooyung Han int err = pthread_attr_init(&attr);
38*5f39d1b3SJooyung Han if (!err) {
39*5f39d1b3SJooyung Han size_t stack_size;
40*5f39d1b3SJooyung Han err = pthread_attr_getstacksize(&attr, &stack_size);
41*5f39d1b3SJooyung Han if (!err && stack_size > max_stack_size_) {
42*5f39d1b3SJooyung Han err = pthread_attr_setstacksize(&attr, max_stack_size_);
43*5f39d1b3SJooyung Han }
44*5f39d1b3SJooyung Han if (!err) {
45*5f39d1b3SJooyung Han err = pthread_create(&thread_, &attr, ThreadFunc, this);
46*5f39d1b3SJooyung Han }
47*5f39d1b3SJooyung Han }
48*5f39d1b3SJooyung Han if (err) {
49*5f39d1b3SJooyung Han std::cerr << "Failed to create a thread.\n";
50*5f39d1b3SJooyung Han std::abort();
51*5f39d1b3SJooyung Han }
52*5f39d1b3SJooyung Han #else
53*5f39d1b3SJooyung Han pthread_create(&thread_, nullptr, ThreadFunc, this);
54*5f39d1b3SJooyung Han #endif
55*5f39d1b3SJooyung Han }
56*5f39d1b3SJooyung Han
~Thread()57*5f39d1b3SJooyung Han ~Thread() { Join(); }
58*5f39d1b3SJooyung Han
Join()59*5f39d1b3SJooyung Han bool Join() {
60*5f39d1b3SJooyung Han while (!finished_.load()) {
61*5f39d1b3SJooyung Han }
62*5f39d1b3SJooyung Han return made_the_last_decrement_;
63*5f39d1b3SJooyung Han }
64*5f39d1b3SJooyung Han
65*5f39d1b3SJooyung Han private:
66*5f39d1b3SJooyung Han Thread(const Thread& other) = delete;
67*5f39d1b3SJooyung Han
ThreadFunc()68*5f39d1b3SJooyung Han void ThreadFunc() {
69*5f39d1b3SJooyung Han for (int i = 0; i < number_of_times_to_decrement_; i++) {
70*5f39d1b3SJooyung Han Check(!made_the_last_decrement_);
71*5f39d1b3SJooyung Han made_the_last_decrement_ = blocking_counter_->DecrementCount();
72*5f39d1b3SJooyung Han }
73*5f39d1b3SJooyung Han finished_.store(true);
74*5f39d1b3SJooyung Han }
75*5f39d1b3SJooyung Han
ThreadFunc(void * ptr)76*5f39d1b3SJooyung Han static void* ThreadFunc(void* ptr) {
77*5f39d1b3SJooyung Han static_cast<Thread*>(ptr)->ThreadFunc();
78*5f39d1b3SJooyung Han return nullptr;
79*5f39d1b3SJooyung Han }
80*5f39d1b3SJooyung Han
81*5f39d1b3SJooyung Han static constexpr size_t max_stack_size_ = 256 * 1024;
82*5f39d1b3SJooyung Han BlockingCounter* const blocking_counter_;
83*5f39d1b3SJooyung Han const int number_of_times_to_decrement_;
84*5f39d1b3SJooyung Han pthread_t thread_;
85*5f39d1b3SJooyung Han bool made_the_last_decrement_;
86*5f39d1b3SJooyung Han // finished_ is used to manually implement Join() by busy-waiting.
87*5f39d1b3SJooyung Han // I wanted to use pthread_join / std::thread::join, but the behavior
88*5f39d1b3SJooyung Han // observed on Android was that pthread_join aborts when the thread has
89*5f39d1b3SJooyung Han // already joined before calling pthread_join, making that hard to use.
90*5f39d1b3SJooyung Han // It appeared simplest to just implement this simple spinlock, and that
91*5f39d1b3SJooyung Han // is good enough as this is just a test.
92*5f39d1b3SJooyung Han std::atomic<bool> finished_;
93*5f39d1b3SJooyung Han };
94*5f39d1b3SJooyung Han
test_blocking_counter(BlockingCounter * blocking_counter,int num_threads,int num_decrements_per_thread,int num_decrements_to_wait_for)95*5f39d1b3SJooyung Han void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
96*5f39d1b3SJooyung Han int num_decrements_per_thread,
97*5f39d1b3SJooyung Han int num_decrements_to_wait_for) {
98*5f39d1b3SJooyung Han std::vector<Thread*> threads;
99*5f39d1b3SJooyung Han blocking_counter->Reset(num_decrements_to_wait_for);
100*5f39d1b3SJooyung Han for (int i = 0; i < num_threads; i++) {
101*5f39d1b3SJooyung Han threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
102*5f39d1b3SJooyung Han }
103*5f39d1b3SJooyung Han blocking_counter->Wait();
104*5f39d1b3SJooyung Han
105*5f39d1b3SJooyung Han int num_threads_that_made_the_last_decrement = 0;
106*5f39d1b3SJooyung Han for (int i = 0; i < num_threads; i++) {
107*5f39d1b3SJooyung Han if (threads[i]->Join()) {
108*5f39d1b3SJooyung Han num_threads_that_made_the_last_decrement++;
109*5f39d1b3SJooyung Han }
110*5f39d1b3SJooyung Han delete threads[i];
111*5f39d1b3SJooyung Han }
112*5f39d1b3SJooyung Han Check(num_threads_that_made_the_last_decrement == 1);
113*5f39d1b3SJooyung Han }
114*5f39d1b3SJooyung Han
test_blocking_counter()115*5f39d1b3SJooyung Han void test_blocking_counter() {
116*5f39d1b3SJooyung Han BlockingCounter* blocking_counter = new BlockingCounter;
117*5f39d1b3SJooyung Han
118*5f39d1b3SJooyung Han // repeating the entire test sequence ensures that we test
119*5f39d1b3SJooyung Han // non-monotonic changes.
120*5f39d1b3SJooyung Han for (int repeat = 1; repeat <= 2; repeat++) {
121*5f39d1b3SJooyung Han for (int num_threads = 1; num_threads <= 5; num_threads++) {
122*5f39d1b3SJooyung Han for (int num_decrements_per_thread = 1;
123*5f39d1b3SJooyung Han num_decrements_per_thread <= 4 * 1024;
124*5f39d1b3SJooyung Han num_decrements_per_thread *= 16) {
125*5f39d1b3SJooyung Han test_blocking_counter(blocking_counter, num_threads,
126*5f39d1b3SJooyung Han num_decrements_per_thread,
127*5f39d1b3SJooyung Han num_threads * num_decrements_per_thread);
128*5f39d1b3SJooyung Han }
129*5f39d1b3SJooyung Han }
130*5f39d1b3SJooyung Han }
131*5f39d1b3SJooyung Han delete blocking_counter;
132*5f39d1b3SJooyung Han }
133*5f39d1b3SJooyung Han
134*5f39d1b3SJooyung Han } // end namespace gemmlowp
135*5f39d1b3SJooyung Han
main()136*5f39d1b3SJooyung Han int main() { gemmlowp::test_blocking_counter(); }
137