/* * Copyright 2019 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "fcp/secagg/server/secagg_server.h" #include #include #include #include #include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "fcp/base/monitoring.h" #include "fcp/secagg/server/secagg_server_enums.pb.h" #include "fcp/secagg/server/secagg_server_state.h" #include "fcp/secagg/server/tracing_schema.h" #include "fcp/secagg/shared/aes_ctr_prng_factory.h" #include "fcp/secagg/shared/input_vector_specification.h" #include "fcp/secagg/shared/secagg_messages.pb.h" #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h" #include "fcp/secagg/testing/server/mock_secagg_server_metrics_listener.h" #include "fcp/secagg/testing/server/mock_send_to_clients_interface.h" #include "fcp/secagg/testing/server/test_secagg_experiments.h" #include "fcp/testing/testing.h" #include "fcp/tracing/test_tracing_recorder.h" namespace fcp { namespace secagg { namespace { using ::testing::_; using ::testing::Eq; std::unique_ptr CreateServer( SendToClientsInterface* sender, SecAggServerMetricsListener* metrics = new MockSecAggServerMetricsListener(), std::unique_ptr experiments = std::make_unique()) { SecureAggregationRequirements threat_model; threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER); threat_model.set_adversarial_client_rate(.3); threat_model.set_estimated_dropout_rate(.3); std::unique_ptr prng_factory; std::vector input_vector_specs; input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32)); auto status_or_server = SecAggServer::Create( 100, // minimum_number_of_clients_to_proceed 1000, // total_number_of_clients input_vector_specs, sender, std::unique_ptr(metrics), /*prng_runner=*/nullptr, std::move(experiments), threat_model); EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status(); return std::move(status_or_server.value()); } template auto TraceRecorderHas(const M&... matchers) { return ElementsAre(AllOf( IsSpan(), ElementsAre( IsEvent( 1000, // number_of_clients 219, // degree 116, // threshold 700, // minimum_number_of_clients_to_proceed false), // is_r2_async_aggregation_enabled AllOf(IsSpan(), ElementsAre(matchers...))))); } TEST(SecaggServerTest, ConstructedWithCorrectState) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); EXPECT_THAT(server->IsAborted(), Eq(false)); EXPECT_THAT(server->NumberOfNeighbors(), Eq(219)); EXPECT_THAT(server->IsCompletedSuccessfully(), Eq(false)); EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS)); EXPECT_THAT(tracing_recorder.root(), TraceRecorderHas(IsSpan( SecAggServerTraceState_R0AdvertiseKeys))); } TEST(SecaggServerTest, FullgraphSecAggExperimentTakesEffect) { // Tests FullgraphSecAggExperiment by instatiating // a server under that experiment , and // checking that it results in the expected number of neighbors for the given // setting (1000 clients) and threat model (.3 dropout rate and .3 adversarial // client rate). SecureAggregationRequirements threat_model; threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER); threat_model.set_adversarial_client_rate(.3); threat_model.set_estimated_dropout_rate(.3); std::unique_ptr prng_factory; std::vector input_vector_specs; input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32)); auto sender = std::make_unique(); std::set experiment_names = {kFullgraphSecAggExperiment}; auto status_or_server = SecAggServer::Create( 100, // minimum_number_of_clients_to_proceed 1000, // total_number_of_clients input_vector_specs, sender.get(), std::unique_ptr( new MockSecAggServerMetricsListener()), /*prng_runner=*/nullptr, std::make_unique(experiment_names), threat_model); EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status(); EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(1000)); EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false)); EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false)); EXPECT_THAT(status_or_server.value()->State(), Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS)); } TEST(SecaggServerTest, SubgraphSecAggResortsToFullGraphOnSmallCohorts) { // Tests that a small number of clients for which subgraph-secagg does not // have favorable parameters results in executiong the full-graph varian SecureAggregationRequirements threat_model; threat_model.set_adversary_class(AdversaryClass::CURIOUS_SERVER); threat_model.set_adversarial_client_rate(.45); threat_model.set_estimated_dropout_rate(.45); std::unique_ptr prng_factory; std::vector input_vector_specs; auto sender = std::make_unique(); input_vector_specs.push_back(InputVectorSpecification("foobar", 4, 32)); std::set experiment_names = {}; auto status_or_server = SecAggServer::Create( 5, // minimum_number_of_clients_to_proceed 25, // total_number_of_clients input_vector_specs, sender.get(), std::unique_ptr( new MockSecAggServerMetricsListener()), /*prng_runner=*/nullptr, std::make_unique(experiment_names), threat_model); EXPECT_THAT(status_or_server.ok(), true) << status_or_server.status(); EXPECT_THAT(status_or_server.value()->NumberOfNeighbors(), Eq(25)); EXPECT_THAT( status_or_server.value()->MinimumSurvivingNeighborsForReconstruction(), Eq(14)); EXPECT_THAT(status_or_server.value()->IsAborted(), Eq(false)); EXPECT_THAT(status_or_server.value()->IsCompletedSuccessfully(), Eq(false)); EXPECT_THAT(status_or_server.value()->State(), Eq(SecAggServerStateKind::R0_ADVERTISE_KEYS)); } TEST(SecaggServerTest, AbortClientWithInvalidIdThrowsError) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); EXPECT_THAT( server->AbortClient(1001, ClientAbortReason::CONNECTION_DROPPED).code(), Eq(FAILED_PRECONDITION)); } TEST(SecaggServerTest, ReceiveMessageWithInvalidIdThrowsError) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); ClientToServerWrapperMessage client_abort_message; client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test."); EXPECT_THAT( server ->ReceiveMessage(1001, std::make_unique( client_abort_message)) .status() .code(), Eq(FAILED_PRECONDITION)); } TEST(SecaggServerTest, AbortCausesStateTransitionAndMessageToBeSent) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb( abort: { early_success: false diagnostic_info: "Abort upon external request." })pb"); EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message))); Status result = server->Abort(); EXPECT_THAT(result.code(), Eq(OK)); EXPECT_THAT(server->IsAborted(), Eq(true)); EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED)); ASSERT_THAT(server->ErrorMessage().ok(), Eq(true)); EXPECT_THAT(server->ErrorMessage().value(), Eq("Abort upon external request.")); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas( AllOf(IsSpan( SecAggServerTraceState_R0AdvertiseKeys), ElementsAre( IsSpan("Abort upon external request."))), IsSpan(SecAggServerTraceState_Aborted))); } TEST(SecaggServerTest, AbortWithReasonCausesStateTransitionAndMessageToBeSent) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); const ServerToClientWrapperMessage abort_message = PARSE_TEXT_PROTO(R"pb( abort: { early_success: false diagnostic_info: "Abort upon external request for reason ." })pb"); EXPECT_CALL(*sender, SendBroadcast(EqualsProto(abort_message))); Status result = server->Abort("Test reason.", SecAggServerOutcome::EXTERNAL_REQUEST); EXPECT_THAT(result.code(), Eq(OK)); EXPECT_THAT(server->IsAborted(), Eq(true)); EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED)); ASSERT_THAT(server->ErrorMessage().ok(), Eq(true)); EXPECT_THAT(server->ErrorMessage().value(), Eq("Abort upon external request for reason .")); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas( AllOf(IsSpan( SecAggServerTraceState_R0AdvertiseKeys), ElementsAre(IsSpan( "Abort upon external request for reason ."))), IsSpan(SecAggServerTraceState_Aborted))); } TEST(SecaggServerTest, AbortClientNotCheckedIn) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); MockSecAggServerMetricsListener* metrics = new MockSecAggServerMetricsListener(); auto server = CreateServer(sender.get(), metrics); EXPECT_CALL(*metrics, ClientsDropped( Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING), Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT))) .Times(0); // Client is not notified EXPECT_CALL(*sender, Send(_, _)).Times(0); Status result = server->AbortClient(2, ClientAbortReason::NOT_CHECKED_IN); EXPECT_THAT(result.code(), Eq(OK)); EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true)); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas(AllOf( IsSpan(SecAggServerTraceState_R0AdvertiseKeys), ElementsAre(IsSpan(2, "NOT_CHECKED_IN"))))); } TEST(SecaggServerTest, AbortClientWhenConnectionDropped) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); MockSecAggServerMetricsListener* metrics = new MockSecAggServerMetricsListener(); auto server = CreateServer(sender.get(), metrics); EXPECT_CALL(*metrics, ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING), Eq(ClientDropReason::CONNECTION_CLOSED))); // Client is not notified EXPECT_CALL(*sender, Send(_, _)).Times(0); Status result = server->AbortClient(2, ClientAbortReason::CONNECTION_DROPPED); EXPECT_THAT(result.code(), Eq(OK)); EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true)); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas(AllOf( IsSpan(SecAggServerTraceState_R0AdvertiseKeys), ElementsAre(IsSpan(2, "CONNECTION_DROPPED"))))); } TEST(SecaggServerTest, AbortClientWhenInvalidMessageSent) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); MockSecAggServerMetricsListener* metrics = new MockSecAggServerMetricsListener(); auto server = CreateServer(sender.get(), metrics); const ServerToClientWrapperMessage message = PARSE_TEXT_PROTO(R"pb( abort: { early_success: false diagnostic_info: "The protocol is closing client with ClientAbortReason ." })pb"); EXPECT_CALL(*sender, Send(2, EqualsProto(message))); EXPECT_CALL( *metrics, ClientsDropped(Eq(ClientStatus::DEAD_BEFORE_SENDING_ANYTHING), Eq(ClientDropReason::SERVER_PROTOCOL_ABORT_CLIENT))); Status result = server->AbortClient(2, ClientAbortReason::INVALID_MESSAGE); EXPECT_THAT(result.code(), Eq(OK)); EXPECT_THAT(server->AbortedClientIds().contains(2), Eq(true)); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas(AllOf( IsSpan(SecAggServerTraceState_R0AdvertiseKeys), ElementsAre(IsSpan(2, "INVALID_MESSAGE"))))); } TEST(SecaggServerTest, ReceiveMessageCausesServerToAbortIfTooManyClientsAbort) { // The actual behavior of the server upon receipt of messages is tested in the // state class test files, but this tests the special behavior that the server // should automatically transition to an abort state if it cannot continue. TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); StatusOr clients_needed = server->MinimumMessagesNeededForNextRound(); ASSERT_THAT(clients_needed.ok(), Eq(true)); int maximum_number_of_aborts = server->NumberOfAliveClients() - clients_needed.value(); EcdhPregeneratedTestKeys ecdh_keys; ClientToServerWrapperMessage client_abort_message; client_abort_message.mutable_abort()->set_diagnostic_info("Abort for test."); // Receiving `maximum_number_of_aborts - 1` aborts should not cause the entire // protocol to abort. std::vector> matchers; for (int i = 0; i < maximum_number_of_aborts; ++i) { StatusOr result = server->ReceiveMessage( i, std::make_unique(client_abort_message)); matchers.push_back(IsSpan(i)); ASSERT_THAT(result.ok(), Eq(true)); EXPECT_THAT(result.value(), Eq(false)); EXPECT_THAT(server->IsAborted(), Eq(false)); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas(AllOf(IsSpan( SecAggServerTraceState_R0AdvertiseKeys), ElementsAreArray(matchers)))); } // Receiving `maximum_number_of_aborts` aborts means the protocol is ready to // proceed to the aborted state, which is indicated by ReceiveMessage // returning true. StatusOr result = server->ReceiveMessage( maximum_number_of_aborts, std::make_unique(client_abort_message)); matchers.push_back(IsSpan(maximum_number_of_aborts)); ASSERT_THAT(result.ok(), Eq(true)); EXPECT_THAT(result.value(), Eq(true)); // However the server is not aborted until ProceedToNextRound is called. EXPECT_THAT(server->IsAborted(), Eq(false)); EXPECT_THAT(server->ProceedToNextRound(), IsOk()); matchers.push_back(IsSpan()); EXPECT_THAT(server->IsAborted(), Eq(true)); EXPECT_THAT(server->State(), Eq(SecAggServerStateKind::ABORTED)); EXPECT_THAT( tracing_recorder.root(), TraceRecorderHas( AllOf(IsSpan( SecAggServerTraceState_R0AdvertiseKeys), ElementsAreArray(matchers)), IsSpan(SecAggServerTraceState_Aborted))); } TEST(SecaggServerTest, VerifyErrorsInAbortedState) { TestTracingRecorder tracing_recorder; auto sender = std::make_unique(); auto server = CreateServer(sender.get()); EXPECT_THAT(server->Abort(), IsOk()); EXPECT_THAT( server->ReceiveMessage(1, std::make_unique( ClientToServerWrapperMessage{})), IsCode(FAILED_PRECONDITION)); EXPECT_THAT(server->ProceedToNextRound(), IsCode(FAILED_PRECONDITION)); EXPECT_THAT(server->MinimumMessagesNeededForNextRound(), IsCode(FAILED_PRECONDITION)); EXPECT_THAT(server->NumberOfMessagesReceivedInThisRound(), IsCode(FAILED_PRECONDITION)); EXPECT_THAT(server->ReadyForNextRound(), IsCode(FAILED_PRECONDITION)); } } // namespace } // namespace secagg } // namespace fcp