xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/kernel_tests/iterator_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Iterator`."""
16import warnings
17
18from absl.testing import parameterized
19import numpy as np
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.python.client import session
24from tensorflow.python.data.kernel_tests import test_base
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.data.util import structure
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import combinations
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import function
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import test_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import data_flow_ops
41from tensorflow.python.ops import functional_ops
42from tensorflow.python.ops import gradients_impl
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import parsing_ops
45from tensorflow.python.ops import script_ops
46from tensorflow.python.ops import variables
47from tensorflow.python.platform import test
48from tensorflow.python.training import server_lib
49from tensorflow.python.util import compat
50
51
52@test_util.with_eager_op_as_function
53class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
54
55  @combinations.generate(test_base.graph_only_combinations())
56  def testNoGradients(self):
57    component = constant_op.constant([1.])
58    side = constant_op.constant(0.)
59    add = lambda x: x + side
60    dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
61    value = dataset_ops.make_one_shot_iterator(dataset).get_next()
62    self.assertIsNone(gradients_impl.gradients(value, component)[0])
63    self.assertIsNone(gradients_impl.gradients(value, side)[0])
64    self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
65
66  @combinations.generate(test_base.graph_only_combinations())
67  def testCapturingStateInOneShotRaisesException(self):
68    var = variables.Variable(37.0, name="myvar")
69    dataset = (
70        dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
71        .map(lambda x: x + var))
72    with self.assertRaisesRegex(
73        ValueError, r"A likely cause of this error is that the dataset for "
74        r"which you are calling `make_one_shot_iterator\(\)` captures a "
75        r"stateful object, such as a `tf.Variable` or "
76        r"`tf.lookup.StaticHashTable`, which is not supported. Use "
77        r"`make_initializable_iterator\(\)` instead."):
78      dataset_ops.make_one_shot_iterator(dataset)
79
80  @combinations.generate(test_base.graph_only_combinations())
81  def testOneShotIterator(self):
82    components = (np.arange(7),
83                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
84                  np.array(37.0) * np.arange(7))
85
86    def _map_fn(x, y, z):
87      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
88
89    iterator = dataset_ops.make_one_shot_iterator(
90        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
91        .repeat(14))
92    get_next = iterator.get_next()
93
94    self.assertEqual([c.shape[1:] for c in components],
95                     [t.shape for t in get_next])
96
97    with self.cached_session() as sess:
98      for _ in range(14):
99        for i in range(7):
100          result = sess.run(get_next)
101          for component, result_component in zip(components, result):
102            self.assertAllEqual(component[i]**2, result_component)
103      with self.assertRaises(errors.OutOfRangeError):
104        sess.run(get_next)
105
106  @combinations.generate(test_base.graph_only_combinations())
107  def testOneShotIteratorCaptureByValue(self):
108    components = (np.arange(7),
109                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
110                  np.array(37.0) * np.arange(7))
111    tensor_components = tuple([ops.convert_to_tensor(c) for c in components])
112
113    def _map_fn(x, y, z):
114      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
115
116    iterator = dataset_ops.make_one_shot_iterator(
117        dataset_ops.Dataset.from_tensor_slices(tensor_components)
118        .map(_map_fn).repeat(14))
119    get_next = iterator.get_next()
120
121    self.assertEqual([c.shape[1:] for c in components],
122                     [t.shape for t in get_next])
123
124    with self.cached_session() as sess:
125      for _ in range(14):
126        for i in range(7):
127          result = sess.run(get_next)
128          for component, result_component in zip(components, result):
129            self.assertAllEqual(component[i]**2, result_component)
130      with self.assertRaises(errors.OutOfRangeError):
131        sess.run(get_next)
132
133  @combinations.generate(test_base.default_test_combinations())
134  def testOneShotIteratorInsideContainer(self):
135    components = (np.arange(7),
136                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
137                  np.array(37.0) * np.arange(7))
138
139    def within_container():
140
141      def _map_fn(x, y, z):
142        return math_ops.square(x), math_ops.square(y), math_ops.square(z)
143
144      iterator = dataset_ops.make_one_shot_iterator(
145          dataset_ops.Dataset.from_tensor_slices(components)
146          .map(_map_fn).repeat(14))
147      return iterator.get_next()
148
149    server = server_lib.Server.create_local_server()
150
151    # Create two iterators within unique containers, and run them to
152    # make sure that the resources aren't shared.
153    #
154    # The test below would fail if cname were the same across both
155    # sessions.
156    for j in range(2):
157      with session.Session(server.target) as sess:
158        cname = "iteration%d" % j
159        with ops.container(cname):
160          get_next = within_container()
161
162        for _ in range(14):
163          for i in range(7):
164            result = sess.run(get_next)
165            for component, result_component in zip(components, result):
166              self.assertAllEqual(component[i]**2, result_component)
167        with self.assertRaises(errors.OutOfRangeError):
168          sess.run(get_next)
169
170  @combinations.generate(test_base.graph_only_combinations())
171  def testOneShotIteratorNonBlocking(self):
172    dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
173    iterator = dataset_ops.make_one_shot_iterator(dataset)
174    next_element = iterator.get_next()
175
176    # Create a session with a single thread to ensure that the
177    # one-shot iterator initializer does not deadlock.
178    config = config_pb2.ConfigProto(
179        inter_op_parallelism_threads=1, use_per_session_threads=True)
180    with session.Session(config=config) as sess:
181      self.assertAllEqual([1, 4, 9], sess.run(next_element))
182      with self.assertRaises(errors.OutOfRangeError):
183        sess.run(next_element)
184
185    # Test with multiple threads invoking the one-shot iterator concurrently.
186    with session.Session(config=config) as sess:
187      results = []
188
189      def consumer_thread():
190        try:
191          results.append(sess.run(next_element))
192        except errors.OutOfRangeError:
193          results.append(None)
194
195      num_threads = 8
196      threads = [
197          self.checkedThread(consumer_thread) for _ in range(num_threads)
198      ]
199      for t in threads:
200        t.start()
201      for t in threads:
202        t.join()
203
204      self.assertLen(results, num_threads)
205      self.assertLen([None for r in results if r is None], num_threads - 1)
206      self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
207
208  @combinations.generate(test_base.graph_only_combinations())
209  def testOneShotIteratorInitializerFails(self):
210    # Define a dataset whose initialization will always fail.
211    dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
212    iterator = dataset_ops.make_one_shot_iterator(dataset)
213    next_element = iterator.get_next()
214
215    with self.cached_session() as sess:
216      with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
217        sess.run(next_element)
218
219      # Test that subsequent attempts to use the iterator also fail.
220      with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
221        sess.run(next_element)
222
223    with self.cached_session() as sess:
224
225      def consumer_thread():
226        with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
227          sess.run(next_element)
228
229      num_threads = 8
230      threads = [
231          self.checkedThread(consumer_thread) for _ in range(num_threads)
232      ]
233      for t in threads:
234        t.start()
235      for t in threads:
236        t.join()
237
238  @combinations.generate(test_base.default_test_combinations())
239  def testOneShotIteratorEmptyDataset(self):
240    dataset = dataset_ops.Dataset.range(0)
241    iterator = dataset_ops.make_one_shot_iterator(dataset)
242    with self.assertRaises(errors.OutOfRangeError):
243      self.evaluate(iterator.get_next())
244
245  @combinations.generate(test_base.graph_only_combinations())
246  def testSimpleSharedResource(self):
247    components = (np.array(1, dtype=np.int64),
248                  np.array([1, 2, 3], dtype=np.int64),
249                  np.array(37.0, dtype=np.float64))
250
251    server = server_lib.Server.create_local_server()
252
253    # Create two non-overlapping sessions that share the same iterator
254    # resource on the same server, and verify that an action of the
255    # first session (initializing the iterator) is visible in the
256    # second session.
257    with ops.Graph().as_default():
258      iterator = dataset_ops.make_initializable_iterator(
259          dataset_ops.Dataset.from_tensors(
260              components).map(lambda x, y, z: (x, y, z)),
261          shared_name="shared_iterator")
262      init_op = iterator.initializer
263      get_next = iterator.get_next()
264
265      with session.Session(server.target) as sess:
266        sess.run(init_op)
267        results = sess.run(get_next)
268        for component, result_component in zip(components, results):
269          self.assertAllEqual(component, result_component)
270        with self.assertRaises(errors.OutOfRangeError):
271          sess.run(get_next)
272
273        # Re-initialize the iterator in the first session.
274        sess.run(init_op)
275
276    with ops.Graph().as_default():
277      # Re-define the iterator manually, without defining any of the
278      # functions in this graph, to ensure that we are not
279      # accidentally redefining functions with the same names in the
280      # new graph.
281      iterator = iterator_ops.Iterator.from_structure(
282          shared_name="shared_iterator",
283          output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
284          output_shapes=([], [3], []))
285      get_next = iterator.get_next()
286
287      with session.Session(server.target) as sess:
288        # Use the iterator without re-initializing in the second session.
289        results = sess.run(get_next)
290        for component, result_component in zip(components, results):
291          self.assertAllEqual(component, result_component)
292        with self.assertRaises(errors.OutOfRangeError):
293          sess.run(get_next)
294
295  @combinations.generate(test_base.graph_only_combinations())
296  def testNotInitializedError(self):
297    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
298    iterator = dataset_ops.make_initializable_iterator(
299        dataset_ops.Dataset.from_tensors(components))
300    get_next = iterator.get_next()
301
302    with self.cached_session() as sess:
303      with self.assertRaisesRegex(errors.FailedPreconditionError,
304                                  "iterator has not been initialized"):
305        sess.run(get_next)
306
307  @combinations.generate(test_base.graph_only_combinations())
308  def testReinitializableIterator(self):
309    dataset_3 = dataset_ops.Dataset.from_tensors(
310        constant_op.constant([1, 2, 3]))
311    dataset_4 = dataset_ops.Dataset.from_tensors(
312        constant_op.constant([4, 5, 6, 7]))
313    iterator = iterator_ops.Iterator.from_structure(
314        dataset_ops.get_legacy_output_types(dataset_3), [None])
315
316    dataset_3_init_op = iterator.make_initializer(dataset_3)
317    dataset_4_init_op = iterator.make_initializer(dataset_4)
318    get_next = iterator.get_next()
319
320    self.assertEqual(
321        dataset_ops.get_legacy_output_types(dataset_3),
322        dataset_ops.get_legacy_output_types(iterator))
323    self.assertEqual(
324        dataset_ops.get_legacy_output_types(dataset_4),
325        dataset_ops.get_legacy_output_types(iterator))
326    self.assertEqual(
327        [None], dataset_ops.get_legacy_output_shapes(iterator).as_list())
328
329    with self.cached_session() as sess:
330      # The iterator is initially uninitialized.
331      with self.assertRaises(errors.FailedPreconditionError):
332        sess.run(get_next)
333
334      # Initialize with one dataset.
335      sess.run(dataset_3_init_op)
336      self.assertAllEqual([1, 2, 3], sess.run(get_next))
337      with self.assertRaises(errors.OutOfRangeError):
338        sess.run(get_next)
339
340      # Initialize with a different dataset.
341      sess.run(dataset_4_init_op)
342      self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
343      with self.assertRaises(errors.OutOfRangeError):
344        sess.run(get_next)
345
346      # Reinitialize with the first dataset.
347      sess.run(dataset_3_init_op)
348      self.assertAllEqual([1, 2, 3], sess.run(get_next))
349      with self.assertRaises(errors.OutOfRangeError):
350        sess.run(get_next)
351
352  @combinations.generate(test_base.graph_only_combinations())
353  def testReinitializableIteratorWithFunctions(self):
354
355    def g():
356      for i in range(10):
357        yield i
358
359    iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
360    next_element = iterator.get_next()
361
362    with self.cached_session() as sess:
363      dataset_1 = dataset_ops.Dataset.from_generator(
364          g, output_types=dtypes.int64)
365      sess.run(iterator.make_initializer(dataset_1))
366      for expected in range(10):
367        self.assertEqual(expected, sess.run(next_element))
368      with self.assertRaises(errors.OutOfRangeError):
369        sess.run(next_element)
370
371      dataset_2 = dataset_ops.Dataset.from_generator(
372          g, output_types=dtypes.int64)
373      sess.run(iterator.make_initializer(dataset_2))
374      for expected in range(10):
375        self.assertEqual(expected, sess.run(next_element))
376      with self.assertRaises(errors.OutOfRangeError):
377        sess.run(next_element)
378
379  @combinations.generate(test_base.default_test_combinations())
380  def testReinitializableIteratorStaticErrors(self):
381    # Non-matching structure for types and shapes.
382    with self.assertRaises(TypeError):
383      iterator = iterator_ops.Iterator.from_structure(
384          (dtypes.int64, dtypes.float64), [None])
385
386    # Test validation of dataset argument.
387    iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
388                                                     dtypes.float64))
389
390    # Incompatible structure.
391    with self.assertRaisesRegex(
392        ValueError, "The two structures don't have the same nested structure."):
393      iterator.make_initializer(
394          dataset_ops.Dataset.from_tensors(((constant_op.constant(
395              [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
396                  [4., 5., 6., 7.], dtype=dtypes.float64),))))
397
398    # Incompatible types.
399    with self.assertRaisesRegex(
400        TypeError,
401        r"Expected output types \(tf.int64, tf.float64\) but got dataset with "
402        r"output types \(tf.int32, tf.float32\)."):
403      iterator.make_initializer(
404          dataset_ops.Dataset.from_tensors(
405              (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
406               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32))))
407
408    # Incompatible shapes.
409    iterator = iterator_ops.Iterator.from_structure(
410        (dtypes.int64, dtypes.float64), ([None], []))
411    with self.assertRaisesRegex(
412        TypeError,
413        r"Expected output shapes compatible with .* but got dataset with "
414        r"output shapes.*"):
415      iterator.make_initializer(
416          dataset_ops.Dataset.from_tensors(
417              (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
418               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
419
420  @combinations.generate(test_base.default_test_combinations())
421  def testReinitializableIteratorEmptyDataset(self):
422    dataset = dataset_ops.Dataset.range(0)
423    iterator = iterator_ops.Iterator.from_structure(
424        dataset_ops.get_legacy_output_types(dataset), [])
425    init_op = iterator.make_initializer(dataset)
426
427    with self.cached_session() as sess:
428      sess.run(init_op)
429      with self.assertRaises(errors.OutOfRangeError):
430        sess.run(iterator.get_next())
431
432  @combinations.generate(test_base.graph_only_combinations())
433  def testIteratorStringHandle(self):
434    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
435    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
436
437    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
438    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
439
440    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
441    feedable_iterator = iterator_ops.Iterator.from_string_handle(
442        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
443        dataset_ops.get_legacy_output_shapes(dataset_3))
444    next_element = feedable_iterator.get_next()
445
446    self.assertTrue(
447        structure.are_compatible(
448            dataset_ops.get_structure(dataset_3),
449            dataset_ops.get_structure(feedable_iterator)))
450
451    with self.cached_session() as sess:
452      iterator_3_handle = sess.run(iterator_3.string_handle())
453      iterator_4_handle = sess.run(iterator_4.string_handle())
454
455      self.assertEqual(10,
456                       sess.run(
457                           next_element,
458                           feed_dict={handle_placeholder: iterator_4_handle}))
459      self.assertEqual(1,
460                       sess.run(
461                           next_element,
462                           feed_dict={handle_placeholder: iterator_3_handle}))
463      self.assertEqual(20,
464                       sess.run(
465                           next_element,
466                           feed_dict={handle_placeholder: iterator_4_handle}))
467      self.assertEqual(2,
468                       sess.run(
469                           next_element,
470                           feed_dict={handle_placeholder: iterator_3_handle}))
471      self.assertEqual(30,
472                       sess.run(
473                           next_element,
474                           feed_dict={handle_placeholder: iterator_4_handle}))
475      self.assertEqual(3,
476                       sess.run(
477                           next_element,
478                           feed_dict={handle_placeholder: iterator_3_handle}))
479      self.assertEqual(40,
480                       sess.run(
481                           next_element,
482                           feed_dict={handle_placeholder: iterator_4_handle}))
483      with self.assertRaises(errors.OutOfRangeError):
484        sess.run(
485            next_element, feed_dict={handle_placeholder: iterator_3_handle})
486      with self.assertRaises(errors.OutOfRangeError):
487        sess.run(
488            next_element, feed_dict={handle_placeholder: iterator_4_handle})
489
490  @combinations.generate(test_base.graph_only_combinations())
491  def testIteratorStringHandleFuture(self):
492    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
493    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
494
495    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
496    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
497
498    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
499    feedable_iterator = iterator_ops.Iterator.from_string_handle(
500        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
501        dataset_ops.get_legacy_output_shapes(dataset_3))
502    next_element = feedable_iterator.get_next()
503
504    self.assertTrue(
505        structure.are_compatible(
506            dataset_ops.get_structure(dataset_3),
507            dataset_ops.get_structure(feedable_iterator)))
508
509    with self.cached_session() as sess:
510      iterator_3_handle = sess.run(iterator_3.string_handle())
511      iterator_4_handle = sess.run(iterator_4.string_handle())
512
513      self.assertEqual(
514          10,
515          sess.run(
516              next_element,
517              feed_dict={handle_placeholder: iterator_4_handle}))
518      self.assertEqual(
519          1,
520          sess.run(
521              next_element,
522              feed_dict={handle_placeholder: iterator_3_handle}))
523      self.assertEqual(
524          20,
525          sess.run(
526              next_element,
527              feed_dict={handle_placeholder: iterator_4_handle}))
528      self.assertEqual(
529          2,
530          sess.run(
531              next_element,
532              feed_dict={handle_placeholder: iterator_3_handle}))
533      self.assertEqual(
534          30,
535          sess.run(
536              next_element,
537              feed_dict={handle_placeholder: iterator_4_handle}))
538      self.assertEqual(
539          3,
540          sess.run(
541              next_element,
542              feed_dict={handle_placeholder: iterator_3_handle}))
543      self.assertEqual(
544          40,
545          sess.run(
546              next_element,
547              feed_dict={handle_placeholder: iterator_4_handle}))
548      with self.assertRaises(errors.OutOfRangeError):
549        sess.run(
550            next_element, feed_dict={handle_placeholder: iterator_3_handle})
551      with self.assertRaises(errors.OutOfRangeError):
552        sess.run(
553            next_element, feed_dict={handle_placeholder: iterator_4_handle})
554
555  @combinations.generate(test_base.graph_only_combinations())
556  def testIteratorStringHandleReuseTensorObject(self):
557    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
558    one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
559    initializable_iterator = dataset_ops.make_initializable_iterator(dataset)
560    structure_iterator = iterator_ops.Iterator.from_structure(
561        dataset_ops.get_legacy_output_types(dataset))
562
563    created_ops = len(ops.get_default_graph().get_operations())
564
565    self.assertIs(one_shot_iterator.string_handle(),
566                  one_shot_iterator.string_handle())
567    self.assertIs(initializable_iterator.string_handle(),
568                  initializable_iterator.string_handle())
569    self.assertIs(structure_iterator.string_handle(),
570                  structure_iterator.string_handle())
571
572    # Assert that getting the (default) string handle creates no ops.
573    self.assertLen(ops.get_default_graph().get_operations(), created_ops)
574
575    # Specifying an explicit name will create a new op.
576    handle_with_name = one_shot_iterator.string_handle(name="foo")
577    self.assertEqual("foo", handle_with_name.op.name)
578    self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)
579
580    handle_with_same_name = one_shot_iterator.string_handle(name="foo")
581    self.assertEqual("foo_1", handle_with_same_name.op.name)
582    self.assertIsNot(handle_with_name, handle_with_same_name)
583
584  @combinations.generate(test_base.graph_only_combinations())
585  def testIteratorStringHandleError(self):
586    dataset_int_scalar = (
587        dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
588    dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]))
589
590    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
591
592    feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
593        handle_placeholder, dtypes.int32, [])
594    feedable_int_vector = iterator_ops.Iterator.from_string_handle(
595        handle_placeholder, dtypes.int32, [None])
596    feedable_int_any = iterator_ops.Iterator.from_string_handle(
597        handle_placeholder, dtypes.int32)
598
599    with self.cached_session() as sess:
600      handle_int_scalar = sess.run(dataset_ops.make_one_shot_iterator(
601          dataset_int_scalar).string_handle())
602      handle_float_vector = sess.run(dataset_ops.make_one_shot_iterator(
603          dataset_float_vector).string_handle())
604
605      self.assertEqual(1,
606                       sess.run(
607                           feedable_int_scalar.get_next(),
608                           feed_dict={handle_placeholder: handle_int_scalar}))
609
610      self.assertEqual(2,
611                       sess.run(
612                           feedable_int_any.get_next(),
613                           feed_dict={handle_placeholder: handle_int_scalar}))
614
615      with self.assertRaises(errors.InvalidArgumentError):
616        print(sess.run(
617            feedable_int_vector.get_next(),
618            feed_dict={handle_placeholder: handle_int_scalar}))
619
620      with self.assertRaises(errors.InvalidArgumentError):
621        print(sess.run(
622            feedable_int_vector.get_next(),
623            feed_dict={handle_placeholder: handle_float_vector}))
624
625  @combinations.generate(test_base.graph_only_combinations())
626  def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
627    worker_config = config_pb2.ConfigProto()
628    worker_config.device_count["CPU"] = 3
629
630    with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
631      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
632      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
633      iterator_3_handle = iterator_3.string_handle()
634
635    @function.Defun(dtypes.string)
636    def _remote_fn(h):
637      remote_iterator = iterator_ops.Iterator.from_string_handle(
638          h, dataset_ops.get_legacy_output_types(dataset_3),
639          dataset_ops.get_legacy_output_shapes(dataset_3))
640      return remote_iterator.get_next()
641
642    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
643      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
644      remote_op = functional_ops.remote_call(
645          args=[iterator_3_handle],
646          Tout=[dtypes.int32],
647          f=_remote_fn,
648          target=target_placeholder)
649
650    with self.session(config=worker_config) as sess:
651      elem = sess.run(
652          remote_op,
653          feed_dict={
654              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
655          })
656      self.assertEqual(elem, [1])
657      # Fails when target is cpu:2 where the resource is not located.
658      with self.assertRaises(errors.InvalidArgumentError):
659        sess.run(
660            remote_op,
661            feed_dict={
662                target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
663            })
664      elem = sess.run(
665          remote_op,
666          feed_dict={
667              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
668          })
669      self.assertEqual(elem, [2])
670      elem = sess.run(
671          remote_op,
672          feed_dict={
673              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
674          })
675      self.assertEqual(elem, [3])
676      with self.assertRaises(errors.OutOfRangeError):
677        sess.run(
678            remote_op,
679            feed_dict={
680                target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
681            })
682
683  @combinations.generate(test_base.graph_only_combinations())
684  def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
685    s1 = server_lib.Server.create_local_server()
686    s2 = server_lib.Server.create_local_server()
687    s3 = server_lib.Server.create_local_server()
688
689    cluster_def = cluster_pb2.ClusterDef()
690    workers = cluster_def.job.add()
691    workers.name = "worker"
692    workers.tasks[0] = s1.target[len("grpc://"):]
693    workers.tasks[1] = s2.target[len("grpc://"):]
694    client = cluster_def.job.add()
695    client.name = "client"
696    client.tasks[0] = s3.target[len("grpc://"):]
697    config = config_pb2.ConfigProto(cluster_def=cluster_def)
698
699    worker_devices = [
700        "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
701    ]
702    itr_handles = []
703    for device in worker_devices:
704      with ops.device(device):
705        src = dataset_ops.Dataset.from_tensor_slices([device])
706        itr = dataset_ops.make_one_shot_iterator(src)
707        itr_handles.append(itr.string_handle())
708
709    targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
710    handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)
711
712    @function.Defun(dtypes.string)
713    def loading_func(h):
714      remote_itr = iterator_ops.Iterator.from_string_handle(
715          h, dataset_ops.get_legacy_output_types(itr),
716          dataset_ops.get_legacy_output_shapes(itr))
717      return remote_itr.get_next()
718
719    def map_fn(target, handle):
720      return functional_ops.remote_call(
721          args=[handle], Tout=[dtypes.string], f=loading_func, target=target)
722
723    with ops.device("/job:client"):
724      client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn)
725      itr = dataset_ops.make_initializable_iterator(client_dataset)
726      n = itr.get_next()
727
728    with session.Session(s3.target, config=config) as sess:
729      sess.run(itr.initializer)
730      expected_values = worker_devices
731      for expected in expected_values:
732        self.assertEqual((compat.as_bytes(expected),), sess.run(n))
733
734      with self.assertRaises(errors.OutOfRangeError):
735        sess.run(n)
736
737  @combinations.generate(test_base.graph_only_combinations())
738  def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
739    if not test_util.is_gpu_available():
740      self.skipTest("No GPU available")
741
742    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
743      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
744      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
745      iterator_3_handle = iterator_3.string_handle()
746
747    def _encode_raw(byte_array):
748      return bytes(bytearray(byte_array))
749
750    @function.Defun(dtypes.uint8)
751    def _remote_fn(h):
752      handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
753      remote_iterator = iterator_ops.Iterator.from_string_handle(
754          handle, dataset_ops.get_legacy_output_types(dataset_3),
755          dataset_ops.get_legacy_output_shapes(dataset_3))
756      return remote_iterator.get_next()
757
758    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
759      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
760      iterator_3_handle_uint8 = parsing_ops.decode_raw(
761          input_bytes=iterator_3_handle, out_type=dtypes.uint8)
762      remote_op = functional_ops.remote_call(
763          args=[iterator_3_handle_uint8],
764          Tout=[dtypes.int32],
765          f=_remote_fn,
766          target=target_placeholder)
767
768    with self.cached_session() as sess:
769      elem = sess.run(
770          remote_op,
771          feed_dict={
772              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
773          })
774      self.assertEqual(elem, [1])
775      elem = sess.run(
776          remote_op,
777          feed_dict={
778              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
779          })
780      self.assertEqual(elem, [2])
781      elem = sess.run(
782          remote_op,
783          feed_dict={
784              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
785          })
786      self.assertEqual(elem, [3])
787      with self.assertRaises(errors.OutOfRangeError):
788        sess.run(
789            remote_op,
790            feed_dict={
791                target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
792            })
793
794  @combinations.generate(test_base.graph_only_combinations())
795  def testRepeatedGetNextWarning(self):
796    iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
797    warnings.simplefilter("always")
798    with warnings.catch_warnings(record=True) as w:
799      for _ in range(100):
800        iterator.get_next()
801    self.assertLen(w, 100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD)
802    for warning in w:
803      self.assertIn(
804          iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message))
805
806  @combinations.generate(
807      combinations.times(
808          test_base.default_test_combinations(),
809          combinations.combine(
810              expected_element_structure=tensor_spec.TensorSpec([],
811                                                                dtypes.float32),
812              expected_output_classes=ops.Tensor,
813              expected_output_types=dtypes.float32,
814              expected_output_shapes=[[]])))
815  def testTensorIteratorStructure(self, expected_element_structure,
816                                  expected_output_classes,
817                                  expected_output_types,
818                                  expected_output_shapes):
819    tf_value_fn = lambda: constant_op.constant(37.0)
820    tf_value = tf_value_fn()
821    iterator = dataset_ops.make_one_shot_iterator(
822        dataset_ops.Dataset.from_tensors(tf_value))
823
824    self.assertTrue(
825        structure.are_compatible(
826            dataset_ops.get_structure(iterator), expected_element_structure))
827    self.assertEqual(expected_output_classes,
828                     dataset_ops.get_legacy_output_classes(iterator))
829    self.assertEqual(expected_output_types,
830                     dataset_ops.get_legacy_output_types(iterator))
831    self.assertEqual(expected_output_shapes,
832                     dataset_ops.get_legacy_output_shapes(iterator))
833
834  @combinations.generate(
835      combinations.times(
836          test_base.default_test_combinations(),
837          combinations.combine(
838              expected_element_structure=sparse_tensor.SparseTensorSpec(
839                  [1], dtypes.int32),
840              expected_output_classes=sparse_tensor.SparseTensor,
841              expected_output_types=dtypes.int32,
842              expected_output_shapes=[[1]])))
843  def testSparseTensorIteratorStructure(self, expected_element_structure,
844                                        expected_output_classes,
845                                        expected_output_types,
846                                        expected_output_shapes):
847
848    def tf_value_fn():
849      return sparse_tensor.SparseTensor(
850          indices=[[0]],
851          values=constant_op.constant([0], dtype=dtypes.int32),
852          dense_shape=[1])
853
854    tf_value = tf_value_fn()
855    iterator = dataset_ops.make_one_shot_iterator(
856        dataset_ops.Dataset.from_tensors(tf_value))
857
858    self.assertTrue(
859        structure.are_compatible(
860            dataset_ops.get_structure(iterator), expected_element_structure))
861    self.assertEqual(expected_output_classes,
862                     dataset_ops.get_legacy_output_classes(iterator))
863    self.assertEqual(expected_output_types,
864                     dataset_ops.get_legacy_output_types(iterator))
865    self.assertEqual(expected_output_shapes,
866                     dataset_ops.get_legacy_output_shapes(iterator))
867
868  @combinations.generate(
869      combinations.times(
870          test_base.default_test_combinations(),
871          combinations.combine(
872              expected_element_structure={
873                  "a":
874                      tensor_spec.TensorSpec([], dtypes.float32),
875                  "b": (tensor_spec.TensorSpec([1], dtypes.string),
876                        tensor_spec.TensorSpec([], dtypes.string))
877              },
878              expected_output_classes={
879                  "a": ops.Tensor,
880                  "b": (ops.Tensor, ops.Tensor)
881              },
882              expected_output_types={
883                  "a": dtypes.float32,
884                  "b": (dtypes.string, dtypes.string)
885              },
886              expected_output_shapes={
887                  "a": [],
888                  "b": ([1], [])
889              })))
890  def testNestedTensorIteratorStructure(self, expected_element_structure,
891                                        expected_output_classes,
892                                        expected_output_types,
893                                        expected_output_shapes):
894
895    def tf_value_fn():
896      return {
897          "a": constant_op.constant(37.0),
898          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
899      }
900
901    tf_value = tf_value_fn()
902    iterator = dataset_ops.make_one_shot_iterator(
903        dataset_ops.Dataset.from_tensors(tf_value))
904
905    self.assertTrue(
906        structure.are_compatible(
907            dataset_ops.get_structure(iterator), expected_element_structure))
908    self.assertEqual(expected_output_classes,
909                     dataset_ops.get_legacy_output_classes(iterator))
910    self.assertEqual(expected_output_types,
911                     dataset_ops.get_legacy_output_types(iterator))
912    self.assertEqual(expected_output_shapes,
913                     dataset_ops.get_legacy_output_shapes(iterator))
914
915  @combinations.generate(test_base.graph_only_combinations())
916  def testIteratorGetNextName(self):
917    with ops.Graph().as_default():
918      iterator = dataset_ops.make_one_shot_iterator(
919          dataset_ops.Dataset.from_tensors(37.0))
920      next_element = iterator.get_next(name="overridden_name")
921      self.assertEqual("overridden_name", next_element.op.name)
922
923  @combinations.generate(
924      combinations.combine(
925          tf_api_version=[1, 2],
926          mode="eager",
927          execution_mode=[context.ASYNC, context.SYNC]))
928  def testIteratorEagerIteration(self, execution_mode):
929    with context.eager_mode(), context.execution_mode(execution_mode):
930      val = 0
931      dataset = dataset_ops.Dataset.range(10)
932      iterator = iter(dataset)
933      for foo in iterator:
934        self.assertEqual(val, foo.numpy())
935        val += 1
936
937  @combinations.generate(test_base.eager_only_combinations())
938  def testOwnedIteratorFunction(self):
939
940    queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
941
942    @def_function.function
943    def fn():
944      dataset = dataset_ops.Dataset.range(10)
945      iterator = iter(dataset)
946      for _ in range(10):
947        queue.enqueue(next(iterator))
948
949    fn()
950
951    for i in range(10):
952      self.assertEqual(queue.dequeue().numpy(), i)
953
954  @combinations.generate(test_base.eager_only_combinations())
955  def testOwnedIteratorFunctionError(self):
956    # In this test we verify that a function that raises an error ends up
957    # properly deallocating the iterator resource.
958
959    queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
960    queue.enqueue(0)
961
962    def init_fn(n):
963      return n
964
965    def next_fn(_):
966      ds = dataset_ops.Dataset.range(0)
967      return next(iter(ds))
968
969    def finalize_fn(n):
970      queue.enqueue(0)
971      return n
972
973    @def_function.function
974    def fn():
975      output_signature = tensor_spec.TensorSpec((), dtypes.int64)
976      dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn,
977                                              output_signature)
978      iterator = iter(dataset)
979      next(iterator)
980
981    with self.assertRaises(errors.OutOfRangeError):
982      fn()
983
984    self.assertEqual(queue.size().numpy(), 2)
985
986  @combinations.generate(test_base.default_test_combinations())
987  def testNoInitializer(self):
988    dataset = dataset_ops.Dataset.range(10)
989    iterator = iterator_ops.Iterator.from_structure(
990        dataset_ops.get_legacy_output_types(dataset), [])
991    with self.assertRaisesRegex(
992        ValueError, "The iterator does not have an initializer."):
993      _ = iterator.initializer
994
995  @combinations.generate(test_base.default_test_combinations())
996  def testtestMissingInput(self):
997    with self.assertRaisesRegex(
998        ValueError,
999        "When `dataset` is not provided, both `components` and `element_spec` "
1000        "must be specified."):
1001      iterator_ops.OwnedIterator(dataset=None)
1002
1003  @combinations.generate(test_base.eager_only_combinations())
1004  def testExtraElementSpecInput(self):
1005    dataset = dataset_ops.Dataset.range(1000)
1006    with self.assertRaisesRegex(
1007        ValueError,
1008        "When `dataset` is provided, `element_spec` and `components` must "
1009        "not be specified."):
1010      iterator_ops.OwnedIterator(
1011          dataset, element_spec=dataset.element_spec)
1012
1013  @combinations.generate(test_base.eager_only_combinations())
1014  def testLimitedRetracing(self):
1015    trace_count = [0]
1016
1017    @def_function.function
1018    def f(iterator):
1019      trace_count[0] += 1
1020      counter = np.int64(0)
1021      for elem in iterator:
1022        counter += elem
1023      return counter
1024
1025    dataset = dataset_ops.Dataset.range(5)
1026    dataset2 = dataset_ops.Dataset.range(10)
1027
1028    for _ in range(10):
1029      self.assertEqual(self.evaluate(f(iter(dataset))), 10)
1030      self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
1031      self.assertEqual(trace_count[0], 1)
1032
1033  @combinations.generate(test_base.eager_only_combinations())
1034  def testNestedFunctionsIteratorResource(self):
1035
1036    @def_function.function
1037    def sum_dataset(ds):
1038      it = iter(ds)
1039
1040      @def_function.function
1041      def next_element(it):
1042        return next(it)
1043
1044      total = 0
1045      for _ in range(10):
1046        total += next_element(it)
1047      return total
1048
1049    ds = dataset_ops.Dataset.range(10)
1050    self.assertEqual(sum_dataset(ds).numpy(), 45)
1051    self.assertEqual(sum_dataset(ds).numpy(), 45)
1052
1053  @combinations.generate(test_base.default_test_combinations())
1054  def testNestedAutomaticControlDependencies(self):
1055    counter_var = variables.Variable(0)
1056
1057    def map_fn(x):
1058      counter_var.assign_add(1)
1059      return x
1060
1061    def dataset_fn():
1062      return dataset_ops.Dataset.range(10).map(map_fn)
1063
1064    @def_function.function
1065    def fn():
1066      it = iter(dataset_fn())
1067      for _ in range(10):
1068        _ = next(it)
1069      return counter_var
1070
1071    self.evaluate(counter_var.initializer)
1072    self.assertEqual(self.evaluate(fn()), 10)
1073
1074
1075if __name__ == "__main__":
1076  test.main()
1077