xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/builder_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel builder implementation."""
16
17import functools
18import os
19
20from google.protobuf.any_pb2 import Any
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.core.protobuf import saved_model_pb2
25from tensorflow.core.protobuf import saver_pb2
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.lib.io import file_io
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.saved_model import constants
32from tensorflow.python.saved_model import signature_def_utils
33from tensorflow.python.saved_model import utils_impl as saved_model_utils
34from tensorflow.python.saved_model.pywrap_saved_model import metrics
35from tensorflow.python.training import saver as tf_saver
36from tensorflow.python.util import compat
37from tensorflow.python.util.deprecation import deprecated_args
38from tensorflow.python.util.tf_export import tf_export
39
40# API label for SavedModel metrics.
41_SAVE_BUILDER_LABEL = "save_v1_builder"
42
43
44# Base class for the SavedModelBuilder that is only used by Tensorflow
45# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
46@tf_export("__internal__.saved_model.SavedModelBuilder", v1=[])
47class _SavedModelBuilder(object):
48  """Builds the `SavedModel` protocol buffer and saves variables and assets.
49
50  The `SavedModelBuilder` class provides the functionality to build a
51  `SavedModel` protocol buffer. Specifically, this allows multiple meta
52  graphs to be saved as part of a single language-neutral `SavedModel`,
53  while sharing variables and assets.
54
55  To build a SavedModel, the first meta graph must be saved with variables.
56  Subsequent meta graphs will simply be saved with their graph definitions. If
57  assets need to be saved and written or copied to disk, they can be provided
58  when the meta graph def is added. If multiple meta graph defs are associated
59  an asset of the same name, only the first version is retained.
60
61  Each meta graph added to the SavedModel must be annotated with tags. The tags
62  provide a means to identify the specific meta graph to load and restore, along
63  with the shared set of variables and assets.
64
65  Typical usage for the `SavedModelBuilder`:
66
67  ```python
68  ...
69  builder = tf.compat.v1.saved_model.Builder(export_dir)
70
71  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
72    ...
73    builder.add_meta_graph_and_variables(sess,
74                                    ["foo-tag"],
75                                    signature_def_map=foo_signatures,
76                                    assets_list=foo_assets)
77  ...
78
79  with tf.compat.v1.Session(graph=tf.Graph()) as sess:
80    ...
81    builder.add_meta_graph(["bar-tag", "baz-tag"])
82  ...
83
84  builder.save()
85  ```
86
87  Note: This function will only be available through the v1 compatibility
88  library as tf.compat.v1.saved_model.builder.SavedModelBuilder or
89  tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new
90  object-based method of creating SavedModels.
91  """
92
93  def __init__(self, export_dir):
94    self._saved_model = saved_model_pb2.SavedModel()
95    self._saved_model.saved_model_schema_version = (
96        constants.SAVED_MODEL_SCHEMA_VERSION)
97
98    self._export_dir = export_dir
99    if file_io.file_exists(export_dir):
100      if file_io.list_directory(export_dir):
101        raise AssertionError(
102            f"Export directory {export_dir} already exists, and isn't empty. "
103            "Please choose a different export directory, or delete all the "
104            "contents of the specified directory.")
105    else:
106      file_io.recursive_create_dir(self._export_dir)
107
108    # Boolean to track whether variables and assets corresponding to the
109    # SavedModel have been saved. Specifically, the first meta graph to be added
110    # MUST use the add_meta_graph_and_variables() API. Subsequent add operations
111    # on the SavedModel MUST use the add_meta_graph() API which does not save
112    # weights.
113    self._has_saved_variables = False
114
115  def _save_and_write_assets(self, meta_graph_def, assets_list=None):
116    """Saves asset to the meta graph and writes asset files to disk.
117
118    Args:
119      meta_graph_def: The meta graph def to which the assets will be added.
120      assets_list: The list where the asset paths are setup.
121    """
122    # Creates a function that adds assets into the meta graph def.
123    write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def)
124    asset_filename_map = _maybe_save_assets(write_fn, assets_list)
125
126    # Return if there are no assets to write.
127    if not asset_filename_map:
128      tf_logging.info("No assets to write.")
129      return
130
131    # Copy assets from source path to destination path.
132    copy_assets_to_destination_dir(asset_filename_map, self._export_dir)
133
134  def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
135    """Tags the meta graph def and adds it to the SavedModel.
136
137    Tags the meta graph def with the supplied tags, adds signature defs to it if
138    provided and appends the meta graph def to the SavedModel proto.
139
140    Args:
141      meta_graph_def: The meta graph def to add to the SavedModel.
142      tags: The set of tags to annotate the meta graph def with.
143      signature_def_map: The map of signature defs to be added to the meta graph
144          def.
145    """
146    for tag in tags:
147      meta_graph_def.meta_info_def.tags.append(tag)
148
149    if signature_def_map is not None:
150      for key in signature_def_map:
151        meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
152
153    proto_meta_graph_def = self._saved_model.meta_graphs.add()
154    proto_meta_graph_def.CopyFrom(meta_graph_def)
155
156  def _validate_tensor_info(self, tensor_info):
157    """Validates the `TensorInfo` proto.
158
159    Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and
160    `dtype` fields exist and are non-empty.
161
162    Args:
163      tensor_info: `TensorInfo` protocol buffer to validate.
164
165    Raises:
166      AssertionError: If the `encoding` or `dtype` fields of the supplied
167          `TensorInfo` proto are not populated.
168    """
169    if tensor_info is None:
170      raise AssertionError(
171          "All TensorInfo protos used in the SignatureDefs must have the name "
172          "and dtype fields set.")
173    if tensor_info.WhichOneof("encoding") is None:
174      # TODO(soergel) validate each of the fields of coo_sparse
175      raise AssertionError(
176          f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used "
177          "in the SignatureDefs must have one of the 'encoding' fields (e.g., "
178          "name or coo_sparse) set.")
179    if tensor_info.WhichOneof("encoding") == "composite_tensor":
180      for component in tensor_info.composite_tensor.components:
181        self._validate_tensor_info(component)
182    elif tensor_info.dtype == types_pb2.DT_INVALID:
183      raise AssertionError(
184          f"Invalid `tensor_info`: {tensor_info}. All TensorInfo protos used in"
185          " the SignatureDefs must have the dtype field set.")
186
187  def _validate_signature_def_map(self, signature_def_map):
188    """Validates the `SignatureDef` entries in the signature def map.
189
190    Validation of entries in the signature def map includes ensuring that the
191    `name` and `dtype` fields of the TensorInfo protos of the `inputs` and
192    `outputs` of each `SignatureDef` are populated. Also ensures that reserved
193    SignatureDef keys for the initialization and train ops are not used.
194
195    Args:
196      signature_def_map: The map of signature defs to be validated.
197
198    Raises:
199      AssertionError: If a TensorInfo is not valid.
200      KeyError: If a reserved signature key is used in the map.
201    """
202    for signature_def_key in signature_def_map:
203      signature_def = signature_def_map[signature_def_key]
204      inputs = signature_def.inputs
205      outputs = signature_def.outputs
206      for inputs_key in inputs:
207        self._validate_tensor_info(inputs[inputs_key])
208      for outputs_key in outputs:
209        self._validate_tensor_info(outputs[outputs_key])
210    if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
211      raise KeyError(
212          f"SignatureDef map key \"{constants.INIT_OP_SIGNATURE_KEY}\" is "
213          "reserved for initialization. Please use a different key.")
214    if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
215      raise KeyError(
216          f"SignatureDef map key \"{constants.TRAIN_OP_SIGNATURE_KEY}\" is "
217          f"reserved for the train op. Please use a different key.")
218
219  def _maybe_create_saver(self, saver=None):
220    """Creates a sharded saver if one does not already exist."""
221    if not saver:
222      # Initialize a saver to generate a sharded output for all saveables in the
223      # current scope.
224      saver = tf_saver.Saver(
225          variables._all_saveable_objects(),  # pylint: disable=protected-access
226          sharded=True,
227          write_version=saver_pb2.SaverDef.V2,
228          allow_empty=True)
229    return saver
230
231  def add_meta_graph(self,
232                     tags,
233                     signature_def_map=None,
234                     assets_list=None,
235                     clear_devices=False,
236                     init_op=None,
237                     train_op=None,
238                     saver=None):
239    """Adds the current meta graph to the SavedModel.
240
241    Creates a Saver in the current scope and uses the Saver to export the meta
242    graph def. Invoking this API requires the `add_meta_graph_and_variables()`
243    API to have been invoked before.
244
245    Args:
246      tags: The set of tags to annotate the meta graph def with.
247      signature_def_map: The map of signature defs to be added to the meta graph
248          def.
249      assets_list: Assets to be saved with SavedModel. Note
250          that this list should be a subset of the assets saved as part of
251          the first meta graph in the SavedModel.
252      clear_devices: Set to true if the device info on the default graph should
253          be cleared.
254      init_op: Op or group of ops to execute when the graph is loaded. Note
255          that when the init_op is specified it is run after the restore op at
256          load-time.
257      train_op: Op or group of opts that trains the model when run. This will
258        not be run automatically when the graph is loaded, instead saved in
259        a SignatureDef accessible through the exported MetaGraph.
260      saver: An instance of tf.compat.v1.train.Saver that will be used to export
261        the metagraph. If None, a sharded Saver that restores all variables will
262        be used.
263
264    Raises:
265      AssertionError: If the variables for the SavedModel have not been saved
266          yet, or if the graph already contains one or more legacy init ops.
267    """
268    if not self._has_saved_variables:
269      raise AssertionError(
270          "Graph state including variables and assets has not been saved yet. "
271          "Please invoke `add_meta_graph_and_variables()` first.")
272
273    # Validate the signature def map to ensure all included TensorInfos are
274    # properly populated.
275    signature_def_map = signature_def_map or {}
276    self._validate_signature_def_map(signature_def_map)
277
278    # Create a SignatureDef pointing to the graph initialization op, which will
279    # be added to the MetaGraphDef.
280    _add_op_to_signature_def_map(signature_def_map, init_op,
281                                 constants.INIT_OP_SIGNATURE_KEY)
282    _add_op_to_signature_def_map(signature_def_map, train_op,
283                                 constants.TRAIN_OP_SIGNATURE_KEY)
284
285    saver = self._maybe_create_saver(saver)
286
287    # The graph almost certainly previously contained at least one Saver, and
288    # possibly several (e.g. one for loading a pretrained embedding, and another
289    # for the model weights).  Removing the preexisting ones was the
290    # motivation for the clear_extraneous_savers option, but it turns out that
291    # there are edge cases where that option breaks the graph.  Until that is
292    # resolved, we just leave the option set to False for now.
293    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
294    meta_graph_def = saver.export_meta_graph(
295        clear_devices=clear_devices, strip_default_attrs=True)
296
297    # Save asset files and write them to disk, if any.
298    self._save_and_write_assets(meta_graph_def, assets_list)
299
300    # Tag the meta graph def and add it to the SavedModel.
301    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
302
303  def add_meta_graph_and_variables(self,
304                                   sess,
305                                   tags,
306                                   signature_def_map=None,
307                                   assets_list=None,
308                                   clear_devices=False,
309                                   init_op=None,
310                                   train_op=None,
311                                   strip_default_attrs=False,
312                                   saver=None):
313    # pylint: disable=line-too-long
314    """Adds the current meta graph to the SavedModel and saves variables.
315
316    Creates a Saver to save the variables from the provided session. Exports the
317    corresponding meta graph def. This function assumes that the variables to be
318    saved have been initialized. For a given `SavedModelBuilder`, this API must
319    be called exactly once and for the first meta graph to save. For subsequent
320    meta graph defs to be added, the `add_meta_graph()` API must be used.
321
322    Args:
323      sess: The TensorFlow session from which to save the meta graph and
324        variables.
325      tags: The set of tags with which to save the meta graph.
326      signature_def_map: The map of signature def map to add to the meta graph
327        def.
328      assets_list: Assets to be saved with SavedModel.
329      clear_devices: Set to true if the device info on the default graph should
330          be cleared.
331      init_op: Op or group of ops to execute when the graph is loaded. Note
332          that when the init_op is specified it is run after the restore op at
333          load-time.
334      train_op: Op or group of ops that trains the model when run. This will
335        not be run automatically when the graph is loaded, instead saved in
336        a SignatureDef accessible through the exported MetaGraph.
337      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
338        removed from the NodeDefs. For a detailed guide, see
339        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
340      saver: An instance of tf.compat.v1.train.Saver that will be used to export the
341        metagraph and save variables. If None, a sharded Saver that restores
342        all variables will be used.
343
344    """
345    # pylint: enable=line-too-long
346    if self._has_saved_variables:
347      raise AssertionError("Graph state including variables and assets has "
348                           "already been saved. Please invoke "
349                           "`add_meta_graph()` instead.")
350
351    # Validate the signature def map to ensure all included TensorInfos are
352    # properly populated.
353    signature_def_map = signature_def_map or {}
354    self._validate_signature_def_map(signature_def_map)
355
356    # Create a SignatureDef pointing to the graph initialization op, which will
357    # be added to the MetaGraphDef.
358    _add_op_to_signature_def_map(signature_def_map, init_op,
359                                 constants.INIT_OP_SIGNATURE_KEY)
360    _add_op_to_signature_def_map(signature_def_map, train_op,
361                                 constants.TRAIN_OP_SIGNATURE_KEY)
362
363    saved_model_utils.get_or_create_variables_dir(self._export_dir)
364    variables_path = saved_model_utils.get_variables_path(self._export_dir)
365
366    saver = self._maybe_create_saver(saver)
367
368    # Save the variables. Also, disable writing the checkpoint state proto. The
369    # file is not used during SavedModel loading. In addition, since a
370    # SavedModel can be copied or moved, this avoids the checkpoint state to
371    # become outdated.
372    saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
373
374    # Export the meta graph def.
375
376    # The graph almost certainly previously contained at least one Saver, and
377    # possibly several (e.g. one for loading a pretrained embedding, and another
378    # for the model weights).  Removing the preexisting ones was the
379    # motivation for the clear_extraneous_savers option, but it turns out that
380    # there are edge cases where that option breaks the graph.  Until that is
381    # resolved, we just leave the option set to False for now.
382    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
383    meta_graph_def = saver.export_meta_graph(
384        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
385
386    # Save asset files and write them to disk, if any.
387    self._save_and_write_assets(meta_graph_def, assets_list)
388
389    # Tag the meta graph def and add it to the SavedModel.
390    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
391
392    # Mark this instance of SavedModel as having saved variables, such that
393    # subsequent attempts to save variables will fail.
394    self._has_saved_variables = True
395
396  def save(self, as_text=False):
397    """Writes a `SavedModel` protocol buffer to disk.
398
399    The function writes the SavedModel protocol buffer to the export directory
400    in a serialized format.
401
402    Args:
403      as_text: Writes the SavedModel protocol buffer in text format to
404        disk. Protocol buffers in text format are useful for debugging, but
405        parsing fails when it encounters an unknown field and so is not forward
406        compatible. This means changes to TensorFlow may prevent deployment of
407        new text format SavedModels to existing serving binaries. Do not deploy
408        `as_text` SavedModels to production.
409
410    Returns:
411      The path to which the SavedModel protocol buffer was written.
412    """
413    metrics.IncrementWriteApi(_SAVE_BUILDER_LABEL)
414    if not file_io.file_exists(self._export_dir):
415      file_io.recursive_create_dir(self._export_dir)
416
417    if as_text:
418      path = file_io.join(
419          compat.as_bytes(self._export_dir),
420          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
421      file_io.write_string_to_file(path, str(self._saved_model))
422    else:
423      path = file_io.join(
424          compat.as_bytes(self._export_dir),
425          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
426      file_io.write_string_to_file(
427          path, self._saved_model.SerializeToString(deterministic=True))
428    tf_logging.info("SavedModel written to: %s", compat.as_text(path))
429    metrics.IncrementWrite(write_version="1")
430    return path
431
432
433@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"])  # pylint: disable=missing-docstring
434class SavedModelBuilder(_SavedModelBuilder):
435  __doc__ = _SavedModelBuilder.__doc__.replace("assets_list",
436                                               "assets_collection")
437
438  def __init__(self, export_dir):
439    super(SavedModelBuilder, self).__init__(export_dir=export_dir)
440
441  def _add_collections(self, assets_collection, main_op, train_op):
442    """Add asset and op collections to be saved."""
443    # Save asset files and write them to disk, if any.
444    self._save_and_write_assets(assets_collection)
445
446    self._maybe_add_main_op(main_op)
447
448    self._add_train_op(train_op)
449
450  def _save_and_write_assets(self, assets_collection_to_add=None):
451    """Saves asset to the meta graph and writes asset files to disk.
452
453    Args:
454      assets_collection_to_add: The collection where the asset paths are setup.
455    """
456    # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the
457    # graph.
458    asset_filename_map = _maybe_save_assets(_add_asset_to_collection,
459                                            assets_collection_to_add)
460
461    # Return if there are no assets to write.
462    if not asset_filename_map:
463      tf_logging.info("No assets to write.")
464      return
465
466    # Copy assets from source path to destination path.
467    copy_assets_to_destination_dir(asset_filename_map, self._export_dir)
468
469  def _maybe_add_main_op(self, main_op):
470    """Adds main op to the SavedModel.
471
472    Args:
473      main_op: Main op to run as part of graph initialization. If None, no main
474        op will be added to the graph.
475
476    Raises:
477      TypeError: If the main op is provided but is not of type `Operation`.
478      ValueError: if the Graph already contains an init op.
479    """
480    if main_op is None:
481      return
482
483    if not isinstance(main_op, ops.Operation):
484      raise TypeError(f"Expected {main_op} to be an Operation but got type "
485                      f"{type(main_op)} instead.")
486
487    # Validate that no other init ops have been added to this graph already.
488    # We check main_op and legacy_init_op for thoroughness and explicitness.
489    for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
490      if ops.get_collection(init_op_key):
491        raise ValueError(
492            "Graph already contains one or more main ops under the "
493            f"collection {init_op_key}.")
494
495    ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
496
497  def _add_train_op(self, train_op):
498    """Add train op to the SavedModel.
499
500    Note that this functionality is in development, and liable to be
501    moved elsewhere.
502
503    Args:
504      train_op: Op or group of ops that are used for training. These are stored
505        as a collection with key TRAIN_OP_KEY, but not executed.
506
507    Raises:
508      TypeError if Train op is not of type `Operation`.
509    """
510    if train_op is not None:
511      if (not isinstance(train_op, ops.Tensor) and
512          not isinstance(train_op, ops.Operation)):
513        raise TypeError(f"`train_op` {train_op} needs to be a Tensor or Op.")
514      ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
515
516  @deprecated_args(None,
517                   "Pass your op to the equivalent parameter main_op instead.",
518                   "legacy_init_op")
519  def add_meta_graph(self,
520                     tags,
521                     signature_def_map=None,
522                     assets_collection=None,
523                     legacy_init_op=None,
524                     clear_devices=False,
525                     main_op=None,
526                     strip_default_attrs=False,
527                     saver=None):
528    if not self._has_saved_variables:
529      raise AssertionError(
530          "Graph state including variables and assets has not been saved yet. "
531          "Please invoke `add_meta_graph_and_variables()` first.")
532
533    # Validate the signature def map to ensure all included TensorInfos are
534    # properly populated.
535    signature_def_map = signature_def_map or {}
536    self._validate_signature_def_map(signature_def_map)
537
538    # legacy_init_op is deprecated, and going away in TF 2.0.
539    # Re-mapping to main_op, as treatment is identical regardless.
540    main_op = main_op if main_op is not None else legacy_init_op
541
542    # Add assets and ops
543    self._add_collections(assets_collection, main_op, None)
544
545    saver = self._maybe_create_saver(saver)
546
547    # The graph almost certainly previously contained at least one Saver, and
548    # possibly several (e.g. one for loading a pretrained embedding, and another
549    # for the model weights).  Removing the preexisting ones was the
550    # motivation for the clear_extraneous_savers option, but it turns out that
551    # there are edge cases where that option breaks the graph.  Until that is
552    # resolved, we just leave the option set to False for now.
553    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
554    meta_graph_def = saver.export_meta_graph(
555        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
556
557    # Tag the meta graph def and add it to the SavedModel.
558    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
559
560  @deprecated_args(None,
561                   "Pass your op to the equivalent parameter main_op instead.",
562                   "legacy_init_op")
563  def add_meta_graph_and_variables(self,
564                                   sess,
565                                   tags,
566                                   signature_def_map=None,
567                                   assets_collection=None,
568                                   legacy_init_op=None,
569                                   clear_devices=False,
570                                   main_op=None,
571                                   strip_default_attrs=False,
572                                   saver=None):
573    if self._has_saved_variables:
574      raise AssertionError("Graph state including variables and assets has "
575                           "already been saved. Please invoke "
576                           "`add_meta_graph()` instead.")
577
578    # Validate the signature def map to ensure all included TensorInfos are
579    # properly populated.
580    signature_def_map = signature_def_map or {}
581    self._validate_signature_def_map(signature_def_map)
582
583    # legacy_init_op is deprecated, and going away in TF 2.0.
584    # Re-mapping to main_op, as treatment is identical regardless.
585    main_op = main_op or legacy_init_op
586
587    # Add assets and ops
588    self._add_collections(assets_collection, main_op, None)
589
590    saved_model_utils.get_or_create_variables_dir(self._export_dir)
591    variables_path = saved_model_utils.get_variables_path(self._export_dir)
592
593    saver = self._maybe_create_saver(saver)
594
595    # Save the variables. Also, disable writing the checkpoint state proto. The
596    # file is not used during SavedModel loading. In addition, since a
597    # SavedModel can be copied or moved, this avoids the checkpoint state to
598    # become outdated.
599    saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
600
601    # Export the meta graph def.
602
603    # The graph almost certainly previously contained at least one Saver, and
604    # possibly several (e.g. one for loading a pretrained embedding, and another
605    # for the model weights).  Removing the preexisting ones was the
606    # motivation for the clear_extraneous_savers option, but it turns out that
607    # there are edge cases where that option breaks the graph.  Until that is
608    # resolved, we just leave the option set to False for now.
609    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
610    meta_graph_def = saver.export_meta_graph(
611        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
612
613    # Tag the meta graph def and add it to the SavedModel.
614    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
615
616    # Mark this instance of SavedModel as having saved variables, such that
617    # subsequent attempts to save variables will fail.
618    self._has_saved_variables = True
619
620  add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace(
621      "assets_list", "assets_collection")
622  add_meta_graph_and_variables.__doc__ = \
623      _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace(
624          "assets_list", "assets_collection")
625
626
627def _maybe_save_assets(write_fn, assets_to_add=None):
628  """Saves assets to the meta graph.
629
630  Args:
631    write_fn: A function callback that writes assets into meta graph.
632    assets_to_add: The list where the asset paths are setup.
633
634  Returns:
635    A dict of asset basenames for saving to the original full path to the asset.
636
637  Raises:
638    ValueError: Indicating an invalid filepath tensor.
639  """
640  # Map of target file names to original filenames
641  asset_filename_map = {}
642
643  if assets_to_add is None:
644    tf_logging.info("No assets to save.")
645    return asset_filename_map
646
647  # Iterate over the supplied assets, build the `AssetFile` proto and add them
648  # to the meta graph.
649  for asset_tensor in assets_to_add:
650    asset_source_filepath = _asset_path_from_tensor(asset_tensor)
651    if not asset_source_filepath:
652      raise ValueError(f"Asset filepath tensor {asset_tensor} in is invalid.")
653
654    asset_filename = get_asset_filename_to_add(
655        asset_source_filepath, asset_filename_map)
656
657    # Call the passed-in function that builds AssetFileDef proto and adds it
658    # to either the collection or asset_file_def field of the meta graph.
659    # Note that this should be done even when the file is a duplicate of an
660    # already-added file, as the tensor reference should still exist.
661    write_fn(asset_filename, asset_tensor)
662
663    # In the cases where we are adding a duplicate, this will result in the
664    # last of the filepaths being the one used for copying the file to the
665    # SavedModel. Since the files in question are the same, it doesn't matter
666    # either way.
667    asset_filename_map[asset_filename] = asset_source_filepath
668
669  tf_logging.info("Assets added to graph.")
670  return asset_filename_map
671
672
673def get_asset_filename_to_add(asset_filepath, asset_filename_map):
674  """Get a unique basename to add to the SavedModel if this file is unseen.
675
676  Assets come from users as full paths, and we save them out to the
677  SavedModel as basenames. In some cases, the basenames collide. Here,
678  we dedupe asset basenames by first checking if the file is the same,
679  and, if different, generate and return an index-suffixed basename
680  that can be used to add the asset to the SavedModel.
681
682  Args:
683    asset_filepath: the full path to the asset that is being saved
684    asset_filename_map: a dict of filenames used for saving the asset in
685      the SavedModel to full paths from which the filenames were derived.
686
687  Returns:
688    Uniquified filename string if the file is not a duplicate, or the original
689    filename if the file has already been seen and saved.
690  """
691  asset_filename = os.path.basename(asset_filepath)
692
693  if asset_filename not in asset_filename_map:
694    # This is an unseen asset. Safe to add.
695    return asset_filename
696
697  other_asset_filepath = asset_filename_map[asset_filename]
698  if other_asset_filepath == asset_filepath:
699    # This is the same file, stored twice in the list. No need
700    # to make unique.
701    return asset_filename
702
703  # Else, asset_filename is in the map, and the filepath is different. Dedupe.
704  if not file_io.filecmp(asset_filepath, other_asset_filepath):
705    # Files are different; dedupe filenames.
706    return _get_unique_asset_filename(asset_filename, asset_filename_map)
707
708  # Files are the same; don't make unique.
709  return asset_filename
710
711
712def _get_unique_asset_filename(asset_filename, asset_filename_map):
713  i = 1
714  unique_filename = asset_filename
715  while unique_filename in asset_filename_map:
716    unique_filename = compat.as_bytes("_").join(
717        [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
718    i += 1
719  return unique_filename
720
721
722def _asset_path_from_tensor(path_tensor):
723  """Returns the filepath value stored in constant `path_tensor`.
724
725  Args:
726    path_tensor: Tensor of a file-path.
727
728  Returns:
729    The string value i.e. path of the tensor, if valid.
730
731  Raises:
732    TypeError if tensor does not match expected op type, dtype or value.
733  """
734  if not isinstance(path_tensor, ops.Tensor):
735    raise TypeError(f"Asset path tensor {path_tensor} must be a Tensor.")
736  if path_tensor.op.type != "Const":
737    raise TypeError(f"Asset path tensor {path_tensor} must be of type constant."
738                    f"Has type {path_tensor.op.type} instead.")
739  if path_tensor.dtype != dtypes.string:
740    raise TypeError(f"Asset path tensor {path_tensor}` must be of dtype string."
741                    f"Has type {path_tensor.dtype} instead.")
742  str_values = path_tensor.op.get_attr("value").string_val
743  if len(str_values) != 1:
744    raise TypeError(f"Asset path tensor {path_tensor} must be a scalar.")
745  return str_values[0]
746
747
748def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor):
749  """Builds an asset proto and adds it to the meta graph def.
750
751  Args:
752    meta_graph_def: The meta graph def to which the asset will be added.
753    asset_filename: The filename of the asset to be added.
754    asset_tensor: The asset tensor used to populate the tensor info of the asset
755      proto.
756  """
757  asset_proto = meta_graph_def.asset_file_def.add()
758  asset_proto.filename = asset_filename
759  asset_proto.tensor_info.name = asset_tensor.name
760
761
762def copy_assets_to_destination_dir(asset_filename_map, destination_dir):
763  """Copy all assets from source path to destination path."""
764  assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
765      destination_dir)
766
767  # Copy each asset from source path to destination path.
768  for asset_basename, asset_source_filepath in asset_filename_map.items():
769    asset_destination_filepath = file_io.join(
770        compat.as_bytes(assets_destination_dir),
771        compat.as_bytes(asset_basename))
772
773    # Only copy the asset file to the destination if it does not already
774    # exist. This is to ensure that an asset with the same name defined as
775    # part of multiple graphs is only copied the first time.
776    if not file_io.file_exists(asset_destination_filepath):
777      file_io.copy(asset_source_filepath, asset_destination_filepath)
778
779  tf_logging.info("Assets written to: %s",
780                  compat.as_text(assets_destination_dir))
781
782
783def _add_asset_to_collection(asset_filename, asset_tensor):
784  """Builds an asset proto and adds it to the asset collection of the graph.
785
786  Args:
787    asset_filename: The filename of the asset to be added.
788    asset_tensor: The asset tensor used to populate the tensor info of the
789        asset proto.
790  """
791  asset_proto = meta_graph_pb2.AssetFileDef()
792  asset_proto.filename = asset_filename
793  asset_proto.tensor_info.name = asset_tensor.name
794
795  asset_any_proto = Any()
796  asset_any_proto.Pack(asset_proto)
797  ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
798
799
800def _add_op_to_signature_def_map(signature_def_map, op, key):
801  if op is not None:
802    signature_def_map[key] = signature_def_utils.op_signature_def(op, key)
803