xref: /aosp_15_r20/external/federated-compute/fcp/client/http/http_secagg_send_to_server_impl.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_
17 #define FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "fcp/client/http/protocol_request_helper.h"
28 #include "fcp/client/secagg_event_publisher.h"
29 #include "fcp/client/secagg_runner.h"
30 #include "fcp/protos/federatedcompute/common.pb.h"
31 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 
34 namespace fcp {
35 namespace client {
36 namespace http {
37 
38 // Implementation of SecAggSendToServerBase for HTTP federated protocol.
39 class HttpSecAggSendToServerImpl : public SecAggSendToServerBase {
40  public:
41   // Create an instance of HttpSecAggSendToServerImpl.
42   // This method returns error status when failed to create
43   // ProtocolRequestCreator based on the input ForwardingInfo or
44   // ByteStreamResources.
45   static absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>> Create(
46       absl::string_view api_key, Clock* clock,
47       ProtocolRequestHelper* request_helper,
48       InterruptibleRunner* interruptible_runner,
49       std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
50           delayed_interruptible_runner_creator,
51       absl::StatusOr<secagg::ServerToClientWrapperMessage>*
52           server_response_holder,
53       absl::string_view aggregation_id, absl::string_view client_token,
54       const google::internal::federatedcompute::v1::ForwardingInfo&
55           secagg_upload_forwarding_info,
56       const google::internal::federatedcompute::v1::ByteStreamResource&
57           masked_result_resource,
58       const google::internal::federatedcompute::v1::ByteStreamResource&
59           nonmasked_result_resource,
60       std::optional<std::string> tf_checkpoint,
61       bool disable_request_body_compression,
62       absl::Duration waiting_period_for_cancellation);
63   ~HttpSecAggSendToServerImpl() override = default;
64 
65   // Sends a client to server request based on the
66   // secagg::ClientToServerWrapperMessage, waits for the response, and set it to
67   // the server response holder.
68   void Send(secagg::ClientToServerWrapperMessage* message) override;
69 
70  private:
HttpSecAggSendToServerImpl(absl::string_view api_key,Clock * clock,ProtocolRequestHelper * request_helper,InterruptibleRunner * interruptible_runner,std::function<std::unique_ptr<InterruptibleRunner> (absl::Time)> delayed_interruptible_runner_creator,absl::StatusOr<secagg::ServerToClientWrapperMessage> * server_response_holder,absl::string_view aggregation_id,absl::string_view client_token,absl::string_view masked_resource_name,absl::string_view nonmasked_resource_name,std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,std::unique_ptr<ProtocolRequestCreator> masked_result_upload_request_creator,std::unique_ptr<ProtocolRequestCreator> nonmasked_result_upload_request_creator,std::optional<std::string> tf_checkpoint,absl::Duration waiting_period_for_cancellation)71   HttpSecAggSendToServerImpl(
72       absl::string_view api_key, Clock* clock,
73       ProtocolRequestHelper* request_helper,
74       InterruptibleRunner* interruptible_runner,
75       std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
76           delayed_interruptible_runner_creator,
77       absl::StatusOr<secagg::ServerToClientWrapperMessage>*
78           server_response_holder,
79       absl::string_view aggregation_id, absl::string_view client_token,
80       absl::string_view masked_resource_name,
81       absl::string_view nonmasked_resource_name,
82       std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,
83       std::unique_ptr<ProtocolRequestCreator>
84           masked_result_upload_request_creator,
85       std::unique_ptr<ProtocolRequestCreator>
86           nonmasked_result_upload_request_creator,
87       std::optional<std::string> tf_checkpoint,
88       absl::Duration waiting_period_for_cancellation)
89       : api_key_(api_key),
90         clock_(*clock),
91         request_helper_(*request_helper),
92         interruptible_runner_(*interruptible_runner),
93         delayed_interruptible_runner_creator_(
94             delayed_interruptible_runner_creator),
95         server_response_holder_(*server_response_holder),
96         aggregation_id_(std::string(aggregation_id)),
97         client_token_(std::string(client_token)),
98         masked_resource_name_(std::string(masked_resource_name)),
99         nonmasked_resource_name_(std::string(nonmasked_resource_name)),
100         secagg_request_creator_(std::move(secagg_request_creator)),
101         masked_result_upload_request_creator_(
102             std::move(masked_result_upload_request_creator)),
103         nonmasked_result_upload_request_creator_(
104             std::move(nonmasked_result_upload_request_creator)),
105         tf_checkpoint_(std::move(tf_checkpoint)),
106         waiting_period_for_cancellation_(waiting_period_for_cancellation) {}
107 
108   // Sends an AbortSecureAggregationRequest.
109   absl::StatusOr<secagg::ServerToClientWrapperMessage> AbortSecureAggregation(
110       secagg::AbortMessage abort_message);
111   // Sends an AdvertiseKeysRequest and waits for the AdvertiseKeysResponse,
112   // polling the corresponding LRO if needed.
113   absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR0AdvertiseKeys(
114       secagg::AdvertiseKeys advertise_keys);
115   // Sends an ShareKeysRequest and waits for the ShareKeysResponse, polling
116   // the corresponding LRO if needed.
117   absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR1ShareKeys(
118       secagg::ShareKeysResponse share_keys_response);
119   // Uploads masked resource and (optional) nonmasked resource. After successful
120   // upload, sends an SubmitSecureAggregationResultRequest and waits for the
121   // SubmitSecureAggregationResultResponse, polling the corresponding LRO if
122   // needed.
123   absl::StatusOr<secagg::ServerToClientWrapperMessage>
124   DoR2SubmitSecureAggregationResult(
125       secagg::MaskedInputCollectionResponse masked_input_response);
126   // Sends an UnmaskRequest and waits for the UnmaskResponse.
127   absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR3Unmask(
128       secagg::UnmaskingResponse unmasking_response);
129   const std::string api_key_;
130   Clock& clock_;
131   ProtocolRequestHelper& request_helper_;
132   InterruptibleRunner& interruptible_runner_;
133   std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)>
134       delayed_interruptible_runner_creator_;
135   absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_;
136   std::string aggregation_id_;
137   std::string client_token_;
138   std::string masked_resource_name_;
139   std::string nonmasked_resource_name_;
140   std::unique_ptr<ProtocolRequestCreator> secagg_request_creator_;
141   std::unique_ptr<ProtocolRequestCreator> masked_result_upload_request_creator_;
142   std::unique_ptr<ProtocolRequestCreator>
143       nonmasked_result_upload_request_creator_;
144   std::optional<std::string> tf_checkpoint_;
145   absl::Duration waiting_period_for_cancellation_;
146 };
147 
148 // Implementation of SecAggProtocolDelegate for the HTTP federated protocol.
149 class HttpSecAggProtocolDelegate : public SecAggProtocolDelegate {
150  public:
HttpSecAggProtocolDelegate(google::protobuf::Map<std::string,google::internal::federatedcompute::v1::SecureAggregandExecutionInfo> secure_aggregands,absl::StatusOr<secagg::ServerToClientWrapperMessage> * server_response_holder)151   HttpSecAggProtocolDelegate(
152       google::protobuf::Map<
153           std::string,
154           google::internal::federatedcompute::v1::SecureAggregandExecutionInfo>
155           secure_aggregands,
156       absl::StatusOr<secagg::ServerToClientWrapperMessage>*
157           server_response_holder)
158       : secure_aggregands_(std::move(secure_aggregands)),
159         server_response_holder_(*server_response_holder) {}
160   // Retrieve the modulus for a given SecAgg vector.
161   absl::StatusOr<uint64_t> GetModulus(const std::string& key) override;
162   // Receive Server message.
163   absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage()
164       override;
165   // Called when the SecAgg protocol is interrupted.
166   void Abort() override;
167   size_t last_received_message_size() override;
168 
169  private:
170   google::protobuf::Map<
171       std::string,
172       google::internal::federatedcompute::v1::SecureAggregandExecutionInfo>
173       secure_aggregands_;
174   absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_;
175 };
176 
177 }  // namespace http
178 }  // namespace client
179 }  // namespace fcp
180 
181 #endif  // FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_
182