xref: /aosp_15_r20/external/federated-compute/fcp/client/secagg_runner.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_SECAGG_RUNNER_H_
17 #define FCP_CLIENT_SECAGG_RUNNER_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "fcp/client/federated_protocol.h"
23 #include "fcp/client/interruptible_runner.h"
24 #include "fcp/client/secagg_event_publisher.h"
25 #include "fcp/secagg/client/secagg_client.h"
26 #include "fcp/secagg/shared/input_vector_specification.h"
27 #include "fcp/secagg/shared/secagg_messages.pb.h"
28 #include "fcp/secagg/shared/secagg_vector.h"
29 
30 namespace fcp {
31 namespace client {
32 
33 // Base SecAggSendToServer class which provides message size and network
34 // bandwidth usage metrics. When the child class inherit from this class, it's
35 // up to the child class to record metrics correctly.
36 class SecAggSendToServerBase : public secagg::SendToServerInterface {
37  public:
last_sent_message_size()38   size_t last_sent_message_size() const { return last_sent_message_size_; }
39 
40  protected:
41   size_t last_sent_message_size_ = 0;
42 };
43 
44 // A delegate class which handles server to client communication protocol
45 // specific details (HTTP vs gRPC etc).
46 class SecAggProtocolDelegate {
47  public:
48   virtual ~SecAggProtocolDelegate() = default;
49   // Retrieve the modulus for a given SecAgg vector.
50   virtual absl::StatusOr<uint64_t> GetModulus(const std::string& key) = 0;
51   // Receive Server message.
52   virtual absl::StatusOr<secagg::ServerToClientWrapperMessage>
53   ReceiveServerMessage() = 0;
54   // Called when the SecAgg protocol is interrupted.
55   virtual void Abort() = 0;
56   virtual size_t last_received_message_size() = 0;
57 };
58 
59 // A helper class which runs the secure aggregation protocol.
60 class SecAggRunner {
61  public:
62   virtual ~SecAggRunner() = default;
63   virtual absl::Status Run(ComputationResults results) = 0;
64 };
65 
66 // Implementation of SecAggRunner.
67 class SecAggRunnerImpl : public SecAggRunner {
68  public:
69   SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
70                    std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
71                    SecAggEventPublisher* secagg_event_publisher,
72                    LogManager* log_manager,
73                    InterruptibleRunner* interruptible_runner,
74                    int64_t expected_number_of_clients,
75                    int64_t minimum_surviving_clients_for_reconstruction);
76   // Run the secure aggregation protocol.
77   // SecAggProtocolDelegate and SecAggSendToServerBase will only be invoked from
78   // a single thread.
79   absl::Status Run(ComputationResults results) override;
80 
81  private:
82   void AbortInternal();
83 
84   std::unique_ptr<SecAggSendToServerBase> send_to_server_impl_;
85   std::unique_ptr<SecAggProtocolDelegate> protocol_delegate_;
86   std::unique_ptr<secagg::SecAggClient> secagg_client_;
87   SecAggEventPublisher& secagg_event_publisher_;
88   LogManager& log_manager_;
89   InterruptibleRunner& interruptible_runner_;
90   const int64_t expected_number_of_clients_;
91   const int64_t minimum_surviving_clients_for_reconstruction_;
92 };
93 
94 // A factory interface for SecAggRunner.
95 class SecAggRunnerFactory {
96  public:
97   virtual ~SecAggRunnerFactory() = default;
98   virtual std::unique_ptr<SecAggRunner> CreateSecAggRunner(
99       std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
100       std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
101       SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
102       InterruptibleRunner* interruptible_runner,
103       int64_t expected_number_of_clients,
104       int64_t minimum_surviving_clients_for_reconstruction) = 0;
105 };
106 
107 class SecAggRunnerFactoryImpl : public SecAggRunnerFactory {
108   std::unique_ptr<SecAggRunner> CreateSecAggRunner(
109       std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
110       std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
111       SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
112       InterruptibleRunner* interruptible_runner,
113       int64_t expected_number_of_clients,
114       int64_t minimum_surviving_clients_for_reconstruction) override;
115 };
116 
117 }  // namespace client
118 }  // namespace fcp
119 
120 #endif  // FCP_CLIENT_SECAGG_RUNNER_H_
121