1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_FEDERATED_PROTOCOL_H_ 17 #define FCP_CLIENT_FEDERATED_PROTOCOL_H_ 18 19 #include <cstdint> 20 #include <functional> 21 #include <memory> 22 #include <optional> 23 #include <string> 24 #include <utility> 25 #include <variant> 26 #include <vector> 27 28 #include "absl/container/node_hash_map.h" 29 #include "absl/status/status.h" 30 #include "absl/status/statusor.h" 31 #include "fcp/client/engine/engine.pb.h" 32 #include "fcp/client/stats.h" 33 #include "fcp/protos/federated_api.pb.h" 34 #include "fcp/protos/federatedcompute/eligibility_eval_tasks.pb.h" 35 #include "fcp/protos/plan.pb.h" 36 37 namespace fcp { 38 namespace client { 39 40 // Data type used to encode results of a computation - a TensorFlow 41 // checkpoint, or SecAgg quantized tensors. 42 // For non-SecAgg use (simple federated aggregation, or local computation), 43 // this map should only contain one entry - a TFCheckpoint - and the string 44 // should be ignored by downstream code. 45 // For SecAgg use, there should be 46 // * at most one TFCheckpoint - again, the key should be ignored - and 47 // * N QuantizedTensors, whose string keys must map to the tensor names 48 // provided in the server's CheckinResponse's SideChannelExecutionInfo. 49 using TFCheckpoint = std::string; 50 struct QuantizedTensor { 51 std::vector<uint64_t> values; 52 int32_t bitwidth = 0; 53 std::vector<int64_t> dimensions; 54 55 QuantizedTensor() = default; 56 // Disallow copy and assign. 57 QuantizedTensor(const QuantizedTensor&) = delete; 58 QuantizedTensor& operator=(const QuantizedTensor&) = delete; 59 // Enable move semantics. 60 QuantizedTensor(QuantizedTensor&&) = default; 61 QuantizedTensor& operator=(QuantizedTensor&&) = default; 62 }; 63 // This is equivalent to using ComputationResults = 64 // std::map<std::string, std::variant<TFCheckpoint, QuantizedTensor>>; 65 // except copy construction and assignment are explicitly prohibited and move 66 // semantics is enforced. 67 class ComputationResults 68 : public absl::node_hash_map<std::string, 69 std::variant<TFCheckpoint, QuantizedTensor>> { 70 public: 71 using Base = absl::node_hash_map<std::string, 72 std::variant<TFCheckpoint, QuantizedTensor>>; 73 using Base::Base; 74 using Base::operator=; 75 ComputationResults(const ComputationResults&) = delete; 76 ComputationResults& operator=(const ComputationResults&) = delete; 77 ComputationResults(ComputationResults&&) = default; 78 ComputationResults& operator=(ComputationResults&&) = default; 79 }; 80 81 // An interface that represents a single Federated Compute protocol session. 82 // 83 // An instance of this class represents a single session of client-server 84 // interaction. Instances are generally stateful, and therefore cannot be 85 // reused (each session should use a dedicated instance). 86 // 87 // The protocol consists of 3 phases, which must occur in the following order: 88 // 1. A call to `EligibilityEvalCheckin()`. 89 // 2. A call to `Checkin(...)`, only if the client wasn't rejected by the server 90 // in the previous phase. 91 // 3. A call to `ReportCompleted(...)` or `ReportNotCompleted(...)`, only if the 92 // client wasn't rejected in the previous phase. 93 class FederatedProtocol { 94 public: 95 virtual ~FederatedProtocol() = default; 96 97 // The unparsed plan and checkpoint payload which make up a computation. The 98 // data can be provided as either an std::string or an absl::Cord. 99 struct PlanAndCheckpointPayloads { 100 std::variant<std::string, absl::Cord> plan; 101 std::variant<std::string, absl::Cord> checkpoint; 102 }; 103 104 // An eligibility task, consisting of task payloads and an execution ID. 105 struct EligibilityEvalTask { 106 PlanAndCheckpointPayloads payloads; 107 std::string execution_id; 108 std::optional< 109 google::internal::federatedcompute::v1::PopulationEligibilitySpec> 110 population_eligibility_spec; 111 }; 112 // A rejection of a client by the server. 113 struct Rejection {}; 114 // Indicates that the server does not have an eligibility eval task configured 115 // for the population. 116 struct EligibilityEvalDisabled {}; 117 // EligibilityEvalCheckin() returns either 118 // 1. an `EligibilityEvalTask` struct holding the payloads for an eligibility 119 // eval task, if the population is configured with such a task. In this 120 // case the caller should execute the task and pass the resulting 121 // `TaskEligibilityInfo` value to the `Checkin(...)` method. 122 // 2. an `EligibilityEvalDisabled` struct if the population doesn't have an 123 // eligibility eval task configured. In this case the caller should 124 // continue the protocol by calling the `Checkin(...)` method without 125 // providing a `TaskEligibilityInfo` value. 126 // 3. a `Rejection` if the server rejected this device. In this case the 127 // caller 128 // should end its protocol interaction. 129 using EligibilityEvalCheckinResult = 130 std::variant<EligibilityEvalTask, EligibilityEvalDisabled, Rejection>; 131 132 // Checks in with a federated server to receive the population's eligibility 133 // eval task. This method is optional and may be called 0 or 1 times. If it is 134 // called, then it must be called before any call to `Checkin(...)`. 135 // 136 // If an eligibility eval task is configured, then the 137 // `payload_uris_received_callback` function will be called with a partially 138 // populated `EligibilityEvalTask` containing all of the task's info except 139 // for the actual payloads (which are yet to be fetched at that point). 140 // 141 // Returns: 142 // - On success, an EligibilityEvalCheckinResult. 143 // - On error: 144 // - ABORTED when one of the I/O operations got aborted by the server. 145 // - CANCELLED when one of the I/O operations was interrupted by the client 146 // (possibly due to a positive result from the should_abort callback). 147 // - UNAVAILABLE when server cannot be reached or URI is invalid. 148 // - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the 149 // specified population name is incorrect. 150 // - UNIMPLEMENTED if an unexpected server response is received. 151 // - INTERNAL if the server-provided ClientOnlyPlan cannot be parsed. (See 152 // note in federated_protocol.cc for the reasoning for this.) 153 // - INTERNAL for other unexpected client-side errors. 154 // - any server-provided error code. 155 virtual absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin( 156 std::function<void(const EligibilityEvalTask&)> 157 payload_uris_received_callback) = 0; 158 159 // Report an eligibility eval task error to the federated server. 160 // Must only be called once and after a successful call to 161 // EligibilityEvalCheckin() which returns an eligibility eval task. This 162 // method is only used to report an error happened during the computation of 163 // the eligibility eval task. If the eligibility eval computation succeeds, 164 // the success will be reported during task assignment. 165 // @param status the outcome of the eligibility eval computation. 166 virtual void ReportEligibilityEvalError(absl::Status error_status) = 0; 167 168 // SecAgg metadata, e.g. see SecureAggregationProtocolExecutionInfo in 169 // federated_api.proto. 170 struct SecAggInfo { 171 int32_t expected_number_of_clients; 172 int32_t minimum_clients_in_server_visible_aggregate; 173 }; 174 175 // A task assignment, consisting of task payloads, a URI template to download 176 // federated select task slices with (if the plan uses federated select), a 177 // session identifier, and SecAgg-related metadata. 178 struct TaskAssignment { 179 PlanAndCheckpointPayloads payloads; 180 std::string federated_select_uri_template; 181 std::string aggregation_session_id; 182 std::optional<SecAggInfo> sec_agg_info; 183 }; 184 // Checkin() returns either 185 // 1. a `TaskAssignment` struct if the client was assigned a task to run, or 186 // 2. a `Rejection` struct if the server rejected this device. 187 using CheckinResult = std::variant<TaskAssignment, Rejection>; 188 189 // Checks in with a federated server. Must only be called once. If the 190 // `EligibilityEvalCheckin()` method was previously called, then this method 191 // must only be called if the result of that call was not a `Rejection`. 192 // 193 // If the caller previously called `EligibilityEvalCheckin()` and: 194 // - received a payload, then the `TaskEligibilityInfo` value computed by that 195 // payload must be provided via the `task_eligibility_info` parameter. 196 // - received an `EligibilityEvalDisabled` result, then the 197 // `task_eligibility_info` parameter should be left empty. 198 // 199 // If the caller did not previously call `EligibilityEvalCheckin()`, then the 200 // `task_eligibility_info` parameter should be left empty. 201 // 202 // If the client is assigned a task by the server, then the 203 // `payload_uris_received_callback` function will be called with a partially 204 // populated `TaskAssignment` containing all of the task's info except for the 205 // actual payloads (which are yet to be fetched at that point) 206 // 207 // Returns: 208 // - On success, a `CheckinResult`. 209 // - On error: 210 // - ABORTED when one of the I/O operations got aborted by the server. 211 // - CANCELLED when one of the I/O operations was interrupted by the client 212 // (possibly due to a positive result from the should_abort callback). 213 // - UNAVAILABLE when server cannot be reached or URI is invalid. 214 // - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the 215 // specified population name is incorrect. 216 // - UNIMPLEMENTED if an unexpected server response is received. 217 // - INTERNAL if the server-provided ClientOnlyPlan cannot be parsed. (See 218 // note in federated_protocol.cc for the reasoning for this.) 219 // - INTERNAL for other unexpected client-side errors. 220 // - any server-provided error code. 221 // TODO(team): Replace this reference to protocol-specific 222 // TaskEligibilityInfo proto with a protocol-agnostic struct. 223 virtual absl::StatusOr<CheckinResult> Checkin( 224 const std::optional< 225 google::internal::federatedml::v2::TaskEligibilityInfo>& 226 task_eligibility_info, 227 std::function<void(const TaskAssignment&)> 228 payload_uris_received_callback) = 0; 229 230 // A list of absl::StatusOr<TaskAssignment> returned by 231 // PerformMultipleTaskAssignments. Individual absl::StatusOr<TaskAssignment> 232 // may be an error status due to failed to fetch the plan resources. 233 struct MultipleTaskAssignments { 234 std::vector<absl::StatusOr<TaskAssignment>> task_assignments; 235 }; 236 237 // Checks in with a federated server to get multiple task assignments. 238 // 239 // Must only be called once after the following conditions are met: 240 // 241 // - the caller previously called `EligibilityEvalCheckin()` and, 242 // - received a payload, and the returned EligibilityEvalTask's 243 // `PopulationEligibilitySpec` contained at least one task with 244 // TASK_ASSIGNMENT_MODE_MULTIPLE, for which the device is eligible. 245 // 246 // 247 // Returns: 248 // - On success, a `MultipleTaskAssignments`. 249 // - On error: 250 // - ABORTED when one of the I/O operations got aborted by the server. 251 // - CANCELLED when one of the I/O operations was interrupted by the client 252 // (possibly due to a positive result from the should_abort callback). 253 // - UNAVAILABLE when server cannot be reached or URI is invalid. 254 // - NOT_FOUND if the server responds with NOT_FOUND, e.g. because the 255 // specified population name is incorrect. 256 // - UNIMPLEMENTED if an unexpected server response is received. 257 // - INTERNAL for other unexpected client-side errors. 258 // - any server-provided error code. 259 virtual absl::StatusOr<MultipleTaskAssignments> 260 PerformMultipleTaskAssignments( 261 const std::vector<std::string>& task_names) = 0; 262 263 // Reports the result of a federated computation to the server. Must only be 264 // called once and after a successful call to Checkin(). 265 // @param checkpoint A checkpoint proto. 266 // @param stats all stats reported during the computation. 267 // @param plan_duration the duration for executing the plan in the plan 268 // engine. Does not include time spent on downloading the plan. 269 // Returns: 270 // - On success, OK. 271 // - On error (e.g. an interruption, network error, or other unexpected 272 // error): 273 // - ABORTED when one of the I/O operations got aborted by the server. 274 // - CANCELLED when one of the I/O operations was interrupted by the client 275 // (possibly due to a positive result from the should_abort callback). 276 // - UNIMPLEMENTED if the server responded with an unexpected response 277 // message. 278 // - INTERNAL for other unexpected client-side errors. 279 // - any server-provided error code. 280 virtual absl::Status ReportCompleted( 281 ComputationResults results, absl::Duration plan_duration, 282 std::optional<std::string> aggregation_session_id) = 0; 283 284 // Reports the unsuccessful result of a federated computation to the server. 285 // Must only be called once and after a successful call to Checkin(). 286 // @param phase_outcome the outcome of the federated computation. 287 // @param plan_duration the duration for executing the plan in the plan 288 // engine. Does not include time spent on downloading the plan. 289 // Returns: 290 // - On success, OK. 291 // - On error: 292 // - ABORTED when one of the I/O operations got aborted by the server. 293 // - CANCELLED when one of the I/O operations was interrupted by the client 294 // (possibly due to a positive result from the should_abort callback). 295 // - UNIMPLEMENTED if the server responded with an unexpected response 296 // message, or if the results to report require SecAgg support. 297 // - INTERNAL for other unexpected client-side errors. 298 // - any server-provided error code. 299 virtual absl::Status ReportNotCompleted( 300 engine::PhaseOutcome phase_outcome, absl::Duration plan_duration, 301 std::optional<std::string> aggregation_session_id) = 0; 302 303 // Returns the RetryWindow the caller should use when rescheduling, based on 304 // the current protocol phase. The value returned by this method may change 305 // after every interaction with the protocol, so callers should call this 306 // right before ending their interactions with the FederatedProtocol object to 307 // ensure they use the most recent value. 308 // TODO(team): Replace this reference to protocol-specific 309 // RetryWindow proto with a protocol-agnostic struct (or just a single 310 // absl::Duration). 311 virtual google::internal::federatedml::v2::RetryWindow 312 GetLatestRetryWindow() = 0; 313 314 // Returns the best estimate of the total bytes downloaded and uploaded over 315 // the network, plus the best estimate of the duration of wall clock time 316 // spent waiting for network requests to finish (but, for example, excluding 317 // any idle time spent waiting between issuing polling requests). 318 // 319 // Note that this estimate may still include time spent simply waiting for a 320 // server response, even if no data was being sent or received during that 321 // time. E.g. in the case of the legacy gRPC protocol where the single checkin 322 // request blocks until a task is assigned to the client. 323 // 324 // If possible, this estimate should also include time spent 325 // compressing/decompressing payloads before writing them to or after reading 326 // them from the network. 327 virtual NetworkStats GetNetworkStats() = 0; 328 329 protected: 330 // A list of states representing the sequence of calls we expect to receive 331 // via this interface, as well as their possible outcomes. Implementations of 332 // this class are likely to share these coarse-grained states, and use them to 333 // determine which values to return from `GetLatestRetryWindow()`. 334 enum class ObjectState { 335 // The initial object state. 336 kInitialized, 337 // EligibilityEvalCheckin() was called but it failed with a 'transient' 338 // error (e.g. an UNAVAILABLE network error, although the set of transient 339 // errors is flag-defined). 340 kEligibilityEvalCheckinFailed, 341 // EligibilityEvalCheckin() was called but it failed with a 'permanent' 342 // error (e.g. a NOT_FOUND network error, although the set of permanent 343 // errors is flag-defined). 344 kEligibilityEvalCheckinFailedPermanentError, 345 // EligibilityEvalCheckin() was called, and the server rejected the client. 346 kEligibilityEvalCheckinRejected, 347 // EligibilityEvalCheckin() was called, and the server did not return an 348 // eligibility eval payload. 349 kEligibilityEvalDisabled, 350 // EligibilityEvalCheckin() was called, and the server did return an 351 // eligibility eval payload, which must then be run to produce a 352 // TaskEligibilityInfo value. 353 kEligibilityEvalEnabled, 354 // Checkin(...) was called but it failed with a 'transient' error. 355 kCheckinFailed, 356 // Checkin(...) was called but it failed with a 'permanent' error. 357 kCheckinFailedPermanentError, 358 // Checkin(...) was called, and the server rejected the client. 359 kCheckinRejected, 360 // Checkin(...) was called, and the server accepted the client and returned 361 // a payload, which must then be run to produce a report. 362 kCheckinAccepted, 363 // PerformMultipleTaskAssignments(...) was called but it failed with a 364 // 'transient' error, without receiving a single task assignment. If some 365 // task assignments were successfully received, but some others failed (e.g. 366 // because their resources failed to be downloaded), then this state won't 367 // be used. 368 kMultipleTaskAssignmentsFailed, 369 // PerformMultipleTaskAssignments(...) was called but it failed with a 370 // 'permanent' error. 371 kMultipleTaskAssignmentsFailedPermanentError, 372 // PerformMultipleTaskAssignments(...) was called but an empty list of tasks 373 // is returned by the server. 374 kMultipleTaskAssignmentsNoAvailableTask, 375 // PerformMultipleTaskAssignments(...) was called, and the server accepted 376 // the client and returned one or more payload, which must then be run to 377 // produce a report. 378 kMultipleTaskAssignmentsAccepted, 379 // Report(...) was called. 380 kReportCalled, 381 // Report(...) was called and it resulted in a 'permanent' error. 382 // 383 // Note: there is no kReportFailed (corresponding to 'transient' errors, 384 // like the other phases have), because by the time the report phase is 385 // reached, a set of RetryWindows is guaranteed to have been received from 386 // the server. 387 kReportFailedPermanentError, 388 // Report(...) was called for multiple tasks, and only a subset of the tasks 389 // succeed. 390 kReportMultipleTaskPartialError, 391 }; 392 }; 393 394 } // namespace client 395 } // namespace fcp 396 397 #endif // FCP_CLIENT_FEDERATED_PROTOCOL_H_ 398