1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Interface that provides access to Keras dependencies. 17 18This library is a common interface that contains Keras functions needed by 19TensorFlow and TensorFlow Lite and is required as per the dependency inversion 20principle (https://en.wikipedia.org/wiki/Dependency_inversion_principle). As per 21this principle, high-level modules (eg: TensorFlow and TensorFlow Lite) should 22not depend on low-level modules (eg: Keras) and instead both should depend on a 23common interface such as this file. 24""" 25 26 27from tensorflow.python.util.tf_export import tf_export 28 29_KERAS_CALL_CONTEXT_FUNCTION = None 30_KERAS_CLEAR_SESSION_FUNCTION = None 31_KERAS_GET_SESSION_FUNCTION = None 32_KERAS_LOAD_MODEL_FUNCTION = None 33 34# TODO(scottzhu): Disable duplicated inject once keras is moved to 35# third_party/py/keras. 36# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF 37 38 39# Register functions 40@tf_export('__internal__.register_call_context_function', v1=[]) 41def register_call_context_function(func): 42 global _KERAS_CALL_CONTEXT_FUNCTION 43 _KERAS_CALL_CONTEXT_FUNCTION = func 44 45 46@tf_export('__internal__.register_clear_session_function', v1=[]) 47def register_clear_session_function(func): 48 global _KERAS_CLEAR_SESSION_FUNCTION 49 _KERAS_CLEAR_SESSION_FUNCTION = func 50 51 52@tf_export('__internal__.register_get_session_function', v1=[]) 53def register_get_session_function(func): 54 global _KERAS_GET_SESSION_FUNCTION 55 _KERAS_GET_SESSION_FUNCTION = func 56 57 58@tf_export('__internal__.register_load_model_function', v1=[]) 59def register_load_model_function(func): 60 global _KERAS_LOAD_MODEL_FUNCTION 61 _KERAS_LOAD_MODEL_FUNCTION = func 62 63 64# Get functions 65def get_call_context_function(): 66 global _KERAS_CALL_CONTEXT_FUNCTION 67 return _KERAS_CALL_CONTEXT_FUNCTION 68 69 70def get_clear_session_function(): 71 global _KERAS_CLEAR_SESSION_FUNCTION 72 return _KERAS_CLEAR_SESSION_FUNCTION 73 74 75def get_get_session_function(): 76 global _KERAS_GET_SESSION_FUNCTION 77 return _KERAS_GET_SESSION_FUNCTION 78 79 80def get_load_model_function(): 81 global _KERAS_LOAD_MODEL_FUNCTION 82 return _KERAS_LOAD_MODEL_FUNCTION 83