1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Iterator`.""" 16import warnings 17 18from absl.testing import parameterized 19import numpy as np 20 21from tensorflow.core.protobuf import cluster_pb2 22from tensorflow.core.protobuf import config_pb2 23from tensorflow.python.client import session 24from tensorflow.python.data.kernel_tests import test_base 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.ops import iterator_ops 27from tensorflow.python.data.util import structure 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import combinations 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import function 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import sparse_tensor 37from tensorflow.python.framework import tensor_spec 38from tensorflow.python.framework import test_util 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import data_flow_ops 41from tensorflow.python.ops import functional_ops 42from tensorflow.python.ops import gradients_impl 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import parsing_ops 45from tensorflow.python.ops import script_ops 46from tensorflow.python.ops import variables 47from tensorflow.python.platform import test 48from tensorflow.python.training import server_lib 49from tensorflow.python.util import compat 50 51 52@test_util.with_eager_op_as_function 53class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): 54 55 @combinations.generate(test_base.graph_only_combinations()) 56 def testNoGradients(self): 57 component = constant_op.constant([1.]) 58 side = constant_op.constant(0.) 59 add = lambda x: x + side 60 dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) 61 value = dataset_ops.make_one_shot_iterator(dataset).get_next() 62 self.assertIsNone(gradients_impl.gradients(value, component)[0]) 63 self.assertIsNone(gradients_impl.gradients(value, side)[0]) 64 self.assertIsNone(gradients_impl.gradients(value, [component, side])[0]) 65 66 @combinations.generate(test_base.graph_only_combinations()) 67 def testCapturingStateInOneShotRaisesException(self): 68 var = variables.Variable(37.0, name="myvar") 69 dataset = ( 70 dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) 71 .map(lambda x: x + var)) 72 with self.assertRaisesRegex( 73 ValueError, r"A likely cause of this error is that the dataset for " 74 r"which you are calling `make_one_shot_iterator\(\)` captures a " 75 r"stateful object, such as a `tf.Variable` or " 76 r"`tf.lookup.StaticHashTable`, which is not supported. Use " 77 r"`make_initializable_iterator\(\)` instead."): 78 dataset_ops.make_one_shot_iterator(dataset) 79 80 @combinations.generate(test_base.graph_only_combinations()) 81 def testOneShotIterator(self): 82 components = (np.arange(7), 83 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 84 np.array(37.0) * np.arange(7)) 85 86 def _map_fn(x, y, z): 87 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 88 89 iterator = dataset_ops.make_one_shot_iterator( 90 dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 91 .repeat(14)) 92 get_next = iterator.get_next() 93 94 self.assertEqual([c.shape[1:] for c in components], 95 [t.shape for t in get_next]) 96 97 with self.cached_session() as sess: 98 for _ in range(14): 99 for i in range(7): 100 result = sess.run(get_next) 101 for component, result_component in zip(components, result): 102 self.assertAllEqual(component[i]**2, result_component) 103 with self.assertRaises(errors.OutOfRangeError): 104 sess.run(get_next) 105 106 @combinations.generate(test_base.graph_only_combinations()) 107 def testOneShotIteratorCaptureByValue(self): 108 components = (np.arange(7), 109 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 110 np.array(37.0) * np.arange(7)) 111 tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) 112 113 def _map_fn(x, y, z): 114 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 115 116 iterator = dataset_ops.make_one_shot_iterator( 117 dataset_ops.Dataset.from_tensor_slices(tensor_components) 118 .map(_map_fn).repeat(14)) 119 get_next = iterator.get_next() 120 121 self.assertEqual([c.shape[1:] for c in components], 122 [t.shape for t in get_next]) 123 124 with self.cached_session() as sess: 125 for _ in range(14): 126 for i in range(7): 127 result = sess.run(get_next) 128 for component, result_component in zip(components, result): 129 self.assertAllEqual(component[i]**2, result_component) 130 with self.assertRaises(errors.OutOfRangeError): 131 sess.run(get_next) 132 133 @combinations.generate(test_base.default_test_combinations()) 134 def testOneShotIteratorInsideContainer(self): 135 components = (np.arange(7), 136 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 137 np.array(37.0) * np.arange(7)) 138 139 def within_container(): 140 141 def _map_fn(x, y, z): 142 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 143 144 iterator = dataset_ops.make_one_shot_iterator( 145 dataset_ops.Dataset.from_tensor_slices(components) 146 .map(_map_fn).repeat(14)) 147 return iterator.get_next() 148 149 server = server_lib.Server.create_local_server() 150 151 # Create two iterators within unique containers, and run them to 152 # make sure that the resources aren't shared. 153 # 154 # The test below would fail if cname were the same across both 155 # sessions. 156 for j in range(2): 157 with session.Session(server.target) as sess: 158 cname = "iteration%d" % j 159 with ops.container(cname): 160 get_next = within_container() 161 162 for _ in range(14): 163 for i in range(7): 164 result = sess.run(get_next) 165 for component, result_component in zip(components, result): 166 self.assertAllEqual(component[i]**2, result_component) 167 with self.assertRaises(errors.OutOfRangeError): 168 sess.run(get_next) 169 170 @combinations.generate(test_base.graph_only_combinations()) 171 def testOneShotIteratorNonBlocking(self): 172 dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) 173 iterator = dataset_ops.make_one_shot_iterator(dataset) 174 next_element = iterator.get_next() 175 176 # Create a session with a single thread to ensure that the 177 # one-shot iterator initializer does not deadlock. 178 config = config_pb2.ConfigProto( 179 inter_op_parallelism_threads=1, use_per_session_threads=True) 180 with session.Session(config=config) as sess: 181 self.assertAllEqual([1, 4, 9], sess.run(next_element)) 182 with self.assertRaises(errors.OutOfRangeError): 183 sess.run(next_element) 184 185 # Test with multiple threads invoking the one-shot iterator concurrently. 186 with session.Session(config=config) as sess: 187 results = [] 188 189 def consumer_thread(): 190 try: 191 results.append(sess.run(next_element)) 192 except errors.OutOfRangeError: 193 results.append(None) 194 195 num_threads = 8 196 threads = [ 197 self.checkedThread(consumer_thread) for _ in range(num_threads) 198 ] 199 for t in threads: 200 t.start() 201 for t in threads: 202 t.join() 203 204 self.assertLen(results, num_threads) 205 self.assertLen([None for r in results if r is None], num_threads - 1) 206 self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) 207 208 @combinations.generate(test_base.graph_only_combinations()) 209 def testOneShotIteratorInitializerFails(self): 210 # Define a dataset whose initialization will always fail. 211 dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4])) 212 iterator = dataset_ops.make_one_shot_iterator(dataset) 213 next_element = iterator.get_next() 214 215 with self.cached_session() as sess: 216 with self.assertRaisesRegex(errors.InvalidArgumentError, ""): 217 sess.run(next_element) 218 219 # Test that subsequent attempts to use the iterator also fail. 220 with self.assertRaisesRegex(errors.InvalidArgumentError, ""): 221 sess.run(next_element) 222 223 with self.cached_session() as sess: 224 225 def consumer_thread(): 226 with self.assertRaisesRegex(errors.InvalidArgumentError, ""): 227 sess.run(next_element) 228 229 num_threads = 8 230 threads = [ 231 self.checkedThread(consumer_thread) for _ in range(num_threads) 232 ] 233 for t in threads: 234 t.start() 235 for t in threads: 236 t.join() 237 238 @combinations.generate(test_base.default_test_combinations()) 239 def testOneShotIteratorEmptyDataset(self): 240 dataset = dataset_ops.Dataset.range(0) 241 iterator = dataset_ops.make_one_shot_iterator(dataset) 242 with self.assertRaises(errors.OutOfRangeError): 243 self.evaluate(iterator.get_next()) 244 245 @combinations.generate(test_base.graph_only_combinations()) 246 def testSimpleSharedResource(self): 247 components = (np.array(1, dtype=np.int64), 248 np.array([1, 2, 3], dtype=np.int64), 249 np.array(37.0, dtype=np.float64)) 250 251 server = server_lib.Server.create_local_server() 252 253 # Create two non-overlapping sessions that share the same iterator 254 # resource on the same server, and verify that an action of the 255 # first session (initializing the iterator) is visible in the 256 # second session. 257 with ops.Graph().as_default(): 258 iterator = dataset_ops.make_initializable_iterator( 259 dataset_ops.Dataset.from_tensors( 260 components).map(lambda x, y, z: (x, y, z)), 261 shared_name="shared_iterator") 262 init_op = iterator.initializer 263 get_next = iterator.get_next() 264 265 with session.Session(server.target) as sess: 266 sess.run(init_op) 267 results = sess.run(get_next) 268 for component, result_component in zip(components, results): 269 self.assertAllEqual(component, result_component) 270 with self.assertRaises(errors.OutOfRangeError): 271 sess.run(get_next) 272 273 # Re-initialize the iterator in the first session. 274 sess.run(init_op) 275 276 with ops.Graph().as_default(): 277 # Re-define the iterator manually, without defining any of the 278 # functions in this graph, to ensure that we are not 279 # accidentally redefining functions with the same names in the 280 # new graph. 281 iterator = iterator_ops.Iterator.from_structure( 282 shared_name="shared_iterator", 283 output_types=(dtypes.int64, dtypes.int64, dtypes.float64), 284 output_shapes=([], [3], [])) 285 get_next = iterator.get_next() 286 287 with session.Session(server.target) as sess: 288 # Use the iterator without re-initializing in the second session. 289 results = sess.run(get_next) 290 for component, result_component in zip(components, results): 291 self.assertAllEqual(component, result_component) 292 with self.assertRaises(errors.OutOfRangeError): 293 sess.run(get_next) 294 295 @combinations.generate(test_base.graph_only_combinations()) 296 def testNotInitializedError(self): 297 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 298 iterator = dataset_ops.make_initializable_iterator( 299 dataset_ops.Dataset.from_tensors(components)) 300 get_next = iterator.get_next() 301 302 with self.cached_session() as sess: 303 with self.assertRaisesRegex(errors.FailedPreconditionError, 304 "iterator has not been initialized"): 305 sess.run(get_next) 306 307 @combinations.generate(test_base.graph_only_combinations()) 308 def testReinitializableIterator(self): 309 dataset_3 = dataset_ops.Dataset.from_tensors( 310 constant_op.constant([1, 2, 3])) 311 dataset_4 = dataset_ops.Dataset.from_tensors( 312 constant_op.constant([4, 5, 6, 7])) 313 iterator = iterator_ops.Iterator.from_structure( 314 dataset_ops.get_legacy_output_types(dataset_3), [None]) 315 316 dataset_3_init_op = iterator.make_initializer(dataset_3) 317 dataset_4_init_op = iterator.make_initializer(dataset_4) 318 get_next = iterator.get_next() 319 320 self.assertEqual( 321 dataset_ops.get_legacy_output_types(dataset_3), 322 dataset_ops.get_legacy_output_types(iterator)) 323 self.assertEqual( 324 dataset_ops.get_legacy_output_types(dataset_4), 325 dataset_ops.get_legacy_output_types(iterator)) 326 self.assertEqual( 327 [None], dataset_ops.get_legacy_output_shapes(iterator).as_list()) 328 329 with self.cached_session() as sess: 330 # The iterator is initially uninitialized. 331 with self.assertRaises(errors.FailedPreconditionError): 332 sess.run(get_next) 333 334 # Initialize with one dataset. 335 sess.run(dataset_3_init_op) 336 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 337 with self.assertRaises(errors.OutOfRangeError): 338 sess.run(get_next) 339 340 # Initialize with a different dataset. 341 sess.run(dataset_4_init_op) 342 self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) 343 with self.assertRaises(errors.OutOfRangeError): 344 sess.run(get_next) 345 346 # Reinitialize with the first dataset. 347 sess.run(dataset_3_init_op) 348 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 349 with self.assertRaises(errors.OutOfRangeError): 350 sess.run(get_next) 351 352 @combinations.generate(test_base.graph_only_combinations()) 353 def testReinitializableIteratorWithFunctions(self): 354 355 def g(): 356 for i in range(10): 357 yield i 358 359 iterator = iterator_ops.Iterator.from_structure(dtypes.int64, []) 360 next_element = iterator.get_next() 361 362 with self.cached_session() as sess: 363 dataset_1 = dataset_ops.Dataset.from_generator( 364 g, output_types=dtypes.int64) 365 sess.run(iterator.make_initializer(dataset_1)) 366 for expected in range(10): 367 self.assertEqual(expected, sess.run(next_element)) 368 with self.assertRaises(errors.OutOfRangeError): 369 sess.run(next_element) 370 371 dataset_2 = dataset_ops.Dataset.from_generator( 372 g, output_types=dtypes.int64) 373 sess.run(iterator.make_initializer(dataset_2)) 374 for expected in range(10): 375 self.assertEqual(expected, sess.run(next_element)) 376 with self.assertRaises(errors.OutOfRangeError): 377 sess.run(next_element) 378 379 @combinations.generate(test_base.default_test_combinations()) 380 def testReinitializableIteratorStaticErrors(self): 381 # Non-matching structure for types and shapes. 382 with self.assertRaises(TypeError): 383 iterator = iterator_ops.Iterator.from_structure( 384 (dtypes.int64, dtypes.float64), [None]) 385 386 # Test validation of dataset argument. 387 iterator = iterator_ops.Iterator.from_structure((dtypes.int64, 388 dtypes.float64)) 389 390 # Incompatible structure. 391 with self.assertRaisesRegex( 392 ValueError, "The two structures don't have the same nested structure."): 393 iterator.make_initializer( 394 dataset_ops.Dataset.from_tensors(((constant_op.constant( 395 [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( 396 [4., 5., 6., 7.], dtype=dtypes.float64),)))) 397 398 # Incompatible types. 399 with self.assertRaisesRegex( 400 TypeError, 401 r"Expected output types \(tf.int64, tf.float64\) but got dataset with " 402 r"output types \(tf.int32, tf.float32\)."): 403 iterator.make_initializer( 404 dataset_ops.Dataset.from_tensors( 405 (constant_op.constant([1, 2, 3], dtype=dtypes.int32), 406 constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) 407 408 # Incompatible shapes. 409 iterator = iterator_ops.Iterator.from_structure( 410 (dtypes.int64, dtypes.float64), ([None], [])) 411 with self.assertRaisesRegex( 412 TypeError, 413 r"Expected output shapes compatible with .* but got dataset with " 414 r"output shapes.*"): 415 iterator.make_initializer( 416 dataset_ops.Dataset.from_tensors( 417 (constant_op.constant([1, 2, 3], dtype=dtypes.int64), 418 constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) 419 420 @combinations.generate(test_base.default_test_combinations()) 421 def testReinitializableIteratorEmptyDataset(self): 422 dataset = dataset_ops.Dataset.range(0) 423 iterator = iterator_ops.Iterator.from_structure( 424 dataset_ops.get_legacy_output_types(dataset), []) 425 init_op = iterator.make_initializer(dataset) 426 427 with self.cached_session() as sess: 428 sess.run(init_op) 429 with self.assertRaises(errors.OutOfRangeError): 430 sess.run(iterator.get_next()) 431 432 @combinations.generate(test_base.graph_only_combinations()) 433 def testIteratorStringHandle(self): 434 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 435 dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) 436 437 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 438 iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) 439 440 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 441 feedable_iterator = iterator_ops.Iterator.from_string_handle( 442 handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), 443 dataset_ops.get_legacy_output_shapes(dataset_3)) 444 next_element = feedable_iterator.get_next() 445 446 self.assertTrue( 447 structure.are_compatible( 448 dataset_ops.get_structure(dataset_3), 449 dataset_ops.get_structure(feedable_iterator))) 450 451 with self.cached_session() as sess: 452 iterator_3_handle = sess.run(iterator_3.string_handle()) 453 iterator_4_handle = sess.run(iterator_4.string_handle()) 454 455 self.assertEqual(10, 456 sess.run( 457 next_element, 458 feed_dict={handle_placeholder: iterator_4_handle})) 459 self.assertEqual(1, 460 sess.run( 461 next_element, 462 feed_dict={handle_placeholder: iterator_3_handle})) 463 self.assertEqual(20, 464 sess.run( 465 next_element, 466 feed_dict={handle_placeholder: iterator_4_handle})) 467 self.assertEqual(2, 468 sess.run( 469 next_element, 470 feed_dict={handle_placeholder: iterator_3_handle})) 471 self.assertEqual(30, 472 sess.run( 473 next_element, 474 feed_dict={handle_placeholder: iterator_4_handle})) 475 self.assertEqual(3, 476 sess.run( 477 next_element, 478 feed_dict={handle_placeholder: iterator_3_handle})) 479 self.assertEqual(40, 480 sess.run( 481 next_element, 482 feed_dict={handle_placeholder: iterator_4_handle})) 483 with self.assertRaises(errors.OutOfRangeError): 484 sess.run( 485 next_element, feed_dict={handle_placeholder: iterator_3_handle}) 486 with self.assertRaises(errors.OutOfRangeError): 487 sess.run( 488 next_element, feed_dict={handle_placeholder: iterator_4_handle}) 489 490 @combinations.generate(test_base.graph_only_combinations()) 491 def testIteratorStringHandleFuture(self): 492 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 493 dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) 494 495 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 496 iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) 497 498 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 499 feedable_iterator = iterator_ops.Iterator.from_string_handle( 500 handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), 501 dataset_ops.get_legacy_output_shapes(dataset_3)) 502 next_element = feedable_iterator.get_next() 503 504 self.assertTrue( 505 structure.are_compatible( 506 dataset_ops.get_structure(dataset_3), 507 dataset_ops.get_structure(feedable_iterator))) 508 509 with self.cached_session() as sess: 510 iterator_3_handle = sess.run(iterator_3.string_handle()) 511 iterator_4_handle = sess.run(iterator_4.string_handle()) 512 513 self.assertEqual( 514 10, 515 sess.run( 516 next_element, 517 feed_dict={handle_placeholder: iterator_4_handle})) 518 self.assertEqual( 519 1, 520 sess.run( 521 next_element, 522 feed_dict={handle_placeholder: iterator_3_handle})) 523 self.assertEqual( 524 20, 525 sess.run( 526 next_element, 527 feed_dict={handle_placeholder: iterator_4_handle})) 528 self.assertEqual( 529 2, 530 sess.run( 531 next_element, 532 feed_dict={handle_placeholder: iterator_3_handle})) 533 self.assertEqual( 534 30, 535 sess.run( 536 next_element, 537 feed_dict={handle_placeholder: iterator_4_handle})) 538 self.assertEqual( 539 3, 540 sess.run( 541 next_element, 542 feed_dict={handle_placeholder: iterator_3_handle})) 543 self.assertEqual( 544 40, 545 sess.run( 546 next_element, 547 feed_dict={handle_placeholder: iterator_4_handle})) 548 with self.assertRaises(errors.OutOfRangeError): 549 sess.run( 550 next_element, feed_dict={handle_placeholder: iterator_3_handle}) 551 with self.assertRaises(errors.OutOfRangeError): 552 sess.run( 553 next_element, feed_dict={handle_placeholder: iterator_4_handle}) 554 555 @combinations.generate(test_base.graph_only_combinations()) 556 def testIteratorStringHandleReuseTensorObject(self): 557 dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 558 one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) 559 initializable_iterator = dataset_ops.make_initializable_iterator(dataset) 560 structure_iterator = iterator_ops.Iterator.from_structure( 561 dataset_ops.get_legacy_output_types(dataset)) 562 563 created_ops = len(ops.get_default_graph().get_operations()) 564 565 self.assertIs(one_shot_iterator.string_handle(), 566 one_shot_iterator.string_handle()) 567 self.assertIs(initializable_iterator.string_handle(), 568 initializable_iterator.string_handle()) 569 self.assertIs(structure_iterator.string_handle(), 570 structure_iterator.string_handle()) 571 572 # Assert that getting the (default) string handle creates no ops. 573 self.assertLen(ops.get_default_graph().get_operations(), created_ops) 574 575 # Specifying an explicit name will create a new op. 576 handle_with_name = one_shot_iterator.string_handle(name="foo") 577 self.assertEqual("foo", handle_with_name.op.name) 578 self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) 579 580 handle_with_same_name = one_shot_iterator.string_handle(name="foo") 581 self.assertEqual("foo_1", handle_with_same_name.op.name) 582 self.assertIsNot(handle_with_name, handle_with_same_name) 583 584 @combinations.generate(test_base.graph_only_combinations()) 585 def testIteratorStringHandleError(self): 586 dataset_int_scalar = ( 587 dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) 588 dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) 589 590 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 591 592 feedable_int_scalar = iterator_ops.Iterator.from_string_handle( 593 handle_placeholder, dtypes.int32, []) 594 feedable_int_vector = iterator_ops.Iterator.from_string_handle( 595 handle_placeholder, dtypes.int32, [None]) 596 feedable_int_any = iterator_ops.Iterator.from_string_handle( 597 handle_placeholder, dtypes.int32) 598 599 with self.cached_session() as sess: 600 handle_int_scalar = sess.run(dataset_ops.make_one_shot_iterator( 601 dataset_int_scalar).string_handle()) 602 handle_float_vector = sess.run(dataset_ops.make_one_shot_iterator( 603 dataset_float_vector).string_handle()) 604 605 self.assertEqual(1, 606 sess.run( 607 feedable_int_scalar.get_next(), 608 feed_dict={handle_placeholder: handle_int_scalar})) 609 610 self.assertEqual(2, 611 sess.run( 612 feedable_int_any.get_next(), 613 feed_dict={handle_placeholder: handle_int_scalar})) 614 615 with self.assertRaises(errors.InvalidArgumentError): 616 print(sess.run( 617 feedable_int_vector.get_next(), 618 feed_dict={handle_placeholder: handle_int_scalar})) 619 620 with self.assertRaises(errors.InvalidArgumentError): 621 print(sess.run( 622 feedable_int_vector.get_next(), 623 feed_dict={handle_placeholder: handle_float_vector})) 624 625 @combinations.generate(test_base.graph_only_combinations()) 626 def testRemoteIteratorUsingRemoteCallOpDirectSession(self): 627 worker_config = config_pb2.ConfigProto() 628 worker_config.device_count["CPU"] = 3 629 630 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): 631 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 632 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 633 iterator_3_handle = iterator_3.string_handle() 634 635 @function.Defun(dtypes.string) 636 def _remote_fn(h): 637 remote_iterator = iterator_ops.Iterator.from_string_handle( 638 h, dataset_ops.get_legacy_output_types(dataset_3), 639 dataset_ops.get_legacy_output_shapes(dataset_3)) 640 return remote_iterator.get_next() 641 642 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 643 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 644 remote_op = functional_ops.remote_call( 645 args=[iterator_3_handle], 646 Tout=[dtypes.int32], 647 f=_remote_fn, 648 target=target_placeholder) 649 650 with self.session(config=worker_config) as sess: 651 elem = sess.run( 652 remote_op, 653 feed_dict={ 654 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 655 }) 656 self.assertEqual(elem, [1]) 657 # Fails when target is cpu:2 where the resource is not located. 658 with self.assertRaises(errors.InvalidArgumentError): 659 sess.run( 660 remote_op, 661 feed_dict={ 662 target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" 663 }) 664 elem = sess.run( 665 remote_op, 666 feed_dict={ 667 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 668 }) 669 self.assertEqual(elem, [2]) 670 elem = sess.run( 671 remote_op, 672 feed_dict={ 673 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 674 }) 675 self.assertEqual(elem, [3]) 676 with self.assertRaises(errors.OutOfRangeError): 677 sess.run( 678 remote_op, 679 feed_dict={ 680 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 681 }) 682 683 @combinations.generate(test_base.graph_only_combinations()) 684 def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): 685 s1 = server_lib.Server.create_local_server() 686 s2 = server_lib.Server.create_local_server() 687 s3 = server_lib.Server.create_local_server() 688 689 cluster_def = cluster_pb2.ClusterDef() 690 workers = cluster_def.job.add() 691 workers.name = "worker" 692 workers.tasks[0] = s1.target[len("grpc://"):] 693 workers.tasks[1] = s2.target[len("grpc://"):] 694 client = cluster_def.job.add() 695 client.name = "client" 696 client.tasks[0] = s3.target[len("grpc://"):] 697 config = config_pb2.ConfigProto(cluster_def=cluster_def) 698 699 worker_devices = [ 700 "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) 701 ] 702 itr_handles = [] 703 for device in worker_devices: 704 with ops.device(device): 705 src = dataset_ops.Dataset.from_tensor_slices([device]) 706 itr = dataset_ops.make_one_shot_iterator(src) 707 itr_handles.append(itr.string_handle()) 708 709 targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) 710 handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) 711 712 @function.Defun(dtypes.string) 713 def loading_func(h): 714 remote_itr = iterator_ops.Iterator.from_string_handle( 715 h, dataset_ops.get_legacy_output_types(itr), 716 dataset_ops.get_legacy_output_shapes(itr)) 717 return remote_itr.get_next() 718 719 def map_fn(target, handle): 720 return functional_ops.remote_call( 721 args=[handle], Tout=[dtypes.string], f=loading_func, target=target) 722 723 with ops.device("/job:client"): 724 client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn) 725 itr = dataset_ops.make_initializable_iterator(client_dataset) 726 n = itr.get_next() 727 728 with session.Session(s3.target, config=config) as sess: 729 sess.run(itr.initializer) 730 expected_values = worker_devices 731 for expected in expected_values: 732 self.assertEqual((compat.as_bytes(expected),), sess.run(n)) 733 734 with self.assertRaises(errors.OutOfRangeError): 735 sess.run(n) 736 737 @combinations.generate(test_base.graph_only_combinations()) 738 def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): 739 if not test_util.is_gpu_available(): 740 self.skipTest("No GPU available") 741 742 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 743 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 744 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 745 iterator_3_handle = iterator_3.string_handle() 746 747 def _encode_raw(byte_array): 748 return bytes(bytearray(byte_array)) 749 750 @function.Defun(dtypes.uint8) 751 def _remote_fn(h): 752 handle = script_ops.py_func(_encode_raw, [h], dtypes.string) 753 remote_iterator = iterator_ops.Iterator.from_string_handle( 754 handle, dataset_ops.get_legacy_output_types(dataset_3), 755 dataset_ops.get_legacy_output_shapes(dataset_3)) 756 return remote_iterator.get_next() 757 758 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 759 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 760 iterator_3_handle_uint8 = parsing_ops.decode_raw( 761 input_bytes=iterator_3_handle, out_type=dtypes.uint8) 762 remote_op = functional_ops.remote_call( 763 args=[iterator_3_handle_uint8], 764 Tout=[dtypes.int32], 765 f=_remote_fn, 766 target=target_placeholder) 767 768 with self.cached_session() as sess: 769 elem = sess.run( 770 remote_op, 771 feed_dict={ 772 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 773 }) 774 self.assertEqual(elem, [1]) 775 elem = sess.run( 776 remote_op, 777 feed_dict={ 778 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 779 }) 780 self.assertEqual(elem, [2]) 781 elem = sess.run( 782 remote_op, 783 feed_dict={ 784 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 785 }) 786 self.assertEqual(elem, [3]) 787 with self.assertRaises(errors.OutOfRangeError): 788 sess.run( 789 remote_op, 790 feed_dict={ 791 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 792 }) 793 794 @combinations.generate(test_base.graph_only_combinations()) 795 def testRepeatedGetNextWarning(self): 796 iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10)) 797 warnings.simplefilter("always") 798 with warnings.catch_warnings(record=True) as w: 799 for _ in range(100): 800 iterator.get_next() 801 self.assertLen(w, 100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD) 802 for warning in w: 803 self.assertIn( 804 iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) 805 806 @combinations.generate( 807 combinations.times( 808 test_base.default_test_combinations(), 809 combinations.combine( 810 expected_element_structure=tensor_spec.TensorSpec([], 811 dtypes.float32), 812 expected_output_classes=ops.Tensor, 813 expected_output_types=dtypes.float32, 814 expected_output_shapes=[[]]))) 815 def testTensorIteratorStructure(self, expected_element_structure, 816 expected_output_classes, 817 expected_output_types, 818 expected_output_shapes): 819 tf_value_fn = lambda: constant_op.constant(37.0) 820 tf_value = tf_value_fn() 821 iterator = dataset_ops.make_one_shot_iterator( 822 dataset_ops.Dataset.from_tensors(tf_value)) 823 824 self.assertTrue( 825 structure.are_compatible( 826 dataset_ops.get_structure(iterator), expected_element_structure)) 827 self.assertEqual(expected_output_classes, 828 dataset_ops.get_legacy_output_classes(iterator)) 829 self.assertEqual(expected_output_types, 830 dataset_ops.get_legacy_output_types(iterator)) 831 self.assertEqual(expected_output_shapes, 832 dataset_ops.get_legacy_output_shapes(iterator)) 833 834 @combinations.generate( 835 combinations.times( 836 test_base.default_test_combinations(), 837 combinations.combine( 838 expected_element_structure=sparse_tensor.SparseTensorSpec( 839 [1], dtypes.int32), 840 expected_output_classes=sparse_tensor.SparseTensor, 841 expected_output_types=dtypes.int32, 842 expected_output_shapes=[[1]]))) 843 def testSparseTensorIteratorStructure(self, expected_element_structure, 844 expected_output_classes, 845 expected_output_types, 846 expected_output_shapes): 847 848 def tf_value_fn(): 849 return sparse_tensor.SparseTensor( 850 indices=[[0]], 851 values=constant_op.constant([0], dtype=dtypes.int32), 852 dense_shape=[1]) 853 854 tf_value = tf_value_fn() 855 iterator = dataset_ops.make_one_shot_iterator( 856 dataset_ops.Dataset.from_tensors(tf_value)) 857 858 self.assertTrue( 859 structure.are_compatible( 860 dataset_ops.get_structure(iterator), expected_element_structure)) 861 self.assertEqual(expected_output_classes, 862 dataset_ops.get_legacy_output_classes(iterator)) 863 self.assertEqual(expected_output_types, 864 dataset_ops.get_legacy_output_types(iterator)) 865 self.assertEqual(expected_output_shapes, 866 dataset_ops.get_legacy_output_shapes(iterator)) 867 868 @combinations.generate( 869 combinations.times( 870 test_base.default_test_combinations(), 871 combinations.combine( 872 expected_element_structure={ 873 "a": 874 tensor_spec.TensorSpec([], dtypes.float32), 875 "b": (tensor_spec.TensorSpec([1], dtypes.string), 876 tensor_spec.TensorSpec([], dtypes.string)) 877 }, 878 expected_output_classes={ 879 "a": ops.Tensor, 880 "b": (ops.Tensor, ops.Tensor) 881 }, 882 expected_output_types={ 883 "a": dtypes.float32, 884 "b": (dtypes.string, dtypes.string) 885 }, 886 expected_output_shapes={ 887 "a": [], 888 "b": ([1], []) 889 }))) 890 def testNestedTensorIteratorStructure(self, expected_element_structure, 891 expected_output_classes, 892 expected_output_types, 893 expected_output_shapes): 894 895 def tf_value_fn(): 896 return { 897 "a": constant_op.constant(37.0), 898 "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) 899 } 900 901 tf_value = tf_value_fn() 902 iterator = dataset_ops.make_one_shot_iterator( 903 dataset_ops.Dataset.from_tensors(tf_value)) 904 905 self.assertTrue( 906 structure.are_compatible( 907 dataset_ops.get_structure(iterator), expected_element_structure)) 908 self.assertEqual(expected_output_classes, 909 dataset_ops.get_legacy_output_classes(iterator)) 910 self.assertEqual(expected_output_types, 911 dataset_ops.get_legacy_output_types(iterator)) 912 self.assertEqual(expected_output_shapes, 913 dataset_ops.get_legacy_output_shapes(iterator)) 914 915 @combinations.generate(test_base.graph_only_combinations()) 916 def testIteratorGetNextName(self): 917 with ops.Graph().as_default(): 918 iterator = dataset_ops.make_one_shot_iterator( 919 dataset_ops.Dataset.from_tensors(37.0)) 920 next_element = iterator.get_next(name="overridden_name") 921 self.assertEqual("overridden_name", next_element.op.name) 922 923 @combinations.generate( 924 combinations.combine( 925 tf_api_version=[1, 2], 926 mode="eager", 927 execution_mode=[context.ASYNC, context.SYNC])) 928 def testIteratorEagerIteration(self, execution_mode): 929 with context.eager_mode(), context.execution_mode(execution_mode): 930 val = 0 931 dataset = dataset_ops.Dataset.range(10) 932 iterator = iter(dataset) 933 for foo in iterator: 934 self.assertEqual(val, foo.numpy()) 935 val += 1 936 937 @combinations.generate(test_base.eager_only_combinations()) 938 def testOwnedIteratorFunction(self): 939 940 queue = data_flow_ops.FIFOQueue(10, dtypes.int64) 941 942 @def_function.function 943 def fn(): 944 dataset = dataset_ops.Dataset.range(10) 945 iterator = iter(dataset) 946 for _ in range(10): 947 queue.enqueue(next(iterator)) 948 949 fn() 950 951 for i in range(10): 952 self.assertEqual(queue.dequeue().numpy(), i) 953 954 @combinations.generate(test_base.eager_only_combinations()) 955 def testOwnedIteratorFunctionError(self): 956 # In this test we verify that a function that raises an error ends up 957 # properly deallocating the iterator resource. 958 959 queue = data_flow_ops.FIFOQueue(10, dtypes.int64) 960 queue.enqueue(0) 961 962 def init_fn(n): 963 return n 964 965 def next_fn(_): 966 ds = dataset_ops.Dataset.range(0) 967 return next(iter(ds)) 968 969 def finalize_fn(n): 970 queue.enqueue(0) 971 return n 972 973 @def_function.function 974 def fn(): 975 output_signature = tensor_spec.TensorSpec((), dtypes.int64) 976 dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, 977 output_signature) 978 iterator = iter(dataset) 979 next(iterator) 980 981 with self.assertRaises(errors.OutOfRangeError): 982 fn() 983 984 self.assertEqual(queue.size().numpy(), 2) 985 986 @combinations.generate(test_base.default_test_combinations()) 987 def testNoInitializer(self): 988 dataset = dataset_ops.Dataset.range(10) 989 iterator = iterator_ops.Iterator.from_structure( 990 dataset_ops.get_legacy_output_types(dataset), []) 991 with self.assertRaisesRegex( 992 ValueError, "The iterator does not have an initializer."): 993 _ = iterator.initializer 994 995 @combinations.generate(test_base.default_test_combinations()) 996 def testtestMissingInput(self): 997 with self.assertRaisesRegex( 998 ValueError, 999 "When `dataset` is not provided, both `components` and `element_spec` " 1000 "must be specified."): 1001 iterator_ops.OwnedIterator(dataset=None) 1002 1003 @combinations.generate(test_base.eager_only_combinations()) 1004 def testExtraElementSpecInput(self): 1005 dataset = dataset_ops.Dataset.range(1000) 1006 with self.assertRaisesRegex( 1007 ValueError, 1008 "When `dataset` is provided, `element_spec` and `components` must " 1009 "not be specified."): 1010 iterator_ops.OwnedIterator( 1011 dataset, element_spec=dataset.element_spec) 1012 1013 @combinations.generate(test_base.eager_only_combinations()) 1014 def testLimitedRetracing(self): 1015 trace_count = [0] 1016 1017 @def_function.function 1018 def f(iterator): 1019 trace_count[0] += 1 1020 counter = np.int64(0) 1021 for elem in iterator: 1022 counter += elem 1023 return counter 1024 1025 dataset = dataset_ops.Dataset.range(5) 1026 dataset2 = dataset_ops.Dataset.range(10) 1027 1028 for _ in range(10): 1029 self.assertEqual(self.evaluate(f(iter(dataset))), 10) 1030 self.assertEqual(self.evaluate(f(iter(dataset2))), 45) 1031 self.assertEqual(trace_count[0], 1) 1032 1033 @combinations.generate(test_base.eager_only_combinations()) 1034 def testNestedFunctionsIteratorResource(self): 1035 1036 @def_function.function 1037 def sum_dataset(ds): 1038 it = iter(ds) 1039 1040 @def_function.function 1041 def next_element(it): 1042 return next(it) 1043 1044 total = 0 1045 for _ in range(10): 1046 total += next_element(it) 1047 return total 1048 1049 ds = dataset_ops.Dataset.range(10) 1050 self.assertEqual(sum_dataset(ds).numpy(), 45) 1051 self.assertEqual(sum_dataset(ds).numpy(), 45) 1052 1053 @combinations.generate(test_base.default_test_combinations()) 1054 def testNestedAutomaticControlDependencies(self): 1055 counter_var = variables.Variable(0) 1056 1057 def map_fn(x): 1058 counter_var.assign_add(1) 1059 return x 1060 1061 def dataset_fn(): 1062 return dataset_ops.Dataset.range(10).map(map_fn) 1063 1064 @def_function.function 1065 def fn(): 1066 it = iter(dataset_fn()) 1067 for _ in range(10): 1068 _ = next(it) 1069 return counter_var 1070 1071 self.evaluate(counter_var.initializer) 1072 self.assertEqual(self.evaluate(fn()), 10) 1073 1074 1075if __name__ == "__main__": 1076 test.main() 1077