xref: /aosp_15_r20/external/federated-compute/fcp/client/secagg_runner.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2022 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_SECAGG_RUNNER_H_
17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_SECAGG_RUNNER_H_
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker 
22*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol.h"
23*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/client/secagg_event_publisher.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/client/secagg_client.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/input_vector_specification.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/secagg_messages.pb.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/secagg_vector.h"
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker namespace fcp {
31*14675a02SAndroid Build Coastguard Worker namespace client {
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker // Base SecAggSendToServer class which provides message size and network
34*14675a02SAndroid Build Coastguard Worker // bandwidth usage metrics. When the child class inherit from this class, it's
35*14675a02SAndroid Build Coastguard Worker // up to the child class to record metrics correctly.
36*14675a02SAndroid Build Coastguard Worker class SecAggSendToServerBase : public secagg::SendToServerInterface {
37*14675a02SAndroid Build Coastguard Worker  public:
last_sent_message_size()38*14675a02SAndroid Build Coastguard Worker   size_t last_sent_message_size() const { return last_sent_message_size_; }
39*14675a02SAndroid Build Coastguard Worker 
40*14675a02SAndroid Build Coastguard Worker  protected:
41*14675a02SAndroid Build Coastguard Worker   size_t last_sent_message_size_ = 0;
42*14675a02SAndroid Build Coastguard Worker };
43*14675a02SAndroid Build Coastguard Worker 
44*14675a02SAndroid Build Coastguard Worker // A delegate class which handles server to client communication protocol
45*14675a02SAndroid Build Coastguard Worker // specific details (HTTP vs gRPC etc).
46*14675a02SAndroid Build Coastguard Worker class SecAggProtocolDelegate {
47*14675a02SAndroid Build Coastguard Worker  public:
48*14675a02SAndroid Build Coastguard Worker   virtual ~SecAggProtocolDelegate() = default;
49*14675a02SAndroid Build Coastguard Worker   // Retrieve the modulus for a given SecAgg vector.
50*14675a02SAndroid Build Coastguard Worker   virtual absl::StatusOr<uint64_t> GetModulus(const std::string& key) = 0;
51*14675a02SAndroid Build Coastguard Worker   // Receive Server message.
52*14675a02SAndroid Build Coastguard Worker   virtual absl::StatusOr<secagg::ServerToClientWrapperMessage>
53*14675a02SAndroid Build Coastguard Worker   ReceiveServerMessage() = 0;
54*14675a02SAndroid Build Coastguard Worker   // Called when the SecAgg protocol is interrupted.
55*14675a02SAndroid Build Coastguard Worker   virtual void Abort() = 0;
56*14675a02SAndroid Build Coastguard Worker   virtual size_t last_received_message_size() = 0;
57*14675a02SAndroid Build Coastguard Worker };
58*14675a02SAndroid Build Coastguard Worker 
59*14675a02SAndroid Build Coastguard Worker // A helper class which runs the secure aggregation protocol.
60*14675a02SAndroid Build Coastguard Worker class SecAggRunner {
61*14675a02SAndroid Build Coastguard Worker  public:
62*14675a02SAndroid Build Coastguard Worker   virtual ~SecAggRunner() = default;
63*14675a02SAndroid Build Coastguard Worker   virtual absl::Status Run(ComputationResults results) = 0;
64*14675a02SAndroid Build Coastguard Worker };
65*14675a02SAndroid Build Coastguard Worker 
66*14675a02SAndroid Build Coastguard Worker // Implementation of SecAggRunner.
67*14675a02SAndroid Build Coastguard Worker class SecAggRunnerImpl : public SecAggRunner {
68*14675a02SAndroid Build Coastguard Worker  public:
69*14675a02SAndroid Build Coastguard Worker   SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
70*14675a02SAndroid Build Coastguard Worker                    std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
71*14675a02SAndroid Build Coastguard Worker                    SecAggEventPublisher* secagg_event_publisher,
72*14675a02SAndroid Build Coastguard Worker                    LogManager* log_manager,
73*14675a02SAndroid Build Coastguard Worker                    InterruptibleRunner* interruptible_runner,
74*14675a02SAndroid Build Coastguard Worker                    int64_t expected_number_of_clients,
75*14675a02SAndroid Build Coastguard Worker                    int64_t minimum_surviving_clients_for_reconstruction);
76*14675a02SAndroid Build Coastguard Worker   // Run the secure aggregation protocol.
77*14675a02SAndroid Build Coastguard Worker   // SecAggProtocolDelegate and SecAggSendToServerBase will only be invoked from
78*14675a02SAndroid Build Coastguard Worker   // a single thread.
79*14675a02SAndroid Build Coastguard Worker   absl::Status Run(ComputationResults results) override;
80*14675a02SAndroid Build Coastguard Worker 
81*14675a02SAndroid Build Coastguard Worker  private:
82*14675a02SAndroid Build Coastguard Worker   void AbortInternal();
83*14675a02SAndroid Build Coastguard Worker 
84*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<SecAggSendToServerBase> send_to_server_impl_;
85*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<SecAggProtocolDelegate> protocol_delegate_;
86*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<secagg::SecAggClient> secagg_client_;
87*14675a02SAndroid Build Coastguard Worker   SecAggEventPublisher& secagg_event_publisher_;
88*14675a02SAndroid Build Coastguard Worker   LogManager& log_manager_;
89*14675a02SAndroid Build Coastguard Worker   InterruptibleRunner& interruptible_runner_;
90*14675a02SAndroid Build Coastguard Worker   const int64_t expected_number_of_clients_;
91*14675a02SAndroid Build Coastguard Worker   const int64_t minimum_surviving_clients_for_reconstruction_;
92*14675a02SAndroid Build Coastguard Worker };
93*14675a02SAndroid Build Coastguard Worker 
94*14675a02SAndroid Build Coastguard Worker // A factory interface for SecAggRunner.
95*14675a02SAndroid Build Coastguard Worker class SecAggRunnerFactory {
96*14675a02SAndroid Build Coastguard Worker  public:
97*14675a02SAndroid Build Coastguard Worker   virtual ~SecAggRunnerFactory() = default;
98*14675a02SAndroid Build Coastguard Worker   virtual std::unique_ptr<SecAggRunner> CreateSecAggRunner(
99*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
100*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
101*14675a02SAndroid Build Coastguard Worker       SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
102*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner* interruptible_runner,
103*14675a02SAndroid Build Coastguard Worker       int64_t expected_number_of_clients,
104*14675a02SAndroid Build Coastguard Worker       int64_t minimum_surviving_clients_for_reconstruction) = 0;
105*14675a02SAndroid Build Coastguard Worker };
106*14675a02SAndroid Build Coastguard Worker 
107*14675a02SAndroid Build Coastguard Worker class SecAggRunnerFactoryImpl : public SecAggRunnerFactory {
108*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<SecAggRunner> CreateSecAggRunner(
109*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
110*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
111*14675a02SAndroid Build Coastguard Worker       SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
112*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner* interruptible_runner,
113*14675a02SAndroid Build Coastguard Worker       int64_t expected_number_of_clients,
114*14675a02SAndroid Build Coastguard Worker       int64_t minimum_surviving_clients_for_reconstruction) override;
115*14675a02SAndroid Build Coastguard Worker };
116*14675a02SAndroid Build Coastguard Worker 
117*14675a02SAndroid Build Coastguard Worker }  // namespace client
118*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
119*14675a02SAndroid Build Coastguard Worker 
120*14675a02SAndroid Build Coastguard Worker #endif  // FCP_CLIENT_SECAGG_RUNNER_H_
121