xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/example_query_plan_engine.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/client/engine/example_query_plan_engine.h"
18 
19 #include <atomic>
20 #include <memory>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/client/engine/common.h"
31 #include "fcp/client/engine/plan_engine_helpers.h"
32 #include "fcp/client/example_query_result.pb.h"
33 #include "fcp/client/opstats/opstats_logger.h"
34 #include "fcp/client/simple_task_environment.h"
35 #include "fcp/protos/plan.pb.h"
36 #include "fcp/tensorflow/status.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_slice.h"
39 #include "tensorflow/core/platform/tstring.h"
40 #include "tensorflow/core/util/tensor_slice_writer.h"
41 
42 namespace fcp {
43 namespace client {
44 namespace engine {
45 
46 namespace tf = ::tensorflow;
47 
48 using ::fcp::client::ExampleQueryResult;
49 using ::fcp::client::engine::PlanResult;
50 using ::fcp::client::opstats::OpStatsLogger;
51 using ::google::internal::federated::plan::ExampleQuerySpec;
52 using ::google::internal::federated::plan::ExampleSelector;
53 
54 namespace {
55 
56 // Writes an one-dimensional tensor using the slice writer.
57 template <typename T>
WriteSlice(tf::checkpoint::TensorSliceWriter & slice_writer,const std::string & name,const int64_t size,const T * data)58 absl::Status WriteSlice(tf::checkpoint::TensorSliceWriter& slice_writer,
59                         const std::string& name, const int64_t size,
60                         const T* data) {
61   tf::TensorShape shape;
62   shape.AddDim(size);
63   tf::TensorSlice slice(shape.dims());
64   tf::Status tf_status = slice_writer.Add(name, shape, slice, data);
65   return ConvertFromTensorFlowStatus(tf_status);
66 }
67 
68 // Returns a map of (vector name) -> tuple(output name, vector spec).
69 absl::flat_hash_map<std::string,
70                     std::tuple<std::string, ExampleQuerySpec::OutputVectorSpec>>
GetOutputVectorSpecs(const ExampleQuerySpec::ExampleQuery & example_query)71 GetOutputVectorSpecs(const ExampleQuerySpec::ExampleQuery& example_query) {
72   absl::flat_hash_map<
73       std::string, std::tuple<std::string, ExampleQuerySpec::OutputVectorSpec>>
74       map;
75   for (auto const& [output_name, output_vector_spec] :
76        example_query.output_vector_specs()) {
77     map[output_vector_spec.vector_name()] =
78         std::make_tuple(output_name, output_vector_spec);
79   }
80   return map;
81 }
82 
CheckOutputVectorDataType(const ExampleQuerySpec::OutputVectorSpec & output_vector_spec,const ExampleQuerySpec::OutputVectorSpec::DataType & expected_data_type)83 absl::Status CheckOutputVectorDataType(
84     const ExampleQuerySpec::OutputVectorSpec& output_vector_spec,
85     const ExampleQuerySpec::OutputVectorSpec::DataType& expected_data_type) {
86   if (output_vector_spec.data_type() != expected_data_type) {
87     return absl::FailedPreconditionError(
88         "Unexpected data type in the example query");
89   }
90   return absl::OkStatus();
91 }
92 
93 // Writes example query results into a checkpoint. Example query results order
94 // must be the same as example_query_spec.example_queries.
WriteCheckpoint(const std::string & output_checkpoint_filename,const std::vector<ExampleQueryResult> & example_query_results,const ExampleQuerySpec & example_query_spec)95 absl::Status WriteCheckpoint(
96     const std::string& output_checkpoint_filename,
97     const std::vector<ExampleQueryResult>& example_query_results,
98     const ExampleQuerySpec& example_query_spec) {
99   tf::checkpoint::TensorSliceWriter slice_writer(
100       output_checkpoint_filename,
101       tf::checkpoint::CreateTableTensorSliceBuilder);
102   for (int i = 0; i < example_query_results.size(); ++i) {
103     const ExampleQueryResult& example_query_result = example_query_results[i];
104     const ExampleQuerySpec::ExampleQuery& example_query =
105         example_query_spec.example_queries()[i];
106     for (auto const& [vector_name, vector_tuple] :
107          GetOutputVectorSpecs(example_query)) {
108       std::string output_name = std::get<0>(vector_tuple);
109       ExampleQuerySpec::OutputVectorSpec output_vector_spec =
110           std::get<1>(vector_tuple);
111       auto it = example_query_result.vector_data().vectors().find(vector_name);
112       if (it == example_query_result.vector_data().vectors().end()) {
113         return absl::DataLossError(
114             "Expected value not found in the example query result");
115       }
116       const ExampleQueryResult::VectorData::Values values = it->second;
117       absl::Status status;
118       if (values.has_int32_values()) {
119         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
120             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::INT32));
121         int64_t size = values.int32_values().value_size();
122         auto data =
123             static_cast<const int32_t*>(values.int32_values().value().data());
124         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
125       } else if (values.has_int64_values()) {
126         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
127             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::INT64));
128         int64_t size = values.int64_values().value_size();
129         auto data =
130             static_cast<const int64_t*>(values.int64_values().value().data());
131         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
132       } else if (values.has_string_values()) {
133         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
134             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::STRING));
135         int64_t size = values.string_values().value_size();
136         std::vector<tf::tstring> tf_string_vector;
137         for (const auto& value : values.string_values().value()) {
138           tf_string_vector.emplace_back(value);
139         }
140         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size,
141                                        tf_string_vector.data()));
142       } else if (values.has_bool_values()) {
143         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
144             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::BOOL));
145         int64_t size = values.bool_values().value_size();
146         auto data =
147             static_cast<const bool*>(values.bool_values().value().data());
148         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
149       } else if (values.has_float_values()) {
150         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
151             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::FLOAT));
152         int64_t size = values.float_values().value_size();
153         auto data =
154             static_cast<const float*>(values.float_values().value().data());
155         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
156       } else if (values.has_double_values()) {
157         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
158             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::DOUBLE));
159         int64_t size = values.double_values().value_size();
160         auto data =
161             static_cast<const double*>(values.double_values().value().data());
162         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
163       } else if (values.has_bytes_values()) {
164         FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
165             output_vector_spec, ExampleQuerySpec::OutputVectorSpec::BYTES));
166         int64_t size = values.bytes_values().value_size();
167         std::vector<tf::tstring> tf_string_vector;
168         for (const auto& value : values.string_values().value()) {
169           tf_string_vector.emplace_back(value);
170         }
171         FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size,
172                                        tf_string_vector.data()));
173       } else {
174         return absl::DataLossError(
175             "Unexpected data type in the example query result");
176       }
177     }
178   }
179   return ConvertFromTensorFlowStatus(slice_writer.Finish());
180 }
181 
182 }  // anonymous namespace
183 
ExampleQueryPlanEngine(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger)184 ExampleQueryPlanEngine::ExampleQueryPlanEngine(
185     std::vector<ExampleIteratorFactory*> example_iterator_factories,
186     OpStatsLogger* opstats_logger)
187     : example_iterator_factories_(example_iterator_factories),
188       opstats_logger_(opstats_logger) {}
189 
RunPlan(const ExampleQuerySpec & example_query_spec,const std::string & output_checkpoint_filename)190 PlanResult ExampleQueryPlanEngine::RunPlan(
191     const ExampleQuerySpec& example_query_spec,
192     const std::string& output_checkpoint_filename) {
193   // TODO(team): Add the same logging as in simple_plan_engine.
194   std::vector<ExampleQueryResult> example_query_results;
195 
196   for (const auto& example_query : example_query_spec.example_queries()) {
197     ExampleSelector selector = example_query.example_selector();
198     ExampleIteratorFactory* example_iterator_factory =
199         FindExampleIteratorFactory(selector, example_iterator_factories_);
200     if (example_iterator_factory == nullptr) {
201       return PlanResult(PlanOutcome::kExampleIteratorError,
202                         absl::InternalError(
203                             "Could not find suitable ExampleIteratorFactory"));
204     }
205     absl::StatusOr<std::unique_ptr<ExampleIterator>> example_iterator =
206         example_iterator_factory->CreateExampleIterator(selector);
207     if (!example_iterator.ok()) {
208       return PlanResult(PlanOutcome::kExampleIteratorError,
209                         example_iterator.status());
210     }
211 
212     std::atomic<int> total_example_count = 0;
213     std::atomic<int64_t> total_example_size_bytes = 0;
214     ExampleIteratorStatus example_iterator_status;
215 
216     auto dataset_iterator = std::make_unique<DatasetIterator>(
217         std::move(*example_iterator), opstats_logger_, &total_example_count,
218         &total_example_size_bytes, &example_iterator_status,
219         selector.collection_uri(),
220         /*collect_stats=*/example_iterator_factory->ShouldCollectStats());
221 
222     absl::StatusOr<std::string> example_query_result_str =
223         dataset_iterator->GetNext();
224     if (!example_query_result_str.ok()) {
225       return PlanResult(PlanOutcome::kExampleIteratorError,
226                         example_query_result_str.status());
227     }
228 
229     ExampleQueryResult example_query_result;
230     if (!example_query_result.ParseFromString(*example_query_result_str)) {
231       return PlanResult(
232           PlanOutcome::kExampleIteratorError,
233           absl::DataLossError("Unexpected example query result format"));
234     }
235     example_query_results.push_back(std::move(example_query_result));
236   }
237   absl::Status status = WriteCheckpoint(
238       output_checkpoint_filename, example_query_results, example_query_spec);
239   if (!status.ok()) {
240     return PlanResult(PlanOutcome::kExampleIteratorError, status);
241   }
242   return PlanResult(PlanOutcome::kSuccess, absl::OkStatus());
243 }
244 
245 }  // namespace engine
246 }  // namespace client
247 }  // namespace fcp
248