1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_ 17 #define FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <optional> 22 #include <string> 23 #include <utility> 24 25 #include "absl/status/status.h" 26 #include "absl/status/statusor.h" 27 #include "fcp/client/http/protocol_request_helper.h" 28 #include "fcp/client/secagg_event_publisher.h" 29 #include "fcp/client/secagg_runner.h" 30 #include "fcp/protos/federatedcompute/common.pb.h" 31 #include "fcp/protos/federatedcompute/secure_aggregations.pb.h" 32 #include "fcp/secagg/shared/secagg_messages.pb.h" 33 34 namespace fcp { 35 namespace client { 36 namespace http { 37 38 // Implementation of SecAggSendToServerBase for HTTP federated protocol. 39 class HttpSecAggSendToServerImpl : public SecAggSendToServerBase { 40 public: 41 // Create an instance of HttpSecAggSendToServerImpl. 42 // This method returns error status when failed to create 43 // ProtocolRequestCreator based on the input ForwardingInfo or 44 // ByteStreamResources. 45 static absl::StatusOr<std::unique_ptr<HttpSecAggSendToServerImpl>> Create( 46 absl::string_view api_key, Clock* clock, 47 ProtocolRequestHelper* request_helper, 48 InterruptibleRunner* interruptible_runner, 49 std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)> 50 delayed_interruptible_runner_creator, 51 absl::StatusOr<secagg::ServerToClientWrapperMessage>* 52 server_response_holder, 53 absl::string_view aggregation_id, absl::string_view client_token, 54 const google::internal::federatedcompute::v1::ForwardingInfo& 55 secagg_upload_forwarding_info, 56 const google::internal::federatedcompute::v1::ByteStreamResource& 57 masked_result_resource, 58 const google::internal::federatedcompute::v1::ByteStreamResource& 59 nonmasked_result_resource, 60 std::optional<std::string> tf_checkpoint, 61 bool disable_request_body_compression, 62 absl::Duration waiting_period_for_cancellation); 63 ~HttpSecAggSendToServerImpl() override = default; 64 65 // Sends a client to server request based on the 66 // secagg::ClientToServerWrapperMessage, waits for the response, and set it to 67 // the server response holder. 68 void Send(secagg::ClientToServerWrapperMessage* message) override; 69 70 private: HttpSecAggSendToServerImpl(absl::string_view api_key,Clock * clock,ProtocolRequestHelper * request_helper,InterruptibleRunner * interruptible_runner,std::function<std::unique_ptr<InterruptibleRunner> (absl::Time)> delayed_interruptible_runner_creator,absl::StatusOr<secagg::ServerToClientWrapperMessage> * server_response_holder,absl::string_view aggregation_id,absl::string_view client_token,absl::string_view masked_resource_name,absl::string_view nonmasked_resource_name,std::unique_ptr<ProtocolRequestCreator> secagg_request_creator,std::unique_ptr<ProtocolRequestCreator> masked_result_upload_request_creator,std::unique_ptr<ProtocolRequestCreator> nonmasked_result_upload_request_creator,std::optional<std::string> tf_checkpoint,absl::Duration waiting_period_for_cancellation)71 HttpSecAggSendToServerImpl( 72 absl::string_view api_key, Clock* clock, 73 ProtocolRequestHelper* request_helper, 74 InterruptibleRunner* interruptible_runner, 75 std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)> 76 delayed_interruptible_runner_creator, 77 absl::StatusOr<secagg::ServerToClientWrapperMessage>* 78 server_response_holder, 79 absl::string_view aggregation_id, absl::string_view client_token, 80 absl::string_view masked_resource_name, 81 absl::string_view nonmasked_resource_name, 82 std::unique_ptr<ProtocolRequestCreator> secagg_request_creator, 83 std::unique_ptr<ProtocolRequestCreator> 84 masked_result_upload_request_creator, 85 std::unique_ptr<ProtocolRequestCreator> 86 nonmasked_result_upload_request_creator, 87 std::optional<std::string> tf_checkpoint, 88 absl::Duration waiting_period_for_cancellation) 89 : api_key_(api_key), 90 clock_(*clock), 91 request_helper_(*request_helper), 92 interruptible_runner_(*interruptible_runner), 93 delayed_interruptible_runner_creator_( 94 delayed_interruptible_runner_creator), 95 server_response_holder_(*server_response_holder), 96 aggregation_id_(std::string(aggregation_id)), 97 client_token_(std::string(client_token)), 98 masked_resource_name_(std::string(masked_resource_name)), 99 nonmasked_resource_name_(std::string(nonmasked_resource_name)), 100 secagg_request_creator_(std::move(secagg_request_creator)), 101 masked_result_upload_request_creator_( 102 std::move(masked_result_upload_request_creator)), 103 nonmasked_result_upload_request_creator_( 104 std::move(nonmasked_result_upload_request_creator)), 105 tf_checkpoint_(std::move(tf_checkpoint)), 106 waiting_period_for_cancellation_(waiting_period_for_cancellation) {} 107 108 // Sends an AbortSecureAggregationRequest. 109 absl::StatusOr<secagg::ServerToClientWrapperMessage> AbortSecureAggregation( 110 secagg::AbortMessage abort_message); 111 // Sends an AdvertiseKeysRequest and waits for the AdvertiseKeysResponse, 112 // polling the corresponding LRO if needed. 113 absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR0AdvertiseKeys( 114 secagg::AdvertiseKeys advertise_keys); 115 // Sends an ShareKeysRequest and waits for the ShareKeysResponse, polling 116 // the corresponding LRO if needed. 117 absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR1ShareKeys( 118 secagg::ShareKeysResponse share_keys_response); 119 // Uploads masked resource and (optional) nonmasked resource. After successful 120 // upload, sends an SubmitSecureAggregationResultRequest and waits for the 121 // SubmitSecureAggregationResultResponse, polling the corresponding LRO if 122 // needed. 123 absl::StatusOr<secagg::ServerToClientWrapperMessage> 124 DoR2SubmitSecureAggregationResult( 125 secagg::MaskedInputCollectionResponse masked_input_response); 126 // Sends an UnmaskRequest and waits for the UnmaskResponse. 127 absl::StatusOr<secagg::ServerToClientWrapperMessage> DoR3Unmask( 128 secagg::UnmaskingResponse unmasking_response); 129 const std::string api_key_; 130 Clock& clock_; 131 ProtocolRequestHelper& request_helper_; 132 InterruptibleRunner& interruptible_runner_; 133 std::function<std::unique_ptr<InterruptibleRunner>(absl::Time)> 134 delayed_interruptible_runner_creator_; 135 absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_; 136 std::string aggregation_id_; 137 std::string client_token_; 138 std::string masked_resource_name_; 139 std::string nonmasked_resource_name_; 140 std::unique_ptr<ProtocolRequestCreator> secagg_request_creator_; 141 std::unique_ptr<ProtocolRequestCreator> masked_result_upload_request_creator_; 142 std::unique_ptr<ProtocolRequestCreator> 143 nonmasked_result_upload_request_creator_; 144 std::optional<std::string> tf_checkpoint_; 145 absl::Duration waiting_period_for_cancellation_; 146 }; 147 148 // Implementation of SecAggProtocolDelegate for the HTTP federated protocol. 149 class HttpSecAggProtocolDelegate : public SecAggProtocolDelegate { 150 public: HttpSecAggProtocolDelegate(google::protobuf::Map<std::string,google::internal::federatedcompute::v1::SecureAggregandExecutionInfo> secure_aggregands,absl::StatusOr<secagg::ServerToClientWrapperMessage> * server_response_holder)151 HttpSecAggProtocolDelegate( 152 google::protobuf::Map< 153 std::string, 154 google::internal::federatedcompute::v1::SecureAggregandExecutionInfo> 155 secure_aggregands, 156 absl::StatusOr<secagg::ServerToClientWrapperMessage>* 157 server_response_holder) 158 : secure_aggregands_(std::move(secure_aggregands)), 159 server_response_holder_(*server_response_holder) {} 160 // Retrieve the modulus for a given SecAgg vector. 161 absl::StatusOr<uint64_t> GetModulus(const std::string& key) override; 162 // Receive Server message. 163 absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage() 164 override; 165 // Called when the SecAgg protocol is interrupted. 166 void Abort() override; 167 size_t last_received_message_size() override; 168 169 private: 170 google::protobuf::Map< 171 std::string, 172 google::internal::federatedcompute::v1::SecureAggregandExecutionInfo> 173 secure_aggregands_; 174 absl::StatusOr<secagg::ServerToClientWrapperMessage>& server_response_holder_; 175 }; 176 177 } // namespace http 178 } // namespace client 179 } // namespace fcp 180 181 #endif // FCP_CLIENT_HTTP_HTTP_SECAGG_SEND_TO_SERVER_IMPL_H_ 182