xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/serialization_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_
17 #define TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/lib/core/status.h"
23 
24 namespace tensorflow {
25 namespace data {
26 
27 // Reads dataset elements from the checkpoint reader using the given key prefix.
28 Status ReadElementsFromCheckpoint(IteratorContext* ctx,
29                                   IteratorStateReader* reader,
30                                   StringPiece key_prefix,
31                                   std::vector<std::vector<Tensor>>* elements);
32 
33 // Writes dataset elements to the checkpoint writer using the given key prefix.
34 // The elements can be read back by passing the same key prefix to
35 // ReadElementsFromCheckpoint. Only one list of elements can be written under
36 // the same key_prefix.
37 Status WriteElementsToCheckpoint(
38     IteratorStateWriter* writer, StringPiece key_prefix,
39     const std::vector<std::vector<Tensor>>& elements);
40 
41 // Helper class for reading data from a vector of VariantTensorData objects.
42 class VariantTensorDataReader : public IteratorStateReader {
43  public:
44   explicit VariantTensorDataReader(
45       const std::vector<const VariantTensorData*>& data);
46 
47   bool Contains(StringPiece key) const override;
48   bool Contains(StringPiece name, StringPiece key) const override;
49 
50   Status ReadScalar(StringPiece key, int64_t* val) const override;
51   Status ReadScalar(StringPiece name, StringPiece key,
52                     int64_t* val) const override;
53   Status ReadScalar(StringPiece key, tstring* val) const override;
54   Status ReadScalar(StringPiece name, StringPiece key,
55                     tstring* val) const override;
56   Status ReadTensor(StringPiece key, Tensor* val) const override;
57   Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
58                     Tensor* val) const override;
59   Status ReadTensor(StringPiece name, StringPiece key,
60                     Tensor* val) const override;
61   Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
62                     StringPiece key, Tensor* val) const override;
63 
64  private:
65   template <typename T>
66   Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const;
67   Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name,
68                             StringPiece key, Tensor* val) const;
69   Status ReadDatasetInternal(FunctionLibraryRuntime* flr, StringPiece name,
70                              StringPiece key, Tensor* val) const;
71 
72   std::map<string, std::map<string, size_t>> map_;
73   std::map<string, const VariantTensorData*> data_;  // Not owned.
74 };
75 
76 // Helper class used to build a list of VariantTensorData objects, one for each
77 // iterator which is determined from the key supplied from the Write* calls.
78 // Sample usage:
79 // VariantTensorDataWriter writer;
80 // writer.WriteScalar(full_name("buffer_size"), buffer_.size());
81 // writer.WriteScalar(full_name("num_threads"), threadpool_.size());
82 // ....
83 // std::vector<std::unique_ptr<VariantTensorData>> variants;
84 // writer.ReleaseData(&variants);
85 // Now the VariantTensorData objects can be used to serialize.
86 class VariantTensorDataWriter : public IteratorStateWriter {
87  public:
88   Status WriteScalar(StringPiece key, const int64_t val) override;
89   Status WriteScalar(StringPiece name, StringPiece key,
90                      const int64_t val) override;
91 
92   Status WriteScalar(StringPiece key, const tstring& val) override;
93   Status WriteScalar(StringPiece name, StringPiece key,
94                      const tstring& val) override;
95 
96   Status WriteTensor(StringPiece key, const Tensor& val) override;
97   Status WriteTensor(StringPiece name, StringPiece key,
98                      const Tensor& val) override;
99 
100   // Releases the built VariantTensorData's to `variants`. Clears out all
101   // class state.
102   void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants);
103 
104   // Obtains a read-only version of the VariantTensorData's built.
105   void GetData(std::vector<const VariantTensorData*>* variants);
106 
107  private:
108   void MaybeFlush();
109   void Reset();
110 
111   template <typename T>
112   Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val);
113   Status WriteTensorInternal(StringPiece name, StringPiece key,
114                              const Tensor& val);
115   Status WriteDatasetInternal(StringPiece name, StringPiece key,
116                               const DatasetBase* dataset);
117 
118   bool is_flushed_ = false;
119   std::map<string, std::unique_ptr<VariantTensorData>> data_;
120   std::map<string, std::vector<string>> keys_;
121 };
122 
123 // Returns a GraphDef representation of the given dataset.
124 Status AsGraphDef(const DatasetBase* dataset,
125                   SerializationContext&& serialization_ctx,
126                   GraphDef* graph_def);
127 
128 // Returns a GraphDef representation of the given dataset suitable for
129 // optimization rewrites. It sets serialization parameters to export a minimum
130 // graph with additional information for optimization (i.e. ignoring external
131 // state, not serializing data tensors, not failing if there are datasets which
132 // do not have AsGraphDef implemented). Sets the `dataset_node` parameter to the
133 // dataset's node name in the resulting GraphDef.
134 Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input,
135                             std::vector<std::pair<string, Tensor>>* input_list,
136                             GraphDef* result, string* dataset_node);
137 
138 }  // namespace data
139 }  // namespace tensorflow
140 
141 #endif  // TENSORFLOW_CORE_KERNELS_SERIALIZATION_UTILS_H_
142