1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_TEST_HELPERS_H_ 17 #define FCP_CLIENT_TEST_HELPERS_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 #include <variant> 23 #include <vector> 24 25 #include "absl/status/status.h" 26 #include "absl/status/statusor.h" 27 #include "fcp/base/monitoring.h" 28 #include "fcp/client/engine/example_iterator_factory.h" 29 #include "fcp/client/event_publisher.h" 30 #include "fcp/client/federated_protocol.h" 31 #include "fcp/client/federated_select.h" 32 #include "fcp/client/flags.h" 33 #include "fcp/client/http/http_client.h" 34 #include "fcp/client/log_manager.h" 35 #include "fcp/client/opstats/opstats_db.h" 36 #include "fcp/client/opstats/opstats_logger.h" 37 #include "fcp/client/phase_logger.h" 38 #include "fcp/client/secagg_event_publisher.h" 39 #include "fcp/client/secagg_runner.h" 40 #include "fcp/client/simple_task_environment.h" 41 #include "gmock/gmock.h" 42 #include "google/protobuf/duration.pb.h" 43 #include "tensorflow/core/example/example.pb.h" 44 #include "tensorflow/core/example/feature.pb.h" 45 46 namespace fcp { 47 namespace client { 48 49 class MockSecAggEventPublisher : public SecAggEventPublisher { 50 public: 51 MOCK_METHOD(void, PublishStateTransition, 52 (::fcp::secagg::ClientState state, size_t last_sent_message_size, 53 size_t last_received_message_size), 54 (override)); 55 MOCK_METHOD(void, PublishError, (), (override)); 56 MOCK_METHOD(void, PublishAbort, 57 (bool client_initiated, const std::string& error_message), 58 (override)); 59 MOCK_METHOD(void, set_execution_session_id, (int64_t execution_session_id), 60 (override)); 61 }; 62 63 class MockEventPublisher : public EventPublisher { 64 public: 65 MOCK_METHOD(void, PublishEligibilityEvalCheckin, (), (override)); 66 MOCK_METHOD(void, PublishEligibilityEvalPlanUriReceived, 67 (const NetworkStats& network_stats, 68 absl::Duration phase_duration), 69 (override)); 70 MOCK_METHOD(void, PublishEligibilityEvalPlanReceived, 71 (const NetworkStats& network_stats, 72 absl::Duration phase_duration), 73 (override)); 74 MOCK_METHOD(void, PublishEligibilityEvalNotConfigured, 75 (const NetworkStats& network_stats, 76 absl::Duration phase_duration), 77 (override)); 78 MOCK_METHOD(void, PublishEligibilityEvalRejected, 79 (const NetworkStats& network_stats, 80 absl::Duration phase_duration), 81 (override)); 82 MOCK_METHOD(void, PublishCheckin, (), (override)); 83 MOCK_METHOD(void, PublishCheckinFinished, 84 (const NetworkStats& network_stats, 85 absl::Duration phase_duration), 86 (override)); 87 MOCK_METHOD(void, PublishRejected, (), (override)); 88 MOCK_METHOD(void, PublishReportStarted, (int64_t report_size_bytes), 89 (override)); 90 MOCK_METHOD(void, PublishReportFinished, 91 (const NetworkStats& network_stats, 92 absl::Duration report_duration), 93 (override)); 94 MOCK_METHOD(void, PublishPlanExecutionStarted, (), (override)); 95 MOCK_METHOD(void, PublishTensorFlowError, 96 (int example_count, absl::string_view error_message), (override)); 97 MOCK_METHOD(void, PublishIoError, (absl::string_view error_message), 98 (override)); 99 MOCK_METHOD(void, PublishExampleSelectorError, 100 (int example_count, absl::string_view error_message), (override)); 101 MOCK_METHOD(void, PublishInterruption, 102 (const ExampleStats& example_stats, absl::Time start_time), 103 (override)); 104 MOCK_METHOD(void, PublishPlanCompleted, 105 (const ExampleStats& example_stats, absl::Time start_time), 106 (override)); 107 MOCK_METHOD(void, SetModelIdentifier, (const std::string& model_identifier), 108 (override)); 109 MOCK_METHOD(void, PublishTaskNotStarted, (absl::string_view error_message), 110 (override)); 111 MOCK_METHOD(void, PublishNonfatalInitializationError, 112 (absl::string_view error_message), (override)); 113 MOCK_METHOD(void, PublishFatalInitializationError, 114 (absl::string_view error_message), (override)); 115 MOCK_METHOD(void, PublishEligibilityEvalCheckinIoError, 116 (absl::string_view error_message, 117 const NetworkStats& network_stats, 118 absl::Duration phase_duration), 119 (override)); 120 MOCK_METHOD(void, PublishEligibilityEvalCheckinClientInterrupted, 121 (absl::string_view error_message, 122 const NetworkStats& network_stats, 123 absl::Duration phase_duration), 124 (override)); 125 MOCK_METHOD(void, PublishEligibilityEvalCheckinServerAborted, 126 (absl::string_view error_message, 127 const NetworkStats& network_stats, 128 absl::Duration phase_duration), 129 (override)); 130 MOCK_METHOD(void, PublishEligibilityEvalCheckinErrorInvalidPayload, 131 (absl::string_view error_message, 132 const NetworkStats& network_stats, 133 absl::Duration phase_duration), 134 (override)); 135 MOCK_METHOD(void, PublishEligibilityEvalComputationStarted, (), (override)); 136 MOCK_METHOD(void, PublishEligibilityEvalComputationInvalidArgument, 137 (absl::string_view error_message, 138 const ExampleStats& example_stats, 139 absl::Duration phase_duration), 140 (override)); 141 MOCK_METHOD(void, PublishEligibilityEvalComputationExampleIteratorError, 142 (absl::string_view error_message, 143 const ExampleStats& example_stats, 144 absl::Duration phase_duration), 145 (override)); 146 MOCK_METHOD(void, PublishEligibilityEvalComputationTensorflowError, 147 (absl::string_view error_message, 148 const ExampleStats& example_stats, 149 absl::Duration phase_duration), 150 (override)); 151 MOCK_METHOD(void, PublishEligibilityEvalComputationInterrupted, 152 (absl::string_view error_message, 153 const ExampleStats& example_stats, 154 absl::Duration phase_duration), 155 (override)); 156 MOCK_METHOD(void, PublishEligibilityEvalComputationCompleted, 157 (const ExampleStats& example_stats, 158 absl::Duration phase_duration), 159 (override)); 160 MOCK_METHOD(void, PublishCheckinIoError, 161 (absl::string_view error_message, 162 const NetworkStats& network_stats, 163 absl::Duration phase_duration), 164 (override)); 165 MOCK_METHOD(void, PublishCheckinClientInterrupted, 166 (absl::string_view error_message, 167 const NetworkStats& network_stats, 168 absl::Duration phase_duration), 169 (override)); 170 MOCK_METHOD(void, PublishCheckinServerAborted, 171 (absl::string_view error_message, 172 const NetworkStats& network_stats, 173 absl::Duration phase_duration), 174 (override)); 175 MOCK_METHOD(void, PublishCheckinInvalidPayload, 176 (absl::string_view error_message, 177 const NetworkStats& network_stats, 178 absl::Duration phase_duration), 179 (override)); 180 MOCK_METHOD(void, PublishRejected, 181 (const NetworkStats& network_stats, 182 absl::Duration phase_duration), 183 (override)); 184 MOCK_METHOD(void, PublishCheckinPlanUriReceived, 185 (const NetworkStats& network_stats, 186 absl::Duration phase_duration), 187 (override)); 188 MOCK_METHOD(void, PublishCheckinFinishedV2, 189 (const NetworkStats& network_stats, 190 absl::Duration phase_duration), 191 (override)); 192 MOCK_METHOD(void, PublishComputationStarted, (), (override)); 193 MOCK_METHOD(void, PublishComputationInvalidArgument, 194 (absl::string_view error_message, 195 const ExampleStats& example_stats, 196 const NetworkStats& network_stats, 197 absl::Duration phase_duration), 198 (override)); 199 MOCK_METHOD(void, PublishComputationIOError, 200 (absl::string_view error_message, 201 const ExampleStats& example_stats, 202 const NetworkStats& network_stats, 203 absl::Duration phase_duration), 204 (override)); 205 MOCK_METHOD(void, PublishComputationExampleIteratorError, 206 (absl::string_view error_message, 207 const ExampleStats& example_stats, 208 const NetworkStats& network_stats, 209 absl::Duration phase_duration), 210 (override)); 211 MOCK_METHOD(void, PublishComputationTensorflowError, 212 (absl::string_view error_message, 213 const ExampleStats& example_stats, 214 const NetworkStats& network_stats, 215 absl::Duration phase_duration), 216 (override)); 217 MOCK_METHOD(void, PublishComputationInterrupted, 218 (absl::string_view error_message, 219 const ExampleStats& example_stats, 220 const NetworkStats& network_stats, 221 absl::Duration phase_duration), 222 (override)); 223 MOCK_METHOD(void, PublishComputationCompleted, 224 (const ExampleStats& example_stats, 225 const NetworkStats& network_stats, 226 absl::Duration phase_duration), 227 (override)); 228 MOCK_METHOD(void, PublishResultUploadStarted, (), (override)); 229 MOCK_METHOD(void, PublishResultUploadIOError, 230 (absl::string_view error_message, 231 const NetworkStats& network_stats, 232 absl::Duration phase_duration), 233 (override)); 234 MOCK_METHOD(void, PublishResultUploadClientInterrupted, 235 (absl::string_view error_message, 236 const NetworkStats& network_stats, 237 absl::Duration phase_duration), 238 (override)); 239 MOCK_METHOD(void, PublishResultUploadServerAborted, 240 (absl::string_view error_message, 241 const NetworkStats& network_stats, 242 absl::Duration phase_duration), 243 (override)); 244 MOCK_METHOD(void, PublishResultUploadCompleted, 245 (const NetworkStats& network_stats, 246 absl::Duration phase_duration), 247 (override)); 248 MOCK_METHOD(void, PublishFailureUploadStarted, (), (override)); 249 MOCK_METHOD(void, PublishFailureUploadIOError, 250 (absl::string_view error_message, 251 const NetworkStats& network_stats, 252 absl::Duration phase_duration), 253 (override)); 254 MOCK_METHOD(void, PublishFailureUploadClientInterrupted, 255 (absl::string_view error_message, 256 const NetworkStats& network_stats, 257 absl::Duration phase_duration), 258 (override)); 259 MOCK_METHOD(void, PublishFailureUploadServerAborted, 260 (absl::string_view error_message, 261 const NetworkStats& network_stats, 262 absl::Duration phase_duration), 263 (override)); 264 MOCK_METHOD(void, PublishFailureUploadCompleted, 265 (const NetworkStats& network_stats, 266 absl::Duration phase_duration), 267 (override)); 268 secagg_event_publisher()269 SecAggEventPublisher* secagg_event_publisher() override { 270 return &secagg_event_publisher_; 271 } 272 273 private: 274 ::testing::NiceMock<MockSecAggEventPublisher> secagg_event_publisher_; 275 }; 276 277 // A mock FederatedProtocol implementation, which keeps track of the stages in 278 // the protocol and returns a different set of network stats and RetryWindow for 279 // each stage, making it easier to write accurate assertions in unit tests. 280 class MockFederatedProtocol : public FederatedProtocol { 281 public: 282 constexpr static NetworkStats 283 kPostEligibilityCheckinPlanUriReceivedNetworkStats = { 284 .bytes_downloaded = 280, 285 .bytes_uploaded = 380, 286 .network_duration = absl::Milliseconds(25)}; 287 constexpr static NetworkStats kPostEligibilityCheckinNetworkStats = { 288 .bytes_downloaded = 300, 289 .bytes_uploaded = 400, 290 .network_duration = absl::Milliseconds(50)}; 291 constexpr static NetworkStats kPostReportEligibilityEvalErrorNetworkStats = { 292 .bytes_downloaded = 400, 293 .bytes_uploaded = 500, 294 .network_duration = absl::Milliseconds(150)}; 295 constexpr static NetworkStats kPostCheckinPlanUriReceivedNetworkStats = { 296 .bytes_downloaded = 2970, 297 .bytes_uploaded = 3970, 298 .network_duration = absl::Milliseconds(225)}; 299 constexpr static NetworkStats kPostCheckinNetworkStats = { 300 .bytes_downloaded = 3000, 301 .bytes_uploaded = 4000, 302 .network_duration = absl::Milliseconds(250)}; 303 constexpr static NetworkStats kPostReportCompletedNetworkStats = { 304 .bytes_downloaded = 30000, 305 .bytes_uploaded = 40000, 306 .network_duration = absl::Milliseconds(350)}; 307 constexpr static NetworkStats kPostReportNotCompletedNetworkStats = { 308 .bytes_downloaded = 29999, 309 .bytes_uploaded = 39999, 310 .network_duration = absl::Milliseconds(450)}; 311 312 static google::internal::federatedml::v2::RetryWindow GetInitialRetryWindow()313 GetInitialRetryWindow() { 314 google::internal::federatedml::v2::RetryWindow retry_window; 315 retry_window.mutable_delay_min()->set_seconds(0L); 316 retry_window.mutable_delay_max()->set_seconds(1L); 317 *retry_window.mutable_retry_token() = "INITIAL"; 318 return retry_window; 319 } 320 321 static google::internal::federatedml::v2::RetryWindow GetPostEligibilityCheckinRetryWindow()322 GetPostEligibilityCheckinRetryWindow() { 323 google::internal::federatedml::v2::RetryWindow retry_window; 324 retry_window.mutable_delay_min()->set_seconds(100L); 325 retry_window.mutable_delay_max()->set_seconds(101L); 326 *retry_window.mutable_retry_token() = "POST_ELIGIBILITY"; 327 return retry_window; 328 } 329 330 static google::internal::federatedml::v2::RetryWindow GetPostCheckinRetryWindow()331 GetPostCheckinRetryWindow() { 332 google::internal::federatedml::v2::RetryWindow retry_window; 333 retry_window.mutable_delay_min()->set_seconds(200L); 334 retry_window.mutable_delay_max()->set_seconds(201L); 335 *retry_window.mutable_retry_token() = "POST_CHECKIN"; 336 return retry_window; 337 } 338 339 static google::internal::federatedml::v2::RetryWindow GetPostReportCompletedRetryWindow()340 GetPostReportCompletedRetryWindow() { 341 google::internal::federatedml::v2::RetryWindow retry_window; 342 retry_window.mutable_delay_min()->set_seconds(300L); 343 retry_window.mutable_delay_max()->set_seconds(301L); 344 *retry_window.mutable_retry_token() = "POST_REPORT_COMPLETED"; 345 return retry_window; 346 } 347 348 static google::internal::federatedml::v2::RetryWindow GetPostReportNotCompletedRetryWindow()349 GetPostReportNotCompletedRetryWindow() { 350 google::internal::federatedml::v2::RetryWindow retry_window; 351 retry_window.mutable_delay_min()->set_seconds(400L); 352 retry_window.mutable_delay_max()->set_seconds(401L); 353 *retry_window.mutable_retry_token() = "POST_REPORT_NOT_COMPLETED"; 354 return retry_window; 355 } 356 MockFederatedProtocol()357 explicit MockFederatedProtocol() {} 358 359 // We override the real FederatedProtocol methods so that we can intercept the 360 // progression of protocol stages, and expose dedicate gMock-overridable 361 // methods for use in tests. EligibilityEvalCheckin(std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)362 absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin( 363 std::function<void(const EligibilityEvalTask&)> 364 payload_uris_received_callback) final { 365 absl::StatusOr<EligibilityEvalCheckinResult> result = 366 MockEligibilityEvalCheckin(); 367 if (result.ok() && 368 std::holds_alternative<FederatedProtocol::EligibilityEvalTask>( 369 *result)) { 370 network_stats_ = kPostEligibilityCheckinPlanUriReceivedNetworkStats; 371 payload_uris_received_callback( 372 std::get<FederatedProtocol::EligibilityEvalTask>(*result)); 373 } 374 network_stats_ = kPostEligibilityCheckinNetworkStats; 375 retry_window_ = GetPostEligibilityCheckinRetryWindow(); 376 return result; 377 }; 378 MOCK_METHOD(absl::StatusOr<EligibilityEvalCheckinResult>, 379 MockEligibilityEvalCheckin, ()); 380 ReportEligibilityEvalError(absl::Status error_status)381 void ReportEligibilityEvalError(absl::Status error_status) final { 382 network_stats_ = kPostReportEligibilityEvalErrorNetworkStats; 383 retry_window_ = GetPostEligibilityCheckinRetryWindow(); 384 MockReportEligibilityEvalError(error_status); 385 } 386 MOCK_METHOD(void, MockReportEligibilityEvalError, 387 (absl::Status error_status)); 388 Checkin(const std::optional<::google::internal::federatedml::v2::TaskEligibilityInfo> & task_eligibility_info,std::function<void (const FederatedProtocol::TaskAssignment &)> payload_uris_received_callback)389 absl::StatusOr<CheckinResult> Checkin( 390 const std::optional< 391 ::google::internal::federatedml::v2::TaskEligibilityInfo>& 392 task_eligibility_info, 393 std::function<void(const FederatedProtocol::TaskAssignment&)> 394 payload_uris_received_callback) final { 395 absl::StatusOr<CheckinResult> result = MockCheckin(task_eligibility_info); 396 if (result.ok() && 397 std::holds_alternative<FederatedProtocol::TaskAssignment>(*result)) { 398 network_stats_ = kPostCheckinPlanUriReceivedNetworkStats; 399 payload_uris_received_callback( 400 std::get<FederatedProtocol::TaskAssignment>(*result)); 401 } 402 retry_window_ = GetPostCheckinRetryWindow(); 403 network_stats_ = kPostCheckinNetworkStats; 404 return result; 405 }; 406 MOCK_METHOD(absl::StatusOr<CheckinResult>, MockCheckin, 407 (const std::optional< 408 ::google::internal::federatedml::v2::TaskEligibilityInfo>& 409 task_eligibility_info)); 410 PerformMultipleTaskAssignments(const std::vector<std::string> & task_names)411 absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments( 412 const std::vector<std::string>& task_names) final { 413 absl::StatusOr<MultipleTaskAssignments> result = 414 MockPerformMultipleTaskAssignments(task_names); 415 retry_window_ = GetPostCheckinRetryWindow(); 416 network_stats_ = kPostCheckinPlanUriReceivedNetworkStats; 417 return result; 418 }; 419 420 MOCK_METHOD(absl::StatusOr<MultipleTaskAssignments>, 421 MockPerformMultipleTaskAssignments, 422 (const std::vector<std::string>& task_names)); 423 ReportCompleted(ComputationResults results,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)424 absl::Status ReportCompleted( 425 ComputationResults results, absl::Duration plan_duration, 426 std::optional<std::string> aggregation_session_id) final { 427 network_stats_ = kPostReportCompletedNetworkStats; 428 retry_window_ = GetPostReportCompletedRetryWindow(); 429 return MockReportCompleted(std::move(results), plan_duration, 430 aggregation_session_id); 431 }; 432 MOCK_METHOD(absl::Status, MockReportCompleted, 433 (ComputationResults results, absl::Duration plan_duration, 434 std::optional<std::string> aggregation_session_id)); 435 ReportNotCompleted(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)436 absl::Status ReportNotCompleted( 437 engine::PhaseOutcome phase_outcome, absl::Duration plan_duration, 438 std::optional<std::string> aggregation_session_id) final { 439 network_stats_ = kPostReportNotCompletedNetworkStats; 440 retry_window_ = GetPostReportNotCompletedRetryWindow(); 441 return MockReportNotCompleted(phase_outcome, plan_duration, 442 aggregation_session_id); 443 }; 444 MOCK_METHOD(absl::Status, MockReportNotCompleted, 445 (engine::PhaseOutcome phase_outcome, absl::Duration plan_duration, 446 std::optional<std::string> aggregation_session_id)); 447 GetLatestRetryWindow()448 ::google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow() 449 final { 450 return retry_window_; 451 } 452 GetNetworkStats()453 NetworkStats GetNetworkStats() final { return network_stats_; } 454 455 private: 456 NetworkStats network_stats_; 457 ::google::internal::federatedml::v2::RetryWindow retry_window_ = 458 GetInitialRetryWindow(); 459 }; 460 461 class MockLogManager : public LogManager { 462 public: 463 MOCK_METHOD(void, LogDiag, (ProdDiagCode), (override)); 464 MOCK_METHOD(void, LogDiag, (DebugDiagCode), (override)); 465 MOCK_METHOD(void, LogToLongHistogram, 466 (fcp::client::HistogramCounters, int, int, 467 fcp::client::engine::DataSourceType, int64_t), 468 (override)); 469 MOCK_METHOD(void, SetModelIdentifier, (const std::string&), (override)); 470 }; 471 472 class MockOpStatsLogger : public ::fcp::client::opstats::OpStatsLogger { 473 public: 474 MOCK_METHOD( 475 void, AddEventAndSetTaskName, 476 (const std::string& task_name, 477 ::fcp::client::opstats::OperationalStats::Event::EventKind event), 478 (override)); 479 MOCK_METHOD( 480 void, AddEvent, 481 (::fcp::client::opstats::OperationalStats::Event::EventKind event), 482 (override)); 483 MOCK_METHOD(void, AddEventWithErrorMessage, 484 (::fcp::client::opstats::OperationalStats::Event::EventKind event, 485 const std::string& error_message), 486 (override)); 487 MOCK_METHOD(void, UpdateDatasetStats, 488 (const std::string& collection_uri, int additional_example_count, 489 int64_t additional_example_size_bytes), 490 (override)); 491 MOCK_METHOD(void, SetNetworkStats, (const NetworkStats& network_stats), 492 (override)); 493 MOCK_METHOD(void, SetRetryWindow, 494 (google::internal::federatedml::v2::RetryWindow retry_window), 495 (override)); 496 MOCK_METHOD(::fcp::client::opstats::OpStatsDb*, GetOpStatsDb, (), (override)); 497 MOCK_METHOD(bool, IsOpStatsEnabled, (), (const override)); 498 MOCK_METHOD(absl::Status, CommitToStorage, (), (override)); 499 MOCK_METHOD(std::string, GetCurrentTaskName, (), (override)); 500 }; 501 502 class MockSimpleTaskEnvironment : public SimpleTaskEnvironment { 503 public: 504 MOCK_METHOD(std::string, GetBaseDir, (), (override)); 505 MOCK_METHOD(std::string, GetCacheDir, (), (override)); 506 MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>), 507 CreateExampleIterator, 508 (const google::internal::federated::plan::ExampleSelector& 509 example_selector), 510 (override)); 511 MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>), 512 CreateExampleIterator, 513 (const google::internal::federated::plan::ExampleSelector& 514 example_selector, 515 const SelectorContext& selector_context), 516 (override)); 517 MOCK_METHOD(std::unique_ptr<fcp::client::http::HttpClient>, CreateHttpClient, 518 (), (override)); 519 MOCK_METHOD(bool, TrainingConditionsSatisfied, (), (override)); 520 }; 521 522 class MockExampleIterator : public ExampleIterator { 523 public: 524 MOCK_METHOD(absl::StatusOr<std::string>, Next, (), (override)); 525 MOCK_METHOD(void, Close, (), (override)); 526 }; 527 528 // An iterator that passes through each example in the dataset once. 529 class SimpleExampleIterator : public ExampleIterator { 530 public: 531 // Uses the given bytes as the examples to return. 532 explicit SimpleExampleIterator(std::vector<const char*> examples); 533 // Passes through each of the examples in the `Dataset.client_data.example` 534 // field. 535 explicit SimpleExampleIterator( 536 google::internal::federated::plan::Dataset dataset); 537 // Passes through each of the examples in the 538 // `Dataset.client_data.selected_example.example` field, whose example 539 // collection URI matches the provided `collection_uri`. 540 SimpleExampleIterator(google::internal::federated::plan::Dataset dataset, 541 absl::string_view collection_uri); 542 absl::StatusOr<std::string> Next() override; Close()543 void Close() override {} 544 545 protected: 546 std::vector<std::string> examples_; 547 int index_ = 0; 548 }; 549 550 struct ComputationArtifacts { 551 // The path to the file containing the plan data. 552 std::string plan_filepath; 553 // The already-parsed plan data. 554 google::internal::federated::plan::ClientOnlyPlan plan; 555 // The test dataset. 556 google::internal::federated::plan::Dataset dataset; 557 // The path to the file containing the initial checkpoint data (not set for 558 // local compute task artifacts). 559 std::string checkpoint_filepath; 560 // The initial checkpoint data, as a string (not set for local compute task 561 // artifacts). 562 std::string checkpoint; 563 // The Federated Select slice data (not set for local compute task artifacts). 564 google::internal::federated::plan::SlicesTestDataset federated_select_slices; 565 }; 566 567 absl::StatusOr<ComputationArtifacts> LoadFlArtifacts(); 568 569 class MockFlags : public Flags { 570 public: 571 MOCK_METHOD(int64_t, condition_polling_period_millis, (), (const, override)); 572 MOCK_METHOD(int64_t, tf_execution_teardown_grace_period_millis, (), 573 (const, override)); 574 MOCK_METHOD(int64_t, tf_execution_teardown_extended_period_millis, (), 575 (const, override)); 576 MOCK_METHOD(int64_t, grpc_channel_deadline_seconds, (), (const, override)); 577 MOCK_METHOD(bool, log_tensorflow_error_messages, (), (const, override)); 578 MOCK_METHOD(bool, enable_opstats, (), (const, override)); 579 MOCK_METHOD(int64_t, opstats_ttl_days, (), (const, override)); 580 MOCK_METHOD(int64_t, opstats_db_size_limit_bytes, (), (const, override)); 581 MOCK_METHOD(int64_t, federated_training_transient_errors_retry_delay_secs, (), 582 (const, override)); 583 MOCK_METHOD(float, 584 federated_training_transient_errors_retry_delay_jitter_percent, 585 (), (const, override)); 586 MOCK_METHOD(int64_t, federated_training_permanent_errors_retry_delay_secs, (), 587 (const, override)); 588 MOCK_METHOD(float, 589 federated_training_permanent_errors_retry_delay_jitter_percent, 590 (), (const, override)); 591 MOCK_METHOD(std::vector<int32_t>, federated_training_permanent_error_codes, 592 (), (const, override)); 593 MOCK_METHOD(bool, use_tflite_training, (), (const, override)); 594 MOCK_METHOD(bool, enable_grpc_with_http_resource_support, (), 595 (const, override)); 596 MOCK_METHOD(bool, enable_grpc_with_eligibility_eval_http_resource_support, (), 597 (const, override)); 598 MOCK_METHOD(bool, ensure_dynamic_tensors_are_released, (), (const, override)); 599 MOCK_METHOD(int32_t, large_tensor_threshold_for_dynamic_allocation, (), 600 (const, override)); 601 MOCK_METHOD(bool, disable_http_request_body_compression, (), 602 (const, override)); 603 MOCK_METHOD(bool, use_http_federated_compute_protocol, (), (const, override)); 604 MOCK_METHOD(bool, enable_computation_id, (), (const, override)); 605 MOCK_METHOD(int32_t, waiting_period_sec_for_cancellation, (), 606 (const, override)); 607 MOCK_METHOD(bool, enable_federated_select, (), (const, override)); 608 MOCK_METHOD(int32_t, num_threads_for_tflite, (), (const, override)); 609 MOCK_METHOD(bool, disable_tflite_delegate_clustering, (), (const, override)); 610 MOCK_METHOD(bool, enable_example_query_plan_engine, (), (const, override)); 611 MOCK_METHOD(bool, support_constant_tf_inputs, (), (const, override)); 612 MOCK_METHOD(bool, http_protocol_supports_multiple_task_assignments, (), 613 (const, override)); 614 }; 615 616 // Helper methods for extracting opstats fields from TF examples. 617 std::string ExtractSingleString(const tensorflow::Example& example, 618 const char key[]); 619 google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString( 620 const tensorflow::Example& example, const char key[]); 621 int64_t ExtractSingleInt64(const tensorflow::Example& example, 622 const char key[]); 623 google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64( 624 const tensorflow::Example& example, const char key[]); 625 626 class MockOpStatsDb : public ::fcp::client::opstats::OpStatsDb { 627 public: 628 MOCK_METHOD(absl::StatusOr<::fcp::client::opstats::OpStatsSequence>, Read, (), 629 (override)); 630 MOCK_METHOD(absl::Status, Transform, 631 (std::function<void(::fcp::client::opstats::OpStatsSequence&)>), 632 (override)); 633 }; 634 635 class MockPhaseLogger : public PhaseLogger { 636 public: 637 MOCK_METHOD( 638 void, UpdateRetryWindowAndNetworkStats, 639 (const ::google::internal::federatedml::v2::RetryWindow& retry_window, 640 const NetworkStats& network_stats), 641 (override)); 642 MOCK_METHOD(void, SetModelIdentifier, (absl::string_view model_identifier), 643 (override)); 644 MOCK_METHOD(void, LogTaskNotStarted, (absl::string_view error_message), 645 (override)); 646 MOCK_METHOD(void, LogNonfatalInitializationError, (absl::Status error_status), 647 (override)); 648 MOCK_METHOD(void, LogFatalInitializationError, (absl::Status error_status), 649 (override)); 650 MOCK_METHOD(void, LogEligibilityEvalCheckinStarted, (), (override)); 651 MOCK_METHOD(void, LogEligibilityEvalCheckinIOError, 652 (absl::Status error_status, const NetworkStats& network_stats, 653 absl::Time time_before_checkin), 654 (override)); 655 MOCK_METHOD(void, LogEligibilityEvalCheckinInvalidPayloadError, 656 (absl::string_view error_message, 657 const NetworkStats& network_stats, 658 absl::Time time_before_checkin), 659 (override)); 660 MOCK_METHOD(void, LogEligibilityEvalCheckinClientInterrupted, 661 (absl::Status error_status, const NetworkStats& network_stats, 662 absl::Time time_before_checkin), 663 (override)); 664 MOCK_METHOD(void, LogEligibilityEvalCheckinServerAborted, 665 (absl::Status error_status, const NetworkStats& network_stats, 666 absl::Time time_before_checkin), 667 (override)); 668 MOCK_METHOD(void, LogEligibilityEvalNotConfigured, 669 (const NetworkStats& network_stats, 670 absl::Time time_before_checkin), 671 (override)); 672 MOCK_METHOD(void, LogEligibilityEvalCheckinTurnedAway, 673 (const NetworkStats& network_stats, 674 absl::Time time_before_checkin), 675 (override)); 676 MOCK_METHOD(void, LogEligibilityEvalCheckinPlanUriReceived, 677 (const NetworkStats& network_stats, 678 absl::Time time_before_checkin), 679 (override)); 680 MOCK_METHOD(void, LogEligibilityEvalCheckinCompleted, 681 (const NetworkStats& network_stats, 682 absl::Time time_before_checkin, 683 absl::Time time_before_plan_download), 684 (override)); 685 MOCK_METHOD(void, LogEligibilityEvalComputationStarted, (), (override)); 686 MOCK_METHOD(void, LogEligibilityEvalComputationInvalidArgument, 687 (absl::Status error_status, const ExampleStats& example_stats, 688 absl::Time run_plan_start_time), 689 (override)); 690 MOCK_METHOD(void, LogEligibilityEvalComputationExampleIteratorError, 691 (absl::Status error_status, const ExampleStats& example_stats, 692 absl::Time run_plan_start_time), 693 (override)); 694 MOCK_METHOD(void, LogEligibilityEvalComputationTensorflowError, 695 (absl::Status error_status, const ExampleStats& example_stats, 696 absl::Time run_plan_start_time, absl::Time reference_time), 697 (override)); 698 MOCK_METHOD(void, LogEligibilityEvalComputationInterrupted, 699 (absl::Status error_status, const ExampleStats& example_stats, 700 absl::Time run_plan_start_time, absl::Time reference_time), 701 (override)); 702 MOCK_METHOD(void, LogEligibilityEvalComputationCompleted, 703 (const ExampleStats& example_stats, 704 absl::Time run_plan_start_time, absl::Time reference_time), 705 (override)); 706 MOCK_METHOD(void, LogCheckinStarted, (), (override)); 707 MOCK_METHOD(void, LogCheckinIOError, 708 (absl::Status error_status, const NetworkStats& network_stats, 709 absl::Time time_before_checkin, absl::Time reference_time), 710 (override)); 711 MOCK_METHOD(void, LogCheckinInvalidPayload, 712 (absl::string_view error_message, 713 const NetworkStats& network_stats, 714 absl::Time time_before_checkin, absl::Time reference_time), 715 (override)); 716 MOCK_METHOD(void, LogCheckinClientInterrupted, 717 (absl::Status error_status, const NetworkStats& network_stats, 718 absl::Time time_before_checkin, absl::Time reference_time), 719 (override)); 720 MOCK_METHOD(void, LogCheckinServerAborted, 721 (absl::Status error_status, const NetworkStats& network_stats, 722 absl::Time time_before_checkin, absl::Time reference_time), 723 (override)); 724 MOCK_METHOD(void, LogCheckinTurnedAway, 725 (const NetworkStats& network_stats, 726 absl::Time time_before_checkin, absl::Time reference_time), 727 (override)); 728 MOCK_METHOD(void, LogCheckinPlanUriReceived, 729 (absl::string_view task_name, const NetworkStats& network_stats, 730 absl::Time time_before_checkin), 731 (override)); 732 MOCK_METHOD(void, LogCheckinCompleted, 733 (absl::string_view task_name, const NetworkStats& network_stats, 734 absl::Time time_before_checkin, 735 absl::Time time_before_plan_download, absl::Time reference_time), 736 (override)); 737 MOCK_METHOD(void, LogComputationStarted, (), (override)); 738 MOCK_METHOD(void, LogComputationInvalidArgument, 739 (absl::Status error_status, const ExampleStats& example_stats, 740 const NetworkStats& network_stats, 741 absl::Time run_plan_start_time), 742 (override)); 743 MOCK_METHOD(void, LogComputationExampleIteratorError, 744 (absl::Status error_status, const ExampleStats& example_stats, 745 const NetworkStats& network_stats, 746 absl::Time run_plan_start_time), 747 (override)); 748 MOCK_METHOD(void, LogComputationIOError, 749 (absl::Status error_status, const ExampleStats& example_stats, 750 const NetworkStats& network_stats, 751 absl::Time run_plan_start_time), 752 (override)); 753 MOCK_METHOD(void, LogComputationTensorflowError, 754 (absl::Status error_status, const ExampleStats& example_stats, 755 const NetworkStats& network_stats, 756 absl::Time run_plan_start_time, absl::Time reference_time), 757 (override)); 758 MOCK_METHOD(void, LogComputationInterrupted, 759 (absl::Status error_status, const ExampleStats& example_stats, 760 const NetworkStats& network_stats, 761 absl::Time run_plan_start_time, absl::Time reference_time), 762 (override)); 763 MOCK_METHOD(void, LogComputationCompleted, 764 (const ExampleStats& example_stats, 765 const NetworkStats& network_stats, 766 absl::Time run_plan_start_time, absl::Time reference_time), 767 (override)); 768 MOCK_METHOD(absl::Status, LogResultUploadStarted, (), (override)); 769 MOCK_METHOD(void, LogResultUploadIOError, 770 (absl::Status error_status, const NetworkStats& network_stats, 771 absl::Time time_before_result_upload, absl::Time reference_time), 772 (override)); 773 MOCK_METHOD(void, LogResultUploadClientInterrupted, 774 (absl::Status error_status, const NetworkStats& network_stats, 775 absl::Time time_before_result_upload, absl::Time reference_time), 776 (override)); 777 MOCK_METHOD(void, LogResultUploadServerAborted, 778 (absl::Status error_status, const NetworkStats& network_stats, 779 absl::Time time_before_result_upload, absl::Time reference_time), 780 (override)); 781 MOCK_METHOD(void, LogResultUploadCompleted, 782 (const NetworkStats& network_stats, 783 absl::Time time_before_result_upload, absl::Time reference_time), 784 (override)); 785 MOCK_METHOD(absl::Status, LogFailureUploadStarted, (), (override)); 786 MOCK_METHOD(void, LogFailureUploadIOError, 787 (absl::Status error_status, const NetworkStats& network_stats, 788 absl::Time time_before_failure_upload, 789 absl::Time reference_time), 790 (override)); 791 MOCK_METHOD(void, LogFailureUploadClientInterrupted, 792 (absl::Status error_status, const NetworkStats& network_stats, 793 absl::Time time_before_failure_upload, 794 absl::Time reference_time), 795 (override)); 796 MOCK_METHOD(void, LogFailureUploadServerAborted, 797 (absl::Status error_status, const NetworkStats& network_stats, 798 absl::Time time_before_failure_upload, 799 absl::Time reference_time), 800 (override)); 801 MOCK_METHOD(void, LogFailureUploadCompleted, 802 (const NetworkStats& network_stats, 803 absl::Time time_before_result_upload, absl::Time reference_time), 804 (override)); 805 }; 806 807 class MockFederatedSelectManager : public FederatedSelectManager { 808 public: 809 MOCK_METHOD(std::unique_ptr<engine::ExampleIteratorFactory>, 810 CreateExampleIteratorFactoryForUriTemplate, 811 (absl::string_view uri_template), (override)); 812 813 MOCK_METHOD(NetworkStats, GetNetworkStats, (), (override)); 814 }; 815 816 class MockFederatedSelectExampleIteratorFactory 817 : public FederatedSelectExampleIteratorFactory { 818 public: 819 MOCK_METHOD(absl::StatusOr<std::unique_ptr<ExampleIterator>>, 820 CreateExampleIterator, 821 (const ::google::internal::federated::plan::ExampleSelector& 822 example_selector), 823 (override)); 824 }; 825 826 class MockSecAggRunnerFactory : public SecAggRunnerFactory { 827 public: 828 MOCK_METHOD(std::unique_ptr<SecAggRunner>, CreateSecAggRunner, 829 (std::unique_ptr<SecAggSendToServerBase> send_to_server_impl, 830 std::unique_ptr<SecAggProtocolDelegate> protocol_delegate, 831 SecAggEventPublisher* secagg_event_publisher, 832 LogManager* log_manager, 833 InterruptibleRunner* interruptible_runner, 834 int64_t expected_number_of_clients, 835 int64_t minimum_surviving_clients_for_reconstruction), 836 (override)); 837 }; 838 839 class MockSecAggRunner : public SecAggRunner { 840 public: 841 MOCK_METHOD(absl::Status, Run, (ComputationResults results), (override)); 842 }; 843 844 class MockSecAggSendToServerBase : public SecAggSendToServerBase { 845 MOCK_METHOD(void, Send, (secagg::ClientToServerWrapperMessage * message), 846 (override)); 847 }; 848 849 class MockSecAggProtocolDelegate : public SecAggProtocolDelegate { 850 public: 851 MOCK_METHOD(absl::StatusOr<uint64_t>, GetModulus, (const std::string& key), 852 (override)); 853 MOCK_METHOD(absl::StatusOr<secagg::ServerToClientWrapperMessage>, 854 ReceiveServerMessage, (), (override)); 855 MOCK_METHOD(void, Abort, (), (override)); 856 }; 857 858 } // namespace client 859 } // namespace fcp 860 861 #endif // FCP_CLIENT_TEST_HELPERS_H_ 862