1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tracing utilities used by SavedModel.""" 16 17from tensorflow.python.checkpoint import saveable_compat 18from tensorflow.python.eager import def_function 19from tensorflow.python.eager import function as defun 20 21 22def trace_save_and_restore(obj): 23 """Traces `Trackable` serialize- and restore-from-tensors functions. 24 25 Args: 26 obj: A `Trackable` object. 27 28 Returns: 29 A concrete Function. 30 """ 31 legacy_name = saveable_compat.get_saveable_name(obj) 32 33 obj_save_fn = obj._serialize_to_tensors # pylint: disable=protected-access 34 obj_restore_fn = obj._restore_from_tensors # pylint: disable=protected-access 35 36 if isinstance(obj_save_fn, defun.ConcreteFunction): 37 concrete_save = obj_save_fn 38 else: 39 @def_function.function 40 def save_fn(): 41 tensor_dict = obj_save_fn() 42 if legacy_name: 43 # If there is a legacy decorator, append the name to the keys. 44 return {f"{legacy_name}{key}": value 45 for key, value in tensor_dict.items()} 46 return tensor_dict 47 48 concrete_save = save_fn.get_concrete_function() 49 50 if isinstance(obj_restore_fn, defun.ConcreteFunction): 51 concrete_restore = obj_restore_fn 52 else: 53 @def_function.function 54 def restore_fn(restored_tensors): 55 if legacy_name: 56 # Do the opposite operation of save_fn() 57 restored_tensors = {key[len(legacy_name):]: value 58 for key, value in restored_tensors.items()} 59 obj_restore_fn(restored_tensors) 60 61 concrete_restore = restore_fn.get_concrete_function( 62 concrete_save.structured_outputs) 63 64 return concrete_save, concrete_restore 65