1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_CLIENT_CLIENT_RUNNER_H_ 18 #define FCP_CLIENT_CLIENT_RUNNER_H_ 19 20 #include <cxxabi.h> 21 #include <fcntl.h> 22 #include <sys/stat.h> 23 #include <sys/types.h> 24 25 #include <array> 26 #include <cstdint> 27 #include <cstdlib> 28 #include <ctime> 29 #include <filesystem> 30 #include <fstream> 31 #include <memory> 32 #include <string> 33 #include <string_view> 34 #include <typeinfo> 35 #include <utility> 36 #include <variant> 37 #include <vector> 38 39 #include "absl/container/flat_hash_map.h" 40 #include "absl/status/statusor.h" 41 #include "absl/strings/str_cat.h" 42 #include "absl/strings/str_split.h" 43 #include "absl/time/time.h" 44 #include "fcp/base/monitoring.h" 45 #include "fcp/client/client_runner_example_data.pb.h" 46 #include "fcp/client/diag_codes.pb.h" 47 #include "fcp/client/fake_event_publisher.h" 48 #include "fcp/client/files.h" 49 #include "fcp/client/flags.h" 50 #include "fcp/client/histogram_counters.pb.h" 51 #include "fcp/client/http/curl/curl_api.h" 52 #include "fcp/client/http/curl/curl_http_client.h" 53 #include "fcp/client/log_manager.h" 54 #include "fcp/client/simple_task_environment.h" 55 #include "fcp/protos/plan.pb.h" 56 #include "google/protobuf/any.pb.h" 57 #include "gtest/gtest.h" 58 59 namespace fcp::client { 60 61 // A stub implementation of the SimpleTaskEnvironment interface that logs calls 62 // to stderr and returns canned example data. 63 class FederatedTaskEnvDepsImpl : public SimpleTaskEnvironment { 64 public: 65 // Constructs a SimpleTaskEnvironment that will return an example iterator 66 // with `num_empty_examples` empty examples. 67 explicit FederatedTaskEnvDepsImpl(int num_empty_examples, 68 std::string test_cert_path = "") examples_(num_empty_examples)69 : examples_(num_empty_examples), 70 test_cert_path_(std::move(test_cert_path)) {} 71 72 // Constructs a SimpleTaskEnvironment that will return an example iterator 73 // with examples determined by the collection URI. 74 explicit FederatedTaskEnvDepsImpl(ClientRunnerExampleData example_data, 75 std::string test_cert_path = "") examples_(std::move (example_data))76 : examples_(std::move(example_data)), 77 test_cert_path_(std::move(test_cert_path)) {} 78 GetBaseDir()79 std::string GetBaseDir() override { 80 return std::filesystem::path(testing::TempDir()); 81 } 82 GetCacheDir()83 std::string GetCacheDir() override { 84 return std::filesystem::path(testing::TempDir()); 85 } 86 CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector)87 absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( 88 const google::internal::federated::plan::ExampleSelector& 89 example_selector) override { 90 SelectorContext unused; 91 return CreateExampleIterator(example_selector, unused); 92 } 93 CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector,const SelectorContext & selector_context)94 absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( 95 const google::internal::federated::plan::ExampleSelector& 96 example_selector, 97 const SelectorContext& selector_context) override { 98 // FCP_CLIENT_LOG_FUNCTION_NAME 99 // << ":\n\turi: " << example_selector.collection_uri() 100 // << "\n\ttype: " << example_selector.criteria().type_url(); 101 if (auto* num_empty_examples = std::get_if<int>(&examples_)) { 102 return std::make_unique<FakeExampleIterator>(*num_empty_examples); 103 } else if (auto* store = std::get_if<ClientRunnerExampleData>(&examples_)) { 104 const auto& examples_map = store->examples_by_collection_uri(); 105 if (auto it = examples_map.find(example_selector.collection_uri()); 106 it != examples_map.end()) { 107 return std::make_unique<FakeExampleIterator>(&it->second); 108 } 109 return absl::InvalidArgumentError("no examples for collection_uri"); 110 } 111 return absl::InternalError("unsupported examples variant type"); 112 } 113 CreateHttpClient()114 std::unique_ptr<fcp::client::http::HttpClient> CreateHttpClient() override { 115 return std::make_unique<fcp::client::http::curl::CurlHttpClient>( 116 &curl_api_, test_cert_path_); 117 } 118 119 private: 120 class FakeExampleIterator : public ExampleIterator { 121 public: FakeExampleIterator(int num_examples)122 explicit FakeExampleIterator(int num_examples) 123 : example_list_(nullptr), num_examples_(num_examples) {} FakeExampleIterator(const ClientRunnerExampleData::ExampleList * examples)124 explicit FakeExampleIterator( 125 const ClientRunnerExampleData::ExampleList* examples) 126 : example_list_(examples), num_examples_(examples->examples_size()) {} Next()127 absl::StatusOr<std::string> Next() override { 128 if (num_examples_served_ >= num_examples_) { 129 return absl::OutOfRangeError(""); 130 } 131 std::string example = 132 example_list_ ? example_list_->examples(num_examples_served_) : ""; 133 num_examples_served_++; 134 return example; 135 } Close()136 void Close() override {} 137 138 private: 139 const ClientRunnerExampleData::ExampleList* const example_list_; 140 const int num_examples_; 141 int num_examples_served_ = 0; 142 }; 143 TrainingConditionsSatisfied()144 bool TrainingConditionsSatisfied() override { 145 FCP_CLIENT_LOG_FUNCTION_NAME; 146 return true; 147 } 148 149 const std::variant<int, ClientRunnerExampleData> examples_; 150 const std::string test_cert_path_; 151 fcp::client::http::curl::CurlApi curl_api_; 152 }; 153 154 // An implementation of the Files interface that attempts to create a temporary 155 // file with the given prefix and suffix in a directory suitable for temporary 156 // files. 157 // NB this is a proof-of-concept implementation that does not use existing infra 158 // such as mkstemps() or std::tmpfile due to the requirements of the existing 159 // Files API: include prefix, suffix strings in filename; return file path 160 // instead of file descriptor. 161 class FilesImpl : public Files { 162 public: FilesImpl()163 FilesImpl() { std::srand(static_cast<int32_t>(std::time(nullptr))); } 164 CreateTempFile(const std::string & prefix,const std::string & suffix)165 absl::StatusOr<std::string> CreateTempFile( 166 const std::string& prefix, const std::string& suffix) override { 167 const auto tmp_dir = std::filesystem::path(testing::TempDir()); 168 std::filesystem::path candidate_path; 169 int fd; 170 do { 171 candidate_path = 172 tmp_dir / absl::StrCat(prefix, std::to_string(std::rand()), suffix); 173 } while ((fd = open(candidate_path.c_str(), O_CREAT | O_EXCL | O_RDWR, 174 S_IRWXU)) == -1 && 175 errno == EEXIST); 176 close(fd); 177 std::ofstream tmp_file(candidate_path); 178 if (!tmp_file) { 179 return absl::InvalidArgumentError( 180 absl::StrCat("could not create file ", candidate_path.string())); 181 } 182 // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << candidate_path; 183 return candidate_path.string(); 184 } 185 }; 186 187 // A stub implementation of the LogManager interface that logs invocations to 188 // stderr. 189 class LogManagerImpl : public LogManager { 190 public: LogDiag(ProdDiagCode diag_code)191 void LogDiag(ProdDiagCode diag_code) override { 192 // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << ProdDiagCode_Name(diag_code); 193 } LogDiag(DebugDiagCode diag_code)194 void LogDiag(DebugDiagCode diag_code) override { 195 // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << DebugDiagCode_Name(diag_code); 196 } LogToLongHistogram(HistogramCounters histogram_counter,int,int,engine::DataSourceType data_source_type,int64_t value)197 void LogToLongHistogram(HistogramCounters histogram_counter, int, int, 198 engine::DataSourceType data_source_type, 199 int64_t value) override { 200 // FCP_CLIENT_LOG_FUNCTION_NAME 201 // << ": " << HistogramCounters_Name(histogram_counter) << " <- " << 202 // value; 203 } 204 SetModelIdentifier(const std::string & model_identifier)205 void SetModelIdentifier(const std::string& model_identifier) override { 206 // FCP_CLIENT_LOG_FUNCTION_NAME << ":\n\t" << model_identifier; 207 } 208 }; 209 210 class FlagsImpl : public Flags { 211 public: set_use_http_federated_compute_protocol(bool value)212 void set_use_http_federated_compute_protocol(bool value) { 213 use_http_federated_compute_protocol_ = value; 214 } set_use_tflite_training(bool value)215 void set_use_tflite_training(bool value) { use_tflite_training_ = value; } 216 condition_polling_period_millis()217 int64_t condition_polling_period_millis() const override { return 1000; } tf_execution_teardown_grace_period_millis()218 int64_t tf_execution_teardown_grace_period_millis() const override { 219 return 1000; 220 } tf_execution_teardown_extended_period_millis()221 int64_t tf_execution_teardown_extended_period_millis() const override { 222 return 2000; 223 } grpc_channel_deadline_seconds()224 int64_t grpc_channel_deadline_seconds() const override { return 0; } log_tensorflow_error_messages()225 bool log_tensorflow_error_messages() const override { return true; } use_http_federated_compute_protocol()226 bool use_http_federated_compute_protocol() const override { 227 return use_http_federated_compute_protocol_; 228 } use_tflite_training()229 bool use_tflite_training() const override { return use_tflite_training_; } 230 231 private: 232 bool use_http_federated_compute_protocol_ = false; 233 bool use_tflite_training_ = false; 234 }; 235 236 } // namespace fcp::client 237 238 #endif // FCP_CLIENT_CLIENT_RUNNER_H_ 239