xref: /aosp_15_r20/external/federated-compute/fcp/client/test_helpers.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_TEST_HELPERS_H_
17 #define FCP_CLIENT_TEST_HELPERS_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 #include <variant>
23 #include <vector>
24 
25 #include "absl/status/status.h"
26 #include "absl/status/statusor.h"
27 #include "fcp/base/monitoring.h"
28 #include "fcp/client/engine/example_iterator_factory.h"
29 #include "fcp/client/event_publisher.h"
30 #include "fcp/client/federated_protocol.h"
31 #include "fcp/client/federated_select.h"
32 #include "fcp/client/flags.h"
33 #include "fcp/client/http/http_client.h"
34 #include "fcp/client/log_manager.h"
35 #include "fcp/client/opstats/opstats_db.h"
36 #include "fcp/client/opstats/opstats_logger.h"
37 #include "fcp/client/phase_logger.h"
38 #include "fcp/client/secagg_event_publisher.h"
39 #include "fcp/client/secagg_runner.h"
40 #include "fcp/client/simple_task_environment.h"
41 #include "gmock/gmock.h"
42 #include "google/protobuf/duration.pb.h"
43 #include "tensorflow/core/example/example.pb.h"
44 #include "tensorflow/core/example/feature.pb.h"
45 
46 namespace fcp {
47 namespace client {
48 
49 class MockSecAggEventPublisher : public SecAggEventPublisher {
50  public:
51   MOCK_METHOD(void, PublishStateTransition,
52               (::fcp::secagg::ClientState state, size_t last_sent_message_size,
53                size_t last_received_message_size),
54               (override));
55   MOCK_METHOD(void, PublishError, (), (override));
56   MOCK_METHOD(void, PublishAbort,
57               (bool client_initiated, const std::string& error_message),
58               (override));
59   MOCK_METHOD(void, set_execution_session_id, (int64_t execution_session_id),
60               (override));
61 };
62 
63 class MockEventPublisher : public EventPublisher {
64  public:
65   MOCK_METHOD(void, PublishEligibilityEvalCheckin, (), (override));
66   MOCK_METHOD(void, PublishEligibilityEvalPlanUriReceived,
67               (const NetworkStats& network_stats,
68                absl::Duration phase_duration),
69               (override));
70   MOCK_METHOD(void, PublishEligibilityEvalPlanReceived,
71               (const NetworkStats& network_stats,
72                absl::Duration phase_duration),
73               (override));
74   MOCK_METHOD(void, PublishEligibilityEvalNotConfigured,
75               (const NetworkStats& network_stats,
76                absl::Duration phase_duration),
77               (override));
78   MOCK_METHOD(void, PublishEligibilityEvalRejected,
79               (const NetworkStats& network_stats,
80                absl::Duration phase_duration),
81               (override));
82   MOCK_METHOD(void, PublishCheckin, (), (override));
83   MOCK_METHOD(void, PublishCheckinFinished,
84               (const NetworkStats& network_stats,
85                absl::Duration phase_duration),
86               (override));
87   MOCK_METHOD(void, PublishRejected, (), (override));
88   MOCK_METHOD(void, PublishReportStarted, (int64_t report_size_bytes),
89               (override));
90   MOCK_METHOD(void, PublishReportFinished,
91               (const NetworkStats& network_stats,
92                absl::Duration report_duration),
93               (override));
94   MOCK_METHOD(void, PublishPlanExecutionStarted, (), (override));
95   MOCK_METHOD(void, PublishTensorFlowError,
96               (int example_count, absl::string_view error_message), (override));
97   MOCK_METHOD(void, PublishIoError, (absl::string_view error_message),
98               (override));
99   MOCK_METHOD(void, PublishExampleSelectorError,
100               (int example_count, absl::string_view error_message), (override));
101   MOCK_METHOD(void, PublishInterruption,
102               (const ExampleStats& example_stats, absl::Time start_time),
103               (override));
104   MOCK_METHOD(void, PublishPlanCompleted,
105               (const ExampleStats& example_stats, absl::Time start_time),
106               (override));
107   MOCK_METHOD(void, SetModelIdentifier, (const std::string& model_identifier),
108               (override));
109   MOCK_METHOD(void, PublishTaskNotStarted, (absl::string_view error_message),
110               (override));
111   MOCK_METHOD(void, PublishNonfatalInitializationError,
112               (absl::string_view error_message), (override));
113   MOCK_METHOD(void, PublishFatalInitializationError,
114               (absl::string_view error_message), (override));
115   MOCK_METHOD(void, PublishEligibilityEvalCheckinIoError,
116               (absl::string_view error_message,
117                const NetworkStats& network_stats,
118                absl::Duration phase_duration),
119               (override));
120   MOCK_METHOD(void, PublishEligibilityEvalCheckinClientInterrupted,
121               (absl::string_view error_message,
122                const NetworkStats& network_stats,
123                absl::Duration phase_duration),
124               (override));
125   MOCK_METHOD(void, PublishEligibilityEvalCheckinServerAborted,
126               (absl::string_view error_message,
127                const NetworkStats& network_stats,
128                absl::Duration phase_duration),
129               (override));
130   MOCK_METHOD(void, PublishEligibilityEvalCheckinErrorInvalidPayload,
131               (absl::string_view error_message,
132                const NetworkStats& network_stats,
133                absl::Duration phase_duration),
134               (override));
135   MOCK_METHOD(void, PublishEligibilityEvalComputationStarted, (), (override));
136   MOCK_METHOD(void, PublishEligibilityEvalComputationInvalidArgument,
137               (absl::string_view error_message,
138                const ExampleStats& example_stats,
139                absl::Duration phase_duration),
140               (override));
141   MOCK_METHOD(void, PublishEligibilityEvalComputationExampleIteratorError,
142               (absl::string_view error_message,
143                const ExampleStats& example_stats,
144                absl::Duration phase_duration),
145               (override));
146   MOCK_METHOD(void, PublishEligibilityEvalComputationTensorflowError,
147               (absl::string_view error_message,
148                const ExampleStats& example_stats,
149                absl::Duration phase_duration),
150               (override));
151   MOCK_METHOD(void, PublishEligibilityEvalComputationInterrupted,
152               (absl::string_view error_message,
153                const ExampleStats& example_stats,
154                absl::Duration phase_duration),
155               (override));
156   MOCK_METHOD(void, PublishEligibilityEvalComputationCompleted,
157               (const ExampleStats& example_stats,
158                absl::Duration phase_duration),
159               (override));
160   MOCK_METHOD(void, PublishCheckinIoError,
161               (absl::string_view error_message,
162                const NetworkStats& network_stats,
163                absl::Duration phase_duration),
164               (override));
165   MOCK_METHOD(void, PublishCheckinClientInterrupted,
166               (absl::string_view error_message,
167                const NetworkStats& network_stats,
168                absl::Duration phase_duration),
169               (override));
170   MOCK_METHOD(void, PublishCheckinServerAborted,
171               (absl::string_view error_message,
172                const NetworkStats& network_stats,
173                absl::Duration phase_duration),
174               (override));
175   MOCK_METHOD(void, PublishCheckinInvalidPayload,
176               (absl::string_view error_message,
177                const NetworkStats& network_stats,
178                absl::Duration phase_duration),
179               (override));
180   MOCK_METHOD(void, PublishRejected,
181               (const NetworkStats& network_stats,
182                absl::Duration phase_duration),
183               (override));
184   MOCK_METHOD(void, PublishCheckinPlanUriReceived,
185               (const NetworkStats& network_stats,
186                absl::Duration phase_duration),
187               (override));
188   MOCK_METHOD(void, PublishCheckinFinishedV2,
189               (const NetworkStats& network_stats,
190                absl::Duration phase_duration),
191               (override));
192   MOCK_METHOD(void, PublishComputationStarted, (), (override));
193   MOCK_METHOD(void, PublishComputationInvalidArgument,
194               (absl::string_view error_message,
195                const ExampleStats& example_stats,
196                const NetworkStats& network_stats,
197                absl::Duration phase_duration),
198               (override));
199   MOCK_METHOD(void, PublishComputationIOError,
200               (absl::string_view error_message,
201                const ExampleStats& example_stats,
202                const NetworkStats& network_stats,
203                absl::Duration phase_duration),
204               (override));
205   MOCK_METHOD(void, PublishComputationExampleIteratorError,
206               (absl::string_view error_message,
207                const ExampleStats& example_stats,
208                const NetworkStats& network_stats,
209                absl::Duration phase_duration),
210               (override));
211   MOCK_METHOD(void, PublishComputationTensorflowError,
212               (absl::string_view error_message,
213                const ExampleStats& example_stats,
214                const NetworkStats& network_stats,
215                absl::Duration phase_duration),
216               (override));
217   MOCK_METHOD(void, PublishComputationInterrupted,
218               (absl::string_view error_message,
219                const ExampleStats& example_stats,
220                const NetworkStats& network_stats,
221                absl::Duration phase_duration),
222               (override));
223   MOCK_METHOD(void, PublishComputationCompleted,
224               (const ExampleStats& example_stats,
225                const NetworkStats& network_stats,
226                absl::Duration phase_duration),
227               (override));
228   MOCK_METHOD(void, PublishResultUploadStarted, (), (override));
229   MOCK_METHOD(void, PublishResultUploadIOError,
230               (absl::string_view error_message,
231                const NetworkStats& network_stats,
232                absl::Duration phase_duration),
233               (override));
234   MOCK_METHOD(void, PublishResultUploadClientInterrupted,
235               (absl::string_view error_message,
236                const NetworkStats& network_stats,
237                absl::Duration phase_duration),
238               (override));
239   MOCK_METHOD(void, PublishResultUploadServerAborted,
240               (absl::string_view error_message,
241                const NetworkStats& network_stats,
242                absl::Duration phase_duration),
243               (override));
244   MOCK_METHOD(void, PublishResultUploadCompleted,
245               (const NetworkStats& network_stats,
246                absl::Duration phase_duration),
247               (override));
248   MOCK_METHOD(void, PublishFailureUploadStarted, (), (override));
249   MOCK_METHOD(void, PublishFailureUploadIOError,
250               (absl::string_view error_message,
251                const NetworkStats& network_stats,
252                absl::Duration phase_duration),
253               (override));
254   MOCK_METHOD(void, PublishFailureUploadClientInterrupted,
255               (absl::string_view error_message,
256                const NetworkStats& network_stats,
257                absl::Duration phase_duration),
258               (override));
259   MOCK_METHOD(void, PublishFailureUploadServerAborted,
260               (absl::string_view error_message,
261                const NetworkStats& network_stats,
262                absl::Duration phase_duration),
263               (override));
264   MOCK_METHOD(void, PublishFailureUploadCompleted,
265               (const NetworkStats& network_stats,
266                absl::Duration phase_duration),
267               (override));
268 
secagg_event_publisher()269   SecAggEventPublisher* secagg_event_publisher() override {
270     return &secagg_event_publisher_;
271   }
272 
273  private:
274   ::testing::NiceMock<MockSecAggEventPublisher> secagg_event_publisher_;
275 };
276 
277 // A mock FederatedProtocol implementation, which keeps track of the stages in
278 // the protocol and returns a different set of network stats and RetryWindow for
279 // each stage, making it easier to write accurate assertions in unit tests.
280 class MockFederatedProtocol : public FederatedProtocol {
281  public:
282   constexpr static NetworkStats
283       kPostEligibilityCheckinPlanUriReceivedNetworkStats = {
284           .bytes_downloaded = 280,
285           .bytes_uploaded = 380,
286           .network_duration = absl::Milliseconds(25)};
287   constexpr static NetworkStats kPostEligibilityCheckinNetworkStats = {
288       .bytes_downloaded = 300,
289       .bytes_uploaded = 400,
290       .network_duration = absl::Milliseconds(50)};
291   constexpr static NetworkStats kPostReportEligibilityEvalErrorNetworkStats = {
292       .bytes_downloaded = 400,
293       .bytes_uploaded = 500,
294       .network_duration = absl::Milliseconds(150)};
295   constexpr static NetworkStats kPostCheckinPlanUriReceivedNetworkStats = {
296       .bytes_downloaded = 2970,
297       .bytes_uploaded = 3970,
298       .network_duration = absl::Milliseconds(225)};
299   constexpr static NetworkStats kPostCheckinNetworkStats = {
300       .bytes_downloaded = 3000,
301       .bytes_uploaded = 4000,
302       .network_duration = absl::Milliseconds(250)};
303   constexpr static NetworkStats kPostReportCompletedNetworkStats = {
304       .bytes_downloaded = 30000,
305       .bytes_uploaded = 40000,
306       .network_duration = absl::Milliseconds(350)};
307   constexpr static NetworkStats kPostReportNotCompletedNetworkStats = {
308       .bytes_downloaded = 29999,
309       .bytes_uploaded = 39999,
310       .network_duration = absl::Milliseconds(450)};
311 
312   static google::internal::federatedml::v2::RetryWindow
GetInitialRetryWindow()313   GetInitialRetryWindow() {
314     google::internal::federatedml::v2::RetryWindow retry_window;
315     retry_window.mutable_delay_min()->set_seconds(0L);
316     retry_window.mutable_delay_max()->set_seconds(1L);
317     *retry_window.mutable_retry_token() = "INITIAL";
318     return retry_window;
319   }
320 
321   static google::internal::federatedml::v2::RetryWindow
GetPostEligibilityCheckinRetryWindow()322   GetPostEligibilityCheckinRetryWindow() {
323     google::internal::federatedml::v2::RetryWindow retry_window;
324     retry_window.mutable_delay_min()->set_seconds(100L);
325     retry_window.mutable_delay_max()->set_seconds(101L);
326     *retry_window.mutable_retry_token() = "POST_ELIGIBILITY";
327     return retry_window;
328   }
329 
330   static google::internal::federatedml::v2::RetryWindow
GetPostCheckinRetryWindow()331   GetPostCheckinRetryWindow() {
332     google::internal::federatedml::v2::RetryWindow retry_window;
333     retry_window.mutable_delay_min()->set_seconds(200L);
334     retry_window.mutable_delay_max()->set_seconds(201L);
335     *retry_window.mutable_retry_token() = "POST_CHECKIN";
336     return retry_window;
337   }
338 
339   static google::internal::federatedml::v2::RetryWindow
GetPostReportCompletedRetryWindow()340   GetPostReportCompletedRetryWindow() {
341     google::internal::federatedml::v2::RetryWindow retry_window;
342     retry_window.mutable_delay_min()->set_seconds(300L);
343     retry_window.mutable_delay_max()->set_seconds(301L);
344     *retry_window.mutable_retry_token() = "POST_REPORT_COMPLETED";
345     return retry_window;
346   }
347 
348   static google::internal::federatedml::v2::RetryWindow
GetPostReportNotCompletedRetryWindow()349   GetPostReportNotCompletedRetryWindow() {
350     google::internal::federatedml::v2::RetryWindow retry_window;
351     retry_window.mutable_delay_min()->set_seconds(400L);
352     retry_window.mutable_delay_max()->set_seconds(401L);
353     *retry_window.mutable_retry_token() = "POST_REPORT_NOT_COMPLETED";
354     return retry_window;
355   }
356 
MockFederatedProtocol()357   explicit MockFederatedProtocol() {}
358 
359   // We override the real FederatedProtocol methods so that we can intercept the
360   // progression of protocol stages, and expose dedicate gMock-overridable
361   // methods for use in tests.
EligibilityEvalCheckin(std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)362   absl::StatusOr<EligibilityEvalCheckinResult> EligibilityEvalCheckin(
363       std::function<void(const EligibilityEvalTask&)>
364           payload_uris_received_callback) final {
365     absl::StatusOr<EligibilityEvalCheckinResult> result =
366         MockEligibilityEvalCheckin();
367     if (result.ok() &&
368         std::holds_alternative<FederatedProtocol::EligibilityEvalTask>(
369             *result)) {
370       network_stats_ = kPostEligibilityCheckinPlanUriReceivedNetworkStats;
371       payload_uris_received_callback(
372           std::get<FederatedProtocol::EligibilityEvalTask>(*result));
373     }
374     network_stats_ = kPostEligibilityCheckinNetworkStats;
375     retry_window_ = GetPostEligibilityCheckinRetryWindow();
376     return result;
377   };
378   MOCK_METHOD(absl::StatusOr<EligibilityEvalCheckinResult>,
379               MockEligibilityEvalCheckin, ());
380 
ReportEligibilityEvalError(absl::Status error_status)381   void ReportEligibilityEvalError(absl::Status error_status) final {
382     network_stats_ = kPostReportEligibilityEvalErrorNetworkStats;
383     retry_window_ = GetPostEligibilityCheckinRetryWindow();
384     MockReportEligibilityEvalError(error_status);
385   }
386   MOCK_METHOD(void, MockReportEligibilityEvalError,
387               (absl::Status error_status));
388 
Checkin(const std::optional<::google::internal::federatedml::v2::TaskEligibilityInfo> & task_eligibility_info,std::function<void (const FederatedProtocol::TaskAssignment &)> payload_uris_received_callback)389   absl::StatusOr<CheckinResult> Checkin(
390       const std::optional<
391           ::google::internal::federatedml::v2::TaskEligibilityInfo>&
392           task_eligibility_info,
393       std::function<void(const FederatedProtocol::TaskAssignment&)>
394           payload_uris_received_callback) final {
395     absl::StatusOr<CheckinResult> result = MockCheckin(task_eligibility_info);
396     if (result.ok() &&
397         std::holds_alternative<FederatedProtocol::TaskAssignment>(*result)) {
398       network_stats_ = kPostCheckinPlanUriReceivedNetworkStats;
399       payload_uris_received_callback(
400           std::get<FederatedProtocol::TaskAssignment>(*result));
401     }
402     retry_window_ = GetPostCheckinRetryWindow();
403     network_stats_ = kPostCheckinNetworkStats;
404     return result;
405   };
406   MOCK_METHOD(absl::StatusOr<CheckinResult>, MockCheckin,
407               (const std::optional<
408                   ::google::internal::federatedml::v2::TaskEligibilityInfo>&
409                    task_eligibility_info));
410 
PerformMultipleTaskAssignments(const std::vector<std::string> & task_names)411   absl::StatusOr<MultipleTaskAssignments> PerformMultipleTaskAssignments(
412       const std::vector<std::string>& task_names) final {
413     absl::StatusOr<MultipleTaskAssignments> result =
414         MockPerformMultipleTaskAssignments(task_names);
415     retry_window_ = GetPostCheckinRetryWindow();
416     network_stats_ = kPostCheckinPlanUriReceivedNetworkStats;
417     return result;
418   };
419 
420   MOCK_METHOD(absl::StatusOr<MultipleTaskAssignments>,
421               MockPerformMultipleTaskAssignments,
422               (const std::vector<std::string>& task_names));
423 
ReportCompleted(ComputationResults results,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)424   absl::Status ReportCompleted(
425       ComputationResults results, absl::Duration plan_duration,
426       std::optional<std::string> aggregation_session_id) final {
427     network_stats_ = kPostReportCompletedNetworkStats;
428     retry_window_ = GetPostReportCompletedRetryWindow();
429     return MockReportCompleted(std::move(results), plan_duration,
430                                aggregation_session_id);
431   };
432   MOCK_METHOD(absl::Status, MockReportCompleted,
433               (ComputationResults results, absl::Duration plan_duration,
434                std::optional<std::string> aggregation_session_id));
435 
ReportNotCompleted(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)436   absl::Status ReportNotCompleted(
437       engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
438       std::optional<std::string> aggregation_session_id) final {
439     network_stats_ = kPostReportNotCompletedNetworkStats;
440     retry_window_ = GetPostReportNotCompletedRetryWindow();
441     return MockReportNotCompleted(phase_outcome, plan_duration,
442                                   aggregation_session_id);
443   };
444   MOCK_METHOD(absl::Status, MockReportNotCompleted,
445               (engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
446                std::optional<std::string> aggregation_session_id));
447 
GetLatestRetryWindow()448   ::google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
449       final {
450     return retry_window_;
451   }
452 
GetNetworkStats()453   NetworkStats GetNetworkStats() final { return network_stats_; }
454 
455  private:
456   NetworkStats network_stats_;
457   ::google::internal::federatedml::v2::RetryWindow retry_window_ =
458       GetInitialRetryWindow();
459 };
460 
461 class MockLogManager : public LogManager {
462  public:
463   MOCK_METHOD(void, LogDiag, (ProdDiagCode), (override));
464   MOCK_METHOD(void, LogDiag, (DebugDiagCode), (override));
465   MOCK_METHOD(void, LogToLongHistogram,
466               (fcp::client::HistogramCounters, int, int,
467                fcp::client::engine::DataSourceType, int64_t),
468               (override));
469   MOCK_METHOD(void, SetModelIdentifier, (const std::string&), (override));
470 };
471 
472 class MockOpStatsLogger : public ::fcp::client::opstats::OpStatsLogger {
473  public:
474   MOCK_METHOD(
475       void, AddEventAndSetTaskName,
476       (const std::string& task_name,
477        ::fcp::client::opstats::OperationalStats::Event::EventKind event),
478       (override));
479   MOCK_METHOD(
480       void, AddEvent,
481       (::fcp::client::opstats::OperationalStats::Event::EventKind event),
482       (override));
483   MOCK_METHOD(void, AddEventWithErrorMessage,
484               (::fcp::client::opstats::OperationalStats::Event::EventKind event,
485                const std::string& error_message),
486               (override));
487   MOCK_METHOD(void, UpdateDatasetStats,
488               (const std::string& collection_uri, int additional_example_count,
489                int64_t additional_example_size_bytes),
490               (override));
491   MOCK_METHOD(void, SetNetworkStats, (const NetworkStats& network_stats),
492               (override));
493   MOCK_METHOD(void, SetRetryWindow,
494               (google::internal::federatedml::v2::RetryWindow retry_window),
495               (override));
496   MOCK_METHOD(::fcp::client::opstats::OpStatsDb*, GetOpStatsDb, (), (override));
497   MOCK_METHOD(bool, IsOpStatsEnabled, (), (const override));
498   MOCK_METHOD(absl::Status, CommitToStorage, (), (override));
499   MOCK_METHOD(std::string, GetCurrentTaskName, (), (override));
500 };
501 
502 class MockSimpleTaskEnvironment : public SimpleTaskEnvironment {
503  public:
504   MOCK_METHOD(std::string, GetBaseDir, (), (override));
505   MOCK_METHOD(std::string, GetCacheDir, (), (override));
506   MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>),
507               CreateExampleIterator,
508               (const google::internal::federated::plan::ExampleSelector&
509                    example_selector),
510               (override));
511   MOCK_METHOD((absl::StatusOr<std::unique_ptr<ExampleIterator>>),
512               CreateExampleIterator,
513               (const google::internal::federated::plan::ExampleSelector&
514                    example_selector,
515                const SelectorContext& selector_context),
516               (override));
517   MOCK_METHOD(std::unique_ptr<fcp::client::http::HttpClient>, CreateHttpClient,
518               (), (override));
519   MOCK_METHOD(bool, TrainingConditionsSatisfied, (), (override));
520 };
521 
522 class MockExampleIterator : public ExampleIterator {
523  public:
524   MOCK_METHOD(absl::StatusOr<std::string>, Next, (), (override));
525   MOCK_METHOD(void, Close, (), (override));
526 };
527 
528 // An iterator that passes through each example in the dataset once.
529 class SimpleExampleIterator : public ExampleIterator {
530  public:
531   // Uses the given bytes as the examples to return.
532   explicit SimpleExampleIterator(std::vector<const char*> examples);
533   // Passes through each of the examples in the `Dataset.client_data.example`
534   // field.
535   explicit SimpleExampleIterator(
536       google::internal::federated::plan::Dataset dataset);
537   // Passes through each of the examples in the
538   // `Dataset.client_data.selected_example.example` field, whose example
539   // collection URI matches the provided `collection_uri`.
540   SimpleExampleIterator(google::internal::federated::plan::Dataset dataset,
541                         absl::string_view collection_uri);
542   absl::StatusOr<std::string> Next() override;
Close()543   void Close() override {}
544 
545  protected:
546   std::vector<std::string> examples_;
547   int index_ = 0;
548 };
549 
550 struct ComputationArtifacts {
551   // The path to the file containing the plan data.
552   std::string plan_filepath;
553   // The already-parsed plan data.
554   google::internal::federated::plan::ClientOnlyPlan plan;
555   // The test dataset.
556   google::internal::federated::plan::Dataset dataset;
557   // The path to the file containing the initial checkpoint data (not set for
558   // local compute task artifacts).
559   std::string checkpoint_filepath;
560   // The initial checkpoint data, as a string (not set for local compute task
561   // artifacts).
562   std::string checkpoint;
563   // The Federated Select slice data (not set for local compute task artifacts).
564   google::internal::federated::plan::SlicesTestDataset federated_select_slices;
565 };
566 
567 absl::StatusOr<ComputationArtifacts> LoadFlArtifacts();
568 
569 class MockFlags : public Flags {
570  public:
571   MOCK_METHOD(int64_t, condition_polling_period_millis, (), (const, override));
572   MOCK_METHOD(int64_t, tf_execution_teardown_grace_period_millis, (),
573               (const, override));
574   MOCK_METHOD(int64_t, tf_execution_teardown_extended_period_millis, (),
575               (const, override));
576   MOCK_METHOD(int64_t, grpc_channel_deadline_seconds, (), (const, override));
577   MOCK_METHOD(bool, log_tensorflow_error_messages, (), (const, override));
578   MOCK_METHOD(bool, enable_opstats, (), (const, override));
579   MOCK_METHOD(int64_t, opstats_ttl_days, (), (const, override));
580   MOCK_METHOD(int64_t, opstats_db_size_limit_bytes, (), (const, override));
581   MOCK_METHOD(int64_t, federated_training_transient_errors_retry_delay_secs, (),
582               (const, override));
583   MOCK_METHOD(float,
584               federated_training_transient_errors_retry_delay_jitter_percent,
585               (), (const, override));
586   MOCK_METHOD(int64_t, federated_training_permanent_errors_retry_delay_secs, (),
587               (const, override));
588   MOCK_METHOD(float,
589               federated_training_permanent_errors_retry_delay_jitter_percent,
590               (), (const, override));
591   MOCK_METHOD(std::vector<int32_t>, federated_training_permanent_error_codes,
592               (), (const, override));
593   MOCK_METHOD(bool, use_tflite_training, (), (const, override));
594   MOCK_METHOD(bool, enable_grpc_with_http_resource_support, (),
595               (const, override));
596   MOCK_METHOD(bool, enable_grpc_with_eligibility_eval_http_resource_support, (),
597               (const, override));
598   MOCK_METHOD(bool, ensure_dynamic_tensors_are_released, (), (const, override));
599   MOCK_METHOD(int32_t, large_tensor_threshold_for_dynamic_allocation, (),
600               (const, override));
601   MOCK_METHOD(bool, disable_http_request_body_compression, (),
602               (const, override));
603   MOCK_METHOD(bool, use_http_federated_compute_protocol, (), (const, override));
604   MOCK_METHOD(bool, enable_computation_id, (), (const, override));
605   MOCK_METHOD(int32_t, waiting_period_sec_for_cancellation, (),
606               (const, override));
607   MOCK_METHOD(bool, enable_federated_select, (), (const, override));
608   MOCK_METHOD(int32_t, num_threads_for_tflite, (), (const, override));
609   MOCK_METHOD(bool, disable_tflite_delegate_clustering, (), (const, override));
610   MOCK_METHOD(bool, enable_example_query_plan_engine, (), (const, override));
611   MOCK_METHOD(bool, support_constant_tf_inputs, (), (const, override));
612   MOCK_METHOD(bool, http_protocol_supports_multiple_task_assignments, (),
613               (const, override));
614 };
615 
616 // Helper methods for extracting opstats fields from TF examples.
617 std::string ExtractSingleString(const tensorflow::Example& example,
618                                 const char key[]);
619 google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
620     const tensorflow::Example& example, const char key[]);
621 int64_t ExtractSingleInt64(const tensorflow::Example& example,
622                            const char key[]);
623 google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
624     const tensorflow::Example& example, const char key[]);
625 
626 class MockOpStatsDb : public ::fcp::client::opstats::OpStatsDb {
627  public:
628   MOCK_METHOD(absl::StatusOr<::fcp::client::opstats::OpStatsSequence>, Read, (),
629               (override));
630   MOCK_METHOD(absl::Status, Transform,
631               (std::function<void(::fcp::client::opstats::OpStatsSequence&)>),
632               (override));
633 };
634 
635 class MockPhaseLogger : public PhaseLogger {
636  public:
637   MOCK_METHOD(
638       void, UpdateRetryWindowAndNetworkStats,
639       (const ::google::internal::federatedml::v2::RetryWindow& retry_window,
640        const NetworkStats& network_stats),
641       (override));
642   MOCK_METHOD(void, SetModelIdentifier, (absl::string_view model_identifier),
643               (override));
644   MOCK_METHOD(void, LogTaskNotStarted, (absl::string_view error_message),
645               (override));
646   MOCK_METHOD(void, LogNonfatalInitializationError, (absl::Status error_status),
647               (override));
648   MOCK_METHOD(void, LogFatalInitializationError, (absl::Status error_status),
649               (override));
650   MOCK_METHOD(void, LogEligibilityEvalCheckinStarted, (), (override));
651   MOCK_METHOD(void, LogEligibilityEvalCheckinIOError,
652               (absl::Status error_status, const NetworkStats& network_stats,
653                absl::Time time_before_checkin),
654               (override));
655   MOCK_METHOD(void, LogEligibilityEvalCheckinInvalidPayloadError,
656               (absl::string_view error_message,
657                const NetworkStats& network_stats,
658                absl::Time time_before_checkin),
659               (override));
660   MOCK_METHOD(void, LogEligibilityEvalCheckinClientInterrupted,
661               (absl::Status error_status, const NetworkStats& network_stats,
662                absl::Time time_before_checkin),
663               (override));
664   MOCK_METHOD(void, LogEligibilityEvalCheckinServerAborted,
665               (absl::Status error_status, const NetworkStats& network_stats,
666                absl::Time time_before_checkin),
667               (override));
668   MOCK_METHOD(void, LogEligibilityEvalNotConfigured,
669               (const NetworkStats& network_stats,
670                absl::Time time_before_checkin),
671               (override));
672   MOCK_METHOD(void, LogEligibilityEvalCheckinTurnedAway,
673               (const NetworkStats& network_stats,
674                absl::Time time_before_checkin),
675               (override));
676   MOCK_METHOD(void, LogEligibilityEvalCheckinPlanUriReceived,
677               (const NetworkStats& network_stats,
678                absl::Time time_before_checkin),
679               (override));
680   MOCK_METHOD(void, LogEligibilityEvalCheckinCompleted,
681               (const NetworkStats& network_stats,
682                absl::Time time_before_checkin,
683                absl::Time time_before_plan_download),
684               (override));
685   MOCK_METHOD(void, LogEligibilityEvalComputationStarted, (), (override));
686   MOCK_METHOD(void, LogEligibilityEvalComputationInvalidArgument,
687               (absl::Status error_status, const ExampleStats& example_stats,
688                absl::Time run_plan_start_time),
689               (override));
690   MOCK_METHOD(void, LogEligibilityEvalComputationExampleIteratorError,
691               (absl::Status error_status, const ExampleStats& example_stats,
692                absl::Time run_plan_start_time),
693               (override));
694   MOCK_METHOD(void, LogEligibilityEvalComputationTensorflowError,
695               (absl::Status error_status, const ExampleStats& example_stats,
696                absl::Time run_plan_start_time, absl::Time reference_time),
697               (override));
698   MOCK_METHOD(void, LogEligibilityEvalComputationInterrupted,
699               (absl::Status error_status, const ExampleStats& example_stats,
700                absl::Time run_plan_start_time, absl::Time reference_time),
701               (override));
702   MOCK_METHOD(void, LogEligibilityEvalComputationCompleted,
703               (const ExampleStats& example_stats,
704                absl::Time run_plan_start_time, absl::Time reference_time),
705               (override));
706   MOCK_METHOD(void, LogCheckinStarted, (), (override));
707   MOCK_METHOD(void, LogCheckinIOError,
708               (absl::Status error_status, const NetworkStats& network_stats,
709                absl::Time time_before_checkin, absl::Time reference_time),
710               (override));
711   MOCK_METHOD(void, LogCheckinInvalidPayload,
712               (absl::string_view error_message,
713                const NetworkStats& network_stats,
714                absl::Time time_before_checkin, absl::Time reference_time),
715               (override));
716   MOCK_METHOD(void, LogCheckinClientInterrupted,
717               (absl::Status error_status, const NetworkStats& network_stats,
718                absl::Time time_before_checkin, absl::Time reference_time),
719               (override));
720   MOCK_METHOD(void, LogCheckinServerAborted,
721               (absl::Status error_status, const NetworkStats& network_stats,
722                absl::Time time_before_checkin, absl::Time reference_time),
723               (override));
724   MOCK_METHOD(void, LogCheckinTurnedAway,
725               (const NetworkStats& network_stats,
726                absl::Time time_before_checkin, absl::Time reference_time),
727               (override));
728   MOCK_METHOD(void, LogCheckinPlanUriReceived,
729               (absl::string_view task_name, const NetworkStats& network_stats,
730                absl::Time time_before_checkin),
731               (override));
732   MOCK_METHOD(void, LogCheckinCompleted,
733               (absl::string_view task_name, const NetworkStats& network_stats,
734                absl::Time time_before_checkin,
735                absl::Time time_before_plan_download, absl::Time reference_time),
736               (override));
737   MOCK_METHOD(void, LogComputationStarted, (), (override));
738   MOCK_METHOD(void, LogComputationInvalidArgument,
739               (absl::Status error_status, const ExampleStats& example_stats,
740                const NetworkStats& network_stats,
741                absl::Time run_plan_start_time),
742               (override));
743   MOCK_METHOD(void, LogComputationExampleIteratorError,
744               (absl::Status error_status, const ExampleStats& example_stats,
745                const NetworkStats& network_stats,
746                absl::Time run_plan_start_time),
747               (override));
748   MOCK_METHOD(void, LogComputationIOError,
749               (absl::Status error_status, const ExampleStats& example_stats,
750                const NetworkStats& network_stats,
751                absl::Time run_plan_start_time),
752               (override));
753   MOCK_METHOD(void, LogComputationTensorflowError,
754               (absl::Status error_status, const ExampleStats& example_stats,
755                const NetworkStats& network_stats,
756                absl::Time run_plan_start_time, absl::Time reference_time),
757               (override));
758   MOCK_METHOD(void, LogComputationInterrupted,
759               (absl::Status error_status, const ExampleStats& example_stats,
760                const NetworkStats& network_stats,
761                absl::Time run_plan_start_time, absl::Time reference_time),
762               (override));
763   MOCK_METHOD(void, LogComputationCompleted,
764               (const ExampleStats& example_stats,
765                const NetworkStats& network_stats,
766                absl::Time run_plan_start_time, absl::Time reference_time),
767               (override));
768   MOCK_METHOD(absl::Status, LogResultUploadStarted, (), (override));
769   MOCK_METHOD(void, LogResultUploadIOError,
770               (absl::Status error_status, const NetworkStats& network_stats,
771                absl::Time time_before_result_upload, absl::Time reference_time),
772               (override));
773   MOCK_METHOD(void, LogResultUploadClientInterrupted,
774               (absl::Status error_status, const NetworkStats& network_stats,
775                absl::Time time_before_result_upload, absl::Time reference_time),
776               (override));
777   MOCK_METHOD(void, LogResultUploadServerAborted,
778               (absl::Status error_status, const NetworkStats& network_stats,
779                absl::Time time_before_result_upload, absl::Time reference_time),
780               (override));
781   MOCK_METHOD(void, LogResultUploadCompleted,
782               (const NetworkStats& network_stats,
783                absl::Time time_before_result_upload, absl::Time reference_time),
784               (override));
785   MOCK_METHOD(absl::Status, LogFailureUploadStarted, (), (override));
786   MOCK_METHOD(void, LogFailureUploadIOError,
787               (absl::Status error_status, const NetworkStats& network_stats,
788                absl::Time time_before_failure_upload,
789                absl::Time reference_time),
790               (override));
791   MOCK_METHOD(void, LogFailureUploadClientInterrupted,
792               (absl::Status error_status, const NetworkStats& network_stats,
793                absl::Time time_before_failure_upload,
794                absl::Time reference_time),
795               (override));
796   MOCK_METHOD(void, LogFailureUploadServerAborted,
797               (absl::Status error_status, const NetworkStats& network_stats,
798                absl::Time time_before_failure_upload,
799                absl::Time reference_time),
800               (override));
801   MOCK_METHOD(void, LogFailureUploadCompleted,
802               (const NetworkStats& network_stats,
803                absl::Time time_before_result_upload, absl::Time reference_time),
804               (override));
805 };
806 
807 class MockFederatedSelectManager : public FederatedSelectManager {
808  public:
809   MOCK_METHOD(std::unique_ptr<engine::ExampleIteratorFactory>,
810               CreateExampleIteratorFactoryForUriTemplate,
811               (absl::string_view uri_template), (override));
812 
813   MOCK_METHOD(NetworkStats, GetNetworkStats, (), (override));
814 };
815 
816 class MockFederatedSelectExampleIteratorFactory
817     : public FederatedSelectExampleIteratorFactory {
818  public:
819   MOCK_METHOD(absl::StatusOr<std::unique_ptr<ExampleIterator>>,
820               CreateExampleIterator,
821               (const ::google::internal::federated::plan::ExampleSelector&
822                    example_selector),
823               (override));
824 };
825 
826 class MockSecAggRunnerFactory : public SecAggRunnerFactory {
827  public:
828   MOCK_METHOD(std::unique_ptr<SecAggRunner>, CreateSecAggRunner,
829               (std::unique_ptr<SecAggSendToServerBase> send_to_server_impl,
830                std::unique_ptr<SecAggProtocolDelegate> protocol_delegate,
831                SecAggEventPublisher* secagg_event_publisher,
832                LogManager* log_manager,
833                InterruptibleRunner* interruptible_runner,
834                int64_t expected_number_of_clients,
835                int64_t minimum_surviving_clients_for_reconstruction),
836               (override));
837 };
838 
839 class MockSecAggRunner : public SecAggRunner {
840  public:
841   MOCK_METHOD(absl::Status, Run, (ComputationResults results), (override));
842 };
843 
844 class MockSecAggSendToServerBase : public SecAggSendToServerBase {
845   MOCK_METHOD(void, Send, (secagg::ClientToServerWrapperMessage * message),
846               (override));
847 };
848 
849 class MockSecAggProtocolDelegate : public SecAggProtocolDelegate {
850  public:
851   MOCK_METHOD(absl::StatusOr<uint64_t>, GetModulus, (const std::string& key),
852               (override));
853   MOCK_METHOD(absl::StatusOr<secagg::ServerToClientWrapperMessage>,
854               ReceiveServerMessage, (), (override));
855   MOCK_METHOD(void, Abort, (), (override));
856 };
857 
858 }  // namespace client
859 }  // namespace fcp
860 
861 #endif  // FCP_CLIENT_TEST_HELPERS_H_
862