xref: /aosp_15_r20/external/federated-compute/fcp/client/client_runner.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_CLIENT_CLIENT_RUNNER_H_
18 #define FCP_CLIENT_CLIENT_RUNNER_H_
19 
20 #include <cxxabi.h>
21 #include <fcntl.h>
22 #include <sys/stat.h>
23 #include <sys/types.h>
24 
25 #include <array>
26 #include <cstdint>
27 #include <cstdlib>
28 #include <ctime>
29 #include <filesystem>
30 #include <fstream>
31 #include <memory>
32 #include <string>
33 #include <string_view>
34 #include <typeinfo>
35 #include <utility>
36 #include <variant>
37 #include <vector>
38 
39 #include "absl/container/flat_hash_map.h"
40 #include "absl/status/statusor.h"
41 #include "absl/strings/str_cat.h"
42 #include "absl/strings/str_split.h"
43 #include "absl/time/time.h"
44 #include "fcp/base/monitoring.h"
45 #include "fcp/client/client_runner_example_data.pb.h"
46 #include "fcp/client/diag_codes.pb.h"
47 #include "fcp/client/fake_event_publisher.h"
48 #include "fcp/client/files.h"
49 #include "fcp/client/flags.h"
50 #include "fcp/client/histogram_counters.pb.h"
51 #include "fcp/client/http/curl/curl_api.h"
52 #include "fcp/client/http/curl/curl_http_client.h"
53 #include "fcp/client/log_manager.h"
54 #include "fcp/client/simple_task_environment.h"
55 #include "fcp/protos/plan.pb.h"
56 #include "google/protobuf/any.pb.h"
57 #include "gtest/gtest.h"
58 
59 namespace fcp::client {
60 
61 // A stub implementation of the SimpleTaskEnvironment interface that logs calls
62 // to stderr and returns canned example data.
63 class FederatedTaskEnvDepsImpl : public SimpleTaskEnvironment {
64  public:
65   // Constructs a SimpleTaskEnvironment that will return an example iterator
66   // with `num_empty_examples` empty examples.
67   explicit FederatedTaskEnvDepsImpl(int num_empty_examples,
68                                     std::string test_cert_path = "")
examples_(num_empty_examples)69       : examples_(num_empty_examples),
70         test_cert_path_(std::move(test_cert_path)) {}
71 
72   // Constructs a SimpleTaskEnvironment that will return an example iterator
73   // with examples determined by the collection URI.
74   explicit FederatedTaskEnvDepsImpl(ClientRunnerExampleData example_data,
75                                     std::string test_cert_path = "")
examples_(std::move (example_data))76       : examples_(std::move(example_data)),
77         test_cert_path_(std::move(test_cert_path)) {}
78 
GetBaseDir()79   std::string GetBaseDir() override {
80     return std::filesystem::path(testing::TempDir());
81   }
82 
GetCacheDir()83   std::string GetCacheDir() override {
84     return std::filesystem::path(testing::TempDir());
85   }
86 
CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector)87   absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
88       const google::internal::federated::plan::ExampleSelector&
89           example_selector) override {
90     SelectorContext unused;
91     return CreateExampleIterator(example_selector, unused);
92   }
93 
CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector,const SelectorContext & selector_context)94   absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator(
95       const google::internal::federated::plan::ExampleSelector&
96           example_selector,
97       const SelectorContext& selector_context) override {
98     // FCP_CLIENT_LOG_FUNCTION_NAME
99     //     << ":\n\turi: " << example_selector.collection_uri()
100     //     << "\n\ttype: " << example_selector.criteria().type_url();
101     if (auto* num_empty_examples = std::get_if<int>(&examples_)) {
102       return std::make_unique<FakeExampleIterator>(*num_empty_examples);
103     } else if (auto* store = std::get_if<ClientRunnerExampleData>(&examples_)) {
104       const auto& examples_map = store->examples_by_collection_uri();
105       if (auto it = examples_map.find(example_selector.collection_uri());
106           it != examples_map.end()) {
107         return std::make_unique<FakeExampleIterator>(&it->second);
108       }
109       return absl::InvalidArgumentError("no examples for collection_uri");
110     }
111     return absl::InternalError("unsupported examples variant type");
112   }
113 
CreateHttpClient()114   std::unique_ptr<fcp::client::http::HttpClient> CreateHttpClient() override {
115     return std::make_unique<fcp::client::http::curl::CurlHttpClient>(
116         &curl_api_, test_cert_path_);
117   }
118 
119  private:
120   class FakeExampleIterator : public ExampleIterator {
121    public:
FakeExampleIterator(int num_examples)122     explicit FakeExampleIterator(int num_examples)
123         : example_list_(nullptr), num_examples_(num_examples) {}
FakeExampleIterator(const ClientRunnerExampleData::ExampleList * examples)124     explicit FakeExampleIterator(
125         const ClientRunnerExampleData::ExampleList* examples)
126         : example_list_(examples), num_examples_(examples->examples_size()) {}
Next()127     absl::StatusOr<std::string> Next() override {
128       if (num_examples_served_ >= num_examples_) {
129         return absl::OutOfRangeError("");
130       }
131       std::string example =
132           example_list_ ? example_list_->examples(num_examples_served_) : "";
133       num_examples_served_++;
134       return example;
135     }
Close()136     void Close() override {}
137 
138    private:
139     const ClientRunnerExampleData::ExampleList* const example_list_;
140     const int num_examples_;
141     int num_examples_served_ = 0;
142   };
143 
TrainingConditionsSatisfied()144   bool TrainingConditionsSatisfied() override {
145     FCP_CLIENT_LOG_FUNCTION_NAME;
146     return true;
147   }
148 
149   const std::variant<int, ClientRunnerExampleData> examples_;
150   const std::string test_cert_path_;
151   fcp::client::http::curl::CurlApi curl_api_;
152 };
153 
154 // An implementation of the Files interface that attempts to create a temporary
155 // file with the given prefix and suffix in a directory suitable for temporary
156 // files.
157 // NB this is a proof-of-concept implementation that does not use existing infra
158 // such as mkstemps() or std::tmpfile due to the requirements of the existing
159 // Files API: include prefix, suffix strings in filename; return file path
160 // instead of file descriptor.
161 class FilesImpl : public Files {
162  public:
FilesImpl()163   FilesImpl() { std::srand(static_cast<int32_t>(std::time(nullptr))); }
164 
CreateTempFile(const std::string & prefix,const std::string & suffix)165   absl::StatusOr<std::string> CreateTempFile(
166       const std::string& prefix, const std::string& suffix) override {
167     const auto tmp_dir = std::filesystem::path(testing::TempDir());
168     std::filesystem::path candidate_path;
169     int fd;
170     do {
171       candidate_path =
172           tmp_dir / absl::StrCat(prefix, std::to_string(std::rand()), suffix);
173     } while ((fd = open(candidate_path.c_str(), O_CREAT | O_EXCL | O_RDWR,
174                         S_IRWXU)) == -1 &&
175              errno == EEXIST);
176     close(fd);
177     std::ofstream tmp_file(candidate_path);
178     if (!tmp_file) {
179       return absl::InvalidArgumentError(
180           absl::StrCat("could not create file ", candidate_path.string()));
181     }
182     // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << candidate_path;
183     return candidate_path.string();
184   }
185 };
186 
187 // A stub implementation of the LogManager interface that logs invocations to
188 // stderr.
189 class LogManagerImpl : public LogManager {
190  public:
LogDiag(ProdDiagCode diag_code)191   void LogDiag(ProdDiagCode diag_code) override {
192     // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << ProdDiagCode_Name(diag_code);
193   }
LogDiag(DebugDiagCode diag_code)194   void LogDiag(DebugDiagCode diag_code) override {
195     // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << DebugDiagCode_Name(diag_code);
196   }
LogToLongHistogram(HistogramCounters histogram_counter,int,int,engine::DataSourceType data_source_type,int64_t value)197   void LogToLongHistogram(HistogramCounters histogram_counter, int, int,
198                           engine::DataSourceType data_source_type,
199                           int64_t value) override {
200     // FCP_CLIENT_LOG_FUNCTION_NAME
201     //     << ": " << HistogramCounters_Name(histogram_counter) << " <- " <<
202     //     value;
203   }
204 
SetModelIdentifier(const std::string & model_identifier)205   void SetModelIdentifier(const std::string& model_identifier) override {
206     // FCP_CLIENT_LOG_FUNCTION_NAME << ":\n\t" << model_identifier;
207   }
208 };
209 
210 class FlagsImpl : public Flags {
211  public:
set_use_http_federated_compute_protocol(bool value)212   void set_use_http_federated_compute_protocol(bool value) {
213     use_http_federated_compute_protocol_ = value;
214   }
set_use_tflite_training(bool value)215   void set_use_tflite_training(bool value) { use_tflite_training_ = value; }
216 
condition_polling_period_millis()217   int64_t condition_polling_period_millis() const override { return 1000; }
tf_execution_teardown_grace_period_millis()218   int64_t tf_execution_teardown_grace_period_millis() const override {
219     return 1000;
220   }
tf_execution_teardown_extended_period_millis()221   int64_t tf_execution_teardown_extended_period_millis() const override {
222     return 2000;
223   }
grpc_channel_deadline_seconds()224   int64_t grpc_channel_deadline_seconds() const override { return 0; }
log_tensorflow_error_messages()225   bool log_tensorflow_error_messages() const override { return true; }
use_http_federated_compute_protocol()226   bool use_http_federated_compute_protocol() const override {
227     return use_http_federated_compute_protocol_;
228   }
use_tflite_training()229   bool use_tflite_training() const override { return use_tflite_training_; }
230 
231  private:
232   bool use_http_federated_compute_protocol_ = false;
233   bool use_tflite_training_ = false;
234 };
235 
236 }  // namespace fcp::client
237 
238 #endif  // FCP_CLIENT_CLIENT_RUNNER_H_
239