xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/losses_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to loss functions."""
17
18from tensorflow.python.distribute import distribution_strategy_context
19from tensorflow.python.framework import ops
20from tensorflow.python.keras import backend
21from tensorflow.python.keras.engine import keras_tensor
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.ragged import ragged_tensor
26from tensorflow.python.util.tf_export import keras_export
27
28
29@keras_export('keras.losses.Reduction', v1=[])
30class ReductionV2(object):
31  """Types of loss reduction.
32
33  Contains the following values:
34
35  * `AUTO`: Indicates that the reduction option will be determined by the usage
36     context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When
37     used with `tf.distribute.Strategy`, outside of built-in training loops such
38     as `tf.keras` `compile` and `fit`, we expect reduction value to be
39     `SUM` or `NONE`. Using `AUTO` in that case will raise an error.
40  * `NONE`: No **additional** reduction is applied to the output of the wrapped
41     loss function. When non-scalar losses are returned to Keras functions like
42     `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer
43     but the reported loss will be a scalar value.
44
45     Caution: **Verify the shape of the outputs when using** `Reduction.NONE`.
46     The builtin loss functions wrapped by the loss classes reduce
47     one dimension (`axis=-1`, or `axis` if specified by loss function).
48     `Reduction.NONE` just means that no **additional** reduction is applied by
49     the class wrapper. For categorical losses with an example input shape of
50     `[batch, W, H, n_classes]` the `n_classes` dimension is reduced. For
51     pointwise losses your must include a dummy axis so that `[batch, W, H, 1]`
52     is reduced to `[batch, W, H]`. Without the dummy axis `[batch, W, H]`
53     will be incorrectly reduced to `[batch, W]`.
54
55  * `SUM`: Scalar sum of weighted losses.
56  * `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
57     This reduction type is not supported when used with
58     `tf.distribute.Strategy` outside of built-in training loops like `tf.keras`
59     `compile`/`fit`.
60
61     You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:
62     ```
63     with strategy.scope():
64       loss_obj = tf.keras.losses.CategoricalCrossentropy(
65           reduction=tf.keras.losses.Reduction.NONE)
66       ....
67       loss = tf.reduce_sum(loss_obj(labels, predictions)) *
68           (1. / global_batch_size)
69     ```
70
71  Please see the [custom training guide](
72  https://www.tensorflow.org/tutorials/distribute/custom_training) for more
73  details on this.
74  """
75
76  AUTO = 'auto'
77  NONE = 'none'
78  SUM = 'sum'
79  SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
80
81  @classmethod
82  def all(cls):
83    return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
84
85  @classmethod
86  def validate(cls, key):
87    if key not in cls.all():
88      raise ValueError('Invalid Reduction Key %s.' % key)
89
90
91def remove_squeezable_dimensions(
92    labels, predictions, expected_rank_diff=0, name=None):
93  """Squeeze last dim if ranks differ from expected by exactly 1.
94
95  In the common case where we expect shapes to match, `expected_rank_diff`
96  defaults to 0, and we squeeze the last dimension of the larger rank if they
97  differ by 1.
98
99  But, for example, if `labels` contains class IDs and `predictions` contains 1
100  probability per class, we expect `predictions` to have 1 more dimension than
101  `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
102  `labels` if `rank(predictions) - rank(labels) == 0`, and
103  `predictions` if `rank(predictions) - rank(labels) == 2`.
104
105  This will use static shape if available. Otherwise, it will add graph
106  operations, which could result in a performance hit.
107
108  Args:
109    labels: Label values, a `Tensor` whose dimensions match `predictions`.
110    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
111    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
112    name: Name of the op.
113
114  Returns:
115    Tuple of `labels` and `predictions`, possibly with last dim squeezed.
116  """
117  with backend.name_scope(name or 'remove_squeezable_dimensions'):
118    if not isinstance(predictions, ragged_tensor.RaggedTensor):
119      predictions = ops.convert_to_tensor_v2_with_dispatch(predictions)
120    if not isinstance(labels, ragged_tensor.RaggedTensor):
121      labels = ops.convert_to_tensor_v2_with_dispatch(labels)
122    predictions_shape = predictions.shape
123    predictions_rank = predictions_shape.ndims
124    labels_shape = labels.shape
125    labels_rank = labels_shape.ndims
126    if (labels_rank is not None) and (predictions_rank is not None):
127      # Use static rank.
128      rank_diff = predictions_rank - labels_rank
129      if (rank_diff == expected_rank_diff + 1 and
130          predictions_shape.dims[-1].is_compatible_with(1)):
131        predictions = array_ops.squeeze(predictions, [-1])
132      elif (rank_diff == expected_rank_diff - 1 and
133            labels_shape.dims[-1].is_compatible_with(1)):
134        labels = array_ops.squeeze(labels, [-1])
135      return labels, predictions
136
137    # Use dynamic rank.
138    rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
139    if (predictions_rank is None) or (
140        predictions_shape.dims[-1].is_compatible_with(1)):
141      predictions = control_flow_ops.cond(
142          math_ops.equal(expected_rank_diff + 1, rank_diff),
143          lambda: array_ops.squeeze(predictions, [-1]),
144          lambda: predictions)
145    if (labels_rank is None) or (
146        labels_shape.dims[-1].is_compatible_with(1)):
147      labels = control_flow_ops.cond(
148          math_ops.equal(expected_rank_diff - 1, rank_diff),
149          lambda: array_ops.squeeze(labels, [-1]),
150          lambda: labels)
151    return labels, predictions
152
153
154def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None):
155  """Squeeze or expand last dimension if needed.
156
157  1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
158  (using `remove_squeezable_dimensions`).
159  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
160  from the new rank of `y_pred`.
161  If `sample_weight` is scalar, it is kept scalar.
162
163  This will use static shape if available. Otherwise, it will add graph
164  operations, which could result in a performance hit.
165
166  Args:
167    y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
168    y_true: Optional label `Tensor` whose dimensions match `y_pred`.
169    sample_weight: Optional weight scalar or `Tensor` whose dimensions match
170      `y_pred`.
171
172  Returns:
173    Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
174    the last dimension squeezed,
175    `sample_weight` could be extended by one dimension.
176    If `sample_weight` is None, (y_pred, y_true) is returned.
177  """
178  y_pred_shape = y_pred.shape
179  y_pred_rank = y_pred_shape.ndims
180  if y_true is not None:
181
182    # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
183    # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
184    # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
185    # In this case, we should not try to remove squeezable dimension.
186    y_true_shape = y_true.shape
187    y_true_rank = y_true_shape.ndims
188    if (y_true_rank is not None) and (y_pred_rank is not None):
189      # Use static rank for `y_true` and `y_pred`.
190      if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
191        y_true, y_pred = remove_squeezable_dimensions(
192            y_true, y_pred)
193    else:
194      # Use dynamic rank.
195      rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true)
196      squeeze_dims = lambda: remove_squeezable_dimensions(  # pylint: disable=g-long-lambda
197          y_true, y_pred)
198      is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1])
199      maybe_squeeze_dims = lambda: control_flow_ops.cond(  # pylint: disable=g-long-lambda
200          is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
201      y_true, y_pred = control_flow_ops.cond(
202          math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)
203
204  if sample_weight is None:
205    return y_pred, y_true
206
207  weights_shape = sample_weight.shape
208  weights_rank = weights_shape.ndims
209  if weights_rank == 0:  # If weights is scalar, do nothing.
210    return y_pred, y_true, sample_weight
211
212  if (y_pred_rank is not None) and (weights_rank is not None):
213    # Use static rank.
214    if weights_rank - y_pred_rank == 1:
215      sample_weight = array_ops.squeeze(sample_weight, [-1])
216    elif y_pred_rank - weights_rank == 1:
217      sample_weight = array_ops.expand_dims(sample_weight, [-1])
218    return y_pred, y_true, sample_weight
219
220  # Use dynamic rank.
221  weights_rank_tensor = array_ops.rank(sample_weight)
222  rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
223  maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
224
225  def _maybe_expand_weights():
226    expand_weights = lambda: array_ops.expand_dims(sample_weight, [-1])
227    return control_flow_ops.cond(
228        math_ops.equal(rank_diff, -1), expand_weights, lambda: sample_weight)
229
230  def _maybe_adjust_weights():
231    return control_flow_ops.cond(
232        math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
233        _maybe_expand_weights)
234
235  # squeeze or expand last dim of `sample_weight` if its rank differs by 1
236  # from the new rank of `y_pred`.
237  sample_weight = control_flow_ops.cond(
238      math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
239      _maybe_adjust_weights)
240  return y_pred, y_true, sample_weight
241
242
243def _safe_mean(losses, num_present):
244  """Computes a safe mean of the losses.
245
246  Args:
247    losses: `Tensor` whose elements contain individual loss measurements.
248    num_present: The number of measurable elements in `losses`.
249
250  Returns:
251    A scalar representing the mean of `losses`. If `num_present` is zero,
252      then zero is returned.
253  """
254  total_loss = math_ops.reduce_sum(losses)
255  return math_ops.div_no_nan(total_loss, num_present, name='value')
256
257
258def _num_elements(losses):
259  """Computes the number of elements in `losses` tensor."""
260  with backend.name_scope('num_elements') as scope:
261    return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
262
263
264def reduce_weighted_loss(weighted_losses,
265                         reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
266  """Reduces the individual weighted loss measurements."""
267  if reduction == ReductionV2.NONE:
268    loss = weighted_losses
269  else:
270    loss = math_ops.reduce_sum(weighted_losses)
271    if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
272      loss = _safe_mean(loss, _num_elements(weighted_losses))
273  return loss
274
275
276@keras_export('keras.__internal__.losses.compute_weighted_loss', v1=[])
277def compute_weighted_loss(losses,
278                          sample_weight=None,
279                          reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
280                          name=None):
281  """Computes the weighted loss.
282
283  Args:
284    losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
285    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
286      `losses`, or be broadcastable to `losses`.
287    reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss.
288      Default value is `SUM_OVER_BATCH_SIZE`.
289    name: Optional name for the op.
290
291  Raises:
292    ValueError: If the shape of `sample_weight` is not compatible with `losses`.
293
294  Returns:
295    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
296    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
297  """
298  ReductionV2.validate(reduction)
299
300  # If this function is called directly, then we just default 'AUTO' to
301  # 'SUM_OVER_BATCH_SIZE'. Eg. Canned estimator use cases.
302  if reduction == ReductionV2.AUTO:
303    reduction = ReductionV2.SUM_OVER_BATCH_SIZE
304  if sample_weight is None:
305    sample_weight = 1.0
306  with backend.name_scope(name or 'weighted_loss'):
307    # Save the `reduction` argument for loss normalization when distributing
308    # to multiple replicas. Used only for estimator + v1 optimizer flow.
309    ops.get_default_graph()._last_loss_reduction = reduction  # pylint: disable=protected-access
310
311    if not isinstance(losses,
312                      (keras_tensor.KerasTensor, ragged_tensor.RaggedTensor)):
313      losses = ops.convert_to_tensor_v2_with_dispatch(losses)
314    input_dtype = losses.dtype
315
316    if not isinstance(sample_weight, keras_tensor.KerasTensor):
317      sample_weight = ops.convert_to_tensor_v2_with_dispatch(sample_weight)
318
319    # TODO(psv): Handle casting here in a better way, eg. if losses is float64
320    # we do not want to lose precision.
321    losses = math_ops.cast(losses, 'float32')
322    sample_weight = math_ops.cast(sample_weight, 'float32')
323    # Update dimensions of `sample_weight` to match with `losses` if possible.
324    losses, _, sample_weight = squeeze_or_expand_dimensions(  # pylint: disable=unbalanced-tuple-unpacking
325        losses, None, sample_weight)
326    weighted_losses = math_ops.multiply(losses, sample_weight)
327
328    # Apply reduction function to the individual weighted losses.
329    loss = reduce_weighted_loss(weighted_losses, reduction)
330    # Convert the result back to the input type.
331    loss = math_ops.cast(loss, input_dtype)
332    return loss
333
334
335def scale_loss_for_distribution(loss_value):
336  """Scales and returns the given loss value by the number of replicas."""
337  num_replicas = (
338      distribution_strategy_context.get_strategy().num_replicas_in_sync)
339  if num_replicas > 1:
340    loss_value *= (1. / num_replicas)
341  return loss_value
342
343
344def cast_losses_to_common_dtype(losses):
345  """Cast a list of losses to a common dtype.
346
347  If any loss is floating-point, they will all be casted to the most-precise
348  floating-point loss. Otherwise the losses are not casted. We also skip casting
349  losses if there are any complex losses.
350
351  Args:
352    losses: A list of losses.
353
354  Returns:
355    `losses`, but they have been casted to a common dtype.
356  """
357  highest_float = None
358  for loss in losses:
359    if loss.dtype.is_floating:
360      if highest_float is None or loss.dtype.size > highest_float.size:
361        highest_float = loss.dtype
362      elif {loss.dtype, highest_float} == {'bfloat16', 'float16'}:
363        highest_float = 'float32'
364    if loss.dtype.is_complex:
365      return losses  # If we find any complex losses, do not cast any losses
366  if highest_float:
367    losses = [math_ops.cast(loss, highest_float) for loss in losses]
368  return losses
369