xref: /aosp_15_r20/external/federated-compute/fcp/client/http/http_secagg_send_to_server_impl.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/http/http_secagg_send_to_server_impl.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "google/protobuf/any.pb.h"
26 // #include "google/rpc/code.pb.h"
27 #include "absl/strings/substitute.h"
28 #include "fcp/client/http/http_client_util.h"
29 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
30 
31 namespace fcp {
32 namespace client {
33 namespace http {
34 
35 using ::google::internal::federatedcompute::v1::AbortSecureAggregationRequest;
36 using ::google::internal::federatedcompute::v1::AdvertiseKeysRequest;
37 using ::google::internal::federatedcompute::v1::AdvertiseKeysResponse;
38 using ::google::internal::federatedcompute::v1::ByteStreamResource;
39 using ::google::internal::federatedcompute::v1::ForwardingInfo;
40 using ::google::internal::federatedcompute::v1::ShareKeysRequest;
41 using ::google::internal::federatedcompute::v1::ShareKeysResponse;
42 using ::google::internal::federatedcompute::v1::
43     SubmitSecureAggregationResultRequest;
44 using ::google::internal::federatedcompute::v1::
45     SubmitSecureAggregationResultResponse;
46 using ::google::internal::federatedcompute::v1::UnmaskRequest;
47 // using ::google::longrunning::Operation;
48 
49 namespace {
CreateAbortSecureAggregationUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)50 absl::StatusOr<std::string> CreateAbortSecureAggregationUriSuffix(
51     absl::string_view aggregation_id, absl::string_view client_token) {
52   constexpr absl::string_view pattern =
53       "/v1/secureaggregations/$0/clients/$1:abort";
54   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
55                        EncodeUriSinglePathSegment(aggregation_id));
56   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
57                        EncodeUriSinglePathSegment(client_token));
58   // Construct the URI suffix.
59   return absl::Substitute(pattern, encoded_aggregation_id,
60                           encoded_client_token);
61 }
62 
CreateAdvertiseKeysUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)63 absl::StatusOr<std::string> CreateAdvertiseKeysUriSuffix(
64     absl::string_view aggregation_id, absl::string_view client_token) {
65   constexpr absl::string_view pattern =
66       "/v1/secureaggregations/$0/clients/$1:advertisekeys";
67   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
68                        EncodeUriSinglePathSegment(aggregation_id));
69   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
70                        EncodeUriSinglePathSegment(client_token));
71   // Construct the URI suffix.
72   return absl::Substitute(pattern, encoded_aggregation_id,
73                           encoded_client_token);
74 }
75 
CreateShareKeysUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)76 absl::StatusOr<std::string> CreateShareKeysUriSuffix(
77     absl::string_view aggregation_id, absl::string_view client_token) {
78   constexpr absl::string_view pattern =
79       "/v1/secureaggregations/$0/clients/$1:sharekeys";
80   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
81                        EncodeUriSinglePathSegment(aggregation_id));
82   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
83                        EncodeUriSinglePathSegment(client_token));
84   // Construct the URI suffix.
85   return absl::Substitute(pattern, encoded_aggregation_id,
86                           encoded_client_token);
87 }
88 
CreateSubmitSecureAggregationResultUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)89 absl::StatusOr<std::string> CreateSubmitSecureAggregationResultUriSuffix(
90     absl::string_view aggregation_id, absl::string_view client_token) {
91   constexpr absl::string_view pattern =
92       "/v1/secureaggregations/$0/clients/$1:submit";
93   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
94                        EncodeUriSinglePathSegment(aggregation_id));
95   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
96                        EncodeUriSinglePathSegment(client_token));
97   // Construct the URI suffix.
98   return absl::Substitute(pattern, encoded_aggregation_id,
99                           encoded_client_token);
100 }
101 
CreateUnmaskUriSuffix(absl::string_view aggregation_id,absl::string_view client_token)102 absl::StatusOr<std::string> CreateUnmaskUriSuffix(
103     absl::string_view aggregation_id, absl::string_view client_token) {
104   constexpr absl::string_view pattern =
105       "/v1/secureaggregations/$0/clients/$1:unmask";
106   FCP_ASSIGN_OR_RETURN(std::string encoded_aggregation_id,
107                        EncodeUriSinglePathSegment(aggregation_id));
108   FCP_ASSIGN_OR_RETURN(std::string encoded_client_token,
109                        EncodeUriSinglePathSegment(client_token));
110   // Construct the URI suffix.
111   return absl::Substitute(pattern, encoded_aggregation_id,
112                           encoded_client_token);
113 }
114 
115 }  // anonymous namespace
116 
117 absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>>
Create(absl::string_view api_key,Clock * clock,ProtocolRequestHelper * request_helper,InterruptibleRunner * interruptible_runner,std::function<std::unique_ptr<InterruptibleRunner> (absl::Time)> delayed_interruptible_runner_creator,absl::StatusOr<secagg::ServerToClientWrapperMessage> * server_response_holder,absl::string_view aggregation_id,absl::string_view client_token,const ForwardingInfo & secagg_upload_forwarding_info,const ByteStreamResource & masked_result_resource,const ByteStreamResource & nonmasked_result_resource,std::optional<std::string> tf_checkpoint,bool disable_request_body_compression,absl::Duration waiting_period_for_cancellation)118 HttpSecAggSendToServerImpl::Create(
119     absl::string_view api_key, Clock* clock,
120     ProtocolRequestHelper* request_helper,
121     InterruptibleRunner* interruptible_runner,
122     std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
123         delayed_interruptible_runner_creator,
124     absl::StatusOr<secagg::ServerToClientWrapperMessage>*
125         server_response_holder,
126     absl::string_view aggregation_id, absl::string_view client_token,
127     const ForwardingInfo& secagg_upload_forwarding_info,
128     const ByteStreamResource& masked_result_resource,
129     const ByteStreamResource& nonmasked_result_resource,
130     std::optional<std::string> tf_checkpoint,
131     bool disable_request_body_compression,
132     absl::Duration waiting_period_for_cancellation) {
133   FCP_ASSIGN_OR_RETURN(
134       std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,
135       ProtocolRequestCreator::Create(api_key, secagg_upload_forwarding_info,
136                                      !disable_request_body_compression));
137   // We don't use request body compression for resource upload.
138   FCP_ASSIGN_OR_RETURN(
139       std::unique_ptr<ProtocolRequestCreator>
140           masked_result_upload_request_creator,
141       ProtocolRequestCreator::Create(
142           api_key, masked_result_resource.data_upload_forwarding_info(),
143           /*use_compression=*/false));
144   // We don't use request body compression for resource upload.
145   FCP_ASSIGN_OR_RETURN(
146       std::unique_ptr<ProtocolRequestCreator>
147           nonmasked_result_upload_request_creator,
148       ProtocolRequestCreator::Create(
149           api_key, nonmasked_result_resource.data_upload_forwarding_info(),
150           /*use_compression=*/false));
151 
152   return absl::WrapUnique(new HttpSecAggSendToServerImpl(
153       api_key, clock, request_helper, interruptible_runner,
154       std::move(delayed_interruptible_runner_creator), server_response_holder,
155       aggregation_id, client_token, masked_result_resource.resource_name(),
156       nonmasked_result_resource.resource_name(),
157       std::move(secagg_request_creator),
158       std::move(masked_result_upload_request_creator),
159       std::move(nonmasked_result_upload_request_creator),
160       std::move(tf_checkpoint), waiting_period_for_cancellation));
161 }
162 
163 // Despite the method name is "Send", this method is doing more. It sends the
164 // request, waits for the response and set the response to the response holder
165 // for the secagg client to access in the next round of secagg communications.
166 //
167 // The current SecAgg library is built around the assumption that the underlying
168 // network protocol is fully asynchronous and bidirectional. This was true for
169 // the gRPC protocol but isn't the case anymore for the HTTP protocol (which has
170 // a more traditional request/response structure). Nevertheless, because we
171 // still need to support the gRPC protocol the structure of the SecAgg library
172 // cannot be changed yet, and this means that currently we need to store away
173 // the result and let the secagg client to access on a later time. However, once
174 // the gRPC protocol support is removed, we should consider updating the SecAgg
175 // library to assume the more traditional request/response structure (e.g. by
176 // having SecAggSendToServer::Send return the corresponding response message).
177 //
178 // TODO(team): Simplify SecAgg library around request/response structure
179 // once gRPC support is removed.
Send(secagg::ClientToServerWrapperMessage * message)180 void HttpSecAggSendToServerImpl::Send(
181     secagg::ClientToServerWrapperMessage* message) {
182   absl::StatusOr<secagg::ServerToClientWrapperMessage> server_message;
183   if (message->has_advertise_keys()) {
184     server_response_holder_ =
185         DoR0AdvertiseKeys(std::move(message->advertise_keys()));
186   } else if (message->has_share_keys_response()) {
187     server_response_holder_ =
188         DoR1ShareKeys(std::move(message->share_keys_response()));
189   } else if (message->has_masked_input_response()) {
190     server_response_holder_ = DoR2SubmitSecureAggregationResult(
191         std::move(message->masked_input_response()));
192   } else if (message->has_unmasking_response()) {
193     server_response_holder_ =
194         DoR3Unmask(std::move(message->unmasking_response()));
195   } else if (message->has_abort()) {
196     server_response_holder_ =
197         AbortSecureAggregation(std::move(message->abort()));
198   } else {
199     // When the protocol succeeds, the ClientToServerWrapperMessage will be
200     // empty, and we'll just set the empty server message.
201     server_response_holder_ = secagg::ServerToClientWrapperMessage();
202   }
203 }
204 
205 absl::StatusOr<secagg::ServerToClientWrapperMessage>
AbortSecureAggregation(secagg::AbortMessage abort_message)206 HttpSecAggSendToServerImpl::AbortSecureAggregation(
207     secagg::AbortMessage abort_message) {
208   FCP_ASSIGN_OR_RETURN(
209       std::string uri_suffix,
210       CreateAbortSecureAggregationUriSuffix(aggregation_id_, client_token_));
211 
212   AbortSecureAggregationRequest request;
213   ::google::internal::federatedcompute::v1::Status* status =
214       request.mutable_status();
215   status->set_code(13);
216   status->set_message(abort_message.diagnostic_info());
217 
218   FCP_ASSIGN_OR_RETURN(
219       std::unique_ptr<HttpRequest> http_request,
220       secagg_request_creator_->CreateProtocolRequest(
221           uri_suffix, QueryParams(), HttpRequest::Method::kPost,
222           request.SerializeAsString(),
223           /*is_protobuf_encoded=*/true));
224   std::unique_ptr<InterruptibleRunner> delayed_interruptible_runner =
225       delayed_interruptible_runner_creator_(clock_.Now() +
226                                             waiting_period_for_cancellation_);
227   FCP_ASSIGN_OR_RETURN(
228       InMemoryHttpResponse response,
229       request_helper_.PerformProtocolRequest(std::move(http_request),
230                                              *delayed_interruptible_runner));
231 
232   secagg::ServerToClientWrapperMessage server_message;
233   server_message.mutable_abort();
234   return server_message;
235 }
236 
237 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR0AdvertiseKeys(secagg::AdvertiseKeys advertise_keys)238 HttpSecAggSendToServerImpl::DoR0AdvertiseKeys(
239     secagg::AdvertiseKeys advertise_keys) {
240   FCP_ASSIGN_OR_RETURN(
241       std::string uri_suffix,
242       CreateAdvertiseKeysUriSuffix(aggregation_id_, client_token_));
243 
244   AdvertiseKeysRequest request;
245   *request.mutable_advertise_keys() = advertise_keys;
246   FCP_ASSIGN_OR_RETURN(
247       std::unique_ptr<HttpRequest> http_request,
248       secagg_request_creator_->CreateProtocolRequest(
249           uri_suffix, QueryParams(), HttpRequest::Method::kPost,
250           request.SerializeAsString(),
251           /*is_protobuf_encoded=*/true));
252   FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse response,
253                        request_helper_.PerformProtocolRequest(
254                            std::move(http_request), interruptible_runner_));
255   // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
256   //                      ParseOperationProtoFromHttpResponse(response));
257 
258   // FCP_ASSIGN_OR_RETURN(
259   //     Operation completed_operation,
260   //     request_helper_.PollOperationResponseUntilDone(
261   //         initial_operation, *secagg_request_creator_,
262   //         interruptible_runner_));
263 
264   // // The Operation has finished. Check if it resulted in an error, and if so
265   // // forward it after converting it to an absl::Status error.
266   // if (completed_operation.has_error()) {
267   //   return ConvertRpcStatusToAbslStatus(completed_operation.error());
268   // }
269   AdvertiseKeysResponse response_proto;
270   if (!response_proto.ParseFromString(std::string(response.body))) {
271     return absl::InternalError("could not parse AdvertiseKeysResponse proto");
272   }
273   secagg::ServerToClientWrapperMessage server_message;
274   *server_message.mutable_share_keys_request() =
275       response_proto.share_keys_server_request();
276   return server_message;
277 }
278 
279 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR1ShareKeys(secagg::ShareKeysResponse share_keys_response)280 HttpSecAggSendToServerImpl::DoR1ShareKeys(
281     secagg::ShareKeysResponse share_keys_response) {
282   FCP_ASSIGN_OR_RETURN(
283       std::string uri_suffix,
284       CreateShareKeysUriSuffix(aggregation_id_, client_token_));
285 
286   ShareKeysRequest request;
287   *request.mutable_share_keys_client_response() = share_keys_response;
288   FCP_ASSIGN_OR_RETURN(
289       std::unique_ptr<HttpRequest> http_request,
290       secagg_request_creator_->CreateProtocolRequest(
291           uri_suffix, QueryParams(), HttpRequest::Method::kPost,
292           request.SerializeAsString(),
293           /*is_protobuf_encoded=*/true));
294 
295   FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse response,
296                        request_helper_.PerformProtocolRequest(
297                            std::move(http_request), interruptible_runner_));
298   // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
299   //                      ParseOperationProtoFromHttpResponse(response));
300 
301   // FCP_ASSIGN_OR_RETURN(
302   //     Operation completed_operation,
303   //     request_helper_.PollOperationResponseUntilDone(
304   //         initial_operation, *secagg_request_creator_,
305   //         interruptible_runner_));
306 
307   // // The Operation has finished. Check if it resulted in an error, and if so
308   // // forward it after converting it to an absl::Status error.
309   // if (completed_operation.has_error()) {
310   //   return ConvertRpcStatusToAbslStatus(completed_operation.error());
311   // }
312   ShareKeysResponse response_proto;
313   if (!response_proto.ParseFromString(std::string(response.body))) {
314     return absl::InternalError(
315         "could not parse StartSecureAggregationResponse proto");
316   }
317   secagg::ServerToClientWrapperMessage server_message;
318   *server_message.mutable_masked_input_request() =
319       response_proto.masked_input_collection_server_request();
320   return server_message;
321 }
322 
323 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR2SubmitSecureAggregationResult(secagg::MaskedInputCollectionResponse masked_input_response)324 HttpSecAggSendToServerImpl::DoR2SubmitSecureAggregationResult(
325     secagg::MaskedInputCollectionResponse masked_input_response) {
326   std::vector<std::unique_ptr<HttpRequest>> requests;
327   FCP_ASSIGN_OR_RETURN(std::string masked_result_upload_uri_suffix,
328                        CreateByteStreamUploadUriSuffix(masked_resource_name_));
329 
330   FCP_ASSIGN_OR_RETURN(
331       std::unique_ptr<HttpRequest> masked_input_upload_request,
332       masked_result_upload_request_creator_->CreateProtocolRequest(
333           masked_result_upload_uri_suffix, {{"upload_protocol", "raw"}},
334           HttpRequest::Method::kPost,
335           std::move(masked_input_response).SerializeAsString(),
336           /*is_protobuf_encoded=*/false));
337   requests.push_back(std::move(masked_input_upload_request));
338   bool has_checkpoint = tf_checkpoint_.has_value();
339   if (has_checkpoint) {
340     FCP_ASSIGN_OR_RETURN(
341         std::string nonmasked_result_upload_uri_suffix,
342         CreateByteStreamUploadUriSuffix(nonmasked_resource_name_));
343     FCP_ASSIGN_OR_RETURN(
344         std::unique_ptr<HttpRequest> nonmasked_input_upload_request,
345         nonmasked_result_upload_request_creator_->CreateProtocolRequest(
346             nonmasked_result_upload_uri_suffix, {{"upload_protocol", "raw"}},
347             HttpRequest::Method::kPost, std::move(tf_checkpoint_).value(),
348             /*is_protobuf_encoded=*/false));
349     requests.push_back(std::move(nonmasked_input_upload_request));
350   }
351   FCP_ASSIGN_OR_RETURN(
352       std::vector<absl::StatusOr<InMemoryHttpResponse>> responses,
353       request_helper_.PerformMultipleProtocolRequests(std::move(requests),
354                                                       interruptible_runner_));
355   for (const auto& response : responses) {
356     if (!response.ok()) {
357       return response.status();
358     }
359   }
360   FCP_ASSIGN_OR_RETURN(std::string submit_result_uri_suffix,
361                        CreateSubmitSecureAggregationResultUriSuffix(
362                            aggregation_id_, client_token_));
363   SubmitSecureAggregationResultRequest request;
364   *request.mutable_masked_result_resource_name() = masked_resource_name_;
365   if (has_checkpoint) {
366     *request.mutable_nonmasked_result_resource_name() =
367         nonmasked_resource_name_;
368   }
369   FCP_ASSIGN_OR_RETURN(
370       std::unique_ptr<HttpRequest> submit_result_request,
371       secagg_request_creator_->CreateProtocolRequest(
372           submit_result_uri_suffix, QueryParams(), HttpRequest::Method::kPost,
373           request.SerializeAsString(),
374           /*is_protobuf_encoded=*/true));
375   FCP_ASSIGN_OR_RETURN(
376       InMemoryHttpResponse response,
377       request_helper_.PerformProtocolRequest(std::move(submit_result_request),
378                                              interruptible_runner_));
379   // FCP_ASSIGN_OR_RETURN(Operation initial_operation,
380   //                      ParseOperationProtoFromHttpResponse(response));
381   // FCP_ASSIGN_OR_RETURN(
382   //     Operation completed_operation,
383   //     request_helper_.PollOperationResponseUntilDone(
384   //         initial_operation, *secagg_request_creator_,
385   //         interruptible_runner_));
386 
387   // // The Operation has finished. Check if it resulted in an error, and if so
388   // // forward it after converting it to an absl::Status error.
389   // if (completed_operation.has_error()) {
390   //   return ConvertRpcStatusToAbslStatus(completed_operation.error());
391   // }
392   SubmitSecureAggregationResultResponse response_proto;
393   if (!response_proto.ParseFromString(std::string(response.body))) {
394     return absl::InvalidArgumentError(
395         "could not parse SubmitSecureAggregationResultResponse proto");
396   }
397   secagg::ServerToClientWrapperMessage server_message;
398   *server_message.mutable_unmasking_request() =
399       response_proto.unmasking_server_request();
400   return server_message;
401 }
402 
403 absl::StatusOr<secagg::ServerToClientWrapperMessage>
DoR3Unmask(secagg::UnmaskingResponse unmasking_response)404 HttpSecAggSendToServerImpl::DoR3Unmask(
405     secagg::UnmaskingResponse unmasking_response) {
406   FCP_ASSIGN_OR_RETURN(std::string unmask_uri_suffix,
407                        CreateUnmaskUriSuffix(aggregation_id_, client_token_));
408   UnmaskRequest request;
409   *request.mutable_unmasking_client_response() = unmasking_response;
410   FCP_ASSIGN_OR_RETURN(
411       std::unique_ptr<HttpRequest> unmask_request,
412       secagg_request_creator_->CreateProtocolRequest(
413           unmask_uri_suffix, QueryParams(), HttpRequest::Method::kPost,
414           request.SerializeAsString(),
415           /*is_protobuf_encoded=*/true));
416   FCP_ASSIGN_OR_RETURN(InMemoryHttpResponse unmask_response,
417                        request_helper_.PerformProtocolRequest(
418                            std::move(unmask_request), interruptible_runner_));
419   return secagg::ServerToClientWrapperMessage();
420 }
421 
422 // TODO(team): remove GetModulus method, merge it into SecAggRunner.
GetModulus(const std::string & key)423 absl::StatusOr<uint64_t> HttpSecAggProtocolDelegate::GetModulus(
424     const std::string& key) {
425   if (!secure_aggregands_.contains(key)) {
426     return absl::InternalError(
427         absl::StrCat("Execution not found for aggregand: ", key));
428   }
429   return secure_aggregands_[key].modulus();
430 }
431 
432 absl::StatusOr<secagg::ServerToClientWrapperMessage>
ReceiveServerMessage()433 HttpSecAggProtocolDelegate::ReceiveServerMessage() {
434   return server_response_holder_;
435 }
436 
Abort()437 void HttpSecAggProtocolDelegate::Abort() {
438   // Intentional to be blank because we don't have internal states to clear.
439 }
440 
last_received_message_size()441 size_t HttpSecAggProtocolDelegate::last_received_message_size() {
442   if (server_response_holder_.ok()) {
443     return server_response_holder_->ByteSizeLong();
444   } else {
445     // If the last request failed, return zero.
446     return 0;
447   }
448 }
449 
450 }  // namespace http
451 }  // namespace client
452 }  // namespace fcp
453