1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Deprecated experimental Keras SavedModel implementation.""" 16 17import os 18import warnings 19from tensorflow.python.checkpoint import graph_view 20from tensorflow.python.client import session 21from tensorflow.python.framework import ops 22from tensorflow.python.keras import backend 23from tensorflow.python.keras import optimizer_v1 24from tensorflow.python.keras.optimizer_v2 import optimizer_v2 25from tensorflow.python.keras.saving import model_config 26from tensorflow.python.keras.saving import saving_utils 27from tensorflow.python.keras.saving import utils_v1 as model_utils 28from tensorflow.python.keras.utils import mode_keys 29from tensorflow.python.keras.utils.generic_utils import LazyLoader 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import gfile 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.saved_model import builder as saved_model_builder 35from tensorflow.python.saved_model import constants 36from tensorflow.python.saved_model import save as save_lib 37from tensorflow.python.training import saver as saver_lib 38from tensorflow.python.util import compat 39from tensorflow.python.util import nest 40from tensorflow.python.util.tf_export import keras_export 41 42# To avoid circular dependencies between keras/engine and keras/saving, 43# code in keras/saving must delay imports. 44 45# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 46# once the issue with copybara is fixed. 47# pylint:disable=g-inconsistent-quotes 48metrics_lib = LazyLoader("metrics_lib", globals(), 49 "tensorflow.python.keras.metrics") 50models_lib = LazyLoader("models_lib", globals(), 51 "tensorflow.python.keras.models") 52sequential = LazyLoader( 53 "sequential", globals(), 54 "tensorflow.python.keras.engine.sequential") 55# pylint:enable=g-inconsistent-quotes 56 57 58# File name for json format of SavedModel. 59SAVED_MODEL_FILENAME_JSON = 'saved_model.json' 60 61 62@keras_export(v1=['keras.experimental.export_saved_model']) 63def export_saved_model(model, 64 saved_model_path, 65 custom_objects=None, 66 as_text=False, 67 input_signature=None, 68 serving_only=False): 69 """Exports a `tf.keras.Model` as a Tensorflow SavedModel. 70 71 Note that at this time, subclassed models can only be saved using 72 `serving_only=True`. 73 74 The exported `SavedModel` is a standalone serialization of Tensorflow objects, 75 and is supported by TF language APIs and the Tensorflow Serving system. 76 To load the model, use the function 77 `tf.keras.experimental.load_from_saved_model`. 78 79 The `SavedModel` contains: 80 81 1. a checkpoint containing the model weights. 82 2. a `SavedModel` proto containing the Tensorflow backend graph. Separate 83 graphs are saved for prediction (serving), train, and evaluation. If 84 the model has not been compiled, then only the graph computing predictions 85 will be exported. 86 3. the model's json config. If the model is subclassed, this will only be 87 included if the model's `get_config()` method is overwritten. 88 89 Example: 90 91 ```python 92 import tensorflow as tf 93 94 # Create a tf.keras model. 95 model = tf.keras.Sequential() 96 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 97 model.summary() 98 99 # Save the tf.keras model in the SavedModel format. 100 path = '/tmp/simple_keras_model' 101 tf.keras.experimental.export_saved_model(model, path) 102 103 # Load the saved keras model back. 104 new_model = tf.keras.experimental.load_from_saved_model(path) 105 new_model.summary() 106 ``` 107 108 Args: 109 model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag 110 `serving_only` must be set to True. 111 saved_model_path: a string specifying the path to the SavedModel directory. 112 custom_objects: Optional dictionary mapping string names to custom classes 113 or functions (e.g. custom loss functions). 114 as_text: bool, `False` by default. Whether to write the `SavedModel` proto 115 in text format. Currently unavailable in serving-only mode. 116 input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used 117 to specify the expected model inputs. See `tf.function` for more details. 118 serving_only: bool, `False` by default. When this is true, only the 119 prediction graph is saved. 120 121 Raises: 122 NotImplementedError: If the model is a subclassed model, and serving_only is 123 False. 124 ValueError: If the input signature cannot be inferred from the model. 125 AssertionError: If the SavedModel directory already exists and isn't empty. 126 """ 127 warnings.warn('`tf.keras.experimental.export_saved_model` is deprecated' 128 'and will be removed in a future version. ' 129 'Please use `model.save(..., save_format="tf")` or ' 130 '`tf.keras.models.save_model(..., save_format="tf")`.') 131 if serving_only: 132 save_lib.save( 133 model, 134 saved_model_path, 135 signatures=saving_utils.trace_model_call(model, input_signature)) 136 else: 137 _save_v1_format(model, saved_model_path, custom_objects, as_text, 138 input_signature) 139 140 try: 141 _export_model_json(model, saved_model_path) 142 except NotImplementedError: 143 logging.warning('Skipped saving model JSON, subclassed model does not have ' 144 'get_config() defined.') 145 146 147def _export_model_json(model, saved_model_path): 148 """Saves model configuration as a json string under assets folder.""" 149 model_json = model.to_json() 150 model_json_filepath = os.path.join( 151 _get_or_create_assets_dir(saved_model_path), 152 compat.as_text(SAVED_MODEL_FILENAME_JSON)) 153 with gfile.Open(model_json_filepath, 'w') as f: 154 f.write(model_json) 155 156 157def _export_model_variables(model, saved_model_path): 158 """Saves model weights in checkpoint format under variables folder.""" 159 _get_or_create_variables_dir(saved_model_path) 160 checkpoint_prefix = _get_variables_path(saved_model_path) 161 model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) 162 return checkpoint_prefix 163 164 165def _save_v1_format(model, path, custom_objects, as_text, input_signature): 166 """Exports model to v1 SavedModel format.""" 167 if not model._is_graph_network: # pylint: disable=protected-access 168 if isinstance(model, sequential.Sequential): 169 # If input shape is not directly set in the model, the exported model 170 # will infer the expected shapes of the input from the model. 171 if not model.built: 172 raise ValueError('Weights for sequential model have not yet been ' 173 'created. Weights are created when the Model is first ' 174 'called on inputs or `build()` is called with an ' 175 '`input_shape`, or the first layer in the model has ' 176 '`input_shape` during construction.') 177 # TODO(kathywu): Build the model with input_signature to create the 178 # weights before _export_model_variables(). 179 else: 180 raise NotImplementedError( 181 'Subclassed models can only be exported for serving. Please set ' 182 'argument serving_only=True.') 183 184 builder = saved_model_builder._SavedModelBuilder(path) # pylint: disable=protected-access 185 186 # Manually save variables to export them in an object-based checkpoint. This 187 # skips the `builder.add_meta_graph_and_variables()` step, which saves a 188 # named-based checkpoint. 189 # TODO(b/113134168): Add fn to Builder to save with object-based saver. 190 # TODO(b/113178242): This should only export the model json structure. Only 191 # one save is needed once the weights can be copied from the model to clone. 192 checkpoint_path = _export_model_variables(model, path) 193 194 # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that 195 # Keras models and `Estimator`s are exported with the same format. 196 # Every time a mode is exported, the code checks to see if new variables have 197 # been created (e.g. optimizer slot variables). If that is the case, the 198 # checkpoint is re-saved to include the new variables. 199 export_args = {'builder': builder, 200 'model': model, 201 'custom_objects': custom_objects, 202 'checkpoint_path': checkpoint_path, 203 'input_signature': input_signature} 204 205 has_saved_vars = False 206 if model.optimizer: 207 if isinstance(model.optimizer, (optimizer_v1.TFOptimizer, 208 optimizer_v2.OptimizerV2)): 209 _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args) 210 has_saved_vars = True 211 _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args) 212 else: 213 logging.warning( 214 'Model was compiled with an optimizer, but the optimizer is not from ' 215 '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' 216 'graph was exported. The train and evaluate graphs were not added to ' 217 'the SavedModel.') 218 _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args) 219 220 builder.save(as_text) 221 222 223def _get_var_list(model): 224 """Returns list of all checkpointed saveable objects in the model.""" 225 var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph() 226 return var_list 227 228 229def create_placeholder(spec): 230 return backend.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) 231 232 233def _export_mode( 234 mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, 235 input_signature): 236 """Exports a model, and optionally saves new vars from the clone model. 237 238 Args: 239 mode: A `tf.estimator.ModeKeys` string. 240 has_saved_vars: A `boolean` indicating whether the SavedModel has already 241 exported variables. 242 builder: A `SavedModelBuilder` object. 243 model: A `tf.keras.Model` object. 244 custom_objects: A dictionary mapping string names to custom classes 245 or functions. 246 checkpoint_path: String path to checkpoint. 247 input_signature: Nested TensorSpec containing the expected inputs. Can be 248 `None`, in which case the signature will be inferred from the model. 249 250 Raises: 251 ValueError: If the train/eval mode is being exported, but the model does 252 not have an optimizer. 253 """ 254 compile_clone = (mode != mode_keys.ModeKeys.PREDICT) 255 if compile_clone and not model.optimizer: 256 raise ValueError( 257 'Model does not have an optimizer. Cannot export mode %s' % mode) 258 259 model_graph = ops.get_default_graph() 260 with ops.Graph().as_default() as g, backend.learning_phase_scope( 261 mode == mode_keys.ModeKeys.TRAIN): 262 263 if input_signature is None: 264 input_tensors = None 265 else: 266 input_tensors = nest.map_structure(create_placeholder, input_signature) 267 268 # Clone the model into blank graph. This will create placeholders for inputs 269 # and targets. 270 clone = models_lib.clone_and_build_model( 271 model, input_tensors=input_tensors, custom_objects=custom_objects, 272 compile_clone=compile_clone) 273 274 # Make sure that iterations variable is added to the global step collection, 275 # to ensure that, when the SavedModel graph is loaded, the iterations 276 # variable is returned by `tf.compat.v1.train.get_global_step()`. This is 277 # required for compatibility with the SavedModelEstimator. 278 if compile_clone: 279 g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) 280 281 # Extract update and train ops from train/test/predict functions. 282 train_op = None 283 if mode == mode_keys.ModeKeys.TRAIN: 284 clone._make_train_function() # pylint: disable=protected-access 285 train_op = clone.train_function.updates_op 286 elif mode == mode_keys.ModeKeys.TEST: 287 clone._make_test_function() # pylint: disable=protected-access 288 else: 289 clone._make_predict_function() # pylint: disable=protected-access 290 g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) 291 292 with session.Session().as_default(): 293 clone_var_list = _get_var_list(clone) 294 if has_saved_vars: 295 # Confirm all variables in the clone have an entry in the checkpoint. 296 status = clone.load_weights(checkpoint_path) 297 status.assert_existing_objects_matched() 298 else: 299 # Confirm that variables between the clone and model match up exactly, 300 # not counting optimizer objects. Optimizer objects are ignored because 301 # if the model has not trained, the slot variables will not have been 302 # created yet. 303 # TODO(b/113179535): Replace with trackable equivalence. 304 _assert_same_non_optimizer_objects(model, model_graph, clone, g) 305 306 # TODO(b/113178242): Use value transfer for trackable objects. 307 clone.load_weights(checkpoint_path) 308 309 # Add graph and variables to SavedModel. 310 # TODO(b/113134168): Switch to add_meta_graph_and_variables. 311 clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) 312 builder._has_saved_variables = True # pylint: disable=protected-access 313 314 # Add graph to the SavedModel builder. 315 builder.add_meta_graph( 316 model_utils.EXPORT_TAG_MAP[mode], 317 signature_def_map=_create_signature_def_map(clone, mode), 318 saver=saver_lib.Saver( 319 clone_var_list, 320 # Allow saving Models with no variables. This is somewhat odd, but 321 # it's not necessarily a bug. 322 allow_empty=True), 323 init_op=variables.local_variables_initializer(), 324 train_op=train_op) 325 return None 326 327 328def _create_signature_def_map(model, mode): 329 """Creates a SignatureDef map from a Keras model.""" 330 inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} 331 if model.optimizer: 332 targets_dict = {x.name.split(':')[0]: x 333 for x in model._targets if x is not None} # pylint: disable=protected-access 334 inputs_dict.update(targets_dict) 335 outputs_dict = {name: x 336 for name, x in zip(model.output_names, model.outputs)} 337 metrics = saving_utils.extract_model_metrics(model) 338 339 # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables 340 # are by default not added to any collections. We are doing this here, so 341 # that metric variables get initialized. 342 local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 343 vars_to_add = set() 344 if metrics is not None: 345 for key, value in metrics.items(): 346 if isinstance(value, metrics_lib.Metric): 347 vars_to_add.update(value.variables) 348 # Convert Metric instances to (value_tensor, update_op) tuple. 349 metrics[key] = (value.result(), value.updates[0]) 350 # Remove variables that are in the local variables collection already. 351 vars_to_add = vars_to_add.difference(local_vars) 352 for v in vars_to_add: 353 ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) 354 355 export_outputs = model_utils.export_outputs_for_mode( 356 mode, 357 predictions=outputs_dict, 358 loss=model.total_loss if model.optimizer else None, 359 metrics=metrics) 360 return model_utils.build_all_signature_defs( 361 inputs_dict, 362 export_outputs=export_outputs, 363 serving_only=(mode == mode_keys.ModeKeys.PREDICT)) 364 365 366def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument 367 """Asserts model and clone contain the same trackable objects.""" 368 369 # TODO(fchollet, kathywu): make sure this works in eager mode. 370 return True 371 372 373@keras_export(v1=['keras.experimental.load_from_saved_model']) 374def load_from_saved_model(saved_model_path, custom_objects=None): 375 """Loads a keras Model from a SavedModel created by `export_saved_model()`. 376 377 This function reinstantiates model state by: 378 1) loading model topology from json (this will eventually come 379 from metagraph). 380 2) loading model weights from checkpoint. 381 382 Example: 383 384 ```python 385 import tensorflow as tf 386 387 # Create a tf.keras model. 388 model = tf.keras.Sequential() 389 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 390 model.summary() 391 392 # Save the tf.keras model in the SavedModel format. 393 path = '/tmp/simple_keras_model' 394 tf.keras.experimental.export_saved_model(model, path) 395 396 # Load the saved keras model back. 397 new_model = tf.keras.experimental.load_from_saved_model(path) 398 new_model.summary() 399 ``` 400 401 Args: 402 saved_model_path: a string specifying the path to an existing SavedModel. 403 custom_objects: Optional dictionary mapping names 404 (strings) to custom classes or functions to be 405 considered during deserialization. 406 407 Returns: 408 a keras.Model instance. 409 """ 410 warnings.warn('`tf.keras.experimental.load_from_saved_model` is deprecated' 411 'and will be removed in a future version. ' 412 'Please switch to `tf.keras.models.load_model`.') 413 # restore model topology from json string 414 model_json_filepath = os.path.join( 415 compat.as_bytes(saved_model_path), 416 compat.as_bytes(constants.ASSETS_DIRECTORY), 417 compat.as_bytes(SAVED_MODEL_FILENAME_JSON)) 418 with gfile.Open(model_json_filepath, 'r') as f: 419 model_json = f.read() 420 model = model_config.model_from_json( 421 model_json, custom_objects=custom_objects) 422 423 # restore model weights 424 checkpoint_prefix = os.path.join( 425 compat.as_text(saved_model_path), 426 compat.as_text(constants.VARIABLES_DIRECTORY), 427 compat.as_text(constants.VARIABLES_FILENAME)) 428 model.load_weights(checkpoint_prefix) 429 return model 430 431 432#### Directory / path helpers 433 434 435def _get_or_create_variables_dir(export_dir): 436 """Return variables sub-directory, or create one if it doesn't exist.""" 437 variables_dir = _get_variables_dir(export_dir) 438 file_io.recursive_create_dir(variables_dir) 439 return variables_dir 440 441 442def _get_variables_dir(export_dir): 443 """Return variables sub-directory in the SavedModel.""" 444 return os.path.join( 445 compat.as_text(export_dir), 446 compat.as_text(constants.VARIABLES_DIRECTORY)) 447 448 449def _get_variables_path(export_dir): 450 """Return the variables path, used as the prefix for checkpoint files.""" 451 return os.path.join( 452 compat.as_text(_get_variables_dir(export_dir)), 453 compat.as_text(constants.VARIABLES_FILENAME)) 454 455 456def _get_or_create_assets_dir(export_dir): 457 """Return assets sub-directory, or create one if it doesn't exist.""" 458 assets_destination_dir = _get_assets_dir(export_dir) 459 460 file_io.recursive_create_dir(assets_destination_dir) 461 462 return assets_destination_dir 463 464 465def _get_assets_dir(export_dir): 466 """Return path to asset directory in the SavedModel.""" 467 return os.path.join( 468 compat.as_text(export_dir), 469 compat.as_text(constants.ASSETS_DIRECTORY)) 470