1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils to create distributed datasets based on TF version.""" 16 17from tensorflow.python import tf2 18from tensorflow.python.distribute import input_lib 19from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 20 21 22def get_distributed_dataset(dataset, 23 input_workers, 24 strategy, 25 num_replicas_in_sync=None, 26 input_context=None, 27 options=None, 28 build=True): 29 """Returns a distributed dataset from the given tf.data.Dataset instance. 30 31 This is a common function that is used by all strategies to return a 32 distributed dataset. The distributed dataset instance returned is different 33 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 34 instances returned differ from each other in the APIs supported by each of 35 them. 36 37 Args: 38 dataset: a tf.data.Dataset instance. 39 input_workers: an InputWorkers object which specifies devices on which 40 iterators should be created. 41 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 42 handle last partial batch. 43 num_replicas_in_sync: Optional integer. If this is not None, the value is 44 used to decide how to rebatch datasets into smaller batches so that the 45 total batch size for each step (across all workers and replicas) adds up 46 to `dataset`'s batch size. 47 input_context: `InputContext` for sharding. Only pass this in for between 48 graph multi-worker cases where there is only one `input_worker`. In these 49 cases, we will shard based on the `input_pipeline_id` and 50 `num_input_pipelines` in the `InputContext`. 51 options: Default is None. `tf.distribute.InputOptions` used to control 52 options on how this dataset is distributed. 53 build: whether to build underlying datasets when a DistributedDataset is 54 created. This is only useful for `ParameterServerStrategy` now. 55 56 Returns: 57 A distributed dataset instance. 58 """ 59 if tf2.enabled(): 60 return input_lib.DistributedDataset( 61 input_workers, 62 strategy, 63 dataset, 64 num_replicas_in_sync=num_replicas_in_sync, 65 input_context=input_context, 66 build=build, 67 options=options) 68 else: 69 return input_lib_v1.DistributedDatasetV1( 70 dataset, 71 input_workers, 72 strategy, 73 num_replicas_in_sync=num_replicas_in_sync, 74 input_context=input_context, 75 options=options) 76 77 78def get_distributed_datasets_from_function(dataset_fn, 79 input_workers, 80 input_contexts, 81 strategy, 82 options=None, 83 build=True): 84 """Returns a distributed dataset from the given input function. 85 86 This is a common function that is used by all strategies to return a 87 distributed dataset. The distributed dataset instance returned is different 88 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 89 instances returned differ from each other in the APIs supported by each of 90 them. 91 92 Args: 93 dataset_fn: a function that returns a tf.data.Dataset instance. 94 input_workers: an InputWorkers object which specifies devices on which 95 iterators should be created. 96 input_contexts: A list of `InputContext` instances to be passed to call(s) 97 to `dataset_fn`. Length and order should match worker order in 98 `worker_device_pairs`. 99 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 100 handle last partial batch. 101 options: Default is None. `tf.distribute.InputOptions` used to control 102 options on how this dataset is distributed. 103 build: whether to build underlying datasets when a 104 `DistributedDatasetFromFunction` is created. This is only useful for 105 `ParameterServerStrategy` now. 106 107 Returns: 108 A distributed dataset instance. 109 110 Raises: 111 ValueError: if `options.experimental_replication_mode` and 112 `options.experimental_place_dataset_on_device` are not consistent 113 """ 114 if (options is not None and options.experimental_replication_mode != 115 input_lib.InputReplicationMode.PER_REPLICA and 116 options.experimental_place_dataset_on_device): 117 raise ValueError( 118 "When `experimental_place_dataset_on_device` is set for dataset " 119 "placement, you must also specify `PER_REPLICA` for the " 120 "replication mode") 121 122 if (options is not None and options.experimental_replication_mode 123 == input_lib.InputReplicationMode.PER_REPLICA and 124 options.experimental_fetch_to_device and 125 options.experimental_place_dataset_on_device): 126 raise ValueError( 127 "`experimental_place_dataset_on_device` can not be set to True " 128 "when experimental_fetch_to_device is True and " 129 "replication mode is set to `PER_REPLICA`") 130 131 if tf2.enabled(): 132 return input_lib.DistributedDatasetsFromFunction( 133 input_workers, 134 strategy, 135 input_contexts=input_contexts, 136 dataset_fn=dataset_fn, 137 options=options, 138 build=build, 139 ) 140 else: 141 return input_lib_v1.DistributedDatasetsFromFunctionV1( 142 input_workers, strategy, input_contexts, dataset_fn, options) 143