xref: /aosp_15_r20/external/executorch/extension/parallel/test/thread_parallel_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 
11 #include <array>
12 #include <mutex>
13 
14 #include <executorch/extension/parallel/thread_parallel.h>
15 #include <executorch/runtime/platform/platform.h>
16 
17 using namespace ::testing;
18 using ::executorch::extension::parallel_for;
19 
20 class ParallelTest : public ::testing::Test {
21  protected:
SetUp()22   void SetUp() override {
23     data_.fill(0);
24     sum_of_all_elements_ = 0;
25   }
26 
RunTask(int64_t begin,int64_t end)27   void RunTask(int64_t begin, int64_t end) {
28     for (int64_t j = begin; j < end; ++j) {
29       // Check that we haven't written to this index before
30       EXPECT_EQ(data_[j], 0);
31       data_[j] = j;
32     }
33   }
34 
RunExclusiveTask(int64_t begin,int64_t end)35   void RunExclusiveTask(int64_t begin, int64_t end) {
36     for (int64_t j = begin; j < end; ++j) {
37       // Check that we haven't written to this index before
38       EXPECT_EQ(data_[j], 0);
39       std::lock_guard<std::mutex> lock(mutex_);
40       data_[j] = j;
41       sum_of_all_elements_ += data_[j];
42     }
43   }
44 
45   std::array<int, 10> data_;
46   std::mutex mutex_;
47   int sum_of_all_elements_;
48 };
49 
TEST_F(ParallelTest,TestAllInvoked)50 TEST_F(ParallelTest, TestAllInvoked) {
51   EXPECT_TRUE(parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
52     this->RunTask(begin, end);
53   }));
54 
55   for (int64_t i = 0; i < 10; ++i) {
56     EXPECT_EQ(data_[i], i);
57   }
58 }
59 
TEST_F(ParallelTest,TestAllInvokedWithMutex)60 TEST_F(ParallelTest, TestAllInvokedWithMutex) {
61   EXPECT_TRUE(parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
62     this->RunExclusiveTask(begin, end);
63   }));
64 
65   int expected_sum = 0;
66   for (int64_t i = 0; i < 10; ++i) {
67     EXPECT_EQ(data_[i], i);
68     expected_sum += i;
69   }
70   EXPECT_EQ(sum_of_all_elements_, expected_sum);
71 }
72 
TEST_F(ParallelTest,TestInvalidRange)73 TEST_F(ParallelTest, TestInvalidRange) {
74   et_pal_init();
75   EXPECT_FALSE(parallel_for(10, 0, 1, [this](int64_t begin, int64_t end) {
76     this->RunExclusiveTask(begin, end);
77   }));
78 
79   for (int64_t i = 0; i < 10; ++i) {
80     EXPECT_EQ(data_[i], 0);
81   }
82   EXPECT_EQ(sum_of_all_elements_, 0);
83 }
84 
TEST_F(ParallelTest,TestInvalidRange2)85 TEST_F(ParallelTest, TestInvalidRange2) {
86   et_pal_init();
87   EXPECT_FALSE(parallel_for(6, 5, 1, [this](int64_t begin, int64_t end) {
88     this->RunExclusiveTask(begin, end);
89   }));
90 
91   for (int64_t i = 0; i < 10; ++i) {
92     EXPECT_EQ(data_[i], 0);
93   }
94   EXPECT_EQ(sum_of_all_elements_, 0);
95 }
96 
TEST_F(ParallelTest,TestInvokePartialFromBeginning)97 TEST_F(ParallelTest, TestInvokePartialFromBeginning) {
98   EXPECT_TRUE(parallel_for(0, 5, 1, [this](int64_t begin, int64_t end) {
99     this->RunTask(begin, end);
100   }));
101 
102   for (int64_t i = 0; i < 5; ++i) {
103     EXPECT_EQ(data_[i], i);
104   }
105   for (int64_t i = 5; i < 10; ++i) {
106     EXPECT_EQ(data_[i], 0);
107   }
108 }
109 
TEST_F(ParallelTest,TestInvokePartialToEnd)110 TEST_F(ParallelTest, TestInvokePartialToEnd) {
111   EXPECT_TRUE(parallel_for(5, 10, 1, [this](int64_t begin, int64_t end) {
112     this->RunTask(begin, end);
113   }));
114 
115   for (int64_t i = 0; i < 5; ++i) {
116     EXPECT_EQ(data_[i], 0);
117   }
118   for (int64_t i = 5; i < 10; ++i) {
119     EXPECT_EQ(data_[i], i);
120   }
121 }
122 
TEST_F(ParallelTest,TestInvokePartialMiddle)123 TEST_F(ParallelTest, TestInvokePartialMiddle) {
124   EXPECT_TRUE(parallel_for(2, 8, 1, [this](int64_t begin, int64_t end) {
125     this->RunTask(begin, end);
126   }));
127 
128   for (int64_t i = 0; i < 2; ++i) {
129     EXPECT_EQ(data_[i], 0);
130   }
131   for (int64_t i = 2; i < 8; ++i) {
132     EXPECT_EQ(data_[i], i);
133   }
134   for (int64_t i = 8; i < 10; ++i) {
135     EXPECT_EQ(data_[i], 0);
136   }
137 }
138 
TEST_F(ParallelTest,TestChunkSize2)139 TEST_F(ParallelTest, TestChunkSize2) {
140   EXPECT_TRUE(parallel_for(0, 10, 2, [this](int64_t begin, int64_t end) {
141     this->RunTask(begin, end);
142   }));
143 
144   for (int64_t i = 0; i < 10; ++i) {
145     EXPECT_EQ(data_[i], i);
146   }
147 }
148 
TEST_F(ParallelTest,TestChunkSize2Middle)149 TEST_F(ParallelTest, TestChunkSize2Middle) {
150   EXPECT_TRUE(parallel_for(3, 8, 2, [this](int64_t begin, int64_t end) {
151     this->RunTask(begin, end);
152   }));
153 
154   for (int64_t i = 0; i < 3; ++i) {
155     EXPECT_EQ(data_[i], 0);
156   }
157   for (int64_t i = 3; i < 8; ++i) {
158     EXPECT_EQ(data_[i], i);
159   }
160   for (int64_t i = 8; i < 10; ++i) {
161     EXPECT_EQ(data_[i], 0);
162   }
163 }
164 
TEST_F(ParallelTest,TestChunkSize3)165 TEST_F(ParallelTest, TestChunkSize3) {
166   EXPECT_TRUE(parallel_for(0, 10, 3, [this](int64_t begin, int64_t end) {
167     this->RunTask(begin, end);
168   }));
169 
170   for (int64_t i = 0; i < 10; ++i) {
171     EXPECT_EQ(data_[i], i);
172   }
173 }
174 
TEST_F(ParallelTest,TestChunkSize6)175 TEST_F(ParallelTest, TestChunkSize6) {
176   EXPECT_TRUE(parallel_for(0, 10, 6, [this](int64_t begin, int64_t end) {
177     this->RunTask(begin, end);
178   }));
179 
180   for (int64_t i = 0; i < 10; ++i) {
181     EXPECT_EQ(data_[i], i);
182   }
183 }
184 
TEST_F(ParallelTest,TestChunkSizeTooLarge)185 TEST_F(ParallelTest, TestChunkSizeTooLarge) {
186   EXPECT_TRUE(parallel_for(0, 10, 11, [this](int64_t begin, int64_t end) {
187     this->RunTask(begin, end);
188   }));
189 
190   for (int64_t i = 0; i < 10; ++i) {
191     EXPECT_EQ(data_[i], i);
192   }
193 }
194