1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed inputs.""" 16 17from tensorflow.python.data.experimental.ops import cardinality as cardinality_lib 18from tensorflow.python.data.ops import dataset_ops 19from tensorflow.python.data.ops import multi_device_iterator_ops 20from tensorflow.python.data.ops import optional_ops 21from tensorflow.python.distribute import input_lib 22from tensorflow.python.eager import context 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.util.deprecation import deprecated 26 27 28class DistributedDatasetV1(input_lib.DistributedDataset): 29 """Distributed dataset that supports prefetching to multiple devices.""" 30 31 def __init__(self, 32 dataset, 33 input_workers, 34 strategy, 35 num_replicas_in_sync=None, 36 input_context=None, 37 options=None): 38 self._input_workers = input_workers 39 super(DistributedDatasetV1, self).__init__( 40 input_workers, 41 strategy, 42 dataset, 43 num_replicas_in_sync=num_replicas_in_sync, 44 input_context=input_context, 45 options=options) 46 47 def make_one_shot_iterator(self): 48 """Get a one time use iterator for DistributedDatasetV1. 49 50 Note: This API is deprecated. Please use `for ... in dataset:` to iterate 51 over the dataset or `iter` to create an iterator. 52 53 Returns: 54 A DistributedIteratorV1 instance. 55 """ 56 return self._make_one_shot_iterator() 57 58 def _make_one_shot_iterator(self): 59 """Get an iterator for DistributedDatasetV1.""" 60 # Graph mode with one shot iterator is disabled because we have to call 61 # `initialize` on the iterator which is only required if we are using a 62 # tf.distribute strategy. 63 if not context.executing_eagerly(): 64 raise ValueError("Cannot create a one shot iterator. Please use " 65 "`make_initializable_iterator()` instead.") 66 return self._get_iterator() 67 68 def make_initializable_iterator(self): 69 """Get an initializable iterator for DistributedDatasetV1. 70 71 Note: This API is deprecated. Please use 72 `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an 73 initializable iterator. 74 75 Returns: 76 A DistributedIteratorV1 instance. 77 """ 78 return self._make_initializable_iterator() 79 80 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument 81 """Get an initializable iterator for DistributedDatasetV1.""" 82 # Eager mode generates already initialized iterators. Hence we cannot create 83 # an initializable iterator. 84 if context.executing_eagerly(): 85 raise ValueError("Cannot create initializable iterator in Eager mode. " 86 "Please use `iter()` instead.") 87 return self._get_iterator() 88 89 def _get_iterator(self): 90 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 91 self._input_workers, 92 self._options) 93 cardinality = input_lib._cardinality(self._cloned_datasets[0]) # pylint: disable=protected-access 94 iterator = DistributedIteratorV1(self._input_workers, worker_iterators, 95 self._strategy, cardinality, 96 self._enable_get_next_as_optional) 97 iterator._element_spec = self.element_spec # pylint: disable=protected-access 98 99 # When async eager is enabled, sometimes the iterator may not finish 100 # initialization before passing to a multi device function, add a sync point 101 # here to make sure all underlying iterators are initialized. 102 if context.executing_eagerly(): 103 context.async_wait() 104 105 return iterator 106 107 # pylint: disable=non-iterator-returned 108 def __iter__(self): 109 if (ops.executing_eagerly_outside_functions() or 110 ops.get_default_graph().building_function): 111 return self._get_iterator() 112 113 raise RuntimeError("__iter__() is only supported inside of tf.function " 114 "or when eager execution is enabled.") 115 116 # pylint: enable=non-iterator-returned 117 118 119class DistributedDatasetsFromFunctionV1( 120 input_lib.DistributedDatasetsFromFunction): 121 """Inputs created from dataset function.""" 122 123 def _make_initializable_iterator(self, shared_name=None): 124 """Get an initializable iterator for DistributedDatasetsFromFunctionV1.""" 125 del shared_name # Unused 126 # Eager mode generates already initialized iterators. Hence we cannot create 127 # an initializable iterator. 128 if context.executing_eagerly(): 129 raise ValueError("Cannot create initializable iterator in Eager mode. " 130 "Please use `iter()` instead.") 131 return self._get_iterator() 132 133 def _make_one_shot_iterator(self): 134 """Get an iterator for iterating over DistributedDatasetsFromFunctionV1.""" 135 # Graph mode with one shot iterator is disabled because we have to call 136 # `initialize` on the iterator which is only required if we are using a 137 # tf.distribute strategy. 138 if not context.executing_eagerly(): 139 raise ValueError("Cannot create a one shot iterator. Please use " 140 "`make_initializable_iterator()` instead.") 141 return self._get_iterator() 142 143 def _get_iterator(self): 144 iterators = _create_iterators_per_worker(self._datasets, 145 self._input_workers, self._options) 146 cardinality = input_lib._cardinality(self._datasets[0]) # pylint: disable=protected-access 147 iterator = DistributedIteratorV1(self._input_workers, iterators, 148 self._strategy, cardinality, 149 self._enable_get_next_as_optional) 150 iterator._element_spec = self._element_spec # pylint: disable=protected-access 151 152 # When async eager is enabled, sometimes the iterator may not finish 153 # initialization before passing to a multi device function, add a sync point 154 # here to make sure all underlying iterators are initialized. 155 if context.executing_eagerly(): 156 context.async_wait() 157 158 return iterator 159 160 # pylint: disable=non-iterator-returned 161 def __iter__(self): 162 if (ops.executing_eagerly_outside_functions() or 163 ops.get_default_graph().building_function): 164 return self._get_iterator() 165 166 raise RuntimeError("__iter__() is only supported inside of tf.function " 167 "or when eager execution is enabled.") 168 169 # pylint: enable=non-iterator-returned 170 171 172class DistributedIteratorV1(input_lib.DistributedIteratorBase): 173 """Input Iterator for a distributed dataset.""" 174 175 # We need a private initializer method for re-initializing multidevice 176 # iterators when used with Keras training loops. If we don't reinitialize the 177 # iterator we run into memory leak issues (b/123315763). 178 @property 179 def _initializer(self): 180 init_ops = [] 181 for it in self._iterators: 182 init_ops.extend(it.initialize()) 183 return control_flow_ops.group(init_ops) 184 185 @deprecated(None, "Use the iterator's `initializer` property instead.") 186 def initialize(self): 187 """Initialize underlying iterators. 188 189 Returns: 190 A list of any initializer ops that should be run. 191 """ 192 return self._initializer 193 194 @property 195 def initializer(self): 196 """Returns a list of ops that initialize the iterator.""" 197 return self.initialize() 198 199 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 200 @property 201 def output_classes(self): 202 return self._iterators[0].output_classes 203 204 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 205 @property 206 def output_shapes(self): 207 return self._iterators[0].output_shapes 208 209 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 210 @property 211 def output_types(self): 212 return self._iterators[0].output_types 213 214 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 215 def get_iterator(self, worker): 216 for i, w in enumerate(self._input_workers.worker_devices): 217 if worker == w: 218 return self._iterators[i] 219 return None 220 221 @property 222 def element_spec(self): 223 """The type specification of an element of this iterator.""" 224 return self._element_spec 225 226 227class DatasetIterator(DistributedIteratorV1): 228 """Iterator created from input dataset.""" 229 230 def __init__(self, 231 dataset, 232 input_workers, 233 strategy, 234 num_replicas_in_sync=None, 235 input_context=None): 236 """Make an iterator for the dataset on given devices. 237 238 If `num_replicas_in_sync` is not None, we split each batch of the dataset 239 into `num_replicas_in_sync` smaller batches, to be distributed among that 240 worker's replicas, so that the batch size for a global step (across all 241 workers and replicas) is as expected. 242 243 Args: 244 dataset: `tf.data.Dataset` that will be used as the input source. 245 input_workers: an `InputWorkers` object. 246 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 247 handle last partial batch. 248 num_replicas_in_sync: Optional integer. If this is not None, the value is 249 used to decide how to rebatch datasets into smaller batches so that the 250 total batch size for each step (across all workers and replicas) adds up 251 to `dataset`'s batch size. 252 input_context: `InputContext` for sharding. Only pass this in for between 253 graph multi-worker cases where there is only one `input_worker`. In 254 these cases, we will shard based on the `input_pipeline_id` and 255 `num_input_pipelines` in the `InputContext`. 256 """ 257 dist_dataset = DistributedDatasetV1( 258 dataset, 259 input_workers, 260 strategy, 261 num_replicas_in_sync=num_replicas_in_sync, 262 input_context=input_context) 263 # pylint: disable=protected-access 264 worker_iterators = _create_iterators_per_worker( 265 dist_dataset._cloned_datasets, input_workers) 266 super(DatasetIterator, 267 self).__init__(input_workers, worker_iterators, strategy, 268 dist_dataset.cardinality, 269 dist_dataset._enable_get_next_as_optional) 270 self._element_spec = dist_dataset.element_spec 271 # pylint: enable=protected-access 272 273 274class InputFunctionIterator(DistributedIteratorV1): 275 """Iterator created from input function.""" 276 277 def __init__(self, input_fn, input_workers, input_contexts, strategy): 278 """Make an iterator for input provided via an input function. 279 280 Currently implements PER_WORKER mode, in which the `input_fn` is called 281 once on each worker. 282 283 TODO(priyag): Add other replication modes. 284 285 Args: 286 input_fn: Input function that returns a `tf.data.Dataset` object. 287 input_workers: an `InputWorkers` object. 288 input_contexts: A list of `InputContext` instances to be passed to call(s) 289 to `input_fn`. Length and order should match worker order in 290 `worker_device_pairs`. 291 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 292 handle last partial batch. 293 """ 294 assert isinstance(input_workers, input_lib.InputWorkers) 295 if input_workers.num_workers != len(input_contexts): 296 raise ValueError("Number of input workers (%d) is not same as number of " 297 "input_contexts (%d)" % 298 (input_workers.num_workers, len(input_contexts))) 299 300 iterators = [] 301 for i, ctx in enumerate(input_contexts): 302 worker = input_workers.worker_devices[i] 303 with ops.device(worker): 304 result = input_fn(ctx) 305 devices = input_workers.compute_devices_for_worker(i) 306 if isinstance(result, dataset_ops.DatasetV2): 307 iterator = _SingleWorkerDatasetIterator(result, worker, devices) 308 elif callable(result): 309 iterator = _SingleWorkerCallableIterator(result, worker, devices) 310 else: 311 raise ValueError( 312 "input_fn must return a tf.data.Dataset or a callable.") 313 iterators.append(iterator) 314 315 super(InputFunctionIterator, self).__init__( 316 input_workers, 317 iterators, 318 strategy, 319 cardinality=cardinality_lib.UNKNOWN, 320 enable_get_next_as_optional=False) 321 self._enable_get_next_as_optional = False 322 323 324class _SingleWorkerDatasetIterator(input_lib._SingleWorkerDatasetIteratorBase): # pylint: disable=protected-access 325 """Iterator for a single DistributedDatasetV1 instance.""" 326 327 def _make_iterator(self): 328 """Make appropriate iterator on the dataset.""" 329 with ops.device(self._worker): 330 if self._options is not None: 331 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 332 self._dataset, 333 self._devices, 334 max_buffer_size=self._options.experimental_per_replica_buffer_size, 335 prefetch_buffer_size=self._options 336 .experimental_per_replica_buffer_size) 337 else: 338 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 339 self._dataset, 340 self._devices, 341 ) 342 343 def initialize(self): 344 """Initialize underlying iterator. 345 346 In eager execution, this simply recreates the underlying iterator. 347 In graph execution, it returns the initializer ops for the underlying 348 iterator. 349 350 Returns: 351 A list of any initializer ops that should be run. 352 """ 353 if ops.executing_eagerly_outside_functions(): 354 self._iterator._eager_reset() # pylint: disable=protected-access 355 return [] 356 else: 357 return [self._iterator.initializer] 358 359 @property 360 def output_classes(self): 361 return dataset_ops.get_legacy_output_classes(self._iterator) 362 363 @property 364 def output_shapes(self): 365 return dataset_ops.get_legacy_output_shapes(self._iterator) 366 367 @property 368 def output_types(self): 369 return dataset_ops.get_legacy_output_types(self._iterator) 370 371 372class _SingleWorkerCallableIterator(object): 373 """Iterator for a single tensor-returning callable.""" 374 375 def __init__(self, fn, worker, devices): 376 self._fn = fn 377 self._worker = worker 378 self._devices = devices 379 380 def get_next(self, device, name=None): 381 """Get next element for the given device from the callable.""" 382 del device, name 383 with ops.device(self._worker): 384 return self._fn() 385 386 def get_next_as_list(self, name=None): 387 """Get next element from the callable.""" 388 del name 389 with ops.device(self._worker): 390 data_list = [self._fn() for _ in self._devices] 391 return data_list 392 393 def get_next_as_optional_list(self): 394 with ops.device(self._worker): 395 data_list = [ 396 optional_ops.Optional.from_value(self._fn()) for _ in self._devices 397 ] 398 return data_list 399 400 def initialize(self): 401 # TODO(petebu) Should this throw an exception instead? 402 return [] 403 404 405def _create_iterators_per_worker(worker_datasets, input_workers, options=None): 406 """Create a multidevice iterator on each of the workers.""" 407 assert isinstance(input_workers, input_lib.InputWorkers) 408 assert len(worker_datasets) == len(input_workers.worker_devices) 409 iterators = [] 410 for i, worker in enumerate(input_workers.worker_devices): 411 with ops.device(worker): 412 worker_devices = input_workers.compute_devices_for_worker(i) 413 iterator = _SingleWorkerDatasetIterator( 414 worker_datasets[i], # pylint: disable=protected-access 415 worker, 416 worker_devices, 417 options) 418 iterators.append(iterator) 419 return iterators 420