1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gtest/gtest.h>
10
11 #include <array>
12 #include <mutex>
13
14 #include <executorch/extension/parallel/thread_parallel.h>
15 #include <executorch/runtime/platform/platform.h>
16
17 using namespace ::testing;
18 using ::executorch::extension::parallel_for;
19
20 class ParallelTest : public ::testing::Test {
21 protected:
SetUp()22 void SetUp() override {
23 data_.fill(0);
24 sum_of_all_elements_ = 0;
25 }
26
RunTask(int64_t begin,int64_t end)27 void RunTask(int64_t begin, int64_t end) {
28 for (int64_t j = begin; j < end; ++j) {
29 // Check that we haven't written to this index before
30 EXPECT_EQ(data_[j], 0);
31 data_[j] = j;
32 }
33 }
34
RunExclusiveTask(int64_t begin,int64_t end)35 void RunExclusiveTask(int64_t begin, int64_t end) {
36 for (int64_t j = begin; j < end; ++j) {
37 // Check that we haven't written to this index before
38 EXPECT_EQ(data_[j], 0);
39 std::lock_guard<std::mutex> lock(mutex_);
40 data_[j] = j;
41 sum_of_all_elements_ += data_[j];
42 }
43 }
44
45 std::array<int, 10> data_;
46 std::mutex mutex_;
47 int sum_of_all_elements_;
48 };
49
TEST_F(ParallelTest,TestAllInvoked)50 TEST_F(ParallelTest, TestAllInvoked) {
51 EXPECT_TRUE(parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
52 this->RunTask(begin, end);
53 }));
54
55 for (int64_t i = 0; i < 10; ++i) {
56 EXPECT_EQ(data_[i], i);
57 }
58 }
59
TEST_F(ParallelTest,TestAllInvokedWithMutex)60 TEST_F(ParallelTest, TestAllInvokedWithMutex) {
61 EXPECT_TRUE(parallel_for(0, 10, 1, [this](int64_t begin, int64_t end) {
62 this->RunExclusiveTask(begin, end);
63 }));
64
65 int expected_sum = 0;
66 for (int64_t i = 0; i < 10; ++i) {
67 EXPECT_EQ(data_[i], i);
68 expected_sum += i;
69 }
70 EXPECT_EQ(sum_of_all_elements_, expected_sum);
71 }
72
TEST_F(ParallelTest,TestInvalidRange)73 TEST_F(ParallelTest, TestInvalidRange) {
74 et_pal_init();
75 EXPECT_FALSE(parallel_for(10, 0, 1, [this](int64_t begin, int64_t end) {
76 this->RunExclusiveTask(begin, end);
77 }));
78
79 for (int64_t i = 0; i < 10; ++i) {
80 EXPECT_EQ(data_[i], 0);
81 }
82 EXPECT_EQ(sum_of_all_elements_, 0);
83 }
84
TEST_F(ParallelTest,TestInvalidRange2)85 TEST_F(ParallelTest, TestInvalidRange2) {
86 et_pal_init();
87 EXPECT_FALSE(parallel_for(6, 5, 1, [this](int64_t begin, int64_t end) {
88 this->RunExclusiveTask(begin, end);
89 }));
90
91 for (int64_t i = 0; i < 10; ++i) {
92 EXPECT_EQ(data_[i], 0);
93 }
94 EXPECT_EQ(sum_of_all_elements_, 0);
95 }
96
TEST_F(ParallelTest,TestInvokePartialFromBeginning)97 TEST_F(ParallelTest, TestInvokePartialFromBeginning) {
98 EXPECT_TRUE(parallel_for(0, 5, 1, [this](int64_t begin, int64_t end) {
99 this->RunTask(begin, end);
100 }));
101
102 for (int64_t i = 0; i < 5; ++i) {
103 EXPECT_EQ(data_[i], i);
104 }
105 for (int64_t i = 5; i < 10; ++i) {
106 EXPECT_EQ(data_[i], 0);
107 }
108 }
109
TEST_F(ParallelTest,TestInvokePartialToEnd)110 TEST_F(ParallelTest, TestInvokePartialToEnd) {
111 EXPECT_TRUE(parallel_for(5, 10, 1, [this](int64_t begin, int64_t end) {
112 this->RunTask(begin, end);
113 }));
114
115 for (int64_t i = 0; i < 5; ++i) {
116 EXPECT_EQ(data_[i], 0);
117 }
118 for (int64_t i = 5; i < 10; ++i) {
119 EXPECT_EQ(data_[i], i);
120 }
121 }
122
TEST_F(ParallelTest,TestInvokePartialMiddle)123 TEST_F(ParallelTest, TestInvokePartialMiddle) {
124 EXPECT_TRUE(parallel_for(2, 8, 1, [this](int64_t begin, int64_t end) {
125 this->RunTask(begin, end);
126 }));
127
128 for (int64_t i = 0; i < 2; ++i) {
129 EXPECT_EQ(data_[i], 0);
130 }
131 for (int64_t i = 2; i < 8; ++i) {
132 EXPECT_EQ(data_[i], i);
133 }
134 for (int64_t i = 8; i < 10; ++i) {
135 EXPECT_EQ(data_[i], 0);
136 }
137 }
138
TEST_F(ParallelTest,TestChunkSize2)139 TEST_F(ParallelTest, TestChunkSize2) {
140 EXPECT_TRUE(parallel_for(0, 10, 2, [this](int64_t begin, int64_t end) {
141 this->RunTask(begin, end);
142 }));
143
144 for (int64_t i = 0; i < 10; ++i) {
145 EXPECT_EQ(data_[i], i);
146 }
147 }
148
TEST_F(ParallelTest,TestChunkSize2Middle)149 TEST_F(ParallelTest, TestChunkSize2Middle) {
150 EXPECT_TRUE(parallel_for(3, 8, 2, [this](int64_t begin, int64_t end) {
151 this->RunTask(begin, end);
152 }));
153
154 for (int64_t i = 0; i < 3; ++i) {
155 EXPECT_EQ(data_[i], 0);
156 }
157 for (int64_t i = 3; i < 8; ++i) {
158 EXPECT_EQ(data_[i], i);
159 }
160 for (int64_t i = 8; i < 10; ++i) {
161 EXPECT_EQ(data_[i], 0);
162 }
163 }
164
TEST_F(ParallelTest,TestChunkSize3)165 TEST_F(ParallelTest, TestChunkSize3) {
166 EXPECT_TRUE(parallel_for(0, 10, 3, [this](int64_t begin, int64_t end) {
167 this->RunTask(begin, end);
168 }));
169
170 for (int64_t i = 0; i < 10; ++i) {
171 EXPECT_EQ(data_[i], i);
172 }
173 }
174
TEST_F(ParallelTest,TestChunkSize6)175 TEST_F(ParallelTest, TestChunkSize6) {
176 EXPECT_TRUE(parallel_for(0, 10, 6, [this](int64_t begin, int64_t end) {
177 this->RunTask(begin, end);
178 }));
179
180 for (int64_t i = 0; i < 10; ++i) {
181 EXPECT_EQ(data_[i], i);
182 }
183 }
184
TEST_F(ParallelTest,TestChunkSizeTooLarge)185 TEST_F(ParallelTest, TestChunkSizeTooLarge) {
186 EXPECT_TRUE(parallel_for(0, 10, 11, [this](int64_t begin, int64_t end) {
187 this->RunTask(begin, end);
188 }));
189
190 for (int64_t i = 0; i < 10; ++i) {
191 EXPECT_EQ(data_[i], i);
192 }
193 }
194