xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/c/eager/immediate_execution_context.h"
29 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
30 #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
31 #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
32 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
33 #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
34 #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
35 #include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
36 #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
37 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
38 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
39 #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
40 #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
41 #include "tensorflow/cc/saved_model/bundle_v2.h"
42 #include "tensorflow/cc/saved_model/constants.h"
43 #include "tensorflow/core/framework/attr_value.pb.h"
44 #include "tensorflow/core/framework/function.pb.h"
45 #include "tensorflow/core/framework/graph.pb.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor.pb.h"
49 #include "tensorflow/core/framework/tensor_shape.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/lib/gtl/flatmap.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/errors.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/path.h"
58 #include "tensorflow/core/platform/stringpiece.h"
59 #include "tensorflow/core/platform/tstring.h"
60 #include "tensorflow/core/protobuf/meta_graph.pb.h"
61 #include "tensorflow/core/protobuf/saved_model.pb.h"
62 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
63 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
64 
65 namespace tensorflow {
66 
67 // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
68 using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
69                                     StringPieceHasher>;
70 
71 // Maps from a functiondef's name to the corresponding "TFConcreteFunction"
72 using FlatTensorFunctionMap =
73     gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
74 
75 namespace {
76 
77 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,absl::string_view name)78 FindSerializedTensorInTrackable(
79     const TrackableObjectGraph::TrackableObject& trackable_object,
80     absl::string_view name) {
81   for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
82     if (maybe_serialized_tensor.name() == name) {
83       return &maybe_serialized_tensor;
84     }
85   }
86   return nullptr;
87 }
88 
89 // This function reads the Checkpoint embedded in the SavedModel, and calls the
90 // appropriate Restore ops on each of the variables.
91 // Note(bmzhao): Conceptually, objects that contain checkpointable state
92 // implement the "_gather_saveables_for_checkpoint" method
93 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
94 // which returns a dict of string key -> EITHER:
95 // 1. python callable (taking a checkpoint key) returning SaveableObject OR
96 // 2. variable (partitioned/resource/reference or otherwise)
97 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
98 // The string key becomes the "name" attribute of the SerializedTensor proto
99 // in the TrackableObjectGraph,
100 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
101 // And the checkpoint_key is a globally unique string derived from this name:
102 // https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
103 // SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
104 // ops via their SaveSpec members
105 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
106 // which contain the "real" checkpoint keys into the TensorBundle SSTable.
107 // They also contain the logic needed to take the restored tensors from
108 // RestoreV2 and load them back into the "object" they came from via their
109 // overridden "restore" method:
110 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
RestoreCheckpoint(SavedModelV2Bundle * bundle,const RevivedObjects & revived_objects,const std::string & directory,ImmediateExecutionContext * context)111 Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
112                          const RevivedObjects& revived_objects,
113                          const std::string& directory,
114                          ImmediateExecutionContext* context) {
115   // TODO(bmzhao): Batch up all the restores into a single restore op per
116   // device, following logic in MultiDeviceSaver.
117   TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
118       [&revived_objects, &directory, context, bundle](
119           int node, const TrackableObjectGraph::TrackableObject& trackable) {
120         if (bundle->saved_object_graph().nodes(node).kind_case() !=
121             SavedObject::kVariable) {
122           // TODO(bmzhao): This requires using the newly added Save/Restore
123           // functions from
124           // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
125           LOG(WARNING) << "Restoring non-variable objects has not been "
126                           "implemented yet. (Kind="
127                        << bundle->saved_object_graph().nodes(node).kind_case()
128                        << ")";
129           return OkStatus();
130         }
131 
132         Variable* variable = revived_objects.variables.at(node).get();
133 
134         // Restore the tensor's value from the checkpoint
135         const TrackableObjectGraph::TrackableObject::SerializedTensor*
136             attribute =
137                 FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
138         if (attribute == nullptr) {
139           return errors::FailedPrecondition(
140               "Could not find SerializedTensor with name VARIABLE_VALUE for "
141               "saved variable");
142         }
143 
144         const std::string& checkpoint_key = attribute->checkpoint_key();
145         if (!bundle->variable_reader()->Contains(checkpoint_key)) {
146           LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
147                        << ". Variable will be uninitialized.";
148           return Status();
149         }
150 
151         std::string variables_path_prefix =
152             io::JoinPath(directory, kSavedModelVariablesDirectory,
153                          kSavedModelVariablesFilename);
154         ImmediateTensorHandlePtr restored_output;
155         TF_RETURN_IF_ERROR(internal::SingleRestore(
156             context, variables_path_prefix, checkpoint_key, variable->dtype(),
157             &restored_output));
158 
159         // Assign the restored tensor's value to the variable
160         return variable->Assign(restored_output.get());
161       }));
162 
163   return Status();
164 }
165 
InitializeAllResources(const RevivedObjects & revived)166 Status InitializeAllResources(const RevivedObjects& revived) {
167   for (const auto& node_and_resource : revived.restored_resources) {
168     const RestoredResource& resource = node_and_resource.second;
169     TF_RETURN_IF_ERROR(resource.Initialize());
170   }
171   return Status();
172 }
173 
174 }  // namespace
175 
GetFunction(const std::string & function_path,ConcreteFunction ** function)176 Status TFSavedModelAPI::GetFunction(const std::string& function_path,
177                                     ConcreteFunction** function) {
178   absl::optional<int> node =
179       internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
180   if (!node.has_value()) {
181     return errors::NotFound("No saved object found at path ", function_path);
182   }
183 
184   *function = revived_objects_.concrete_functions.Find(*node);
185   if (*function == nullptr) {
186     return errors::NotFound("No function found at path ", function_path);
187   }
188 
189   return Status();
190 }
191 
GetFunctions(int node_id,absl::flat_hash_map<std::string,ConcreteFunction * > * functions)192 Status TFSavedModelAPI::GetFunctions(
193     int node_id,
194     absl::flat_hash_map<std::string, ConcreteFunction*>* functions) {
195   const auto& nodes = bundle_.saved_object_graph().nodes();
196   if (node_id >= nodes.size()) {
197     return errors::OutOfRange(
198         "node_id ", node_id,
199         " not found.  Maximum node ID: ", nodes.size() - 1);
200   }
201   const SavedObject* current_node = &nodes.Get(node_id);
202   for (const auto& child : current_node->children()) {
203     ConcreteFunction* concrete_fn;
204     Status status = GetFunction(child.local_name(), &concrete_fn);
205     if (status.ok()) {
206       (*functions)[child.local_name()] = concrete_fn;
207     }
208   }
209   return Status();
210 }
211 
GetSignatureDefFunction(const std::string & signature_def_key,SignatureDefFunction ** function)212 Status TFSavedModelAPI::GetSignatureDefFunction(
213     const std::string& signature_def_key, SignatureDefFunction** function) {
214   auto signatures_iter =
215       revived_objects_.signatures_map.find(signature_def_key);
216   if (signatures_iter == revived_objects_.signatures_map.end()) {
217     return errors::NotFound("No signature with key ", signature_def_key,
218                             " was found");
219   }
220   int node = signatures_iter->second;
221 
222   auto function_iter = revived_objects_.signature_def_functions.find(node);
223   if (function_iter == revived_objects_.signature_def_functions.end()) {
224     return errors::Internal(
225         "Unable to find SignatureDefFunction associated with key ",
226         signature_def_key, " despite key being valid.");
227   }
228 
229   *function = function_iter->second.get();
230   return Status();
231 }
232 
GetVariable(const std::string & variable_path,Variable ** variable)233 Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
234                                     Variable** variable) {
235   absl::optional<int> node =
236       internal::FindNodeAtPath(variable_path, bundle_.saved_object_graph());
237   if (!node.has_value()) {
238     return errors::NotFound("No saved object found at path ", variable_path);
239   }
240 
241   auto variables_iter = revived_objects_.variables.find(*node);
242   if (variables_iter == revived_objects_.variables.end()) {
243     return errors::NotFound("No variable found at path ", variable_path);
244   }
245 
246   *variable = variables_iter->second.get();
247   return Status();
248 }
249 
GetBundle()250 SavedModelV2Bundle* TFSavedModelAPI::GetBundle() { return &this->bundle_; }
251 
TFSavedModelAPI(const std::string & directory,SavedModelV2Bundle bundle,RevivedObjects revived_objects)252 TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
253                                  SavedModelV2Bundle bundle,
254                                  RevivedObjects revived_objects)
255     : directory_(directory),
256       bundle_(std::move(bundle)),
257       revived_objects_(std::move(revived_objects)) {}
258 
Load(const std::string & directory,const absl::optional<std::unordered_set<std::string>> & tags,ImmediateExecutionContext * context,std::unique_ptr<TFSavedModelAPI> * out)259 Status TFSavedModelAPI::Load(
260     const std::string& directory,
261     const absl::optional<std::unordered_set<std::string>>& tags,
262     ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
263   // TODO(bmzhao): Add support for loading a TF1 SavedModel.
264   if (tags) {
265     return errors::Unimplemented(
266         "Loading saved models with explicit tags will be supported in the "
267         "future");
268   }
269 
270   SavedModelV2Bundle bundle;
271   TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
272 
273   // TODO(bmzhao): Mangle loaded function names so that different
274   // models loaded in the same runtime Context don't clobber eachother.
275   // This occurs in python here:
276   // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
277 
278   // For each node in the graph, we should initialize an object of the
279   // corresponding type. For objects that depend on the initialization of other
280   // objects (like functions which capture resources), we will initialize them
281   // later.
282   PartiallyRevivedObjects partially_revived_objects;
283   TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
284       bundle.meta_graph_def(), context, directory, &partially_revived_objects));
285 
286   RevivedObjects revived_objects;
287   TF_RETURN_IF_ERROR(partially_revived_objects.Build(
288       context, bundle.saved_object_graph(), &revived_objects));
289 
290   // Revive function library functions as concrete functions without captures.
291   // This is necessary because object graph functions may refer to functions
292   // _not_ in the object graph: A while loop, for example, will create two
293   // auxiliary `while_cond` and `while_body` functions that are only present in
294   // the graph def function library.
295   for (const FunctionDef& function :
296        bundle.meta_graph_def().graph_def().library().function()) {
297     std::unique_ptr<TFConcreteFunction> concrete_function;
298     TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function,
299                                                   /*captures=*/{},
300                                                   /*metadata=*/{},
301                                                   /*ctx=*/context,
302                                                   /*out=*/&concrete_function));
303     revived_objects.concrete_functions.Insert(std::move(concrete_function));
304   }
305 
306   TF_RETURN_IF_ERROR(
307       RestoreCheckpoint(&bundle, revived_objects, directory, context));
308 
309   TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
310 
311   out->reset(new TFSavedModelAPI(directory, std::move(bundle),
312                                  std::move(revived_objects)));
313   return Status();
314 }
315 
316 }  // namespace tensorflow
317