xref: /aosp_15_r20/external/federated-compute/fcp/client/test_helpers.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/test_helpers.h"
17 
18 #include <android-base/file.h>
19 #include <fcntl.h>
20 
21 #include <fstream>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/status/statusor.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 
29 namespace fcp {
30 namespace client {
31 
32 using ::google::internal::federated::plan::Dataset;
33 
34 namespace {
LoadFileAsString_(std::string path,std::string * msg)35 bool LoadFileAsString_(std::string path, std::string* msg) {
36   std::ifstream checkpoint_istream(path);
37   if (!checkpoint_istream) {
38     return false;
39   }
40   std::stringstream checkpoint_stream;
41   checkpoint_stream << checkpoint_istream.rdbuf();
42   *msg = checkpoint_stream.str();
43   return true;
44 }
45 
LoadMessageLiteFromFile_(std::string path,google::protobuf::MessageLite * msg)46 bool LoadMessageLiteFromFile_(std::string path,
47                               google::protobuf::MessageLite* msg) {
48   std::string data;
49   if (!LoadFileAsString_(path, &data)) {
50     return false;
51   }
52   if (!msg->ParseFromString(data)) {
53     return false;
54   }
55   return true;
56 }
57 }  // namespace
58 
SimpleExampleIterator(std::vector<const char * > examples)59 SimpleExampleIterator::SimpleExampleIterator(
60     std::vector<const char*> examples) {
61   FCP_LOG(INFO) << "***** create example iterator examples";
62   for (const auto& example : examples) {
63     examples_.push_back(std::string(example));
64   }
65   FCP_CHECK(!examples_.empty()) << "No data was loaded";
66 }
67 
SimpleExampleIterator(Dataset dataset)68 SimpleExampleIterator::SimpleExampleIterator(Dataset dataset) {
69   FCP_LOG(INFO) << "***** create example iterator dataset";
70   for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
71     FCP_CHECK(client_dataset.selected_example_size() == 0)
72         << "This constructor can only be used for Dataset protos with unnamed "
73            "example data.";
74     for (const auto& example : client_dataset.example()) {
75       FCP_LOG(INFO) << "***** create example iterator";
76       examples_.push_back(example);
77     }
78   }
79   FCP_CHECK(!examples_.empty()) << "No data was loaded";
80 }
81 
SimpleExampleIterator(Dataset dataset,absl::string_view collection_uri)82 SimpleExampleIterator::SimpleExampleIterator(Dataset dataset,
83                                              absl::string_view collection_uri) {
84   FCP_LOG(INFO) << "***** create example iterator dataset uri";
85 
86   for (const Dataset::ClientDataset& client_dataset : dataset.client_data()) {
87     FCP_CHECK(client_dataset.selected_example_size() > 0)
88         << "This constructor can only be used for Dataset protos with named "
89            "example data.";
90     for (const Dataset::ClientDataset::SelectedExample& selected_example :
91          client_dataset.selected_example()) {
92       // Only use those examples whose `ExampleSelector` matches the
93       // `collection_uri` argument. Note that the `ExampleSelector`'s selection
94       // criteria is ignored/not taken into account here.
95       if (selected_example.selector().collection_uri() != collection_uri) {
96         continue;
97       }
98       for (const auto& example : selected_example.example()) {
99         examples_.push_back(example);
100       }
101     }
102   }
103   FCP_CHECK(!examples_.empty()) << "No data was loaded for " << collection_uri;
104 }
105 
Next()106 absl::StatusOr<std::string> SimpleExampleIterator::Next() {
107   if (index_ < examples_.size()) {
108     FCP_LOG(INFO) << "***** return next example " << examples_[index_];
109     return examples_[index_++];
110   }
111   return absl::OutOfRangeError("");
112 }
113 
LoadFlArtifacts()114 absl::StatusOr<ComputationArtifacts> LoadFlArtifacts() {
115   FCP_LOG(INFO) << "***** LoadFlArtifacts";
116   std::string artifact_path_prefix =
117       absl::StrCat(android::base::GetExecutableDirectory(), "/fcp/testdata");
118   ComputationArtifacts result;
119   result.plan_filepath =
120       absl::StrCat(artifact_path_prefix, "/federation_client_only_plan.pb");
121   std::string plan;
122   // if (!LoadFileAsString_(result.plan_filepath, &plan)) {
123   //   return absl::InternalError("Failed to load ClientOnlyPlan as string");
124   // }
125   //     //  Load the plan data from the file.
126   if (!LoadMessageLiteFromFile_(result.plan_filepath, &result.plan)) {
127     return absl::InternalError("Failed to load ClientOnlyPlan");
128   }
129 
130   // Load dataset
131   auto dataset_filepath =
132       absl::StrCat(artifact_path_prefix, "/federation_proxy_train_examples.pb");
133   if (!LoadMessageLiteFromFile_(dataset_filepath, &result.dataset)) {
134     return absl::InternalError("Failed to load example Dataset");
135   }
136 
137   result.checkpoint_filepath = absl::StrCat(
138       artifact_path_prefix, "/federation_test_checkpoint.client.ckp");
139   // Load the checkpoint data from the file.
140   if (!LoadFileAsString_(result.checkpoint_filepath, &result.checkpoint)) {
141     return absl::InternalError("Failed to load checkpoint");
142   }
143 
144   auto federated_select_slices_filepath = absl::StrCat(
145       artifact_path_prefix, "/federation_test_select_checkpoints.pb");
146   // Load the federated select slices data.
147   if (!LoadMessageLiteFromFile_(federated_select_slices_filepath,
148                                 &result.federated_select_slices)) {
149     return absl::InternalError("Failed to load federated select slices");
150   }
151   return result;
152 }
153 
ExtractSingleString(const tensorflow::Example & example,const char key[])154 std::string ExtractSingleString(const tensorflow::Example& example,
155                                 const char key[]) {
156   return example.features().feature().at(key).bytes_list().value().at(0);
157 }
158 
ExtractRepeatedString(const tensorflow::Example & example,const char key[])159 google::protobuf::RepeatedPtrField<std::string> ExtractRepeatedString(
160     const tensorflow::Example& example, const char key[]) {
161   return example.features().feature().at(key).bytes_list().value();
162 }
163 
ExtractSingleInt64(const tensorflow::Example & example,const char key[])164 int64_t ExtractSingleInt64(const tensorflow::Example& example,
165                            const char key[]) {
166   return example.features().feature().at(key).int64_list().value().at(0);
167 }
168 
ExtractRepeatedInt64(const tensorflow::Example & example,const char key[])169 google::protobuf::RepeatedField<int64_t> ExtractRepeatedInt64(
170     const tensorflow::Example& example, const char key[]) {
171   return example.features().feature().at(key).int64_list().value();
172 }
173 
174 }  // namespace client
175 }  // namespace fcp
176