xref: /aosp_15_r20/external/federated-compute/fcp/client/client_runner_main.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <fstream>
18 #include <optional>
19 #include <string>
20 #include <utility>
21 
22 
23 #include "absl/flags/flag.h"
24 #include "absl/flags/parse.h"
25 #include "absl/flags/usage.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/str_split.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/client/client_runner.h"
31 #include "fcp/client/client_runner_example_data.pb.h"
32 #include "fcp/client/fake_event_publisher.h"
33 #include "fcp/client/fl_runner.h"
34 
35 ABSL_FLAG(std::string, server, "",
36           "Federated Server URI (supports https+test:// and https:// URIs");
37 ABSL_FLAG(std::string, api_key, "", "API Key");
38 ABSL_FLAG(std::string, test_cert, "",
39           "Path to test CA certificate PEM file; used for https+test:// URIs");
40 ABSL_FLAG(std::string, session, "", "Session name");
41 ABSL_FLAG(std::string, population, "", "Population name");
42 ABSL_FLAG(std::string, retry_token, "", "Retry token");
43 ABSL_FLAG(std::string, client_version, "", "Client version");
44 ABSL_FLAG(std::string, attestation_string, "", "Attestation string");
45 ABSL_FLAG(std::string, example_data_path, "",
46           "Path to a serialized ClientRunnerExampleData proto with client "
47           "example data. Falls back to --num_empty_examples if unset.");
48 ABSL_FLAG(int, num_empty_examples, 0,
49           "Number of (empty) examples each created iterator serves. Ignored if "
50           "--example_store_path is set.");
51 ABSL_FLAG(int, num_rounds, 1, "Number of rounds to train");
52 ABSL_FLAG(int, sleep_after_round_secs, 3,
53           "Number of seconds to sleep after each round.");
54 ABSL_FLAG(bool, use_http_federated_compute_protocol, false,
55           "Whether to enable the HTTP FederatedCompute protocol instead "
56           "of the gRPC FederatedTrainingApi protocol.");
57 ABSL_FLAG(bool, use_tflite_training, false, "Whether use TFLite for training.");
58 
59 static constexpr char kUsageString[] =
60     "Stand-alone Federated Client Executable.\n\n"
61     "Connects to the specified server, tries to retrieve a plan, run the\n"
62     "plan (feeding the specified number of empty examples), and report the\n"
63     "results of the computation back to the server.";
64 
LoadExampleData(const std::string & examples_path)65 static absl::StatusOr<fcp::client::ClientRunnerExampleData> LoadExampleData(
66     const std::string& examples_path) {
67   std::ifstream examples_file(examples_path);
68   fcp::client::ClientRunnerExampleData data;
69   if (!data.ParseFromIstream(&examples_file) || !examples_file.eof()) {
70     return absl::InvalidArgumentError(
71         "Failed to parse ClientRunnerExampleData");
72   }
73   return data;
74 }
75 
main(int argc,char ** argv)76 int main(int argc, char** argv) {
77   absl::SetProgramUsageMessage(kUsageString);
78   absl::ParseCommandLine(argc, argv);
79 
80   int num_rounds = absl::GetFlag(FLAGS_num_rounds);
81   std::string server = absl::GetFlag(FLAGS_server);
82   std::string session = absl::GetFlag(FLAGS_session);
83   std::string population = absl::GetFlag(FLAGS_population);
84   std::string client_version = absl::GetFlag(FLAGS_client_version);
85   std::string test_cert = absl::GetFlag(FLAGS_test_cert);
86   FCP_LOG(INFO) << "Running for " << num_rounds << " rounds:";
87   FCP_LOG(INFO) << " - server:         " << server;
88   FCP_LOG(INFO) << " - session:        " << session;
89   FCP_LOG(INFO) << " - population:     " << population;
90   FCP_LOG(INFO) << " - client_version: " << client_version;
91 
92   std::optional<fcp::client::ClientRunnerExampleData> example_data;
93   if (std::string path = absl::GetFlag(FLAGS_example_data_path);
94       !path.empty()) {
95     auto statusor = LoadExampleData(path);
96     if (!statusor.ok()) {
97       FCP_LOG(ERROR) << "Failed to load example data: " << statusor.status();
98       return 1;
99     }
100     example_data = *std::move(statusor);
101   }
102 
103   bool success = false;
104   for (auto i = 0; i < num_rounds || num_rounds < 0; ++i) {
105     fcp::client::FederatedTaskEnvDepsImpl federated_task_env_deps_impl =
106         example_data
107             ? fcp::client::FederatedTaskEnvDepsImpl(*example_data, test_cert)
108             : fcp::client::FederatedTaskEnvDepsImpl(
109                   absl::GetFlag(FLAGS_num_empty_examples), test_cert);
110     fcp::client::FakeEventPublisher event_publisher(/*quiet=*/false);
111     fcp::client::FilesImpl files_impl;
112     fcp::client::LogManagerImpl log_manager_impl;
113     fcp::client::FlagsImpl flags;
114     flags.set_use_http_federated_compute_protocol(
115         absl::GetFlag(FLAGS_use_http_federated_compute_protocol));
116     flags.set_use_tflite_training(absl::GetFlag(FLAGS_use_tflite_training));
117 
118     auto fl_runner_result = RunFederatedComputation(
119         &federated_task_env_deps_impl, &event_publisher, &files_impl,
120         &log_manager_impl, &flags, server, absl::GetFlag(FLAGS_api_key),
121         test_cert, session, population, absl::GetFlag(FLAGS_retry_token),
122         client_version, absl::GetFlag(FLAGS_attestation_string));
123     if (fl_runner_result.ok()) {
124       FCP_LOG(INFO) << "Run finished successfully; result: "
125                     << fl_runner_result.value().DebugString();
126       success = true;
127     } else {
128       FCP_LOG(ERROR) << "Error during run: " << fl_runner_result.status();
129     }
130     int sleep_secs = absl::GetFlag(FLAGS_sleep_after_round_secs);
131     FCP_LOG(INFO) << "Sleeping for " << sleep_secs << " secs";
132     absl::SleepFor(absl::Seconds(sleep_secs));
133   }
134   return success ? 0 : 1;
135 }
136