xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_aborted_state_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_aborted_state.h"
18 
19 #include <memory>
20 #include <string>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/container/node_hash_set.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
27 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
28 #include "fcp/secagg/shared/secagg_messages.pb.h"
29 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
30 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
31 #include "fcp/tracing/test_tracing_recorder.h"
32 
33 namespace fcp {
34 namespace secagg {
35 namespace {
36 
37 using ::testing::Eq;
38 
CreateSecAggServerProtocolImpl(MockSecAggServerMetricsListener * metrics_listener=nullptr)39 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
40     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
41   auto sender = std::unique_ptr<SendToClientsInterface>();
42   SecretSharingGraphFactory factory;
43   return std::make_unique<AesSecAggServerProtocolImpl>(
44       factory.CreateCompleteGraph(4, 3),  // total number of clients is 4
45       3,  // minimum_number_of_clients_to_proceed,
46       std::vector<InputVectorSpecification>(),  // input_vector_specs
47       std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
48       nullptr,  // prng_factory
49       sender.get(),
50       nullptr,  // prng_runner
51       std::vector<ClientStatus>(
52           4, DEAD_AFTER_SHARE_KEYS_RECEIVED),  // client_statuses
53       ServerVariant::NATIVE_V1);
54 }
55 
TEST(SecaggServerAbortedStateTest,IsAbortedReturnsTrue)56 TEST(SecaggServerAbortedStateTest, IsAbortedReturnsTrue) {
57   std::string test_error_message = "test error message";
58 
59   SecAggServerAbortedState aborted_state(
60       test_error_message, CreateSecAggServerProtocolImpl(),
61       0,  // number_of_clients_failed_after_sending_masked_input
62       4,  // number_of_clients_failed_before_sending_masked_input
63       0   // number_of_clients_terminated_without_unmasking
64   );
65 
66   EXPECT_THAT(aborted_state.IsAborted(), Eq(true));
67 }
68 
TEST(SecaggServerAbortedStateTest,IsCompletedSuccessfullyReturnsFalse)69 TEST(SecaggServerAbortedStateTest, IsCompletedSuccessfullyReturnsFalse) {
70   std::string test_error_message = "test error message";
71 
72   SecAggServerAbortedState aborted_state(
73       test_error_message, CreateSecAggServerProtocolImpl(),
74       0,  // number_of_clients_failed_after_sending_masked_input
75       4,  // number_of_clients_failed_before_sending_masked_input
76       0   // number_of_clients_terminated_without_unmasking
77   );
78 
79   EXPECT_THAT(aborted_state.IsCompletedSuccessfully(), Eq(false));
80 }
81 
TEST(SecaggServerAbortedStateTest,ErrorMessageReturnsSelectedMessage)82 TEST(SecaggServerAbortedStateTest, ErrorMessageReturnsSelectedMessage) {
83   std::string test_error_message = "test error message";
84 
85   SecAggServerAbortedState aborted_state(
86       test_error_message, CreateSecAggServerProtocolImpl(),
87       0,  // number_of_clients_failed_after_sending_masked_input
88       4,  // number_of_clients_failed_before_sending_masked_input
89       0   // number_of_clients_terminated_without_unmasking
90   );
91 
92   EXPECT_THAT(aborted_state.ErrorMessage().value(), Eq(test_error_message));
93 }
94 
TEST(SecaggServerAbortedStateTest,ReadyForNextRoundReturnsFalse)95 TEST(SecaggServerAbortedStateTest, ReadyForNextRoundReturnsFalse) {
96   std::string test_error_message = "test error message";
97 
98   SecAggServerAbortedState aborted_state(
99       test_error_message, CreateSecAggServerProtocolImpl(),
100       0,  // number_of_clients_failed_after_sending_masked_input
101       4,  // number_of_clients_failed_before_sending_masked_input
102       0   // number_of_clients_terminated_without_unmasking
103   );
104 
105   EXPECT_THAT(aborted_state.ReadyForNextRound(), Eq(false));
106 }
107 
TEST(SecaggServerAbortedStateTest,NumberOfMessagesReceivedInThisRoundReturnsZero)108 TEST(SecaggServerAbortedStateTest,
109      NumberOfMessagesReceivedInThisRoundReturnsZero) {
110   std::string test_error_message = "test error message";
111 
112   SecAggServerAbortedState aborted_state(
113       test_error_message, CreateSecAggServerProtocolImpl(),
114       0,  // number_of_clients_failed_after_sending_masked_input
115       4,  // number_of_clients_failed_before_sending_masked_input
116       0   // number_of_clients_terminated_without_unmasking
117   );
118 
119   EXPECT_THAT(aborted_state.NumberOfMessagesReceivedInThisRound(), Eq(0));
120 }
121 
TEST(SecaggServerAbortedStateTest,NumberOfClientsReadyForNextRoundReturnsZero)122 TEST(SecaggServerAbortedStateTest,
123      NumberOfClientsReadyForNextRoundReturnsZero) {
124   std::string test_error_message = "test error message";
125 
126   SecAggServerAbortedState aborted_state(
127       test_error_message, CreateSecAggServerProtocolImpl(),
128       0,  // number_of_clients_failed_after_sending_masked_input
129       4,  // number_of_clients_failed_before_sending_masked_input
130       0   // number_of_clients_terminated_without_unmasking
131   );
132 
133   EXPECT_THAT(aborted_state.NumberOfClientsReadyForNextRound(), Eq(0));
134 }
135 
TEST(SecaggServerAbortedStateTest,NumberOfAliveClientsIsZero)136 TEST(SecaggServerAbortedStateTest, NumberOfAliveClientsIsZero) {
137   std::string test_error_message = "test error message";
138 
139   SecAggServerAbortedState aborted_state(
140       test_error_message, CreateSecAggServerProtocolImpl(),
141       0,  // number_of_clients_failed_after_sending_masked_input
142       4,  // number_of_clients_failed_before_sending_masked_input
143       0   // number_of_clients_terminated_without_unmasking
144   );
145 
146   EXPECT_THAT(aborted_state.NumberOfAliveClients(), Eq(0));
147 }
148 
TEST(SecaggServerAbortedStateTest,NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate)149 TEST(SecaggServerAbortedStateTest,
150      NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate) {
151   std::string test_error_message = "test error message";
152 
153   SecAggServerAbortedState aborted_state(
154       test_error_message, CreateSecAggServerProtocolImpl(),
155       0,  // number_of_clients_failed_after_sending_masked_input
156       4,  // number_of_clients_failed_before_sending_masked_input
157       0   // number_of_clients_terminated_without_unmasking
158   );
159 
160   EXPECT_THAT(aborted_state.NumberOfClientsFailedBeforeSendingMaskedInput(),
161               Eq(4));
162 }
163 
TEST(SecaggServerAbortedStateTest,NumberOfClientsFailedAfterSendingMaskedInputReturnsZero)164 TEST(SecaggServerAbortedStateTest,
165      NumberOfClientsFailedAfterSendingMaskedInputReturnsZero) {
166   std::string test_error_message = "test error message";
167 
168   SecAggServerAbortedState aborted_state(
169       test_error_message, CreateSecAggServerProtocolImpl(),
170       0,  // number_of_clients_failed_after_sending_masked_input
171       4,  // number_of_clients_failed_before_sending_masked_input
172       0   // number_of_clients_terminated_without_unmasking
173   );
174 
175   EXPECT_THAT(aborted_state.NumberOfClientsFailedAfterSendingMaskedInput(),
176               Eq(0));
177 }
178 
TEST(SecaggServerAbortedStateTest,NumberOfClientsTerminatedWithoutUnmaskingReturnsZero)179 TEST(SecaggServerAbortedStateTest,
180      NumberOfClientsTerminatedWithoutUnmaskingReturnsZero) {
181   std::string test_error_message = "test error message";
182 
183   SecAggServerAbortedState aborted_state(
184       test_error_message, CreateSecAggServerProtocolImpl(),
185       0,  // number_of_clients_failed_after_sending_masked_input
186       4,  // number_of_clients_failed_before_sending_masked_input
187       0   // number_of_clients_terminated_without_unmasking
188   );
189 
190   EXPECT_THAT(aborted_state.NumberOfClientsTerminatedWithoutUnmasking(), Eq(0));
191 }
192 
TEST(SecaggServerAbortedStateTest,NumberOfPendingClientsReturnsZero)193 TEST(SecaggServerAbortedStateTest, NumberOfPendingClientsReturnsZero) {
194   std::string test_error_message = "test error message";
195 
196   SecAggServerAbortedState aborted_state(
197       test_error_message, CreateSecAggServerProtocolImpl(),
198       0,  // number_of_clients_failed_after_sending_masked_input
199       4,  // number_of_clients_failed_before_sending_masked_input
200       0   // number_of_clients_terminated_without_unmasking
201   );
202 
203   EXPECT_THAT(aborted_state.NumberOfPendingClients(), Eq(0));
204 }
205 
TEST(SecaggServerAbortedStateTest,NumberOfIncludedInputsReturnsZero)206 TEST(SecaggServerAbortedStateTest, NumberOfIncludedInputsReturnsZero) {
207   std::string test_error_message = "test error message";
208 
209   SecAggServerAbortedState aborted_state(
210       test_error_message, CreateSecAggServerProtocolImpl(),
211       0,  // number_of_clients_failed_after_sending_masked_input
212       4,  // number_of_clients_failed_before_sending_masked_input
213       0   // number_of_clients_terminated_without_unmasking
214   );
215 
216   EXPECT_THAT(aborted_state.NumberOfIncludedInputs(), Eq(0));
217 }
218 
TEST(SecaggServerAbortedStateTest,IsNumberOfIncludedInputsCommittedReturnsTrue)219 TEST(SecaggServerAbortedStateTest,
220      IsNumberOfIncludedInputsCommittedReturnsTrue) {
221   std::string test_error_message = "test error message";
222 
223   SecAggServerAbortedState aborted_state(
224       test_error_message, CreateSecAggServerProtocolImpl(),
225       0,  // number_of_clients_failed_after_sending_masked_input
226       4,  // number_of_clients_failed_before_sending_masked_input
227       0   // number_of_clients_terminated_without_unmasking
228   );
229 
230   EXPECT_THAT(aborted_state.IsNumberOfIncludedInputsCommitted(), Eq(true));
231 }
232 
TEST(SecaggServerAbortedStateTest,MinimumMessagesNeededForNextRoundReturnsZero)233 TEST(SecaggServerAbortedStateTest,
234      MinimumMessagesNeededForNextRoundReturnsZero) {
235   std::string test_error_message = "test error message";
236 
237   SecAggServerAbortedState aborted_state(
238       test_error_message, CreateSecAggServerProtocolImpl(),
239       0,  // number_of_clients_failed_after_sending_masked_input
240       4,  // number_of_clients_failed_before_sending_masked_input
241       0   // number_of_clients_terminated_without_unmasking
242   );
243 
244   EXPECT_THAT(aborted_state.MinimumMessagesNeededForNextRound(), Eq(0));
245 }
246 
TEST(SecaggServerAbortedStateTest,minimum_number_of_clients_to_proceedIsAccurate)247 TEST(SecaggServerAbortedStateTest,
248      minimum_number_of_clients_to_proceedIsAccurate) {
249   std::string test_error_message = "test error message";
250 
251   SecAggServerAbortedState aborted_state(
252       test_error_message, CreateSecAggServerProtocolImpl(),
253       0,  // number_of_clients_failed_after_sending_masked_input
254       4,  // number_of_clients_failed_before_sending_masked_input
255       0   // number_of_clients_terminated_without_unmasking
256   );
257 
258   EXPECT_THAT(aborted_state.minimum_number_of_clients_to_proceed(), Eq(3));
259 }
260 
TEST(SecaggServerAbortedStateTest,HandleMessageRaisesError)261 TEST(SecaggServerAbortedStateTest, HandleMessageRaisesError) {
262   std::string test_error_message = "test error message";
263   MockSecAggServerMetricsListener* metrics =
264       new MockSecAggServerMetricsListener();
265 
266   SecAggServerAbortedState aborted_state(
267       test_error_message, CreateSecAggServerProtocolImpl(metrics),
268       0,  // number_of_clients_failed_after_sending_masked_input
269       4,  // number_of_clients_failed_before_sending_masked_input
270       0   // number_of_clients_terminated_without_unmasking
271   );
272 
273   ClientToServerWrapperMessage client_message;
274   EXPECT_CALL(*metrics, MessageReceivedSizes(
275                             Eq(ClientToServerWrapperMessage::
276                                    MessageContentCase::MESSAGE_CONTENT_NOT_SET),
277                             Eq(false), Eq(client_message.ByteSizeLong())));
278   EXPECT_THAT(aborted_state.HandleMessage(0, client_message).ok(), Eq(false));
279 }
280 
TEST(SecaggServerAbortedStateTest,ProceedToNextRoundRaisesError)281 TEST(SecaggServerAbortedStateTest, ProceedToNextRoundRaisesError) {
282   std::string test_error_message = "test error message";
283 
284   SecAggServerAbortedState aborted_state(
285       test_error_message, CreateSecAggServerProtocolImpl(),
286       0,  // number_of_clients_failed_after_sending_masked_input
287       4,  // number_of_clients_failed_before_sending_masked_input
288       0   // number_of_clients_terminated_without_unmasking
289   );
290 
291   EXPECT_THAT(aborted_state.ProceedToNextRound().ok(), Eq(false));
292 }
293 
TEST(SecaggServerAbortedStateTest,ResultRaisesErrorStatus)294 TEST(SecaggServerAbortedStateTest, ResultRaisesErrorStatus) {
295   std::string test_error_message = "test error message";
296 
297   SecAggServerAbortedState aborted_state(
298       test_error_message, CreateSecAggServerProtocolImpl(),
299       0,  // number_of_clients_failed_after_sending_masked_input
300       4,  // number_of_clients_failed_before_sending_masked_input
301       0   // number_of_clients_terminated_without_unmasking
302   );
303 
304   EXPECT_THAT(aborted_state.Result().ok(), Eq(false));
305 }
306 
307 }  // namespace
308 }  // namespace secagg
309 }  // namespace fcp
310