xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/input_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils to create distributed datasets based on TF version."""
16
17from tensorflow.python import tf2
18from tensorflow.python.distribute import input_lib
19from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
20
21
22def get_distributed_dataset(dataset,
23                            input_workers,
24                            strategy,
25                            num_replicas_in_sync=None,
26                            input_context=None,
27                            options=None,
28                            build=True):
29  """Returns a distributed dataset from the given tf.data.Dataset instance.
30
31  This is a common function that is used by all strategies to return a
32  distributed dataset. The distributed dataset instance returned is different
33  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
34  instances returned differ from each other in the APIs supported by each of
35  them.
36
37  Args:
38    dataset: a tf.data.Dataset instance.
39    input_workers: an InputWorkers object which specifies devices on which
40      iterators should be created.
41    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
42      handle last partial batch.
43    num_replicas_in_sync: Optional integer. If this is not None, the value is
44      used to decide how to rebatch datasets into smaller batches so that the
45      total batch size for each step (across all workers and replicas) adds up
46      to `dataset`'s batch size.
47    input_context: `InputContext` for sharding. Only pass this in for between
48      graph multi-worker cases where there is only one `input_worker`. In these
49      cases, we will shard based on the `input_pipeline_id` and
50      `num_input_pipelines` in the `InputContext`.
51    options: Default is None. `tf.distribute.InputOptions` used to control
52      options on how this dataset is distributed.
53    build: whether to build underlying datasets when a DistributedDataset is
54      created. This is only useful for `ParameterServerStrategy` now.
55
56  Returns:
57    A distributed dataset instance.
58  """
59  if tf2.enabled():
60    return input_lib.DistributedDataset(
61        input_workers,
62        strategy,
63        dataset,
64        num_replicas_in_sync=num_replicas_in_sync,
65        input_context=input_context,
66        build=build,
67        options=options)
68  else:
69    return input_lib_v1.DistributedDatasetV1(
70        dataset,
71        input_workers,
72        strategy,
73        num_replicas_in_sync=num_replicas_in_sync,
74        input_context=input_context,
75        options=options)
76
77
78def get_distributed_datasets_from_function(dataset_fn,
79                                           input_workers,
80                                           input_contexts,
81                                           strategy,
82                                           options=None,
83                                           build=True):
84  """Returns a distributed dataset from the given input function.
85
86  This is a common function that is used by all strategies to return a
87  distributed dataset. The distributed dataset instance returned is different
88  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
89  instances returned differ from each other in the APIs supported by each of
90  them.
91
92  Args:
93    dataset_fn: a function that returns a tf.data.Dataset instance.
94    input_workers: an InputWorkers object which specifies devices on which
95      iterators should be created.
96    input_contexts: A list of `InputContext` instances to be passed to call(s)
97      to `dataset_fn`. Length and order should match worker order in
98      `worker_device_pairs`.
99    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
100      handle last partial batch.
101    options: Default is None. `tf.distribute.InputOptions` used to control
102      options on how this dataset is distributed.
103    build: whether to build underlying datasets when a
104      `DistributedDatasetFromFunction` is created. This is only useful for
105      `ParameterServerStrategy` now.
106
107  Returns:
108    A distributed dataset instance.
109
110  Raises:
111    ValueError: if `options.experimental_replication_mode` and
112    `options.experimental_place_dataset_on_device` are not consistent
113  """
114  if (options is not None and options.experimental_replication_mode !=
115      input_lib.InputReplicationMode.PER_REPLICA and
116      options.experimental_place_dataset_on_device):
117    raise ValueError(
118        "When `experimental_place_dataset_on_device` is set for dataset "
119        "placement, you must also specify `PER_REPLICA` for the "
120        "replication mode")
121
122  if (options is not None and options.experimental_replication_mode
123      == input_lib.InputReplicationMode.PER_REPLICA and
124      options.experimental_fetch_to_device and
125      options.experimental_place_dataset_on_device):
126    raise ValueError(
127        "`experimental_place_dataset_on_device` can not be set to True "
128        "when experimental_fetch_to_device is True and "
129        "replication mode is set to `PER_REPLICA`")
130
131  if tf2.enabled():
132    return input_lib.DistributedDatasetsFromFunction(
133        input_workers,
134        strategy,
135        input_contexts=input_contexts,
136        dataset_fn=dataset_fn,
137        options=options,
138        build=build,
139    )
140  else:
141    return input_lib_v1.DistributedDatasetsFromFunctionV1(
142        input_workers, strategy, input_contexts, dataset_fn, options)
143