1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/http/http_secagg_send_to_server_impl.h"
17
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "google/protobuf/any.pb.h"
26 // #include "google/rpc/code.pb.h"
27 #include "absl/strings/substitute.h"
28 #include "fcp/client/http/http_client_util.h"
29 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
30
31 namespace fcp {
32 namespace client {
33 namespace http {
34
35 using ::google::internal::federatedcompute::v1::AbortSecureAggregationRequest;
36 using ::google::internal::federatedcompute::v1::AdvertiseKeysRequest;
37 using ::google::internal::federatedcompute::v1::AdvertiseKeysResponse;
38 using ::google::internal::federatedcompute::v1::ByteStreamResource;
39 using ::google::internal::federatedcompute::v1::ForwardingInfo;
40 using ::google::internal::federatedcompute::v1::ShareKeysRequest;
41 using ::google::internal::federatedcompute::v1::ShareKeysResponse;
42 using ::google::internal::federatedcompute::v1::
43 SubmitSecureAggregationResultRequest;
44 using ::google::internal::federatedcompute::v1::
45 SubmitSecureAggregationResultResponse;
46 using ::google::internal::federatedcompute::v1::UnmaskRequest;
47 // using ::google::longrunning::Operation;
48
49 namespace {
CreateAbortSecureAggregationUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)50 absl::StatusOr<std::string> CreateAbortSecureAggregationUriSuffix(
51 absl::string_view aggregation_id, absl::string_view client_token) {
52 constexpr absl::string_view pattern =
53 "/v1/secureaggregations/$0/clients/$1:abort";
54 FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
55 EncodeUriSinglePathSegment(aggregation_id));
56 FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
57 EncodeUriSinglePathSegment(client_token));
58 // Construct the URI suffix.
59 return absl::Substitute(pattern, encoded_aggregation_id,
60 encoded_client_token);
61 }
62
CreateAdvertiseKeysUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)63 absl::StatusOr<std::string> CreateAdvertiseKeysUriSuffix(
64 absl::string_view aggregation_id, absl::string_view client_token) {
65 constexpr absl::string_view pattern =
66 "/v1/secureaggregations/$0/clients/$1:advertisekeys";
67 FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
68 EncodeUriSinglePathSegment(aggregation_id));
69 FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
70 EncodeUriSinglePathSegment(client_token));
71 // Construct the URI suffix.
72 return absl::Substitute(pattern, encoded_aggregation_id,
73 encoded_client_token);
74 }
75
CreateShareKeysUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)76 absl::StatusOr<std::string> CreateShareKeysUriSuffix(
77 absl::string_view aggregation_id, absl::string_view client_token) {
78 constexpr absl::string_view pattern =
79 "/v1/secureaggregations/$0/clients/$1:sharekeys";
80 FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
81 EncodeUriSinglePathSegment(aggregation_id));
82 FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
83 EncodeUriSinglePathSegment(client_token));
84 // Construct the URI suffix.
85 return absl::Substitute(pattern, encoded_aggregation_id,
86 encoded_client_token);
87 }
88
CreateSubmitSecureAggregationResultUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)89 absl::StatusOr<std::string> CreateSubmitSecureAggregationResultUriSuffix(
90 absl::string_view aggregation_id, absl::string_view client_token) {
91 constexpr absl::string_view pattern =
92 "/v1/secureaggregations/$0/clients/$1:submit";
93 FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
94 EncodeUriSinglePathSegment(aggregation_id));
95 FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
96 EncodeUriSinglePathSegment(client_token));
97 // Construct the URI suffix.
98 return absl::Substitute(pattern, encoded_aggregation_id,
99 encoded_client_token);
100 }
101
CreateUnmaskUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)102 absl::StatusOr<std::string> CreateUnmaskUriSuffix(
103 absl::string_view aggregation_id, absl::string_view client_token) {
104 constexpr absl::string_view pattern =
105 "/v1/secureaggregations/$0/clients/$1:unmask";
106 FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
107 EncodeUriSinglePathSegment(aggregation_id));
108 FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
109 EncodeUriSinglePathSegment(client_token));
110 // Construct the URI suffix.
111 return absl::Substitute(pattern, encoded_aggregation_id,
112 encoded_client_token);
113 }
114
115 } // anonymous namespace
116
117 absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>>
Create(absl::string_view api_key,Clock * clock,ProtocolRequestHelper * request_helper,InterruptibleRunner * interruptible_runner,std::function<std::unique_ptr<InterruptibleRunner> (absl::Time)> delayed_interruptible_runner_creator,absl::StatusOr<secagg::ServerToClientWrapperMessage> * server_response_holder,absl::string_view aggregation_id,absl::string_view client_token,const ForwardingInfo & secagg_upload_forwarding_info,const ByteStreamResource & masked_result_resource,const ByteStreamResource & nonmasked_result_resource,std::optional<std::string> tf_checkpoint,bool disable_request_body_compression,absl::Duration waiting_period_for_cancellation)118 HttpSecAggSendToServerImpl::Create(
119 absl::string_view api_key, Clock* clock,
120 ProtocolRequestHelper* request_helper,
121 InterruptibleRunner* interruptible_runner,
122 std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
123 delayed_interruptible_runner_creator,
124 absl::StatusOr<secagg::ServerToClientWrapperMessage>*
125 server_response_holder,
126 absl::string_view aggregation_id, absl::string_view client_token,
127 const ForwardingInfo& secagg_upload_forwarding_info,
128 const ByteStreamResource& masked_result_resource,
129 const ByteStreamResource& nonmasked_result_resource,
130 std::optional<std::string> tf_checkpoint,
131 bool disable_request_body_compression,
132 absl::Duration waiting_period_for_cancellation) {
133 FCP_ASSIGN_OR_RETURN(
134 std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,
135 ProtocolRequestCreator::Create(api_key, secagg_upload_forwarding_info,
136 !disable_request_body_compression));
137 // We don't use request body compression for resource upload.
138 FCP_ASSIGN_OR_RETURN(
139 std::unique_ptr<ProtocolRequestCreator>
140 masked_result_upload_request_creator,
141 ProtocolRequestCreator::Create(
142 api_key, masked_result_resource.data_upload_forwarding_info(),
143 /*use_compression=*/false));
144 // We don't use request body compression for resource upload.
145 FCP_ASSIGN_OR_RETURN(
146 std::unique_ptr<ProtocolRequestCreator>
147 nonmasked_result_upload_request_creator,
148 ProtocolRequestCreator::Create(
149 api_key, nonmasked_result_resource.data_upload_forwarding_info(),
150 /*use_compression=*/false));
151
152 return absl::WrapUnique(new HttpSecAggSendToServerImpl(
153 api_key, clock, request_helper, interruptible_runner,
154 std::move(delayed_interruptible_runner_creator), server_response_holder,
155 aggregation_id, client_token, masked_result_resource.resource_name(),
156 nonmasked_result_resource.resource_name(),
157 std::move(secagg_request_creator),
158 std::move(masked_result_upload_request_creator),
159 std::move(nonmasked_result_upload_request_creator),
160 std::move(tf_checkpoint), waiting_period_for_cancellation));
161 }
162
163 // Despite the method name is "Send", this method is doing more. It sends the
164 // request, waits for the response and set the response to the response holder
165 // for the secagg client to access in the next round of secagg communications.
166 //
167 // The current SecAgg library is built around the assumption that the underlying
168 // network protocol is fully asynchronous and bidirectional. This was true for
169 // the gRPC protocol but isn't the case anymore for the HTTP protocol (which has
170 // a more traditional request/response structure). Nevertheless, because we
171 // still need to support the gRPC protocol the structure of the SecAgg library
172 // cannot be changed yet, and this means that currently we need to store away
173 // the result and let the secagg client to access on a later time. However, once
174 // the gRPC protocol support is removed, we should consider updating the SecAgg
175 // library to assume the more traditional request/response structure (e.g. by
176 // having SecAggSendToServer::Send return the corresponding response message).
177 //
178 // TODO(team): Simplify SecAgg library around request/response structure
179 // once gRPC support is removed.
Send(secagg::ClientToServerWrapperMessage * message)180 void HttpSecAggSendToServerImpl::Send(
181 secagg::ClientToServerWrapperMessage* message) {
182 absl::StatusOr<secagg::ServerToClientWrapperMessage> server_message;
183 if (message->has_advertise_keys()) {
184 server_response_holder_ =
185 DoR0AdvertiseKeys(std::move(message->advertise_keys()));
186 } else if (message->has_share_keys_response()) {
187 server_response_holder_ =
188 DoR1ShareKeys(std::move(message->share_keys_response()));
189 } else if (message->has_masked_input_response()) {
190 server_response_holder_ = DoR2SubmitSecureAggregationResult(
191 std::move(message->masked_input_response()));
192 } else if (message->has_unmasking_response()) {
193 server_response_holder_ =
194 DoR3Unmask(std::move(message->unmasking_response()));
195 } else if (message->has_abort()) {
196 server_response_holder_ =
197 AbortSecureAggregation(std::move(message->abort()));
198 } else {
199 // When the protocol succeeds, the ClientToServerWrapperMessage will be
200 // empty, and we'll just set the empty server message.
201 server_response_holder_ = secagg::ServerToClientWrapperMessage();
202 }
203 }
204
205 absl::StatusOr<secagg::ServerToClientWrapperMessage>
AbortSecureAggregation(secagg::AbortMessage abort_message)206 HttpSecAggSendToServerImpl::AbortSecureAggregation(
207 secagg::AbortMessage abort_message) {
208 FCP_ASSIGN_OR_RETURN(
209 std::string uri_suffix,
210 CreateAbortSecureAggregationUriSuffix(aggregation_id_, client_token_));
211
212 AbortSecureAggregationRequest request;
213 ::google::internal::federatedcompute::v1::Status* status =
214 request.mutable_status();
215 status->set_code(13);
216 status->set_message(abort_message.diagnostic_info());
217
218 FCP_ASSIGN_OR_RETURN(
219 std::unique_ptr<HttpRequest> http_request,
220 secagg_request_creator_->CreateProtocolRequest(
221 uri_suffix, QueryParams(), HttpRequest::Method::kPost,
222 request.SerializeAsString(),
223 /*is_protobuf_encoded=*/true));
224 std::unique_ptr<InterruptibleRunner> delayed_interruptible_runner =
225 delayed_interruptible_runner_creator_(clock_.Now() +
226 waiting_period_for_cancellation_);
227 FCP_ASSIGN_OR_RETURN(
228 InMemoryHttpResponse response,
229 request_helper_.PerformProtocolRequest(std::move(http_request),
230 *delayed_interruptible_runner));
231
232 secagg::ServerToClientWrapperMessage server_message;
233 server_message.mutable_abort();
234 return server_message;
235 }
236
237 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR0AdvertiseKeys(secagg::AdvertiseKeys advertise_keys)238 HttpSecAggSendToServerImpl::DoR0AdvertiseKeys(
239 secagg::AdvertiseKeys advertise_keys) {
240 FCP_ASSIGN_OR_RETURN(
241 std::string uri_suffix,
242 CreateAdvertiseKeysUriSuffix(aggregation_id_, client_token_));
243
244 AdvertiseKeysRequest request;
245 *request.mutable_advertise_keys() = advertise_keys;
246 FCP_ASSIGN_OR_RETURN(
247 std::unique_ptr<HttpRequest> http_request,
248 secagg_request_creator_->CreateProtocolRequest(
249 uri_suffix, QueryParams(), HttpRequest::Method::kPost,
250 request.SerializeAsString(),
251 /*is_protobuf_encoded=*/true));
252 FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse response,
253 request_helper_.PerformProtocolRequest(
254 std::move(http_request), interruptible_runner_));
255 // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
256 // ParseOperationProtoFromHttpResponse(response));
257
258 // FCP_ASSIGN_OR_RETURN(
259 // Operation completed_operation,
260 // request_helper_.PollOperationResponseUntilDone(
261 // initial_operation, *secagg_request_creator_,
262 // interruptible_runner_));
263
264 // // The Operation has finished. Check if it resulted in an error, and if so
265 // // forward it after converting it to an absl::Status error.
266 // if (completed_operation.has_error()) {
267 // return ConvertRpcStatusToAbslStatus(completed_operation.error());
268 // }
269 AdvertiseKeysResponse response_proto;
270 if (!response_proto.ParseFromString(std::string(response.body))) {
271 return absl::InternalError("could not parse AdvertiseKeysResponse proto");
272 }
273 secagg::ServerToClientWrapperMessage server_message;
274 *server_message.mutable_share_keys_request() =
275 response_proto.share_keys_server_request();
276 return server_message;
277 }
278
279 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR1ShareKeys(secagg::ShareKeysResponse share_keys_response)280 HttpSecAggSendToServerImpl::DoR1ShareKeys(
281 secagg::ShareKeysResponse share_keys_response) {
282 FCP_ASSIGN_OR_RETURN(
283 std::string uri_suffix,
284 CreateShareKeysUriSuffix(aggregation_id_, client_token_));
285
286 ShareKeysRequest request;
287 *request.mutable_share_keys_client_response() = share_keys_response;
288 FCP_ASSIGN_OR_RETURN(
289 std::unique_ptr<HttpRequest> http_request,
290 secagg_request_creator_->CreateProtocolRequest(
291 uri_suffix, QueryParams(), HttpRequest::Method::kPost,
292 request.SerializeAsString(),
293 /*is_protobuf_encoded=*/true));
294
295 FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse response,
296 request_helper_.PerformProtocolRequest(
297 std::move(http_request), interruptible_runner_));
298 // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
299 // ParseOperationProtoFromHttpResponse(response));
300
301 // FCP_ASSIGN_OR_RETURN(
302 // Operation completed_operation,
303 // request_helper_.PollOperationResponseUntilDone(
304 // initial_operation, *secagg_request_creator_,
305 // interruptible_runner_));
306
307 // // The Operation has finished. Check if it resulted in an error, and if so
308 // // forward it after converting it to an absl::Status error.
309 // if (completed_operation.has_error()) {
310 // return ConvertRpcStatusToAbslStatus(completed_operation.error());
311 // }
312 ShareKeysResponse response_proto;
313 if (!response_proto.ParseFromString(std::string(response.body))) {
314 return absl::InternalError(
315 "could not parse StartSecureAggregationResponse proto");
316 }
317 secagg::ServerToClientWrapperMessage server_message;
318 *server_message.mutable_masked_input_request() =
319 response_proto.masked_input_collection_server_request();
320 return server_message;
321 }
322
323 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR2SubmitSecureAggregationResult(secagg::MaskedInputCollectionResponse masked_input_response)324 HttpSecAggSendToServerImpl::DoR2SubmitSecureAggregationResult(
325 secagg::MaskedInputCollectionResponse masked_input_response) {
326 std::vector<std::unique_ptr<HttpRequest>> requests;
327 FCP_ASSIGN_OR_RETURN(std::string masked_result_upload_uri_suffix,
328 CreateByteStreamUploadUriSuffix(masked_resource_name_));
329
330 FCP_ASSIGN_OR_RETURN(
331 std::unique_ptr<HttpRequest> masked_input_upload_request,
332 masked_result_upload_request_creator_->CreateProtocolRequest(
333 masked_result_upload_uri_suffix, {{"upload_protocol", "raw"}},
334 HttpRequest::Method::kPost,
335 std::move(masked_input_response).SerializeAsString(),
336 /*is_protobuf_encoded=*/false));
337 requests.push_back(std::move(masked_input_upload_request));
338 bool has_checkpoint = tf_checkpoint_.has_value();
339 if (has_checkpoint) {
340 FCP_ASSIGN_OR_RETURN(
341 std::string nonmasked_result_upload_uri_suffix,
342 CreateByteStreamUploadUriSuffix(nonmasked_resource_name_));
343 FCP_ASSIGN_OR_RETURN(
344 std::unique_ptr<HttpRequest> nonmasked_input_upload_request,
345 nonmasked_result_upload_request_creator_->CreateProtocolRequest(
346 nonmasked_result_upload_uri_suffix, {{"upload_protocol", "raw"}},
347 HttpRequest::Method::kPost, std::move(tf_checkpoint_).value(),
348 /*is_protobuf_encoded=*/false));
349 requests.push_back(std::move(nonmasked_input_upload_request));
350 }
351 FCP_ASSIGN_OR_RETURN(
352 std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
353 request_helper_.PerformMultipleProtocolRequests(std::move(requests),
354 interruptible_runner_));
355 for (const auto& response : responses) {
356 if (!response.ok()) {
357 return response.status();
358 }
359 }
360 FCP_ASSIGN_OR_RETURN(std::string submit_result_uri_suffix,
361 CreateSubmitSecureAggregationResultUriSuffix(
362 aggregation_id_, client_token_));
363 SubmitSecureAggregationResultRequest request;
364 *request.mutable_masked_result_resource_name() = masked_resource_name_;
365 if (has_checkpoint) {
366 *request.mutable_nonmasked_result_resource_name() =
367 nonmasked_resource_name_;
368 }
369 FCP_ASSIGN_OR_RETURN(
370 std::unique_ptr<HttpRequest> submit_result_request,
371 secagg_request_creator_->CreateProtocolRequest(
372 submit_result_uri_suffix, QueryParams(), HttpRequest::Method::kPost,
373 request.SerializeAsString(),
374 /*is_protobuf_encoded=*/true));
375 FCP_ASSIGN_OR_RETURN(
376 InMemoryHttpResponse response,
377 request_helper_.PerformProtocolRequest(std::move(submit_result_request),
378 interruptible_runner_));
379 // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
380 // ParseOperationProtoFromHttpResponse(response));
381 // FCP_ASSIGN_OR_RETURN(
382 // Operation completed_operation,
383 // request_helper_.PollOperationResponseUntilDone(
384 // initial_operation, *secagg_request_creator_,
385 // interruptible_runner_));
386
387 // // The Operation has finished. Check if it resulted in an error, and if so
388 // // forward it after converting it to an absl::Status error.
389 // if (completed_operation.has_error()) {
390 // return ConvertRpcStatusToAbslStatus(completed_operation.error());
391 // }
392 SubmitSecureAggregationResultResponse response_proto;
393 if (!response_proto.ParseFromString(std::string(response.body))) {
394 return absl::InvalidArgumentError(
395 "could not parse SubmitSecureAggregationResultResponse proto");
396 }
397 secagg::ServerToClientWrapperMessage server_message;
398 *server_message.mutable_unmasking_request() =
399 response_proto.unmasking_server_request();
400 return server_message;
401 }
402
403 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR3Unmask(secagg::UnmaskingResponse unmasking_response)404 HttpSecAggSendToServerImpl::DoR3Unmask(
405 secagg::UnmaskingResponse unmasking_response) {
406 FCP_ASSIGN_OR_RETURN(std::string unmask_uri_suffix,
407 CreateUnmaskUriSuffix(aggregation_id_, client_token_));
408 UnmaskRequest request;
409 *request.mutable_unmasking_client_response() = unmasking_response;
410 FCP_ASSIGN_OR_RETURN(
411 std::unique_ptr<HttpRequest> unmask_request,
412 secagg_request_creator_->CreateProtocolRequest(
413 unmask_uri_suffix, QueryParams(), HttpRequest::Method::kPost,
414 request.SerializeAsString(),
415 /*is_protobuf_encoded=*/true));
416 FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse unmask_response,
417 request_helper_.PerformProtocolRequest(
418 std::move(unmask_request), interruptible_runner_));
419 return secagg::ServerToClientWrapperMessage();
420 }
421
422 // TODO(team): remove GetModulus method, merge it into SecAggRunner.
GetModulus(const std::string & key)423 absl::StatusOr<uint64_t> HttpSecAggProtocolDelegate::GetModulus(
424 const std::string& key) {
425 if (!secure_aggregands_.contains(key)) {
426 return absl::InternalError(
427 absl::StrCat("Execution not found for aggregand: ", key));
428 }
429 return secure_aggregands_[key].modulus();
430 }
431
432 absl::StatusOr<secagg::ServerToClientWrapperMessage>
ReceiveServerMessage()433 HttpSecAggProtocolDelegate::ReceiveServerMessage() {
434 return server_response_holder_;
435 }
436
Abort()437 void HttpSecAggProtocolDelegate::Abort() {
438 // Intentional to be blank because we don't have internal states to clear.
439 }
440
last_received_message_size()441 size_t HttpSecAggProtocolDelegate::last_received_message_size() {
442 if (server_response_holder_.ok()) {
443 return server_response_holder_->ByteSizeLong();
444 } else {
445 // If the last request failed, return zero.
446 return 0;
447 }
448 }
449
450 } // namespace http
451 } // namespace client
452 } // namespace fcp
453