1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
17
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/c/eager/immediate_execution_context.h"
29 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
30 #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
31 #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
32 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
33 #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
34 #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
35 #include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
36 #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
37 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
38 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
39 #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
40 #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
41 #include "tensorflow/cc/saved_model/bundle_v2.h"
42 #include "tensorflow/cc/saved_model/constants.h"
43 #include "tensorflow/core/framework/attr_value.pb.h"
44 #include "tensorflow/core/framework/function.pb.h"
45 #include "tensorflow/core/framework/graph.pb.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor.pb.h"
49 #include "tensorflow/core/framework/tensor_shape.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/lib/gtl/flatmap.h"
52 #include "tensorflow/core/lib/hash/hash.h"
53 #include "tensorflow/core/platform/casts.h"
54 #include "tensorflow/core/platform/errors.h"
55 #include "tensorflow/core/platform/logging.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/path.h"
58 #include "tensorflow/core/platform/stringpiece.h"
59 #include "tensorflow/core/platform/tstring.h"
60 #include "tensorflow/core/protobuf/meta_graph.pb.h"
61 #include "tensorflow/core/protobuf/saved_model.pb.h"
62 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
63 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
64
65 namespace tensorflow {
66
67 // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
68 using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
69 StringPieceHasher>;
70
71 // Maps from a functiondef's name to the corresponding "TFConcreteFunction"
72 using FlatTensorFunctionMap =
73 gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
74
75 namespace {
76
77 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,absl::string_view name)78 FindSerializedTensorInTrackable(
79 const TrackableObjectGraph::TrackableObject& trackable_object,
80 absl::string_view name) {
81 for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
82 if (maybe_serialized_tensor.name() == name) {
83 return &maybe_serialized_tensor;
84 }
85 }
86 return nullptr;
87 }
88
89 // This function reads the Checkpoint embedded in the SavedModel, and calls the
90 // appropriate Restore ops on each of the variables.
91 // Note(bmzhao): Conceptually, objects that contain checkpointable state
92 // implement the "_gather_saveables_for_checkpoint" method
93 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
94 // which returns a dict of string key -> EITHER:
95 // 1. python callable (taking a checkpoint key) returning SaveableObject OR
96 // 2. variable (partitioned/resource/reference or otherwise)
97 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
98 // The string key becomes the "name" attribute of the SerializedTensor proto
99 // in the TrackableObjectGraph,
100 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
101 // And the checkpoint_key is a globally unique string derived from this name:
102 // https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
103 // SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
104 // ops via their SaveSpec members
105 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
106 // which contain the "real" checkpoint keys into the TensorBundle SSTable.
107 // They also contain the logic needed to take the restored tensors from
108 // RestoreV2 and load them back into the "object" they came from via their
109 // overridden "restore" method:
110 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
RestoreCheckpoint(SavedModelV2Bundle * bundle,const RevivedObjects & revived_objects,const std::string & directory,ImmediateExecutionContext * context)111 Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
112 const RevivedObjects& revived_objects,
113 const std::string& directory,
114 ImmediateExecutionContext* context) {
115 // TODO(bmzhao): Batch up all the restores into a single restore op per
116 // device, following logic in MultiDeviceSaver.
117 TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
118 [&revived_objects, &directory, context, bundle](
119 int node, const TrackableObjectGraph::TrackableObject& trackable) {
120 if (bundle->saved_object_graph().nodes(node).kind_case() !=
121 SavedObject::kVariable) {
122 // TODO(bmzhao): This requires using the newly added Save/Restore
123 // functions from
124 // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
125 LOG(WARNING) << "Restoring non-variable objects has not been "
126 "implemented yet. (Kind="
127 << bundle->saved_object_graph().nodes(node).kind_case()
128 << ")";
129 return OkStatus();
130 }
131
132 Variable* variable = revived_objects.variables.at(node).get();
133
134 // Restore the tensor's value from the checkpoint
135 const TrackableObjectGraph::TrackableObject::SerializedTensor*
136 attribute =
137 FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
138 if (attribute == nullptr) {
139 return errors::FailedPrecondition(
140 "Could not find SerializedTensor with name VARIABLE_VALUE for "
141 "saved variable");
142 }
143
144 const std::string& checkpoint_key = attribute->checkpoint_key();
145 if (!bundle->variable_reader()->Contains(checkpoint_key)) {
146 LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
147 << ". Variable will be uninitialized.";
148 return Status();
149 }
150
151 std::string variables_path_prefix =
152 io::JoinPath(directory, kSavedModelVariablesDirectory,
153 kSavedModelVariablesFilename);
154 ImmediateTensorHandlePtr restored_output;
155 TF_RETURN_IF_ERROR(internal::SingleRestore(
156 context, variables_path_prefix, checkpoint_key, variable->dtype(),
157 &restored_output));
158
159 // Assign the restored tensor's value to the variable
160 return variable->Assign(restored_output.get());
161 }));
162
163 return Status();
164 }
165
InitializeAllResources(const RevivedObjects & revived)166 Status InitializeAllResources(const RevivedObjects& revived) {
167 for (const auto& node_and_resource : revived.restored_resources) {
168 const RestoredResource& resource = node_and_resource.second;
169 TF_RETURN_IF_ERROR(resource.Initialize());
170 }
171 return Status();
172 }
173
174 } // namespace
175
GetFunction(const std::string & function_path,ConcreteFunction ** function)176 Status TFSavedModelAPI::GetFunction(const std::string& function_path,
177 ConcreteFunction** function) {
178 absl::optional<int> node =
179 internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
180 if (!node.has_value()) {
181 return errors::NotFound("No saved object found at path ", function_path);
182 }
183
184 *function = revived_objects_.concrete_functions.Find(*node);
185 if (*function == nullptr) {
186 return errors::NotFound("No function found at path ", function_path);
187 }
188
189 return Status();
190 }
191
GetFunctions(int node_id,absl::flat_hash_map<std::string,ConcreteFunction * > * functions)192 Status TFSavedModelAPI::GetFunctions(
193 int node_id,
194 absl::flat_hash_map<std::string, ConcreteFunction*>* functions) {
195 const auto& nodes = bundle_.saved_object_graph().nodes();
196 if (node_id >= nodes.size()) {
197 return errors::OutOfRange(
198 "node_id ", node_id,
199 " not found. Maximum node ID: ", nodes.size() - 1);
200 }
201 const SavedObject* current_node = &nodes.Get(node_id);
202 for (const auto& child : current_node->children()) {
203 ConcreteFunction* concrete_fn;
204 Status status = GetFunction(child.local_name(), &concrete_fn);
205 if (status.ok()) {
206 (*functions)[child.local_name()] = concrete_fn;
207 }
208 }
209 return Status();
210 }
211
GetSignatureDefFunction(const std::string & signature_def_key,SignatureDefFunction ** function)212 Status TFSavedModelAPI::GetSignatureDefFunction(
213 const std::string& signature_def_key, SignatureDefFunction** function) {
214 auto signatures_iter =
215 revived_objects_.signatures_map.find(signature_def_key);
216 if (signatures_iter == revived_objects_.signatures_map.end()) {
217 return errors::NotFound("No signature with key ", signature_def_key,
218 " was found");
219 }
220 int node = signatures_iter->second;
221
222 auto function_iter = revived_objects_.signature_def_functions.find(node);
223 if (function_iter == revived_objects_.signature_def_functions.end()) {
224 return errors::Internal(
225 "Unable to find SignatureDefFunction associated with key ",
226 signature_def_key, " despite key being valid.");
227 }
228
229 *function = function_iter->second.get();
230 return Status();
231 }
232
GetVariable(const std::string & variable_path,Variable ** variable)233 Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
234 Variable** variable) {
235 absl::optional<int> node =
236 internal::FindNodeAtPath(variable_path, bundle_.saved_object_graph());
237 if (!node.has_value()) {
238 return errors::NotFound("No saved object found at path ", variable_path);
239 }
240
241 auto variables_iter = revived_objects_.variables.find(*node);
242 if (variables_iter == revived_objects_.variables.end()) {
243 return errors::NotFound("No variable found at path ", variable_path);
244 }
245
246 *variable = variables_iter->second.get();
247 return Status();
248 }
249
GetBundle()250 SavedModelV2Bundle* TFSavedModelAPI::GetBundle() { return &this->bundle_; }
251
TFSavedModelAPI(const std::string & directory,SavedModelV2Bundle bundle,RevivedObjects revived_objects)252 TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
253 SavedModelV2Bundle bundle,
254 RevivedObjects revived_objects)
255 : directory_(directory),
256 bundle_(std::move(bundle)),
257 revived_objects_(std::move(revived_objects)) {}
258
Load(const std::string & directory,const absl::optional<std::unordered_set<std::string>> & tags,ImmediateExecutionContext * context,std::unique_ptr<TFSavedModelAPI> * out)259 Status TFSavedModelAPI::Load(
260 const std::string& directory,
261 const absl::optional<std::unordered_set<std::string>>& tags,
262 ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
263 // TODO(bmzhao): Add support for loading a TF1 SavedModel.
264 if (tags) {
265 return errors::Unimplemented(
266 "Loading saved models with explicit tags will be supported in the "
267 "future");
268 }
269
270 SavedModelV2Bundle bundle;
271 TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
272
273 // TODO(bmzhao): Mangle loaded function names so that different
274 // models loaded in the same runtime Context don't clobber eachother.
275 // This occurs in python here:
276 // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
277
278 // For each node in the graph, we should initialize an object of the
279 // corresponding type. For objects that depend on the initialization of other
280 // objects (like functions which capture resources), we will initialize them
281 // later.
282 PartiallyRevivedObjects partially_revived_objects;
283 TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
284 bundle.meta_graph_def(), context, directory, &partially_revived_objects));
285
286 RevivedObjects revived_objects;
287 TF_RETURN_IF_ERROR(partially_revived_objects.Build(
288 context, bundle.saved_object_graph(), &revived_objects));
289
290 // Revive function library functions as concrete functions without captures.
291 // This is necessary because object graph functions may refer to functions
292 // _not_ in the object graph: A while loop, for example, will create two
293 // auxiliary `while_cond` and `while_body` functions that are only present in
294 // the graph def function library.
295 for (const FunctionDef& function :
296 bundle.meta_graph_def().graph_def().library().function()) {
297 std::unique_ptr<TFConcreteFunction> concrete_function;
298 TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function,
299 /*captures=*/{},
300 /*metadata=*/{},
301 /*ctx=*/context,
302 /*out=*/&concrete_function));
303 revived_objects.concrete_functions.Insert(std::move(concrete_function));
304 }
305
306 TF_RETURN_IF_ERROR(
307 RestoreCheckpoint(&bundle, revived_objects, directory, context));
308
309 TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
310
311 out->reset(new TFSavedModelAPI(directory, std::move(bundle),
312 std::move(revived_objects)));
313 return Status();
314 }
315
316 } // namespace tensorflow
317