1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/test_helpers.h"
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Worker #include <android-base/file.h>
19*14675a02SAndroid Build Coastguard Worker #include <fcntl.h>
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Worker #include <fstream>
22*14675a02SAndroid Build Coastguard Worker #include <string>
23*14675a02SAndroid Build Coastguard Worker #include <vector>
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_cat.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker namespace fcp {
30*14675a02SAndroid Build Coastguard Worker namespace client {
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::Dataset;
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker namespace {
LoadFileAsString_(std::string path,std::string * msg)35*14675a02SAndroid Build Coastguard Worker bool LoadFileAsString_(std::string path, std::string* msg) {
36*14675a02SAndroid Build Coastguard Worker std::ifstream checkpoint_istream(path);
37*14675a02SAndroid Build Coastguard Worker if (!checkpoint_istream) {
38*14675a02SAndroid Build Coastguard Worker return false;
39*14675a02SAndroid Build Coastguard Worker }
40*14675a02SAndroid Build Coastguard Worker std::stringstream checkpoint_stream;
41*14675a02SAndroid Build Coastguard Worker checkpoint_stream << checkpoint_istream.rdbuf();
42*14675a02SAndroid Build Coastguard Worker *msg = checkpoint_stream.str();
43*14675a02SAndroid Build Coastguard Worker return true;
44*14675a02SAndroid Build Coastguard Worker }
45*14675a02SAndroid Build Coastguard Worker
LoadMessageLiteFromFile_(std::string path,google::protobuf::MessageLite * msg)46*14675a02SAndroid Build Coastguard Worker bool LoadMessageLiteFromFile_(std::string path,
47*14675a02SAndroid Build Coastguard Worker google::protobuf::MessageLite* msg) {
48*14675a02SAndroid Build Coastguard Worker std::string data;
49*14675a02SAndroid Build Coastguard Worker if (!LoadFileAsString_(path, &data)) {
50*14675a02SAndroid Build Coastguard Worker return false;
51*14675a02SAndroid Build Coastguard Worker }
52*14675a02SAndroid Build Coastguard Worker if (!msg->ParseFromString(data)) {
53*14675a02SAndroid Build Coastguard Worker return false;
54*14675a02SAndroid Build Coastguard Worker }
55*14675a02SAndroid Build Coastguard Worker return true;
56*14675a02SAndroid Build Coastguard Worker }
57*14675a02SAndroid Build Coastguard Worker } // namespace
58*14675a02SAndroid Build Coastguard Worker
SimpleExampleIterator(std::vector<const char * > examples)59*14675a02SAndroid Build Coastguard Worker SimpleExampleIterator::SimpleExampleIterator(
60*14675a02SAndroid Build Coastguard Worker std::vector<const char*> examples) {
61*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "***** create example iterator examples";
62*14675a02SAndroid Build Coastguard Worker for (const auto& example : examples) {
63*14675a02SAndroid Build Coastguard Worker examples_.push_back(std::string(example));
64*14675a02SAndroid Build Coastguard Worker }
65*14675a02SAndroid Build Coastguard Worker FCP_CHECK(!examples_.empty()) << "No data was loaded";
66*14675a02SAndroid Build Coastguard Worker }
67*14675a02SAndroid Build Coastguard Worker
SimpleExampleIterator(Dataset dataset)68*14675a02SAndroid Build Coastguard Worker SimpleExampleIterator::SimpleExampleIterator(Dataset dataset) {
69*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "***** create example iterator dataset";
70*14675a02SAndroid Build Coastguard Worker for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
71*14675a02SAndroid Build Coastguard Worker FCP_CHECK(client_dataset.selected_example_size() == 0)
72*14675a02SAndroid Build Coastguard Worker << "This constructor can only be used for Dataset protos with unnamed "
73*14675a02SAndroid Build Coastguard Worker "example data.";
74*14675a02SAndroid Build Coastguard Worker for (const auto& example : client_dataset.example()) {
75*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "***** create example iterator";
76*14675a02SAndroid Build Coastguard Worker examples_.push_back(example);
77*14675a02SAndroid Build Coastguard Worker }
78*14675a02SAndroid Build Coastguard Worker }
79*14675a02SAndroid Build Coastguard Worker FCP_CHECK(!examples_.empty()) << "No data was loaded";
80*14675a02SAndroid Build Coastguard Worker }
81*14675a02SAndroid Build Coastguard Worker
SimpleExampleIterator(Dataset dataset,absl::string_view collection_uri)82*14675a02SAndroid Build Coastguard Worker SimpleExampleIterator::SimpleExampleIterator(Dataset dataset,
83*14675a02SAndroid Build Coastguard Worker absl::string_view collection_uri) {
84*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "***** create example iterator dataset uri";
85*14675a02SAndroid Build Coastguard Worker
86*14675a02SAndroid Build Coastguard Worker for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
87*14675a02SAndroid Build Coastguard Worker FCP_CHECK(client_dataset.selected_example_size() > 0)
88*14675a02SAndroid Build Coastguard Worker << "This constructor can only be used for Dataset protos with named "
89*14675a02SAndroid Build Coastguard Worker "example data.";
90*14675a02SAndroid Build Coastguard Worker for (const Dataset::ClientDataset::SelectedExample& selected_example :
91*14675a02SAndroid Build Coastguard Worker client_dataset.selected_example()) {
92*14675a02SAndroid Build Coastguard Worker // Only use those examples whose `ExampleSelector` matches the
93*14675a02SAndroid Build Coastguard Worker // `collection_uri` argument. Note that the `ExampleSelector`'s selection
94*14675a02SAndroid Build Coastguard Worker // criteria is ignored/not taken into account here.
95*14675a02SAndroid Build Coastguard Worker if (selected_example.selector().collection_uri() != collection_uri) {
96*14675a02SAndroid Build Coastguard Worker continue;
97*14675a02SAndroid Build Coastguard Worker }
98*14675a02SAndroid Build Coastguard Worker for (const auto& example : selected_example.example()) {
99*14675a02SAndroid Build Coastguard Worker examples_.push_back(example);
100*14675a02SAndroid Build Coastguard Worker }
101*14675a02SAndroid Build Coastguard Worker }
102*14675a02SAndroid Build Coastguard Worker }
103*14675a02SAndroid Build Coastguard Worker FCP_CHECK(!examples_.empty()) << "No data was loaded for " << collection_uri;
104*14675a02SAndroid Build Coastguard Worker }
105*14675a02SAndroid Build Coastguard Worker
Next()106*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::string> SimpleExampleIterator::Next() {
107*14675a02SAndroid Build Coastguard Worker if (index_ < examples_.size()) {
108*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "***** return next example " << examples_[index_];
109*14675a02SAndroid Build Coastguard Worker return examples_[index_++];
110*14675a02SAndroid Build Coastguard Worker }
111*14675a02SAndroid Build Coastguard Worker return absl::OutOfRangeError("");
112*14675a02SAndroid Build Coastguard Worker }
113*14675a02SAndroid Build Coastguard Worker
LoadFlArtifacts()114*14675a02SAndroid Build Coastguard Worker absl::StatusOr<ComputationArtifacts> LoadFlArtifacts() {
115*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "***** LoadFlArtifacts";
116*14675a02SAndroid Build Coastguard Worker std::string artifact_path_prefix =
117*14675a02SAndroid Build Coastguard Worker absl::StrCat(android::base::GetExecutableDirectory(), "/fcp/testdata");
118*14675a02SAndroid Build Coastguard Worker ComputationArtifacts result;
119*14675a02SAndroid Build Coastguard Worker result.plan_filepath =
120*14675a02SAndroid Build Coastguard Worker absl::StrCat(artifact_path_prefix, "/federation_client_only_plan.pb");
121*14675a02SAndroid Build Coastguard Worker std::string plan;
122*14675a02SAndroid Build Coastguard Worker // if (!LoadFileAsString_(result.plan_filepath, &plan)) {
123*14675a02SAndroid Build Coastguard Worker // return absl::InternalError("Failed to load ClientOnlyPlan as string");
124*14675a02SAndroid Build Coastguard Worker // }
125*14675a02SAndroid Build Coastguard Worker // // Load the plan data from the file.
126*14675a02SAndroid Build Coastguard Worker if (!LoadMessageLiteFromFile_(result.plan_filepath, &result.plan)) {
127*14675a02SAndroid Build Coastguard Worker return absl::InternalError("Failed to load ClientOnlyPlan");
128*14675a02SAndroid Build Coastguard Worker }
129*14675a02SAndroid Build Coastguard Worker
130*14675a02SAndroid Build Coastguard Worker // Load dataset
131*14675a02SAndroid Build Coastguard Worker auto dataset_filepath =
132*14675a02SAndroid Build Coastguard Worker absl::StrCat(artifact_path_prefix, "/federation_proxy_train_examples.pb");
133*14675a02SAndroid Build Coastguard Worker if (!LoadMessageLiteFromFile_(dataset_filepath, &result.dataset)) {
134*14675a02SAndroid Build Coastguard Worker return absl::InternalError("Failed to load example Dataset");
135*14675a02SAndroid Build Coastguard Worker }
136*14675a02SAndroid Build Coastguard Worker
137*14675a02SAndroid Build Coastguard Worker result.checkpoint_filepath = absl::StrCat(
138*14675a02SAndroid Build Coastguard Worker artifact_path_prefix, "/federation_test_checkpoint.client.ckp");
139*14675a02SAndroid Build Coastguard Worker // Load the checkpoint data from the file.
140*14675a02SAndroid Build Coastguard Worker if (!LoadFileAsString_(result.checkpoint_filepath, &result.checkpoint)) {
141*14675a02SAndroid Build Coastguard Worker return absl::InternalError("Failed to load checkpoint");
142*14675a02SAndroid Build Coastguard Worker }
143*14675a02SAndroid Build Coastguard Worker
144*14675a02SAndroid Build Coastguard Worker auto federated_select_slices_filepath = absl::StrCat(
145*14675a02SAndroid Build Coastguard Worker artifact_path_prefix, "/federation_test_select_checkpoints.pb");
146*14675a02SAndroid Build Coastguard Worker // Load the federated select slices data.
147*14675a02SAndroid Build Coastguard Worker if (!LoadMessageLiteFromFile_(federated_select_slices_filepath,
148*14675a02SAndroid Build Coastguard Worker &result.federated_select_slices)) {
149*14675a02SAndroid Build Coastguard Worker return absl::InternalError("Failed to load federated select slices");
150*14675a02SAndroid Build Coastguard Worker }
151*14675a02SAndroid Build Coastguard Worker return result;
152*14675a02SAndroid Build Coastguard Worker }
153*14675a02SAndroid Build Coastguard Worker
ExtractSingleString(const tensorflow::Example & example,const char key[])154*14675a02SAndroid Build Coastguard Worker std::string ExtractSingleString(const tensorflow::Example& example,
155*14675a02SAndroid Build Coastguard Worker const char key[]) {
156*14675a02SAndroid Build Coastguard Worker return example.features().feature().at(key).bytes_list().value().at(0);
157*14675a02SAndroid Build Coastguard Worker }
158*14675a02SAndroid Build Coastguard Worker
ExtractRepeatedString(const tensorflow::Example & example,const char key[])159*14675a02SAndroid Build Coastguard Worker google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
160*14675a02SAndroid Build Coastguard Worker const tensorflow::Example& example, const char key[]) {
161*14675a02SAndroid Build Coastguard Worker return example.features().feature().at(key).bytes_list().value();
162*14675a02SAndroid Build Coastguard Worker }
163*14675a02SAndroid Build Coastguard Worker
ExtractSingleInt64(const tensorflow::Example & example,const char key[])164*14675a02SAndroid Build Coastguard Worker int64_t ExtractSingleInt64(const tensorflow::Example& example,
165*14675a02SAndroid Build Coastguard Worker const char key[]) {
166*14675a02SAndroid Build Coastguard Worker return example.features().feature().at(key).int64_list().value().at(0);
167*14675a02SAndroid Build Coastguard Worker }
168*14675a02SAndroid Build Coastguard Worker
ExtractRepeatedInt64(const tensorflow::Example & example,const char key[])169*14675a02SAndroid Build Coastguard Worker google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
170*14675a02SAndroid Build Coastguard Worker const tensorflow::Example& example, const char key[]) {
171*14675a02SAndroid Build Coastguard Worker return example.features().feature().at(key).int64_list().value();
172*14675a02SAndroid Build Coastguard Worker }
173*14675a02SAndroid Build Coastguard Worker
174*14675a02SAndroid Build Coastguard Worker } // namespace client
175*14675a02SAndroid Build Coastguard Worker } // namespace fcp
176