1 /*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/client/engine/example_query_plan_engine.h"
18
19 #include <atomic>
20 #include <memory>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/client/engine/common.h"
31 #include "fcp/client/engine/plan_engine_helpers.h"
32 #include "fcp/client/example_query_result.pb.h"
33 #include "fcp/client/opstats/opstats_logger.h"
34 #include "fcp/client/simple_task_environment.h"
35 #include "fcp/protos/plan.pb.h"
36 #include "fcp/tensorflow/status.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_slice.h"
39 #include "tensorflow/core/platform/tstring.h"
40 #include "tensorflow/core/util/tensor_slice_writer.h"
41
42 namespace fcp {
43 namespace client {
44 namespace engine {
45
46 namespace tf = ::tensorflow;
47
48 using ::fcp::client::ExampleQueryResult;
49 using ::fcp::client::engine::PlanResult;
50 using ::fcp::client::opstats::OpStatsLogger;
51 using ::google::internal::federated::plan::ExampleQuerySpec;
52 using ::google::internal::federated::plan::ExampleSelector;
53
54 namespace {
55
56 // Writes an one-dimensional tensor using the slice writer.
57 template <typename T>
WriteSlice(tf::checkpoint::TensorSliceWriter & slice_writer,const std::string & name,const int64_t size,const T * data)58 absl::Status WriteSlice(tf::checkpoint::TensorSliceWriter& slice_writer,
59 const std::string& name, const int64_t size,
60 const T* data) {
61 tf::TensorShape shape;
62 shape.AddDim(size);
63 tf::TensorSlice slice(shape.dims());
64 tf::Status tf_status = slice_writer.Add(name, shape, slice, data);
65 return ConvertFromTensorFlowStatus(tf_status);
66 }
67
68 // Returns a map of (vector name) -> tuple(output name, vector spec).
69 absl::flat_hash_map<std::string,
70 std::tuple<std::string, ExampleQuerySpec::OutputVectorSpec>>
GetOutputVectorSpecs(const ExampleQuerySpec::ExampleQuery & example_query)71 GetOutputVectorSpecs(const ExampleQuerySpec::ExampleQuery& example_query) {
72 absl::flat_hash_map<
73 std::string, std::tuple<std::string, ExampleQuerySpec::OutputVectorSpec>>
74 map;
75 for (auto const& [output_name, output_vector_spec] :
76 example_query.output_vector_specs()) {
77 map[output_vector_spec.vector_name()] =
78 std::make_tuple(output_name, output_vector_spec);
79 }
80 return map;
81 }
82
CheckOutputVectorDataType(const ExampleQuerySpec::OutputVectorSpec & output_vector_spec,const ExampleQuerySpec::OutputVectorSpec::DataType & expected_data_type)83 absl::Status CheckOutputVectorDataType(
84 const ExampleQuerySpec::OutputVectorSpec& output_vector_spec,
85 const ExampleQuerySpec::OutputVectorSpec::DataType& expected_data_type) {
86 if (output_vector_spec.data_type() != expected_data_type) {
87 return absl::FailedPreconditionError(
88 "Unexpected data type in the example query");
89 }
90 return absl::OkStatus();
91 }
92
93 // Writes example query results into a checkpoint. Example query results order
94 // must be the same as example_query_spec.example_queries.
WriteCheckpoint(const std::string & output_checkpoint_filename,const std::vector<ExampleQueryResult> & example_query_results,const ExampleQuerySpec & example_query_spec)95 absl::Status WriteCheckpoint(
96 const std::string& output_checkpoint_filename,
97 const std::vector<ExampleQueryResult>& example_query_results,
98 const ExampleQuerySpec& example_query_spec) {
99 tf::checkpoint::TensorSliceWriter slice_writer(
100 output_checkpoint_filename,
101 tf::checkpoint::CreateTableTensorSliceBuilder);
102 for (int i = 0; i < example_query_results.size(); ++i) {
103 const ExampleQueryResult& example_query_result = example_query_results[i];
104 const ExampleQuerySpec::ExampleQuery& example_query =
105 example_query_spec.example_queries()[i];
106 for (auto const& [vector_name, vector_tuple] :
107 GetOutputVectorSpecs(example_query)) {
108 std::string output_name = std::get<0>(vector_tuple);
109 ExampleQuerySpec::OutputVectorSpec output_vector_spec =
110 std::get<1>(vector_tuple);
111 auto it = example_query_result.vector_data().vectors().find(vector_name);
112 if (it == example_query_result.vector_data().vectors().end()) {
113 return absl::DataLossError(
114 "Expected value not found in the example query result");
115 }
116 const ExampleQueryResult::VectorData::Values values = it->second;
117 absl::Status status;
118 if (values.has_int32_values()) {
119 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
120 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::INT32));
121 int64_t size = values.int32_values().value_size();
122 auto data =
123 static_cast<const int32_t*>(values.int32_values().value().data());
124 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
125 } else if (values.has_int64_values()) {
126 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
127 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::INT64));
128 int64_t size = values.int64_values().value_size();
129 auto data =
130 static_cast<const int64_t*>(values.int64_values().value().data());
131 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
132 } else if (values.has_string_values()) {
133 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
134 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::STRING));
135 int64_t size = values.string_values().value_size();
136 std::vector<tf::tstring> tf_string_vector;
137 for (const auto& value : values.string_values().value()) {
138 tf_string_vector.emplace_back(value);
139 }
140 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size,
141 tf_string_vector.data()));
142 } else if (values.has_bool_values()) {
143 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
144 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::BOOL));
145 int64_t size = values.bool_values().value_size();
146 auto data =
147 static_cast<const bool*>(values.bool_values().value().data());
148 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
149 } else if (values.has_float_values()) {
150 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
151 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::FLOAT));
152 int64_t size = values.float_values().value_size();
153 auto data =
154 static_cast<const float*>(values.float_values().value().data());
155 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
156 } else if (values.has_double_values()) {
157 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
158 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::DOUBLE));
159 int64_t size = values.double_values().value_size();
160 auto data =
161 static_cast<const double*>(values.double_values().value().data());
162 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size, data));
163 } else if (values.has_bytes_values()) {
164 FCP_RETURN_IF_ERROR(CheckOutputVectorDataType(
165 output_vector_spec, ExampleQuerySpec::OutputVectorSpec::BYTES));
166 int64_t size = values.bytes_values().value_size();
167 std::vector<tf::tstring> tf_string_vector;
168 for (const auto& value : values.string_values().value()) {
169 tf_string_vector.emplace_back(value);
170 }
171 FCP_RETURN_IF_ERROR(WriteSlice(slice_writer, output_name, size,
172 tf_string_vector.data()));
173 } else {
174 return absl::DataLossError(
175 "Unexpected data type in the example query result");
176 }
177 }
178 }
179 return ConvertFromTensorFlowStatus(slice_writer.Finish());
180 }
181
182 } // anonymous namespace
183
ExampleQueryPlanEngine(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger)184 ExampleQueryPlanEngine::ExampleQueryPlanEngine(
185 std::vector<ExampleIteratorFactory*> example_iterator_factories,
186 OpStatsLogger* opstats_logger)
187 : example_iterator_factories_(example_iterator_factories),
188 opstats_logger_(opstats_logger) {}
189
RunPlan(const ExampleQuerySpec & example_query_spec,const std::string & output_checkpoint_filename)190 PlanResult ExampleQueryPlanEngine::RunPlan(
191 const ExampleQuerySpec& example_query_spec,
192 const std::string& output_checkpoint_filename) {
193 // TODO(team): Add the same logging as in simple_plan_engine.
194 std::vector<ExampleQueryResult> example_query_results;
195
196 for (const auto& example_query : example_query_spec.example_queries()) {
197 ExampleSelector selector = example_query.example_selector();
198 ExampleIteratorFactory* example_iterator_factory =
199 FindExampleIteratorFactory(selector, example_iterator_factories_);
200 if (example_iterator_factory == nullptr) {
201 return PlanResult(PlanOutcome::kExampleIteratorError,
202 absl::InternalError(
203 "Could not find suitable ExampleIteratorFactory"));
204 }
205 absl::StatusOr<std::unique_ptr<ExampleIterator>> example_iterator =
206 example_iterator_factory->CreateExampleIterator(selector);
207 if (!example_iterator.ok()) {
208 return PlanResult(PlanOutcome::kExampleIteratorError,
209 example_iterator.status());
210 }
211
212 std::atomic<int> total_example_count = 0;
213 std::atomic<int64_t> total_example_size_bytes = 0;
214 ExampleIteratorStatus example_iterator_status;
215
216 auto dataset_iterator = std::make_unique<DatasetIterator>(
217 std::move(*example_iterator), opstats_logger_, &total_example_count,
218 &total_example_size_bytes, &example_iterator_status,
219 selector.collection_uri(),
220 /*collect_stats=*/example_iterator_factory->ShouldCollectStats());
221
222 absl::StatusOr<std::string> example_query_result_str =
223 dataset_iterator->GetNext();
224 if (!example_query_result_str.ok()) {
225 return PlanResult(PlanOutcome::kExampleIteratorError,
226 example_query_result_str.status());
227 }
228
229 ExampleQueryResult example_query_result;
230 if (!example_query_result.ParseFromString(*example_query_result_str)) {
231 return PlanResult(
232 PlanOutcome::kExampleIteratorError,
233 absl::DataLossError("Unexpected example query result format"));
234 }
235 example_query_results.push_back(std::move(example_query_result));
236 }
237 absl::Status status = WriteCheckpoint(
238 output_checkpoint_filename, example_query_results, example_query_spec);
239 if (!status.ok()) {
240 return PlanResult(PlanOutcome::kExampleIteratorError, status);
241 }
242 return PlanResult(PlanOutcome::kSuccess, absl::OkStatus());
243 }
244
245 } // namespace engine
246 } // namespace client
247 } // namespace fcp
248