xref: /aosp_15_r20/external/federated-compute/fcp/client/client_runner.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_CLIENT_RUNNER_H_
18*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_CLIENT_RUNNER_H_
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include <cxxabi.h>
21*14675a02SAndroid Build Coastguard Worker #include <fcntl.h>
22*14675a02SAndroid Build Coastguard Worker #include <sys/stat.h>
23*14675a02SAndroid Build Coastguard Worker #include <sys/types.h>
24*14675a02SAndroid Build Coastguard Worker 
25*14675a02SAndroid Build Coastguard Worker #include <array>
26*14675a02SAndroid Build Coastguard Worker #include <cstdint>
27*14675a02SAndroid Build Coastguard Worker #include <cstdlib>
28*14675a02SAndroid Build Coastguard Worker #include <ctime>
29*14675a02SAndroid Build Coastguard Worker #include <filesystem>
30*14675a02SAndroid Build Coastguard Worker #include <fstream>
31*14675a02SAndroid Build Coastguard Worker #include <memory>
32*14675a02SAndroid Build Coastguard Worker #include <string>
33*14675a02SAndroid Build Coastguard Worker #include <string_view>
34*14675a02SAndroid Build Coastguard Worker #include <typeinfo>
35*14675a02SAndroid Build Coastguard Worker #include <utility>
36*14675a02SAndroid Build Coastguard Worker #include <variant>
37*14675a02SAndroid Build Coastguard Worker #include <vector>
38*14675a02SAndroid Build Coastguard Worker 
39*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
40*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
41*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_cat.h"
42*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_split.h"
43*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
44*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
45*14675a02SAndroid Build Coastguard Worker #include "fcp/client/client_runner_example_data.pb.h"
46*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h"
47*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fake_event_publisher.h"
48*14675a02SAndroid Build Coastguard Worker #include "fcp/client/files.h"
49*14675a02SAndroid Build Coastguard Worker #include "fcp/client/flags.h"
50*14675a02SAndroid Build Coastguard Worker #include "fcp/client/histogram_counters.pb.h"
51*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/curl/curl_api.h"
52*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/curl/curl_http_client.h"
53*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h"
54*14675a02SAndroid Build Coastguard Worker #include "fcp/client/simple_task_environment.h"
55*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
56*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h"
57*14675a02SAndroid Build Coastguard Worker #include "gtest/gtest.h"
58*14675a02SAndroid Build Coastguard Worker 
59*14675a02SAndroid Build Coastguard Worker namespace fcp::client {
60*14675a02SAndroid Build Coastguard Worker 
61*14675a02SAndroid Build Coastguard Worker // A stub implementation of the SimpleTaskEnvironment interface that logs calls
62*14675a02SAndroid Build Coastguard Worker // to stderr and returns canned example data.
63*14675a02SAndroid Build Coastguard Worker class FederatedTaskEnvDepsImpl : public SimpleTaskEnvironment {
64*14675a02SAndroid Build Coastguard Worker  public:
65*14675a02SAndroid Build Coastguard Worker   // Constructs a SimpleTaskEnvironment that will return an example iterator
66*14675a02SAndroid Build Coastguard Worker   // with `num_empty_examples` empty examples.
67*14675a02SAndroid Build Coastguard Worker   explicit FederatedTaskEnvDepsImpl(int num_empty_examples,
68*14675a02SAndroid Build Coastguard Worker                                     std::string test_cert_path = "")
examples_(num_empty_examples)69*14675a02SAndroid Build Coastguard Worker       : examples_(num_empty_examples),
70*14675a02SAndroid Build Coastguard Worker         test_cert_path_(std::move(test_cert_path)) {}
71*14675a02SAndroid Build Coastguard Worker 
72*14675a02SAndroid Build Coastguard Worker   // Constructs a SimpleTaskEnvironment that will return an example iterator
73*14675a02SAndroid Build Coastguard Worker   // with examples determined by the collection URI.
74*14675a02SAndroid Build Coastguard Worker   explicit FederatedTaskEnvDepsImpl(ClientRunnerExampleData example_data,
75*14675a02SAndroid Build Coastguard Worker                                     std::string test_cert_path = "")
examples_(std::move (example_data))76*14675a02SAndroid Build Coastguard Worker       : examples_(std::move(example_data)),
77*14675a02SAndroid Build Coastguard Worker         test_cert_path_(std::move(test_cert_path)) {}
78*14675a02SAndroid Build Coastguard Worker 
GetBaseDir()79*14675a02SAndroid Build Coastguard Worker   std::string GetBaseDir() override {
80*14675a02SAndroid Build Coastguard Worker     return std::filesystem::path(testing::TempDir());
81*14675a02SAndroid Build Coastguard Worker   }
82*14675a02SAndroid Build Coastguard Worker 
GetCacheDir()83*14675a02SAndroid Build Coastguard Worker   std::string GetCacheDir() override {
84*14675a02SAndroid Build Coastguard Worker     return std::filesystem::path(testing::TempDir());
85*14675a02SAndroid Build Coastguard Worker   }
86*14675a02SAndroid Build Coastguard Worker 
CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector)87*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
88*14675a02SAndroid Build Coastguard Worker       const google::internal::federated::plan::ExampleSelector&
89*14675a02SAndroid Build Coastguard Worker           example_selector) override {
90*14675a02SAndroid Build Coastguard Worker     SelectorContext unused;
91*14675a02SAndroid Build Coastguard Worker     return CreateExampleIterator(example_selector, unused);
92*14675a02SAndroid Build Coastguard Worker   }
93*14675a02SAndroid Build Coastguard Worker 
CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector,const SelectorContext & selector_context)94*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
95*14675a02SAndroid Build Coastguard Worker       const google::internal::federated::plan::ExampleSelector&
96*14675a02SAndroid Build Coastguard Worker           example_selector,
97*14675a02SAndroid Build Coastguard Worker       const SelectorContext& selector_context) override {
98*14675a02SAndroid Build Coastguard Worker     // FCP_CLIENT_LOG_FUNCTION_NAME
99*14675a02SAndroid Build Coastguard Worker     //     << ":\n\turi: " << example_selector.collection_uri()
100*14675a02SAndroid Build Coastguard Worker     //     << "\n\ttype: " << example_selector.criteria().type_url();
101*14675a02SAndroid Build Coastguard Worker     if (auto* num_empty_examples = std::get_if<int>(&examples_)) {
102*14675a02SAndroid Build Coastguard Worker       return std::make_unique<FakeExampleIterator>(*num_empty_examples);
103*14675a02SAndroid Build Coastguard Worker     } else if (auto* store = std::get_if<ClientRunnerExampleData>(&examples_)) {
104*14675a02SAndroid Build Coastguard Worker       const auto& examples_map = store->examples_by_collection_uri();
105*14675a02SAndroid Build Coastguard Worker       if (auto it = examples_map.find(example_selector.collection_uri());
106*14675a02SAndroid Build Coastguard Worker           it != examples_map.end()) {
107*14675a02SAndroid Build Coastguard Worker         return std::make_unique<FakeExampleIterator>(&it->second);
108*14675a02SAndroid Build Coastguard Worker       }
109*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("no examples for collection_uri");
110*14675a02SAndroid Build Coastguard Worker     }
111*14675a02SAndroid Build Coastguard Worker     return absl::InternalError("unsupported examples variant type");
112*14675a02SAndroid Build Coastguard Worker   }
113*14675a02SAndroid Build Coastguard Worker 
CreateHttpClient()114*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<fcp::client::http::HttpClient> CreateHttpClient() override {
115*14675a02SAndroid Build Coastguard Worker     return std::make_unique<fcp::client::http::curl::CurlHttpClient>(
116*14675a02SAndroid Build Coastguard Worker         &curl_api_, test_cert_path_);
117*14675a02SAndroid Build Coastguard Worker   }
118*14675a02SAndroid Build Coastguard Worker 
119*14675a02SAndroid Build Coastguard Worker  private:
120*14675a02SAndroid Build Coastguard Worker   class FakeExampleIterator : public ExampleIterator {
121*14675a02SAndroid Build Coastguard Worker    public:
FakeExampleIterator(int num_examples)122*14675a02SAndroid Build Coastguard Worker     explicit FakeExampleIterator(int num_examples)
123*14675a02SAndroid Build Coastguard Worker         : example_list_(nullptr), num_examples_(num_examples) {}
FakeExampleIterator(const ClientRunnerExampleData::ExampleList * examples)124*14675a02SAndroid Build Coastguard Worker     explicit FakeExampleIterator(
125*14675a02SAndroid Build Coastguard Worker         const ClientRunnerExampleData::ExampleList* examples)
126*14675a02SAndroid Build Coastguard Worker         : example_list_(examples), num_examples_(examples->examples_size()) {}
Next()127*14675a02SAndroid Build Coastguard Worker     absl::StatusOr<std::string> Next() override {
128*14675a02SAndroid Build Coastguard Worker       if (num_examples_served_ >= num_examples_) {
129*14675a02SAndroid Build Coastguard Worker         return absl::OutOfRangeError("");
130*14675a02SAndroid Build Coastguard Worker       }
131*14675a02SAndroid Build Coastguard Worker       std::string example =
132*14675a02SAndroid Build Coastguard Worker           example_list_ ? example_list_->examples(num_examples_served_) : "";
133*14675a02SAndroid Build Coastguard Worker       num_examples_served_++;
134*14675a02SAndroid Build Coastguard Worker       return example;
135*14675a02SAndroid Build Coastguard Worker     }
Close()136*14675a02SAndroid Build Coastguard Worker     void Close() override {}
137*14675a02SAndroid Build Coastguard Worker 
138*14675a02SAndroid Build Coastguard Worker    private:
139*14675a02SAndroid Build Coastguard Worker     const ClientRunnerExampleData::ExampleList* const example_list_;
140*14675a02SAndroid Build Coastguard Worker     const int num_examples_;
141*14675a02SAndroid Build Coastguard Worker     int num_examples_served_ = 0;
142*14675a02SAndroid Build Coastguard Worker   };
143*14675a02SAndroid Build Coastguard Worker 
TrainingConditionsSatisfied()144*14675a02SAndroid Build Coastguard Worker   bool TrainingConditionsSatisfied() override {
145*14675a02SAndroid Build Coastguard Worker     FCP_CLIENT_LOG_FUNCTION_NAME;
146*14675a02SAndroid Build Coastguard Worker     return true;
147*14675a02SAndroid Build Coastguard Worker   }
148*14675a02SAndroid Build Coastguard Worker 
149*14675a02SAndroid Build Coastguard Worker   const std::variant<int, ClientRunnerExampleData> examples_;
150*14675a02SAndroid Build Coastguard Worker   const std::string test_cert_path_;
151*14675a02SAndroid Build Coastguard Worker   fcp::client::http::curl::CurlApi curl_api_;
152*14675a02SAndroid Build Coastguard Worker };
153*14675a02SAndroid Build Coastguard Worker 
154*14675a02SAndroid Build Coastguard Worker // An implementation of the Files interface that attempts to create a temporary
155*14675a02SAndroid Build Coastguard Worker // file with the given prefix and suffix in a directory suitable for temporary
156*14675a02SAndroid Build Coastguard Worker // files.
157*14675a02SAndroid Build Coastguard Worker // NB this is a proof-of-concept implementation that does not use existing infra
158*14675a02SAndroid Build Coastguard Worker // such as mkstemps() or std::tmpfile due to the requirements of the existing
159*14675a02SAndroid Build Coastguard Worker // Files API: include prefix, suffix strings in filename; return file path
160*14675a02SAndroid Build Coastguard Worker // instead of file descriptor.
161*14675a02SAndroid Build Coastguard Worker class FilesImpl : public Files {
162*14675a02SAndroid Build Coastguard Worker  public:
FilesImpl()163*14675a02SAndroid Build Coastguard Worker   FilesImpl() { std::srand(static_cast<int32_t>(std::time(nullptr))); }
164*14675a02SAndroid Build Coastguard Worker 
CreateTempFile(const std::string & prefix,const std::string & suffix)165*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::string> CreateTempFile(
166*14675a02SAndroid Build Coastguard Worker       const std::string& prefix, const std::string& suffix) override {
167*14675a02SAndroid Build Coastguard Worker     const auto tmp_dir = std::filesystem::path(testing::TempDir());
168*14675a02SAndroid Build Coastguard Worker     std::filesystem::path candidate_path;
169*14675a02SAndroid Build Coastguard Worker     int fd;
170*14675a02SAndroid Build Coastguard Worker     do {
171*14675a02SAndroid Build Coastguard Worker       candidate_path =
172*14675a02SAndroid Build Coastguard Worker           tmp_dir / absl::StrCat(prefix, std::to_string(std::rand()), suffix);
173*14675a02SAndroid Build Coastguard Worker     } while ((fd = open(candidate_path.c_str(), O_CREAT | O_EXCL | O_RDWR,
174*14675a02SAndroid Build Coastguard Worker                         S_IRWXU)) == -1 &&
175*14675a02SAndroid Build Coastguard Worker              errno == EEXIST);
176*14675a02SAndroid Build Coastguard Worker     close(fd);
177*14675a02SAndroid Build Coastguard Worker     std::ofstream tmp_file(candidate_path);
178*14675a02SAndroid Build Coastguard Worker     if (!tmp_file) {
179*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
180*14675a02SAndroid Build Coastguard Worker           absl::StrCat("could not create file ", candidate_path.string()));
181*14675a02SAndroid Build Coastguard Worker     }
182*14675a02SAndroid Build Coastguard Worker     // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << candidate_path;
183*14675a02SAndroid Build Coastguard Worker     return candidate_path.string();
184*14675a02SAndroid Build Coastguard Worker   }
185*14675a02SAndroid Build Coastguard Worker };
186*14675a02SAndroid Build Coastguard Worker 
187*14675a02SAndroid Build Coastguard Worker // A stub implementation of the LogManager interface that logs invocations to
188*14675a02SAndroid Build Coastguard Worker // stderr.
189*14675a02SAndroid Build Coastguard Worker class LogManagerImpl : public LogManager {
190*14675a02SAndroid Build Coastguard Worker  public:
LogDiag(ProdDiagCode diag_code)191*14675a02SAndroid Build Coastguard Worker   void LogDiag(ProdDiagCode diag_code) override {
192*14675a02SAndroid Build Coastguard Worker     // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << ProdDiagCode_Name(diag_code);
193*14675a02SAndroid Build Coastguard Worker   }
LogDiag(DebugDiagCode diag_code)194*14675a02SAndroid Build Coastguard Worker   void LogDiag(DebugDiagCode diag_code) override {
195*14675a02SAndroid Build Coastguard Worker     // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << DebugDiagCode_Name(diag_code);
196*14675a02SAndroid Build Coastguard Worker   }
LogToLongHistogram(HistogramCounters histogram_counter,int,int,engine::DataSourceType data_source_type,int64_t value)197*14675a02SAndroid Build Coastguard Worker   void LogToLongHistogram(HistogramCounters histogram_counter, int, int,
198*14675a02SAndroid Build Coastguard Worker                           engine::DataSourceType data_source_type,
199*14675a02SAndroid Build Coastguard Worker                           int64_t value) override {
200*14675a02SAndroid Build Coastguard Worker     // FCP_CLIENT_LOG_FUNCTION_NAME
201*14675a02SAndroid Build Coastguard Worker     //     << ": " << HistogramCounters_Name(histogram_counter) << " <- " <<
202*14675a02SAndroid Build Coastguard Worker     //     value;
203*14675a02SAndroid Build Coastguard Worker   }
204*14675a02SAndroid Build Coastguard Worker 
SetModelIdentifier(const std::string & model_identifier)205*14675a02SAndroid Build Coastguard Worker   void SetModelIdentifier(const std::string& model_identifier) override {
206*14675a02SAndroid Build Coastguard Worker     // FCP_CLIENT_LOG_FUNCTION_NAME << ":\n\t" << model_identifier;
207*14675a02SAndroid Build Coastguard Worker   }
208*14675a02SAndroid Build Coastguard Worker };
209*14675a02SAndroid Build Coastguard Worker 
210*14675a02SAndroid Build Coastguard Worker class FlagsImpl : public Flags {
211*14675a02SAndroid Build Coastguard Worker  public:
set_use_http_federated_compute_protocol(bool value)212*14675a02SAndroid Build Coastguard Worker   void set_use_http_federated_compute_protocol(bool value) {
213*14675a02SAndroid Build Coastguard Worker     use_http_federated_compute_protocol_ = value;
214*14675a02SAndroid Build Coastguard Worker   }
set_use_tflite_training(bool value)215*14675a02SAndroid Build Coastguard Worker   void set_use_tflite_training(bool value) { use_tflite_training_ = value; }
216*14675a02SAndroid Build Coastguard Worker 
condition_polling_period_millis()217*14675a02SAndroid Build Coastguard Worker   int64_t condition_polling_period_millis() const override { return 1000; }
tf_execution_teardown_grace_period_millis()218*14675a02SAndroid Build Coastguard Worker   int64_t tf_execution_teardown_grace_period_millis() const override {
219*14675a02SAndroid Build Coastguard Worker     return 1000;
220*14675a02SAndroid Build Coastguard Worker   }
tf_execution_teardown_extended_period_millis()221*14675a02SAndroid Build Coastguard Worker   int64_t tf_execution_teardown_extended_period_millis() const override {
222*14675a02SAndroid Build Coastguard Worker     return 2000;
223*14675a02SAndroid Build Coastguard Worker   }
grpc_channel_deadline_seconds()224*14675a02SAndroid Build Coastguard Worker   int64_t grpc_channel_deadline_seconds() const override { return 0; }
log_tensorflow_error_messages()225*14675a02SAndroid Build Coastguard Worker   bool log_tensorflow_error_messages() const override { return true; }
use_http_federated_compute_protocol()226*14675a02SAndroid Build Coastguard Worker   bool use_http_federated_compute_protocol() const override {
227*14675a02SAndroid Build Coastguard Worker     return use_http_federated_compute_protocol_;
228*14675a02SAndroid Build Coastguard Worker   }
use_tflite_training()229*14675a02SAndroid Build Coastguard Worker   bool use_tflite_training() const override { return use_tflite_training_; }
230*14675a02SAndroid Build Coastguard Worker 
231*14675a02SAndroid Build Coastguard Worker  private:
232*14675a02SAndroid Build Coastguard Worker   bool use_http_federated_compute_protocol_ = false;
233*14675a02SAndroid Build Coastguard Worker   bool use_tflite_training_ = false;
234*14675a02SAndroid Build Coastguard Worker };
235*14675a02SAndroid Build Coastguard Worker 
236*14675a02SAndroid Build Coastguard Worker }  // namespace fcp::client
237*14675a02SAndroid Build Coastguard Worker 
238*14675a02SAndroid Build Coastguard Worker #endif  // FCP_CLIENT_CLIENT_RUNNER_H_
239