xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/nn_ops/losses_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for losses."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import errors_impl
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import random_seed
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import init_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import random_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.ops import variables
31from tensorflow.python.ops.losses import losses
32from tensorflow.python.ops.losses import util
33from tensorflow.python.platform import test
34from tensorflow.python.training import momentum as momentum_lib
35
36
37@test_util.run_deprecated_v1
38class AbsoluteDifferenceLossTest(test.TestCase):
39
40  def setUp(self):
41    super(AbsoluteDifferenceLossTest, self).setUp()
42    self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
43    self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
44
45  def testValueErrorThrownWhenWeightIsNone(self):
46    with self.cached_session():
47      with self.assertRaises(ValueError):
48        losses.absolute_difference(
49            self._predictions, self._predictions, weights=None)
50
51  def testAllCorrectNoLossWeight(self):
52    loss = losses.absolute_difference(self._predictions, self._predictions)
53    with self.cached_session():
54      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
55
56  def testNonZeroLoss(self):
57    loss = losses.absolute_difference(self._labels, self._predictions)
58    with self.cached_session():
59      self.assertAlmostEqual(5.5, self.evaluate(loss), 3)
60
61  def testNonZeroLossWithPythonScalarWeight(self):
62    weights = 2.3
63    loss = losses.absolute_difference(self._labels, self._predictions, weights)
64    with self.cached_session():
65      self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3)
66
67  def testNonZeroLossWithScalarTensorWeight(self):
68    weights = 2.3
69    loss = losses.absolute_difference(self._labels, self._predictions,
70                                      constant_op.constant(weights))
71    with self.cached_session():
72      self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3)
73
74  def testNonZeroLossWithOneDimBatchSpecificWeights(self):
75    weights = constant_op.constant((1.2, 0.0), shape=(2, 1))
76    loss = losses.absolute_difference(self._labels, self._predictions, weights)
77    with self.cached_session():
78      self.assertAlmostEqual(5.6, self.evaluate(loss), 3)
79
80  def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
81    weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
82    loss = losses.absolute_difference(self._labels, self._predictions, weights)
83    with self.cached_session():
84      self.assertAlmostEqual(5.6, self.evaluate(loss), 3)
85
86  def testNonZeroLossWithSampleSpecificWeights(self):
87    weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
88    loss = losses.absolute_difference(self._labels, self._predictions, weights)
89    with self.cached_session():
90      self.assertAlmostEqual(16.6, self.evaluate(loss), 3)
91
92  def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
93    weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
94    loss = losses.absolute_difference(self._labels, self._predictions, weights)
95    with self.cached_session():
96      self.assertAlmostEqual(6.0, self.evaluate(loss), 3)
97
98  def testLossWithSampleSpecificWeightsAllZero(self):
99    weights = array_ops.zeros((2, 3))
100    loss = losses.absolute_difference(self._labels, self._predictions, weights)
101    with self.cached_session():
102      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
103
104  @test_util.assert_no_new_pyobjects_executing_eagerly
105  def testEagerNoMemoryLeaked(self):
106    # This is a somewhat convoluted way of testing that nothing gets added to
107    # a global collection.
108    predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
109    labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
110    losses.absolute_difference(labels, predictions)
111
112
113class SoftmaxCrossEntropyLossTest(test.TestCase):
114
115  def testNoneWeightRaisesValueError(self):
116    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
117                                   [0.0, 0.0, 10.0]])
118    labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
119    with self.cached_session():
120      with self.assertRaises(ValueError):
121        losses.softmax_cross_entropy(labels, logits, weights=None)
122
123  @test_util.run_deprecated_v1
124  def testAllCorrect(self):
125    with self.cached_session():
126      logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
127                                     [0.0, 0.0, 10.0]])
128      labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
129      loss = losses.softmax_cross_entropy(labels, logits)
130      self.assertEqual('softmax_cross_entropy_loss/value', loss.op.name)
131      self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
132
133  @test_util.run_deprecated_v1
134  def testAllWrong(self):
135    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
136                                   [0.0, 0.0, 10.0]])
137    labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
138
139    with self.cached_session():
140      loss = losses.softmax_cross_entropy(labels, logits)
141      self.assertEqual(loss.op.name, 'softmax_cross_entropy_loss/value')
142      self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
143
144  @test_util.run_deprecated_v1
145  def testNonZeroLossWithPythonScalarWeight(self):
146    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
147                                   [0.0, 0.0, 10.0]])
148    labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
149    weights = 2.3
150    with self.cached_session():
151      loss = losses.softmax_cross_entropy(labels, logits, weights)
152      self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
153
154  @test_util.run_deprecated_v1
155  def testNonZeroLossWithScalarTensorWeight(self):
156    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
157                                   [0.0, 0.0, 10.0]])
158    labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
159    weights = 2.3
160    with self.cached_session():
161      loss = losses.softmax_cross_entropy(labels, logits,
162                                          constant_op.constant(weights))
163      self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
164
165  def testNonZeroLossWithOneDimBatchSpecificWeights(self):
166    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
167                                   [0.0, 0.0, 10.0]])
168    labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
169    weights = constant_op.constant((1.2, 3.4, 5.6))
170    with self.cached_session():
171      loss = losses.softmax_cross_entropy(labels, logits, weights)
172      self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0,
173                             self.evaluate(loss), 3)
174
175  def testAllWrongAllWeightsMissing(self):
176    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
177                                   [0.0, 0.0, 10.0]])
178    labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
179    weights = constant_op.constant([0, 0, 0], shape=[3])
180    with self.cached_session():
181      loss = losses.softmax_cross_entropy(labels, logits, weights)
182      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
183
184  def testSomeWeightsMissing(self):
185    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
186                                   [0.0, 0.0, 10.0]])
187    labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
188    weights = constant_op.constant([1.2, 0, 0], shape=[3])
189    with self.cached_session():
190      loss = losses.softmax_cross_entropy(labels, logits, weights)
191      self.assertAlmostEqual(12.0, self.evaluate(loss), 3)
192
193  def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
194    with self.cached_session():
195      logits = constant_op.constant([[100.0, -100.0, -100.0],
196                                     [-100.0, 100.0, -100.0],
197                                     [-100.0, -100.0, 100.0]])
198      labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
199      weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]])
200
201      with self.assertRaises(ValueError):
202        losses.softmax_cross_entropy(labels, logits, weights=weights).eval()
203
204  @test_util.run_deprecated_v1
205  def testSoftmaxLabelSmoothing(self):
206    with self.cached_session():
207      # Softmax Cross Entropy Loss is:
208      #   -\sum_i p_i \log q_i
209      # where for a softmax activation
210      # \log q_i = x_i - \log \sum_j \exp x_j
211      #          = x_i - x_max - \log \sum_j \exp (x_j - x_max)
212      # For our activations, [100, -100, -100] the log partition function
213      # becomes \log ( exp(0) + exp(-200) + exp(-200) ) = 0
214      # so our log softmaxes become: [0, -200, -200]
215      # so our cross entropy loss is:
216      # -(1 - L + L/n) * 0 + 400 * L/n = 400 L/n
217      logits = constant_op.constant([[100.0, -100.0, -100.0]])
218      labels = constant_op.constant([[1, 0, 0]])
219      label_smoothing = 0.1
220      loss = losses.softmax_cross_entropy(
221          labels, logits, label_smoothing=label_smoothing)
222      self.assertEqual(loss.op.name, 'softmax_cross_entropy_loss/value')
223      expected_value = 400.0 * label_smoothing / 3.0
224      self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
225
226
227class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
228
229  def testNoneWeightRaisesValueError(self):
230    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
231                                   [0.0, 0.0, 10.0]])
232    labels = constant_op.constant([[0], [1], [2]])
233    with self.cached_session():
234      with self.assertRaises(ValueError):
235        losses.sparse_softmax_cross_entropy(labels, logits, weights=None)
236
237  @test_util.run_deprecated_v1
238  def testAllCorrectInt32Labels(self):
239    with self.cached_session():
240      logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
241                                     [0.0, 0.0, 10.0]])
242      labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
243      loss = losses.sparse_softmax_cross_entropy(labels, logits)
244      self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
245      self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
246
247  @test_util.assert_no_new_pyobjects_executing_eagerly
248  def testEagerNoMemoryLeaked(self):
249    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
250                                   [0.0, 0.0, 10.0]])
251    labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
252    losses.sparse_softmax_cross_entropy(labels, logits)
253
254  @test_util.run_deprecated_v1
255  def testAllCorrectInt64Labels(self):
256    with self.cached_session():
257      logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
258                                     [0.0, 0.0, 10.0]])
259      labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64)
260      loss = losses.sparse_softmax_cross_entropy(labels, logits)
261      self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
262      self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
263
264  @test_util.run_deprecated_v1
265  def testAllCorrectNonColumnLabels(self):
266    with self.cached_session():
267      logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
268                                     [0.0, 0.0, 10.0]])
269      labels = constant_op.constant([0, 1, 2])
270      loss = losses.sparse_softmax_cross_entropy(labels, logits)
271      self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
272      self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
273
274  @test_util.run_deprecated_v1
275  def testAllWrongInt32Labels(self):
276    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
277                                   [0.0, 0.0, 10.0]])
278    labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
279
280    with self.cached_session():
281      loss = losses.sparse_softmax_cross_entropy(labels, logits)
282      self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
283      self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
284
285  @test_util.run_deprecated_v1
286  def testAllWrongInt64Labels(self):
287    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
288                                   [0.0, 0.0, 10.0]])
289    labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
290
291    with self.cached_session():
292      loss = losses.sparse_softmax_cross_entropy(labels, logits)
293      self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
294      self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
295
296  @test_util.run_deprecated_v1
297  def testAllWrongNonColumnLabels(self):
298    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
299                                   [0.0, 0.0, 10.0]])
300    labels = constant_op.constant([2, 0, 1])
301
302    with self.cached_session():
303      loss = losses.sparse_softmax_cross_entropy(labels, logits)
304      self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
305      self.assertAlmostEqual(self.evaluate(loss), 10.0, 3)
306
307  @test_util.run_deprecated_v1
308  def testNonZeroLossWithPythonScalarWeight(self):
309    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
310                                   [0.0, 0.0, 10.0]])
311    labels = constant_op.constant([[2], [0], [1]])
312    weights = 2.3
313    with self.cached_session():
314      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
315      self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
316
317  @test_util.run_deprecated_v1
318  def testNonZeroLossWithScalarTensorWeight(self):
319    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
320                                   [0.0, 0.0, 10.0]])
321    labels = constant_op.constant([[2], [0], [1]])
322    weights = 2.3
323    with self.cached_session():
324      loss = losses.sparse_softmax_cross_entropy(labels, logits,
325                                                 constant_op.constant(weights))
326      self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
327
328  def testNonZeroLossWith1DTensorWeight(self):
329    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
330                                   [0.0, 0.0, 10.0]])
331    labels = constant_op.constant([[2], [0], [1]])
332    weights = 2.3
333    with self.cached_session():
334      loss = losses.sparse_softmax_cross_entropy(
335          labels, logits, constant_op.constant((weights,)))
336      self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3)
337
338  @test_util.run_deprecated_v1
339  def testNonZeroLossWithPlaceholderForWeights(self):
340    logits = constant_op.constant([[10.0, 0.0, 0.0],
341                                   [0.0, 10.0, 0.0],
342                                   [0.0, 0.0, 10.0]])
343    labels = constant_op.constant([[2], [0], [1]])
344    weights = array_ops.placeholder(dtypes.float32)
345    with self.cached_session() as sess:
346      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
347      loss_val = sess.run(loss,
348                          feed_dict={weights: ((1.2,), (3.4,), (5.6,))})
349      self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3)
350
351  @test_util.run_deprecated_v1
352  def testUnknownShapePlaceholderForLogitsLabelsButScalarWeights(self):
353    logits = array_ops.placeholder(dtypes.float32)
354    labels = array_ops.placeholder(dtypes.int32)
355    weights = 1.0
356    with self.cached_session() as sess:
357      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
358      loss_val = sess.run(loss,
359                          feed_dict={
360                              logits: [[10.0, 0.0, 0.0],
361                                       [0.0, 10.0, 0.0],
362                                       [0.0, 0.0, 10.0]],
363                              labels: [[2], [0], [1]],
364                          })
365      self.assertAlmostEqual((1.0 + 1.0 + 1.0) * 10.0 / 3.0, loss_val, 3)
366
367  @test_util.run_deprecated_v1
368  def testNonZeroLossWithPlaceholderForLogitsLabelsAndWeights(self):
369    logits = array_ops.placeholder(dtypes.float32, shape=(None, 3))
370    labels = array_ops.placeholder(dtypes.int32, shape=(None, 1))
371    weights = array_ops.placeholder(dtypes.float32)
372    with self.cached_session() as sess:
373      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
374      loss_val = sess.run(loss,
375                          feed_dict={
376                              logits: [[10.0, 0.0, 0.0],
377                                       [0.0, 10.0, 0.0],
378                                       [0.0, 0.0, 10.0]],
379                              labels: [[2], [0], [1]],
380                              weights: ((1.2,), (3.4,), (5.6,)),
381                          })
382      self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3)
383
384  def testNonZeroLossWithOneDimBatchSpecificWeights(self):
385    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
386                                   [0.0, 0.0, 10.0]])
387    labels = constant_op.constant([[2], [0], [1]])
388    weights = constant_op.constant([1.2, 3.4, 5.6], shape=(3, 1))
389    with self.cached_session():
390      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
391      self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0,
392                             self.evaluate(loss), 3)
393
394  def testNonZeroLossWithColumnWeights(self):
395    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
396                                   [0.0, 0.0, 10.0]])
397    labels = constant_op.constant([[2], [0], [1]])
398    weights = constant_op.constant([[1.2], [3.4], [5.6]])
399    with self.cached_session():
400      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
401      self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0,
402                             self.evaluate(loss), 3)
403
404  def testAllWrongAllWeightsMissing(self):
405    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
406                                   [0.0, 0.0, 10.0]])
407    labels = constant_op.constant([[2], [0], [1]])
408    weights = constant_op.constant([0, 0, 0], shape=(3, 1))
409    with self.cached_session():
410      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
411      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
412
413  def testSomeWeightsMissing(self):
414    logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
415                                   [0.0, 0.0, 10.0]])
416    labels = constant_op.constant([[2], [0], [1]])
417    weights = constant_op.constant([1.2, 0, 0], shape=(3, 1))
418    with self.cached_session():
419      loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
420      self.assertAlmostEqual(12.0, self.evaluate(loss), 3)
421
422  @test_util.run_deprecated_v1
423  def testMeasurementSpecificWeightsRaisesException(self):
424    with self.cached_session():
425      logits = constant_op.constant([[100.0, -100.0, -100.0],
426                                     [-100.0, 100.0, -100.0],
427                                     [-100.0, -100.0, 100.0]])
428      labels = constant_op.constant([[0], [1], [2]])
429      weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]])
430
431      with self.assertRaises(ValueError):
432        losses.sparse_softmax_cross_entropy(
433            labels, logits, weights=weights).eval()
434
435  def testInconsistentWeightSizeRaisesException(self):
436    """The weight tensor has incorrect number of elements."""
437    with self.cached_session():
438      logits = constant_op.constant([[100.0, -100.0, -100.0],
439                                     [-100.0, 100.0, -100.0],
440                                     [-100.0, -100.0, 100.0]])
441      labels = constant_op.constant([[0], [1], [2]])
442      weights = constant_op.constant([1.2, 3.4, 5.6, 7.8])
443
444      with self.assertRaises(ValueError):
445        losses.sparse_softmax_cross_entropy(
446            labels, logits, weights=weights).eval()
447
448  def testInconsistentLabelSizeRaisesException(self):
449    """The label tensor has incorrect number of elements."""
450    with self.cached_session():
451      logits = constant_op.constant([[100.0, -100.0, -100.0],
452                                     [-100.0, 100.0, -100.0],
453                                     [-100.0, -100.0, 100.0]])
454      labels = constant_op.constant([[0], [1], [2], [3]])
455      weights = constant_op.constant([1.2, 3.4, 5.6])
456
457      with self.assertRaises(ValueError):
458        losses.sparse_softmax_cross_entropy(
459            labels, logits, weights=weights).eval()
460
461  @test_util.run_deprecated_v1
462  def testInconsistentWeightShapeRaisesException(self):
463    """The weight tensor has incorrect shape."""
464    with self.cached_session():
465      logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
466                                     [-100.0, 100.0, -100.0, -100.0],
467                                     [-100.0, -100.0, 100.0, -100.0],
468                                     [-100.0, -100.0, -100.0, 100.0]])
469      labels = constant_op.constant([[0], [1], [2], [3]])
470      weights = constant_op.constant([[1.2, 3.4], [5.6, 7.8]])
471
472      with self.assertRaises(ValueError):
473        losses.sparse_softmax_cross_entropy(
474            labels, logits, weights=weights).eval()
475
476  @test_util.run_deprecated_v1
477  def testInconsistentLabelShapeRaisesException(self):
478    """The label tensor has incorrect shape."""
479    with self.cached_session():
480      logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
481                                     [-100.0, 100.0, -100.0, -100.0],
482                                     [-100.0, -100.0, 100.0, -100.0],
483                                     [-100.0, -100.0, -100.0, 100.0]])
484      labels = constant_op.constant([[0, 1], [2, 3]])
485      weights = constant_op.constant(1.2)
486
487      with self.assertRaisesRegex(
488          ValueError,
489          '`labels.shape.rank` must equal `logits.shape.rank - 1`'):
490        losses.sparse_softmax_cross_entropy(
491            labels, logits, weights=weights).eval()
492
493
494class SigmoidCrossEntropyLossTest(test.TestCase):
495
496  @test_util.run_deprecated_v1
497  def testAllCorrectSigmoid(self):
498    with self.cached_session():
499      logits = constant_op.constant([[100.0, -100.0, -100.0],
500                                     [-100.0, 100.0, -100.0],
501                                     [-100.0, -100.0, 100.0]])
502      labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
503      loss = losses.sigmoid_cross_entropy(labels, logits)
504      self.assertEqual(logits.dtype, loss.dtype)
505      self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
506      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
507
508  @test_util.run_deprecated_v1
509  def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self):
510    logits = array_ops.placeholder(dtypes.float32, shape=(None, 1))
511    labels = array_ops.placeholder(dtypes.float32, shape=(None, 1))
512    weights = array_ops.ones_like(logits, dtype=dtypes.float32)
513
514    loss = losses.sigmoid_cross_entropy(labels, logits, weights)
515    self.assertEqual(logits.dtype, loss.dtype)
516
517    with self.cached_session() as sess:
518      loss = sess.run(loss,
519                      feed_dict={
520                          logits: np.ones((32, 1)),
521                          labels: np.ones((32, 1)),
522                      })
523      self.assertAlmostEqual(0.313, loss, 3)
524
525  @test_util.run_deprecated_v1
526  def testLossWithSingleDimPlaceholderForLogitsAndWeights2(self):
527    logits = array_ops.placeholder(dtypes.float32, shape=(None, 2))
528    labels = array_ops.placeholder(dtypes.float32, shape=(None, 2))
529    weights = array_ops.ones_like(logits, dtype=dtypes.float32)
530
531    loss = losses.sigmoid_cross_entropy(labels, logits, weights)
532    self.assertEqual(logits.dtype, loss.dtype)
533
534    with self.cached_session() as sess:
535      loss = sess.run(loss,
536                      feed_dict={
537                          logits: np.ones((32, 2)),
538                          labels: np.ones((32, 2)),
539                      })
540      self.assertAlmostEqual(0.313, loss, 3)
541
542  @test_util.run_deprecated_v1
543  def testAllWrongSigmoid(self):
544    with self.cached_session():
545      logits = constant_op.constant([[100.0, -100.0, -100.0],
546                                     [-100.0, 100.0, -100.0],
547                                     [-100.0, -100.0, 100.0]])
548      labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
549      loss = losses.sigmoid_cross_entropy(labels, logits)
550      self.assertEqual(logits.dtype, loss.dtype)
551      self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
552      self.assertAlmostEqual(self.evaluate(loss), 600.0 / 9.0, 3)
553
554  @test_util.run_deprecated_v1
555  def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
556    with self.cached_session():
557      logits = constant_op.constant([[100.0, -100.0, -100.0],
558                                     [-100.0, 100.0, -100.0],
559                                     [-100.0, -100.0, 100.0]])
560      labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
561      weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]])
562      loss = losses.sigmoid_cross_entropy(labels, logits, weights)
563      self.assertEqual(logits.dtype, loss.dtype)
564      self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
565      self.assertAlmostEqual(1700.0 / 7.0, self.evaluate(loss), 3)
566
567  @test_util.run_deprecated_v1
568  def testMultiCorrectSigmoid(self):
569    logits = constant_op.constant([[100.0, -100.0, 100.0],
570                                   [100.0, 100.0, -100.0],
571                                   [-100.0, 100.0, 100.0]])
572    labels = constant_op.constant([[1, 0, 1], [1, 1, 0], [0, 1, 1]])
573    loss = losses.sigmoid_cross_entropy(labels, logits)
574    self.assertEqual(logits.dtype, loss.dtype)
575    self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
576
577    with self.cached_session():
578      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
579
580  def testSigmoidFloat64(self):
581    logits = constant_op.constant((
582        (100.0, -100.0, 100.0),
583        (100.0, -100.0, 100.0),
584        (100.0, 100.0, -100.0)
585    ), dtype=dtypes.float64)
586    labels = constant_op.constant((
587        (1, 0, 1), (1, 1, 0), (0, 1, 1)
588    ), dtype=dtypes.int64)
589    loss = losses.sigmoid_cross_entropy(labels, logits)
590    self.assertEqual(logits.dtype, loss.dtype)
591
592    with self.cached_session():
593      self.assertAlmostEqual(44.444, self.evaluate(loss), 3)
594
595  def testSigmoidNoReduction(self):
596    logits = constant_op.constant((
597        (100.0, -100.0, 100.0),
598        (100.0, -100.0, 100.0),
599        (100.0, 100.0, -100.0)))
600    labels = constant_op.constant(((1, 0, 1), (1, 1, 0), (0, 1, 1)))
601    loss = losses.sigmoid_cross_entropy(
602        labels, logits, reduction=losses.Reduction.NONE)
603    self.assertEqual(logits.dtype, loss.dtype)
604
605    with self.cached_session():
606      self.assertAllClose(((0., 0., 0.), (0., 100., 100.), (100., 0., 100.)),
607                          self.evaluate(loss), 3)
608
609  @test_util.run_deprecated_v1
610  def testSigmoidLabelSmoothingCorrect(self):
611    with self.cached_session():
612      logits = constant_op.constant([[100.0, -100.0, -100.0]])
613      labels = constant_op.constant([[1, 0, 1]])
614      # Sigmoid cross entropy loss is:
615      #   max(x,0) - x*z + log(1 + exp(-abs(x)))
616      # The new labels are:
617      #    z' = z * (1 - L) + 0.5 L
618      #    1 -> 1 - 0.5 L
619      #    0 -> 0.5 L
620      # here we expect:
621      # 1/3 * (100 - 100 * (1 - 0.5 L)  + 0
622      #       + 0  + 100 * (0.5 L)      + 0
623      #       + 0  + 100 * (1 - 0.5 L)  + 0)
624      # = 1/3 * (100 + 50 L)
625      label_smoothing = 0.1
626      loss = losses.sigmoid_cross_entropy(
627          labels, logits, label_smoothing=label_smoothing)
628      self.assertEqual(logits.dtype, loss.dtype)
629      self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name)
630      expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
631      self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
632
633  @test_util.run_deprecated_v1
634  def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
635    with self.cached_session():
636      label_smoothing = 0.1
637      sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
638      sigmoid_labels = constant_op.constant([[1, 0, 1]])
639      sigmoid_loss = losses.sigmoid_cross_entropy(
640          sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing)
641      self.assertEqual(sigmoid_logits.dtype, sigmoid_loss.dtype)
642
643      softmax_logits = constant_op.constant(
644          [[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]])
645      softmax_labels = constant_op.constant([[0, 1], [1, 0], [0, 1]])
646      softmax_loss = losses.softmax_cross_entropy(
647          softmax_labels, softmax_logits, label_smoothing=label_smoothing)
648      self.assertAlmostEqual(
649          self.evaluate(sigmoid_loss), self.evaluate(softmax_loss), 3)
650
651
652@test_util.run_deprecated_v1
653class LogLossTest(test.TestCase):
654
655  def setUp(self):
656    super(LogLossTest, self).setUp()
657    predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3))
658    labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3))
659
660    self._np_predictions = predictions
661    self._np_labels = labels
662
663    epsilon = 1e-7
664    self._expected_losses = np.multiply(
665        labels, np.log(predictions + epsilon)) + np.multiply(
666            1 - labels, np.log(1 - predictions + epsilon))
667
668    self._predictions = constant_op.constant(predictions)
669    self._labels = constant_op.constant(labels)
670
671  def testValueErrorThrownWhenWeightIsNone(self):
672    with self.cached_session():
673      with self.assertRaises(ValueError):
674        losses.log_loss(self._labels, self._labels, weights=None)
675
676  def testAllCorrectNoLossWeight(self):
677    loss = losses.log_loss(self._labels, self._labels)
678    with self.cached_session():
679      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
680
681  def testAllCorrectNoLossWeightWithPlaceholder(self):
682    tf_predictions = array_ops.placeholder(
683        dtypes.float32, shape=self._np_labels.shape)
684    loss = losses.log_loss(self._labels, tf_predictions)
685    with self.cached_session():
686      self.assertAlmostEqual(
687          0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
688
689  def testNonZeroLoss(self):
690    loss = losses.log_loss(self._labels, self._predictions)
691    with self.cached_session():
692      self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
693                             self.evaluate(loss), 3)
694
695  def testNonZeroLossWithPythonScalarWeight(self):
696    weights = 2.3
697    loss = losses.log_loss(self._labels, self._predictions, weights)
698    with self.cached_session():
699      self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
700                             self.evaluate(loss), 3)
701
702  def testNonZeroLossWithScalarTensorWeight(self):
703    weights = 2.3
704    loss = losses.log_loss(self._labels, self._predictions,
705                           constant_op.constant(weights))
706    with self.cached_session():
707      self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
708                             self.evaluate(loss), 3)
709
710  def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
711    tf_predictions = array_ops.placeholder(
712        dtypes.float32, shape=self._np_predictions.shape)
713    weights = 2.3
714    loss = losses.log_loss(self._labels, tf_predictions,
715                           constant_op.constant(weights))
716    with self.cached_session() as sess:
717      loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
718      self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
719                             loss, 3)
720
721  def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
722    tf_predictions = array_ops.placeholder(dtypes.float32, shape=[None, None])
723    weights = 2.3
724    loss = losses.log_loss(self._labels, tf_predictions,
725                           constant_op.constant(weights))
726    with self.cached_session() as sess:
727      loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
728      self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
729                             loss, 3)
730
731  def testNonZeroLossWithOneDimBatchSpecificWeights(self):
732    weights = constant_op.constant((1.2, 3.4), shape=(2, 1))
733    expected_losses = np.multiply(
734        self._expected_losses,
735        np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
736    loss = losses.log_loss(self._labels, self._predictions, weights)
737    with self.cached_session():
738      self.assertAlmostEqual(-np.sum(expected_losses) / 6.0,
739                             self.evaluate(loss), 3)
740
741  def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
742    weights = constant_op.constant((1.2, 0), shape=(2, 1))
743    expected_losses = np.multiply(self._expected_losses,
744                                  np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
745                                      (2, 3)))
746    loss = losses.log_loss(self._labels, self._predictions, weights)
747    with self.cached_session():
748      self.assertAlmostEqual(-np.sum(expected_losses) / 3.0,
749                             self.evaluate(loss), 3)
750
751  def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
752    weights = constant_op.constant([1.2, 0], shape=[2, 1])
753    expected_losses = np.multiply(self._expected_losses,
754                                  np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
755                                      (2, 3)))
756    loss = losses.log_loss(self._labels, self._predictions, weights)
757    with self.cached_session():
758      self.assertAlmostEqual(-np.sum(expected_losses) / 3.0,
759                             self.evaluate(loss), 3)
760
761  def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
762    weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
763    with self.cached_session():
764      with self.assertRaises(ValueError):
765        losses.log_loss(self._labels, self._predictions, weights)
766
767  def testNonZeroLossWithMeasurementSpecificWeights(self):
768    weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
769    expected_losses = np.multiply(self._expected_losses, weights)
770
771    loss = losses.log_loss(
772        self._labels,
773        self._predictions,
774        constant_op.constant(
775            weights, shape=(2, 3)))
776    with self.cached_session():
777      self.assertAlmostEqual(-np.sum(expected_losses) / 5.0,
778                             self.evaluate(loss), 3)
779
780  def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
781    weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
782    expected_losses = np.multiply(self._expected_losses, weights)
783
784    tf_predictions = array_ops.placeholder(dtypes.float32, shape=[2, 3])
785    loss = losses.log_loss(
786        self._labels,
787        tf_predictions,
788        constant_op.constant(
789            weights, shape=(2, 3)))
790
791    with self.cached_session() as sess:
792      loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
793      self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
794
795  def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
796    weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
797    expected_losses = np.multiply(self._expected_losses, weights)
798
799    loss = losses.log_loss(
800        self._labels,
801        self._predictions,
802        constant_op.constant(
803            weights, shape=(2, 3)))
804    with self.cached_session():
805      self.assertAlmostEqual(-np.sum(expected_losses), self.evaluate(loss), 3)
806
807  def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
808    weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
809    expected_losses = np.multiply(self._expected_losses, weights)
810
811    tf_predictions = array_ops.placeholder(dtypes.float32, shape=[2, 3])
812    tf_weights = constant_op.constant(weights, shape=(2, 3))
813    loss = losses.log_loss(self._labels, tf_predictions, tf_weights)
814
815    with self.cached_session() as sess:
816      loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
817      self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
818
819  def testLossWithSampleSpecificWeightsAllZero(self):
820    tf_weights = array_ops.zeros(shape=(2, 3))
821    loss = losses.log_loss(self._labels, self._predictions, tf_weights)
822    with self.cached_session():
823      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
824
825
826class HingeLossTest(test.TestCase):
827
828  def testIncompatibleShapes(self):
829    with self.cached_session():
830      logits = constant_op.constant([[-1.0], [2.1]])
831      labels = constant_op.constant([0.0, 1.0])
832      with self.assertRaises(ValueError):
833        _ = losses.hinge_loss(labels, logits).eval()
834
835  @test_util.run_deprecated_v1
836  def testAllOutsideMargin(self):
837    with self.cached_session():
838      logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
839      labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
840      loss = losses.hinge_loss(labels, logits)
841      self.assertAllClose(loss, 0.0, atol=1e-3)
842
843  @test_util.run_deprecated_v1
844  def testSomeInsideMargin(self):
845    with self.cached_session():
846      logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
847      labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
848      loss = losses.hinge_loss(labels, logits)
849      # Examples 1 and 4 are on the correct side of the hyperplane but within
850      # the margin so they incur some (small) loss.
851      self.assertAllClose(loss, 0.175, atol=1e-3)
852
853  @test_util.run_deprecated_v1
854  def testSomeMisclassified(self):
855    with self.cached_session():
856      logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
857      labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
858      loss = losses.hinge_loss(labels, logits)
859      # Examples 2 and 4 are on the wrong side of the hyperplane so they incur
860      # some (fairly large) loss.
861      self.assertAllClose(loss, 0.875, atol=1e-3)
862
863
864class HuberLossTest(test.TestCase):
865
866  def testIncompatibleShapes(self):
867    with self.cached_session():
868      predictions = constant_op.constant([[-1.0], [2.1]])
869      labels = constant_op.constant([0.0, 1.0])
870      with self.assertRaises(ValueError):
871        _ = losses.huber_loss(labels, predictions).eval()
872
873  @test_util.run_deprecated_v1
874  def testAllQuadratic(self):
875    with self.cached_session():
876      predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
877      labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
878      loss = losses.huber_loss(labels, predictions)
879      self.assertAllClose(
880          loss, 0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4., atol=1e-5)
881
882  @test_util.run_deprecated_v1
883  def testAllLinear(self):
884    with self.cached_session():
885      predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
886      labels = constant_op.constant([0.0, 1.0, 0.0, 1.5])
887      loss = losses.huber_loss(labels, predictions)
888      self.assertAllClose(loss, (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5, atol=1e-5)
889
890  @test_util.run_deprecated_v1
891  def testMixedQuadraticLinear(self):
892    with self.cached_session():
893      predictions = constant_op.constant([[1.5, -1.4, -1.0, 0.0],
894                                          [1.5, -1.4, -1.0, 0.0]])
895      labels = constant_op.constant([[1.0, -1.0, 0.0, 0.5],
896                                     [0.0, 1.0, 0.0, 1.5]])
897      loss = losses.huber_loss(labels, predictions)
898      quadratic = 0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4.
899      linear = (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5
900      expected_loss = (quadratic + linear) / 2.
901      self.assertAllClose(loss, expected_loss, atol=1e-5)
902
903  def testAllQuadraticDelta(self):
904    with self.cached_session():
905      delta = 0.5
906      predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0])
907      labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
908      expected = 0.5 * np.array([0.5**2, 0.4**2, 0.5**2, 0.5**2]).mean()
909      loss = losses.huber_loss(labels, predictions, delta=delta)
910      self.assertAllClose(expected, self.evaluate(loss), atol=1e-5)
911
912  def testAllLinearDelta(self):
913    delta = 0.5
914    predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
915    labels = constant_op.constant([0.0, 1.0, 0.0, 1.5])
916    expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean()
917    expected -= 0.5 * delta**2
918    loss = losses.huber_loss(labels, predictions, delta=delta)
919    with self.cached_session():
920      self.assertAllClose(expected, self.evaluate(loss), atol=1e-5)
921
922
923@test_util.run_deprecated_v1
924class MeanSquaredErrorTest(test.TestCase):
925
926  def setUp(self):
927    super(MeanSquaredErrorTest, self).setUp()
928    self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
929    self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
930
931  def testValueErrorThrownWhenWeightIsNone(self):
932    with self.cached_session():
933      with self.assertRaises(ValueError):
934        losses.mean_squared_error(
935            self._predictions, self._predictions, weights=None)
936
937  @test_util.run_deprecated_v1
938  def testScalar(self):
939    with self.cached_session():
940      self.assertEqual(
941          0.0,
942          losses.mean_squared_error(predictions=constant_op.constant(0),
943                                    labels=constant_op.constant(0)).eval())
944
945  @test_util.run_deprecated_v1
946  def testAllCorrectNoLossWeight(self):
947    loss = losses.mean_squared_error(self._predictions, self._predictions)
948    with self.cached_session():
949      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
950
951  @test_util.run_deprecated_v1
952  def testNonZeroLoss(self):
953    loss = losses.mean_squared_error(self._labels, self._predictions)
954    with self.cached_session():
955      self.assertAlmostEqual(49.5, self.evaluate(loss), 3)
956
957  @test_util.run_deprecated_v1
958  def testNonZeroLossWithPythonScalarWeight(self):
959    weights = 2.3
960    loss = losses.mean_squared_error(self._labels, self._predictions, weights)
961    with self.cached_session():
962      self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3)
963
964  @test_util.run_deprecated_v1
965  def testNonZeroLossWithScalarTensorWeight(self):
966    weights = 2.3
967    loss = losses.mean_squared_error(self._labels, self._predictions,
968                                     constant_op.constant(weights))
969    with self.cached_session():
970      self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3)
971
972  def testNonZeroLossWithOneDimBatchSpecificWeights(self):
973    weights = constant_op.constant([1.2, 3.4], shape=(2, 1))
974    loss = losses.mean_squared_error(self._labels, self._predictions, weights)
975    with self.cached_session():
976      self.assertAlmostEqual(767.8 / 6.0, self.evaluate(loss), 3)
977
978  def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
979    weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
980    loss = losses.mean_squared_error(self._labels, self._predictions, weights)
981    with self.cached_session():
982      self.assertAlmostEqual(767.8 / 6.0, self.evaluate(loss), 3)
983
984  def testNonZeroLossWithSampleSpecificWeights(self):
985    weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
986    loss = losses.mean_squared_error(self._labels, self._predictions, weights)
987    with self.cached_session():
988      self.assertAlmostEqual(587 / 5.0, self.evaluate(loss), 3)
989
990  def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
991    weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
992    loss = losses.mean_squared_error(self._labels, self._predictions, weights)
993    with self.cached_session():
994      self.assertAlmostEqual(18.0, self.evaluate(loss), 3)
995
996  def testLossWithSampleSpecificWeightsAllZero(self):
997    weights = array_ops.zeros((2, 3))
998    loss = losses.mean_squared_error(self._labels, self._predictions, weights)
999    with self.cached_session():
1000      self.assertAlmostEqual(0.0, self.evaluate(loss), 3)
1001
1002
1003@test_util.run_deprecated_v1
1004class MeanPairwiseSquaredErrorTest(test.TestCase):
1005
1006  def setUp(self):
1007    super(MeanPairwiseSquaredErrorTest, self).setUp()
1008    self._predictions = np.array([[4, 8, 12], [8, 1, 3]])
1009    self._labels = np.array([[1, 9, 2], [-5, -5, 7]])
1010
1011    batch_size, dims = self._labels.shape  # pylint: disable=unpacking-non-sequence
1012
1013    # Compute the expected loss 'manually'.
1014    total = np.zeros((batch_size,))
1015    for b in range(batch_size):
1016      for i in range(dims - 1):
1017        for j in range(i + 1, dims):
1018          x = self._predictions[b, i].item() - self._predictions[b, j].item()
1019          y = self._labels[b, i].item() - self._labels[b, j].item()
1020          diff = (x - y)
1021          total[b] += (diff * diff)
1022
1023    self._expected_losses = np.divide(total, 3.0)
1024
1025  def testValueErrorThrownWhenWeightIsNone(self):
1026    with self.cached_session():
1027      with self.assertRaises(ValueError):
1028        losses.mean_pairwise_squared_error(
1029            predictions=constant_op.constant(self._labels),
1030            labels=constant_op.constant(self._labels),
1031            weights=None)
1032
1033  def _test_valid_weights(
1034      self, labels, predictions, expected_loss, weights=1.0):
1035    with self.cached_session():
1036      static_inputs_op = losses.mean_pairwise_squared_error(
1037          predictions=predictions, labels=labels, weights=weights)
1038      self.assertAlmostEqual(
1039          expected_loss, self.evaluate(static_inputs_op), places=3)
1040
1041      predictions_placeholder = array_ops.placeholder(
1042          dtypes.float32, shape=np.asarray(predictions.shape))
1043      labels_placeholder = array_ops.placeholder(
1044          dtypes.int32, shape=np.asarray(labels.shape))
1045      weights_placeholder = array_ops.placeholder(
1046          dtypes.float32, shape=np.asarray(weights).shape)
1047      dynamic_inputs_op = losses.mean_pairwise_squared_error(
1048          predictions=predictions_placeholder,
1049          labels=labels_placeholder,
1050          weights=weights_placeholder)
1051      feed_dict = {
1052          predictions_placeholder: predictions,
1053          labels_placeholder: labels,
1054          weights_placeholder: weights,
1055      }
1056      self.assertAlmostEqual(
1057          expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3)
1058
1059  def testAllCorrectNoLossWeight(self):
1060    self._test_valid_weights(
1061        self._labels, self._labels, expected_loss=0.0)
1062
1063  def testNonZeroLoss(self):
1064    self._test_valid_weights(
1065        self._labels, self._predictions,
1066        expected_loss=np.sum(self._expected_losses))
1067
1068  def testGradientWithZeroWeight(self):
1069    with ops.Graph().as_default():
1070      random_seed.set_random_seed(0)
1071
1072      inputs = array_ops.ones((2, 3))
1073      weights = variable_scope.get_variable(
1074          'weights',
1075          shape=[3, 4],
1076          initializer=init_ops.truncated_normal_initializer())
1077      predictions = math_ops.matmul(inputs, weights)
1078
1079      optimizer = momentum_lib.MomentumOptimizer(
1080          learning_rate=0.001, momentum=0.9)
1081      loss = losses.mean_pairwise_squared_error(predictions, predictions, 0)
1082
1083      gradients_to_variables = optimizer.compute_gradients(loss)
1084
1085      init_op = variables.global_variables_initializer()
1086
1087      with self.cached_session() as sess:
1088        self.evaluate(init_op)
1089        for grad, _ in gradients_to_variables:
1090          np_grad = self.evaluate(grad)
1091          self.assertFalse(np.isnan(np_grad).any())
1092
1093  def testNonZeroLossWithPythonScalarWeight(self):
1094    weight = 2.3
1095    self._test_valid_weights(
1096        self._labels, self._predictions,
1097        expected_loss=weight * np.sum(self._expected_losses),
1098        weights=weight)
1099
1100  def testNonZeroLossWithScalarTensorWeight(self):
1101    weights = 2.3
1102    loss = losses.mean_pairwise_squared_error(
1103        predictions=constant_op.constant(self._predictions),
1104        labels=constant_op.constant(self._labels),
1105        weights=constant_op.constant(weights))
1106    with self.cached_session():
1107      self.assertAlmostEqual(weights * np.sum(self._expected_losses),
1108                             self.evaluate(loss), 3)
1109
1110  def testNonZeroLossWithScalarZeroWeight(self):
1111    self._test_valid_weights(
1112        self._labels, self._predictions, expected_loss=0.0, weights=0.0)
1113
1114  def test3d(self):
1115    labels = np.array([
1116        [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
1117        [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
1118    ])
1119    predictions = np.array([
1120        [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
1121        [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
1122    ])
1123    self._test_valid_weights(labels, predictions, expected_loss=137.5)
1124
1125  def test3dWeightedScalar(self):
1126    labels = np.array([
1127        [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
1128        [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
1129    ])
1130    predictions = np.array([
1131        [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
1132        [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
1133    ])
1134    weight = 3.0
1135    self._test_valid_weights(
1136        labels, predictions, expected_loss=weight * 137.5, weights=weight)
1137
1138  def _test_invalid_weights(
1139      self, labels, predictions, weights=1.0):
1140    expected_error_msg = 'weights can not be broadcast to values'
1141
1142    # Static check.
1143    with self.assertRaisesRegex(ValueError, expected_error_msg):
1144      losses.mean_pairwise_squared_error(
1145          predictions=predictions, labels=labels, weights=weights)
1146
1147    # Dynamic check.
1148    predictions_placeholder = array_ops.placeholder(dtypes.float32)
1149    labels_placeholder = array_ops.placeholder(dtypes.int32)
1150    weights_placeholder = array_ops.placeholder(dtypes.float32)
1151    dynamic_inputs_op = losses.mean_pairwise_squared_error(
1152        predictions=predictions_placeholder,
1153        labels=labels_placeholder,
1154        weights=weights_placeholder)
1155    with self.cached_session():
1156      with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg):
1157        dynamic_inputs_op.eval(feed_dict={
1158            predictions_placeholder: predictions,
1159            labels_placeholder: labels,
1160            weights_placeholder: weights,
1161        })
1162
1163  def testInvalid3dWeighted2x0(self):
1164    labels = np.array([
1165        [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
1166        [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
1167    ])
1168    predictions = np.array([
1169        [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
1170        [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
1171    ])
1172    self._test_invalid_weights(
1173        labels, predictions, weights=np.asarray((1.2, 3.4)))
1174
1175  def test3dWeighted2x3x3(self):
1176    labels = np.array([
1177        [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
1178        [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
1179    ])
1180    predictions = np.array([
1181        [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
1182        [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
1183    ])
1184    self._test_valid_weights(
1185        # TODO(ptucker): This doesn't look right.
1186        labels,
1187        predictions,
1188        expected_loss=9 * 137.5,
1189        weights=np.ones((2, 3, 3)))
1190
1191  def testLossWithAllZeroBatchSpecificWeights(self):
1192    self._test_valid_weights(
1193        self._labels, self._predictions, expected_loss=0.0,
1194        weights=np.zeros((2, 1)))
1195
1196  def testLossIsAssociativeAcrossBatchElements(self):
1197    with ops.Graph().as_default():
1198      random_seed.set_random_seed(0)
1199
1200      height = 3
1201      width = 4
1202      shape = (1, height, width, 1)
1203
1204      labels0 = random_ops.random_uniform(
1205          shape, minval=0, maxval=1, dtype=dtypes.float32)
1206      predictions0 = random_ops.random_uniform(
1207          shape, minval=0, maxval=1, dtype=dtypes.float32)
1208
1209      labels1 = random_ops.random_uniform(
1210          shape, minval=0, maxval=1, dtype=dtypes.float32)
1211      predictions1 = random_ops.random_uniform(
1212          shape, minval=0, maxval=1, dtype=dtypes.float32)
1213
1214      loss0 = losses.mean_pairwise_squared_error(
1215          labels=labels0,
1216          predictions=predictions0)
1217      loss1 = losses.mean_pairwise_squared_error(
1218          labels=labels1,
1219          predictions=predictions1)
1220      loss0_1 = losses.mean_pairwise_squared_error(
1221          labels=array_ops.concat([labels0, labels1], 0),
1222          predictions=array_ops.concat([predictions0, predictions1], 0))
1223
1224      with self.cached_session() as session:
1225        loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
1226
1227        self.assertTrue(loss0 > 0)
1228        self.assertTrue(loss1 > 0)
1229        self.assertAlmostEqual(loss0 + loss1, loss0_1, 5)
1230
1231
1232@test_util.run_deprecated_v1
1233class CosineDistanceLossTest(test.TestCase):
1234
1235  def setUp(self):
1236    super(CosineDistanceLossTest, self).setUp()
1237    self._predictions = np.asarray([
1238        [1, 0, 0],  # Batch 1
1239        [0, 0, -1],
1240        [1, 0, 0],  # Batch 2
1241        [1, 0, 0],
1242        [0, 0, -1],  # Batch 3
1243        [1, 0, 0]
1244    ]).reshape((3, 2, 3))
1245
1246    self._labels = np.asarray([[1, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0],
1247                               [0, 0, 1], [0, 1, 0]]).reshape((3, 2, 3))
1248
1249  def testValueErrorThrownWhenWeightIsNone(self):
1250    with self.cached_session():
1251      with self.assertRaises(ValueError):
1252        losses.cosine_distance(
1253            predictions=constant_op.constant(self._labels),
1254            labels=constant_op.constant(self._labels),
1255            dim=2,
1256            weights=None)
1257
1258  def testAllCorrectNoWeights(self):
1259    loss = losses.cosine_distance(
1260        predictions=constant_op.constant(self._labels),
1261        labels=constant_op.constant(self._labels),
1262        dim=2)
1263    with self.cached_session():
1264      self.assertAlmostEqual(0, self.evaluate(loss), 5)
1265
1266  def testPartiallyCorrectWithIntegerValues(self):
1267    loss = losses.cosine_distance(
1268        predictions=constant_op.constant(self._predictions),
1269        labels=constant_op.constant(self._labels),
1270        dim=2)
1271    with self.cached_session():
1272      self.assertAlmostEqual(1, self.evaluate(loss), 5)
1273
1274  def testPartiallyCorrectFloatingPointValues(self):
1275    predictions = np.matrix(
1276        ('0.819031913261206 0.567041924552012 0.087465312324590;'
1277         '-0.665139432070255 -0.739487441769973 -0.103671883216994;'
1278         '0.707106781186548 -0.707106781186548 0'))
1279    labels = np.matrix(('0.819031913261206 0.567041924552012 0.087465312324590;'
1280                        '0.665139432070255 0.739487441769973 0.103671883216994;'
1281                        '0.707106781186548 0.707106781186548 0'))
1282
1283    tf_preds = constant_op.constant(
1284        predictions, shape=(3, 1, 3), dtype=dtypes.float32)
1285    tf_labels = constant_op.constant(
1286        labels, shape=(3, 1, 3), dtype=dtypes.float32)
1287    loss = losses.cosine_distance(tf_labels, tf_preds, dim=2)
1288
1289    with self.cached_session():
1290      self.assertAlmostEqual(1.0, self.evaluate(loss), 5)
1291
1292  def testSampleSpecificWeights(self):
1293    loss = losses.cosine_distance(
1294        predictions=constant_op.constant(self._predictions),
1295        labels=constant_op.constant(self._labels),
1296        dim=2,
1297        weights=np.asarray((1, 0, 0)).reshape((3, 1, 1)))
1298    with self.cached_session():
1299      self.assertEqual(1.0, self.evaluate(loss))
1300
1301  def testMeasurementSpecificWeights(self):
1302    loss = losses.cosine_distance(
1303        predictions=constant_op.constant(self._predictions),
1304        labels=constant_op.constant(self._labels),
1305        dim=2,
1306        weights=constant_op.constant(
1307            [1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
1308    with self.cached_session():
1309      self.assertEqual(3.0 / 4.0, self.evaluate(loss))
1310
1311  def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
1312    tf_predictions = array_ops.placeholder(
1313        dtypes.float32, shape=self._labels.shape)
1314    loss = losses.cosine_distance(
1315        predictions=tf_predictions,
1316        labels=constant_op.constant(self._labels),
1317        dim=2,
1318        weights=constant_op.constant(
1319            [1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
1320    with self.cached_session() as sess:
1321      loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
1322      self.assertEqual(3.0 / 4.0, loss)
1323
1324  def testZeroLossWhenAllSampleSpecificWeightsAreZero(self):
1325    loss = losses.cosine_distance(
1326        predictions=constant_op.constant(self._predictions),
1327        labels=constant_op.constant(self._labels),
1328        dim=2,
1329        weights=array_ops.zeros((3, 1, 1)))
1330    with self.cached_session():
1331      self.assertEqual(0, self.evaluate(loss))
1332
1333  def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
1334    loss = losses.cosine_distance(
1335        predictions=constant_op.constant(self._predictions),
1336        labels=constant_op.constant(self._labels),
1337        dim=2,
1338        weights=array_ops.zeros((3, 2, 1)))
1339    with self.cached_session():
1340      self.assertEqual(0, self.evaluate(loss))
1341
1342
1343class AddLossTest(test.TestCase):
1344
1345  def testNoCollectLossesBatch2(self):
1346    logits = constant_op.constant([[1.2, 0.4, -1.0, -1.1]] * 2)
1347    labels = constant_op.constant([[1.0, 0.0, 0.0, 1.0]] * 2)
1348    self.assertFalse(util.get_losses())
1349    losses.absolute_difference(logits, labels, loss_collection=None)
1350    losses.log_loss(logits, labels, loss_collection=None)
1351    losses.mean_squared_error(logits, labels, loss_collection=None)
1352    losses.sigmoid_cross_entropy(logits, labels, loss_collection=None)
1353    losses.softmax_cross_entropy(logits, labels, loss_collection=None)
1354    self.assertFalse(util.get_losses())
1355
1356
1357class ComputeWeightedLossTest(test.TestCase):
1358
1359  def setUp(self):
1360    super(ComputeWeightedLossTest, self).setUp()
1361    self._shape = (3, 2, 4)
1362    raw_losses = np.zeros(self._shape)
1363    next_loss = 0.0
1364    for i in range(self._shape[0]):
1365      for j in range(self._shape[1]):
1366        for k in range(self._shape[2]):
1367          raw_losses[i][j][k] = next_loss
1368          next_loss += 1.0
1369    raw_losses.setflags(write=False)
1370    self._raw_losses = raw_losses
1371
1372  def testUnweighted(self):
1373    for reduction in losses.Reduction.all():
1374      with ops.Graph().as_default() as g:
1375        self.assertEqual(0, len(util.get_losses()))
1376        raw_losses = self._raw_losses
1377        unweighted_losses = (
1378            losses.compute_weighted_loss(raw_losses, reduction=reduction),
1379            losses.compute_weighted_loss(
1380                raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction),
1381            losses.compute_weighted_loss(
1382                raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction),
1383            losses.compute_weighted_loss(
1384                raw_losses, weights=np.ones((1, 2, 1)), reduction=reduction),
1385            losses.compute_weighted_loss(
1386                raw_losses, weights=np.ones((1, 2, 4)), reduction=reduction),
1387            losses.compute_weighted_loss(
1388                raw_losses, weights=np.ones((3, 1, 1)), reduction=reduction),
1389            losses.compute_weighted_loss(
1390                raw_losses, weights=np.ones((3, 1, 4)), reduction=reduction),
1391            losses.compute_weighted_loss(
1392                raw_losses, weights=np.ones((3, 2, 1)), reduction=reduction),
1393            losses.compute_weighted_loss(
1394                raw_losses, weights=np.ones(self._shape), reduction=reduction)
1395        )
1396        self.assertEqual(9, len(util.get_losses()))
1397        with self.session(g):
1398          for unweighted_loss in unweighted_losses:
1399            if reduction == losses.Reduction.NONE:
1400              self.assertAllClose(self._raw_losses,
1401                                  self.evaluate(unweighted_loss))
1402            elif reduction == losses.Reduction.SUM:
1403              self.assertAllClose(
1404                  np.sum(self._raw_losses), self.evaluate(unweighted_loss))
1405            else:
1406              # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS,
1407              # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE.
1408              self.assertAllClose(
1409                  np.mean(self._raw_losses), self.evaluate(unweighted_loss))
1410
1411  def testUnweightedFromPlaceholder(self):
1412    for reduction in losses.Reduction.all():
1413      with ops.Graph().as_default() as g:
1414        self.assertEqual(0, len(util.get_losses()))
1415        raw_losses = array_ops.placeholder(dtype=dtypes.float32)
1416        feed_dict = {raw_losses: self._raw_losses}
1417        unweighted_losses = (
1418            losses.compute_weighted_loss(raw_losses, reduction=reduction),
1419            losses.compute_weighted_loss(
1420                raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction),
1421            losses.compute_weighted_loss(
1422                raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction),
1423        )
1424        self.assertEqual(3, len(util.get_losses()))
1425        with self.session(g):
1426          for unweighted_loss in unweighted_losses:
1427            if reduction == losses.Reduction.NONE:
1428              self.assertAllClose(
1429                  self._raw_losses, unweighted_loss.eval(feed_dict))
1430            elif reduction == losses.Reduction.SUM:
1431              self.assertAllClose(
1432                  np.sum(self._raw_losses), unweighted_loss.eval(feed_dict))
1433            else:
1434              # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS,
1435              # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE.
1436              self.assertAllClose(
1437                  np.mean(self._raw_losses), unweighted_loss.eval(feed_dict))
1438
1439  def testScalarWeight(self):
1440    with ops.Graph().as_default():
1441      self.assertEqual(0, len(util.get_losses()))
1442      weight = 17.0
1443      weighted_loss = losses.compute_weighted_loss(
1444          self._raw_losses, weights=weight)
1445      self.assertEqual(1, len(util.get_losses()))
1446      with self.cached_session():
1447        self.assertAllClose(
1448            np.mean(weight * self._raw_losses), self.evaluate(weighted_loss))
1449
1450  def _test_invalid_weights(self, weights):
1451    with ops.Graph().as_default():
1452      self.assertEqual(0, len(util.get_losses()))
1453      expected_error_msg = 'weights can not be broadcast to values'
1454
1455      # Static check.
1456      with self.assertRaisesRegex(ValueError, expected_error_msg):
1457        losses.compute_weighted_loss(self._raw_losses, weights=weights)
1458
1459      # Dynamic check.
1460      weights_placeholder = array_ops.placeholder(dtypes.float32)
1461      weighted_loss = losses.compute_weighted_loss(
1462          self._raw_losses, weights=weights_placeholder)
1463      self.assertEqual(1, len(util.get_losses()))
1464      with self.cached_session():
1465        with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg):
1466          weighted_loss.eval(feed_dict={weights_placeholder: weights})
1467
1468  def testInvalidWeightTooManyDims(self):
1469    self._test_invalid_weights(np.zeros(shape=(2, 2, 2, 2)))
1470
1471  def testInvalidWeightMismatchedDim(self):
1472    with ops.Graph().as_default():
1473      raw_losses = array_ops.reshape(self._raw_losses, shape=(3, 2, 4, 1))
1474      weights = np.ones(shape=(3, 2, 4, 2))
1475      expected_error_msg = 'weights can not be broadcast to values'
1476      self.assertEqual(0, len(util.get_losses()))
1477
1478      # Static check.
1479      with self.assertRaisesRegex(ValueError, expected_error_msg):
1480        losses.compute_weighted_loss(raw_losses, weights=weights)
1481
1482      # Dynamic check.
1483      weights_placeholder = array_ops.placeholder(dtypes.float32)
1484      weighted_loss = losses.compute_weighted_loss(
1485          raw_losses, weights=weights_placeholder)
1486      self.assertEqual(1, len(util.get_losses()))
1487      with self.cached_session():
1488        with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg):
1489          weighted_loss.eval(feed_dict={weights_placeholder: weights})
1490
1491  def testInvalid3Weight(self):
1492    self._test_invalid_weights((17.0, 5.0, 2.0))
1493
1494  def testInvalid3x1Weight(self):
1495    self._test_invalid_weights(((17.0,), (5.0,), (2.0,),))
1496
1497  def testInvalid3x2Weight(self):
1498    self._test_invalid_weights((
1499        (17.0, 3.0),
1500        (5.0, 31.0),
1501        (2.0, 7.0),))
1502
1503  def testInvalid1x2Weight(self):
1504    self._test_invalid_weights((17.0, 3.0,),)
1505
1506  def testInvalidScalar1DWeight(self):
1507    self._test_invalid_weights((17.0,),)
1508
1509  def _test_valid_weights(self, weights):
1510    for reduction in losses.Reduction.all():
1511      with ops.Graph().as_default() as g:
1512        self.assertEqual(0, len(util.get_losses()))
1513        weighted_loss = losses.compute_weighted_loss(
1514            self._raw_losses, weights=weights, reduction=reduction)
1515        self.assertEqual(1, len(util.get_losses()))
1516        with self.session(g):
1517          weighted_losses = weights * self._raw_losses
1518          weighted_sum = np.sum(weighted_losses)
1519          if reduction == losses.Reduction.NONE:
1520            self.assertAllClose(weighted_losses, self.evaluate(weighted_loss))
1521          elif reduction == losses.Reduction.SUM:
1522            self.assertAllClose(weighted_sum, self.evaluate(weighted_loss))
1523          else:
1524            broadcast_weights = weights * np.ones_like(self._raw_losses)
1525            if reduction == losses.Reduction.MEAN:
1526              self.assertAllClose(weighted_sum / np.sum(broadcast_weights),
1527                                  self.evaluate(weighted_loss))
1528            elif (reduction == losses.Reduction.SUM_OVER_NONZERO_WEIGHTS or
1529                  reduction == losses.Reduction.SUM_BY_NONZERO_WEIGHTS):
1530              self.assertAllClose(
1531                  weighted_sum / np.count_nonzero(broadcast_weights),
1532                  self.evaluate(weighted_loss))
1533            elif reduction == losses.Reduction.SUM_OVER_BATCH_SIZE:
1534              self.assertAllClose(weighted_sum / self._raw_losses.size,
1535                                  self.evaluate(weighted_loss))
1536
1537  def test1x1x1Weight(self):
1538    self._test_valid_weights((((17.0,),),))
1539
1540  def test1x2x1Weight(self):
1541    self._test_valid_weights((((17.0,), (3.0,),),))
1542
1543  def test1x1x4Weight(self):
1544    self._test_valid_weights((((17.0, 0.0, 2.0, 5.0),),))
1545
1546  def test3x1x1Weight(self):
1547    self._test_valid_weights((((17.0,),), ((5.0,),), ((2.0,),),))
1548
1549  def test3x2x1Weight(self):
1550    self._test_valid_weights((
1551        ((17.0,), (3.0,)),
1552        ((5.0,), (31.0,)),
1553        ((2.0,), (7.0,)),
1554    ))
1555
1556  def test3x1x4Weight(self):
1557    self._test_valid_weights((
1558        ((17.0, 0.0, 2.0, 5.0),),
1559        ((5.0, 31.0, 17.0, 5.0),),
1560        ((7.0, 3.0, 11.0, 5.0),),
1561    ))
1562
1563  def test1x2x4Weight(self):
1564    self._test_valid_weights(((
1565        (17.0, 0.0, 2.0, 5.0),
1566        (3.0, 13.0, 11.0, 2.0),
1567    ),))
1568
1569  def test3x2x4Weight(self):
1570    self._test_valid_weights((
1571        ((17.0, 0.0, 2.0, 5.0), (3.0, 13.0, 11.0, 2.0),),
1572        ((5.0, 31.0, 17.0, 5.0), (13.0, 3.0, 0.0, 11.0),),
1573        ((0.0, 3.0, 11.0, 5.0), (13.0, 11.0, 1.0, 7.0),),
1574    ))
1575
1576
1577if __name__ == '__main__':
1578  test.main()
1579