1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_scheduler.h"
18
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/synchronization/notification.h"
27 #include "absl/time/clock.h"
28 #include "absl/time/time.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/base/scheduler.h"
31 #include "fcp/base/simulated_clock.h"
32
33 namespace fcp {
34 namespace secagg {
35 namespace {
36
37 using ::testing::_;
38 using ::testing::Eq;
39 using ::testing::IsFalse;
40 using ::testing::Lt;
41 using ::testing::StrictMock;
42 using ::testing::Test;
43
44 class MockScheduler : public Scheduler {
45 public:
46 MOCK_METHOD(void, Schedule, (std::function<void()>), (override));
47 MOCK_METHOD(void, WaitUntilIdle, ());
48 };
49
50 // Wrap int in a struct to keep Clang-tidy happy.
51 struct Integer {
Integerfcp::secagg::__anon032271a60111::Integer52 Integer() : value(0) {}
Integerfcp::secagg::__anon032271a60111::Integer53 explicit Integer(int v) : value(v) {}
54 int value;
55 };
56
IntGenerators(int n)57 std::vector<std::function<std::unique_ptr<Integer>()>> IntGenerators(int n) {
58 std::vector<std::function<std::unique_ptr<Integer>()>> generators;
59 for (int i = 1; i <= n; ++i) {
60 generators.emplace_back([i]() { return std::make_unique<Integer>(i); });
61 }
62 return generators;
63 }
64
__anon032271a60302(const Integer& l, const Integer& r) 65 constexpr auto multiply_accumulator = [](const Integer& l, const Integer& r) {
66 return std::make_unique<Integer>(l.value * r.value);
67 };
__anon032271a60402(const std::function<void()>& f) 68 constexpr auto call_fn = [](const std::function<void()>& f) { f(); };
69
TEST(SecAggSchedulerTest,ScheduleCallback)70 TEST(SecAggSchedulerTest, ScheduleCallback) {
71 StrictMock<MockScheduler> parallel_scheduler;
72 StrictMock<MockScheduler> sequential_scheduler;
73
74 EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(0);
75 EXPECT_CALL(sequential_scheduler, Schedule(_)).WillOnce(call_fn);
76
77 SecAggScheduler runner(¶llel_scheduler, &sequential_scheduler);
78
79 int r = 0;
80 runner.ScheduleCallback([&r]() { r = 5; });
81 EXPECT_THAT(r, Eq(5));
82 }
83
TEST(SecAggSchedulerTest,SingleCall)84 TEST(SecAggSchedulerTest, SingleCall) {
85 StrictMock<MockScheduler> parallel_scheduler;
86 StrictMock<MockScheduler> sequential_scheduler;
87
88 EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(6).WillRepeatedly(call_fn);
89 EXPECT_CALL(sequential_scheduler, Schedule(_))
90 .Times(7)
91 .WillRepeatedly(call_fn);
92
93 // Technically unsafe, but we know the pointers will be valid as long as
94 // runner is alive.
95 SecAggScheduler runner(¶llel_scheduler, &sequential_scheduler);
96
97 std::vector<std::function<std::unique_ptr<Integer>()>> generators =
98 IntGenerators(6);
99
100 Integer result;
101 auto accumulator = runner.CreateAccumulator<Integer>(
102 std::make_unique<Integer>(1), multiply_accumulator);
103 for (const auto& generator : generators) {
104 accumulator->Schedule(generator);
105 }
106 accumulator->SetAsyncObserver(
107 [&]() { result = *(accumulator->GetResultAndCancel()); });
108 EXPECT_THAT(result.value, Eq(720)); // 6! = 720
109 }
110
TEST(SecAggSchedulerTest,SingleCallWithDelay)111 TEST(SecAggSchedulerTest, SingleCallWithDelay) {
112 StrictMock<MockScheduler> parallel_scheduler;
113 StrictMock<MockScheduler> sequential_scheduler;
114 SimulatedClock clock;
115
116 EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(6).WillRepeatedly(call_fn);
117 EXPECT_CALL(sequential_scheduler, Schedule(_))
118 .Times(6)
119 .WillRepeatedly(call_fn);
120
121 SecAggScheduler runner(¶llel_scheduler, &sequential_scheduler, &clock);
122
123 std::vector<std::function<std::unique_ptr<Integer>()>> generators =
124 IntGenerators(6);
125
126 Integer result;
127 auto accumulator = runner.CreateAccumulator<Integer>(
128 std::make_unique<Integer>(1), multiply_accumulator);
129 for (const auto& generator : generators) {
130 accumulator->Schedule(generator, absl::Seconds(5));
131 }
132 accumulator->SetAsyncObserver(
133 [&]() { result = *(accumulator->GetResultAndCancel()); });
134
135 // Generators are still delayed.
136 EXPECT_THAT(result.value, Eq(0));
137
138 // Advance time by one second.
139 clock.AdvanceTime(absl::Seconds(1));
140 // Generators are still delayed.
141 EXPECT_THAT(result.value, Eq(0));
142
143 // Advance time by another 4 seconds.
144 clock.AdvanceTime(absl::Seconds(4));
145 EXPECT_THAT(result.value, Eq(720)); // 6! = 720
146 }
147
TEST(SecAggSchedulerTest,TwoCalls)148 TEST(SecAggSchedulerTest, TwoCalls) {
149 StrictMock<MockScheduler> parallel_scheduler;
150 StrictMock<MockScheduler> sequential_scheduler;
151
152 EXPECT_CALL(parallel_scheduler, Schedule(_)).WillRepeatedly(call_fn);
153 EXPECT_CALL(sequential_scheduler, Schedule(_)).WillRepeatedly(call_fn);
154
155 // Technically unsafe, but we know the pointers will be valid as long as
156 // runner is alive.
157 SecAggScheduler runner(¶llel_scheduler, &sequential_scheduler);
158
159 // First call
160 std::vector<std::function<std::unique_ptr<Integer>()>> generators =
161 IntGenerators(6);
162
163 Integer result;
164 auto accumulator = runner.CreateAccumulator<Integer>(
165 std::make_unique<Integer>(1), multiply_accumulator);
166 for (const auto& generator : generators) {
167 accumulator->Schedule(generator);
168 }
169 accumulator->SetAsyncObserver(
170 [&]() { result = *(accumulator->GetResultAndCancel()); });
171
172 EXPECT_THAT(result.value, Eq(720)); // 6! = 720
173
174 // Second call
175 std::vector<std::function<std::unique_ptr<Integer>()>> generators2 =
176 IntGenerators(4);
177 auto accumulator2 = runner.CreateAccumulator<Integer>(
178 std::make_unique<Integer>(1), multiply_accumulator);
179
180 for (const auto& generator : generators2) {
181 accumulator2->Schedule(generator);
182 }
183 accumulator2->SetAsyncObserver(
184 [&]() { result = *(accumulator2->GetResultAndCancel()); });
185 EXPECT_THAT(result.value, Eq(24)); // 4! = 24
186 }
187
TEST(SecAggSchedulerAbortTest,Abort)188 TEST(SecAggSchedulerAbortTest, Abort) {
189 auto parallel_scheduler = fcp::CreateThreadPoolScheduler(4);
190 auto sequential_scheduler = fcp::CreateThreadPoolScheduler(1);
191
192 absl::Notification signal_abort;
193 std::atomic<int> callback_counter = 0;
194
195 std::vector<std::function<std::unique_ptr<Integer>()>> generators;
196 for (int i = 1; i <= 100; ++i) {
197 generators.emplace_back([&, i]() {
198 callback_counter++;
199 // Signal abort when running 10th parallel task
200 if (i == 10) {
201 signal_abort.Notify();
202 }
203 absl::SleepFor(absl::Milliseconds(1));
204 return std::make_unique<Integer>(i);
205 });
206 }
207
208 auto accumulator_func = [&](const Integer& l, const Integer& r) {
209 callback_counter++;
210 return std::make_unique<Integer>(l.value * r.value);
211 };
212
213 SecAggScheduler runner(parallel_scheduler.get(), sequential_scheduler.get());
214 bool final_callback_called = false;
215 auto accumulator = runner.CreateAccumulator<Integer>(
216 std::make_unique<Integer>(1), accumulator_func);
217 for (const auto& generator : generators) {
218 accumulator->Schedule(generator);
219 }
220 accumulator->SetAsyncObserver([&]() { final_callback_called = true; });
221
222 signal_abort.WaitForNotification();
223 accumulator->Cancel();
224
225 int count_after_abort = callback_counter.load();
226 FCP_LOG(INFO) << "count_after_abort = " << count_after_abort;
227
228 // Wait for all scheduled tasks to finish
229 runner.WaitUntilIdle();
230
231 // The final number of callbacks should not change since returning from
232 // Abort.
233 int final_count = callback_counter.load();
234 EXPECT_THAT(final_count, Eq(count_after_abort));
235 EXPECT_THAT(final_count, Lt(generators.size()));
236 EXPECT_THAT(final_callback_called, IsFalse());
237 }
238
239 // Tests that three batches of async work result in three calls to the callback,
240 // which can be overriden in between calls.
TEST(SecAggSchedulerTest,ThreeCallbackCalls)241 TEST(SecAggSchedulerTest, ThreeCallbackCalls) {
242 auto parallel_scheduler = fcp::CreateThreadPoolScheduler(4);
243 auto sequential_scheduler = fcp::CreateThreadPoolScheduler(1);
244
245 SecAggScheduler runner(parallel_scheduler.get(), sequential_scheduler.get());
246
247 std::vector<std::function<std::unique_ptr<Integer>()>> generators =
248 IntGenerators(3);
249
250 auto accumulator = runner.CreateAccumulator<Integer>(
251 std::make_unique<Integer>(1), multiply_accumulator);
252 for (const auto& generator : generators) {
253 accumulator->Schedule(generator);
254 }
255 int callback_counter = 0;
256 accumulator->SetAsyncObserver([&]() { callback_counter++; });
257 runner.WaitUntilIdle();
258 EXPECT_THAT(callback_counter, Eq(1));
259 for (const auto& generator : generators) {
260 accumulator->Schedule(generator);
261 }
262 runner.WaitUntilIdle();
263 // The callback was not re-scheduled, so the second call to Schedule didn't
264 // trigger it. This results in unobserved work.
265 EXPECT_THAT(callback_counter, Eq(1));
266 bool has_work = accumulator->SetAsyncObserver([&]() { callback_counter++; });
267 runner.WaitUntilIdle();
268 EXPECT_TRUE(has_work);
269 EXPECT_THAT(callback_counter, Eq(2));
270 // The accumulator should be idle and without unobserved work at this point.
271 has_work = accumulator->SetAsyncObserver([&]() { callback_counter++; });
272 EXPECT_FALSE(has_work);
273 Integer result;
274 for (const auto& generator : generators) {
275 accumulator->Schedule(generator);
276 }
277 accumulator->SetAsyncObserver(
278 [&]() { result = *(accumulator->GetResultAndCancel()); });
279 runner.WaitUntilIdle();
280 // The last call to SetAsyncObserver overwrittes the previous callback.
281 EXPECT_THAT(callback_counter, Eq(2));
282 EXPECT_THAT(result.value, Eq(216)); // 6^3 = 216
283 }
284
285 } // namespace
286 } // namespace secagg
287 } // namespace fcp
288