xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/example_proto_helper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
17 #define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/example/example.pb.h"
24 #include "tensorflow/core/example/feature.pb.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/partial_tensor_shape.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
33 
34 // This is a set of helper methods that will make it possible to share
35 // tensorflow::Example proto Tensor conversion code inside the ExampleParserOp
36 // OpKernel as well as in external code.
37 namespace tensorflow {
38 
39 // "Dense" feature configuration.
40 struct FixedLenFeature {
41   string key;
42   DataType dtype;
43   TensorShape shape;
44   Tensor default_value;
45   string values_output_tensor_name;
46 };
47 
48 // "Sparse" feature configuration.
49 struct VarLenFeature {
50   string key;
51   DataType dtype;
52   string values_output_tensor_name;
53   string indices_output_tensor_name;
54   string shapes_output_tensor_name;
55 };
56 
57 // Given a single tensorflow::Example, with an optional example name
58 // at a particular index within a batch, and dense and sparse feature
59 // configurations from fixed_len_features, var_len_features, this method
60 // updates the dense value tensor and the sparse values temporary vector
61 // of tensors. The indexing of the output vectors correspond 1:1 to the
62 // indexing of the feature configuration vectors.
63 //
64 // The fixed_len_features and var_len_features maps are assume to be
65 // have disjoint key fields from the Feature map in the tensorflow.Example
66 // proto.
67 //
68 // For each sparse feature, the sparse values temporary vector holds a
69 // tensor for each Example. Each tensor is either empty or filled, depending
70 // on if the sparse feature value is set for the Example. This
71 // temporary structure is needed because we need to know the total number
72 // of filled elements in the batch to get the proper final sparse tensor
73 // shapes allocated.  After the entire batch is processed,
74 // GetSparseTensorShape can be used to calculate the final shapes and
75 // CopyIntoSparseTensor can be used to copy from the temporary vector
76 // into the final allocated tensors.
77 Status SingleExampleProtoToTensors(
78     const Example& example, const string& name, const int batch_index,
79     const std::vector<FixedLenFeature>& fixed_len_features,
80     const std::vector<VarLenFeature>& var_len_features,
81     std::vector<Tensor*>* dense_values,
82     std::vector<std::vector<Tensor>>* sparse_values_temporary_vector);
83 
84 // The shape of the indices and values tensors associated with a SparseTensor
85 // are dependent on the contents of the batch.
86 struct VarLenFeatureBatchShapes {
87   TensorShape indices_shape;
88   TensorShape values_shape;
89   int max_num_features;
90 };
91 
92 // Get the shape of the sparse values and indices tensors for the batch,
93 // given how many of the tensors in the temporary sparse values vector
94 // are actually filled.
95 Status GetSparseTensorShapes(const VarLenFeature& var_len_feature,
96                              const std::vector<Tensor>& sparse_values_tmp,
97                              const int batch_size,
98                              VarLenFeatureBatchShapes* output_shapes);
99 
100 // A method to convert a batch of tensorflow::Example protos into output
101 // tensors. This method is useful if there already is a batch of deserialized
102 // Example protos in memory (such as a serving use-case) and we do not wish
103 // to incur an extraneous serialize/deserialize.  It is intended
104 // as an outside of OpKernel compatible replacement for the functionality of
105 // ExampleParserOp. In a serving setting, this method could be used to produce
106 // a feed_dict of Tensors that could bypass the ExampleParserOp.
107 //
108 // Note that unlike SingleExampleProtoToTensors, output tensors are
109 // allocated using a provided Allocator within this method.
110 Status BatchExampleProtoToTensors(
111     const std::vector<const Example*>& examples,
112     const std::vector<string>& names,
113     const std::vector<FixedLenFeature>& fixed_len_features,
114     const std::vector<VarLenFeature>& var_len_features, Allocator* allocator,
115     std::vector<Tensor>* output_dense_values_tensor,
116     std::vector<Tensor>* output_sparse_indices_tensor,
117     std::vector<Tensor>* output_sparse_values_tensor,
118     std::vector<Tensor>* output_sparse_shapes_tensor);
119 
120 // Check that the given dtype is one that is compatible with
121 // tensorflow::Example protocol buffer feature values.
122 Status CheckValidType(const DataType& dtype);
123 
124 // Check that the provided Feature proto message's oneof value
125 // matches that of the provided dtype.
126 Status CheckTypesMatch(const Feature& feature, const DataType& dtype,
127                        bool* match);
128 
129 // For a single Example, copy a dense feature value into an output
130 // dense value tensor Out at the provided out_index offset.
131 Status FeatureDenseCopy(const std::size_t out_index, const string& name,
132                         const string& key, const DataType& dtype,
133                         const TensorShape& shape, const Feature& feature,
134                         Tensor* out);
135 
136 // Copy the value a provided Tensor into an output dense_value tensor Out
137 // at the provided out_index offset.
138 void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
139                   const Tensor& in, Tensor* out);
140 
141 // For a single Example, and given sparse feature return a temporary output
142 // Tensor suitable for being collected in the temporary sparse value vector.
143 Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
144                          const DataType& dtype, const Feature& feature);
145 
146 // Copy a temporary Tensor into the final sparse indices and values
147 // tensor at a given batch index and element offset. This method
148 // assumes that the indices/values Tensors have been properly allocated
149 // for the batch.
150 int64_t CopyIntoSparseTensor(const Tensor& in, const int batch,
151                              const int64_t offset, Tensor* indices,
152                              Tensor* values);
153 
154 // Check that each dense_shape has known rank and inner dimensions; and
155 // update variable_length (whether the outer dimension is None) and
156 // elements_per_stride for each denes_shape.
157 Status GetDenseShapes(const std::vector<PartialTensorShape>& dense_shapes,
158                       std::vector<bool>* variable_length,
159                       std::vector<std::size_t>* elements_per_stride);
160 
161 // Parses the attributes passed to ParseExample.
162 // REQUIRES: Init must be called after construction.
163 struct ParseExampleAttrs {
164  public:
165   template <typename ContextType>
166   Status Init(ContextType* ctx, int op_version = 1) {
167     TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types", &sparse_types));
168     TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense", &dense_types));
169     TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes", &dense_shapes));
170     TF_RETURN_IF_ERROR(
171         GetDenseShapes(dense_shapes, &variable_length, &elements_per_stride));
172     switch (op_version) {
173       case 1:
174         TF_RETURN_IF_ERROR(ctx->GetAttr("Nsparse", &num_sparse));
175         TF_RETURN_IF_ERROR(ctx->GetAttr("Ndense", &num_dense));
176         break;
177       case 2:
178         TF_RETURN_IF_ERROR(
179             ctx->GetAttr("ragged_value_types", &ragged_value_types));
180         TF_RETURN_IF_ERROR(ctx->GetAttr("num_sparse", &num_sparse));
181         TF_RETURN_IF_ERROR(
182             ctx->GetAttr("ragged_split_types", &ragged_split_types));
183         break;
184       default:
185         return errors::InvalidArgument("Unexpected op_version", op_version);
186     }
187     return FinishInit(op_version);
188   }
189 
190   int64_t num_sparse;
191   int64_t num_dense;
192   int64_t num_ragged;
193   std::vector<DataType> sparse_types;
194   std::vector<DataType> dense_types;
195   std::vector<DataType> ragged_value_types;
196   std::vector<DataType> ragged_split_types;
197   std::vector<PartialTensorShape> dense_shapes;
198   std::vector<bool> variable_length;
199   std::vector<std::size_t> elements_per_stride;
200 
201  private:
202   Status FinishInit(int op_version);  // for context-independent parts of Init.
203 };
204 
205 // Parses the attributes passed to ParseSingleExample.
206 // REQUIRES: Init must be called after construction.
207 struct ParseSingleExampleAttrs {
208  public:
209   template <typename ContextType>
InitParseSingleExampleAttrs210   Status Init(ContextType* ctx) {
211     TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_keys", &sparse_keys));
212     TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types", &sparse_types));
213     TF_RETURN_IF_ERROR(ctx->GetAttr("dense_keys", &dense_keys));
214     TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense", &dense_types));
215     TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes", &dense_shapes));
216 
217     int num_sparse;
218     TF_RETURN_IF_ERROR(ctx->GetAttr("num_sparse", &num_sparse));
219     if (num_sparse != sparse_keys.size() || num_sparse != sparse_types.size()) {
220       return errors::InvalidArgument(
221           "num_sparse (", num_sparse, ") must match the size of sparse_keys (",
222           sparse_keys.size(), ") and sparse_types (", sparse_types.size(), ")");
223     }
224 
225     TF_RETURN_IF_ERROR(
226         GetDenseShapes(dense_shapes, &variable_length, &elements_per_stride));
227     return FinishInit();
228   }
229 
230   std::vector<tstring> sparse_keys;
231   std::vector<DataType> sparse_types;
232   std::vector<tstring> dense_keys;
233   std::vector<DataType> dense_types;
234   std::vector<PartialTensorShape> dense_shapes;
235   std::vector<bool> variable_length;
236   std::vector<std::size_t> elements_per_stride;
237 
238  private:
239   Status FinishInit();  // for context-independent parts of Init.
240 };
241 
242 // Parses the attributes passed to ParseSequenceExample.
243 // REQUIRES: Init must be called after construction.
244 struct ParseSequenceExampleAttrs {
245  public:
246   template <typename ContextType>
247   Status Init(ContextType* ctx, int op_version = 1) {
248     switch (op_version) {
249       case 1: {
250         std::vector<string> missing_empty_vector;
251         TF_RETURN_IF_ERROR(ctx->GetAttr(
252             "feature_list_dense_missing_assumed_empty", &missing_empty_vector));
253         for (const string& feature : missing_empty_vector) {
254           feature_list_dense_missing_assumed_empty.insert(feature);
255         }
256       }
257         TF_RETURN_IF_ERROR(
258             ctx->GetAttr("context_sparse_keys", &context_sparse_keys));
259         TF_RETURN_IF_ERROR(
260             ctx->GetAttr("context_dense_keys", &context_dense_keys));
261         TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_sparse_keys",
262                                         &feature_list_sparse_keys));
263         TF_RETURN_IF_ERROR(
264             ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys));
265         TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense));
266         break;
267       case 2:
268         TF_RETURN_IF_ERROR(ctx->GetAttr("context_ragged_value_types",
269                                         &context_ragged_value_types));
270         TF_RETURN_IF_ERROR(ctx->GetAttr("context_ragged_split_types",
271                                         &context_ragged_split_types));
272         TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_ragged_value_types",
273                                         &feature_list_ragged_value_types));
274         TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_ragged_split_types",
275                                         &feature_list_ragged_split_types));
276         break;
277       default:
278         return errors::InvalidArgument("Unexpected op_version", op_version);
279     }
280     TF_RETURN_IF_ERROR(
281         ctx->GetAttr("context_sparse_types", &context_sparse_types));
282     TF_RETURN_IF_ERROR(
283         ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense));
284     TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse));
285     TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types));
286     TF_RETURN_IF_ERROR(
287         ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types));
288     TF_RETURN_IF_ERROR(
289         ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types));
290     TF_RETURN_IF_ERROR(
291         ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse));
292     TF_RETURN_IF_ERROR(
293         ctx->GetAttr("context_dense_shapes", &context_dense_shapes));
294     TF_RETURN_IF_ERROR(
295         ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes));
296     return FinishInit(op_version);
297   }
298 
299   std::unordered_set<string> feature_list_dense_missing_assumed_empty;
300   int64_t num_context_sparse;
301   int64_t num_context_dense;
302   int64_t num_context_ragged;
303   int64_t num_feature_list_sparse;
304   int64_t num_feature_list_dense;
305   int64_t num_feature_list_ragged;
306   std::vector<tstring> context_sparse_keys;
307   std::vector<tstring> context_dense_keys;
308   std::vector<tstring> feature_list_sparse_keys;
309   std::vector<tstring> feature_list_dense_keys;
310   std::vector<DataType> context_sparse_types;
311   std::vector<DataType> context_dense_types;
312   std::vector<TensorShape> context_dense_shapes;
313   std::vector<DataType> feature_list_sparse_types;
314   std::vector<DataType> feature_list_dense_types;
315   std::vector<TensorShape> feature_list_dense_shapes;
316   std::vector<DataType> context_ragged_value_types;
317   std::vector<DataType> context_ragged_split_types;
318   std::vector<DataType> feature_list_ragged_value_types;
319   std::vector<DataType> feature_list_ragged_split_types;
320 
321  private:
322   Status FinishInit(int op_version);  // for context-independent parts of Init.
323 };
324 
325 // Parses the attributes passed to ParseSingleSequenceExample.
326 // REQUIRES: Init must be called after construction.
327 struct ParseSingleSequenceExampleAttrs {
328  public:
329   template <typename ContextType>
InitParseSingleSequenceExampleAttrs330   Status Init(ContextType* ctx) {
331     TF_RETURN_IF_ERROR(
332         ctx->GetAttr("context_sparse_types", &context_sparse_types));
333     TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense));
334     TF_RETURN_IF_ERROR(
335         ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense));
336     TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse));
337     TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types));
338     TF_RETURN_IF_ERROR(
339         ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types));
340     TF_RETURN_IF_ERROR(
341         ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types));
342     TF_RETURN_IF_ERROR(
343         ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse));
344     TF_RETURN_IF_ERROR(
345         ctx->GetAttr("context_dense_shapes", &context_dense_shapes));
346     TF_RETURN_IF_ERROR(
347         ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes));
348     return FinishInit();
349   }
350 
351   int64_t num_context_sparse;
352   int64_t num_context_dense;
353   int64_t num_feature_list_sparse;
354   int64_t num_feature_list_dense;
355   std::vector<DataType> context_sparse_types;
356   std::vector<DataType> context_dense_types;
357   std::vector<TensorShape> context_dense_shapes;
358   std::vector<DataType> feature_list_sparse_types;
359   std::vector<DataType> feature_list_dense_types;
360   std::vector<TensorShape> feature_list_dense_shapes;
361 
362  private:
363   Status FinishInit();  // for context-independent parts of Init.
364 };
365 
366 }  // namespace tensorflow
367 
368 #endif  // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_
369