1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to loss functions.""" 17 18from tensorflow.python.distribute import distribution_strategy_context 19from tensorflow.python.framework import ops 20from tensorflow.python.keras import backend 21from tensorflow.python.keras.engine import keras_tensor 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.ragged import ragged_tensor 26from tensorflow.python.util.tf_export import keras_export 27 28 29@keras_export('keras.losses.Reduction', v1=[]) 30class ReductionV2(object): 31 """Types of loss reduction. 32 33 Contains the following values: 34 35 * `AUTO`: Indicates that the reduction option will be determined by the usage 36 context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When 37 used with `tf.distribute.Strategy`, outside of built-in training loops such 38 as `tf.keras` `compile` and `fit`, we expect reduction value to be 39 `SUM` or `NONE`. Using `AUTO` in that case will raise an error. 40 * `NONE`: No **additional** reduction is applied to the output of the wrapped 41 loss function. When non-scalar losses are returned to Keras functions like 42 `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer 43 but the reported loss will be a scalar value. 44 45 Caution: **Verify the shape of the outputs when using** `Reduction.NONE`. 46 The builtin loss functions wrapped by the loss classes reduce 47 one dimension (`axis=-1`, or `axis` if specified by loss function). 48 `Reduction.NONE` just means that no **additional** reduction is applied by 49 the class wrapper. For categorical losses with an example input shape of 50 `[batch, W, H, n_classes]` the `n_classes` dimension is reduced. For 51 pointwise losses your must include a dummy axis so that `[batch, W, H, 1]` 52 is reduced to `[batch, W, H]`. Without the dummy axis `[batch, W, H]` 53 will be incorrectly reduced to `[batch, W]`. 54 55 * `SUM`: Scalar sum of weighted losses. 56 * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. 57 This reduction type is not supported when used with 58 `tf.distribute.Strategy` outside of built-in training loops like `tf.keras` 59 `compile`/`fit`. 60 61 You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like: 62 ``` 63 with strategy.scope(): 64 loss_obj = tf.keras.losses.CategoricalCrossentropy( 65 reduction=tf.keras.losses.Reduction.NONE) 66 .... 67 loss = tf.reduce_sum(loss_obj(labels, predictions)) * 68 (1. / global_batch_size) 69 ``` 70 71 Please see the [custom training guide]( 72 https://www.tensorflow.org/tutorials/distribute/custom_training) for more 73 details on this. 74 """ 75 76 AUTO = 'auto' 77 NONE = 'none' 78 SUM = 'sum' 79 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 80 81 @classmethod 82 def all(cls): 83 return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE) 84 85 @classmethod 86 def validate(cls, key): 87 if key not in cls.all(): 88 raise ValueError('Invalid Reduction Key %s.' % key) 89 90 91def remove_squeezable_dimensions( 92 labels, predictions, expected_rank_diff=0, name=None): 93 """Squeeze last dim if ranks differ from expected by exactly 1. 94 95 In the common case where we expect shapes to match, `expected_rank_diff` 96 defaults to 0, and we squeeze the last dimension of the larger rank if they 97 differ by 1. 98 99 But, for example, if `labels` contains class IDs and `predictions` contains 1 100 probability per class, we expect `predictions` to have 1 more dimension than 101 `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze 102 `labels` if `rank(predictions) - rank(labels) == 0`, and 103 `predictions` if `rank(predictions) - rank(labels) == 2`. 104 105 This will use static shape if available. Otherwise, it will add graph 106 operations, which could result in a performance hit. 107 108 Args: 109 labels: Label values, a `Tensor` whose dimensions match `predictions`. 110 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 111 expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`. 112 name: Name of the op. 113 114 Returns: 115 Tuple of `labels` and `predictions`, possibly with last dim squeezed. 116 """ 117 with backend.name_scope(name or 'remove_squeezable_dimensions'): 118 if not isinstance(predictions, ragged_tensor.RaggedTensor): 119 predictions = ops.convert_to_tensor_v2_with_dispatch(predictions) 120 if not isinstance(labels, ragged_tensor.RaggedTensor): 121 labels = ops.convert_to_tensor_v2_with_dispatch(labels) 122 predictions_shape = predictions.shape 123 predictions_rank = predictions_shape.ndims 124 labels_shape = labels.shape 125 labels_rank = labels_shape.ndims 126 if (labels_rank is not None) and (predictions_rank is not None): 127 # Use static rank. 128 rank_diff = predictions_rank - labels_rank 129 if (rank_diff == expected_rank_diff + 1 and 130 predictions_shape.dims[-1].is_compatible_with(1)): 131 predictions = array_ops.squeeze(predictions, [-1]) 132 elif (rank_diff == expected_rank_diff - 1 and 133 labels_shape.dims[-1].is_compatible_with(1)): 134 labels = array_ops.squeeze(labels, [-1]) 135 return labels, predictions 136 137 # Use dynamic rank. 138 rank_diff = array_ops.rank(predictions) - array_ops.rank(labels) 139 if (predictions_rank is None) or ( 140 predictions_shape.dims[-1].is_compatible_with(1)): 141 predictions = control_flow_ops.cond( 142 math_ops.equal(expected_rank_diff + 1, rank_diff), 143 lambda: array_ops.squeeze(predictions, [-1]), 144 lambda: predictions) 145 if (labels_rank is None) or ( 146 labels_shape.dims[-1].is_compatible_with(1)): 147 labels = control_flow_ops.cond( 148 math_ops.equal(expected_rank_diff - 1, rank_diff), 149 lambda: array_ops.squeeze(labels, [-1]), 150 lambda: labels) 151 return labels, predictions 152 153 154def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): 155 """Squeeze or expand last dimension if needed. 156 157 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1 158 (using `remove_squeezable_dimensions`). 159 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1 160 from the new rank of `y_pred`. 161 If `sample_weight` is scalar, it is kept scalar. 162 163 This will use static shape if available. Otherwise, it will add graph 164 operations, which could result in a performance hit. 165 166 Args: 167 y_pred: Predicted values, a `Tensor` of arbitrary dimensions. 168 y_true: Optional label `Tensor` whose dimensions match `y_pred`. 169 sample_weight: Optional weight scalar or `Tensor` whose dimensions match 170 `y_pred`. 171 172 Returns: 173 Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has 174 the last dimension squeezed, 175 `sample_weight` could be extended by one dimension. 176 If `sample_weight` is None, (y_pred, y_true) is returned. 177 """ 178 y_pred_shape = y_pred.shape 179 y_pred_rank = y_pred_shape.ndims 180 if y_true is not None: 181 182 # If sparse matrix is provided as `y_true`, the last dimension in `y_pred` 183 # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)), 184 # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3)) 185 # In this case, we should not try to remove squeezable dimension. 186 y_true_shape = y_true.shape 187 y_true_rank = y_true_shape.ndims 188 if (y_true_rank is not None) and (y_pred_rank is not None): 189 # Use static rank for `y_true` and `y_pred`. 190 if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: 191 y_true, y_pred = remove_squeezable_dimensions( 192 y_true, y_pred) 193 else: 194 # Use dynamic rank. 195 rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true) 196 squeeze_dims = lambda: remove_squeezable_dimensions( # pylint: disable=g-long-lambda 197 y_true, y_pred) 198 is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1]) 199 maybe_squeeze_dims = lambda: control_flow_ops.cond( # pylint: disable=g-long-lambda 200 is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred)) 201 y_true, y_pred = control_flow_ops.cond( 202 math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims) 203 204 if sample_weight is None: 205 return y_pred, y_true 206 207 weights_shape = sample_weight.shape 208 weights_rank = weights_shape.ndims 209 if weights_rank == 0: # If weights is scalar, do nothing. 210 return y_pred, y_true, sample_weight 211 212 if (y_pred_rank is not None) and (weights_rank is not None): 213 # Use static rank. 214 if weights_rank - y_pred_rank == 1: 215 sample_weight = array_ops.squeeze(sample_weight, [-1]) 216 elif y_pred_rank - weights_rank == 1: 217 sample_weight = array_ops.expand_dims(sample_weight, [-1]) 218 return y_pred, y_true, sample_weight 219 220 # Use dynamic rank. 221 weights_rank_tensor = array_ops.rank(sample_weight) 222 rank_diff = weights_rank_tensor - array_ops.rank(y_pred) 223 maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1]) 224 225 def _maybe_expand_weights(): 226 expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1]) 227 return control_flow_ops.cond( 228 math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight) 229 230 def _maybe_adjust_weights(): 231 return control_flow_ops.cond( 232 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 233 _maybe_expand_weights) 234 235 # squeeze or expand last dim of `sample_weight` if its rank differs by 1 236 # from the new rank of `y_pred`. 237 sample_weight = control_flow_ops.cond( 238 math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight, 239 _maybe_adjust_weights) 240 return y_pred, y_true, sample_weight 241 242 243def _safe_mean(losses, num_present): 244 """Computes a safe mean of the losses. 245 246 Args: 247 losses: `Tensor` whose elements contain individual loss measurements. 248 num_present: The number of measurable elements in `losses`. 249 250 Returns: 251 A scalar representing the mean of `losses`. If `num_present` is zero, 252 then zero is returned. 253 """ 254 total_loss = math_ops.reduce_sum(losses) 255 return math_ops.div_no_nan(total_loss, num_present, name='value') 256 257 258def _num_elements(losses): 259 """Computes the number of elements in `losses` tensor.""" 260 with backend.name_scope('num_elements') as scope: 261 return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) 262 263 264def reduce_weighted_loss(weighted_losses, 265 reduction=ReductionV2.SUM_OVER_BATCH_SIZE): 266 """Reduces the individual weighted loss measurements.""" 267 if reduction == ReductionV2.NONE: 268 loss = weighted_losses 269 else: 270 loss = math_ops.reduce_sum(weighted_losses) 271 if reduction == ReductionV2.SUM_OVER_BATCH_SIZE: 272 loss = _safe_mean(loss, _num_elements(weighted_losses)) 273 return loss 274 275 276@keras_export('keras.__internal__.losses.compute_weighted_loss', v1=[]) 277def compute_weighted_loss(losses, 278 sample_weight=None, 279 reduction=ReductionV2.SUM_OVER_BATCH_SIZE, 280 name=None): 281 """Computes the weighted loss. 282 283 Args: 284 losses: `Tensor` of shape `[batch_size, d1, ... dN]`. 285 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 286 `losses`, or be broadcastable to `losses`. 287 reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. 288 Default value is `SUM_OVER_BATCH_SIZE`. 289 name: Optional name for the op. 290 291 Raises: 292 ValueError: If the shape of `sample_weight` is not compatible with `losses`. 293 294 Returns: 295 Weighted loss `Tensor` of the same type as `losses`. If `reduction` is 296 `NONE`, this has the same shape as `losses`; otherwise, it is scalar. 297 """ 298 ReductionV2.validate(reduction) 299 300 # If this function is called directly, then we just default 'AUTO' to 301 # 'SUM_OVER_BATCH_SIZE'. Eg. Canned estimator use cases. 302 if reduction == ReductionV2.AUTO: 303 reduction = ReductionV2.SUM_OVER_BATCH_SIZE 304 if sample_weight is None: 305 sample_weight = 1.0 306 with backend.name_scope(name or 'weighted_loss'): 307 # Save the `reduction` argument for loss normalization when distributing 308 # to multiple replicas. Used only for estimator + v1 optimizer flow. 309 ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access 310 311 if not isinstance(losses, 312 (keras_tensor.KerasTensor, ragged_tensor.RaggedTensor)): 313 losses = ops.convert_to_tensor_v2_with_dispatch(losses) 314 input_dtype = losses.dtype 315 316 if not isinstance(sample_weight, keras_tensor.KerasTensor): 317 sample_weight = ops.convert_to_tensor_v2_with_dispatch(sample_weight) 318 319 # TODO(psv): Handle casting here in a better way, eg. if losses is float64 320 # we do not want to lose precision. 321 losses = math_ops.cast(losses, 'float32') 322 sample_weight = math_ops.cast(sample_weight, 'float32') 323 # Update dimensions of `sample_weight` to match with `losses` if possible. 324 losses, _, sample_weight = squeeze_or_expand_dimensions( # pylint: disable=unbalanced-tuple-unpacking 325 losses, None, sample_weight) 326 weighted_losses = math_ops.multiply(losses, sample_weight) 327 328 # Apply reduction function to the individual weighted losses. 329 loss = reduce_weighted_loss(weighted_losses, reduction) 330 # Convert the result back to the input type. 331 loss = math_ops.cast(loss, input_dtype) 332 return loss 333 334 335def scale_loss_for_distribution(loss_value): 336 """Scales and returns the given loss value by the number of replicas.""" 337 num_replicas = ( 338 distribution_strategy_context.get_strategy().num_replicas_in_sync) 339 if num_replicas > 1: 340 loss_value *= (1. / num_replicas) 341 return loss_value 342 343 344def cast_losses_to_common_dtype(losses): 345 """Cast a list of losses to a common dtype. 346 347 If any loss is floating-point, they will all be casted to the most-precise 348 floating-point loss. Otherwise the losses are not casted. We also skip casting 349 losses if there are any complex losses. 350 351 Args: 352 losses: A list of losses. 353 354 Returns: 355 `losses`, but they have been casted to a common dtype. 356 """ 357 highest_float = None 358 for loss in losses: 359 if loss.dtype.is_floating: 360 if highest_float is None or loss.dtype.size > highest_float.size: 361 highest_float = loss.dtype 362 elif {loss.dtype, highest_float} == {'bfloat16', 'float16'}: 363 highest_float = 'float32' 364 if loss.dtype.is_complex: 365 return losses # If we find any complex losses, do not cast any losses 366 if highest_float: 367 losses = [math_ops.cast(loss, highest_float) for loss in losses] 368 return losses 369