xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_scheduler_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_scheduler.h"
18 
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/synchronization/notification.h"
27 #include "absl/time/clock.h"
28 #include "absl/time/time.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/base/scheduler.h"
31 #include "fcp/base/simulated_clock.h"
32 
33 namespace fcp {
34 namespace secagg {
35 namespace {
36 
37 using ::testing::_;
38 using ::testing::Eq;
39 using ::testing::IsFalse;
40 using ::testing::Lt;
41 using ::testing::StrictMock;
42 using ::testing::Test;
43 
44 class MockScheduler : public Scheduler {
45  public:
46   MOCK_METHOD(void, Schedule, (std::function<void()>), (override));
47   MOCK_METHOD(void, WaitUntilIdle, ());
48 };
49 
50 // Wrap int in a struct to keep Clang-tidy happy.
51 struct Integer {
Integerfcp::secagg::__anon032271a60111::Integer52   Integer() : value(0) {}
Integerfcp::secagg::__anon032271a60111::Integer53   explicit Integer(int v) : value(v) {}
54   int value;
55 };
56 
IntGenerators(int n)57 std::vector<std::function<std::unique_ptr<Integer>()>> IntGenerators(int n) {
58   std::vector<std::function<std::unique_ptr<Integer>()>> generators;
59   for (int i = 1; i <= n; ++i) {
60     generators.emplace_back([i]() { return std::make_unique<Integer>(i); });
61   }
62   return generators;
63 }
64 
__anon032271a60302(const Integer& l, const Integer& r) 65 constexpr auto multiply_accumulator = [](const Integer& l, const Integer& r) {
66   return std::make_unique<Integer>(l.value * r.value);
67 };
__anon032271a60402(const std::function<void()>& f) 68 constexpr auto call_fn = [](const std::function<void()>& f) { f(); };
69 
TEST(SecAggSchedulerTest,ScheduleCallback)70 TEST(SecAggSchedulerTest, ScheduleCallback) {
71   StrictMock<MockScheduler> parallel_scheduler;
72   StrictMock<MockScheduler> sequential_scheduler;
73 
74   EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(0);
75   EXPECT_CALL(sequential_scheduler, Schedule(_)).WillOnce(call_fn);
76 
77   SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler);
78 
79   int r = 0;
80   runner.ScheduleCallback([&r]() { r = 5; });
81   EXPECT_THAT(r, Eq(5));
82 }
83 
TEST(SecAggSchedulerTest,SingleCall)84 TEST(SecAggSchedulerTest, SingleCall) {
85   StrictMock<MockScheduler> parallel_scheduler;
86   StrictMock<MockScheduler> sequential_scheduler;
87 
88   EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(6).WillRepeatedly(call_fn);
89   EXPECT_CALL(sequential_scheduler, Schedule(_))
90       .Times(7)
91       .WillRepeatedly(call_fn);
92 
93   // Technically unsafe, but we know the pointers will be valid as long as
94   // runner is alive.
95   SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler);
96 
97   std::vector<std::function<std::unique_ptr<Integer>()>> generators =
98       IntGenerators(6);
99 
100   Integer result;
101   auto accumulator = runner.CreateAccumulator<Integer>(
102       std::make_unique<Integer>(1), multiply_accumulator);
103   for (const auto& generator : generators) {
104     accumulator->Schedule(generator);
105   }
106   accumulator->SetAsyncObserver(
107       [&]() { result = *(accumulator->GetResultAndCancel()); });
108   EXPECT_THAT(result.value, Eq(720));  // 6! = 720
109 }
110 
TEST(SecAggSchedulerTest,SingleCallWithDelay)111 TEST(SecAggSchedulerTest, SingleCallWithDelay) {
112   StrictMock<MockScheduler> parallel_scheduler;
113   StrictMock<MockScheduler> sequential_scheduler;
114   SimulatedClock clock;
115 
116   EXPECT_CALL(parallel_scheduler, Schedule(_)).Times(6).WillRepeatedly(call_fn);
117   EXPECT_CALL(sequential_scheduler, Schedule(_))
118       .Times(6)
119       .WillRepeatedly(call_fn);
120 
121   SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler, &clock);
122 
123   std::vector<std::function<std::unique_ptr<Integer>()>> generators =
124       IntGenerators(6);
125 
126   Integer result;
127   auto accumulator = runner.CreateAccumulator<Integer>(
128       std::make_unique<Integer>(1), multiply_accumulator);
129   for (const auto& generator : generators) {
130     accumulator->Schedule(generator, absl::Seconds(5));
131   }
132   accumulator->SetAsyncObserver(
133       [&]() { result = *(accumulator->GetResultAndCancel()); });
134 
135   // Generators are still delayed.
136   EXPECT_THAT(result.value, Eq(0));
137 
138   // Advance time by one second.
139   clock.AdvanceTime(absl::Seconds(1));
140   // Generators are still delayed.
141   EXPECT_THAT(result.value, Eq(0));
142 
143   // Advance time by another 4 seconds.
144   clock.AdvanceTime(absl::Seconds(4));
145   EXPECT_THAT(result.value, Eq(720));  // 6! = 720
146 }
147 
TEST(SecAggSchedulerTest,TwoCalls)148 TEST(SecAggSchedulerTest, TwoCalls) {
149   StrictMock<MockScheduler> parallel_scheduler;
150   StrictMock<MockScheduler> sequential_scheduler;
151 
152   EXPECT_CALL(parallel_scheduler, Schedule(_)).WillRepeatedly(call_fn);
153   EXPECT_CALL(sequential_scheduler, Schedule(_)).WillRepeatedly(call_fn);
154 
155   // Technically unsafe, but we know the pointers will be valid as long as
156   // runner is alive.
157   SecAggScheduler runner(&parallel_scheduler, &sequential_scheduler);
158 
159   // First call
160   std::vector<std::function<std::unique_ptr<Integer>()>> generators =
161       IntGenerators(6);
162 
163   Integer result;
164   auto accumulator = runner.CreateAccumulator<Integer>(
165       std::make_unique<Integer>(1), multiply_accumulator);
166   for (const auto& generator : generators) {
167     accumulator->Schedule(generator);
168   }
169   accumulator->SetAsyncObserver(
170       [&]() { result = *(accumulator->GetResultAndCancel()); });
171 
172   EXPECT_THAT(result.value, Eq(720));  // 6! = 720
173 
174   // Second call
175   std::vector<std::function<std::unique_ptr<Integer>()>> generators2 =
176       IntGenerators(4);
177   auto accumulator2 = runner.CreateAccumulator<Integer>(
178       std::make_unique<Integer>(1), multiply_accumulator);
179 
180   for (const auto& generator : generators2) {
181     accumulator2->Schedule(generator);
182   }
183   accumulator2->SetAsyncObserver(
184       [&]() { result = *(accumulator2->GetResultAndCancel()); });
185   EXPECT_THAT(result.value, Eq(24));  // 4! = 24
186 }
187 
TEST(SecAggSchedulerAbortTest,Abort)188 TEST(SecAggSchedulerAbortTest, Abort) {
189   auto parallel_scheduler = fcp::CreateThreadPoolScheduler(4);
190   auto sequential_scheduler = fcp::CreateThreadPoolScheduler(1);
191 
192   absl::Notification signal_abort;
193   std::atomic<int> callback_counter = 0;
194 
195   std::vector<std::function<std::unique_ptr<Integer>()>> generators;
196   for (int i = 1; i <= 100; ++i) {
197     generators.emplace_back([&, i]() {
198       callback_counter++;
199       // Signal abort when running 10th parallel task
200       if (i == 10) {
201         signal_abort.Notify();
202       }
203       absl::SleepFor(absl::Milliseconds(1));
204       return std::make_unique<Integer>(i);
205     });
206   }
207 
208   auto accumulator_func = [&](const Integer& l, const Integer& r) {
209     callback_counter++;
210     return std::make_unique<Integer>(l.value * r.value);
211   };
212 
213   SecAggScheduler runner(parallel_scheduler.get(), sequential_scheduler.get());
214   bool final_callback_called = false;
215   auto accumulator = runner.CreateAccumulator<Integer>(
216       std::make_unique<Integer>(1), accumulator_func);
217   for (const auto& generator : generators) {
218     accumulator->Schedule(generator);
219   }
220   accumulator->SetAsyncObserver([&]() { final_callback_called = true; });
221 
222   signal_abort.WaitForNotification();
223   accumulator->Cancel();
224 
225   int count_after_abort = callback_counter.load();
226   FCP_LOG(INFO) << "count_after_abort = " << count_after_abort;
227 
228   // Wait for all scheduled tasks to finish
229   runner.WaitUntilIdle();
230 
231   // The final number of callbacks should not change since returning from
232   // Abort.
233   int final_count = callback_counter.load();
234   EXPECT_THAT(final_count, Eq(count_after_abort));
235   EXPECT_THAT(final_count, Lt(generators.size()));
236   EXPECT_THAT(final_callback_called, IsFalse());
237 }
238 
239 // Tests that three batches of async work result in three calls to the callback,
240 // which can be overriden in between calls.
TEST(SecAggSchedulerTest,ThreeCallbackCalls)241 TEST(SecAggSchedulerTest, ThreeCallbackCalls) {
242   auto parallel_scheduler = fcp::CreateThreadPoolScheduler(4);
243   auto sequential_scheduler = fcp::CreateThreadPoolScheduler(1);
244 
245   SecAggScheduler runner(parallel_scheduler.get(), sequential_scheduler.get());
246 
247   std::vector<std::function<std::unique_ptr<Integer>()>> generators =
248       IntGenerators(3);
249 
250   auto accumulator = runner.CreateAccumulator<Integer>(
251       std::make_unique<Integer>(1), multiply_accumulator);
252   for (const auto& generator : generators) {
253     accumulator->Schedule(generator);
254   }
255   int callback_counter = 0;
256   accumulator->SetAsyncObserver([&]() { callback_counter++; });
257   runner.WaitUntilIdle();
258   EXPECT_THAT(callback_counter, Eq(1));
259   for (const auto& generator : generators) {
260     accumulator->Schedule(generator);
261   }
262   runner.WaitUntilIdle();
263   // The callback was not re-scheduled, so the second call to Schedule didn't
264   // trigger it. This results in unobserved work.
265   EXPECT_THAT(callback_counter, Eq(1));
266   bool has_work = accumulator->SetAsyncObserver([&]() { callback_counter++; });
267   runner.WaitUntilIdle();
268   EXPECT_TRUE(has_work);
269   EXPECT_THAT(callback_counter, Eq(2));
270   // The accumulator should be idle and without unobserved work at this point.
271   has_work = accumulator->SetAsyncObserver([&]() { callback_counter++; });
272   EXPECT_FALSE(has_work);
273   Integer result;
274   for (const auto& generator : generators) {
275     accumulator->Schedule(generator);
276   }
277   accumulator->SetAsyncObserver(
278       [&]() { result = *(accumulator->GetResultAndCancel()); });
279   runner.WaitUntilIdle();
280   // The last call to SetAsyncObserver overwrittes the previous callback.
281   EXPECT_THAT(callback_counter, Eq(2));
282   EXPECT_THAT(result.value, Eq(216));  // 6^3 = 216
283 }
284 
285 }  // namespace
286 }  // namespace secagg
287 }  // namespace fcp
288