xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/serialization_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/serialization_utils.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_runner.h"
23 #include "tensorflow/core/data/dataset_utils.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 
28 namespace tensorflow {
29 namespace data {
30 namespace {
31 
32 constexpr char kDelimiter[] = "@@";
33 constexpr char kComponent[] = "component";
34 constexpr char kNumComponents[] = "num_components";
35 constexpr char kNumElements[] = "num_elements";
36 constexpr char kIsDataset[] = ".is_dataset";
37 constexpr char kOutputNode[] = ".output_node";
38 
39 // We assume that all keys are of the form <iterator_prefix>:<name>. We extract
40 // the iterator name by getting rid of everything post the final colon.
GetIteratorName(StringPiece key,string * name)41 Status GetIteratorName(StringPiece key, string* name) {
42   if (!str_util::StartsWith(key, data::kFullNameRandomHex)) {
43     return errors::InvalidArgument("Save key: ", key,
44                                    " not generated using full_name.");
45   }
46   std::vector<string> split_keys = str_util::Split(key, data::kPipe);
47   if (split_keys.size() != 2) {
48     return errors::InvalidArgument("Save key: ", key,
49                                    " not generated using full_name.");
50   }
51   string real_key = split_keys[1];
52   const int pos = real_key.rfind(kColon);
53   *name = real_key.substr(0, pos);
54   return OkStatus();
55 }
56 
FromGraphDef(FunctionLibraryRuntime * flr,const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_list,const string & output_node,Tensor * result)57 Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def,
58                     const std::vector<std::pair<string, Tensor>>& input_list,
59                     const string& output_node, Tensor* result) {
60   FunctionLibraryRuntime* cloned_flr = nullptr;
61   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
62   std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
63   TF_RETURN_IF_ERROR(flr->Clone(&lib_def, &pflr, &cloned_flr, true));
64   TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
65   Graph graph(OpRegistry::Global());
66   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
67   std::vector<Tensor> outputs;
68   GraphRunner graph_runner(cloned_flr->device());
69   TF_RETURN_IF_ERROR(graph_runner.Run(&graph, cloned_flr, input_list,
70                                       {output_node}, &outputs));
71   *result = outputs[0];
72   return OkStatus();
73 }
74 
75 // FindStatefulOps searches `graph_def` for all of its stateful ops storing
76 // their names in `stateful_op_names`.
FindStatefulOps(const GraphDef & graph_def,std::vector<string> * stateful_op_names)77 Status FindStatefulOps(const GraphDef& graph_def,
78                        std::vector<string>* stateful_op_names) {
79   FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library());
80 
81   // Iterate over all nodes in the graph.
82   for (const auto& node : graph_def.node()) {
83     // Each Dataset graph has a _Retval op in the end which is marked stateful
84     if (node.op() == FunctionLibraryDefinition::kRetOp) continue;
85     if (!IsNodeStateful(lib_def, node).ok()) {
86       stateful_op_names->push_back(node.op());
87     }
88   }
89 
90   // Iterate over all functions.
91   for (const auto& fdef : graph_def.library().function()) {
92     if (!fdef.signature().is_stateful()) continue;
93     for (const auto& node : fdef.node_def()) {
94       if (!IsNodeStateful(lib_def, node).ok()) {
95         stateful_op_names->push_back(
96             absl::StrCat(node.op(), " in function: ", fdef.signature().name()));
97       }
98     }
99   }
100   return OkStatus();
101 }
102 
103 }  // namespace
104 
ReadElementsFromCheckpoint(IteratorContext * ctx,IteratorStateReader * reader,StringPiece key_prefix,std::vector<std::vector<Tensor>> * elements)105 Status ReadElementsFromCheckpoint(IteratorContext* ctx,
106                                   IteratorStateReader* reader,
107                                   StringPiece key_prefix,
108                                   std::vector<std::vector<Tensor>>* elements) {
109   int64_t num_elements;
110   TF_RETURN_IF_ERROR(
111       reader->ReadScalar(key_prefix, kNumElements, &num_elements));
112   DCHECK(elements->empty());
113   elements->reserve(num_elements);
114   for (int i = 0; i < num_elements; ++i) {
115     std::string element_prefix = absl::StrCat(key_prefix, "::", i);
116     int64_t num_components;
117     TF_RETURN_IF_ERROR(
118         reader->ReadScalar(element_prefix, kNumComponents, &num_components));
119     elements->emplace_back();
120     std::vector<Tensor>& element = elements->at(i);
121     element.reserve(num_components);
122     for (int j = 0; j < num_components; ++j) {
123       element.emplace_back();
124       TF_RETURN_IF_ERROR(reader->ReadTensor(
125           ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
126           &element.back()));
127     }
128   }
129   return OkStatus();
130 }
131 
WriteElementsToCheckpoint(IteratorStateWriter * writer,StringPiece key_prefix,const std::vector<std::vector<Tensor>> & elements)132 Status WriteElementsToCheckpoint(
133     IteratorStateWriter* writer, StringPiece key_prefix,
134     const std::vector<std::vector<Tensor>>& elements) {
135   TF_RETURN_IF_ERROR(
136       writer->WriteScalar(key_prefix, kNumElements, elements.size()));
137   for (int i = 0; i < elements.size(); ++i) {
138     const std::vector<Tensor>& element = elements[i];
139     std::string element_prefix = absl::StrCat(key_prefix, "::", i);
140     TF_RETURN_IF_ERROR(
141         writer->WriteScalar(element_prefix, kNumComponents, element.size()));
142     for (int j = 0; j < elements[i].size(); ++j) {
143       TF_RETURN_IF_ERROR(writer->WriteTensor(
144           element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j]));
145     }
146   }
147   return OkStatus();
148 }
149 
VariantTensorDataReader(const std::vector<const tensorflow::VariantTensorData * > & data)150 VariantTensorDataReader::VariantTensorDataReader(
151     const std::vector<const tensorflow::VariantTensorData*>& data) {
152   for (const auto& d : data) {
153     string metadata;
154     d->get_metadata(&metadata);
155     auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
156     const string name = keys[0];
157     data_[name] = d;
158     map_[name] = std::map<string, size_t>();
159     for (size_t i = 1; i < keys.size(); ++i) {
160       map_[name][keys[i]] = i - 1;
161     }
162   }
163 }
164 
ReadScalar(StringPiece key,int64_t * val) const165 Status VariantTensorDataReader::ReadScalar(StringPiece key,
166                                            int64_t* val) const {
167   string name;
168   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
169   return ReadScalar(name, key, val);
170 }
171 
ReadScalar(StringPiece name,StringPiece key,int64_t * val) const172 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
173                                            int64_t* val) const {
174   return ReadScalarInternal(name, key, val);
175 }
176 
ReadScalar(StringPiece key,tstring * val) const177 Status VariantTensorDataReader::ReadScalar(StringPiece key,
178                                            tstring* val) const {
179   string name;
180   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
181   return ReadScalar(name, key, val);
182 }
183 
ReadScalar(StringPiece name,StringPiece key,tstring * val) const184 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
185                                            tstring* val) const {
186   return ReadScalarInternal(name, key, val);
187 }
188 
ReadTensor(StringPiece key,Tensor * val) const189 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const {
190   string name;
191   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
192   return ReadTensor(name, key, val);
193 }
194 
ReadTensor(FunctionLibraryRuntime * flr,StringPiece key,Tensor * val) const195 Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
196                                            StringPiece key, Tensor* val) const {
197   string name;
198   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
199   return ReadTensorInternal(flr, name, key, val);
200 }
201 
ReadTensor(StringPiece name,StringPiece key,Tensor * val) const202 Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
203                                            Tensor* val) const {
204   return ReadTensor(/*flr=*/nullptr, name, key, val);
205 }
206 
ReadTensor(FunctionLibraryRuntime * flr,StringPiece name,StringPiece key,Tensor * val) const207 Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
208                                            StringPiece name, StringPiece key,
209                                            Tensor* val) const {
210   return ReadTensorInternal(flr, name, key, val);
211 }
212 
Contains(StringPiece key) const213 bool VariantTensorDataReader::Contains(StringPiece key) const {
214   string name;
215   if (!GetIteratorName(key, &name).ok()) {
216     return false;
217   }
218   return Contains(name, key);
219 }
220 
Contains(StringPiece n,StringPiece key) const221 bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const {
222   string name(n);
223   auto it = map_.find(name);
224   if (it == map_.end()) {
225     return false;
226   }
227   const auto& bucket = it->second;
228   return bucket.find(string(key)) != bucket.end();
229 }
230 
231 template <typename T>
ReadScalarInternal(StringPiece n,StringPiece key,T * val) const232 Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
233                                                    StringPiece key,
234                                                    T* val) const {
235   string name(n);
236   auto it = map_.find(name);
237   if (it == map_.end()) {
238     return errors::NotFound(name);
239   }
240   const auto& bucket = it->second;
241   auto key_it = bucket.find(string(key));
242   if (key_it == bucket.end()) {
243     return errors::NotFound(key);
244   }
245   *val = data_.at(name)->tensors(key_it->second).scalar<T>()();
246   return OkStatus();
247 }
248 
ReadTensorInternal(FunctionLibraryRuntime * flr,StringPiece n,StringPiece key,Tensor * val) const249 Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr,
250                                                    StringPiece n,
251                                                    StringPiece key,
252                                                    Tensor* val) const {
253   if (Contains(n, strings::StrCat(key, kIsDataset))) {
254     return ReadDatasetInternal(flr, n, key, val);
255   }
256   string name(n);
257   auto it = map_.find(name);
258   if (it == map_.end()) {
259     return errors::NotFound(name);
260   }
261   const auto& bucket = it->second;
262   auto key_it = bucket.find(string(key));
263   if (key_it == bucket.end()) {
264     return errors::NotFound(key);
265   }
266   *val = data_.at(name)->tensors(key_it->second);
267   return OkStatus();
268 }
269 
ReadDatasetInternal(FunctionLibraryRuntime * flr,StringPiece n,StringPiece key,Tensor * val) const270 Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr,
271                                                     StringPiece n,
272                                                     StringPiece key,
273                                                     Tensor* val) const {
274   if (flr == nullptr) {
275     return errors::Internal(
276         "Function library runtime is needed to restore a dataset.");
277   }
278   tstring output_node, serialized_graph_def;
279   TF_RETURN_IF_ERROR(
280       ReadScalar(n, strings::StrCat(key, kOutputNode), &output_node));
281   TF_RETURN_IF_ERROR(
282       ReadScalar(n, strings::StrCat(key), &serialized_graph_def));
283   GraphDef graph_def;
284   graph_def.ParseFromString(serialized_graph_def);
285   TF_RETURN_IF_ERROR(FromGraphDef(flr, graph_def, {}, output_node, val));
286   return OkStatus();
287 }
288 
WriteScalar(StringPiece key,const int64_t val)289 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
290                                             const int64_t val) {
291   string name;
292   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
293   return WriteScalar(name, key, val);
294 }
295 
WriteScalar(StringPiece name,StringPiece key,const int64_t val)296 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
297                                             const int64_t val) {
298   return WriteScalarInternal(name, key, val);
299 }
300 
WriteScalar(StringPiece key,const tstring & val)301 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
302                                             const tstring& val) {
303   string name;
304   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
305   return WriteScalar(name, key, val);
306 }
307 
WriteScalar(StringPiece name,StringPiece key,const tstring & val)308 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
309                                             const tstring& val) {
310   return WriteScalarInternal(name, key, val);
311 }
312 
WriteTensor(StringPiece key,const Tensor & val)313 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
314                                             const Tensor& val) {
315   string name;
316   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
317   return WriteTensor(name, key, val);
318 }
319 
WriteTensor(StringPiece name,StringPiece key,const Tensor & val)320 Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
321                                             const Tensor& val) {
322   return WriteTensorInternal(name, key, val);
323 }
324 
MaybeFlush()325 void VariantTensorDataWriter::MaybeFlush() {
326   if (is_flushed_) return;
327   for (auto& keys : keys_) {
328     const string name = keys.first;
329     string metadata = name;
330     for (size_t i = 0; i < keys_[name].size(); ++i) {
331       strings::StrAppend(&metadata, kDelimiter, keys_[name][i]);
332     }
333     data_[name]->set_metadata(metadata);
334   }
335   is_flushed_ = true;
336 }
337 
Reset()338 void VariantTensorDataWriter::Reset() {
339   is_flushed_ = false;
340   data_.clear();
341   keys_.clear();
342 }
343 
ReleaseData(std::vector<std::unique_ptr<VariantTensorData>> * variants)344 void VariantTensorDataWriter::ReleaseData(
345     std::vector<std::unique_ptr<VariantTensorData>>* variants) {
346   MaybeFlush();
347   for (auto& it : data_) {
348     variants->push_back(std::move(it.second));
349   }
350   Reset();
351 }
352 
GetData(std::vector<const VariantTensorData * > * variants)353 void VariantTensorDataWriter::GetData(
354     std::vector<const VariantTensorData*>* variants) {
355   MaybeFlush();
356   for (auto& it : data_) {
357     variants->push_back(it.second.get());
358   }
359 }
360 
361 template <typename T>
WriteScalarInternal(StringPiece name,StringPiece key,const T & val)362 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
363                                                     StringPiece key,
364                                                     const T& val) {
365   if (is_flushed_) {
366     return errors::FailedPrecondition(
367         "Cannot call WriteScalar after GetData or ReleaseData is called");
368   }
369   Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
370   val_t.scalar<T>()() = val;
371   return WriteTensorInternal(name, key, val_t);
372 }
373 
WriteTensorInternal(StringPiece n,StringPiece key,const Tensor & val)374 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
375                                                     StringPiece key,
376                                                     const Tensor& val) {
377   DatasetBase* dataset;
378   if (GetDatasetFromVariantTensor(val, &dataset).ok()) {
379     return WriteDatasetInternal(n, key, dataset);
380   }
381   if (is_flushed_) {
382     return errors::FailedPrecondition(
383         "Cannot call WriteTensor after GetData or ReleaseData is called");
384   }
385   DCHECK_EQ(key.find(kDelimiter), string::npos);
386   string name(n);
387   if (keys_.count(name) == 0) {
388     keys_[name] = std::vector<string>();
389   }
390   keys_[name].push_back(string(key));
391   if (data_.count(name) == 0) {
392     data_[name] = std::make_unique<VariantTensorData>();
393     data_[name]->set_type_name("tensorflow::Iterator");
394   }
395   *(data_[name]->add_tensors()) = val;
396   return OkStatus();
397 }
398 
WriteDatasetInternal(StringPiece n,StringPiece key,const DatasetBase * dataset)399 Status VariantTensorDataWriter::WriteDatasetInternal(
400     StringPiece n, StringPiece key, const DatasetBase* dataset) {
401   GraphDef graph_def;
402   SerializationContext ctx((SerializationContext::Params()));
403   TF_RETURN_IF_ERROR(AsGraphDef(dataset, std::move(ctx), &graph_def));
404   string output_node;
405   for (const auto& node : graph_def.node()) {
406     if (node.op() == "_Retval") {
407       output_node = node.input(0);
408       break;
409     }
410   }
411   string result;
412   graph_def.SerializeToString(&result);
413   TF_RETURN_IF_ERROR(WriteScalar(n, strings::StrCat(key, kIsDataset), ""));
414   TF_RETURN_IF_ERROR(
415       WriteScalar(n, strings::StrCat(key, kOutputNode), output_node));
416   TF_RETURN_IF_ERROR(WriteScalar(n, key, result));
417   return OkStatus();
418 }
419 
AsGraphDefForRewrite(OpKernelContext * ctx,const DatasetBase * input,std::vector<std::pair<string,Tensor>> * input_list,GraphDef * result,string * dataset_node)420 Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input,
421                             std::vector<std::pair<string, Tensor>>* input_list,
422                             GraphDef* result, string* dataset_node) {
423   SerializationContext::Params params(ctx);
424   params.input_list = input_list;
425   params.external_state_policy =
426       SerializationContext::ExternalStatePolicy::kIgnore;
427   params.is_graph_rewrite = true;
428   SerializationContext serialization_ctx(params);
429   TF_RETURN_IF_ERROR(AsGraphDef(input, std::move(serialization_ctx), result));
430 
431   // Symbolic `_Retval` node indicates which node corresponds to the dataset.
432   for (const auto& node : result->node()) {
433     if (node.op() == "_Retval") {
434       *dataset_node = node.input(0);
435     }
436   }
437   return OkStatus();
438 }
439 
AsGraphDef(const DatasetBase * dataset,SerializationContext && serialization_ctx,GraphDef * graph_def)440 Status AsGraphDef(const DatasetBase* dataset,
441                   SerializationContext&& serialization_ctx,
442                   GraphDef* graph_def) {
443   if (serialization_ctx.external_state_policy() ==
444       SerializationContext::ExternalStatePolicy::kFail) {
445     TF_RETURN_IF_ERROR(dataset->CheckExternalState());
446   }
447   if (serialization_ctx.external_state_policy() ==
448       SerializationContext::ExternalStatePolicy::kWarn) {
449     std::vector<string> stateful_op_names;
450     TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names));
451     if (!stateful_op_names.empty()) {
452       LOG(WARNING) << "We found the following stateful ops in the dataset "
453                       "construction graph whose state would not be "
454                       "serialized and might "
455                       "cause subtle bugs: "
456                    << absl::StrJoin(stateful_op_names, ", ");
457     }
458   }
459   GraphDefBuilder b;
460   DatasetBase::DatasetGraphDefBuilder db(&b);
461   Node* output_node = nullptr;
462   TF_RETURN_IF_ERROR(
463       db.AddInputDataset(&serialization_ctx, dataset, &output_node));
464   // Insert a purely symbolic _Retval node to indicate to consumers which node
465   // represents `dataset`.
466   ops::UnaryOp("_Retval", output_node,
467                b.opts()
468                    .WithName("dataset")
469                    .WithAttr("T", DT_VARIANT)
470                    .WithAttr("index", 0));
471   TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
472   return OkStatus();
473 }
474 
475 }  // namespace data
476 }  // namespace tensorflow
477