1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_aborted_state.h"
18
19 #include <memory>
20 #include <string>
21
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/container/node_hash_set.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
27 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
28 #include "fcp/secagg/shared/secagg_messages.pb.h"
29 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
30 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
31 #include "fcp/tracing/test_tracing_recorder.h"
32
33 namespace fcp {
34 namespace secagg {
35 namespace {
36
37 using ::testing::Eq;
38
CreateSecAggServerProtocolImpl(MockSecAggServerMetricsListener * metrics_listener=nullptr)39 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
40 MockSecAggServerMetricsListener* metrics_listener = nullptr) {
41 auto sender = std::unique_ptr<SendToClientsInterface>();
42 SecretSharingGraphFactory factory;
43 return std::make_unique<AesSecAggServerProtocolImpl>(
44 factory.CreateCompleteGraph(4, 3), // total number of clients is 4
45 3, // minimum_number_of_clients_to_proceed,
46 std::vector<InputVectorSpecification>(), // input_vector_specs
47 std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
48 nullptr, // prng_factory
49 sender.get(),
50 nullptr, // prng_runner
51 std::vector<ClientStatus>(
52 4, DEAD_AFTER_SHARE_KEYS_RECEIVED), // client_statuses
53 ServerVariant::NATIVE_V1);
54 }
55
TEST(SecaggServerAbortedStateTest,IsAbortedReturnsTrue)56 TEST(SecaggServerAbortedStateTest, IsAbortedReturnsTrue) {
57 std::string test_error_message = "test error message";
58
59 SecAggServerAbortedState aborted_state(
60 test_error_message, CreateSecAggServerProtocolImpl(),
61 0, // number_of_clients_failed_after_sending_masked_input
62 4, // number_of_clients_failed_before_sending_masked_input
63 0 // number_of_clients_terminated_without_unmasking
64 );
65
66 EXPECT_THAT(aborted_state.IsAborted(), Eq(true));
67 }
68
TEST(SecaggServerAbortedStateTest,IsCompletedSuccessfullyReturnsFalse)69 TEST(SecaggServerAbortedStateTest, IsCompletedSuccessfullyReturnsFalse) {
70 std::string test_error_message = "test error message";
71
72 SecAggServerAbortedState aborted_state(
73 test_error_message, CreateSecAggServerProtocolImpl(),
74 0, // number_of_clients_failed_after_sending_masked_input
75 4, // number_of_clients_failed_before_sending_masked_input
76 0 // number_of_clients_terminated_without_unmasking
77 );
78
79 EXPECT_THAT(aborted_state.IsCompletedSuccessfully(), Eq(false));
80 }
81
TEST(SecaggServerAbortedStateTest,ErrorMessageReturnsSelectedMessage)82 TEST(SecaggServerAbortedStateTest, ErrorMessageReturnsSelectedMessage) {
83 std::string test_error_message = "test error message";
84
85 SecAggServerAbortedState aborted_state(
86 test_error_message, CreateSecAggServerProtocolImpl(),
87 0, // number_of_clients_failed_after_sending_masked_input
88 4, // number_of_clients_failed_before_sending_masked_input
89 0 // number_of_clients_terminated_without_unmasking
90 );
91
92 EXPECT_THAT(aborted_state.ErrorMessage().value(), Eq(test_error_message));
93 }
94
TEST(SecaggServerAbortedStateTest,ReadyForNextRoundReturnsFalse)95 TEST(SecaggServerAbortedStateTest, ReadyForNextRoundReturnsFalse) {
96 std::string test_error_message = "test error message";
97
98 SecAggServerAbortedState aborted_state(
99 test_error_message, CreateSecAggServerProtocolImpl(),
100 0, // number_of_clients_failed_after_sending_masked_input
101 4, // number_of_clients_failed_before_sending_masked_input
102 0 // number_of_clients_terminated_without_unmasking
103 );
104
105 EXPECT_THAT(aborted_state.ReadyForNextRound(), Eq(false));
106 }
107
TEST(SecaggServerAbortedStateTest,NumberOfMessagesReceivedInThisRoundReturnsZero)108 TEST(SecaggServerAbortedStateTest,
109 NumberOfMessagesReceivedInThisRoundReturnsZero) {
110 std::string test_error_message = "test error message";
111
112 SecAggServerAbortedState aborted_state(
113 test_error_message, CreateSecAggServerProtocolImpl(),
114 0, // number_of_clients_failed_after_sending_masked_input
115 4, // number_of_clients_failed_before_sending_masked_input
116 0 // number_of_clients_terminated_without_unmasking
117 );
118
119 EXPECT_THAT(aborted_state.NumberOfMessagesReceivedInThisRound(), Eq(0));
120 }
121
TEST(SecaggServerAbortedStateTest,NumberOfClientsReadyForNextRoundReturnsZero)122 TEST(SecaggServerAbortedStateTest,
123 NumberOfClientsReadyForNextRoundReturnsZero) {
124 std::string test_error_message = "test error message";
125
126 SecAggServerAbortedState aborted_state(
127 test_error_message, CreateSecAggServerProtocolImpl(),
128 0, // number_of_clients_failed_after_sending_masked_input
129 4, // number_of_clients_failed_before_sending_masked_input
130 0 // number_of_clients_terminated_without_unmasking
131 );
132
133 EXPECT_THAT(aborted_state.NumberOfClientsReadyForNextRound(), Eq(0));
134 }
135
TEST(SecaggServerAbortedStateTest,NumberOfAliveClientsIsZero)136 TEST(SecaggServerAbortedStateTest, NumberOfAliveClientsIsZero) {
137 std::string test_error_message = "test error message";
138
139 SecAggServerAbortedState aborted_state(
140 test_error_message, CreateSecAggServerProtocolImpl(),
141 0, // number_of_clients_failed_after_sending_masked_input
142 4, // number_of_clients_failed_before_sending_masked_input
143 0 // number_of_clients_terminated_without_unmasking
144 );
145
146 EXPECT_THAT(aborted_state.NumberOfAliveClients(), Eq(0));
147 }
148
TEST(SecaggServerAbortedStateTest,NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate)149 TEST(SecaggServerAbortedStateTest,
150 NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate) {
151 std::string test_error_message = "test error message";
152
153 SecAggServerAbortedState aborted_state(
154 test_error_message, CreateSecAggServerProtocolImpl(),
155 0, // number_of_clients_failed_after_sending_masked_input
156 4, // number_of_clients_failed_before_sending_masked_input
157 0 // number_of_clients_terminated_without_unmasking
158 );
159
160 EXPECT_THAT(aborted_state.NumberOfClientsFailedBeforeSendingMaskedInput(),
161 Eq(4));
162 }
163
TEST(SecaggServerAbortedStateTest,NumberOfClientsFailedAfterSendingMaskedInputReturnsZero)164 TEST(SecaggServerAbortedStateTest,
165 NumberOfClientsFailedAfterSendingMaskedInputReturnsZero) {
166 std::string test_error_message = "test error message";
167
168 SecAggServerAbortedState aborted_state(
169 test_error_message, CreateSecAggServerProtocolImpl(),
170 0, // number_of_clients_failed_after_sending_masked_input
171 4, // number_of_clients_failed_before_sending_masked_input
172 0 // number_of_clients_terminated_without_unmasking
173 );
174
175 EXPECT_THAT(aborted_state.NumberOfClientsFailedAfterSendingMaskedInput(),
176 Eq(0));
177 }
178
TEST(SecaggServerAbortedStateTest,NumberOfClientsTerminatedWithoutUnmaskingReturnsZero)179 TEST(SecaggServerAbortedStateTest,
180 NumberOfClientsTerminatedWithoutUnmaskingReturnsZero) {
181 std::string test_error_message = "test error message";
182
183 SecAggServerAbortedState aborted_state(
184 test_error_message, CreateSecAggServerProtocolImpl(),
185 0, // number_of_clients_failed_after_sending_masked_input
186 4, // number_of_clients_failed_before_sending_masked_input
187 0 // number_of_clients_terminated_without_unmasking
188 );
189
190 EXPECT_THAT(aborted_state.NumberOfClientsTerminatedWithoutUnmasking(), Eq(0));
191 }
192
TEST(SecaggServerAbortedStateTest,NumberOfPendingClientsReturnsZero)193 TEST(SecaggServerAbortedStateTest, NumberOfPendingClientsReturnsZero) {
194 std::string test_error_message = "test error message";
195
196 SecAggServerAbortedState aborted_state(
197 test_error_message, CreateSecAggServerProtocolImpl(),
198 0, // number_of_clients_failed_after_sending_masked_input
199 4, // number_of_clients_failed_before_sending_masked_input
200 0 // number_of_clients_terminated_without_unmasking
201 );
202
203 EXPECT_THAT(aborted_state.NumberOfPendingClients(), Eq(0));
204 }
205
TEST(SecaggServerAbortedStateTest,NumberOfIncludedInputsReturnsZero)206 TEST(SecaggServerAbortedStateTest, NumberOfIncludedInputsReturnsZero) {
207 std::string test_error_message = "test error message";
208
209 SecAggServerAbortedState aborted_state(
210 test_error_message, CreateSecAggServerProtocolImpl(),
211 0, // number_of_clients_failed_after_sending_masked_input
212 4, // number_of_clients_failed_before_sending_masked_input
213 0 // number_of_clients_terminated_without_unmasking
214 );
215
216 EXPECT_THAT(aborted_state.NumberOfIncludedInputs(), Eq(0));
217 }
218
TEST(SecaggServerAbortedStateTest,IsNumberOfIncludedInputsCommittedReturnsTrue)219 TEST(SecaggServerAbortedStateTest,
220 IsNumberOfIncludedInputsCommittedReturnsTrue) {
221 std::string test_error_message = "test error message";
222
223 SecAggServerAbortedState aborted_state(
224 test_error_message, CreateSecAggServerProtocolImpl(),
225 0, // number_of_clients_failed_after_sending_masked_input
226 4, // number_of_clients_failed_before_sending_masked_input
227 0 // number_of_clients_terminated_without_unmasking
228 );
229
230 EXPECT_THAT(aborted_state.IsNumberOfIncludedInputsCommitted(), Eq(true));
231 }
232
TEST(SecaggServerAbortedStateTest,MinimumMessagesNeededForNextRoundReturnsZero)233 TEST(SecaggServerAbortedStateTest,
234 MinimumMessagesNeededForNextRoundReturnsZero) {
235 std::string test_error_message = "test error message";
236
237 SecAggServerAbortedState aborted_state(
238 test_error_message, CreateSecAggServerProtocolImpl(),
239 0, // number_of_clients_failed_after_sending_masked_input
240 4, // number_of_clients_failed_before_sending_masked_input
241 0 // number_of_clients_terminated_without_unmasking
242 );
243
244 EXPECT_THAT(aborted_state.MinimumMessagesNeededForNextRound(), Eq(0));
245 }
246
TEST(SecaggServerAbortedStateTest,minimum_number_of_clients_to_proceedIsAccurate)247 TEST(SecaggServerAbortedStateTest,
248 minimum_number_of_clients_to_proceedIsAccurate) {
249 std::string test_error_message = "test error message";
250
251 SecAggServerAbortedState aborted_state(
252 test_error_message, CreateSecAggServerProtocolImpl(),
253 0, // number_of_clients_failed_after_sending_masked_input
254 4, // number_of_clients_failed_before_sending_masked_input
255 0 // number_of_clients_terminated_without_unmasking
256 );
257
258 EXPECT_THAT(aborted_state.minimum_number_of_clients_to_proceed(), Eq(3));
259 }
260
TEST(SecaggServerAbortedStateTest,HandleMessageRaisesError)261 TEST(SecaggServerAbortedStateTest, HandleMessageRaisesError) {
262 std::string test_error_message = "test error message";
263 MockSecAggServerMetricsListener* metrics =
264 new MockSecAggServerMetricsListener();
265
266 SecAggServerAbortedState aborted_state(
267 test_error_message, CreateSecAggServerProtocolImpl(metrics),
268 0, // number_of_clients_failed_after_sending_masked_input
269 4, // number_of_clients_failed_before_sending_masked_input
270 0 // number_of_clients_terminated_without_unmasking
271 );
272
273 ClientToServerWrapperMessage client_message;
274 EXPECT_CALL(*metrics, MessageReceivedSizes(
275 Eq(ClientToServerWrapperMessage::
276 MessageContentCase::MESSAGE_CONTENT_NOT_SET),
277 Eq(false), Eq(client_message.ByteSizeLong())));
278 EXPECT_THAT(aborted_state.HandleMessage(0, client_message).ok(), Eq(false));
279 }
280
TEST(SecaggServerAbortedStateTest,ProceedToNextRoundRaisesError)281 TEST(SecaggServerAbortedStateTest, ProceedToNextRoundRaisesError) {
282 std::string test_error_message = "test error message";
283
284 SecAggServerAbortedState aborted_state(
285 test_error_message, CreateSecAggServerProtocolImpl(),
286 0, // number_of_clients_failed_after_sending_masked_input
287 4, // number_of_clients_failed_before_sending_masked_input
288 0 // number_of_clients_terminated_without_unmasking
289 );
290
291 EXPECT_THAT(aborted_state.ProceedToNextRound().ok(), Eq(false));
292 }
293
TEST(SecaggServerAbortedStateTest,ResultRaisesErrorStatus)294 TEST(SecaggServerAbortedStateTest, ResultRaisesErrorStatus) {
295 std::string test_error_message = "test error message";
296
297 SecAggServerAbortedState aborted_state(
298 test_error_message, CreateSecAggServerProtocolImpl(),
299 0, // number_of_clients_failed_after_sending_masked_input
300 4, // number_of_clients_failed_before_sending_masked_input
301 0 // number_of_clients_terminated_without_unmasking
302 );
303
304 EXPECT_THAT(aborted_state.Result().ok(), Eq(false));
305 }
306
307 } // namespace
308 } // namespace secagg
309 } // namespace fcp
310