xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/interleave_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Non-deterministic dataset transformations."""
16from tensorflow.python import tf2
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.data.ops import readers
19from tensorflow.python.util import deprecation
20from tensorflow.python.util.tf_export import tf_export
21
22
23@deprecation.deprecated(
24    None,
25    "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
26    "num_parallel_calls=tf.data.AUTOTUNE)` instead. If sloppy "
27    "execution is desired, use `tf.data.Options.deterministic`.")
28@tf_export("data.experimental.parallel_interleave")
29def parallel_interleave(map_func,
30                        cycle_length,
31                        block_length=1,
32                        sloppy=False,
33                        buffer_output_elements=None,
34                        prefetch_input_elements=None):
35  """A parallel version of the `Dataset.interleave()` transformation.
36
37  `parallel_interleave()` maps `map_func` across its input to produce nested
38  datasets, and outputs their elements interleaved. Unlike
39  `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
40  datasets in parallel, which increases the throughput, especially in the
41  presence of stragglers. Furthermore, the `sloppy` argument can be used to
42  improve performance, by relaxing the requirement that the outputs are produced
43  in a deterministic order, and allowing the implementation to skip over nested
44  datasets whose elements are not readily available when requested.
45
46  Example usage:
47
48  ```python
49  # Preprocess 4 files concurrently.
50  filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
51  dataset = filenames.apply(
52      tf.data.experimental.parallel_interleave(
53          lambda filename: tf.data.TFRecordDataset(filename),
54          cycle_length=4))
55  ```
56
57  WARNING: If `sloppy` is `True`, the order of produced elements is not
58  deterministic.
59
60  Args:
61    map_func: A function mapping a nested structure of tensors to a `Dataset`.
62    cycle_length: The number of input `Dataset`s to interleave from in parallel.
63    block_length: The number of consecutive elements to pull from an input
64      `Dataset` before advancing to the next input `Dataset`.
65    sloppy: A boolean controlling whether determinism should be traded for
66      performance by allowing elements to be produced out of order.  If `sloppy`
67      is `None`, the `tf.data.Options.deterministic` dataset option (`True` by
68      default) is used to decide whether to enforce a deterministic order.
69    buffer_output_elements: The number of elements each iterator being
70      interleaved should buffer (similar to the `.prefetch()` transformation for
71      each interleaved iterator).
72    prefetch_input_elements: The number of input elements to transform to
73      iterators before they are needed for interleaving.
74
75  Returns:
76    A `Dataset` transformation function, which can be passed to
77    `tf.data.Dataset.apply`.
78  """
79
80  def _apply_fn(dataset):
81    return readers.ParallelInterleaveDataset(dataset, map_func, cycle_length,
82                                             block_length, sloppy,
83                                             buffer_output_elements,
84                                             prefetch_input_elements)
85
86  return _apply_fn
87
88
89@deprecation.deprecated(None,
90                        "Use `tf.data.Dataset.sample_from_datasets(...)`.")
91@tf_export("data.experimental.sample_from_datasets", v1=[])
92def sample_from_datasets_v2(datasets,
93                            weights=None,
94                            seed=None,
95                            stop_on_empty_dataset=False):
96  """Samples elements at random from the datasets in `datasets`.
97
98  Creates a dataset by interleaving elements of `datasets` with `weight[i]`
99  probability of picking an element from dataset `i`. Sampling is done without
100  replacement. For example, suppose we have 2 datasets:
101
102  ```python
103  dataset1 = tf.data.Dataset.range(0, 3)
104  dataset2 = tf.data.Dataset.range(100, 103)
105  ```
106
107  Suppose also that we sample from these 2 datasets with the following weights:
108
109  ```python
110  sample_dataset = tf.data.Dataset.sample_from_datasets(
111      [dataset1, dataset2], weights=[0.5, 0.5])
112  ```
113
114  One possible outcome of elements in sample_dataset is:
115
116  ```
117  print(list(sample_dataset.as_numpy_iterator()))
118  # [100, 0, 1, 101, 2, 102]
119  ```
120
121  Args:
122    datasets: A non-empty list of `tf.data.Dataset` objects with compatible
123      structure.
124    weights: (Optional.) A list or Tensor of `len(datasets)` floating-point
125      values where `weights[i]` represents the probability to sample from
126      `datasets[i]`, or a `tf.data.Dataset` object where each element is such a
127      list. Defaults to a uniform distribution across `datasets`.
128    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
129      seed that will be used to create the distribution. See
130      `tf.random.set_seed` for behavior.
131    stop_on_empty_dataset: If `True`, sampling stops if it encounters an empty
132      dataset. If `False`, it skips empty datasets. It is recommended to set it
133      to `True`. Otherwise, the distribution of samples starts off as the user
134      intends, but may change as input datasets become empty. This can be
135      difficult to detect since the dataset starts off looking correct. Default
136      to `False` for backward compatibility.
137
138  Returns:
139    A dataset that interleaves elements from `datasets` at random, according to
140    `weights` if provided, otherwise with uniform probability.
141
142  Raises:
143    TypeError: If the `datasets` or `weights` arguments have the wrong type.
144    ValueError:
145      - If `datasets` is empty, or
146      - If `weights` is specified and does not match the length of `datasets`.
147  """
148  return dataset_ops.Dataset.sample_from_datasets(
149      datasets=datasets,
150      weights=weights,
151      seed=seed,
152      stop_on_empty_dataset=stop_on_empty_dataset)
153
154
155@deprecation.deprecated(None,
156                        "Use `tf.data.Dataset.sample_from_datasets(...)`.")
157@tf_export(v1=["data.experimental.sample_from_datasets"])
158def sample_from_datasets_v1(datasets,
159                            weights=None,
160                            seed=None,
161                            stop_on_empty_dataset=False):
162  return dataset_ops.DatasetV1Adapter(
163      sample_from_datasets_v2(datasets, weights, seed, stop_on_empty_dataset))
164
165
166sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__
167
168
169@deprecation.deprecated(
170    None, "Use `tf.data.Dataset.choose_from_datasets(...)` instead. Note that, "
171    "unlike the experimental endpoint, the non-experimental endpoint "
172    "sets `stop_on_empty_dataset=True` by default. You should set this "
173    "argument explicitly in case you would like to match the behavior of the "
174    "experimental endpoint.")
175@tf_export("data.experimental.choose_from_datasets", v1=[])
176def choose_from_datasets_v2(datasets,
177                            choice_dataset,
178                            stop_on_empty_dataset=False):
179  """Creates a dataset that deterministically chooses elements from `datasets`.
180
181  For example, given the following datasets:
182
183  ```python
184  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
185              tf.data.Dataset.from_tensors("bar").repeat(),
186              tf.data.Dataset.from_tensors("baz").repeat()]
187
188  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
189  choice_dataset = tf.data.Dataset.range(3).repeat(3)
190
191  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
192  ```
193
194  The elements of `result` will be:
195
196  ```
197  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
198  ```
199
200  Args:
201    datasets: A non-empty list of `tf.data.Dataset` objects with compatible
202      structure.
203    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0`
204      and `len(datasets) - 1`.
205    stop_on_empty_dataset: If `True`, selection stops if it encounters an empty
206      dataset. If `False`, it skips empty datasets. It is recommended to set it
207      to `True`. Otherwise, the selected elements start off as the user intends,
208      but may change as input datasets become empty. This can be difficult to
209      detect since the dataset starts off looking correct. Default to `False`
210      for backward compatibility.
211
212  Returns:
213    A dataset that interleaves elements from `datasets` according to the values
214    of `choice_dataset`.
215
216  Raises:
217    TypeError: If `datasets` or `choice_dataset` has the wrong type.
218    ValueError: If `datasets` is empty.
219  """
220  return dataset_ops.Dataset.choose_from_datasets(
221      datasets=datasets,
222      choice_dataset=choice_dataset,
223      stop_on_empty_dataset=stop_on_empty_dataset)
224
225
226@deprecation.deprecated(
227    None, "Use `tf.data.Dataset.choose_from_datasets(...)` instead. Note that, "
228    "unlike the experimental endpoint, the non-experimental endpoint "
229    "sets `stop_on_empty_dataset=True` by default. You should set this "
230    "argument explicitly in case you would like to match the behavior of the "
231    "experimental endpoint.")
232@tf_export(v1=["data.experimental.choose_from_datasets"])
233def choose_from_datasets_v1(datasets,
234                            choice_dataset,
235                            stop_on_empty_dataset=False):
236  return dataset_ops.DatasetV1Adapter(
237      choose_from_datasets_v2(datasets, choice_dataset, stop_on_empty_dataset))
238
239
240choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
241
242if tf2.enabled():
243  choose_from_datasets = choose_from_datasets_v2
244  sample_from_datasets = sample_from_datasets_v2
245else:
246  choose_from_datasets = choose_from_datasets_v1
247  sample_from_datasets = sample_from_datasets_v1
248