1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_ 17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_ 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Worker #include <cstdint> 20*14675a02SAndroid Build Coastguard Worker #include <functional> 21*14675a02SAndroid Build Coastguard Worker #include <memory> 22*14675a02SAndroid Build Coastguard Worker #include <optional> 23*14675a02SAndroid Build Coastguard Worker #include <string> 24*14675a02SAndroid Build Coastguard Worker #include <utility> 25*14675a02SAndroid Build Coastguard Worker #include <variant> 26*14675a02SAndroid Build Coastguard Worker #include <vector> 27*14675a02SAndroid Build Coastguard Worker 28*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h" 29*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h" 30*14675a02SAndroid Build Coastguard Worker #include "absl/container/node_hash_map.h" 31*14675a02SAndroid Build Coastguard Worker #include "absl/random/random.h" 32*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h" 33*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h" 34*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h" 35*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 36*14675a02SAndroid Build Coastguard Worker #include "fcp/base/wall_clock_stopwatch.h" 37*14675a02SAndroid Build Coastguard Worker #include "fcp/client/cache/resource_cache.h" 38*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/engine.pb.h" 39*14675a02SAndroid Build Coastguard Worker #include "fcp/client/event_publisher.h" 40*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol.h" 41*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.pb.h" 42*14675a02SAndroid Build Coastguard Worker #include "fcp/client/flags.h" 43*14675a02SAndroid Build Coastguard Worker #include "fcp/client/grpc_bidi_stream.h" 44*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/http_client.h" 45*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/in_memory_request_response.h" 46*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h" 47*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h" 48*14675a02SAndroid Build Coastguard Worker #include "fcp/client/secagg_runner.h" 49*14675a02SAndroid Build Coastguard Worker #include "fcp/client/selector_context.pb.h" 50*14675a02SAndroid Build Coastguard Worker #include "fcp/client/stats.h" 51*14675a02SAndroid Build Coastguard Worker #include "fcp/protocol/grpc_chunked_bidi_stream.h" 52*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/federated_api.pb.h" 53*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h" 54*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/client/secagg_client.h" 55*14675a02SAndroid Build Coastguard Worker 56*14675a02SAndroid Build Coastguard Worker namespace fcp { 57*14675a02SAndroid Build Coastguard Worker namespace client { 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker // Implements a single session of the gRPC-based Federated Learning protocol. 60*14675a02SAndroid Build Coastguard Worker class GrpcFederatedProtocol : public ::fcp::client::FederatedProtocol { 61*14675a02SAndroid Build Coastguard Worker public: 62*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol( 63*14675a02SAndroid Build Coastguard Worker EventPublisher* event_publisher, LogManager* log_manager, 64*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory, 65*14675a02SAndroid Build Coastguard Worker const Flags* flags, ::fcp::client::http::HttpClient* http_client, 66*14675a02SAndroid Build Coastguard Worker const std::string& federated_service_uri, const std::string& api_key, 67*14675a02SAndroid Build Coastguard Worker const std::string& test_cert_path, absl::string_view population_name, 68*14675a02SAndroid Build Coastguard Worker absl::string_view retry_token, absl::string_view client_version, 69*14675a02SAndroid Build Coastguard Worker absl::string_view attestation_measurement, 70*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort, 71*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config, 72*14675a02SAndroid Build Coastguard Worker const int64_t grpc_channel_deadline_seconds, 73*14675a02SAndroid Build Coastguard Worker cache::ResourceCache* resource_cache); 74*14675a02SAndroid Build Coastguard Worker 75*14675a02SAndroid Build Coastguard Worker // Test c'tor. 76*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol( 77*14675a02SAndroid Build Coastguard Worker EventPublisher* event_publisher, LogManager* log_manager, 78*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory, 79*14675a02SAndroid Build Coastguard Worker const Flags* flags, ::fcp::client::http::HttpClient* http_client, 80*14675a02SAndroid Build Coastguard Worker std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream, 81*14675a02SAndroid Build Coastguard Worker absl::string_view population_name, absl::string_view retry_token, 82*14675a02SAndroid Build Coastguard Worker absl::string_view client_version, 83*14675a02SAndroid Build Coastguard Worker absl::string_view attestation_measurement, 84*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort, absl::BitGen bit_gen, 85*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config, 86*14675a02SAndroid Build Coastguard Worker cache::ResourceCache* resource_cache); 87*14675a02SAndroid Build Coastguard Worker 88*14675a02SAndroid Build Coastguard Worker ~GrpcFederatedProtocol() override; 89*14675a02SAndroid Build Coastguard Worker 90*14675a02SAndroid Build Coastguard Worker absl::StatusOr<::fcp::client::FederatedProtocol::EligibilityEvalCheckinResult> 91*14675a02SAndroid Build Coastguard Worker EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)> 92*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback) override; 93*14675a02SAndroid Build Coastguard Worker 94*14675a02SAndroid Build Coastguard Worker void ReportEligibilityEvalError(absl::Status error_status) override; 95*14675a02SAndroid Build Coastguard Worker 96*14675a02SAndroid Build Coastguard Worker absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult> Checkin( 97*14675a02SAndroid Build Coastguard Worker const std::optional< 98*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::TaskEligibilityInfo>& 99*14675a02SAndroid Build Coastguard Worker task_eligibility_info, 100*14675a02SAndroid Build Coastguard Worker std::function<void(const TaskAssignment&)> payload_uris_received_callback) 101*14675a02SAndroid Build Coastguard Worker override; 102*14675a02SAndroid Build Coastguard Worker 103*14675a02SAndroid Build Coastguard Worker absl::StatusOr<::fcp::client::FederatedProtocol::MultipleTaskAssignments> 104*14675a02SAndroid Build Coastguard Worker PerformMultipleTaskAssignments( 105*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& task_names) override; 106*14675a02SAndroid Build Coastguard Worker 107*14675a02SAndroid Build Coastguard Worker absl::Status ReportCompleted( 108*14675a02SAndroid Build Coastguard Worker ComputationResults results, absl::Duration plan_duration, 109*14675a02SAndroid Build Coastguard Worker std::optional<std::string> aggregation_session_id) override; 110*14675a02SAndroid Build Coastguard Worker 111*14675a02SAndroid Build Coastguard Worker absl::Status ReportNotCompleted( 112*14675a02SAndroid Build Coastguard Worker engine::PhaseOutcome phase_outcome, absl::Duration plan_duration, 113*14675a02SAndroid Build Coastguard Worker std::optional<std::string> aggregation_session_id) override; 114*14675a02SAndroid Build Coastguard Worker 115*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow() 116*14675a02SAndroid Build Coastguard Worker override; 117*14675a02SAndroid Build Coastguard Worker 118*14675a02SAndroid Build Coastguard Worker NetworkStats GetNetworkStats() override; 119*14675a02SAndroid Build Coastguard Worker 120*14675a02SAndroid Build Coastguard Worker private: 121*14675a02SAndroid Build Coastguard Worker // Internal implementation of reporting for use by ReportCompleted() and 122*14675a02SAndroid Build Coastguard Worker // ReportNotCompleted(). 123*14675a02SAndroid Build Coastguard Worker absl::Status Report(ComputationResults results, 124*14675a02SAndroid Build Coastguard Worker engine::PhaseOutcome phase_outcome, 125*14675a02SAndroid Build Coastguard Worker absl::Duration plan_duration); 126*14675a02SAndroid Build Coastguard Worker absl::Status ReportInternal( 127*14675a02SAndroid Build Coastguard Worker std::string tf_checkpoint, engine::PhaseOutcome phase_outcome, 128*14675a02SAndroid Build Coastguard Worker absl::Duration plan_duration, 129*14675a02SAndroid Build Coastguard Worker fcp::secagg::ClientToServerWrapperMessage* secagg_commit_message); 130*14675a02SAndroid Build Coastguard Worker 131*14675a02SAndroid Build Coastguard Worker // Helper function to send a ClientStreamMessage. If sending did not succeed, 132*14675a02SAndroid Build Coastguard Worker // closes the underlying grpc stream. If sending does succeed then it updates 133*14675a02SAndroid Build Coastguard Worker // `bytes_uploaded_`. 134*14675a02SAndroid Build Coastguard Worker absl::Status Send(google::internal::federatedml::v2::ClientStreamMessage* 135*14675a02SAndroid Build Coastguard Worker client_stream_message); 136*14675a02SAndroid Build Coastguard Worker 137*14675a02SAndroid Build Coastguard Worker // Helper function to receive a ServerStreamMessage. If receiving did not 138*14675a02SAndroid Build Coastguard Worker // succeed, closes the underlying grpc stream. If receiving does succeed then 139*14675a02SAndroid Build Coastguard Worker // it updates `bytes_downloaded_`. 140*14675a02SAndroid Build Coastguard Worker absl::Status Receive(google::internal::federatedml::v2::ServerStreamMessage* 141*14675a02SAndroid Build Coastguard Worker server_stream_message); 142*14675a02SAndroid Build Coastguard Worker 143*14675a02SAndroid Build Coastguard Worker // Helper function to compose a ProtocolOptionsRequest for eligibility eval or 144*14675a02SAndroid Build Coastguard Worker // regular checkin requests. 145*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ProtocolOptionsRequest 146*14675a02SAndroid Build Coastguard Worker CreateProtocolOptionsRequest(bool should_ack_checkin) const; 147*14675a02SAndroid Build Coastguard Worker 148*14675a02SAndroid Build Coastguard Worker // Helper function to compose and send an EligibilityEvalCheckinRequest to the 149*14675a02SAndroid Build Coastguard Worker // server. 150*14675a02SAndroid Build Coastguard Worker absl::Status SendEligibilityEvalCheckinRequest(); 151*14675a02SAndroid Build Coastguard Worker 152*14675a02SAndroid Build Coastguard Worker // Helper function to compose and send a CheckinRequest to the server. 153*14675a02SAndroid Build Coastguard Worker absl::Status SendCheckinRequest( 154*14675a02SAndroid Build Coastguard Worker const std::optional< 155*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::TaskEligibilityInfo>& 156*14675a02SAndroid Build Coastguard Worker task_eligibility_info); 157*14675a02SAndroid Build Coastguard Worker 158*14675a02SAndroid Build Coastguard Worker // Helper to receive + process a CheckinRequestAck message. 159*14675a02SAndroid Build Coastguard Worker absl::Status ReceiveCheckinRequestAck(); 160*14675a02SAndroid Build Coastguard Worker 161*14675a02SAndroid Build Coastguard Worker // Helper to receive + process an EligibilityEvalCheckinResponse message. 162*14675a02SAndroid Build Coastguard Worker absl::StatusOr<EligibilityEvalCheckinResult> 163*14675a02SAndroid Build Coastguard Worker ReceiveEligibilityEvalCheckinResponse( 164*14675a02SAndroid Build Coastguard Worker absl::Time start_time, std::function<void(const EligibilityEvalTask&)> 165*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback); 166*14675a02SAndroid Build Coastguard Worker 167*14675a02SAndroid Build Coastguard Worker // Helper to receive + process a CheckinResponse message. 168*14675a02SAndroid Build Coastguard Worker absl::StatusOr<CheckinResult> ReceiveCheckinResponse( 169*14675a02SAndroid Build Coastguard Worker absl::Time start_time, std::function<void(const TaskAssignment&)> 170*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback); 171*14675a02SAndroid Build Coastguard Worker 172*14675a02SAndroid Build Coastguard Worker // Utility class for holding an absolute retry time and a corresponding retry 173*14675a02SAndroid Build Coastguard Worker // token. 174*14675a02SAndroid Build Coastguard Worker struct RetryTimeAndToken { 175*14675a02SAndroid Build Coastguard Worker absl::Time retry_time; 176*14675a02SAndroid Build Coastguard Worker std::string retry_token; 177*14675a02SAndroid Build Coastguard Worker }; 178*14675a02SAndroid Build Coastguard Worker // Helper to generate a RetryWindow from a given time and token. 179*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::RetryWindow 180*14675a02SAndroid Build Coastguard Worker GenerateRetryWindowFromRetryTimeAndToken(const RetryTimeAndToken& retry_info); 181*14675a02SAndroid Build Coastguard Worker 182*14675a02SAndroid Build Coastguard Worker // Helper that moves to the given object state if the given status represents 183*14675a02SAndroid Build Coastguard Worker // a permanent error. 184*14675a02SAndroid Build Coastguard Worker void UpdateObjectStateIfPermanentError( 185*14675a02SAndroid Build Coastguard Worker absl::Status status, ObjectState permanent_error_object_state); 186*14675a02SAndroid Build Coastguard Worker 187*14675a02SAndroid Build Coastguard Worker // Utility struct to represent resource data coming from the gRPC protocol. 188*14675a02SAndroid Build Coastguard Worker // A resource is either represented by a URI from which the data should be 189*14675a02SAndroid Build Coastguard Worker // fetched (in which case `has_uri` is true and `uri` should not be empty), or 190*14675a02SAndroid Build Coastguard Worker // is available as inline data (in which case `has_uri` is false and `data` 191*14675a02SAndroid Build Coastguard Worker // may or may not be empty). 192*14675a02SAndroid Build Coastguard Worker struct TaskResource { 193*14675a02SAndroid Build Coastguard Worker bool has_uri; 194*14675a02SAndroid Build Coastguard Worker const std::string& uri; 195*14675a02SAndroid Build Coastguard Worker const std::string& data; 196*14675a02SAndroid Build Coastguard Worker // The following fields will be set if the client should attempt to cache 197*14675a02SAndroid Build Coastguard Worker // the resource. 198*14675a02SAndroid Build Coastguard Worker const std::string& client_cache_id; 199*14675a02SAndroid Build Coastguard Worker const absl::Duration max_age; 200*14675a02SAndroid Build Coastguard Worker }; 201*14675a02SAndroid Build Coastguard Worker // Represents the common set of resources a task may have. 202*14675a02SAndroid Build Coastguard Worker struct TaskResources { 203*14675a02SAndroid Build Coastguard Worker TaskResource plan; 204*14675a02SAndroid Build Coastguard Worker TaskResource checkpoint; 205*14675a02SAndroid Build Coastguard Worker }; 206*14675a02SAndroid Build Coastguard Worker 207*14675a02SAndroid Build Coastguard Worker // Helper function for fetching the checkpoint/plan resources for an 208*14675a02SAndroid Build Coastguard Worker // eligibility eval task or regular task. This function will return an error 209*14675a02SAndroid Build Coastguard Worker // if either `TaskResource` represents an invalid state (e.g. if `has_uri && 210*14675a02SAndroid Build Coastguard Worker // uri.empty()`). 211*14675a02SAndroid Build Coastguard Worker absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources( 212*14675a02SAndroid Build Coastguard Worker TaskResources task_resources); 213*14675a02SAndroid Build Coastguard Worker // Validates the given `TaskResource` and converts it to a `UriOrInlineData` 214*14675a02SAndroid Build Coastguard Worker // object for use with the `FetchResourcesInMemory` utility method. 215*14675a02SAndroid Build Coastguard Worker absl::StatusOr<::fcp::client::http::UriOrInlineData> 216*14675a02SAndroid Build Coastguard Worker ConvertResourceToUriOrInlineData(const TaskResource& resource); 217*14675a02SAndroid Build Coastguard Worker 218*14675a02SAndroid Build Coastguard Worker ObjectState object_state_; 219*14675a02SAndroid Build Coastguard Worker EventPublisher* const event_publisher_; 220*14675a02SAndroid Build Coastguard Worker LogManager* const log_manager_; 221*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_; 222*14675a02SAndroid Build Coastguard Worker const Flags* const flags_; 223*14675a02SAndroid Build Coastguard Worker ::fcp::client::http::HttpClient* const http_client_; 224*14675a02SAndroid Build Coastguard Worker std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream_; 225*14675a02SAndroid Build Coastguard Worker std::unique_ptr<InterruptibleRunner> interruptible_runner_; 226*14675a02SAndroid Build Coastguard Worker const std::string population_name_; 227*14675a02SAndroid Build Coastguard Worker const std::string retry_token_; 228*14675a02SAndroid Build Coastguard Worker const std::string client_version_; 229*14675a02SAndroid Build Coastguard Worker const std::string attestation_measurement_; 230*14675a02SAndroid Build Coastguard Worker std::function<absl::StatusOr<bool>()> should_abort_; 231*14675a02SAndroid Build Coastguard Worker absl::BitGen bit_gen_; 232*14675a02SAndroid Build Coastguard Worker // The set of canonical error codes that should be treated as 'permanent' 233*14675a02SAndroid Build Coastguard Worker // errors. 234*14675a02SAndroid Build Coastguard Worker absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_; 235*14675a02SAndroid Build Coastguard Worker int64_t http_bytes_downloaded_ = 0; 236*14675a02SAndroid Build Coastguard Worker int64_t http_bytes_uploaded_ = 0; 237*14675a02SAndroid Build Coastguard Worker std::unique_ptr<WallClockStopwatch> network_stopwatch_ = 238*14675a02SAndroid Build Coastguard Worker WallClockStopwatch::Create(); 239*14675a02SAndroid Build Coastguard Worker // Represents 2 absolute retry timestamps and their corresponding retry 240*14675a02SAndroid Build Coastguard Worker // tokens, to use when the device is rejected or accepted. The retry 241*14675a02SAndroid Build Coastguard Worker // timestamps will have been generated based on the retry windows specified in 242*14675a02SAndroid Build Coastguard Worker // the server's CheckinRequestAck message and the time at which that message 243*14675a02SAndroid Build Coastguard Worker // was received. 244*14675a02SAndroid Build Coastguard Worker struct CheckinRequestAckInfo { 245*14675a02SAndroid Build Coastguard Worker RetryTimeAndToken retry_info_if_rejected; 246*14675a02SAndroid Build Coastguard Worker RetryTimeAndToken retry_info_if_accepted; 247*14675a02SAndroid Build Coastguard Worker }; 248*14675a02SAndroid Build Coastguard Worker // Represents the information received via the CheckinRequestAck message. 249*14675a02SAndroid Build Coastguard Worker // This field will have an absent value until that message has been received. 250*14675a02SAndroid Build Coastguard Worker std::optional<CheckinRequestAckInfo> checkin_request_ack_info_; 251*14675a02SAndroid Build Coastguard Worker // The identifier of the task that was received in a CheckinResponse. Note 252*14675a02SAndroid Build Coastguard Worker // that this does not refer to the identifier of the eligbility eval task that 253*14675a02SAndroid Build Coastguard Worker // may have been received in an EligibilityEvalCheckinResponse. 254*14675a02SAndroid Build Coastguard Worker std::string execution_phase_id_; 255*14675a02SAndroid Build Coastguard Worker absl::flat_hash_map< 256*14675a02SAndroid Build Coastguard Worker std::string, google::internal::federatedml::v2::SideChannelExecutionInfo> 257*14675a02SAndroid Build Coastguard Worker side_channels_; 258*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::SideChannelProtocolExecutionInfo 259*14675a02SAndroid Build Coastguard Worker side_channel_protocol_execution_info_; 260*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::SideChannelProtocolOptionsResponse 261*14675a02SAndroid Build Coastguard Worker side_channel_protocol_options_response_; 262*14675a02SAndroid Build Coastguard Worker // `nullptr` if the feature is disabled. 263*14675a02SAndroid Build Coastguard Worker cache::ResourceCache* resource_cache_; 264*14675a02SAndroid Build Coastguard Worker }; 265*14675a02SAndroid Build Coastguard Worker 266*14675a02SAndroid Build Coastguard Worker } // namespace client 267*14675a02SAndroid Build Coastguard Worker } // namespace fcp 268*14675a02SAndroid Build Coastguard Worker 269*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_ 270