1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/secagg_runner.h"
17
18 #include <memory>
19 #include <utility>
20 #include <variant>
21 #include <vector>
22
23 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
24 #include "fcp/secagg/shared/crypto_rand_prng.h"
25 #include "fcp/secagg/shared/input_vector_specification.h"
26
27 namespace fcp {
28 namespace client {
29
30 using ::fcp::secagg::ClientState;
31
32 // Implementation of StateTransitionListenerInterface.
33 class SecAggStateTransitionListenerImpl
34 : public secagg::StateTransitionListenerInterface {
35 public:
36 SecAggStateTransitionListenerImpl(
37 SecAggEventPublisher& secagg_event_publisher, LogManager& log_manager,
38 SecAggSendToServerBase& secagg_send_to_server_impl,
39 SecAggProtocolDelegate& secagg_protocol_delegate);
40 void Transition(secagg::ClientState new_state) override;
41
42 void Started(secagg::ClientState state) override;
43
44 void Stopped(secagg::ClientState state) override;
45
46 void set_execution_session_id(int64_t execution_session_id) override;
47
48 private:
49 SecAggEventPublisher& secagg_event_publisher_;
50 LogManager& log_manager_;
51 SecAggSendToServerBase& secagg_send_to_server_;
52 SecAggProtocolDelegate& secagg_protocol_delegate_;
53 secagg::ClientState state_ = secagg::ClientState::INITIAL;
54 };
55
SecAggStateTransitionListenerImpl(SecAggEventPublisher & secagg_event_publisher,LogManager & log_manager,SecAggSendToServerBase & secagg_send_to_server_impl,SecAggProtocolDelegate & secagg_protocol_delegate)56 SecAggStateTransitionListenerImpl::SecAggStateTransitionListenerImpl(
57 SecAggEventPublisher& secagg_event_publisher, LogManager& log_manager,
58 SecAggSendToServerBase& secagg_send_to_server_impl,
59 SecAggProtocolDelegate& secagg_protocol_delegate)
60 : secagg_event_publisher_(secagg_event_publisher),
61 log_manager_(log_manager),
62 secagg_send_to_server_(secagg_send_to_server_impl),
63 secagg_protocol_delegate_(secagg_protocol_delegate) {}
64
Transition(ClientState new_state)65 void SecAggStateTransitionListenerImpl::Transition(ClientState new_state) {
66 FCP_LOG(INFO) << "Transitioning from state: " << static_cast<int>(state_)
67 << " to state: " << static_cast<int>(new_state);
68 state_ = new_state;
69 if (state_ == ClientState::ABORTED) {
70 log_manager_.LogDiag(ProdDiagCode::SECAGG_CLIENT_NATIVE_ERROR_GENERIC);
71 }
72 secagg_event_publisher_.PublishStateTransition(
73 new_state, secagg_send_to_server_.last_sent_message_size(),
74 secagg_protocol_delegate_.last_received_message_size());
75 }
76
Started(ClientState state)77 void SecAggStateTransitionListenerImpl::Started(ClientState state) {
78 // TODO(team): Implement this.
79 }
80
Stopped(ClientState state)81 void SecAggStateTransitionListenerImpl::Stopped(ClientState state) {
82 // TODO(team): Implement this.
83 }
84
set_execution_session_id(int64_t execution_session_id)85 void SecAggStateTransitionListenerImpl::set_execution_session_id(
86 int64_t execution_session_id) {
87 secagg_event_publisher_.set_execution_session_id(execution_session_id);
88 }
89
SecAggRunnerImpl(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,SecAggEventPublisher * secagg_event_publisher,LogManager * log_manager,InterruptibleRunner * interruptible_runner,int64_t expected_number_of_clients,int64_t minimum_surviving_clients_for_reconstruction)90 SecAggRunnerImpl::SecAggRunnerImpl(
91 std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
92 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
93 SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
94 InterruptibleRunner* interruptible_runner,
95 int64_t expected_number_of_clients,
96 int64_t minimum_surviving_clients_for_reconstruction)
97 : send_to_server_impl_(std::move(send_to_server_impl)),
98 protocol_delegate_(std::move(protocol_delegate)),
99 secagg_event_publisher_(*secagg_event_publisher),
100 log_manager_(*log_manager),
101 interruptible_runner_(*interruptible_runner),
102 expected_number_of_clients_(expected_number_of_clients),
103 minimum_surviving_clients_for_reconstruction_(
104 minimum_surviving_clients_for_reconstruction) {}
105
Run(ComputationResults results)106 absl::Status SecAggRunnerImpl::Run(ComputationResults results) {
107 auto secagg_state_transition_listener =
108 std::make_unique<SecAggStateTransitionListenerImpl>(
109 secagg_event_publisher_, log_manager_, *send_to_server_impl_,
110 *protocol_delegate_);
111 auto input_map = std::make_unique<secagg::SecAggVectorMap>();
112 std::vector<secagg::InputVectorSpecification> input_vector_specification;
113 for (auto& [k, v] : results) {
114 if (std::holds_alternative<QuantizedTensor>(v)) {
115 FCP_ASSIGN_OR_RETURN(uint64_t modulus, protocol_delegate_->GetModulus(k));
116 // Note: std::move is used below to ensure that each QuantizedTensor
117 // is consumed when converted to SecAggVector and that we don't
118 // continue having both in memory for longer than needed.
119 auto vector = std::get<QuantizedTensor>(std::move(v));
120 if (modulus <= 1 || modulus > secagg::SecAggVector::kMaxModulus) {
121 return absl::InternalError(
122 absl::StrCat("Invalid SecAgg modulus configuration: ", modulus));
123 }
124 if (vector.values.empty())
125 return absl::InternalError(
126 absl::StrCat("Zero sized vector found: ", k));
127 int64_t flattened_length = 1;
128 for (const auto& size : vector.dimensions) flattened_length *= size;
129 auto data_length = vector.values.size();
130 if (flattened_length != data_length)
131 return absl::InternalError(
132 absl::StrCat("Flattened length: ", flattened_length,
133 " does not match vector size: ", data_length));
134 for (const auto& value : vector.values) {
135 if (value >= modulus) {
136 return absl::InternalError(absl::StrCat(
137 "The input SecAgg vector doesn't have the appropriate "
138 "modulus: element with value ",
139 value, " found, max value allowed ", (modulus - 1ULL)));
140 }
141 }
142 input_vector_specification.emplace_back(k, flattened_length, modulus);
143 input_map->try_emplace(
144 k, absl::MakeConstSpan(vector.values.data(), data_length), modulus);
145 }
146 }
147 secagg_client_ = std::make_unique<secagg::SecAggClient>(
148 expected_number_of_clients_,
149 minimum_surviving_clients_for_reconstruction_,
150 std::move(input_vector_specification),
151 std::make_unique<secagg::CryptoRandPrng>(),
152 std::move(send_to_server_impl_),
153 std::move(secagg_state_transition_listener),
154 std::make_unique<secagg::AesCtrPrngFactory>());
155
156 FCP_RETURN_IF_ERROR(interruptible_runner_.Run(
157 [this, &input_map]() -> absl::Status {
158 FCP_RETURN_IF_ERROR(secagg_client_->Start());
159 FCP_RETURN_IF_ERROR(secagg_client_->SetInput(std::move(input_map)));
160 while (!secagg_client_->IsCompletedSuccessfully()) {
161 absl::StatusOr<secagg::ServerToClientWrapperMessage>
162 server_to_client_wrapper_message =
163 this->protocol_delegate_->ReceiveServerMessage();
164 if (!server_to_client_wrapper_message.ok()) {
165 return absl::Status(
166 server_to_client_wrapper_message.status().code(),
167 absl::StrCat(
168 "Error during SecAgg receive: ",
169 server_to_client_wrapper_message.status().message()));
170 }
171 auto result =
172 secagg_client_->ReceiveMessage(*server_to_client_wrapper_message);
173 if (!result.ok()) {
174 this->secagg_event_publisher_.PublishError();
175 return absl::Status(result.status().code(),
176 absl::StrCat("Error receiving SecAgg message: ",
177 result.status().message()));
178 }
179 if (secagg_client_->IsAborted()) {
180 std::string error_message = "error message not found.";
181 if (secagg_client_->ErrorMessage().ok())
182 error_message = secagg_client_->ErrorMessage().value();
183 this->secagg_event_publisher_.PublishAbort(false, error_message);
184 return absl::CancelledError("SecAgg aborted: " + error_message);
185 }
186 }
187 return absl::OkStatus();
188 },
189 [this]() {
190 AbortInternal();
191 this->protocol_delegate_->Abort();
192 }));
193 return absl::OkStatus();
194 }
195
AbortInternal()196 void SecAggRunnerImpl::AbortInternal() {
197 log_manager_.LogDiag(ProdDiagCode::SECAGG_CLIENT_NATIVE_ERROR_GENERIC);
198 auto abort_message = "Client-initiated abort.";
199 auto result = secagg_client_->Abort(abort_message);
200 if (!result.ok()) {
201 FCP_LOG(ERROR) << "Could not initiate client abort, code: " << result.code()
202 << " message: " << result.message();
203 }
204 // Note: the implementation assumes that secagg_event_publisher
205 // cannot hang indefinitely, i.e. does not need its own interruption
206 // trigger.
207 secagg_event_publisher_.PublishAbort(true, abort_message);
208 }
209
CreateSecAggRunner(std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,SecAggEventPublisher * secagg_event_publisher,LogManager * log_manager,InterruptibleRunner * interruptible_runner,int64_t expected_number_of_clients,int64_t minimum_surviving_clients_for_reconstruction)210 std::unique_ptr<SecAggRunner> SecAggRunnerFactoryImpl::CreateSecAggRunner(
211 std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
212 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
213 SecAggEventPublisher* secagg_event_publisher, LogManager* log_manager,
214 InterruptibleRunner* interruptible_runner,
215 int64_t expected_number_of_clients,
216 int64_t minimum_surviving_clients_for_reconstruction) {
217 return std::make_unique<SecAggRunnerImpl>(
218 std::move(send_to_server_impl), std::move(protocol_delegate),
219 secagg_event_publisher, log_manager, interruptible_runner,
220 expected_number_of_clients, minimum_surviving_clients_for_reconstruction);
221 }
222
223 } // namespace client
224 } // namespace fcp
225