1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker 17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_CLIENT_RUNNER_H_ 18*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_CLIENT_RUNNER_H_ 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Worker #include <cxxabi.h> 21*14675a02SAndroid Build Coastguard Worker #include <fcntl.h> 22*14675a02SAndroid Build Coastguard Worker #include <sys/stat.h> 23*14675a02SAndroid Build Coastguard Worker #include <sys/types.h> 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker #include <array> 26*14675a02SAndroid Build Coastguard Worker #include <cstdint> 27*14675a02SAndroid Build Coastguard Worker #include <cstdlib> 28*14675a02SAndroid Build Coastguard Worker #include <ctime> 29*14675a02SAndroid Build Coastguard Worker #include <filesystem> 30*14675a02SAndroid Build Coastguard Worker #include <fstream> 31*14675a02SAndroid Build Coastguard Worker #include <memory> 32*14675a02SAndroid Build Coastguard Worker #include <string> 33*14675a02SAndroid Build Coastguard Worker #include <string_view> 34*14675a02SAndroid Build Coastguard Worker #include <typeinfo> 35*14675a02SAndroid Build Coastguard Worker #include <utility> 36*14675a02SAndroid Build Coastguard Worker #include <variant> 37*14675a02SAndroid Build Coastguard Worker #include <vector> 38*14675a02SAndroid Build Coastguard Worker 39*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h" 40*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h" 41*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_cat.h" 42*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_split.h" 43*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h" 44*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 45*14675a02SAndroid Build Coastguard Worker #include "fcp/client/client_runner_example_data.pb.h" 46*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h" 47*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fake_event_publisher.h" 48*14675a02SAndroid Build Coastguard Worker #include "fcp/client/files.h" 49*14675a02SAndroid Build Coastguard Worker #include "fcp/client/flags.h" 50*14675a02SAndroid Build Coastguard Worker #include "fcp/client/histogram_counters.pb.h" 51*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/curl/curl_api.h" 52*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/curl/curl_http_client.h" 53*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h" 54*14675a02SAndroid Build Coastguard Worker #include "fcp/client/simple_task_environment.h" 55*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h" 56*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h" 57*14675a02SAndroid Build Coastguard Worker #include "gtest/gtest.h" 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker namespace fcp::client { 60*14675a02SAndroid Build Coastguard Worker 61*14675a02SAndroid Build Coastguard Worker // A stub implementation of the SimpleTaskEnvironment interface that logs calls 62*14675a02SAndroid Build Coastguard Worker // to stderr and returns canned example data. 63*14675a02SAndroid Build Coastguard Worker class FederatedTaskEnvDepsImpl : public SimpleTaskEnvironment { 64*14675a02SAndroid Build Coastguard Worker public: 65*14675a02SAndroid Build Coastguard Worker // Constructs a SimpleTaskEnvironment that will return an example iterator 66*14675a02SAndroid Build Coastguard Worker // with `num_empty_examples` empty examples. 67*14675a02SAndroid Build Coastguard Worker explicit FederatedTaskEnvDepsImpl(int num_empty_examples, 68*14675a02SAndroid Build Coastguard Worker std::string test_cert_path = "") examples_(num_empty_examples)69*14675a02SAndroid Build Coastguard Worker : examples_(num_empty_examples), 70*14675a02SAndroid Build Coastguard Worker test_cert_path_(std::move(test_cert_path)) {} 71*14675a02SAndroid Build Coastguard Worker 72*14675a02SAndroid Build Coastguard Worker // Constructs a SimpleTaskEnvironment that will return an example iterator 73*14675a02SAndroid Build Coastguard Worker // with examples determined by the collection URI. 74*14675a02SAndroid Build Coastguard Worker explicit FederatedTaskEnvDepsImpl(ClientRunnerExampleData example_data, 75*14675a02SAndroid Build Coastguard Worker std::string test_cert_path = "") examples_(std::move (example_data))76*14675a02SAndroid Build Coastguard Worker : examples_(std::move(example_data)), 77*14675a02SAndroid Build Coastguard Worker test_cert_path_(std::move(test_cert_path)) {} 78*14675a02SAndroid Build Coastguard Worker GetBaseDir()79*14675a02SAndroid Build Coastguard Worker std::string GetBaseDir() override { 80*14675a02SAndroid Build Coastguard Worker return std::filesystem::path(testing::TempDir()); 81*14675a02SAndroid Build Coastguard Worker } 82*14675a02SAndroid Build Coastguard Worker GetCacheDir()83*14675a02SAndroid Build Coastguard Worker std::string GetCacheDir() override { 84*14675a02SAndroid Build Coastguard Worker return std::filesystem::path(testing::TempDir()); 85*14675a02SAndroid Build Coastguard Worker } 86*14675a02SAndroid Build Coastguard Worker CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector)87*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( 88*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::ExampleSelector& 89*14675a02SAndroid Build Coastguard Worker example_selector) override { 90*14675a02SAndroid Build Coastguard Worker SelectorContext unused; 91*14675a02SAndroid Build Coastguard Worker return CreateExampleIterator(example_selector, unused); 92*14675a02SAndroid Build Coastguard Worker } 93*14675a02SAndroid Build Coastguard Worker CreateExampleIterator(const google::internal::federated::plan::ExampleSelector & example_selector,const SelectorContext & selector_context)94*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<ExampleIterator>> CreateExampleIterator( 95*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::ExampleSelector& 96*14675a02SAndroid Build Coastguard Worker example_selector, 97*14675a02SAndroid Build Coastguard Worker const SelectorContext& selector_context) override { 98*14675a02SAndroid Build Coastguard Worker // FCP_CLIENT_LOG_FUNCTION_NAME 99*14675a02SAndroid Build Coastguard Worker // << ":\n\turi: " << example_selector.collection_uri() 100*14675a02SAndroid Build Coastguard Worker // << "\n\ttype: " << example_selector.criteria().type_url(); 101*14675a02SAndroid Build Coastguard Worker if (auto* num_empty_examples = std::get_if<int>(&examples_)) { 102*14675a02SAndroid Build Coastguard Worker return std::make_unique<FakeExampleIterator>(*num_empty_examples); 103*14675a02SAndroid Build Coastguard Worker } else if (auto* store = std::get_if<ClientRunnerExampleData>(&examples_)) { 104*14675a02SAndroid Build Coastguard Worker const auto& examples_map = store->examples_by_collection_uri(); 105*14675a02SAndroid Build Coastguard Worker if (auto it = examples_map.find(example_selector.collection_uri()); 106*14675a02SAndroid Build Coastguard Worker it != examples_map.end()) { 107*14675a02SAndroid Build Coastguard Worker return std::make_unique<FakeExampleIterator>(&it->second); 108*14675a02SAndroid Build Coastguard Worker } 109*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError("no examples for collection_uri"); 110*14675a02SAndroid Build Coastguard Worker } 111*14675a02SAndroid Build Coastguard Worker return absl::InternalError("unsupported examples variant type"); 112*14675a02SAndroid Build Coastguard Worker } 113*14675a02SAndroid Build Coastguard Worker CreateHttpClient()114*14675a02SAndroid Build Coastguard Worker std::unique_ptr<fcp::client::http::HttpClient> CreateHttpClient() override { 115*14675a02SAndroid Build Coastguard Worker return std::make_unique<fcp::client::http::curl::CurlHttpClient>( 116*14675a02SAndroid Build Coastguard Worker &curl_api_, test_cert_path_); 117*14675a02SAndroid Build Coastguard Worker } 118*14675a02SAndroid Build Coastguard Worker 119*14675a02SAndroid Build Coastguard Worker private: 120*14675a02SAndroid Build Coastguard Worker class FakeExampleIterator : public ExampleIterator { 121*14675a02SAndroid Build Coastguard Worker public: FakeExampleIterator(int num_examples)122*14675a02SAndroid Build Coastguard Worker explicit FakeExampleIterator(int num_examples) 123*14675a02SAndroid Build Coastguard Worker : example_list_(nullptr), num_examples_(num_examples) {} FakeExampleIterator(const ClientRunnerExampleData::ExampleList * examples)124*14675a02SAndroid Build Coastguard Worker explicit FakeExampleIterator( 125*14675a02SAndroid Build Coastguard Worker const ClientRunnerExampleData::ExampleList* examples) 126*14675a02SAndroid Build Coastguard Worker : example_list_(examples), num_examples_(examples->examples_size()) {} Next()127*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::string> Next() override { 128*14675a02SAndroid Build Coastguard Worker if (num_examples_served_ >= num_examples_) { 129*14675a02SAndroid Build Coastguard Worker return absl::OutOfRangeError(""); 130*14675a02SAndroid Build Coastguard Worker } 131*14675a02SAndroid Build Coastguard Worker std::string example = 132*14675a02SAndroid Build Coastguard Worker example_list_ ? example_list_->examples(num_examples_served_) : ""; 133*14675a02SAndroid Build Coastguard Worker num_examples_served_++; 134*14675a02SAndroid Build Coastguard Worker return example; 135*14675a02SAndroid Build Coastguard Worker } Close()136*14675a02SAndroid Build Coastguard Worker void Close() override {} 137*14675a02SAndroid Build Coastguard Worker 138*14675a02SAndroid Build Coastguard Worker private: 139*14675a02SAndroid Build Coastguard Worker const ClientRunnerExampleData::ExampleList* const example_list_; 140*14675a02SAndroid Build Coastguard Worker const int num_examples_; 141*14675a02SAndroid Build Coastguard Worker int num_examples_served_ = 0; 142*14675a02SAndroid Build Coastguard Worker }; 143*14675a02SAndroid Build Coastguard Worker TrainingConditionsSatisfied()144*14675a02SAndroid Build Coastguard Worker bool TrainingConditionsSatisfied() override { 145*14675a02SAndroid Build Coastguard Worker FCP_CLIENT_LOG_FUNCTION_NAME; 146*14675a02SAndroid Build Coastguard Worker return true; 147*14675a02SAndroid Build Coastguard Worker } 148*14675a02SAndroid Build Coastguard Worker 149*14675a02SAndroid Build Coastguard Worker const std::variant<int, ClientRunnerExampleData> examples_; 150*14675a02SAndroid Build Coastguard Worker const std::string test_cert_path_; 151*14675a02SAndroid Build Coastguard Worker fcp::client::http::curl::CurlApi curl_api_; 152*14675a02SAndroid Build Coastguard Worker }; 153*14675a02SAndroid Build Coastguard Worker 154*14675a02SAndroid Build Coastguard Worker // An implementation of the Files interface that attempts to create a temporary 155*14675a02SAndroid Build Coastguard Worker // file with the given prefix and suffix in a directory suitable for temporary 156*14675a02SAndroid Build Coastguard Worker // files. 157*14675a02SAndroid Build Coastguard Worker // NB this is a proof-of-concept implementation that does not use existing infra 158*14675a02SAndroid Build Coastguard Worker // such as mkstemps() or std::tmpfile due to the requirements of the existing 159*14675a02SAndroid Build Coastguard Worker // Files API: include prefix, suffix strings in filename; return file path 160*14675a02SAndroid Build Coastguard Worker // instead of file descriptor. 161*14675a02SAndroid Build Coastguard Worker class FilesImpl : public Files { 162*14675a02SAndroid Build Coastguard Worker public: FilesImpl()163*14675a02SAndroid Build Coastguard Worker FilesImpl() { std::srand(static_cast<int32_t>(std::time(nullptr))); } 164*14675a02SAndroid Build Coastguard Worker CreateTempFile(const std::string & prefix,const std::string & suffix)165*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::string> CreateTempFile( 166*14675a02SAndroid Build Coastguard Worker const std::string& prefix, const std::string& suffix) override { 167*14675a02SAndroid Build Coastguard Worker const auto tmp_dir = std::filesystem::path(testing::TempDir()); 168*14675a02SAndroid Build Coastguard Worker std::filesystem::path candidate_path; 169*14675a02SAndroid Build Coastguard Worker int fd; 170*14675a02SAndroid Build Coastguard Worker do { 171*14675a02SAndroid Build Coastguard Worker candidate_path = 172*14675a02SAndroid Build Coastguard Worker tmp_dir / absl::StrCat(prefix, std::to_string(std::rand()), suffix); 173*14675a02SAndroid Build Coastguard Worker } while ((fd = open(candidate_path.c_str(), O_CREAT | O_EXCL | O_RDWR, 174*14675a02SAndroid Build Coastguard Worker S_IRWXU)) == -1 && 175*14675a02SAndroid Build Coastguard Worker errno == EEXIST); 176*14675a02SAndroid Build Coastguard Worker close(fd); 177*14675a02SAndroid Build Coastguard Worker std::ofstream tmp_file(candidate_path); 178*14675a02SAndroid Build Coastguard Worker if (!tmp_file) { 179*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError( 180*14675a02SAndroid Build Coastguard Worker absl::StrCat("could not create file ", candidate_path.string())); 181*14675a02SAndroid Build Coastguard Worker } 182*14675a02SAndroid Build Coastguard Worker // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << candidate_path; 183*14675a02SAndroid Build Coastguard Worker return candidate_path.string(); 184*14675a02SAndroid Build Coastguard Worker } 185*14675a02SAndroid Build Coastguard Worker }; 186*14675a02SAndroid Build Coastguard Worker 187*14675a02SAndroid Build Coastguard Worker // A stub implementation of the LogManager interface that logs invocations to 188*14675a02SAndroid Build Coastguard Worker // stderr. 189*14675a02SAndroid Build Coastguard Worker class LogManagerImpl : public LogManager { 190*14675a02SAndroid Build Coastguard Worker public: LogDiag(ProdDiagCode diag_code)191*14675a02SAndroid Build Coastguard Worker void LogDiag(ProdDiagCode diag_code) override { 192*14675a02SAndroid Build Coastguard Worker // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << ProdDiagCode_Name(diag_code); 193*14675a02SAndroid Build Coastguard Worker } LogDiag(DebugDiagCode diag_code)194*14675a02SAndroid Build Coastguard Worker void LogDiag(DebugDiagCode diag_code) override { 195*14675a02SAndroid Build Coastguard Worker // FCP_CLIENT_LOG_FUNCTION_NAME << ": " << DebugDiagCode_Name(diag_code); 196*14675a02SAndroid Build Coastguard Worker } LogToLongHistogram(HistogramCounters histogram_counter,int,int,engine::DataSourceType data_source_type,int64_t value)197*14675a02SAndroid Build Coastguard Worker void LogToLongHistogram(HistogramCounters histogram_counter, int, int, 198*14675a02SAndroid Build Coastguard Worker engine::DataSourceType data_source_type, 199*14675a02SAndroid Build Coastguard Worker int64_t value) override { 200*14675a02SAndroid Build Coastguard Worker // FCP_CLIENT_LOG_FUNCTION_NAME 201*14675a02SAndroid Build Coastguard Worker // << ": " << HistogramCounters_Name(histogram_counter) << " <- " << 202*14675a02SAndroid Build Coastguard Worker // value; 203*14675a02SAndroid Build Coastguard Worker } 204*14675a02SAndroid Build Coastguard Worker SetModelIdentifier(const std::string & model_identifier)205*14675a02SAndroid Build Coastguard Worker void SetModelIdentifier(const std::string& model_identifier) override { 206*14675a02SAndroid Build Coastguard Worker // FCP_CLIENT_LOG_FUNCTION_NAME << ":\n\t" << model_identifier; 207*14675a02SAndroid Build Coastguard Worker } 208*14675a02SAndroid Build Coastguard Worker }; 209*14675a02SAndroid Build Coastguard Worker 210*14675a02SAndroid Build Coastguard Worker class FlagsImpl : public Flags { 211*14675a02SAndroid Build Coastguard Worker public: set_use_http_federated_compute_protocol(bool value)212*14675a02SAndroid Build Coastguard Worker void set_use_http_federated_compute_protocol(bool value) { 213*14675a02SAndroid Build Coastguard Worker use_http_federated_compute_protocol_ = value; 214*14675a02SAndroid Build Coastguard Worker } set_use_tflite_training(bool value)215*14675a02SAndroid Build Coastguard Worker void set_use_tflite_training(bool value) { use_tflite_training_ = value; } 216*14675a02SAndroid Build Coastguard Worker condition_polling_period_millis()217*14675a02SAndroid Build Coastguard Worker int64_t condition_polling_period_millis() const override { return 1000; } tf_execution_teardown_grace_period_millis()218*14675a02SAndroid Build Coastguard Worker int64_t tf_execution_teardown_grace_period_millis() const override { 219*14675a02SAndroid Build Coastguard Worker return 1000; 220*14675a02SAndroid Build Coastguard Worker } tf_execution_teardown_extended_period_millis()221*14675a02SAndroid Build Coastguard Worker int64_t tf_execution_teardown_extended_period_millis() const override { 222*14675a02SAndroid Build Coastguard Worker return 2000; 223*14675a02SAndroid Build Coastguard Worker } grpc_channel_deadline_seconds()224*14675a02SAndroid Build Coastguard Worker int64_t grpc_channel_deadline_seconds() const override { return 0; } log_tensorflow_error_messages()225*14675a02SAndroid Build Coastguard Worker bool log_tensorflow_error_messages() const override { return true; } use_http_federated_compute_protocol()226*14675a02SAndroid Build Coastguard Worker bool use_http_federated_compute_protocol() const override { 227*14675a02SAndroid Build Coastguard Worker return use_http_federated_compute_protocol_; 228*14675a02SAndroid Build Coastguard Worker } use_tflite_training()229*14675a02SAndroid Build Coastguard Worker bool use_tflite_training() const override { return use_tflite_training_; } 230*14675a02SAndroid Build Coastguard Worker 231*14675a02SAndroid Build Coastguard Worker private: 232*14675a02SAndroid Build Coastguard Worker bool use_http_federated_compute_protocol_ = false; 233*14675a02SAndroid Build Coastguard Worker bool use_tflite_training_ = false; 234*14675a02SAndroid Build Coastguard Worker }; 235*14675a02SAndroid Build Coastguard Worker 236*14675a02SAndroid Build Coastguard Worker } // namespace fcp::client 237*14675a02SAndroid Build Coastguard Worker 238*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_CLIENT_RUNNER_H_ 239