1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2022 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_SECAGG_RUNNER_H_ 17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_SECAGG_RUNNER_H_ 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Worker #include <memory> 20*14675a02SAndroid Build Coastguard Worker #include <string> 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol.h" 23*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h" 24*14675a02SAndroid Build Coastguard Worker #include "fcp/client/secagg_event_publisher.h" 25*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/client/secagg_client.h" 26*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/input_vector_specification.h" 27*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/secagg_messages.pb.h" 28*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/secagg_vector.h" 29*14675a02SAndroid Build Coastguard Worker 30*14675a02SAndroid Build Coastguard Worker namespace fcp { 31*14675a02SAndroid Build Coastguard Worker namespace client { 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Worker // Base SecAggSendToServer class which provides message size and network 34*14675a02SAndroid Build Coastguard Worker // bandwidth usage metrics. When the child class inherit from this class, it's 35*14675a02SAndroid Build Coastguard Worker // up to the child class to record metrics correctly. 36*14675a02SAndroid Build Coastguard Worker class SecAggSendToServerBase : public secagg::SendToServerInterface { 37*14675a02SAndroid Build Coastguard Worker public: last_sent_message_size()38*14675a02SAndroid Build Coastguard Worker size_t last_sent_message_size() const { return last_sent_message_size_; } 39*14675a02SAndroid Build Coastguard Worker 40*14675a02SAndroid Build Coastguard Worker protected: 41*14675a02SAndroid Build Coastguard Worker size_t last_sent_message_size_ = 0; 42*14675a02SAndroid Build Coastguard Worker }; 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Worker // A delegate class which handles server to client communication protocol 45*14675a02SAndroid Build Coastguard Worker // specific details (HTTP vs gRPC etc). 46*14675a02SAndroid Build Coastguard Worker class SecAggProtocolDelegate { 47*14675a02SAndroid Build Coastguard Worker public: 48*14675a02SAndroid Build Coastguard Worker virtual ~SecAggProtocolDelegate() = default; 49*14675a02SAndroid Build Coastguard Worker // Retrieve the modulus for a given SecAgg vector. 50*14675a02SAndroid Build Coastguard Worker virtual absl::StatusOr<uint64_t> GetModulus(const std::string& key) = 0; 51*14675a02SAndroid Build Coastguard Worker // Receive Server message. 52*14675a02SAndroid Build Coastguard Worker virtual absl::StatusOr<secagg::ServerToClientWrapperMessage> 53*14675a02SAndroid Build Coastguard Worker ReceiveServerMessage() = 0; 54*14675a02SAndroid Build Coastguard Worker // Called when the SecAgg protocol is interrupted. 55*14675a02SAndroid Build Coastguard Worker virtual void Abort() = 0; 56*14675a02SAndroid Build Coastguard Worker virtual size_t last_received_message_size() = 0; 57*14675a02SAndroid Build Coastguard Worker }; 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker // A helper class which runs the secure aggregation protocol. 60*14675a02SAndroid Build Coastguard Worker class SecAggRunner { 61*14675a02SAndroid Build Coastguard Worker public: 62*14675a02SAndroid Build Coastguard Worker virtual ~SecAggRunner() = default; 63*14675a02SAndroid Build Coastguard Worker virtual absl::Status Run(ComputationResults results) = 0; 64*14675a02SAndroid Build Coastguard Worker }; 65*14675a02SAndroid Build Coastguard Worker 66*14675a02SAndroid Build Coastguard Worker // Implementation of SecAggRunner. 67*14675a02SAndroid Build Coastguard Worker class SecAggRunnerImpl : public SecAggRunner { 68*14675a02SAndroid Build Coastguard Worker public: 69*14675a02SAndroid Build Coastguard Worker SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 70*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 71*14675a02SAndroid Build Coastguard Worker SecAggEventPublisher* secagg_event_publisher, 72*14675a02SAndroid Build Coastguard Worker LogManager* log_manager, 73*14675a02SAndroid Build Coastguard Worker InterruptibleRunner* interruptible_runner, 74*14675a02SAndroid Build Coastguard Worker int64_t expected_number_of_clients, 75*14675a02SAndroid Build Coastguard Worker int64_t minimum_surviving_clients_for_reconstruction); 76*14675a02SAndroid Build Coastguard Worker // Run the secure aggregation protocol. 77*14675a02SAndroid Build Coastguard Worker // SecAggProtocolDelegate and SecAggSendToServerBase will only be invoked from 78*14675a02SAndroid Build Coastguard Worker // a single thread. 79*14675a02SAndroid Build Coastguard Worker absl::Status Run(ComputationResults results) override; 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker private: 82*14675a02SAndroid Build Coastguard Worker void AbortInternal(); 83*14675a02SAndroid Build Coastguard Worker 84*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggSendToServerBase> send_to_server_impl_; 85*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggProtocolDelegate> protocol_delegate_; 86*14675a02SAndroid Build Coastguard Worker std::unique_ptr<secagg::SecAggClient> secagg_client_; 87*14675a02SAndroid Build Coastguard Worker SecAggEventPublisher& secagg_event_publisher_; 88*14675a02SAndroid Build Coastguard Worker LogManager& log_manager_; 89*14675a02SAndroid Build Coastguard Worker InterruptibleRunner& interruptible_runner_; 90*14675a02SAndroid Build Coastguard Worker const int64_t expected_number_of_clients_; 91*14675a02SAndroid Build Coastguard Worker const int64_t minimum_surviving_clients_for_reconstruction_; 92*14675a02SAndroid Build Coastguard Worker }; 93*14675a02SAndroid Build Coastguard Worker 94*14675a02SAndroid Build Coastguard Worker // A factory interface for SecAggRunner. 95*14675a02SAndroid Build Coastguard Worker class SecAggRunnerFactory { 96*14675a02SAndroid Build Coastguard Worker public: 97*14675a02SAndroid Build Coastguard Worker virtual ~SecAggRunnerFactory() = default; 98*14675a02SAndroid Build Coastguard Worker virtual std::unique_ptr<SecAggRunner> CreateSecAggRunner( 99*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 100*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 101*14675a02SAndroid Build Coastguard Worker SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager, 102*14675a02SAndroid Build Coastguard Worker InterruptibleRunner* interruptible_runner, 103*14675a02SAndroid Build Coastguard Worker int64_t expected_number_of_clients, 104*14675a02SAndroid Build Coastguard Worker int64_t minimum_surviving_clients_for_reconstruction) = 0; 105*14675a02SAndroid Build Coastguard Worker }; 106*14675a02SAndroid Build Coastguard Worker 107*14675a02SAndroid Build Coastguard Worker class SecAggRunnerFactoryImpl : public SecAggRunnerFactory { 108*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunner> CreateSecAggRunner( 109*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 110*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 111*14675a02SAndroid Build Coastguard Worker SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager, 112*14675a02SAndroid Build Coastguard Worker InterruptibleRunner* interruptible_runner, 113*14675a02SAndroid Build Coastguard Worker int64_t expected_number_of_clients, 114*14675a02SAndroid Build Coastguard Worker int64_t minimum_surviving_clients_for_reconstruction) override; 115*14675a02SAndroid Build Coastguard Worker }; 116*14675a02SAndroid Build Coastguard Worker 117*14675a02SAndroid Build Coastguard Worker } // namespace client 118*14675a02SAndroid Build Coastguard Worker } // namespace fcp 119*14675a02SAndroid Build Coastguard Worker 120*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_SECAGG_RUNNER_H_ 121