1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the experimental input pipeline ops.""" 16from absl.testing import parameterized 17import numpy as np 18 19from tensorflow.python.data.kernel_tests import checkpoint_test_base 20from tensorflow.python.data.kernel_tests import test_base 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.framework import combinations 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import random_seed 28from tensorflow.python.platform import test 29 30 31def _weights_type_combinations(): 32 return combinations.combine(weights_type=["list", "tensor", "dataset"]) 33 34 35def _get_weights_of_type(weights_list, weights_type): 36 if weights_type == "list": 37 return weights_list 38 if weights_type == "tensor": 39 return ops.convert_to_tensor(weights_list, name="weights") 40 return dataset_ops.Dataset.from_tensors(weights_list).repeat() 41 42 43class DirectedInterleaveDatasetTest(test_base.DatasetTestBase, 44 parameterized.TestCase): 45 46 @combinations.generate(test_base.default_test_combinations()) 47 def testBasic(self): 48 selector_dataset = dataset_ops.Dataset.range(10).repeat(100) 49 input_datasets = [ 50 dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) 51 ] 52 dataset = dataset_ops._DirectedInterleaveDataset(selector_dataset, 53 input_datasets) 54 next_element = self.getNext(dataset) 55 56 for _ in range(100): 57 for i in range(10): 58 self.assertEqual(i, self.evaluate(next_element())) 59 with self.assertRaises(errors.OutOfRangeError): 60 self.evaluate(next_element()) 61 62 def _normalize(self, vec): 63 return vec / vec.sum() 64 65 def _chi2(self, expected, actual): 66 actual = np.asarray(actual) 67 expected = np.asarray(expected) 68 diff = actual - expected 69 chi2 = np.sum(diff * diff / expected, axis=0) 70 return chi2 71 72 @combinations.generate( 73 combinations.times(test_base.default_test_combinations(), 74 _weights_type_combinations())) 75 def testSampleFromDatasets(self, weights_type): 76 random_seed.set_random_seed(1619) 77 num_samples = 5000 78 rand_probs = self._normalize(np.random.random_sample((5,))) 79 80 # Use chi-squared test to assert that the observed distribution matches the 81 # expected distribution. Based on the implementation in 82 # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py". 83 for probs in [[.85, .05, .1], rand_probs, [1.]]: 84 weights = _get_weights_of_type(np.asarray(probs), weights_type) 85 classes = len(probs) 86 87 # Create a dataset that samples each integer in `[0, num_datasets)` 88 # with probability given by `weights[i]`. 89 dataset = dataset_ops.Dataset.sample_from_datasets([ 90 dataset_ops.Dataset.from_tensors(i).repeat() for i in range(classes) 91 ], weights) 92 dataset = dataset.take(num_samples) 93 94 next_element = self.getNext(dataset) 95 freqs = np.zeros([classes]) 96 for _ in range(num_samples): 97 freqs[self.evaluate(next_element())] += 1 98 with self.assertRaises(errors.OutOfRangeError): 99 self.evaluate(next_element()) 100 101 self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) 102 103 @combinations.generate( 104 combinations.times(test_base.default_test_combinations(), 105 _weights_type_combinations())) 106 def testSampleFromDatasetsStoppingOnEmptyDataset(self, weights_type): 107 # Sampling stops when the first dataset is exhausted. 108 weights = _get_weights_of_type(np.asarray([.5, .1, .4]), weights_type) 109 datasets = [ 110 dataset_ops.Dataset.from_tensors(np.int64(-1)), 111 dataset_ops.Dataset.from_tensors(np.int64(1)).repeat(), 112 dataset_ops.Dataset.range(10).repeat() 113 ] 114 sample_dataset = dataset_ops.Dataset.sample_from_datasets( 115 datasets, weights=weights, stop_on_empty_dataset=True) 116 117 samples_list = self.getIteratorOutput(self.getNext(sample_dataset)) 118 self.assertEqual(samples_list.count(-1), 1) 119 120 @combinations.generate( 121 combinations.times(test_base.default_test_combinations(), 122 _weights_type_combinations())) 123 def testSampleFromDatasetsSkippingEmptyDataset(self, weights_type): 124 # Sampling skips the first dataset after it becomes empty. 125 weights = _get_weights_of_type(np.asarray([.5, .1, .4]), weights_type) 126 datasets = [ 127 dataset_ops.Dataset.from_tensors(np.int64(-1)), 128 dataset_ops.Dataset.from_tensors(np.int64(1)).repeat(), 129 dataset_ops.Dataset.range(10).repeat() 130 ] 131 sample_dataset = dataset_ops.Dataset.sample_from_datasets( 132 datasets, weights=weights, stop_on_empty_dataset=False).take(100) 133 134 samples_list = self.getIteratorOutput(self.getNext(sample_dataset)) 135 self.assertLen(samples_list, 100) 136 self.assertEqual(samples_list.count(-1), 1) 137 138 @combinations.generate( 139 combinations.times(test_base.default_test_combinations(), 140 _weights_type_combinations())) 141 def testSampleFromDatasetsWithZeroWeight(self, weights_type): 142 # Sampling stops when the second dataset is exhausted. 143 weights = _get_weights_of_type(np.asarray([0., 1.]), weights_type) 144 datasets = [ 145 dataset_ops.Dataset.from_tensors(-1).repeat(2), 146 dataset_ops.Dataset.from_tensors(1).repeat(2) 147 ] 148 sample_dataset = dataset_ops.Dataset.sample_from_datasets( 149 datasets, weights=weights, stop_on_empty_dataset=True) 150 self.assertDatasetProduces(sample_dataset, [1, 1]) 151 152 @combinations.generate( 153 combinations.times(test_base.default_test_combinations(), 154 _weights_type_combinations())) 155 def testSampleFromEmptyDataset(self, weights_type): 156 weights = _get_weights_of_type(np.asarray([1., 0.]), weights_type) 157 datasets = [ 158 dataset_ops.Dataset.range(0), 159 dataset_ops.Dataset.range(1).repeat() 160 ] 161 sample_dataset = dataset_ops.Dataset.sample_from_datasets( 162 datasets, weights=weights, stop_on_empty_dataset=True) 163 self.assertDatasetProduces(sample_dataset, []) 164 165 @combinations.generate(test_base.default_test_combinations()) 166 def testSampleFromDatasetsSkippingDatasetsWithZeroWeight(self): 167 # Sampling skips the first dataset. 168 weights = np.asarray([0., 1.]) 169 datasets = [ 170 dataset_ops.Dataset.from_tensors(-1).repeat(), 171 dataset_ops.Dataset.from_tensors(1) 172 ] 173 sample_dataset = dataset_ops.Dataset.sample_from_datasets( 174 datasets, weights=weights, stop_on_empty_dataset=False) 175 self.assertDatasetProduces(sample_dataset, [1]) 176 177 @combinations.generate(test_base.default_test_combinations()) 178 def testSampleFromDatasetsAllWeightsAreZero(self): 179 # Sampling skips both datasets. 180 weights = np.asarray([0., 0.]) 181 datasets = [ 182 dataset_ops.Dataset.from_tensors(-1).repeat(), 183 dataset_ops.Dataset.from_tensors(1).repeat() 184 ] 185 sample_dataset = dataset_ops.Dataset.sample_from_datasets( 186 datasets, weights=weights, stop_on_empty_dataset=False) 187 self.assertDatasetProduces(sample_dataset, []) 188 189 @combinations.generate(test_base.default_test_combinations()) 190 def testSampleFromDatasetsCardinality(self): 191 ds1 = dataset_ops.Dataset.from_tensors([1.0]).repeat() 192 ds2 = dataset_ops.Dataset.from_tensors([2.0]).repeat() 193 ds = dataset_ops.Dataset.sample_from_datasets([ds1, ds2]) 194 self.assertEqual(self.evaluate(ds.cardinality()), dataset_ops.INFINITE) 195 196 @combinations.generate(test_base.default_test_combinations()) 197 def testSampleFromDatasetsNested(self): 198 ds1 = dataset_ops.Dataset.range(10).window(2) 199 ds2 = dataset_ops.Dataset.range(10, 20).window(2) 200 ds = dataset_ops.Dataset.sample_from_datasets([ds1, ds2], 201 weights=[0.3, 0.7]) 202 ds = ds.flat_map(lambda x: x) 203 next_element = self.getNext(ds) 204 self.evaluate(next_element()) 205 206 @combinations.generate(test_base.default_test_combinations()) 207 def testChooseFromDatasets(self): 208 words = [b"foo", b"bar", b"baz"] 209 datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] 210 choice_array = np.random.randint(3, size=(15,), dtype=np.int64) 211 choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) 212 dataset = dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset) 213 next_element = self.getNext(dataset) 214 for i in choice_array: 215 self.assertEqual(words[i], self.evaluate(next_element())) 216 with self.assertRaises(errors.OutOfRangeError): 217 self.evaluate(next_element()) 218 219 @combinations.generate(test_base.default_test_combinations()) 220 def testChooseFromDatasetsStoppingOnEmptyDataset(self): 221 datasets = [ 222 dataset_ops.Dataset.from_tensors(b"foo").repeat(2), 223 dataset_ops.Dataset.from_tensors(b"bar").repeat(), 224 dataset_ops.Dataset.from_tensors(b"baz").repeat(), 225 ] 226 choice_array = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int64) 227 choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) 228 dataset = dataset_ops.Dataset.choose_from_datasets( 229 datasets, choice_dataset, stop_on_empty_dataset=True) 230 self.assertDatasetProduces(dataset, [b"foo", b"foo"]) 231 232 @combinations.generate(test_base.default_test_combinations()) 233 def testChooseFromDatasetsSkippingEmptyDatasets(self): 234 datasets = [ 235 dataset_ops.Dataset.from_tensors(b"foo").repeat(2), 236 dataset_ops.Dataset.from_tensors(b"bar").repeat(), 237 dataset_ops.Dataset.from_tensors(b"baz").repeat(), 238 ] 239 choice_array = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=np.int64) 240 choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) 241 dataset = dataset_ops.Dataset.choose_from_datasets( 242 datasets, choice_dataset, stop_on_empty_dataset=False) 243 # Chooses 2 elements from the first dataset while the selector specifies 3. 244 self.assertDatasetProduces( 245 dataset, 246 [b"foo", b"foo", b"bar", b"bar", b"bar", b"baz", b"baz", b"baz"]) 247 248 @combinations.generate(test_base.default_test_combinations()) 249 def testChooseFromDatasetsChoiceDatasetIsEmpty(self): 250 datasets = [ 251 dataset_ops.Dataset.from_tensors(b"foo").repeat(), 252 dataset_ops.Dataset.from_tensors(b"bar").repeat(), 253 dataset_ops.Dataset.from_tensors(b"baz").repeat(), 254 ] 255 dataset = dataset_ops.Dataset.choose_from_datasets( 256 datasets, 257 choice_dataset=dataset_ops.Dataset.range(0), 258 stop_on_empty_dataset=False) 259 self.assertDatasetProduces(dataset, []) 260 261 @combinations.generate(test_base.default_test_combinations()) 262 def testChooseFromDatasetsNested(self): 263 ds1 = dataset_ops.Dataset.range(10).window(2) 264 ds2 = dataset_ops.Dataset.range(10, 20).window(2) 265 choice_dataset = dataset_ops.Dataset.range(2).repeat(5) 266 ds = dataset_ops.Dataset.choose_from_datasets([ds1, ds2], choice_dataset) 267 ds = ds.flat_map(lambda x: x) 268 expected = [] 269 for i in range(5): 270 for j in range(2): 271 expected.extend([10*j + 2*i, 10*j + 2*i + 1]) 272 self.assertDatasetProduces(ds, expected) 273 274 @combinations.generate(test_base.default_test_combinations()) 275 def testErrors(self): 276 with self.assertRaisesRegex(ValueError, r"should have the same length"): 277 dataset_ops.Dataset.sample_from_datasets( 278 [dataset_ops.Dataset.range(10), 279 dataset_ops.Dataset.range(20)], 280 weights=[0.25, 0.25, 0.25, 0.25]) 281 282 with self.assertRaisesRegex(TypeError, "`tf.float32` or `tf.float64`"): 283 dataset_ops.Dataset.sample_from_datasets( 284 [dataset_ops.Dataset.range(10), 285 dataset_ops.Dataset.range(20)], 286 weights=[1, 1]) 287 288 with self.assertRaisesRegex(TypeError, "must have compatible"): 289 dataset_ops.Dataset.sample_from_datasets([ 290 dataset_ops.Dataset.from_tensors(0), 291 dataset_ops.Dataset.from_tensors(0.0) 292 ]) 293 294 with self.assertRaisesRegex( 295 ValueError, r"Invalid `datasets`. `datasets` should not be empty."): 296 dataset_ops.Dataset.sample_from_datasets(datasets=[], weights=[]) 297 298 with self.assertRaisesRegex(TypeError, "tf.int64"): 299 dataset_ops.Dataset.choose_from_datasets( 300 [ 301 dataset_ops.Dataset.from_tensors(0), 302 dataset_ops.Dataset.from_tensors(1) 303 ], 304 choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) 305 306 with self.assertRaisesRegex(TypeError, "scalar"): 307 dataset_ops.Dataset.choose_from_datasets( 308 [ 309 dataset_ops.Dataset.from_tensors(0), 310 dataset_ops.Dataset.from_tensors(1) 311 ], 312 choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) 313 314 with self.assertRaisesRegex(errors.InvalidArgumentError, "out of range"): 315 dataset = dataset_ops.Dataset.choose_from_datasets( 316 [dataset_ops.Dataset.from_tensors(0)], 317 choice_dataset=dataset_ops.Dataset.from_tensors( 318 constant_op.constant(1, dtype=dtypes.int64))) 319 next_element = self.getNext(dataset) 320 self.evaluate(next_element()) 321 322 with self.assertRaisesRegex( 323 ValueError, r"Invalid `datasets`. `datasets` should not be empty."): 324 dataset_ops.Dataset.choose_from_datasets( 325 datasets=[], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) 326 327 with self.assertRaisesRegex( 328 TypeError, r"`choice_dataset` should be a `tf.data.Dataset`"): 329 datasets = [dataset_ops.Dataset.range(42)] 330 dataset_ops.Dataset.choose_from_datasets(datasets, choice_dataset=None) 331 332 333class SampleFromDatasetsCheckpointTest(checkpoint_test_base.CheckpointTestBase, 334 parameterized.TestCase): 335 336 def _build_dataset(self, probs, num_samples): 337 datasets = [ 338 dataset_ops.Dataset.from_tensors(i).repeat(None) 339 for i in range(len(probs)) 340 ] 341 dataset = dataset_ops.Dataset.sample_from_datasets( 342 datasets, probs, seed=1813) 343 return dataset.take(num_samples) 344 345 @combinations.generate( 346 combinations.times(test_base.default_test_combinations(), 347 checkpoint_test_base.default_test_combinations())) 348 def test(self, verify_fn): 349 verify_fn( 350 self, lambda: self._build_dataset([0.5, 0.5], 100), num_outputs=100) 351 352 353if __name__ == "__main__": 354 test.main() 355