1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/test_helpers.h"
17
18 #include <android-base/file.h>
19 #include <fcntl.h>
20
21 #include <fstream>
22 #include <string>
23 #include <vector>
24
25 #include "absl/status/statusor.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28
29 namespace fcp {
30 namespace client {
31
32 using ::google::internal::federated::plan::Dataset;
33
34 namespace {
LoadFileAsString_(std::string path,std::string * msg)35 bool LoadFileAsString_(std::string path, std::string* msg) {
36 std::ifstream checkpoint_istream(path);
37 if (!checkpoint_istream) {
38 return false;
39 }
40 std::stringstream checkpoint_stream;
41 checkpoint_stream << checkpoint_istream.rdbuf();
42 *msg = checkpoint_stream.str();
43 return true;
44 }
45
LoadMessageLiteFromFile_(std::string path,google::protobuf::MessageLite * msg)46 bool LoadMessageLiteFromFile_(std::string path,
47 google::protobuf::MessageLite* msg) {
48 std::string data;
49 if (!LoadFileAsString_(path, &data)) {
50 return false;
51 }
52 if (!msg->ParseFromString(data)) {
53 return false;
54 }
55 return true;
56 }
57 } // namespace
58
SimpleExampleIterator(std::vector<const char * > examples)59 SimpleExampleIterator::SimpleExampleIterator(
60 std::vector<const char*> examples) {
61 FCP_LOG(INFO) << "***** create example iterator examples";
62 for (const auto& example : examples) {
63 examples_.push_back(std::string(example));
64 }
65 FCP_CHECK(!examples_.empty()) << "No data was loaded";
66 }
67
SimpleExampleIterator(Dataset dataset)68 SimpleExampleIterator::SimpleExampleIterator(Dataset dataset) {
69 FCP_LOG(INFO) << "***** create example iterator dataset";
70 for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
71 FCP_CHECK(client_dataset.selected_example_size() == 0)
72 << "This constructor can only be used for Dataset protos with unnamed "
73 "example data.";
74 for (const auto& example : client_dataset.example()) {
75 FCP_LOG(INFO) << "***** create example iterator";
76 examples_.push_back(example);
77 }
78 }
79 FCP_CHECK(!examples_.empty()) << "No data was loaded";
80 }
81
SimpleExampleIterator(Dataset dataset,absl::string_view collection_uri)82 SimpleExampleIterator::SimpleExampleIterator(Dataset dataset,
83 absl::string_view collection_uri) {
84 FCP_LOG(INFO) << "***** create example iterator dataset uri";
85
86 for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
87 FCP_CHECK(client_dataset.selected_example_size() > 0)
88 << "This constructor can only be used for Dataset protos with named "
89 "example data.";
90 for (const Dataset::ClientDataset::SelectedExample& selected_example :
91 client_dataset.selected_example()) {
92 // Only use those examples whose `ExampleSelector` matches the
93 // `collection_uri` argument. Note that the `ExampleSelector`'s selection
94 // criteria is ignored/not taken into account here.
95 if (selected_example.selector().collection_uri() != collection_uri) {
96 continue;
97 }
98 for (const auto& example : selected_example.example()) {
99 examples_.push_back(example);
100 }
101 }
102 }
103 FCP_CHECK(!examples_.empty()) << "No data was loaded for " << collection_uri;
104 }
105
Next()106 absl::StatusOr<std::string> SimpleExampleIterator::Next() {
107 if (index_ < examples_.size()) {
108 FCP_LOG(INFO) << "***** return next example " << examples_[index_];
109 return examples_[index_++];
110 }
111 return absl::OutOfRangeError("");
112 }
113
LoadFlArtifacts()114 absl::StatusOr<ComputationArtifacts> LoadFlArtifacts() {
115 FCP_LOG(INFO) << "***** LoadFlArtifacts";
116 std::string artifact_path_prefix =
117 absl::StrCat(android::base::GetExecutableDirectory(), "/fcp/testdata");
118 ComputationArtifacts result;
119 result.plan_filepath =
120 absl::StrCat(artifact_path_prefix, "/federation_client_only_plan.pb");
121 std::string plan;
122 // if (!LoadFileAsString_(result.plan_filepath, &plan)) {
123 // return absl::InternalError("Failed to load ClientOnlyPlan as string");
124 // }
125 // // Load the plan data from the file.
126 if (!LoadMessageLiteFromFile_(result.plan_filepath, &result.plan)) {
127 return absl::InternalError("Failed to load ClientOnlyPlan");
128 }
129
130 // Load dataset
131 auto dataset_filepath =
132 absl::StrCat(artifact_path_prefix, "/federation_proxy_train_examples.pb");
133 if (!LoadMessageLiteFromFile_(dataset_filepath, &result.dataset)) {
134 return absl::InternalError("Failed to load example Dataset");
135 }
136
137 result.checkpoint_filepath = absl::StrCat(
138 artifact_path_prefix, "/federation_test_checkpoint.client.ckp");
139 // Load the checkpoint data from the file.
140 if (!LoadFileAsString_(result.checkpoint_filepath, &result.checkpoint)) {
141 return absl::InternalError("Failed to load checkpoint");
142 }
143
144 auto federated_select_slices_filepath = absl::StrCat(
145 artifact_path_prefix, "/federation_test_select_checkpoints.pb");
146 // Load the federated select slices data.
147 if (!LoadMessageLiteFromFile_(federated_select_slices_filepath,
148 &result.federated_select_slices)) {
149 return absl::InternalError("Failed to load federated select slices");
150 }
151 return result;
152 }
153
ExtractSingleString(const tensorflow::Example & example,const char key[])154 std::string ExtractSingleString(const tensorflow::Example& example,
155 const char key[]) {
156 return example.features().feature().at(key).bytes_list().value().at(0);
157 }
158
ExtractRepeatedString(const tensorflow::Example & example,const char key[])159 google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
160 const tensorflow::Example& example, const char key[]) {
161 return example.features().feature().at(key).bytes_list().value();
162 }
163
ExtractSingleInt64(const tensorflow::Example & example,const char key[])164 int64_t ExtractSingleInt64(const tensorflow::Example& example,
165 const char key[]) {
166 return example.features().feature().at(key).int64_list().value().at(0);
167 }
168
ExtractRepeatedInt64(const tensorflow::Example & example,const char key[])169 google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
170 const tensorflow::Example& example, const char key[]) {
171 return example.features().feature().at(key).int64_list().value();
172 }
173
174 } // namespace client
175 } // namespace fcp
176