1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Conditional expressions (e.g. the ternary if statement).""" 16 17 18from tensorflow.python.autograph.operators import control_flow 19from tensorflow.python.autograph.utils import tensors 20from tensorflow.python.ops import control_flow_ops 21 22 23def if_exp(cond, if_true, if_false, expr_repr): 24 if tensors.is_dense_tensor(cond): 25 return _tf_if_exp(cond, if_true, if_false, expr_repr) 26 else: 27 return _py_if_exp(cond, if_true, if_false) 28 29 30def _tf_if_exp(cond, if_true, if_false, expr_repr): 31 """Overload of if_exp that stages a TF cond.""" 32 # TODO(mdan): Use nonlocal once we no longer need to support py2. 33 true_val = [] 34 false_val = [] 35 36 def true_fn(): 37 true_val.append(if_true()) 38 if true_val and false_val: 39 control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0]) 40 return true_val[0] 41 42 def false_fn(): 43 false_val.append(if_false()) 44 if true_val and false_val: 45 control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0]) 46 return false_val[0] 47 48 return control_flow_ops.cond(cond, true_fn, false_fn) 49 50 51def _py_if_exp(cond, if_true, if_false): 52 return if_true() if cond else if_false() 53