1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/secagg_server_completed_state.h"
18
19 #include <memory>
20 #include <utility>
21
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/container/node_hash_set.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
27 #include "fcp/secagg/server/secagg_server_enums.pb.h"
28 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
29 #include "fcp/secagg/shared/secagg_messages.pb.h"
30 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
31 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
32 #include "fcp/secagg/testing/test_matchers.h"
33 #include "fcp/tracing/test_tracing_recorder.h"
34
35 namespace fcp {
36 namespace secagg {
37 namespace {
38
39 using ::testing::Eq;
40
CreateSecAggServerProtocolImpl(MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)41 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
42 MockSendToClientsInterface* sender,
43 MockSecAggServerMetricsListener* metrics_listener = nullptr) {
44 int total_number_of_clients = 4;
45 SecretSharingGraphFactory factory;
46 return std::make_unique<AesSecAggServerProtocolImpl>(
47 factory.CreateCompleteGraph(total_number_of_clients, 3),
48 3, // minimum_number_of_clients_to_proceed
49 std::vector<InputVectorSpecification>(),
50 std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
51 nullptr, // prng_factory
52 sender,
53 nullptr, // prng_runner
54 std::vector<ClientStatus>(total_number_of_clients,
55 ClientStatus::UNMASKING_RESPONSE_RECEIVED),
56 ServerVariant::NATIVE_V1);
57 }
58
CreateState(MockSendToClientsInterface * sender,int number_of_clients_failed_after_sending_masked_input=0,int number_of_clients_failed_before_sending_masked_input=0,int number_of_clients_terminated_without_unmasking=0,std::unique_ptr<SecAggVectorMap> map=std::unique_ptr<SecAggVectorMap> (),MockSecAggServerMetricsListener * metrics_listener=nullptr)59 SecAggServerCompletedState CreateState(
60 MockSendToClientsInterface* sender,
61 int number_of_clients_failed_after_sending_masked_input = 0,
62 int number_of_clients_failed_before_sending_masked_input = 0,
63 int number_of_clients_terminated_without_unmasking = 0,
64 std::unique_ptr<SecAggVectorMap> map = std::unique_ptr<SecAggVectorMap>(),
65 MockSecAggServerMetricsListener* metrics_listener = nullptr) {
66 std::unique_ptr<AesSecAggServerProtocolImpl> impl =
67 CreateSecAggServerProtocolImpl(sender, metrics_listener);
68 impl->SetResult(std::move(map));
69 return SecAggServerCompletedState(
70 std::move(impl), number_of_clients_failed_after_sending_masked_input,
71 number_of_clients_failed_before_sending_masked_input,
72 number_of_clients_terminated_without_unmasking);
73 }
74
TEST(SecAggServerCompletedStateTest,IsAbortedReturnsFalse)75 TEST(SecAggServerCompletedStateTest, IsAbortedReturnsFalse) {
76 auto sender = std::make_unique<MockSendToClientsInterface>();
77 SecAggServerCompletedState completed_state = CreateState(sender.get());
78 EXPECT_THAT(completed_state.IsAborted(), Eq(false));
79 }
80
TEST(SecAggServerCompletedStateTest,IsCompletedSuccessfullyReturnsTrue)81 TEST(SecAggServerCompletedStateTest, IsCompletedSuccessfullyReturnsTrue) {
82 auto sender = std::make_unique<MockSendToClientsInterface>();
83 SecAggServerCompletedState completed_state = CreateState(sender.get());
84 EXPECT_THAT(completed_state.IsCompletedSuccessfully(), Eq(true));
85 }
86
TEST(SecAggServerCompletedStateTest,ErrorMessageRaisesError)87 TEST(SecAggServerCompletedStateTest, ErrorMessageRaisesError) {
88 auto sender = std::make_unique<MockSendToClientsInterface>();
89 SecAggServerCompletedState completed_state = CreateState(sender.get());
90 EXPECT_THAT(completed_state.ErrorMessage().ok(), Eq(false));
91 }
92
TEST(SecAggServerCompletedStateTest,ReadyForNextRoundReturnsFalse)93 TEST(SecAggServerCompletedStateTest, ReadyForNextRoundReturnsFalse) {
94 auto sender = std::make_unique<MockSendToClientsInterface>();
95 SecAggServerCompletedState completed_state = CreateState(sender.get());
96 EXPECT_THAT(completed_state.ReadyForNextRound(), Eq(false));
97 }
98
TEST(SecAggServerCompletedStateTest,NumberOfMessagesReceivedInThisRoundReturnsZero)99 TEST(SecAggServerCompletedStateTest,
100 NumberOfMessagesReceivedInThisRoundReturnsZero) {
101 auto sender = std::make_unique<MockSendToClientsInterface>();
102 SecAggServerCompletedState completed_state = CreateState(sender.get());
103 EXPECT_THAT(completed_state.NumberOfMessagesReceivedInThisRound(), Eq(0));
104 }
105
TEST(SecAggServerCompletedStateTest,NumberOfClientsReadyForNextRoundReturnsZero)106 TEST(SecAggServerCompletedStateTest,
107 NumberOfClientsReadyForNextRoundReturnsZero) {
108 auto sender = std::make_unique<MockSendToClientsInterface>();
109 SecAggServerCompletedState completed_state = CreateState(sender.get());
110 EXPECT_THAT(completed_state.NumberOfClientsReadyForNextRound(), Eq(0));
111 }
112
TEST(SecAggServerCompletedStateTest,NumberOfAliveClientsIsAccurate)113 TEST(SecAggServerCompletedStateTest, NumberOfAliveClientsIsAccurate) {
114 auto sender = std::make_unique<MockSendToClientsInterface>();
115 SecAggServerCompletedState completed_state = CreateState(
116 sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
117 0, // number_of_clients_failed_before_sending_masked_input
118 1); // number_of_clients_terminated_without_unmasking
119 EXPECT_THAT(completed_state.NumberOfAliveClients(), Eq(3));
120 }
121
TEST(SecAggServerCompletedStateTest,NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate)122 TEST(SecAggServerCompletedStateTest,
123 NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate) {
124 auto sender = std::make_unique<MockSendToClientsInterface>();
125 SecAggServerCompletedState completed_state = CreateState(
126 sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
127 1, // number_of_clients_failed_before_sending_masked_input
128 0); // number_of_clients_terminated_without_unmasking
129 EXPECT_THAT(completed_state.NumberOfClientsFailedBeforeSendingMaskedInput(),
130 Eq(1));
131 }
132
TEST(SecAggServerCompletedStateTest,NumberOfClientsFailedAfterSendingMaskedInputIsAccurate)133 TEST(SecAggServerCompletedStateTest,
134 NumberOfClientsFailedAfterSendingMaskedInputIsAccurate) {
135 auto sender = std::make_unique<MockSendToClientsInterface>();
136 SecAggServerCompletedState completed_state = CreateState(
137 sender.get(), 1, // number_of_clients_failed_after_sending_masked_input
138 0, // number_of_clients_failed_before_sending_masked_input
139 0); // number_of_clients_terminated_without_unmasking
140 EXPECT_THAT(completed_state.NumberOfClientsFailedAfterSendingMaskedInput(),
141 Eq(1));
142 }
143
TEST(SecAggServerCompletedStateTest,NumberOfClientsTerminatedWithoutUnmaskingIsAccurate)144 TEST(SecAggServerCompletedStateTest,
145 NumberOfClientsTerminatedWithoutUnmaskingIsAccurate) {
146 auto sender = std::make_unique<MockSendToClientsInterface>();
147 SecAggServerCompletedState completed_state = CreateState(
148 sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
149 0, // number_of_clients_failed_before_sending_masked_input
150 1); // number_of_clients_terminated_without_unmasking
151 EXPECT_THAT(completed_state.NumberOfClientsTerminatedWithoutUnmasking(),
152 Eq(1));
153 }
154
TEST(SecAggServerCompletedStateTest,NumberOfPendingClientsReturnsZero)155 TEST(SecAggServerCompletedStateTest, NumberOfPendingClientsReturnsZero) {
156 auto sender = std::make_unique<MockSendToClientsInterface>();
157 SecAggServerCompletedState completed_state = CreateState(sender.get());
158 EXPECT_THAT(completed_state.NumberOfPendingClients(), Eq(0));
159 }
160
TEST(SecAggServerCompletedStateTest,NumberOfIncludedInputsIsAccurate)161 TEST(SecAggServerCompletedStateTest, NumberOfIncludedInputsIsAccurate) {
162 auto sender = std::make_unique<MockSendToClientsInterface>();
163 SecAggServerCompletedState completed_state = CreateState(
164 sender.get(), 1, // number_of_clients_failed_after_sending_masked_input
165 0, // number_of_clients_failed_before_sending_masked_input
166 0); // number_of_clients_terminated_without_unmasking
167 EXPECT_THAT(completed_state.NumberOfIncludedInputs(), Eq(4));
168
169 SecAggServerCompletedState completed_state_2 = CreateState(
170 sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
171 1, // number_of_clients_failed_before_sending_masked_input
172 0); // number_of_clients_terminated_without_unmasking
173 EXPECT_THAT(completed_state_2.NumberOfIncludedInputs(), Eq(3));
174 }
175
TEST(SecAggServerCompletedStateTest,IsNumberOfIncludedInputsCommittedReturnsTrue)176 TEST(SecAggServerCompletedStateTest,
177 IsNumberOfIncludedInputsCommittedReturnsTrue) {
178 auto sender = std::make_unique<MockSendToClientsInterface>();
179 SecAggServerCompletedState completed_state = CreateState(sender.get());
180 EXPECT_THAT(completed_state.IsNumberOfIncludedInputsCommitted(), Eq(true));
181 }
182
TEST(SecAggServerCompletedStateTest,MinimumMessagesNeededForNextRoundReturnsZero)183 TEST(SecAggServerCompletedStateTest,
184 MinimumMessagesNeededForNextRoundReturnsZero) {
185 auto sender = std::make_unique<MockSendToClientsInterface>();
186 SecAggServerCompletedState completed_state = CreateState(sender.get());
187 EXPECT_THAT(completed_state.MinimumMessagesNeededForNextRound(), Eq(0));
188 }
189
TEST(SecAggServerCompletedStateTest,MinimumNumberOfClientsToProceedIsAccurate)190 TEST(SecAggServerCompletedStateTest,
191 MinimumNumberOfClientsToProceedIsAccurate) {
192 auto sender = std::make_unique<MockSendToClientsInterface>();
193 SecAggServerCompletedState completed_state = CreateState(sender.get());
194 EXPECT_THAT(completed_state.minimum_number_of_clients_to_proceed(), Eq(3));
195 }
196
TEST(SecAggServerCompletedStateTest,HandleMessageRaisesError)197 TEST(SecAggServerCompletedStateTest, HandleMessageRaisesError) {
198 auto sender = std::make_unique<MockSendToClientsInterface>();
199 MockSecAggServerMetricsListener* metrics =
200 new MockSecAggServerMetricsListener();
201
202 SecAggServerCompletedState completed_state = CreateState(
203 sender.get(), 0, // number_of_clients_failed_after_sending_masked_input
204 0, // number_of_clients_failed_before_sending_masked_input
205 0, // number_of_clients_terminated_without_unmasking
206 std::unique_ptr<SecAggVectorMap>(), metrics);
207
208 ClientToServerWrapperMessage client_message;
209 EXPECT_CALL(*metrics, MessageReceivedSizes(
210 Eq(ClientToServerWrapperMessage::
211 MessageContentCase::MESSAGE_CONTENT_NOT_SET),
212 Eq(false), Eq(client_message.ByteSizeLong())));
213 EXPECT_THAT(completed_state.HandleMessage(0, client_message).ok(), Eq(false));
214 }
215
TEST(SecAggServerCompletedStateTest,ProceedToNextRoundRaisesError)216 TEST(SecAggServerCompletedStateTest, ProceedToNextRoundRaisesError) {
217 auto sender = std::make_unique<MockSendToClientsInterface>();
218 SecAggServerCompletedState completed_state = CreateState(sender.get());
219 EXPECT_THAT(completed_state.ProceedToNextRound().ok(), Eq(false));
220 }
221
TEST(SecAggServerCompletedStateTest,ResultGivesStoredResult)222 TEST(SecAggServerCompletedStateTest, ResultGivesStoredResult) {
223 std::vector<uint64_t> vec = {1, 3, 6, 10};
224 auto result_map = std::make_unique<SecAggVectorMap>();
225 auto sender = std::make_unique<MockSendToClientsInterface>();
226 result_map->emplace("foobar", SecAggVector(vec, 32));
227 SecAggServerCompletedState completed_state =
228 CreateState(sender.get(),
229 0, // number_of_clients_failed_after_sending_masked_input
230 0, // number_of_clients_failed_before_sending_masked_input
231 0, // number_of_clients_terminated_without_unmasking
232 std::move(result_map));
233
234 auto result = completed_state.Result();
235 ASSERT_THAT(result.ok(), Eq(true));
236 EXPECT_THAT(*result.value(),
237 testing::MatchesSecAggVector("foobar", SecAggVector(vec, 32)));
238 }
239
TEST(SecAggServerCompletedStateTest,ConstructorRecordsSuccessMetric)240 TEST(SecAggServerCompletedStateTest, ConstructorRecordsSuccessMetric) {
241 TestTracingRecorder tracing_recorder;
242 auto sender = std::make_unique<MockSendToClientsInterface>();
243 MockSecAggServerMetricsListener* metrics =
244 new MockSecAggServerMetricsListener();
245
246 EXPECT_CALL(*metrics, ProtocolOutcomes(Eq(SecAggServerOutcome::SUCCESS)));
247 SecAggServerCompletedState completed_state =
248 CreateState(sender.get(),
249 0, // number_of_clients_failed_after_sending_masked_input
250 0, // number_of_clients_failed_before_sending_masked_input
251 0, // number_of_clients_terminated_without_unmasking
252 std::unique_ptr<SecAggVectorMap>(), metrics);
253
254 EXPECT_THAT(tracing_recorder.FindAllEvents<SecAggProtocolOutcome>(),
255 ElementsAre(IsEvent<SecAggProtocolOutcome>(
256 Eq(TracingSecAggServerOutcome_Success))));
257 }
258
259 } // namespace
260 } // namespace secagg
261 } // namespace fcp
262