xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/saved_model_experimental.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Deprecated experimental Keras SavedModel implementation."""
16
17import os
18import warnings
19from tensorflow.python.checkpoint import graph_view
20from tensorflow.python.client import session
21from tensorflow.python.framework import ops
22from tensorflow.python.keras import backend
23from tensorflow.python.keras import optimizer_v1
24from tensorflow.python.keras.optimizer_v2 import optimizer_v2
25from tensorflow.python.keras.saving import model_config
26from tensorflow.python.keras.saving import saving_utils
27from tensorflow.python.keras.saving import utils_v1 as model_utils
28from tensorflow.python.keras.utils import mode_keys
29from tensorflow.python.keras.utils.generic_utils import LazyLoader
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import gfile
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.saved_model import builder as saved_model_builder
35from tensorflow.python.saved_model import constants
36from tensorflow.python.saved_model import save as save_lib
37from tensorflow.python.training import saver as saver_lib
38from tensorflow.python.util import compat
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import keras_export
41
42# To avoid circular dependencies between keras/engine and keras/saving,
43# code in keras/saving must delay imports.
44
45# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
46# once the issue with copybara is fixed.
47# pylint:disable=g-inconsistent-quotes
48metrics_lib = LazyLoader("metrics_lib", globals(),
49                         "tensorflow.python.keras.metrics")
50models_lib = LazyLoader("models_lib", globals(),
51                        "tensorflow.python.keras.models")
52sequential = LazyLoader(
53    "sequential", globals(),
54    "tensorflow.python.keras.engine.sequential")
55# pylint:enable=g-inconsistent-quotes
56
57
58# File name for json format of SavedModel.
59SAVED_MODEL_FILENAME_JSON = 'saved_model.json'
60
61
62@keras_export(v1=['keras.experimental.export_saved_model'])
63def export_saved_model(model,
64                       saved_model_path,
65                       custom_objects=None,
66                       as_text=False,
67                       input_signature=None,
68                       serving_only=False):
69  """Exports a `tf.keras.Model` as a Tensorflow SavedModel.
70
71  Note that at this time, subclassed models can only be saved using
72  `serving_only=True`.
73
74  The exported `SavedModel` is a standalone serialization of Tensorflow objects,
75  and is supported by TF language APIs and the Tensorflow Serving system.
76  To load the model, use the function
77  `tf.keras.experimental.load_from_saved_model`.
78
79  The `SavedModel` contains:
80
81  1. a checkpoint containing the model weights.
82  2. a `SavedModel` proto containing the Tensorflow backend graph. Separate
83     graphs are saved for prediction (serving), train, and evaluation. If
84     the model has not been compiled, then only the graph computing predictions
85     will be exported.
86  3. the model's json config. If the model is subclassed, this will only be
87     included if the model's `get_config()` method is overwritten.
88
89  Example:
90
91  ```python
92  import tensorflow as tf
93
94  # Create a tf.keras model.
95  model = tf.keras.Sequential()
96  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
97  model.summary()
98
99  # Save the tf.keras model in the SavedModel format.
100  path = '/tmp/simple_keras_model'
101  tf.keras.experimental.export_saved_model(model, path)
102
103  # Load the saved keras model back.
104  new_model = tf.keras.experimental.load_from_saved_model(path)
105  new_model.summary()
106  ```
107
108  Args:
109    model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag
110      `serving_only` must be set to True.
111    saved_model_path: a string specifying the path to the SavedModel directory.
112    custom_objects: Optional dictionary mapping string names to custom classes
113      or functions (e.g. custom loss functions).
114    as_text: bool, `False` by default. Whether to write the `SavedModel` proto
115      in text format. Currently unavailable in serving-only mode.
116    input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used
117      to specify the expected model inputs. See `tf.function` for more details.
118    serving_only: bool, `False` by default. When this is true, only the
119      prediction graph is saved.
120
121  Raises:
122    NotImplementedError: If the model is a subclassed model, and serving_only is
123      False.
124    ValueError: If the input signature cannot be inferred from the model.
125    AssertionError: If the SavedModel directory already exists and isn't empty.
126  """
127  warnings.warn('`tf.keras.experimental.export_saved_model` is deprecated'
128                'and will be removed in a future version. '
129                'Please use `model.save(..., save_format="tf")` or '
130                '`tf.keras.models.save_model(..., save_format="tf")`.')
131  if serving_only:
132    save_lib.save(
133        model,
134        saved_model_path,
135        signatures=saving_utils.trace_model_call(model, input_signature))
136  else:
137    _save_v1_format(model, saved_model_path, custom_objects, as_text,
138                    input_signature)
139
140  try:
141    _export_model_json(model, saved_model_path)
142  except NotImplementedError:
143    logging.warning('Skipped saving model JSON, subclassed model does not have '
144                    'get_config() defined.')
145
146
147def _export_model_json(model, saved_model_path):
148  """Saves model configuration as a json string under assets folder."""
149  model_json = model.to_json()
150  model_json_filepath = os.path.join(
151      _get_or_create_assets_dir(saved_model_path),
152      compat.as_text(SAVED_MODEL_FILENAME_JSON))
153  with gfile.Open(model_json_filepath, 'w') as f:
154    f.write(model_json)
155
156
157def _export_model_variables(model, saved_model_path):
158  """Saves model weights in checkpoint format under variables folder."""
159  _get_or_create_variables_dir(saved_model_path)
160  checkpoint_prefix = _get_variables_path(saved_model_path)
161  model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
162  return checkpoint_prefix
163
164
165def _save_v1_format(model, path, custom_objects, as_text, input_signature):
166  """Exports model to v1 SavedModel format."""
167  if not model._is_graph_network:  # pylint: disable=protected-access
168    if isinstance(model, sequential.Sequential):
169      # If input shape is not directly set in the model, the exported model
170      # will infer the expected shapes of the input from the model.
171      if not model.built:
172        raise ValueError('Weights for sequential model have not yet been '
173                         'created. Weights are created when the Model is first '
174                         'called on inputs or `build()` is called with an '
175                         '`input_shape`, or the first layer in the model has '
176                         '`input_shape` during construction.')
177      # TODO(kathywu): Build the model with input_signature to create the
178      # weights before _export_model_variables().
179    else:
180      raise NotImplementedError(
181          'Subclassed models can only be exported for serving. Please set '
182          'argument serving_only=True.')
183
184  builder = saved_model_builder._SavedModelBuilder(path)  # pylint: disable=protected-access
185
186  # Manually save variables to export them in an object-based checkpoint. This
187  # skips the `builder.add_meta_graph_and_variables()` step, which saves a
188  # named-based checkpoint.
189  # TODO(b/113134168): Add fn to Builder to save with object-based saver.
190  # TODO(b/113178242): This should only export the model json structure. Only
191  # one save is needed once the weights can be copied from the model to clone.
192  checkpoint_path = _export_model_variables(model, path)
193
194  # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
195  # Keras models and `Estimator`s are exported with the same format.
196  # Every time a mode is exported, the code checks to see if new variables have
197  # been created (e.g. optimizer slot variables). If that is the case, the
198  # checkpoint is re-saved to include the new variables.
199  export_args = {'builder': builder,
200                 'model': model,
201                 'custom_objects': custom_objects,
202                 'checkpoint_path': checkpoint_path,
203                 'input_signature': input_signature}
204
205  has_saved_vars = False
206  if model.optimizer:
207    if isinstance(model.optimizer, (optimizer_v1.TFOptimizer,
208                                    optimizer_v2.OptimizerV2)):
209      _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
210      has_saved_vars = True
211      _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
212    else:
213      logging.warning(
214          'Model was compiled with an optimizer, but the optimizer is not from '
215          '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
216          'graph was exported. The train and evaluate graphs were not added to '
217          'the SavedModel.')
218  _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args)
219
220  builder.save(as_text)
221
222
223def _get_var_list(model):
224  """Returns list of all checkpointed saveable objects in the model."""
225  var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
226  return var_list
227
228
229def create_placeholder(spec):
230  return backend.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name)
231
232
233def _export_mode(
234    mode, has_saved_vars, builder, model, custom_objects, checkpoint_path,
235    input_signature):
236  """Exports a model, and optionally saves new vars from the clone model.
237
238  Args:
239    mode: A `tf.estimator.ModeKeys` string.
240    has_saved_vars: A `boolean` indicating whether the SavedModel has already
241      exported variables.
242    builder: A `SavedModelBuilder` object.
243    model: A `tf.keras.Model` object.
244    custom_objects: A dictionary mapping string names to custom classes
245      or functions.
246    checkpoint_path: String path to checkpoint.
247    input_signature: Nested TensorSpec containing the expected inputs. Can be
248      `None`, in which case the signature will be inferred from the model.
249
250  Raises:
251    ValueError: If the train/eval mode is being exported, but the model does
252      not have an optimizer.
253  """
254  compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
255  if compile_clone and not model.optimizer:
256    raise ValueError(
257        'Model does not have an optimizer. Cannot export mode %s' % mode)
258
259  model_graph = ops.get_default_graph()
260  with ops.Graph().as_default() as g, backend.learning_phase_scope(
261      mode == mode_keys.ModeKeys.TRAIN):
262
263    if input_signature is None:
264      input_tensors = None
265    else:
266      input_tensors = nest.map_structure(create_placeholder, input_signature)
267
268    # Clone the model into blank graph. This will create placeholders for inputs
269    # and targets.
270    clone = models_lib.clone_and_build_model(
271        model, input_tensors=input_tensors, custom_objects=custom_objects,
272        compile_clone=compile_clone)
273
274    # Make sure that iterations variable is added to the global step collection,
275    # to ensure that, when the SavedModel graph is loaded, the iterations
276    # variable is returned by `tf.compat.v1.train.get_global_step()`. This is
277    # required for compatibility with the SavedModelEstimator.
278    if compile_clone:
279      g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
280
281    # Extract update and train ops from train/test/predict functions.
282    train_op = None
283    if mode == mode_keys.ModeKeys.TRAIN:
284      clone._make_train_function()  # pylint: disable=protected-access
285      train_op = clone.train_function.updates_op
286    elif mode == mode_keys.ModeKeys.TEST:
287      clone._make_test_function()  # pylint: disable=protected-access
288    else:
289      clone._make_predict_function()  # pylint: disable=protected-access
290    g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
291
292    with session.Session().as_default():
293      clone_var_list = _get_var_list(clone)
294      if has_saved_vars:
295        # Confirm all variables in the clone have an entry in the checkpoint.
296        status = clone.load_weights(checkpoint_path)
297        status.assert_existing_objects_matched()
298      else:
299        # Confirm that variables between the clone and model match up exactly,
300        # not counting optimizer objects. Optimizer objects are ignored because
301        # if the model has not trained, the slot variables will not have been
302        # created yet.
303        # TODO(b/113179535): Replace with trackable equivalence.
304        _assert_same_non_optimizer_objects(model, model_graph, clone, g)
305
306        # TODO(b/113178242): Use value transfer for trackable objects.
307        clone.load_weights(checkpoint_path)
308
309        # Add graph and variables to SavedModel.
310        # TODO(b/113134168): Switch to add_meta_graph_and_variables.
311        clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
312        builder._has_saved_variables = True  # pylint: disable=protected-access
313
314      # Add graph to the SavedModel builder.
315      builder.add_meta_graph(
316          model_utils.EXPORT_TAG_MAP[mode],
317          signature_def_map=_create_signature_def_map(clone, mode),
318          saver=saver_lib.Saver(
319              clone_var_list,
320              # Allow saving Models with no variables. This is somewhat odd, but
321              # it's not necessarily a bug.
322              allow_empty=True),
323          init_op=variables.local_variables_initializer(),
324          train_op=train_op)
325    return None
326
327
328def _create_signature_def_map(model, mode):
329  """Creates a SignatureDef map from a Keras model."""
330  inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
331  if model.optimizer:
332    targets_dict = {x.name.split(':')[0]: x
333                    for x in model._targets if x is not None}  # pylint: disable=protected-access
334    inputs_dict.update(targets_dict)
335  outputs_dict = {name: x
336                  for name, x in zip(model.output_names, model.outputs)}
337  metrics = saving_utils.extract_model_metrics(model)
338
339  # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
340  # are by default not added to any collections. We are doing this here, so
341  # that metric variables get initialized.
342  local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
343  vars_to_add = set()
344  if metrics is not None:
345    for key, value in metrics.items():
346      if isinstance(value, metrics_lib.Metric):
347        vars_to_add.update(value.variables)
348        # Convert Metric instances to (value_tensor, update_op) tuple.
349        metrics[key] = (value.result(), value.updates[0])
350  # Remove variables that are in the local variables collection already.
351  vars_to_add = vars_to_add.difference(local_vars)
352  for v in vars_to_add:
353    ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
354
355  export_outputs = model_utils.export_outputs_for_mode(
356      mode,
357      predictions=outputs_dict,
358      loss=model.total_loss if model.optimizer else None,
359      metrics=metrics)
360  return model_utils.build_all_signature_defs(
361      inputs_dict,
362      export_outputs=export_outputs,
363      serving_only=(mode == mode_keys.ModeKeys.PREDICT))
364
365
366def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):  # pylint: disable=unused-argument
367  """Asserts model and clone contain the same trackable objects."""
368
369  # TODO(fchollet, kathywu): make sure this works in eager mode.
370  return True
371
372
373@keras_export(v1=['keras.experimental.load_from_saved_model'])
374def load_from_saved_model(saved_model_path, custom_objects=None):
375  """Loads a keras Model from a SavedModel created by `export_saved_model()`.
376
377  This function reinstantiates model state by:
378  1) loading model topology from json (this will eventually come
379     from metagraph).
380  2) loading model weights from checkpoint.
381
382  Example:
383
384  ```python
385  import tensorflow as tf
386
387  # Create a tf.keras model.
388  model = tf.keras.Sequential()
389  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
390  model.summary()
391
392  # Save the tf.keras model in the SavedModel format.
393  path = '/tmp/simple_keras_model'
394  tf.keras.experimental.export_saved_model(model, path)
395
396  # Load the saved keras model back.
397  new_model = tf.keras.experimental.load_from_saved_model(path)
398  new_model.summary()
399  ```
400
401  Args:
402    saved_model_path: a string specifying the path to an existing SavedModel.
403    custom_objects: Optional dictionary mapping names
404        (strings) to custom classes or functions to be
405        considered during deserialization.
406
407  Returns:
408    a keras.Model instance.
409  """
410  warnings.warn('`tf.keras.experimental.load_from_saved_model` is deprecated'
411                'and will be removed in a future version. '
412                'Please switch to `tf.keras.models.load_model`.')
413  # restore model topology from json string
414  model_json_filepath = os.path.join(
415      compat.as_bytes(saved_model_path),
416      compat.as_bytes(constants.ASSETS_DIRECTORY),
417      compat.as_bytes(SAVED_MODEL_FILENAME_JSON))
418  with gfile.Open(model_json_filepath, 'r') as f:
419    model_json = f.read()
420  model = model_config.model_from_json(
421      model_json, custom_objects=custom_objects)
422
423  # restore model weights
424  checkpoint_prefix = os.path.join(
425      compat.as_text(saved_model_path),
426      compat.as_text(constants.VARIABLES_DIRECTORY),
427      compat.as_text(constants.VARIABLES_FILENAME))
428  model.load_weights(checkpoint_prefix)
429  return model
430
431
432#### Directory / path helpers
433
434
435def _get_or_create_variables_dir(export_dir):
436  """Return variables sub-directory, or create one if it doesn't exist."""
437  variables_dir = _get_variables_dir(export_dir)
438  file_io.recursive_create_dir(variables_dir)
439  return variables_dir
440
441
442def _get_variables_dir(export_dir):
443  """Return variables sub-directory in the SavedModel."""
444  return os.path.join(
445      compat.as_text(export_dir),
446      compat.as_text(constants.VARIABLES_DIRECTORY))
447
448
449def _get_variables_path(export_dir):
450  """Return the variables path, used as the prefix for checkpoint files."""
451  return os.path.join(
452      compat.as_text(_get_variables_dir(export_dir)),
453      compat.as_text(constants.VARIABLES_FILENAME))
454
455
456def _get_or_create_assets_dir(export_dir):
457  """Return assets sub-directory, or create one if it doesn't exist."""
458  assets_destination_dir = _get_assets_dir(export_dir)
459
460  file_io.recursive_create_dir(assets_destination_dir)
461
462  return assets_destination_dir
463
464
465def _get_assets_dir(export_dir):
466  """Return path to asset directory in the SavedModel."""
467  return os.path.join(
468      compat.as_text(export_dir),
469      compat.as_text(constants.ASSETS_DIRECTORY))
470