xref: /aosp_15_r20/external/federated-compute/fcp/client/secagg_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/secagg_runner.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22 
23 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
24 #include "fcp/secagg/shared/crypto_rand_prng.h"
25 #include "fcp/secagg/shared/input_vector_specification.h"
26 
27 namespace fcp {
28 namespace client {
29 
30 using ::fcp::secagg::ClientState;
31 
32 // Implementation of StateTransitionListenerInterface.
33 class SecAggStateTransitionListenerImpl
34     : public secagg::StateTransitionListenerInterface {
35  public:
36   SecAggStateTransitionListenerImpl(
37       SecAggEventPublisher& secagg_event_publisher, LogManager& log_manager,
38       SecAggSendToServerBase& secagg_send_to_server_impl,
39       SecAggProtocolDelegate& secagg_protocol_delegate);
40   void Transition(secagg::ClientState new_state) override;
41 
42   void Started(secagg::ClientState state) override;
43 
44   void Stopped(secagg::ClientState state) override;
45 
46   void set_execution_session_id(int64_t execution_session_id) override;
47 
48  private:
49   SecAggEventPublisher& secagg_event_publisher_;
50   LogManager& log_manager_;
51   SecAggSendToServerBase& secagg_send_to_server_;
52   SecAggProtocolDelegate& secagg_protocol_delegate_;
53   secagg::ClientState state_ = secagg::ClientState::INITIAL;
54 };
55 
SecAggStateTransitionListenerImpl(SecAggEventPublisher & secagg_event_publisher,LogManager & log_manager,SecAggSendToServerBase & secagg_send_to_server_impl,SecAggProtocolDelegate & secagg_protocol_delegate)56 SecAggStateTransitionListenerImpl::SecAggStateTransitionListenerImpl(
57     SecAggEventPublisher& secagg_event_publisher, LogManager& log_manager,
58     SecAggSendToServerBase& secagg_send_to_server_impl,
59     SecAggProtocolDelegate& secagg_protocol_delegate)
60     : secagg_event_publisher_(secagg_event_publisher),
61       log_manager_(log_manager),
62       secagg_send_to_server_(secagg_send_to_server_impl),
63       secagg_protocol_delegate_(secagg_protocol_delegate) {}
64 
Transition(ClientState new_state)65 void SecAggStateTransitionListenerImpl::Transition(ClientState new_state) {
66   FCP_LOG(INFO) << "Transitioning from state: " << static_cast<int>(state_)
67                 << " to state: " << static_cast<int>(new_state);
68   state_ = new_state;
69   if (state_ == ClientState::ABORTED) {
70     log_manager_.LogDiag(ProdDiagCode::SECAGG_CLIENT_NATIVE_ERROR_GENERIC);
71   }
72   secagg_event_publisher_.PublishStateTransition(
73       new_state, secagg_send_to_server_.last_sent_message_size(),
74       secagg_protocol_delegate_.last_received_message_size());
75 }
76 
Started(ClientState state)77 void SecAggStateTransitionListenerImpl::Started(ClientState state) {
78   // TODO(team): Implement this.
79 }
80 
Stopped(ClientState state)81 void SecAggStateTransitionListenerImpl::Stopped(ClientState state) {
82   // TODO(team): Implement this.
83 }
84 
set_execution_session_id(int64_t execution_session_id)85 void SecAggStateTransitionListenerImpl::set_execution_session_id(
86     int64_t execution_session_id) {
87   secagg_event_publisher_.set_execution_session_id(execution_session_id);
88 }
89 
SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,SecAggEventPublisher * secagg_event_publisher,LogManager * log_manager,InterruptibleRunner * interruptible_runner,int64_t expected_number_of_clients,int64_t minimum_surviving_clients_for_reconstruction)90 SecAggRunnerImpl::SecAggRunnerImpl(
91     std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
92     std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
93     SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
94     InterruptibleRunner* interruptible_runner,
95     int64_t expected_number_of_clients,
96     int64_t minimum_surviving_clients_for_reconstruction)
97     : send_to_server_impl_(std::move(send_to_server_impl)),
98       protocol_delegate_(std::move(protocol_delegate)),
99       secagg_event_publisher_(*secagg_event_publisher),
100       log_manager_(*log_manager),
101       interruptible_runner_(*interruptible_runner),
102       expected_number_of_clients_(expected_number_of_clients),
103       minimum_surviving_clients_for_reconstruction_(
104           minimum_surviving_clients_for_reconstruction) {}
105 
Run(ComputationResults results)106 absl::Status SecAggRunnerImpl::Run(ComputationResults results) {
107   auto secagg_state_transition_listener =
108       std::make_unique<SecAggStateTransitionListenerImpl>(
109           secagg_event_publisher_, log_manager_, *send_to_server_impl_,
110           *protocol_delegate_);
111   auto input_map = std::make_unique<secagg::SecAggVectorMap>();
112   std::vector<secagg::InputVectorSpecification> input_vector_specification;
113   for (auto& [k, v] : results) {
114     if (std::holds_alternative<QuantizedTensor>(v)) {
115       FCP_ASSIGN_OR_RETURN(uint64_t modulus, protocol_delegate_->GetModulus(k));
116       // Note: std::move is used below to ensure that each QuantizedTensor
117       // is consumed when converted to SecAggVector and that we don't
118       // continue having both in memory for longer than needed.
119       auto vector = std::get<QuantizedTensor>(std::move(v));
120       if (modulus <= 1 || modulus > secagg::SecAggVector::kMaxModulus) {
121         return absl::InternalError(
122             absl::StrCat("Invalid SecAgg modulus configuration: ", modulus));
123       }
124       if (vector.values.empty())
125         return absl::InternalError(
126             absl::StrCat("Zero sized vector found: ", k));
127       int64_t flattened_length = 1;
128       for (const auto& size : vector.dimensions) flattened_length *= size;
129       auto data_length = vector.values.size();
130       if (flattened_length != data_length)
131         return absl::InternalError(
132             absl::StrCat("Flattened length: ", flattened_length,
133                          " does not match vector size: ", data_length));
134       for (const auto& value : vector.values) {
135         if (value >= modulus) {
136           return absl::InternalError(absl::StrCat(
137               "The input SecAgg vector doesn't have the appropriate "
138               "modulus: element with value ",
139               value, " found, max value allowed ", (modulus - 1ULL)));
140         }
141       }
142       input_vector_specification.emplace_back(k, flattened_length, modulus);
143       input_map->try_emplace(
144           k, absl::MakeConstSpan(vector.values.data(), data_length), modulus);
145     }
146   }
147   secagg_client_ = std::make_unique<secagg::SecAggClient>(
148       expected_number_of_clients_,
149       minimum_surviving_clients_for_reconstruction_,
150       std::move(input_vector_specification),
151       std::make_unique<secagg::CryptoRandPrng>(),
152       std::move(send_to_server_impl_),
153       std::move(secagg_state_transition_listener),
154       std::make_unique<secagg::AesCtrPrngFactory>());
155 
156   FCP_RETURN_IF_ERROR(interruptible_runner_.Run(
157       [this, &input_map]() -> absl::Status {
158         FCP_RETURN_IF_ERROR(secagg_client_->Start());
159         FCP_RETURN_IF_ERROR(secagg_client_->SetInput(std::move(input_map)));
160         while (!secagg_client_->IsCompletedSuccessfully()) {
161           absl::StatusOr<secagg::ServerToClientWrapperMessage>
162               server_to_client_wrapper_message =
163                   this->protocol_delegate_->ReceiveServerMessage();
164           if (!server_to_client_wrapper_message.ok()) {
165             return absl::Status(
166                 server_to_client_wrapper_message.status().code(),
167                 absl::StrCat(
168                     "Error during SecAgg receive: ",
169                     server_to_client_wrapper_message.status().message()));
170           }
171           auto result =
172               secagg_client_->ReceiveMessage(*server_to_client_wrapper_message);
173           if (!result.ok()) {
174             this->secagg_event_publisher_.PublishError();
175             return absl::Status(result.status().code(),
176                                 absl::StrCat("Error receiving SecAgg message: ",
177                                              result.status().message()));
178           }
179           if (secagg_client_->IsAborted()) {
180             std::string error_message = "error message not found.";
181             if (secagg_client_->ErrorMessage().ok())
182               error_message = secagg_client_->ErrorMessage().value();
183             this->secagg_event_publisher_.PublishAbort(false, error_message);
184             return absl::CancelledError("SecAgg aborted: " + error_message);
185           }
186         }
187         return absl::OkStatus();
188       },
189       [this]() {
190         AbortInternal();
191         this->protocol_delegate_->Abort();
192       }));
193   return absl::OkStatus();
194 }
195 
AbortInternal()196 void SecAggRunnerImpl::AbortInternal() {
197   log_manager_.LogDiag(ProdDiagCode::SECAGG_CLIENT_NATIVE_ERROR_GENERIC);
198   auto abort_message = "Client-initiated abort.";
199   auto result = secagg_client_->Abort(abort_message);
200   if (!result.ok()) {
201     FCP_LOG(ERROR) << "Could not initiate client abort, code: " << result.code()
202                    << " message: " << result.message();
203   }
204   // Note: the implementation assumes that secagg_event_publisher
205   // cannot hang indefinitely, i.e. does not need its own interruption
206   // trigger.
207   secagg_event_publisher_.PublishAbort(true, abort_message);
208 }
209 
CreateSecAggRunner(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,SecAggEventPublisher * secagg_event_publisher,LogManager * log_manager,InterruptibleRunner * interruptible_runner,int64_t expected_number_of_clients,int64_t minimum_surviving_clients_for_reconstruction)210 std::unique_ptr<SecAggRunner> SecAggRunnerFactoryImpl::CreateSecAggRunner(
211     std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
212     std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
213     SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
214     InterruptibleRunner* interruptible_runner,
215     int64_t expected_number_of_clients,
216     int64_t minimum_surviving_clients_for_reconstruction) {
217   return std::make_unique<SecAggRunnerImpl>(
218       std::move(send_to_server_impl), std::move(protocol_delegate),
219       secagg_event_publisher, log_manager, interruptible_runner,
220       expected_number_of_clients, minimum_surviving_clients_for_reconstruction);
221 }
222 
223 }  // namespace client
224 }  // namespace fcp
225