1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains global variables related to mixed precision. 16 17This is not part of mixed_precision.py to avoid a circular dependency. 18mixed_precision.py depends on Session, and Session depends on this file. 19""" 20 21from tensorflow.python.util.tf_export import tf_export 22 23# Whether the mixed precision graph rewrite has been enabled or not with 24# `enable_mixed_precision_graph_rewrite`. Used to turn on auto_mixed_precision 25# in ConfigProtos passed to Sessions. 26_mixed_precision_graph_rewrite_is_enabled = False 27 28 29# True if a Session has been created without the mixed precision graph rewrite 30# being enabled. Used to give a warning if mixed precision is enabled after a 31# Session has already been created. 32_non_mixed_precision_session_created = False 33 34# Whether the global tf.keras.mixed_precision.Policy uses mixed precision. Used 35# to raise an error message if both a mixed Policy and the graph rewrite are 36# used at the same time. 37_using_mixed_precision_policy = False 38 39 40@tf_export('__internal__.train.is_mixed_precision_graph_rewrite_enabled', v1=[]) 41def is_mixed_precision_graph_rewrite_enabled(): 42 return _mixed_precision_graph_rewrite_is_enabled 43 44 45def set_mixed_precision_graph_rewrite_enabled(enabled): 46 global _mixed_precision_graph_rewrite_is_enabled 47 _mixed_precision_graph_rewrite_is_enabled = enabled 48 49 50def non_mixed_precision_session_created(): 51 return _non_mixed_precision_session_created 52 53 54def set_non_mixed_precision_session_created(created): 55 global _non_mixed_precision_session_created 56 _non_mixed_precision_session_created = created 57 58 59def is_using_mixed_precision_policy(): 60 return _using_mixed_precision_policy 61 62 63@tf_export('__internal__.train.set_using_mixed_precision_policy', v1=[]) 64def set_using_mixed_precision_policy(is_using): 65 global _using_mixed_precision_policy 66 _using_mixed_precision_policy = is_using 67