xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/v1/input_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
16
17from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib
18from tensorflow.python.data.ops import dataset_ops
19from tensorflow.python.data.ops import multi_device_iterator_ops
20from tensorflow.python.data.ops import optional_ops
21from tensorflow.python.distribute import input_lib
22from tensorflow.python.eager import context
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.util.deprecation import deprecated
26
27
28class DistributedDatasetV1(input_lib.DistributedDataset):
29  """Distributed dataset that supports prefetching to multiple devices."""
30
31  def __init__(self,
32               dataset,
33               input_workers,
34               strategy,
35               num_replicas_in_sync=None,
36               input_context=None,
37               options=None):
38    self._input_workers = input_workers
39    super(DistributedDatasetV1, self).__init__(
40        input_workers,
41        strategy,
42        dataset,
43        num_replicas_in_sync=num_replicas_in_sync,
44        input_context=input_context,
45        options=options)
46
47  def make_one_shot_iterator(self):
48    """Get a one time use iterator for DistributedDatasetV1.
49
50    Note: This API is deprecated. Please use `for ... in dataset:` to iterate
51    over the dataset or `iter` to create an iterator.
52
53    Returns:
54      A DistributedIteratorV1 instance.
55    """
56    return self._make_one_shot_iterator()
57
58  def _make_one_shot_iterator(self):
59    """Get an iterator for DistributedDatasetV1."""
60    # Graph mode with one shot iterator is disabled because we have to call
61    # `initialize` on the iterator which is only required if we are using a
62    # tf.distribute strategy.
63    if not context.executing_eagerly():
64      raise ValueError("Cannot create a one shot iterator. Please use "
65                       "`make_initializable_iterator()` instead.")
66    return self._get_iterator()
67
68  def make_initializable_iterator(self):
69    """Get an initializable iterator for DistributedDatasetV1.
70
71    Note: This API is deprecated. Please use
72    `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
73    initializable iterator.
74
75    Returns:
76      A DistributedIteratorV1 instance.
77    """
78    return self._make_initializable_iterator()
79
80  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=unused-argument
81    """Get an initializable iterator for DistributedDatasetV1."""
82    # Eager mode generates already initialized iterators. Hence we cannot create
83    # an initializable iterator.
84    if context.executing_eagerly():
85      raise ValueError("Cannot create initializable iterator in Eager mode. "
86                       "Please use `iter()` instead.")
87    return self._get_iterator()
88
89  def _get_iterator(self):
90    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
91                                                    self._input_workers,
92                                                    self._options)
93    cardinality = input_lib._cardinality(self._cloned_datasets[0])  # pylint: disable=protected-access
94    iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
95                                     self._strategy, cardinality,
96                                     self._enable_get_next_as_optional)
97    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
98
99    # When async eager is enabled, sometimes the iterator may not finish
100    # initialization before passing to a multi device function, add a sync point
101    # here to make sure all underlying iterators are initialized.
102    if context.executing_eagerly():
103      context.async_wait()
104
105    return iterator
106
107  # pylint: disable=non-iterator-returned
108  def __iter__(self):
109    if (ops.executing_eagerly_outside_functions() or
110        ops.get_default_graph().building_function):
111      return self._get_iterator()
112
113    raise RuntimeError("__iter__() is only supported inside of tf.function "
114                       "or when eager execution is enabled.")
115
116  # pylint: enable=non-iterator-returned
117
118
119class DistributedDatasetsFromFunctionV1(
120    input_lib.DistributedDatasetsFromFunction):
121  """Inputs created from dataset function."""
122
123  def _make_initializable_iterator(self, shared_name=None):
124    """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
125    del shared_name  # Unused
126    # Eager mode generates already initialized iterators. Hence we cannot create
127    # an initializable iterator.
128    if context.executing_eagerly():
129      raise ValueError("Cannot create initializable iterator in Eager mode. "
130                       "Please use `iter()` instead.")
131    return self._get_iterator()
132
133  def _make_one_shot_iterator(self):
134    """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
135    # Graph mode with one shot iterator is disabled because we have to call
136    # `initialize` on the iterator which is only required if we are using a
137    # tf.distribute strategy.
138    if not context.executing_eagerly():
139      raise ValueError("Cannot create a one shot iterator. Please use "
140                       "`make_initializable_iterator()` instead.")
141    return self._get_iterator()
142
143  def _get_iterator(self):
144    iterators = _create_iterators_per_worker(self._datasets,
145                                             self._input_workers, self._options)
146    cardinality = input_lib._cardinality(self._datasets[0])  # pylint: disable=protected-access
147    iterator = DistributedIteratorV1(self._input_workers, iterators,
148                                     self._strategy, cardinality,
149                                     self._enable_get_next_as_optional)
150    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
151
152    # When async eager is enabled, sometimes the iterator may not finish
153    # initialization before passing to a multi device function, add a sync point
154    # here to make sure all underlying iterators are initialized.
155    if context.executing_eagerly():
156      context.async_wait()
157
158    return iterator
159
160  # pylint: disable=non-iterator-returned
161  def __iter__(self):
162    if (ops.executing_eagerly_outside_functions() or
163        ops.get_default_graph().building_function):
164      return self._get_iterator()
165
166    raise RuntimeError("__iter__() is only supported inside of tf.function "
167                       "or when eager execution is enabled.")
168
169  # pylint: enable=non-iterator-returned
170
171
172class DistributedIteratorV1(input_lib.DistributedIteratorBase):
173  """Input Iterator for a distributed dataset."""
174
175  # We need a private initializer method for re-initializing multidevice
176  # iterators when used with Keras training loops. If we don't reinitialize the
177  # iterator we run into memory leak issues (b/123315763).
178  @property
179  def _initializer(self):
180    init_ops = []
181    for it in self._iterators:
182      init_ops.extend(it.initialize())
183    return control_flow_ops.group(init_ops)
184
185  @deprecated(None, "Use the iterator's `initializer` property instead.")
186  def initialize(self):
187    """Initialize underlying iterators.
188
189    Returns:
190      A list of any initializer ops that should be run.
191    """
192    return self._initializer
193
194  @property
195  def initializer(self):
196    """Returns a list of ops that initialize the iterator."""
197    return self.initialize()
198
199  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
200  @property
201  def output_classes(self):
202    return self._iterators[0].output_classes
203
204  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
205  @property
206  def output_shapes(self):
207    return self._iterators[0].output_shapes
208
209  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
210  @property
211  def output_types(self):
212    return self._iterators[0].output_types
213
214  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
215  def get_iterator(self, worker):
216    for i, w in enumerate(self._input_workers.worker_devices):
217      if worker == w:
218        return self._iterators[i]
219    return None
220
221  @property
222  def element_spec(self):
223    """The type specification of an element of this iterator."""
224    return self._element_spec
225
226
227class DatasetIterator(DistributedIteratorV1):
228  """Iterator created from input dataset."""
229
230  def __init__(self,
231               dataset,
232               input_workers,
233               strategy,
234               num_replicas_in_sync=None,
235               input_context=None):
236    """Make an iterator for the dataset on given devices.
237
238    If `num_replicas_in_sync` is not None, we split each batch of the dataset
239    into `num_replicas_in_sync` smaller batches, to be distributed among that
240    worker's replicas, so that the batch size for a global step (across all
241    workers and replicas) is as expected.
242
243    Args:
244      dataset: `tf.data.Dataset` that will be used as the input source.
245      input_workers: an `InputWorkers` object.
246      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
247        handle last partial batch.
248      num_replicas_in_sync: Optional integer. If this is not None, the value is
249        used to decide how to rebatch datasets into smaller batches so that the
250        total batch size for each step (across all workers and replicas) adds up
251        to `dataset`'s batch size.
252      input_context: `InputContext` for sharding. Only pass this in for between
253        graph multi-worker cases where there is only one `input_worker`. In
254        these cases, we will shard based on the `input_pipeline_id` and
255        `num_input_pipelines` in the `InputContext`.
256    """
257    dist_dataset = DistributedDatasetV1(
258        dataset,
259        input_workers,
260        strategy,
261        num_replicas_in_sync=num_replicas_in_sync,
262        input_context=input_context)
263    # pylint: disable=protected-access
264    worker_iterators = _create_iterators_per_worker(
265        dist_dataset._cloned_datasets, input_workers)
266    super(DatasetIterator,
267          self).__init__(input_workers, worker_iterators, strategy,
268                         dist_dataset.cardinality,
269                         dist_dataset._enable_get_next_as_optional)
270    self._element_spec = dist_dataset.element_spec
271    # pylint: enable=protected-access
272
273
274class InputFunctionIterator(DistributedIteratorV1):
275  """Iterator created from input function."""
276
277  def __init__(self, input_fn, input_workers, input_contexts, strategy):
278    """Make an iterator for input provided via an input function.
279
280    Currently implements PER_WORKER mode, in which the `input_fn` is called
281    once on each worker.
282
283    TODO(priyag): Add other replication modes.
284
285    Args:
286      input_fn: Input function that returns a `tf.data.Dataset` object.
287      input_workers: an `InputWorkers` object.
288      input_contexts: A list of `InputContext` instances to be passed to call(s)
289        to `input_fn`. Length and order should match worker order in
290        `worker_device_pairs`.
291      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
292        handle last partial batch.
293    """
294    assert isinstance(input_workers, input_lib.InputWorkers)
295    if input_workers.num_workers != len(input_contexts):
296      raise ValueError("Number of input workers (%d) is not same as number of "
297                       "input_contexts (%d)" %
298                       (input_workers.num_workers, len(input_contexts)))
299
300    iterators = []
301    for i, ctx in enumerate(input_contexts):
302      worker = input_workers.worker_devices[i]
303      with ops.device(worker):
304        result = input_fn(ctx)
305        devices = input_workers.compute_devices_for_worker(i)
306        if isinstance(result, dataset_ops.DatasetV2):
307          iterator = _SingleWorkerDatasetIterator(result, worker, devices)
308        elif callable(result):
309          iterator = _SingleWorkerCallableIterator(result, worker, devices)
310        else:
311          raise ValueError(
312              "input_fn must return a tf.data.Dataset or a callable.")
313        iterators.append(iterator)
314
315    super(InputFunctionIterator, self).__init__(
316        input_workers,
317        iterators,
318        strategy,
319        cardinality=cardinality_lib.UNKNOWN,
320        enable_get_next_as_optional=False)
321    self._enable_get_next_as_optional = False
322
323
324class _SingleWorkerDatasetIterator(input_lib._SingleWorkerDatasetIteratorBase):  # pylint: disable=protected-access
325  """Iterator for a single DistributedDatasetV1 instance."""
326
327  def _make_iterator(self):
328    """Make appropriate iterator on the dataset."""
329    with ops.device(self._worker):
330      if self._options is not None:
331        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
332            self._dataset,
333            self._devices,
334            max_buffer_size=self._options.experimental_per_replica_buffer_size,
335            prefetch_buffer_size=self._options
336            .experimental_per_replica_buffer_size)
337      else:
338        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
339            self._dataset,
340            self._devices,
341        )
342
343  def initialize(self):
344    """Initialize underlying iterator.
345
346    In eager execution, this simply recreates the underlying iterator.
347    In graph execution, it returns the initializer ops for the underlying
348    iterator.
349
350    Returns:
351      A list of any initializer ops that should be run.
352    """
353    if ops.executing_eagerly_outside_functions():
354      self._iterator._eager_reset()  # pylint: disable=protected-access
355      return []
356    else:
357      return [self._iterator.initializer]
358
359  @property
360  def output_classes(self):
361    return dataset_ops.get_legacy_output_classes(self._iterator)
362
363  @property
364  def output_shapes(self):
365    return dataset_ops.get_legacy_output_shapes(self._iterator)
366
367  @property
368  def output_types(self):
369    return dataset_ops.get_legacy_output_types(self._iterator)
370
371
372class _SingleWorkerCallableIterator(object):
373  """Iterator for a single tensor-returning callable."""
374
375  def __init__(self, fn, worker, devices):
376    self._fn = fn
377    self._worker = worker
378    self._devices = devices
379
380  def get_next(self, device, name=None):
381    """Get next element for the given device from the callable."""
382    del device, name
383    with ops.device(self._worker):
384      return self._fn()
385
386  def get_next_as_list(self, name=None):
387    """Get next element from the callable."""
388    del name
389    with ops.device(self._worker):
390      data_list = [self._fn() for _ in self._devices]
391      return data_list
392
393  def get_next_as_optional_list(self):
394    with ops.device(self._worker):
395      data_list = [
396          optional_ops.Optional.from_value(self._fn()) for _ in self._devices
397      ]
398      return data_list
399
400  def initialize(self):
401    # TODO(petebu) Should this throw an exception instead?
402    return []
403
404
405def _create_iterators_per_worker(worker_datasets, input_workers, options=None):
406  """Create a multidevice iterator on each of the workers."""
407  assert isinstance(input_workers, input_lib.InputWorkers)
408  assert len(worker_datasets) == len(input_workers.worker_devices)
409  iterators = []
410  for i, worker in enumerate(input_workers.worker_devices):
411    with ops.device(worker):
412      worker_devices = input_workers.compute_devices_for_worker(i)
413      iterator = _SingleWorkerDatasetIterator(
414          worker_datasets[i],  # pylint: disable=protected-access
415          worker,
416          worker_devices,
417          options)
418      iterators.append(iterator)
419  return iterators
420