1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_ 17 #define TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/lib/core/status.h" 23 24 namespace tensorflow { 25 namespace data { 26 27 // Reads dataset elements from the checkpoint reader using the given key prefix. 28 Status ReadElementsFromCheckpoint(IteratorContext* ctx, 29 IteratorStateReader* reader, 30 StringPiece key_prefix, 31 std::vector<std::vector<Tensor>>* elements); 32 33 // Writes dataset elements to the checkpoint writer using the given key prefix. 34 // The elements can be read back by passing the same key prefix to 35 // ReadElementsFromCheckpoint. Only one list of elements can be written under 36 // the same key_prefix. 37 Status WriteElementsToCheckpoint( 38 IteratorStateWriter* writer, StringPiece key_prefix, 39 const std::vector<std::vector<Tensor>>& elements); 40 41 // Helper class for reading data from a vector of VariantTensorData objects. 42 class VariantTensorDataReader : public IteratorStateReader { 43 public: 44 explicit VariantTensorDataReader( 45 const std::vector<const VariantTensorData*>& data); 46 47 bool Contains(StringPiece key) const override; 48 bool Contains(StringPiece name, StringPiece key) const override; 49 50 Status ReadScalar(StringPiece key, int64_t* val) const override; 51 Status ReadScalar(StringPiece name, StringPiece key, 52 int64_t* val) const override; 53 Status ReadScalar(StringPiece key, tstring* val) const override; 54 Status ReadScalar(StringPiece name, StringPiece key, 55 tstring* val) const override; 56 Status ReadTensor(StringPiece key, Tensor* val) const override; 57 Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, 58 Tensor* val) const override; 59 Status ReadTensor(StringPiece name, StringPiece key, 60 Tensor* val) const override; 61 Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, 62 StringPiece key, Tensor* val) const override; 63 64 private: 65 template <typename T> 66 Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const; 67 Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name, 68 StringPiece key, Tensor* val) const; 69 Status ReadDatasetInternal(FunctionLibraryRuntime* flr, StringPiece name, 70 StringPiece key, Tensor* val) const; 71 72 std::map<string, std::map<string, size_t>> map_; 73 std::map<string, const VariantTensorData*> data_; // Not owned. 74 }; 75 76 // Helper class used to build a list of VariantTensorData objects, one for each 77 // iterator which is determined from the key supplied from the Write* calls. 78 // Sample usage: 79 // VariantTensorDataWriter writer; 80 // writer.WriteScalar(full_name("buffer_size"), buffer_.size()); 81 // writer.WriteScalar(full_name("num_threads"), threadpool_.size()); 82 // .... 83 // std::vector<std::unique_ptr<VariantTensorData>> variants; 84 // writer.ReleaseData(&variants); 85 // Now the VariantTensorData objects can be used to serialize. 86 class VariantTensorDataWriter : public IteratorStateWriter { 87 public: 88 Status WriteScalar(StringPiece key, const int64_t val) override; 89 Status WriteScalar(StringPiece name, StringPiece key, 90 const int64_t val) override; 91 92 Status WriteScalar(StringPiece key, const tstring& val) override; 93 Status WriteScalar(StringPiece name, StringPiece key, 94 const tstring& val) override; 95 96 Status WriteTensor(StringPiece key, const Tensor& val) override; 97 Status WriteTensor(StringPiece name, StringPiece key, 98 const Tensor& val) override; 99 100 // Releases the built VariantTensorData's to `variants`. Clears out all 101 // class state. 102 void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants); 103 104 // Obtains a read-only version of the VariantTensorData's built. 105 void GetData(std::vector<const VariantTensorData*>* variants); 106 107 private: 108 void MaybeFlush(); 109 void Reset(); 110 111 template <typename T> 112 Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val); 113 Status WriteTensorInternal(StringPiece name, StringPiece key, 114 const Tensor& val); 115 Status WriteDatasetInternal(StringPiece name, StringPiece key, 116 const DatasetBase* dataset); 117 118 bool is_flushed_ = false; 119 std::map<string, std::unique_ptr<VariantTensorData>> data_; 120 std::map<string, std::vector<string>> keys_; 121 }; 122 123 // Returns a GraphDef representation of the given dataset. 124 Status AsGraphDef(const DatasetBase* dataset, 125 SerializationContext&& serialization_ctx, 126 GraphDef* graph_def); 127 128 // Returns a GraphDef representation of the given dataset suitable for 129 // optimization rewrites. It sets serialization parameters to export a minimum 130 // graph with additional information for optimization (i.e. ignoring external 131 // state, not serializing data tensors, not failing if there are datasets which 132 // do not have AsGraphDef implemented). Sets the `dataset_node` parameter to the 133 // dataset's node name in the resulting GraphDef. 134 Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input, 135 std::vector<std::pair<string, Tensor>>* input_list, 136 GraphDef* result, string* dataset_node); 137 138 } // namespace data 139 } // namespace tensorflow 140 141 #endif // TENSORFLOW_CORE_KERNELS_SERIALIZATION_UTILS_H_ 142