1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_SECAGG_RUNNER_H_ 17 #define FCP_CLIENT_SECAGG_RUNNER_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "fcp/client/federated_protocol.h" 23 #include "fcp/client/interruptible_runner.h" 24 #include "fcp/client/secagg_event_publisher.h" 25 #include "fcp/secagg/client/secagg_client.h" 26 #include "fcp/secagg/shared/input_vector_specification.h" 27 #include "fcp/secagg/shared/secagg_messages.pb.h" 28 #include "fcp/secagg/shared/secagg_vector.h" 29 30 namespace fcp { 31 namespace client { 32 33 // Base SecAggSendToServer class which provides message size and network 34 // bandwidth usage metrics. When the child class inherit from this class, it's 35 // up to the child class to record metrics correctly. 36 class SecAggSendToServerBase : public secagg::SendToServerInterface { 37 public: last_sent_message_size()38 size_t last_sent_message_size() const { return last_sent_message_size_; } 39 40 protected: 41 size_t last_sent_message_size_ = 0; 42 }; 43 44 // A delegate class which handles server to client communication protocol 45 // specific details (HTTP vs gRPC etc). 46 class SecAggProtocolDelegate { 47 public: 48 virtual ~SecAggProtocolDelegate() = default; 49 // Retrieve the modulus for a given SecAgg vector. 50 virtual absl::StatusOr<uint64_t> GetModulus(const std::string& key) = 0; 51 // Receive Server message. 52 virtual absl::StatusOr<secagg::ServerToClientWrapperMessage> 53 ReceiveServerMessage() = 0; 54 // Called when the SecAgg protocol is interrupted. 55 virtual void Abort() = 0; 56 virtual size_t last_received_message_size() = 0; 57 }; 58 59 // A helper class which runs the secure aggregation protocol. 60 class SecAggRunner { 61 public: 62 virtual ~SecAggRunner() = default; 63 virtual absl::Status Run(ComputationResults results) = 0; 64 }; 65 66 // Implementation of SecAggRunner. 67 class SecAggRunnerImpl : public SecAggRunner { 68 public: 69 SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 70 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 71 SecAggEventPublisher* secagg_event_publisher, 72 LogManager* log_manager, 73 InterruptibleRunner* interruptible_runner, 74 int64_t expected_number_of_clients, 75 int64_t minimum_surviving_clients_for_reconstruction); 76 // Run the secure aggregation protocol. 77 // SecAggProtocolDelegate and SecAggSendToServerBase will only be invoked from 78 // a single thread. 79 absl::Status Run(ComputationResults results) override; 80 81 private: 82 void AbortInternal(); 83 84 std::unique_ptr<SecAggSendToServerBase> send_to_server_impl_; 85 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate_; 86 std::unique_ptr<secagg::SecAggClient> secagg_client_; 87 SecAggEventPublisher& secagg_event_publisher_; 88 LogManager& log_manager_; 89 InterruptibleRunner& interruptible_runner_; 90 const int64_t expected_number_of_clients_; 91 const int64_t minimum_surviving_clients_for_reconstruction_; 92 }; 93 94 // A factory interface for SecAggRunner. 95 class SecAggRunnerFactory { 96 public: 97 virtual ~SecAggRunnerFactory() = default; 98 virtual std::unique_ptr<SecAggRunner> CreateSecAggRunner( 99 std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 100 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 101 SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager, 102 InterruptibleRunner* interruptible_runner, 103 int64_t expected_number_of_clients, 104 int64_t minimum_surviving_clients_for_reconstruction) = 0; 105 }; 106 107 class SecAggRunnerFactoryImpl : public SecAggRunnerFactory { 108 std::unique_ptr<SecAggRunner> CreateSecAggRunner( 109 std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 110 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 111 SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager, 112 InterruptibleRunner* interruptible_runner, 113 int64_t expected_number_of_clients, 114 int64_t minimum_surviving_clients_for_reconstruction) override; 115 }; 116 117 } // namespace client 118 } // namespace fcp 119 120 #endif // FCP_CLIENT_SECAGG_RUNNER_H_ 121