1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <fstream>
18 #include <optional>
19 #include <string>
20 #include <utility>
21
22
23 #include "absl/flags/flag.h"
24 #include "absl/flags/parse.h"
25 #include "absl/flags/usage.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/str_split.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/client/client_runner.h"
31 #include "fcp/client/client_runner_example_data.pb.h"
32 #include "fcp/client/fake_event_publisher.h"
33 #include "fcp/client/fl_runner.h"
34
35 ABSL_FLAG(std::string, server, "",
36 "Federated Server URI (supports https+test:// and https:// URIs");
37 ABSL_FLAG(std::string, api_key, "", "API Key");
38 ABSL_FLAG(std::string, test_cert, "",
39 "Path to test CA certificate PEM file; used for https+test:// URIs");
40 ABSL_FLAG(std::string, session, "", "Session name");
41 ABSL_FLAG(std::string, population, "", "Population name");
42 ABSL_FLAG(std::string, retry_token, "", "Retry token");
43 ABSL_FLAG(std::string, client_version, "", "Client version");
44 ABSL_FLAG(std::string, attestation_string, "", "Attestation string");
45 ABSL_FLAG(std::string, example_data_path, "",
46 "Path to a serialized ClientRunnerExampleData proto with client "
47 "example data. Falls back to --num_empty_examples if unset.");
48 ABSL_FLAG(int, num_empty_examples, 0,
49 "Number of (empty) examples each created iterator serves. Ignored if "
50 "--example_store_path is set.");
51 ABSL_FLAG(int, num_rounds, 1, "Number of rounds to train");
52 ABSL_FLAG(int, sleep_after_round_secs, 3,
53 "Number of seconds to sleep after each round.");
54 ABSL_FLAG(bool, use_http_federated_compute_protocol, false,
55 "Whether to enable the HTTP FederatedCompute protocol instead "
56 "of the gRPC FederatedTrainingApi protocol.");
57 ABSL_FLAG(bool, use_tflite_training, false, "Whether use TFLite for training.");
58
59 static constexpr char kUsageString[] =
60 "Stand-alone Federated Client Executable.\n\n"
61 "Connects to the specified server, tries to retrieve a plan, run the\n"
62 "plan (feeding the specified number of empty examples), and report the\n"
63 "results of the computation back to the server.";
64
LoadExampleData(const std::string & examples_path)65 static absl::StatusOr<fcp::client::ClientRunnerExampleData> LoadExampleData(
66 const std::string& examples_path) {
67 std::ifstream examples_file(examples_path);
68 fcp::client::ClientRunnerExampleData data;
69 if (!data.ParseFromIstream(&examples_file) || !examples_file.eof()) {
70 return absl::InvalidArgumentError(
71 "Failed to parse ClientRunnerExampleData");
72 }
73 return data;
74 }
75
main(int argc,char ** argv)76 int main(int argc, char** argv) {
77 absl::SetProgramUsageMessage(kUsageString);
78 absl::ParseCommandLine(argc, argv);
79
80 int num_rounds = absl::GetFlag(FLAGS_num_rounds);
81 std::string server = absl::GetFlag(FLAGS_server);
82 std::string session = absl::GetFlag(FLAGS_session);
83 std::string population = absl::GetFlag(FLAGS_population);
84 std::string client_version = absl::GetFlag(FLAGS_client_version);
85 std::string test_cert = absl::GetFlag(FLAGS_test_cert);
86 FCP_LOG(INFO) << "Running for " << num_rounds << " rounds:";
87 FCP_LOG(INFO) << " - server: " << server;
88 FCP_LOG(INFO) << " - session: " << session;
89 FCP_LOG(INFO) << " - population: " << population;
90 FCP_LOG(INFO) << " - client_version: " << client_version;
91
92 std::optional<fcp::client::ClientRunnerExampleData> example_data;
93 if (std::string path = absl::GetFlag(FLAGS_example_data_path);
94 !path.empty()) {
95 auto statusor = LoadExampleData(path);
96 if (!statusor.ok()) {
97 FCP_LOG(ERROR) << "Failed to load example data: " << statusor.status();
98 return 1;
99 }
100 example_data = *std::move(statusor);
101 }
102
103 bool success = false;
104 for (auto i = 0; i < num_rounds || num_rounds < 0; ++i) {
105 fcp::client::FederatedTaskEnvDepsImpl federated_task_env_deps_impl =
106 example_data
107 ? fcp::client::FederatedTaskEnvDepsImpl(*example_data, test_cert)
108 : fcp::client::FederatedTaskEnvDepsImpl(
109 absl::GetFlag(FLAGS_num_empty_examples), test_cert);
110 fcp::client::FakeEventPublisher event_publisher(/*quiet=*/false);
111 fcp::client::FilesImpl files_impl;
112 fcp::client::LogManagerImpl log_manager_impl;
113 fcp::client::FlagsImpl flags;
114 flags.set_use_http_federated_compute_protocol(
115 absl::GetFlag(FLAGS_use_http_federated_compute_protocol));
116 flags.set_use_tflite_training(absl::GetFlag(FLAGS_use_tflite_training));
117
118 auto fl_runner_result = RunFederatedComputation(
119 &federated_task_env_deps_impl, &event_publisher, &files_impl,
120 &log_manager_impl, &flags, server, absl::GetFlag(FLAGS_api_key),
121 test_cert, session, population, absl::GetFlag(FLAGS_retry_token),
122 client_version, absl::GetFlag(FLAGS_attestation_string));
123 if (fl_runner_result.ok()) {
124 FCP_LOG(INFO) << "Run finished successfully; result: "
125 << fl_runner_result.value().DebugString();
126 success = true;
127 } else {
128 FCP_LOG(ERROR) << "Error during run: " << fl_runner_result.status();
129 }
130 int sleep_secs = absl::GetFlag(FLAGS_sleep_after_round_secs);
131 FCP_LOG(INFO) << "Sleeping for " << sleep_secs << " secs";
132 absl::SleepFor(absl::Seconds(sleep_secs));
133 }
134 return success ? 0 : 1;
135 }
136