xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_server_completed_state_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/secagg_server_completed_state.h"
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "absl/container/node_hash_set.h"
25 #include "fcp/base/monitoring.h"
26 #include "fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h"
27 #include "fcp/secagg/server/secagg_server_enums.pb.h"
28 #include "fcp/secagg/server/secret_sharing_graph_factory.h"
29 #include "fcp/secagg/shared/secagg_messages.pb.h"
30 #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h"
31 #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h"
32 #include "fcp/secagg/testing/test_matchers.h"
33 #include "fcp/tracing/test_tracing_recorder.h"
34 
35 namespace fcp {
36 namespace secagg {
37 namespace {
38 
39 using ::testing::Eq;
40 
CreateSecAggServerProtocolImpl(MockSendToClientsInterface * sender,MockSecAggServerMetricsListener * metrics_listener=nullptr)41 std::unique_ptr<AesSecAggServerProtocolImpl> CreateSecAggServerProtocolImpl(
42     MockSendToClientsInterface* sender,
43     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
44   int total_number_of_clients = 4;
45   SecretSharingGraphFactory factory;
46   return std::make_unique<AesSecAggServerProtocolImpl>(
47       factory.CreateCompleteGraph(total_number_of_clients, 3),
48       3,  // minimum_number_of_clients_to_proceed
49       std::vector<InputVectorSpecification>(),
50       std::unique_ptr<MockSecAggServerMetricsListener>(metrics_listener),
51       nullptr,  // prng_factory
52       sender,
53       nullptr,  // prng_runner
54       std::vector<ClientStatus>(total_number_of_clients,
55                                 ClientStatus::UNMASKING_RESPONSE_RECEIVED),
56       ServerVariant::NATIVE_V1);
57 }
58 
CreateState(MockSendToClientsInterface * sender,int number_of_clients_failed_after_sending_masked_input=0,int number_of_clients_failed_before_sending_masked_input=0,int number_of_clients_terminated_without_unmasking=0,std::unique_ptr<SecAggVectorMap> map=std::unique_ptr<SecAggVectorMap> (),MockSecAggServerMetricsListener * metrics_listener=nullptr)59 SecAggServerCompletedState CreateState(
60     MockSendToClientsInterface* sender,
61     int number_of_clients_failed_after_sending_masked_input = 0,
62     int number_of_clients_failed_before_sending_masked_input = 0,
63     int number_of_clients_terminated_without_unmasking = 0,
64     std::unique_ptr<SecAggVectorMap> map = std::unique_ptr<SecAggVectorMap>(),
65     MockSecAggServerMetricsListener* metrics_listener = nullptr) {
66   std::unique_ptr<AesSecAggServerProtocolImpl> impl =
67       CreateSecAggServerProtocolImpl(sender, metrics_listener);
68   impl->SetResult(std::move(map));
69   return SecAggServerCompletedState(
70       std::move(impl), number_of_clients_failed_after_sending_masked_input,
71       number_of_clients_failed_before_sending_masked_input,
72       number_of_clients_terminated_without_unmasking);
73 }
74 
TEST(SecAggServerCompletedStateTest,IsAbortedReturnsFalse)75 TEST(SecAggServerCompletedStateTest, IsAbortedReturnsFalse) {
76   auto sender = std::make_unique<MockSendToClientsInterface>();
77   SecAggServerCompletedState completed_state = CreateState(sender.get());
78   EXPECT_THAT(completed_state.IsAborted(), Eq(false));
79 }
80 
TEST(SecAggServerCompletedStateTest,IsCompletedSuccessfullyReturnsTrue)81 TEST(SecAggServerCompletedStateTest, IsCompletedSuccessfullyReturnsTrue) {
82   auto sender = std::make_unique<MockSendToClientsInterface>();
83   SecAggServerCompletedState completed_state = CreateState(sender.get());
84   EXPECT_THAT(completed_state.IsCompletedSuccessfully(), Eq(true));
85 }
86 
TEST(SecAggServerCompletedStateTest,ErrorMessageRaisesError)87 TEST(SecAggServerCompletedStateTest, ErrorMessageRaisesError) {
88   auto sender = std::make_unique<MockSendToClientsInterface>();
89   SecAggServerCompletedState completed_state = CreateState(sender.get());
90   EXPECT_THAT(completed_state.ErrorMessage().ok(), Eq(false));
91 }
92 
TEST(SecAggServerCompletedStateTest,ReadyForNextRoundReturnsFalse)93 TEST(SecAggServerCompletedStateTest, ReadyForNextRoundReturnsFalse) {
94   auto sender = std::make_unique<MockSendToClientsInterface>();
95   SecAggServerCompletedState completed_state = CreateState(sender.get());
96   EXPECT_THAT(completed_state.ReadyForNextRound(), Eq(false));
97 }
98 
TEST(SecAggServerCompletedStateTest,NumberOfMessagesReceivedInThisRoundReturnsZero)99 TEST(SecAggServerCompletedStateTest,
100      NumberOfMessagesReceivedInThisRoundReturnsZero) {
101   auto sender = std::make_unique<MockSendToClientsInterface>();
102   SecAggServerCompletedState completed_state = CreateState(sender.get());
103   EXPECT_THAT(completed_state.NumberOfMessagesReceivedInThisRound(), Eq(0));
104 }
105 
TEST(SecAggServerCompletedStateTest,NumberOfClientsReadyForNextRoundReturnsZero)106 TEST(SecAggServerCompletedStateTest,
107      NumberOfClientsReadyForNextRoundReturnsZero) {
108   auto sender = std::make_unique<MockSendToClientsInterface>();
109   SecAggServerCompletedState completed_state = CreateState(sender.get());
110   EXPECT_THAT(completed_state.NumberOfClientsReadyForNextRound(), Eq(0));
111 }
112 
TEST(SecAggServerCompletedStateTest,NumberOfAliveClientsIsAccurate)113 TEST(SecAggServerCompletedStateTest, NumberOfAliveClientsIsAccurate) {
114   auto sender = std::make_unique<MockSendToClientsInterface>();
115   SecAggServerCompletedState completed_state = CreateState(
116       sender.get(), 0,  // number_of_clients_failed_after_sending_masked_input
117       0,                // number_of_clients_failed_before_sending_masked_input
118       1);               // number_of_clients_terminated_without_unmasking
119   EXPECT_THAT(completed_state.NumberOfAliveClients(), Eq(3));
120 }
121 
TEST(SecAggServerCompletedStateTest,NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate)122 TEST(SecAggServerCompletedStateTest,
123      NumberOfClientsFailedBeforeSendingMaskedInputIsAccurate) {
124   auto sender = std::make_unique<MockSendToClientsInterface>();
125   SecAggServerCompletedState completed_state = CreateState(
126       sender.get(), 0,  // number_of_clients_failed_after_sending_masked_input
127       1,                // number_of_clients_failed_before_sending_masked_input
128       0);               // number_of_clients_terminated_without_unmasking
129   EXPECT_THAT(completed_state.NumberOfClientsFailedBeforeSendingMaskedInput(),
130               Eq(1));
131 }
132 
TEST(SecAggServerCompletedStateTest,NumberOfClientsFailedAfterSendingMaskedInputIsAccurate)133 TEST(SecAggServerCompletedStateTest,
134      NumberOfClientsFailedAfterSendingMaskedInputIsAccurate) {
135   auto sender = std::make_unique<MockSendToClientsInterface>();
136   SecAggServerCompletedState completed_state = CreateState(
137       sender.get(), 1,  // number_of_clients_failed_after_sending_masked_input
138       0,                // number_of_clients_failed_before_sending_masked_input
139       0);               // number_of_clients_terminated_without_unmasking
140   EXPECT_THAT(completed_state.NumberOfClientsFailedAfterSendingMaskedInput(),
141               Eq(1));
142 }
143 
TEST(SecAggServerCompletedStateTest,NumberOfClientsTerminatedWithoutUnmaskingIsAccurate)144 TEST(SecAggServerCompletedStateTest,
145      NumberOfClientsTerminatedWithoutUnmaskingIsAccurate) {
146   auto sender = std::make_unique<MockSendToClientsInterface>();
147   SecAggServerCompletedState completed_state = CreateState(
148       sender.get(), 0,  // number_of_clients_failed_after_sending_masked_input
149       0,                // number_of_clients_failed_before_sending_masked_input
150       1);               // number_of_clients_terminated_without_unmasking
151   EXPECT_THAT(completed_state.NumberOfClientsTerminatedWithoutUnmasking(),
152               Eq(1));
153 }
154 
TEST(SecAggServerCompletedStateTest,NumberOfPendingClientsReturnsZero)155 TEST(SecAggServerCompletedStateTest, NumberOfPendingClientsReturnsZero) {
156   auto sender = std::make_unique<MockSendToClientsInterface>();
157   SecAggServerCompletedState completed_state = CreateState(sender.get());
158   EXPECT_THAT(completed_state.NumberOfPendingClients(), Eq(0));
159 }
160 
TEST(SecAggServerCompletedStateTest,NumberOfIncludedInputsIsAccurate)161 TEST(SecAggServerCompletedStateTest, NumberOfIncludedInputsIsAccurate) {
162   auto sender = std::make_unique<MockSendToClientsInterface>();
163   SecAggServerCompletedState completed_state = CreateState(
164       sender.get(), 1,  // number_of_clients_failed_after_sending_masked_input
165       0,                // number_of_clients_failed_before_sending_masked_input
166       0);               // number_of_clients_terminated_without_unmasking
167   EXPECT_THAT(completed_state.NumberOfIncludedInputs(), Eq(4));
168 
169   SecAggServerCompletedState completed_state_2 = CreateState(
170       sender.get(), 0,  // number_of_clients_failed_after_sending_masked_input
171       1,                // number_of_clients_failed_before_sending_masked_input
172       0);               // number_of_clients_terminated_without_unmasking
173   EXPECT_THAT(completed_state_2.NumberOfIncludedInputs(), Eq(3));
174 }
175 
TEST(SecAggServerCompletedStateTest,IsNumberOfIncludedInputsCommittedReturnsTrue)176 TEST(SecAggServerCompletedStateTest,
177      IsNumberOfIncludedInputsCommittedReturnsTrue) {
178   auto sender = std::make_unique<MockSendToClientsInterface>();
179   SecAggServerCompletedState completed_state = CreateState(sender.get());
180   EXPECT_THAT(completed_state.IsNumberOfIncludedInputsCommitted(), Eq(true));
181 }
182 
TEST(SecAggServerCompletedStateTest,MinimumMessagesNeededForNextRoundReturnsZero)183 TEST(SecAggServerCompletedStateTest,
184      MinimumMessagesNeededForNextRoundReturnsZero) {
185   auto sender = std::make_unique<MockSendToClientsInterface>();
186   SecAggServerCompletedState completed_state = CreateState(sender.get());
187   EXPECT_THAT(completed_state.MinimumMessagesNeededForNextRound(), Eq(0));
188 }
189 
TEST(SecAggServerCompletedStateTest,MinimumNumberOfClientsToProceedIsAccurate)190 TEST(SecAggServerCompletedStateTest,
191      MinimumNumberOfClientsToProceedIsAccurate) {
192   auto sender = std::make_unique<MockSendToClientsInterface>();
193   SecAggServerCompletedState completed_state = CreateState(sender.get());
194   EXPECT_THAT(completed_state.minimum_number_of_clients_to_proceed(), Eq(3));
195 }
196 
TEST(SecAggServerCompletedStateTest,HandleMessageRaisesError)197 TEST(SecAggServerCompletedStateTest, HandleMessageRaisesError) {
198   auto sender = std::make_unique<MockSendToClientsInterface>();
199   MockSecAggServerMetricsListener* metrics =
200       new MockSecAggServerMetricsListener();
201 
202   SecAggServerCompletedState completed_state = CreateState(
203       sender.get(), 0,  // number_of_clients_failed_after_sending_masked_input
204       0,                // number_of_clients_failed_before_sending_masked_input
205       0,                // number_of_clients_terminated_without_unmasking
206       std::unique_ptr<SecAggVectorMap>(), metrics);
207 
208   ClientToServerWrapperMessage client_message;
209   EXPECT_CALL(*metrics, MessageReceivedSizes(
210                             Eq(ClientToServerWrapperMessage::
211                                    MessageContentCase::MESSAGE_CONTENT_NOT_SET),
212                             Eq(false), Eq(client_message.ByteSizeLong())));
213   EXPECT_THAT(completed_state.HandleMessage(0, client_message).ok(), Eq(false));
214 }
215 
TEST(SecAggServerCompletedStateTest,ProceedToNextRoundRaisesError)216 TEST(SecAggServerCompletedStateTest, ProceedToNextRoundRaisesError) {
217   auto sender = std::make_unique<MockSendToClientsInterface>();
218   SecAggServerCompletedState completed_state = CreateState(sender.get());
219   EXPECT_THAT(completed_state.ProceedToNextRound().ok(), Eq(false));
220 }
221 
TEST(SecAggServerCompletedStateTest,ResultGivesStoredResult)222 TEST(SecAggServerCompletedStateTest, ResultGivesStoredResult) {
223   std::vector<uint64_t> vec = {1, 3, 6, 10};
224   auto result_map = std::make_unique<SecAggVectorMap>();
225   auto sender = std::make_unique<MockSendToClientsInterface>();
226   result_map->emplace("foobar", SecAggVector(vec, 32));
227   SecAggServerCompletedState completed_state =
228       CreateState(sender.get(),
229                   0,  // number_of_clients_failed_after_sending_masked_input
230                   0,  // number_of_clients_failed_before_sending_masked_input
231                   0,  // number_of_clients_terminated_without_unmasking
232                   std::move(result_map));
233 
234   auto result = completed_state.Result();
235   ASSERT_THAT(result.ok(), Eq(true));
236   EXPECT_THAT(*result.value(),
237               testing::MatchesSecAggVector("foobar", SecAggVector(vec, 32)));
238 }
239 
TEST(SecAggServerCompletedStateTest,ConstructorRecordsSuccessMetric)240 TEST(SecAggServerCompletedStateTest, ConstructorRecordsSuccessMetric) {
241   TestTracingRecorder tracing_recorder;
242   auto sender = std::make_unique<MockSendToClientsInterface>();
243   MockSecAggServerMetricsListener* metrics =
244       new MockSecAggServerMetricsListener();
245 
246   EXPECT_CALL(*metrics, ProtocolOutcomes(Eq(SecAggServerOutcome::SUCCESS)));
247   SecAggServerCompletedState completed_state =
248       CreateState(sender.get(),
249                   0,  // number_of_clients_failed_after_sending_masked_input
250                   0,  // number_of_clients_failed_before_sending_masked_input
251                   0,  // number_of_clients_terminated_without_unmasking
252                   std::unique_ptr<SecAggVectorMap>(), metrics);
253 
254   EXPECT_THAT(tracing_recorder.FindAllEvents<SecAggProtocolOutcome>(),
255               ElementsAre(IsEvent<SecAggProtocolOutcome>(
256                   Eq(TracingSecAggServerOutcome_Success))));
257 }
258 
259 }  // namespace
260 }  // namespace secagg
261 }  // namespace fcp
262