1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the experimental input pipeline ops."""
16from absl.testing import parameterized
17import numpy as np
18
19from tensorflow.python.data.kernel_tests import checkpoint_test_base
20from tensorflow.python.data.kernel_tests import test_base
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.framework import combinations
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import random_seed
28from tensorflow.python.platform import test
29
30
31def _weights_type_combinations():
32  return combinations.combine(weights_type=["list", "tensor", "dataset"])
33
34
35def _get_weights_of_type(weights_list, weights_type):
36  if weights_type == "list":
37    return weights_list
38  if weights_type == "tensor":
39    return ops.convert_to_tensor(weights_list, name="weights")
40  return dataset_ops.Dataset.from_tensors(weights_list).repeat()
41
42
43class DirectedInterleaveDatasetTest(test_base.DatasetTestBase,
44                                    parameterized.TestCase):
45
46  @combinations.generate(test_base.default_test_combinations())
47  def testBasic(self):
48    selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
49    input_datasets = [
50        dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
51    ]
52    dataset = dataset_ops._DirectedInterleaveDataset(selector_dataset,
53                                                     input_datasets)
54    next_element = self.getNext(dataset)
55
56    for _ in range(100):
57      for i in range(10):
58        self.assertEqual(i, self.evaluate(next_element()))
59    with self.assertRaises(errors.OutOfRangeError):
60      self.evaluate(next_element())
61
62  def _normalize(self, vec):
63    return vec / vec.sum()
64
65  def _chi2(self, expected, actual):
66    actual = np.asarray(actual)
67    expected = np.asarray(expected)
68    diff = actual - expected
69    chi2 = np.sum(diff * diff / expected, axis=0)
70    return chi2
71
72  @combinations.generate(
73      combinations.times(test_base.default_test_combinations(),
74                         _weights_type_combinations()))
75  def testSampleFromDatasets(self, weights_type):
76    random_seed.set_random_seed(1619)
77    num_samples = 5000
78    rand_probs = self._normalize(np.random.random_sample((5,)))
79
80    # Use chi-squared test to assert that the observed distribution matches the
81    # expected distribution. Based on the implementation in
82    # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py".
83    for probs in [[.85, .05, .1], rand_probs, [1.]]:
84      weights = _get_weights_of_type(np.asarray(probs), weights_type)
85      classes = len(probs)
86
87      # Create a dataset that samples each integer in `[0, num_datasets)`
88      # with probability given by `weights[i]`.
89      dataset = dataset_ops.Dataset.sample_from_datasets([
90          dataset_ops.Dataset.from_tensors(i).repeat() for i in range(classes)
91      ], weights)
92      dataset = dataset.take(num_samples)
93
94      next_element = self.getNext(dataset)
95      freqs = np.zeros([classes])
96      for _ in range(num_samples):
97        freqs[self.evaluate(next_element())] += 1
98      with self.assertRaises(errors.OutOfRangeError):
99        self.evaluate(next_element())
100
101      self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
102
103  @combinations.generate(
104      combinations.times(test_base.default_test_combinations(),
105                         _weights_type_combinations()))
106  def testSampleFromDatasetsStoppingOnEmptyDataset(self, weights_type):
107    # Sampling stops when the first dataset is exhausted.
108    weights = _get_weights_of_type(np.asarray([.5, .1, .4]), weights_type)
109    datasets = [
110        dataset_ops.Dataset.from_tensors(np.int64(-1)),
111        dataset_ops.Dataset.from_tensors(np.int64(1)).repeat(),
112        dataset_ops.Dataset.range(10).repeat()
113    ]
114    sample_dataset = dataset_ops.Dataset.sample_from_datasets(
115        datasets, weights=weights, stop_on_empty_dataset=True)
116
117    samples_list = self.getIteratorOutput(self.getNext(sample_dataset))
118    self.assertEqual(samples_list.count(-1), 1)
119
120  @combinations.generate(
121      combinations.times(test_base.default_test_combinations(),
122                         _weights_type_combinations()))
123  def testSampleFromDatasetsSkippingEmptyDataset(self, weights_type):
124    # Sampling skips the first dataset after it becomes empty.
125    weights = _get_weights_of_type(np.asarray([.5, .1, .4]), weights_type)
126    datasets = [
127        dataset_ops.Dataset.from_tensors(np.int64(-1)),
128        dataset_ops.Dataset.from_tensors(np.int64(1)).repeat(),
129        dataset_ops.Dataset.range(10).repeat()
130    ]
131    sample_dataset = dataset_ops.Dataset.sample_from_datasets(
132        datasets, weights=weights, stop_on_empty_dataset=False).take(100)
133
134    samples_list = self.getIteratorOutput(self.getNext(sample_dataset))
135    self.assertLen(samples_list, 100)
136    self.assertEqual(samples_list.count(-1), 1)
137
138  @combinations.generate(
139      combinations.times(test_base.default_test_combinations(),
140                         _weights_type_combinations()))
141  def testSampleFromDatasetsWithZeroWeight(self, weights_type):
142    # Sampling stops when the second dataset is exhausted.
143    weights = _get_weights_of_type(np.asarray([0., 1.]), weights_type)
144    datasets = [
145        dataset_ops.Dataset.from_tensors(-1).repeat(2),
146        dataset_ops.Dataset.from_tensors(1).repeat(2)
147    ]
148    sample_dataset = dataset_ops.Dataset.sample_from_datasets(
149        datasets, weights=weights, stop_on_empty_dataset=True)
150    self.assertDatasetProduces(sample_dataset, [1, 1])
151
152  @combinations.generate(
153      combinations.times(test_base.default_test_combinations(),
154                         _weights_type_combinations()))
155  def testSampleFromEmptyDataset(self, weights_type):
156    weights = _get_weights_of_type(np.asarray([1., 0.]), weights_type)
157    datasets = [
158        dataset_ops.Dataset.range(0),
159        dataset_ops.Dataset.range(1).repeat()
160    ]
161    sample_dataset = dataset_ops.Dataset.sample_from_datasets(
162        datasets, weights=weights, stop_on_empty_dataset=True)
163    self.assertDatasetProduces(sample_dataset, [])
164
165  @combinations.generate(test_base.default_test_combinations())
166  def testSampleFromDatasetsSkippingDatasetsWithZeroWeight(self):
167    # Sampling skips the first dataset.
168    weights = np.asarray([0., 1.])
169    datasets = [
170        dataset_ops.Dataset.from_tensors(-1).repeat(),
171        dataset_ops.Dataset.from_tensors(1)
172    ]
173    sample_dataset = dataset_ops.Dataset.sample_from_datasets(
174        datasets, weights=weights, stop_on_empty_dataset=False)
175    self.assertDatasetProduces(sample_dataset, [1])
176
177  @combinations.generate(test_base.default_test_combinations())
178  def testSampleFromDatasetsAllWeightsAreZero(self):
179    # Sampling skips both datasets.
180    weights = np.asarray([0., 0.])
181    datasets = [
182        dataset_ops.Dataset.from_tensors(-1).repeat(),
183        dataset_ops.Dataset.from_tensors(1).repeat()
184    ]
185    sample_dataset = dataset_ops.Dataset.sample_from_datasets(
186        datasets, weights=weights, stop_on_empty_dataset=False)
187    self.assertDatasetProduces(sample_dataset, [])
188
189  @combinations.generate(test_base.default_test_combinations())
190  def testSampleFromDatasetsCardinality(self):
191    ds1 = dataset_ops.Dataset.from_tensors([1.0]).repeat()
192    ds2 = dataset_ops.Dataset.from_tensors([2.0]).repeat()
193    ds = dataset_ops.Dataset.sample_from_datasets([ds1, ds2])
194    self.assertEqual(self.evaluate(ds.cardinality()), dataset_ops.INFINITE)
195
196  @combinations.generate(test_base.default_test_combinations())
197  def testSampleFromDatasetsNested(self):
198    ds1 = dataset_ops.Dataset.range(10).window(2)
199    ds2 = dataset_ops.Dataset.range(10, 20).window(2)
200    ds = dataset_ops.Dataset.sample_from_datasets([ds1, ds2],
201                                                  weights=[0.3, 0.7])
202    ds = ds.flat_map(lambda x: x)
203    next_element = self.getNext(ds)
204    self.evaluate(next_element())
205
206  @combinations.generate(test_base.default_test_combinations())
207  def testChooseFromDatasets(self):
208    words = [b"foo", b"bar", b"baz"]
209    datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
210    choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
211    choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
212    dataset = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset)
213    next_element = self.getNext(dataset)
214    for i in choice_array:
215      self.assertEqual(words[i], self.evaluate(next_element()))
216    with self.assertRaises(errors.OutOfRangeError):
217      self.evaluate(next_element())
218
219  @combinations.generate(test_base.default_test_combinations())
220  def testChooseFromDatasetsStoppingOnEmptyDataset(self):
221    datasets = [
222        dataset_ops.Dataset.from_tensors(b"foo").repeat(2),
223        dataset_ops.Dataset.from_tensors(b"bar").repeat(),
224        dataset_ops.Dataset.from_tensors(b"baz").repeat(),
225    ]
226    choice_array = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int64)
227    choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
228    dataset = dataset_ops.Dataset.choose_from_datasets(
229        datasets, choice_dataset, stop_on_empty_dataset=True)
230    self.assertDatasetProduces(dataset, [b"foo", b"foo"])
231
232  @combinations.generate(test_base.default_test_combinations())
233  def testChooseFromDatasetsSkippingEmptyDatasets(self):
234    datasets = [
235        dataset_ops.Dataset.from_tensors(b"foo").repeat(2),
236        dataset_ops.Dataset.from_tensors(b"bar").repeat(),
237        dataset_ops.Dataset.from_tensors(b"baz").repeat(),
238    ]
239    choice_array = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int64)
240    choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
241    dataset = dataset_ops.Dataset.choose_from_datasets(
242        datasets, choice_dataset, stop_on_empty_dataset=False)
243    # Chooses 2 elements from the first dataset while the selector specifies 3.
244    self.assertDatasetProduces(
245        dataset,
246        [b"foo", b"foo", b"bar", b"bar", b"bar", b"baz", b"baz", b"baz"])
247
248  @combinations.generate(test_base.default_test_combinations())
249  def testChooseFromDatasetsChoiceDatasetIsEmpty(self):
250    datasets = [
251        dataset_ops.Dataset.from_tensors(b"foo").repeat(),
252        dataset_ops.Dataset.from_tensors(b"bar").repeat(),
253        dataset_ops.Dataset.from_tensors(b"baz").repeat(),
254    ]
255    dataset = dataset_ops.Dataset.choose_from_datasets(
256        datasets,
257        choice_dataset=dataset_ops.Dataset.range(0),
258        stop_on_empty_dataset=False)
259    self.assertDatasetProduces(dataset, [])
260
261  @combinations.generate(test_base.default_test_combinations())
262  def testChooseFromDatasetsNested(self):
263    ds1 = dataset_ops.Dataset.range(10).window(2)
264    ds2 = dataset_ops.Dataset.range(10, 20).window(2)
265    choice_dataset = dataset_ops.Dataset.range(2).repeat(5)
266    ds = dataset_ops.Dataset.choose_from_datasets([ds1, ds2], choice_dataset)
267    ds = ds.flat_map(lambda x: x)
268    expected = []
269    for i in range(5):
270      for j in range(2):
271        expected.extend([10*j + 2*i, 10*j + 2*i + 1])
272    self.assertDatasetProduces(ds, expected)
273
274  @combinations.generate(test_base.default_test_combinations())
275  def testErrors(self):
276    with self.assertRaisesRegex(ValueError, r"should have the same length"):
277      dataset_ops.Dataset.sample_from_datasets(
278          [dataset_ops.Dataset.range(10),
279           dataset_ops.Dataset.range(20)],
280          weights=[0.25, 0.25, 0.25, 0.25])
281
282    with self.assertRaisesRegex(TypeError, "`tf.float32` or `tf.float64`"):
283      dataset_ops.Dataset.sample_from_datasets(
284          [dataset_ops.Dataset.range(10),
285           dataset_ops.Dataset.range(20)],
286          weights=[1, 1])
287
288    with self.assertRaisesRegex(TypeError, "must have compatible"):
289      dataset_ops.Dataset.sample_from_datasets([
290          dataset_ops.Dataset.from_tensors(0),
291          dataset_ops.Dataset.from_tensors(0.0)
292      ])
293
294    with self.assertRaisesRegex(
295        ValueError, r"Invalid `datasets`. `datasets` should not be empty."):
296      dataset_ops.Dataset.sample_from_datasets(datasets=[], weights=[])
297
298    with self.assertRaisesRegex(TypeError, "tf.int64"):
299      dataset_ops.Dataset.choose_from_datasets(
300          [
301              dataset_ops.Dataset.from_tensors(0),
302              dataset_ops.Dataset.from_tensors(1)
303          ],
304          choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
305
306    with self.assertRaisesRegex(TypeError, "scalar"):
307      dataset_ops.Dataset.choose_from_datasets(
308          [
309              dataset_ops.Dataset.from_tensors(0),
310              dataset_ops.Dataset.from_tensors(1)
311          ],
312          choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
313
314    with self.assertRaisesRegex(errors.InvalidArgumentError, "out of range"):
315      dataset = dataset_ops.Dataset.choose_from_datasets(
316          [dataset_ops.Dataset.from_tensors(0)],
317          choice_dataset=dataset_ops.Dataset.from_tensors(
318              constant_op.constant(1, dtype=dtypes.int64)))
319      next_element = self.getNext(dataset)
320      self.evaluate(next_element())
321
322    with self.assertRaisesRegex(
323        ValueError, r"Invalid `datasets`. `datasets` should not be empty."):
324      dataset_ops.Dataset.choose_from_datasets(
325          datasets=[], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
326
327    with self.assertRaisesRegex(
328        TypeError, r"`choice_dataset` should be a `tf.data.Dataset`"):
329      datasets = [dataset_ops.Dataset.range(42)]
330      dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset=None)
331
332
333class SampleFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase,
334                                       parameterized.TestCase):
335
336  def _build_dataset(self, probs, num_samples):
337    datasets = [
338        dataset_ops.Dataset.from_tensors(i).repeat(None)
339        for i in range(len(probs))
340    ]
341    dataset = dataset_ops.Dataset.sample_from_datasets(
342        datasets, probs, seed=1813)
343    return dataset.take(num_samples)
344
345  @combinations.generate(
346      combinations.times(test_base.default_test_combinations(),
347                         checkpoint_test_base.default_test_combinations()))
348  def test(self, verify_fn):
349    verify_fn(
350        self, lambda: self._build_dataset([0.5, 0.5], 100), num_outputs=100)
351
352
353if __name__ == "__main__":
354  test.main()
355