1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for losses.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import errors_impl 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import random_seed 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import init_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import random_ops 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.ops import variables 31from tensorflow.python.ops.losses import losses 32from tensorflow.python.ops.losses import util 33from tensorflow.python.platform import test 34from tensorflow.python.training import momentum as momentum_lib 35 36 37@test_util.run_deprecated_v1 38class AbsoluteDifferenceLossTest(test.TestCase): 39 40 def setUp(self): 41 super(AbsoluteDifferenceLossTest, self).setUp() 42 self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) 43 self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) 44 45 def testValueErrorThrownWhenWeightIsNone(self): 46 with self.cached_session(): 47 with self.assertRaises(ValueError): 48 losses.absolute_difference( 49 self._predictions, self._predictions, weights=None) 50 51 def testAllCorrectNoLossWeight(self): 52 loss = losses.absolute_difference(self._predictions, self._predictions) 53 with self.cached_session(): 54 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 55 56 def testNonZeroLoss(self): 57 loss = losses.absolute_difference(self._labels, self._predictions) 58 with self.cached_session(): 59 self.assertAlmostEqual(5.5, self.evaluate(loss), 3) 60 61 def testNonZeroLossWithPythonScalarWeight(self): 62 weights = 2.3 63 loss = losses.absolute_difference(self._labels, self._predictions, weights) 64 with self.cached_session(): 65 self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3) 66 67 def testNonZeroLossWithScalarTensorWeight(self): 68 weights = 2.3 69 loss = losses.absolute_difference(self._labels, self._predictions, 70 constant_op.constant(weights)) 71 with self.cached_session(): 72 self.assertAlmostEqual(5.5 * weights, self.evaluate(loss), 3) 73 74 def testNonZeroLossWithOneDimBatchSpecificWeights(self): 75 weights = constant_op.constant((1.2, 0.0), shape=(2, 1)) 76 loss = losses.absolute_difference(self._labels, self._predictions, weights) 77 with self.cached_session(): 78 self.assertAlmostEqual(5.6, self.evaluate(loss), 3) 79 80 def testNonZeroLossWithTwoDimBatchSpecificWeights(self): 81 weights = constant_op.constant([1.2, 0.0], shape=[2, 1]) 82 loss = losses.absolute_difference(self._labels, self._predictions, weights) 83 with self.cached_session(): 84 self.assertAlmostEqual(5.6, self.evaluate(loss), 3) 85 86 def testNonZeroLossWithSampleSpecificWeights(self): 87 weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) 88 loss = losses.absolute_difference(self._labels, self._predictions, weights) 89 with self.cached_session(): 90 self.assertAlmostEqual(16.6, self.evaluate(loss), 3) 91 92 def testNonZeroLossWithSampleSpecificWeightsMostZero(self): 93 weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) 94 loss = losses.absolute_difference(self._labels, self._predictions, weights) 95 with self.cached_session(): 96 self.assertAlmostEqual(6.0, self.evaluate(loss), 3) 97 98 def testLossWithSampleSpecificWeightsAllZero(self): 99 weights = array_ops.zeros((2, 3)) 100 loss = losses.absolute_difference(self._labels, self._predictions, weights) 101 with self.cached_session(): 102 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 103 104 @test_util.assert_no_new_pyobjects_executing_eagerly 105 def testEagerNoMemoryLeaked(self): 106 # This is a somewhat convoluted way of testing that nothing gets added to 107 # a global collection. 108 predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) 109 labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) 110 losses.absolute_difference(labels, predictions) 111 112 113class SoftmaxCrossEntropyLossTest(test.TestCase): 114 115 def testNoneWeightRaisesValueError(self): 116 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 117 [0.0, 0.0, 10.0]]) 118 labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 119 with self.cached_session(): 120 with self.assertRaises(ValueError): 121 losses.softmax_cross_entropy(labels, logits, weights=None) 122 123 @test_util.run_deprecated_v1 124 def testAllCorrect(self): 125 with self.cached_session(): 126 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 127 [0.0, 0.0, 10.0]]) 128 labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 129 loss = losses.softmax_cross_entropy(labels, logits) 130 self.assertEqual('softmax_cross_entropy_loss/value', loss.op.name) 131 self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) 132 133 @test_util.run_deprecated_v1 134 def testAllWrong(self): 135 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 136 [0.0, 0.0, 10.0]]) 137 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 138 139 with self.cached_session(): 140 loss = losses.softmax_cross_entropy(labels, logits) 141 self.assertEqual(loss.op.name, 'softmax_cross_entropy_loss/value') 142 self.assertAlmostEqual(self.evaluate(loss), 10.0, 3) 143 144 @test_util.run_deprecated_v1 145 def testNonZeroLossWithPythonScalarWeight(self): 146 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 147 [0.0, 0.0, 10.0]]) 148 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 149 weights = 2.3 150 with self.cached_session(): 151 loss = losses.softmax_cross_entropy(labels, logits, weights) 152 self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) 153 154 @test_util.run_deprecated_v1 155 def testNonZeroLossWithScalarTensorWeight(self): 156 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 157 [0.0, 0.0, 10.0]]) 158 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 159 weights = 2.3 160 with self.cached_session(): 161 loss = losses.softmax_cross_entropy(labels, logits, 162 constant_op.constant(weights)) 163 self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) 164 165 def testNonZeroLossWithOneDimBatchSpecificWeights(self): 166 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 167 [0.0, 0.0, 10.0]]) 168 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 169 weights = constant_op.constant((1.2, 3.4, 5.6)) 170 with self.cached_session(): 171 loss = losses.softmax_cross_entropy(labels, logits, weights) 172 self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, 173 self.evaluate(loss), 3) 174 175 def testAllWrongAllWeightsMissing(self): 176 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 177 [0.0, 0.0, 10.0]]) 178 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 179 weights = constant_op.constant([0, 0, 0], shape=[3]) 180 with self.cached_session(): 181 loss = losses.softmax_cross_entropy(labels, logits, weights) 182 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 183 184 def testSomeWeightsMissing(self): 185 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 186 [0.0, 0.0, 10.0]]) 187 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 188 weights = constant_op.constant([1.2, 0, 0], shape=[3]) 189 with self.cached_session(): 190 loss = losses.softmax_cross_entropy(labels, logits, weights) 191 self.assertAlmostEqual(12.0, self.evaluate(loss), 3) 192 193 def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self): 194 with self.cached_session(): 195 logits = constant_op.constant([[100.0, -100.0, -100.0], 196 [-100.0, 100.0, -100.0], 197 [-100.0, -100.0, 100.0]]) 198 labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 199 weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) 200 201 with self.assertRaises(ValueError): 202 losses.softmax_cross_entropy(labels, logits, weights=weights).eval() 203 204 @test_util.run_deprecated_v1 205 def testSoftmaxLabelSmoothing(self): 206 with self.cached_session(): 207 # Softmax Cross Entropy Loss is: 208 # -\sum_i p_i \log q_i 209 # where for a softmax activation 210 # \log q_i = x_i - \log \sum_j \exp x_j 211 # = x_i - x_max - \log \sum_j \exp (x_j - x_max) 212 # For our activations, [100, -100, -100] the log partition function 213 # becomes \log ( exp(0) + exp(-200) + exp(-200) ) = 0 214 # so our log softmaxes become: [0, -200, -200] 215 # so our cross entropy loss is: 216 # -(1 - L + L/n) * 0 + 400 * L/n = 400 L/n 217 logits = constant_op.constant([[100.0, -100.0, -100.0]]) 218 labels = constant_op.constant([[1, 0, 0]]) 219 label_smoothing = 0.1 220 loss = losses.softmax_cross_entropy( 221 labels, logits, label_smoothing=label_smoothing) 222 self.assertEqual(loss.op.name, 'softmax_cross_entropy_loss/value') 223 expected_value = 400.0 * label_smoothing / 3.0 224 self.assertAlmostEqual(self.evaluate(loss), expected_value, 3) 225 226 227class SparseSoftmaxCrossEntropyLossTest(test.TestCase): 228 229 def testNoneWeightRaisesValueError(self): 230 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 231 [0.0, 0.0, 10.0]]) 232 labels = constant_op.constant([[0], [1], [2]]) 233 with self.cached_session(): 234 with self.assertRaises(ValueError): 235 losses.sparse_softmax_cross_entropy(labels, logits, weights=None) 236 237 @test_util.run_deprecated_v1 238 def testAllCorrectInt32Labels(self): 239 with self.cached_session(): 240 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 241 [0.0, 0.0, 10.0]]) 242 labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) 243 loss = losses.sparse_softmax_cross_entropy(labels, logits) 244 self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') 245 self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) 246 247 @test_util.assert_no_new_pyobjects_executing_eagerly 248 def testEagerNoMemoryLeaked(self): 249 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 250 [0.0, 0.0, 10.0]]) 251 labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32) 252 losses.sparse_softmax_cross_entropy(labels, logits) 253 254 @test_util.run_deprecated_v1 255 def testAllCorrectInt64Labels(self): 256 with self.cached_session(): 257 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 258 [0.0, 0.0, 10.0]]) 259 labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64) 260 loss = losses.sparse_softmax_cross_entropy(labels, logits) 261 self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') 262 self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) 263 264 @test_util.run_deprecated_v1 265 def testAllCorrectNonColumnLabels(self): 266 with self.cached_session(): 267 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 268 [0.0, 0.0, 10.0]]) 269 labels = constant_op.constant([0, 1, 2]) 270 loss = losses.sparse_softmax_cross_entropy(labels, logits) 271 self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') 272 self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) 273 274 @test_util.run_deprecated_v1 275 def testAllWrongInt32Labels(self): 276 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 277 [0.0, 0.0, 10.0]]) 278 labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32) 279 280 with self.cached_session(): 281 loss = losses.sparse_softmax_cross_entropy(labels, logits) 282 self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') 283 self.assertAlmostEqual(self.evaluate(loss), 10.0, 3) 284 285 @test_util.run_deprecated_v1 286 def testAllWrongInt64Labels(self): 287 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 288 [0.0, 0.0, 10.0]]) 289 labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64) 290 291 with self.cached_session(): 292 loss = losses.sparse_softmax_cross_entropy(labels, logits) 293 self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') 294 self.assertAlmostEqual(self.evaluate(loss), 10.0, 3) 295 296 @test_util.run_deprecated_v1 297 def testAllWrongNonColumnLabels(self): 298 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 299 [0.0, 0.0, 10.0]]) 300 labels = constant_op.constant([2, 0, 1]) 301 302 with self.cached_session(): 303 loss = losses.sparse_softmax_cross_entropy(labels, logits) 304 self.assertEqual(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') 305 self.assertAlmostEqual(self.evaluate(loss), 10.0, 3) 306 307 @test_util.run_deprecated_v1 308 def testNonZeroLossWithPythonScalarWeight(self): 309 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 310 [0.0, 0.0, 10.0]]) 311 labels = constant_op.constant([[2], [0], [1]]) 312 weights = 2.3 313 with self.cached_session(): 314 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 315 self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) 316 317 @test_util.run_deprecated_v1 318 def testNonZeroLossWithScalarTensorWeight(self): 319 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 320 [0.0, 0.0, 10.0]]) 321 labels = constant_op.constant([[2], [0], [1]]) 322 weights = 2.3 323 with self.cached_session(): 324 loss = losses.sparse_softmax_cross_entropy(labels, logits, 325 constant_op.constant(weights)) 326 self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) 327 328 def testNonZeroLossWith1DTensorWeight(self): 329 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 330 [0.0, 0.0, 10.0]]) 331 labels = constant_op.constant([[2], [0], [1]]) 332 weights = 2.3 333 with self.cached_session(): 334 loss = losses.sparse_softmax_cross_entropy( 335 labels, logits, constant_op.constant((weights,))) 336 self.assertAlmostEqual(weights * 10.0, self.evaluate(loss), 3) 337 338 @test_util.run_deprecated_v1 339 def testNonZeroLossWithPlaceholderForWeights(self): 340 logits = constant_op.constant([[10.0, 0.0, 0.0], 341 [0.0, 10.0, 0.0], 342 [0.0, 0.0, 10.0]]) 343 labels = constant_op.constant([[2], [0], [1]]) 344 weights = array_ops.placeholder(dtypes.float32) 345 with self.cached_session() as sess: 346 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 347 loss_val = sess.run(loss, 348 feed_dict={weights: ((1.2,), (3.4,), (5.6,))}) 349 self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3) 350 351 @test_util.run_deprecated_v1 352 def testUnknownShapePlaceholderForLogitsLabelsButScalarWeights(self): 353 logits = array_ops.placeholder(dtypes.float32) 354 labels = array_ops.placeholder(dtypes.int32) 355 weights = 1.0 356 with self.cached_session() as sess: 357 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 358 loss_val = sess.run(loss, 359 feed_dict={ 360 logits: [[10.0, 0.0, 0.0], 361 [0.0, 10.0, 0.0], 362 [0.0, 0.0, 10.0]], 363 labels: [[2], [0], [1]], 364 }) 365 self.assertAlmostEqual((1.0 + 1.0 + 1.0) * 10.0 / 3.0, loss_val, 3) 366 367 @test_util.run_deprecated_v1 368 def testNonZeroLossWithPlaceholderForLogitsLabelsAndWeights(self): 369 logits = array_ops.placeholder(dtypes.float32, shape=(None, 3)) 370 labels = array_ops.placeholder(dtypes.int32, shape=(None, 1)) 371 weights = array_ops.placeholder(dtypes.float32) 372 with self.cached_session() as sess: 373 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 374 loss_val = sess.run(loss, 375 feed_dict={ 376 logits: [[10.0, 0.0, 0.0], 377 [0.0, 10.0, 0.0], 378 [0.0, 0.0, 10.0]], 379 labels: [[2], [0], [1]], 380 weights: ((1.2,), (3.4,), (5.6,)), 381 }) 382 self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss_val, 3) 383 384 def testNonZeroLossWithOneDimBatchSpecificWeights(self): 385 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 386 [0.0, 0.0, 10.0]]) 387 labels = constant_op.constant([[2], [0], [1]]) 388 weights = constant_op.constant([1.2, 3.4, 5.6], shape=(3, 1)) 389 with self.cached_session(): 390 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 391 self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, 392 self.evaluate(loss), 3) 393 394 def testNonZeroLossWithColumnWeights(self): 395 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 396 [0.0, 0.0, 10.0]]) 397 labels = constant_op.constant([[2], [0], [1]]) 398 weights = constant_op.constant([[1.2], [3.4], [5.6]]) 399 with self.cached_session(): 400 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 401 self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, 402 self.evaluate(loss), 3) 403 404 def testAllWrongAllWeightsMissing(self): 405 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 406 [0.0, 0.0, 10.0]]) 407 labels = constant_op.constant([[2], [0], [1]]) 408 weights = constant_op.constant([0, 0, 0], shape=(3, 1)) 409 with self.cached_session(): 410 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 411 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 412 413 def testSomeWeightsMissing(self): 414 logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], 415 [0.0, 0.0, 10.0]]) 416 labels = constant_op.constant([[2], [0], [1]]) 417 weights = constant_op.constant([1.2, 0, 0], shape=(3, 1)) 418 with self.cached_session(): 419 loss = losses.sparse_softmax_cross_entropy(labels, logits, weights) 420 self.assertAlmostEqual(12.0, self.evaluate(loss), 3) 421 422 @test_util.run_deprecated_v1 423 def testMeasurementSpecificWeightsRaisesException(self): 424 with self.cached_session(): 425 logits = constant_op.constant([[100.0, -100.0, -100.0], 426 [-100.0, 100.0, -100.0], 427 [-100.0, -100.0, 100.0]]) 428 labels = constant_op.constant([[0], [1], [2]]) 429 weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) 430 431 with self.assertRaises(ValueError): 432 losses.sparse_softmax_cross_entropy( 433 labels, logits, weights=weights).eval() 434 435 def testInconsistentWeightSizeRaisesException(self): 436 """The weight tensor has incorrect number of elements.""" 437 with self.cached_session(): 438 logits = constant_op.constant([[100.0, -100.0, -100.0], 439 [-100.0, 100.0, -100.0], 440 [-100.0, -100.0, 100.0]]) 441 labels = constant_op.constant([[0], [1], [2]]) 442 weights = constant_op.constant([1.2, 3.4, 5.6, 7.8]) 443 444 with self.assertRaises(ValueError): 445 losses.sparse_softmax_cross_entropy( 446 labels, logits, weights=weights).eval() 447 448 def testInconsistentLabelSizeRaisesException(self): 449 """The label tensor has incorrect number of elements.""" 450 with self.cached_session(): 451 logits = constant_op.constant([[100.0, -100.0, -100.0], 452 [-100.0, 100.0, -100.0], 453 [-100.0, -100.0, 100.0]]) 454 labels = constant_op.constant([[0], [1], [2], [3]]) 455 weights = constant_op.constant([1.2, 3.4, 5.6]) 456 457 with self.assertRaises(ValueError): 458 losses.sparse_softmax_cross_entropy( 459 labels, logits, weights=weights).eval() 460 461 @test_util.run_deprecated_v1 462 def testInconsistentWeightShapeRaisesException(self): 463 """The weight tensor has incorrect shape.""" 464 with self.cached_session(): 465 logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0], 466 [-100.0, 100.0, -100.0, -100.0], 467 [-100.0, -100.0, 100.0, -100.0], 468 [-100.0, -100.0, -100.0, 100.0]]) 469 labels = constant_op.constant([[0], [1], [2], [3]]) 470 weights = constant_op.constant([[1.2, 3.4], [5.6, 7.8]]) 471 472 with self.assertRaises(ValueError): 473 losses.sparse_softmax_cross_entropy( 474 labels, logits, weights=weights).eval() 475 476 @test_util.run_deprecated_v1 477 def testInconsistentLabelShapeRaisesException(self): 478 """The label tensor has incorrect shape.""" 479 with self.cached_session(): 480 logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0], 481 [-100.0, 100.0, -100.0, -100.0], 482 [-100.0, -100.0, 100.0, -100.0], 483 [-100.0, -100.0, -100.0, 100.0]]) 484 labels = constant_op.constant([[0, 1], [2, 3]]) 485 weights = constant_op.constant(1.2) 486 487 with self.assertRaisesRegex( 488 ValueError, 489 '`labels.shape.rank` must equal `logits.shape.rank - 1`'): 490 losses.sparse_softmax_cross_entropy( 491 labels, logits, weights=weights).eval() 492 493 494class SigmoidCrossEntropyLossTest(test.TestCase): 495 496 @test_util.run_deprecated_v1 497 def testAllCorrectSigmoid(self): 498 with self.cached_session(): 499 logits = constant_op.constant([[100.0, -100.0, -100.0], 500 [-100.0, 100.0, -100.0], 501 [-100.0, -100.0, 100.0]]) 502 labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 503 loss = losses.sigmoid_cross_entropy(labels, logits) 504 self.assertEqual(logits.dtype, loss.dtype) 505 self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name) 506 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 507 508 @test_util.run_deprecated_v1 509 def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self): 510 logits = array_ops.placeholder(dtypes.float32, shape=(None, 1)) 511 labels = array_ops.placeholder(dtypes.float32, shape=(None, 1)) 512 weights = array_ops.ones_like(logits, dtype=dtypes.float32) 513 514 loss = losses.sigmoid_cross_entropy(labels, logits, weights) 515 self.assertEqual(logits.dtype, loss.dtype) 516 517 with self.cached_session() as sess: 518 loss = sess.run(loss, 519 feed_dict={ 520 logits: np.ones((32, 1)), 521 labels: np.ones((32, 1)), 522 }) 523 self.assertAlmostEqual(0.313, loss, 3) 524 525 @test_util.run_deprecated_v1 526 def testLossWithSingleDimPlaceholderForLogitsAndWeights2(self): 527 logits = array_ops.placeholder(dtypes.float32, shape=(None, 2)) 528 labels = array_ops.placeholder(dtypes.float32, shape=(None, 2)) 529 weights = array_ops.ones_like(logits, dtype=dtypes.float32) 530 531 loss = losses.sigmoid_cross_entropy(labels, logits, weights) 532 self.assertEqual(logits.dtype, loss.dtype) 533 534 with self.cached_session() as sess: 535 loss = sess.run(loss, 536 feed_dict={ 537 logits: np.ones((32, 2)), 538 labels: np.ones((32, 2)), 539 }) 540 self.assertAlmostEqual(0.313, loss, 3) 541 542 @test_util.run_deprecated_v1 543 def testAllWrongSigmoid(self): 544 with self.cached_session(): 545 logits = constant_op.constant([[100.0, -100.0, -100.0], 546 [-100.0, 100.0, -100.0], 547 [-100.0, -100.0, 100.0]]) 548 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 549 loss = losses.sigmoid_cross_entropy(labels, logits) 550 self.assertEqual(logits.dtype, loss.dtype) 551 self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name) 552 self.assertAlmostEqual(self.evaluate(loss), 600.0 / 9.0, 3) 553 554 @test_util.run_deprecated_v1 555 def testAllWrongSigmoidWithMeasurementSpecificWeights(self): 556 with self.cached_session(): 557 logits = constant_op.constant([[100.0, -100.0, -100.0], 558 [-100.0, 100.0, -100.0], 559 [-100.0, -100.0, 100.0]]) 560 labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 561 weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]]) 562 loss = losses.sigmoid_cross_entropy(labels, logits, weights) 563 self.assertEqual(logits.dtype, loss.dtype) 564 self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name) 565 self.assertAlmostEqual(1700.0 / 7.0, self.evaluate(loss), 3) 566 567 @test_util.run_deprecated_v1 568 def testMultiCorrectSigmoid(self): 569 logits = constant_op.constant([[100.0, -100.0, 100.0], 570 [100.0, 100.0, -100.0], 571 [-100.0, 100.0, 100.0]]) 572 labels = constant_op.constant([[1, 0, 1], [1, 1, 0], [0, 1, 1]]) 573 loss = losses.sigmoid_cross_entropy(labels, logits) 574 self.assertEqual(logits.dtype, loss.dtype) 575 self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name) 576 577 with self.cached_session(): 578 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 579 580 def testSigmoidFloat64(self): 581 logits = constant_op.constant(( 582 (100.0, -100.0, 100.0), 583 (100.0, -100.0, 100.0), 584 (100.0, 100.0, -100.0) 585 ), dtype=dtypes.float64) 586 labels = constant_op.constant(( 587 (1, 0, 1), (1, 1, 0), (0, 1, 1) 588 ), dtype=dtypes.int64) 589 loss = losses.sigmoid_cross_entropy(labels, logits) 590 self.assertEqual(logits.dtype, loss.dtype) 591 592 with self.cached_session(): 593 self.assertAlmostEqual(44.444, self.evaluate(loss), 3) 594 595 def testSigmoidNoReduction(self): 596 logits = constant_op.constant(( 597 (100.0, -100.0, 100.0), 598 (100.0, -100.0, 100.0), 599 (100.0, 100.0, -100.0))) 600 labels = constant_op.constant(((1, 0, 1), (1, 1, 0), (0, 1, 1))) 601 loss = losses.sigmoid_cross_entropy( 602 labels, logits, reduction=losses.Reduction.NONE) 603 self.assertEqual(logits.dtype, loss.dtype) 604 605 with self.cached_session(): 606 self.assertAllClose(((0., 0., 0.), (0., 100., 100.), (100., 0., 100.)), 607 self.evaluate(loss), 3) 608 609 @test_util.run_deprecated_v1 610 def testSigmoidLabelSmoothingCorrect(self): 611 with self.cached_session(): 612 logits = constant_op.constant([[100.0, -100.0, -100.0]]) 613 labels = constant_op.constant([[1, 0, 1]]) 614 # Sigmoid cross entropy loss is: 615 # max(x,0) - x*z + log(1 + exp(-abs(x))) 616 # The new labels are: 617 # z' = z * (1 - L) + 0.5 L 618 # 1 -> 1 - 0.5 L 619 # 0 -> 0.5 L 620 # here we expect: 621 # 1/3 * (100 - 100 * (1 - 0.5 L) + 0 622 # + 0 + 100 * (0.5 L) + 0 623 # + 0 + 100 * (1 - 0.5 L) + 0) 624 # = 1/3 * (100 + 50 L) 625 label_smoothing = 0.1 626 loss = losses.sigmoid_cross_entropy( 627 labels, logits, label_smoothing=label_smoothing) 628 self.assertEqual(logits.dtype, loss.dtype) 629 self.assertEqual('sigmoid_cross_entropy_loss/value', loss.op.name) 630 expected_value = (100.0 + 50.0 * label_smoothing) / 3.0 631 self.assertAlmostEqual(self.evaluate(loss), expected_value, 3) 632 633 @test_util.run_deprecated_v1 634 def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self): 635 with self.cached_session(): 636 label_smoothing = 0.1 637 sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]]) 638 sigmoid_labels = constant_op.constant([[1, 0, 1]]) 639 sigmoid_loss = losses.sigmoid_cross_entropy( 640 sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing) 641 self.assertEqual(sigmoid_logits.dtype, sigmoid_loss.dtype) 642 643 softmax_logits = constant_op.constant( 644 [[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]]) 645 softmax_labels = constant_op.constant([[0, 1], [1, 0], [0, 1]]) 646 softmax_loss = losses.softmax_cross_entropy( 647 softmax_labels, softmax_logits, label_smoothing=label_smoothing) 648 self.assertAlmostEqual( 649 self.evaluate(sigmoid_loss), self.evaluate(softmax_loss), 3) 650 651 652@test_util.run_deprecated_v1 653class LogLossTest(test.TestCase): 654 655 def setUp(self): 656 super(LogLossTest, self).setUp() 657 predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3)) 658 labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3)) 659 660 self._np_predictions = predictions 661 self._np_labels = labels 662 663 epsilon = 1e-7 664 self._expected_losses = np.multiply( 665 labels, np.log(predictions + epsilon)) + np.multiply( 666 1 - labels, np.log(1 - predictions + epsilon)) 667 668 self._predictions = constant_op.constant(predictions) 669 self._labels = constant_op.constant(labels) 670 671 def testValueErrorThrownWhenWeightIsNone(self): 672 with self.cached_session(): 673 with self.assertRaises(ValueError): 674 losses.log_loss(self._labels, self._labels, weights=None) 675 676 def testAllCorrectNoLossWeight(self): 677 loss = losses.log_loss(self._labels, self._labels) 678 with self.cached_session(): 679 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 680 681 def testAllCorrectNoLossWeightWithPlaceholder(self): 682 tf_predictions = array_ops.placeholder( 683 dtypes.float32, shape=self._np_labels.shape) 684 loss = losses.log_loss(self._labels, tf_predictions) 685 with self.cached_session(): 686 self.assertAlmostEqual( 687 0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3) 688 689 def testNonZeroLoss(self): 690 loss = losses.log_loss(self._labels, self._predictions) 691 with self.cached_session(): 692 self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0, 693 self.evaluate(loss), 3) 694 695 def testNonZeroLossWithPythonScalarWeight(self): 696 weights = 2.3 697 loss = losses.log_loss(self._labels, self._predictions, weights) 698 with self.cached_session(): 699 self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, 700 self.evaluate(loss), 3) 701 702 def testNonZeroLossWithScalarTensorWeight(self): 703 weights = 2.3 704 loss = losses.log_loss(self._labels, self._predictions, 705 constant_op.constant(weights)) 706 with self.cached_session(): 707 self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, 708 self.evaluate(loss), 3) 709 710 def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self): 711 tf_predictions = array_ops.placeholder( 712 dtypes.float32, shape=self._np_predictions.shape) 713 weights = 2.3 714 loss = losses.log_loss(self._labels, tf_predictions, 715 constant_op.constant(weights)) 716 with self.cached_session() as sess: 717 loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) 718 self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, 719 loss, 3) 720 721 def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self): 722 tf_predictions = array_ops.placeholder(dtypes.float32, shape=[None, None]) 723 weights = 2.3 724 loss = losses.log_loss(self._labels, tf_predictions, 725 constant_op.constant(weights)) 726 with self.cached_session() as sess: 727 loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) 728 self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, 729 loss, 3) 730 731 def testNonZeroLossWithOneDimBatchSpecificWeights(self): 732 weights = constant_op.constant((1.2, 3.4), shape=(2, 1)) 733 expected_losses = np.multiply( 734 self._expected_losses, 735 np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))) 736 loss = losses.log_loss(self._labels, self._predictions, weights) 737 with self.cached_session(): 738 self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, 739 self.evaluate(loss), 3) 740 741 def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self): 742 weights = constant_op.constant((1.2, 0), shape=(2, 1)) 743 expected_losses = np.multiply(self._expected_losses, 744 np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape( 745 (2, 3))) 746 loss = losses.log_loss(self._labels, self._predictions, weights) 747 with self.cached_session(): 748 self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, 749 self.evaluate(loss), 3) 750 751 def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self): 752 weights = constant_op.constant([1.2, 0], shape=[2, 1]) 753 expected_losses = np.multiply(self._expected_losses, 754 np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape( 755 (2, 3))) 756 loss = losses.log_loss(self._labels, self._predictions, weights) 757 with self.cached_session(): 758 self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, 759 self.evaluate(loss), 3) 760 761 def testWeightsWithSameNumDimsButWrongShapeThrowsException(self): 762 weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4]) 763 with self.cached_session(): 764 with self.assertRaises(ValueError): 765 losses.log_loss(self._labels, self._predictions, weights) 766 767 def testNonZeroLossWithMeasurementSpecificWeights(self): 768 weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) 769 expected_losses = np.multiply(self._expected_losses, weights) 770 771 loss = losses.log_loss( 772 self._labels, 773 self._predictions, 774 constant_op.constant( 775 weights, shape=(2, 3))) 776 with self.cached_session(): 777 self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, 778 self.evaluate(loss), 3) 779 780 def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self): 781 weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) 782 expected_losses = np.multiply(self._expected_losses, weights) 783 784 tf_predictions = array_ops.placeholder(dtypes.float32, shape=[2, 3]) 785 loss = losses.log_loss( 786 self._labels, 787 tf_predictions, 788 constant_op.constant( 789 weights, shape=(2, 3))) 790 791 with self.cached_session() as sess: 792 loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) 793 self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3) 794 795 def testNonZeroLossWithSampleSpecificWeightsMostZero(self): 796 weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) 797 expected_losses = np.multiply(self._expected_losses, weights) 798 799 loss = losses.log_loss( 800 self._labels, 801 self._predictions, 802 constant_op.constant( 803 weights, shape=(2, 3))) 804 with self.cached_session(): 805 self.assertAlmostEqual(-np.sum(expected_losses), self.evaluate(loss), 3) 806 807 def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self): 808 weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) 809 expected_losses = np.multiply(self._expected_losses, weights) 810 811 tf_predictions = array_ops.placeholder(dtypes.float32, shape=[2, 3]) 812 tf_weights = constant_op.constant(weights, shape=(2, 3)) 813 loss = losses.log_loss(self._labels, tf_predictions, tf_weights) 814 815 with self.cached_session() as sess: 816 loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) 817 self.assertAlmostEqual(-np.sum(expected_losses), loss, 3) 818 819 def testLossWithSampleSpecificWeightsAllZero(self): 820 tf_weights = array_ops.zeros(shape=(2, 3)) 821 loss = losses.log_loss(self._labels, self._predictions, tf_weights) 822 with self.cached_session(): 823 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 824 825 826class HingeLossTest(test.TestCase): 827 828 def testIncompatibleShapes(self): 829 with self.cached_session(): 830 logits = constant_op.constant([[-1.0], [2.1]]) 831 labels = constant_op.constant([0.0, 1.0]) 832 with self.assertRaises(ValueError): 833 _ = losses.hinge_loss(labels, logits).eval() 834 835 @test_util.run_deprecated_v1 836 def testAllOutsideMargin(self): 837 with self.cached_session(): 838 logits = constant_op.constant([1.2, -1.4, -1.0, 2.1]) 839 labels = constant_op.constant([1.0, 0.0, 0.0, 1.0]) 840 loss = losses.hinge_loss(labels, logits) 841 self.assertAllClose(loss, 0.0, atol=1e-3) 842 843 @test_util.run_deprecated_v1 844 def testSomeInsideMargin(self): 845 with self.cached_session(): 846 logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]]) 847 labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]]) 848 loss = losses.hinge_loss(labels, logits) 849 # Examples 1 and 4 are on the correct side of the hyperplane but within 850 # the margin so they incur some (small) loss. 851 self.assertAllClose(loss, 0.175, atol=1e-3) 852 853 @test_util.run_deprecated_v1 854 def testSomeMisclassified(self): 855 with self.cached_session(): 856 logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) 857 labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]]) 858 loss = losses.hinge_loss(labels, logits) 859 # Examples 2 and 4 are on the wrong side of the hyperplane so they incur 860 # some (fairly large) loss. 861 self.assertAllClose(loss, 0.875, atol=1e-3) 862 863 864class HuberLossTest(test.TestCase): 865 866 def testIncompatibleShapes(self): 867 with self.cached_session(): 868 predictions = constant_op.constant([[-1.0], [2.1]]) 869 labels = constant_op.constant([0.0, 1.0]) 870 with self.assertRaises(ValueError): 871 _ = losses.huber_loss(labels, predictions).eval() 872 873 @test_util.run_deprecated_v1 874 def testAllQuadratic(self): 875 with self.cached_session(): 876 predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0]) 877 labels = constant_op.constant([1.0, -1.0, 0.0, 0.5]) 878 loss = losses.huber_loss(labels, predictions) 879 self.assertAllClose( 880 loss, 0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4., atol=1e-5) 881 882 @test_util.run_deprecated_v1 883 def testAllLinear(self): 884 with self.cached_session(): 885 predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0]) 886 labels = constant_op.constant([0.0, 1.0, 0.0, 1.5]) 887 loss = losses.huber_loss(labels, predictions) 888 self.assertAllClose(loss, (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5, atol=1e-5) 889 890 @test_util.run_deprecated_v1 891 def testMixedQuadraticLinear(self): 892 with self.cached_session(): 893 predictions = constant_op.constant([[1.5, -1.4, -1.0, 0.0], 894 [1.5, -1.4, -1.0, 0.0]]) 895 labels = constant_op.constant([[1.0, -1.0, 0.0, 0.5], 896 [0.0, 1.0, 0.0, 1.5]]) 897 loss = losses.huber_loss(labels, predictions) 898 quadratic = 0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4. 899 linear = (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5 900 expected_loss = (quadratic + linear) / 2. 901 self.assertAllClose(loss, expected_loss, atol=1e-5) 902 903 def testAllQuadraticDelta(self): 904 with self.cached_session(): 905 delta = 0.5 906 predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0]) 907 labels = constant_op.constant([1.0, -1.0, 0.0, 0.5]) 908 expected = 0.5 * np.array([0.5**2, 0.4**2, 0.5**2, 0.5**2]).mean() 909 loss = losses.huber_loss(labels, predictions, delta=delta) 910 self.assertAllClose(expected, self.evaluate(loss), atol=1e-5) 911 912 def testAllLinearDelta(self): 913 delta = 0.5 914 predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0]) 915 labels = constant_op.constant([0.0, 1.0, 0.0, 1.5]) 916 expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean() 917 expected -= 0.5 * delta**2 918 loss = losses.huber_loss(labels, predictions, delta=delta) 919 with self.cached_session(): 920 self.assertAllClose(expected, self.evaluate(loss), atol=1e-5) 921 922 923@test_util.run_deprecated_v1 924class MeanSquaredErrorTest(test.TestCase): 925 926 def setUp(self): 927 super(MeanSquaredErrorTest, self).setUp() 928 self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) 929 self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) 930 931 def testValueErrorThrownWhenWeightIsNone(self): 932 with self.cached_session(): 933 with self.assertRaises(ValueError): 934 losses.mean_squared_error( 935 self._predictions, self._predictions, weights=None) 936 937 @test_util.run_deprecated_v1 938 def testScalar(self): 939 with self.cached_session(): 940 self.assertEqual( 941 0.0, 942 losses.mean_squared_error(predictions=constant_op.constant(0), 943 labels=constant_op.constant(0)).eval()) 944 945 @test_util.run_deprecated_v1 946 def testAllCorrectNoLossWeight(self): 947 loss = losses.mean_squared_error(self._predictions, self._predictions) 948 with self.cached_session(): 949 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 950 951 @test_util.run_deprecated_v1 952 def testNonZeroLoss(self): 953 loss = losses.mean_squared_error(self._labels, self._predictions) 954 with self.cached_session(): 955 self.assertAlmostEqual(49.5, self.evaluate(loss), 3) 956 957 @test_util.run_deprecated_v1 958 def testNonZeroLossWithPythonScalarWeight(self): 959 weights = 2.3 960 loss = losses.mean_squared_error(self._labels, self._predictions, weights) 961 with self.cached_session(): 962 self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3) 963 964 @test_util.run_deprecated_v1 965 def testNonZeroLossWithScalarTensorWeight(self): 966 weights = 2.3 967 loss = losses.mean_squared_error(self._labels, self._predictions, 968 constant_op.constant(weights)) 969 with self.cached_session(): 970 self.assertAlmostEqual(49.5 * weights, self.evaluate(loss), 3) 971 972 def testNonZeroLossWithOneDimBatchSpecificWeights(self): 973 weights = constant_op.constant([1.2, 3.4], shape=(2, 1)) 974 loss = losses.mean_squared_error(self._labels, self._predictions, weights) 975 with self.cached_session(): 976 self.assertAlmostEqual(767.8 / 6.0, self.evaluate(loss), 3) 977 978 def testNonZeroLossWithTwoDimBatchSpecificWeights(self): 979 weights = constant_op.constant([1.2, 3.4], shape=[2, 1]) 980 loss = losses.mean_squared_error(self._labels, self._predictions, weights) 981 with self.cached_session(): 982 self.assertAlmostEqual(767.8 / 6.0, self.evaluate(loss), 3) 983 984 def testNonZeroLossWithSampleSpecificWeights(self): 985 weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) 986 loss = losses.mean_squared_error(self._labels, self._predictions, weights) 987 with self.cached_session(): 988 self.assertAlmostEqual(587 / 5.0, self.evaluate(loss), 3) 989 990 def testNonZeroLossWithSampleSpecificWeightsMostZero(self): 991 weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) 992 loss = losses.mean_squared_error(self._labels, self._predictions, weights) 993 with self.cached_session(): 994 self.assertAlmostEqual(18.0, self.evaluate(loss), 3) 995 996 def testLossWithSampleSpecificWeightsAllZero(self): 997 weights = array_ops.zeros((2, 3)) 998 loss = losses.mean_squared_error(self._labels, self._predictions, weights) 999 with self.cached_session(): 1000 self.assertAlmostEqual(0.0, self.evaluate(loss), 3) 1001 1002 1003@test_util.run_deprecated_v1 1004class MeanPairwiseSquaredErrorTest(test.TestCase): 1005 1006 def setUp(self): 1007 super(MeanPairwiseSquaredErrorTest, self).setUp() 1008 self._predictions = np.array([[4, 8, 12], [8, 1, 3]]) 1009 self._labels = np.array([[1, 9, 2], [-5, -5, 7]]) 1010 1011 batch_size, dims = self._labels.shape # pylint: disable=unpacking-non-sequence 1012 1013 # Compute the expected loss 'manually'. 1014 total = np.zeros((batch_size,)) 1015 for b in range(batch_size): 1016 for i in range(dims - 1): 1017 for j in range(i + 1, dims): 1018 x = self._predictions[b, i].item() - self._predictions[b, j].item() 1019 y = self._labels[b, i].item() - self._labels[b, j].item() 1020 diff = (x - y) 1021 total[b] += (diff * diff) 1022 1023 self._expected_losses = np.divide(total, 3.0) 1024 1025 def testValueErrorThrownWhenWeightIsNone(self): 1026 with self.cached_session(): 1027 with self.assertRaises(ValueError): 1028 losses.mean_pairwise_squared_error( 1029 predictions=constant_op.constant(self._labels), 1030 labels=constant_op.constant(self._labels), 1031 weights=None) 1032 1033 def _test_valid_weights( 1034 self, labels, predictions, expected_loss, weights=1.0): 1035 with self.cached_session(): 1036 static_inputs_op = losses.mean_pairwise_squared_error( 1037 predictions=predictions, labels=labels, weights=weights) 1038 self.assertAlmostEqual( 1039 expected_loss, self.evaluate(static_inputs_op), places=3) 1040 1041 predictions_placeholder = array_ops.placeholder( 1042 dtypes.float32, shape=np.asarray(predictions.shape)) 1043 labels_placeholder = array_ops.placeholder( 1044 dtypes.int32, shape=np.asarray(labels.shape)) 1045 weights_placeholder = array_ops.placeholder( 1046 dtypes.float32, shape=np.asarray(weights).shape) 1047 dynamic_inputs_op = losses.mean_pairwise_squared_error( 1048 predictions=predictions_placeholder, 1049 labels=labels_placeholder, 1050 weights=weights_placeholder) 1051 feed_dict = { 1052 predictions_placeholder: predictions, 1053 labels_placeholder: labels, 1054 weights_placeholder: weights, 1055 } 1056 self.assertAlmostEqual( 1057 expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3) 1058 1059 def testAllCorrectNoLossWeight(self): 1060 self._test_valid_weights( 1061 self._labels, self._labels, expected_loss=0.0) 1062 1063 def testNonZeroLoss(self): 1064 self._test_valid_weights( 1065 self._labels, self._predictions, 1066 expected_loss=np.sum(self._expected_losses)) 1067 1068 def testGradientWithZeroWeight(self): 1069 with ops.Graph().as_default(): 1070 random_seed.set_random_seed(0) 1071 1072 inputs = array_ops.ones((2, 3)) 1073 weights = variable_scope.get_variable( 1074 'weights', 1075 shape=[3, 4], 1076 initializer=init_ops.truncated_normal_initializer()) 1077 predictions = math_ops.matmul(inputs, weights) 1078 1079 optimizer = momentum_lib.MomentumOptimizer( 1080 learning_rate=0.001, momentum=0.9) 1081 loss = losses.mean_pairwise_squared_error(predictions, predictions, 0) 1082 1083 gradients_to_variables = optimizer.compute_gradients(loss) 1084 1085 init_op = variables.global_variables_initializer() 1086 1087 with self.cached_session() as sess: 1088 self.evaluate(init_op) 1089 for grad, _ in gradients_to_variables: 1090 np_grad = self.evaluate(grad) 1091 self.assertFalse(np.isnan(np_grad).any()) 1092 1093 def testNonZeroLossWithPythonScalarWeight(self): 1094 weight = 2.3 1095 self._test_valid_weights( 1096 self._labels, self._predictions, 1097 expected_loss=weight * np.sum(self._expected_losses), 1098 weights=weight) 1099 1100 def testNonZeroLossWithScalarTensorWeight(self): 1101 weights = 2.3 1102 loss = losses.mean_pairwise_squared_error( 1103 predictions=constant_op.constant(self._predictions), 1104 labels=constant_op.constant(self._labels), 1105 weights=constant_op.constant(weights)) 1106 with self.cached_session(): 1107 self.assertAlmostEqual(weights * np.sum(self._expected_losses), 1108 self.evaluate(loss), 3) 1109 1110 def testNonZeroLossWithScalarZeroWeight(self): 1111 self._test_valid_weights( 1112 self._labels, self._predictions, expected_loss=0.0, weights=0.0) 1113 1114 def test3d(self): 1115 labels = np.array([ 1116 [[1, 9, 2], [12, 11, 10], [9, 8, 7]], 1117 [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], 1118 ]) 1119 predictions = np.array([ 1120 [[4, 8, 12], [1, 2, 3], [4, 5, 6]], 1121 [[8, 1, 3], [7, 8, 9], [10, 11, 12]], 1122 ]) 1123 self._test_valid_weights(labels, predictions, expected_loss=137.5) 1124 1125 def test3dWeightedScalar(self): 1126 labels = np.array([ 1127 [[1, 9, 2], [12, 11, 10], [9, 8, 7]], 1128 [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], 1129 ]) 1130 predictions = np.array([ 1131 [[4, 8, 12], [1, 2, 3], [4, 5, 6]], 1132 [[8, 1, 3], [7, 8, 9], [10, 11, 12]], 1133 ]) 1134 weight = 3.0 1135 self._test_valid_weights( 1136 labels, predictions, expected_loss=weight * 137.5, weights=weight) 1137 1138 def _test_invalid_weights( 1139 self, labels, predictions, weights=1.0): 1140 expected_error_msg = 'weights can not be broadcast to values' 1141 1142 # Static check. 1143 with self.assertRaisesRegex(ValueError, expected_error_msg): 1144 losses.mean_pairwise_squared_error( 1145 predictions=predictions, labels=labels, weights=weights) 1146 1147 # Dynamic check. 1148 predictions_placeholder = array_ops.placeholder(dtypes.float32) 1149 labels_placeholder = array_ops.placeholder(dtypes.int32) 1150 weights_placeholder = array_ops.placeholder(dtypes.float32) 1151 dynamic_inputs_op = losses.mean_pairwise_squared_error( 1152 predictions=predictions_placeholder, 1153 labels=labels_placeholder, 1154 weights=weights_placeholder) 1155 with self.cached_session(): 1156 with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg): 1157 dynamic_inputs_op.eval(feed_dict={ 1158 predictions_placeholder: predictions, 1159 labels_placeholder: labels, 1160 weights_placeholder: weights, 1161 }) 1162 1163 def testInvalid3dWeighted2x0(self): 1164 labels = np.array([ 1165 [[1, 9, 2], [12, 11, 10], [9, 8, 7]], 1166 [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], 1167 ]) 1168 predictions = np.array([ 1169 [[4, 8, 12], [1, 2, 3], [4, 5, 6]], 1170 [[8, 1, 3], [7, 8, 9], [10, 11, 12]], 1171 ]) 1172 self._test_invalid_weights( 1173 labels, predictions, weights=np.asarray((1.2, 3.4))) 1174 1175 def test3dWeighted2x3x3(self): 1176 labels = np.array([ 1177 [[1, 9, 2], [12, 11, 10], [9, 8, 7]], 1178 [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], 1179 ]) 1180 predictions = np.array([ 1181 [[4, 8, 12], [1, 2, 3], [4, 5, 6]], 1182 [[8, 1, 3], [7, 8, 9], [10, 11, 12]], 1183 ]) 1184 self._test_valid_weights( 1185 # TODO(ptucker): This doesn't look right. 1186 labels, 1187 predictions, 1188 expected_loss=9 * 137.5, 1189 weights=np.ones((2, 3, 3))) 1190 1191 def testLossWithAllZeroBatchSpecificWeights(self): 1192 self._test_valid_weights( 1193 self._labels, self._predictions, expected_loss=0.0, 1194 weights=np.zeros((2, 1))) 1195 1196 def testLossIsAssociativeAcrossBatchElements(self): 1197 with ops.Graph().as_default(): 1198 random_seed.set_random_seed(0) 1199 1200 height = 3 1201 width = 4 1202 shape = (1, height, width, 1) 1203 1204 labels0 = random_ops.random_uniform( 1205 shape, minval=0, maxval=1, dtype=dtypes.float32) 1206 predictions0 = random_ops.random_uniform( 1207 shape, minval=0, maxval=1, dtype=dtypes.float32) 1208 1209 labels1 = random_ops.random_uniform( 1210 shape, minval=0, maxval=1, dtype=dtypes.float32) 1211 predictions1 = random_ops.random_uniform( 1212 shape, minval=0, maxval=1, dtype=dtypes.float32) 1213 1214 loss0 = losses.mean_pairwise_squared_error( 1215 labels=labels0, 1216 predictions=predictions0) 1217 loss1 = losses.mean_pairwise_squared_error( 1218 labels=labels1, 1219 predictions=predictions1) 1220 loss0_1 = losses.mean_pairwise_squared_error( 1221 labels=array_ops.concat([labels0, labels1], 0), 1222 predictions=array_ops.concat([predictions0, predictions1], 0)) 1223 1224 with self.cached_session() as session: 1225 loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1]) 1226 1227 self.assertTrue(loss0 > 0) 1228 self.assertTrue(loss1 > 0) 1229 self.assertAlmostEqual(loss0 + loss1, loss0_1, 5) 1230 1231 1232@test_util.run_deprecated_v1 1233class CosineDistanceLossTest(test.TestCase): 1234 1235 def setUp(self): 1236 super(CosineDistanceLossTest, self).setUp() 1237 self._predictions = np.asarray([ 1238 [1, 0, 0], # Batch 1 1239 [0, 0, -1], 1240 [1, 0, 0], # Batch 2 1241 [1, 0, 0], 1242 [0, 0, -1], # Batch 3 1243 [1, 0, 0] 1244 ]).reshape((3, 2, 3)) 1245 1246 self._labels = np.asarray([[1, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], 1247 [0, 0, 1], [0, 1, 0]]).reshape((3, 2, 3)) 1248 1249 def testValueErrorThrownWhenWeightIsNone(self): 1250 with self.cached_session(): 1251 with self.assertRaises(ValueError): 1252 losses.cosine_distance( 1253 predictions=constant_op.constant(self._labels), 1254 labels=constant_op.constant(self._labels), 1255 dim=2, 1256 weights=None) 1257 1258 def testAllCorrectNoWeights(self): 1259 loss = losses.cosine_distance( 1260 predictions=constant_op.constant(self._labels), 1261 labels=constant_op.constant(self._labels), 1262 dim=2) 1263 with self.cached_session(): 1264 self.assertAlmostEqual(0, self.evaluate(loss), 5) 1265 1266 def testPartiallyCorrectWithIntegerValues(self): 1267 loss = losses.cosine_distance( 1268 predictions=constant_op.constant(self._predictions), 1269 labels=constant_op.constant(self._labels), 1270 dim=2) 1271 with self.cached_session(): 1272 self.assertAlmostEqual(1, self.evaluate(loss), 5) 1273 1274 def testPartiallyCorrectFloatingPointValues(self): 1275 predictions = np.matrix( 1276 ('0.819031913261206 0.567041924552012 0.087465312324590;' 1277 '-0.665139432070255 -0.739487441769973 -0.103671883216994;' 1278 '0.707106781186548 -0.707106781186548 0')) 1279 labels = np.matrix(('0.819031913261206 0.567041924552012 0.087465312324590;' 1280 '0.665139432070255 0.739487441769973 0.103671883216994;' 1281 '0.707106781186548 0.707106781186548 0')) 1282 1283 tf_preds = constant_op.constant( 1284 predictions, shape=(3, 1, 3), dtype=dtypes.float32) 1285 tf_labels = constant_op.constant( 1286 labels, shape=(3, 1, 3), dtype=dtypes.float32) 1287 loss = losses.cosine_distance(tf_labels, tf_preds, dim=2) 1288 1289 with self.cached_session(): 1290 self.assertAlmostEqual(1.0, self.evaluate(loss), 5) 1291 1292 def testSampleSpecificWeights(self): 1293 loss = losses.cosine_distance( 1294 predictions=constant_op.constant(self._predictions), 1295 labels=constant_op.constant(self._labels), 1296 dim=2, 1297 weights=np.asarray((1, 0, 0)).reshape((3, 1, 1))) 1298 with self.cached_session(): 1299 self.assertEqual(1.0, self.evaluate(loss)) 1300 1301 def testMeasurementSpecificWeights(self): 1302 loss = losses.cosine_distance( 1303 predictions=constant_op.constant(self._predictions), 1304 labels=constant_op.constant(self._labels), 1305 dim=2, 1306 weights=constant_op.constant( 1307 [1, 0, 0, 1, 1, 1], shape=(3, 2, 1))) 1308 with self.cached_session(): 1309 self.assertEqual(3.0 / 4.0, self.evaluate(loss)) 1310 1311 def testMeasurementSpecificWeightsWithPlaceholderWithShape(self): 1312 tf_predictions = array_ops.placeholder( 1313 dtypes.float32, shape=self._labels.shape) 1314 loss = losses.cosine_distance( 1315 predictions=tf_predictions, 1316 labels=constant_op.constant(self._labels), 1317 dim=2, 1318 weights=constant_op.constant( 1319 [1, 0, 0, 1, 1, 1], shape=(3, 2, 1))) 1320 with self.cached_session() as sess: 1321 loss = sess.run(loss, feed_dict={tf_predictions: self._predictions}) 1322 self.assertEqual(3.0 / 4.0, loss) 1323 1324 def testZeroLossWhenAllSampleSpecificWeightsAreZero(self): 1325 loss = losses.cosine_distance( 1326 predictions=constant_op.constant(self._predictions), 1327 labels=constant_op.constant(self._labels), 1328 dim=2, 1329 weights=array_ops.zeros((3, 1, 1))) 1330 with self.cached_session(): 1331 self.assertEqual(0, self.evaluate(loss)) 1332 1333 def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self): 1334 loss = losses.cosine_distance( 1335 predictions=constant_op.constant(self._predictions), 1336 labels=constant_op.constant(self._labels), 1337 dim=2, 1338 weights=array_ops.zeros((3, 2, 1))) 1339 with self.cached_session(): 1340 self.assertEqual(0, self.evaluate(loss)) 1341 1342 1343class AddLossTest(test.TestCase): 1344 1345 def testNoCollectLossesBatch2(self): 1346 logits = constant_op.constant([[1.2, 0.4, -1.0, -1.1]] * 2) 1347 labels = constant_op.constant([[1.0, 0.0, 0.0, 1.0]] * 2) 1348 self.assertFalse(util.get_losses()) 1349 losses.absolute_difference(logits, labels, loss_collection=None) 1350 losses.log_loss(logits, labels, loss_collection=None) 1351 losses.mean_squared_error(logits, labels, loss_collection=None) 1352 losses.sigmoid_cross_entropy(logits, labels, loss_collection=None) 1353 losses.softmax_cross_entropy(logits, labels, loss_collection=None) 1354 self.assertFalse(util.get_losses()) 1355 1356 1357class ComputeWeightedLossTest(test.TestCase): 1358 1359 def setUp(self): 1360 super(ComputeWeightedLossTest, self).setUp() 1361 self._shape = (3, 2, 4) 1362 raw_losses = np.zeros(self._shape) 1363 next_loss = 0.0 1364 for i in range(self._shape[0]): 1365 for j in range(self._shape[1]): 1366 for k in range(self._shape[2]): 1367 raw_losses[i][j][k] = next_loss 1368 next_loss += 1.0 1369 raw_losses.setflags(write=False) 1370 self._raw_losses = raw_losses 1371 1372 def testUnweighted(self): 1373 for reduction in losses.Reduction.all(): 1374 with ops.Graph().as_default() as g: 1375 self.assertEqual(0, len(util.get_losses())) 1376 raw_losses = self._raw_losses 1377 unweighted_losses = ( 1378 losses.compute_weighted_loss(raw_losses, reduction=reduction), 1379 losses.compute_weighted_loss( 1380 raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), 1381 losses.compute_weighted_loss( 1382 raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), 1383 losses.compute_weighted_loss( 1384 raw_losses, weights=np.ones((1, 2, 1)), reduction=reduction), 1385 losses.compute_weighted_loss( 1386 raw_losses, weights=np.ones((1, 2, 4)), reduction=reduction), 1387 losses.compute_weighted_loss( 1388 raw_losses, weights=np.ones((3, 1, 1)), reduction=reduction), 1389 losses.compute_weighted_loss( 1390 raw_losses, weights=np.ones((3, 1, 4)), reduction=reduction), 1391 losses.compute_weighted_loss( 1392 raw_losses, weights=np.ones((3, 2, 1)), reduction=reduction), 1393 losses.compute_weighted_loss( 1394 raw_losses, weights=np.ones(self._shape), reduction=reduction) 1395 ) 1396 self.assertEqual(9, len(util.get_losses())) 1397 with self.session(g): 1398 for unweighted_loss in unweighted_losses: 1399 if reduction == losses.Reduction.NONE: 1400 self.assertAllClose(self._raw_losses, 1401 self.evaluate(unweighted_loss)) 1402 elif reduction == losses.Reduction.SUM: 1403 self.assertAllClose( 1404 np.sum(self._raw_losses), self.evaluate(unweighted_loss)) 1405 else: 1406 # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, 1407 # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. 1408 self.assertAllClose( 1409 np.mean(self._raw_losses), self.evaluate(unweighted_loss)) 1410 1411 def testUnweightedFromPlaceholder(self): 1412 for reduction in losses.Reduction.all(): 1413 with ops.Graph().as_default() as g: 1414 self.assertEqual(0, len(util.get_losses())) 1415 raw_losses = array_ops.placeholder(dtype=dtypes.float32) 1416 feed_dict = {raw_losses: self._raw_losses} 1417 unweighted_losses = ( 1418 losses.compute_weighted_loss(raw_losses, reduction=reduction), 1419 losses.compute_weighted_loss( 1420 raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), 1421 losses.compute_weighted_loss( 1422 raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), 1423 ) 1424 self.assertEqual(3, len(util.get_losses())) 1425 with self.session(g): 1426 for unweighted_loss in unweighted_losses: 1427 if reduction == losses.Reduction.NONE: 1428 self.assertAllClose( 1429 self._raw_losses, unweighted_loss.eval(feed_dict)) 1430 elif reduction == losses.Reduction.SUM: 1431 self.assertAllClose( 1432 np.sum(self._raw_losses), unweighted_loss.eval(feed_dict)) 1433 else: 1434 # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, 1435 # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. 1436 self.assertAllClose( 1437 np.mean(self._raw_losses), unweighted_loss.eval(feed_dict)) 1438 1439 def testScalarWeight(self): 1440 with ops.Graph().as_default(): 1441 self.assertEqual(0, len(util.get_losses())) 1442 weight = 17.0 1443 weighted_loss = losses.compute_weighted_loss( 1444 self._raw_losses, weights=weight) 1445 self.assertEqual(1, len(util.get_losses())) 1446 with self.cached_session(): 1447 self.assertAllClose( 1448 np.mean(weight * self._raw_losses), self.evaluate(weighted_loss)) 1449 1450 def _test_invalid_weights(self, weights): 1451 with ops.Graph().as_default(): 1452 self.assertEqual(0, len(util.get_losses())) 1453 expected_error_msg = 'weights can not be broadcast to values' 1454 1455 # Static check. 1456 with self.assertRaisesRegex(ValueError, expected_error_msg): 1457 losses.compute_weighted_loss(self._raw_losses, weights=weights) 1458 1459 # Dynamic check. 1460 weights_placeholder = array_ops.placeholder(dtypes.float32) 1461 weighted_loss = losses.compute_weighted_loss( 1462 self._raw_losses, weights=weights_placeholder) 1463 self.assertEqual(1, len(util.get_losses())) 1464 with self.cached_session(): 1465 with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg): 1466 weighted_loss.eval(feed_dict={weights_placeholder: weights}) 1467 1468 def testInvalidWeightTooManyDims(self): 1469 self._test_invalid_weights(np.zeros(shape=(2, 2, 2, 2))) 1470 1471 def testInvalidWeightMismatchedDim(self): 1472 with ops.Graph().as_default(): 1473 raw_losses = array_ops.reshape(self._raw_losses, shape=(3, 2, 4, 1)) 1474 weights = np.ones(shape=(3, 2, 4, 2)) 1475 expected_error_msg = 'weights can not be broadcast to values' 1476 self.assertEqual(0, len(util.get_losses())) 1477 1478 # Static check. 1479 with self.assertRaisesRegex(ValueError, expected_error_msg): 1480 losses.compute_weighted_loss(raw_losses, weights=weights) 1481 1482 # Dynamic check. 1483 weights_placeholder = array_ops.placeholder(dtypes.float32) 1484 weighted_loss = losses.compute_weighted_loss( 1485 raw_losses, weights=weights_placeholder) 1486 self.assertEqual(1, len(util.get_losses())) 1487 with self.cached_session(): 1488 with self.assertRaisesRegex(errors_impl.OpError, expected_error_msg): 1489 weighted_loss.eval(feed_dict={weights_placeholder: weights}) 1490 1491 def testInvalid3Weight(self): 1492 self._test_invalid_weights((17.0, 5.0, 2.0)) 1493 1494 def testInvalid3x1Weight(self): 1495 self._test_invalid_weights(((17.0,), (5.0,), (2.0,),)) 1496 1497 def testInvalid3x2Weight(self): 1498 self._test_invalid_weights(( 1499 (17.0, 3.0), 1500 (5.0, 31.0), 1501 (2.0, 7.0),)) 1502 1503 def testInvalid1x2Weight(self): 1504 self._test_invalid_weights((17.0, 3.0,),) 1505 1506 def testInvalidScalar1DWeight(self): 1507 self._test_invalid_weights((17.0,),) 1508 1509 def _test_valid_weights(self, weights): 1510 for reduction in losses.Reduction.all(): 1511 with ops.Graph().as_default() as g: 1512 self.assertEqual(0, len(util.get_losses())) 1513 weighted_loss = losses.compute_weighted_loss( 1514 self._raw_losses, weights=weights, reduction=reduction) 1515 self.assertEqual(1, len(util.get_losses())) 1516 with self.session(g): 1517 weighted_losses = weights * self._raw_losses 1518 weighted_sum = np.sum(weighted_losses) 1519 if reduction == losses.Reduction.NONE: 1520 self.assertAllClose(weighted_losses, self.evaluate(weighted_loss)) 1521 elif reduction == losses.Reduction.SUM: 1522 self.assertAllClose(weighted_sum, self.evaluate(weighted_loss)) 1523 else: 1524 broadcast_weights = weights * np.ones_like(self._raw_losses) 1525 if reduction == losses.Reduction.MEAN: 1526 self.assertAllClose(weighted_sum / np.sum(broadcast_weights), 1527 self.evaluate(weighted_loss)) 1528 elif (reduction == losses.Reduction.SUM_OVER_NONZERO_WEIGHTS or 1529 reduction == losses.Reduction.SUM_BY_NONZERO_WEIGHTS): 1530 self.assertAllClose( 1531 weighted_sum / np.count_nonzero(broadcast_weights), 1532 self.evaluate(weighted_loss)) 1533 elif reduction == losses.Reduction.SUM_OVER_BATCH_SIZE: 1534 self.assertAllClose(weighted_sum / self._raw_losses.size, 1535 self.evaluate(weighted_loss)) 1536 1537 def test1x1x1Weight(self): 1538 self._test_valid_weights((((17.0,),),)) 1539 1540 def test1x2x1Weight(self): 1541 self._test_valid_weights((((17.0,), (3.0,),),)) 1542 1543 def test1x1x4Weight(self): 1544 self._test_valid_weights((((17.0, 0.0, 2.0, 5.0),),)) 1545 1546 def test3x1x1Weight(self): 1547 self._test_valid_weights((((17.0,),), ((5.0,),), ((2.0,),),)) 1548 1549 def test3x2x1Weight(self): 1550 self._test_valid_weights(( 1551 ((17.0,), (3.0,)), 1552 ((5.0,), (31.0,)), 1553 ((2.0,), (7.0,)), 1554 )) 1555 1556 def test3x1x4Weight(self): 1557 self._test_valid_weights(( 1558 ((17.0, 0.0, 2.0, 5.0),), 1559 ((5.0, 31.0, 17.0, 5.0),), 1560 ((7.0, 3.0, 11.0, 5.0),), 1561 )) 1562 1563 def test1x2x4Weight(self): 1564 self._test_valid_weights((( 1565 (17.0, 0.0, 2.0, 5.0), 1566 (3.0, 13.0, 11.0, 2.0), 1567 ),)) 1568 1569 def test3x2x4Weight(self): 1570 self._test_valid_weights(( 1571 ((17.0, 0.0, 2.0, 5.0), (3.0, 13.0, 11.0, 2.0),), 1572 ((5.0, 31.0, 17.0, 5.0), (13.0, 3.0, 0.0, 11.0),), 1573 ((0.0, 3.0, 11.0, 5.0), (13.0, 11.0, 1.0, 7.0),), 1574 )) 1575 1576 1577if __name__ == '__main__': 1578 test.main() 1579