1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python API for executing a tf.data.Dataset using a tf.data service.""" 16 17import enum 18import functools 19 20from tensorflow.core.protobuf import data_service_pb2 21from tensorflow.python import tf2 22from tensorflow.python.compat import compat 23from tensorflow.python.data.experimental.ops import compression_ops 24from tensorflow.python.data.experimental.service import _pywrap_server_lib 25from tensorflow.python.data.experimental.service import _pywrap_utils 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.data.ops import options as options_lib 28from tensorflow.python.data.ops import structured_function 29from tensorflow.python.data.ops.options import AutoShardPolicy 30from tensorflow.python.data.ops.options import ExternalStatePolicy 31from tensorflow.python.eager import context 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.ops import gen_experimental_dataset_ops 37from tensorflow.python.ops import string_ops 38from tensorflow.python.util import lazy_loader 39from tensorflow.python.util.tf_export import tf_export 40 41COMPRESSION_AUTO = "AUTO" 42COMPRESSION_NONE = None 43_PARALLEL_EPOCHS = "parallel_epochs" 44_DISTRIBUTED_EPOCH = "distributed_epoch" 45 46# TODO(b/176933539): Use the regular import. 47# TODO(b/238903802): Use TypeSpec serialization methods directly. 48nested_structure_coder = lazy_loader.LazyLoader( 49 "nested_structure_coder", globals(), 50 "tensorflow.python.saved_model.nested_structure_coder") 51 52 53@tf_export("data.experimental.service.ShardingPolicy") 54class ShardingPolicy(enum.IntEnum): 55 """Specifies how to shard data among tf.data service workers. 56 57 OFF: No sharding will be performed. Each worker produces the entire dataset 58 without any sharding. With this mode, the best practice is to shuffle the 59 dataset nondeterministically so that workers process the dataset in different 60 orders. If workers are restarted or join the cluster mid-job, they will begin 61 processing the dataset from the beginning. 62 63 DYNAMIC: The input dataset is dynamically split among workers at runtime. Each 64 worker gets the next split when it reads data from the dispatcher. Data is 65 produced non-deterministically in this mode. Dynamic sharding works well with 66 varying-sized tf.data service clusters, e.g., when you need to auto-scale your 67 workers. Dynamic sharding provides at-most once visitation guarantees. No 68 examples will be repeated, but some may be missed if a tf.data service worker 69 gets restarted while processing a file. 70 71 The following are static sharding policies. The semantics are similar to 72 `tf.data.experimental.AutoShardPolicy`. These policies require: 73 * The tf.data service cluster is configured with a fixed list of workers 74 in DispatcherConfig. 75 * Each client only reads from the local tf.data service worker. 76 77 If a worker is restarted while performing static sharding, the worker will 78 begin processing its shard again from the beginning. 79 80 FILE: Shards by input files (i.e. each worker will get a fixed set of files to 81 process). When this option is selected, make sure that there is at least as 82 many files as workers. If there are fewer input files than workers, a runtime 83 error will be raised. 84 85 DATA: Shards by elements produced by the dataset. Each worker will process the 86 whole dataset and discard the portion that is not for itself. Note that for 87 this mode to correctly partition the dataset elements, the dataset needs to 88 produce elements in a deterministic order. 89 90 FILE_OR_DATA: Attempts FILE-based sharding, falling back to DATA-based 91 sharding on failure. 92 93 HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a 94 placeholder to replace with `shard(num_workers, worker_index)`. 95 """ 96 97 # LINT.IfChange(tf_data_service_sharding_policy) 98 OFF = 0 99 DYNAMIC = 1 100 FILE = 2 101 DATA = 3 102 FILE_OR_DATA = 4 103 HINT = 5 104 # LINT.ThenChange() 105 106 def _to_proto(self): 107 """Converts the policy to ProcessingModeDef proto enum.""" 108 109 if self == ShardingPolicy.OFF: 110 return data_service_pb2.ProcessingModeDef.OFF 111 if self == ShardingPolicy.DYNAMIC: 112 return data_service_pb2.ProcessingModeDef.DYNAMIC 113 if self == ShardingPolicy.FILE: 114 return data_service_pb2.ProcessingModeDef.FILE 115 if self == ShardingPolicy.DATA: 116 return data_service_pb2.ProcessingModeDef.DATA 117 if self == ShardingPolicy.FILE_OR_DATA: 118 return data_service_pb2.ProcessingModeDef.FILE_OR_DATA 119 if self == ShardingPolicy.HINT: 120 return data_service_pb2.ProcessingModeDef.HINT 121 raise ValueError(f"Unable to convert sharding policy {self!r} to proto.") 122 123 124@tf_export("data.experimental.service.CrossTrainerCache") 125class CrossTrainerCache: 126 """Options related to the tf.data service cross trainer cache. 127 128 This is used to enable cross-trainer cache when distributing a dataset. For 129 example: 130 131 ``` 132 dataset = dataset.apply(tf.data.experimental.service.distribute( 133 processing_mode=tf.data.experimental.service.ShardingPolicy.OFF, 134 service=FLAGS.tf_data_service_address, 135 job_name="job", 136 cross_trainer_cache=data_service_ops.CrossTrainerCache( 137 trainer_id=trainer_id()))) 138 ``` 139 140 For more details, refer to 141 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers. 142 """ 143 144 def __init__(self, trainer_id): 145 """Constructs a CrossTrainerCache. 146 147 Args: 148 trainer_id: Each training job has a unique ID. Once a job has consumed 149 data, the data remains in the cache and is re-used by jobs with different 150 `trainer_id`s. Requests with the same `trainer_id` do not re-use data. 151 152 Raises: 153 ValueError if `trainer_id` is empty. 154 """ 155 if not trainer_id: 156 raise ValueError( 157 "tf.data service cross-trainer cache requires a non-empty trainer ID." 158 ) 159 self.trainer_id = trainer_id 160 161 def _to_proto(self): 162 return data_service_pb2.CrossTrainerCacheOptions(trainer_id=self.trainer_id) 163 164 165def _get_validated_sharding_policy(processing_mode): 166 """Validates `processing_mode` and converts it to ShardingPolicy.""" 167 168 if isinstance(processing_mode, ShardingPolicy): 169 return processing_mode 170 if processing_mode == _PARALLEL_EPOCHS: 171 return ShardingPolicy.OFF 172 if processing_mode == _DISTRIBUTED_EPOCH: 173 return ShardingPolicy.DYNAMIC 174 175 raise ValueError("tf.data service processing mode should be a " 176 "`tf.data.experimental.service.ShardingPolicy`, " 177 "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got " 178 f"{processing_mode!r}.") 179 180 181def _validate_job_name(job_name): 182 if job_name is None: 183 return 184 if not isinstance(job_name, str): 185 raise ValueError("`job_name` must be a string, but `job_name` was of type " 186 f"{type(job_name)}. job_name={job_name}") 187 if not job_name: 188 raise ValueError("`job_name` must not be empty") 189 190 191def _validate_compression(compression): 192 valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE] 193 if compression not in valid_compressions: 194 raise ValueError(f"Invalid `compression` argument: {compression}. " 195 f"Must be one of {valid_compressions}.") 196 197 198def _get_compression_proto(compression): 199 if compression == COMPRESSION_AUTO: 200 return data_service_pb2.DataServiceMetadata.COMPRESSION_SNAPPY 201 if compression == COMPRESSION_NONE: 202 return data_service_pb2.DataServiceMetadata.COMPRESSION_OFF 203 raise ValueError(f"Invalid `compression` argument: {compression}. " 204 f"Must be one of {[COMPRESSION_AUTO, COMPRESSION_NONE]}.") 205 206 207def _decide_compression(compression, data_transfer_protocol): 208 if (compression == COMPRESSION_AUTO and data_transfer_protocol != "grpc" and 209 data_transfer_protocol is not None): 210 return COMPRESSION_NONE 211 return compression 212 213 214def _to_tensor(dataset_id): 215 """Converts `dataset_id` to Tensor.""" 216 217 if isinstance(dataset_id, ops.Tensor): 218 return dataset_id 219 if isinstance(dataset_id, str) or isinstance(dataset_id, bytes): 220 return ops.convert_to_tensor( 221 dataset_id, dtype=dtypes.string, name="dataset_id") 222 return ops.convert_to_tensor( 223 dataset_id, dtype=dtypes.int64, name="dataset_id") 224 225 226def _to_string(dataset_id): 227 """Converts `dataset_id` to string.""" 228 229 if isinstance(dataset_id, ops.Tensor): 230 return (dataset_id if dataset_id.dtype == dtypes.string else 231 string_ops.as_string(dataset_id)) 232 return (dataset_id.decode() 233 if isinstance(dataset_id, bytes) else str(dataset_id)) 234 235 236class _DataServiceDatasetV2(dataset_ops.DatasetSource): 237 """A `Dataset` that reads elements from the tf.data service.""" 238 239 def __init__(self, 240 dataset_id, 241 processing_mode, 242 address, 243 element_spec, 244 protocol, 245 data_transfer_protocol, 246 job_name=None, 247 consumer_index=None, 248 num_consumers=None, 249 max_outstanding_requests=None, 250 task_refresh_interval_hint_ms=None, 251 cross_trainer_cache=None, 252 target_workers="AUTO"): 253 """Constructs a _DataServiceDatasetV2. 254 255 Args: 256 dataset_id: The dataset id for the dataset to read from. 257 processing_mode: A `tf.data.experimental.service.ShardingPolicy` 258 specifying how to shard the dataset among tf.data workers. See 259 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 260 compatibility, `processing_mode` may also be set to the strings 261 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 262 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 263 address: The tf.data service address, e.g. "localhost:5000". 264 element_spec: The dataset element spec for the dataset to read from. 265 protocol: The protocol to use for communicating with the tf.data service, 266 e.g. "grpc". 267 data_transfer_protocol: (Optional.) The protocol to use for transferring 268 data with the tf.data service. By default, data is transferred using 269 gRPC. 270 job_name: (Optional.) The name of the job. If provided, it must be a 271 non-empty string or Tensor. This argument makes it possible for multiple 272 datasets to share the same job. The default behavior is that the dataset 273 creates anonymous, exclusively owned jobs. 274 consumer_index: (Optional.) The index of the consumer in the range from 275 `0` to `num_consumers`. Must be specified alongside `num_consumers`. 276 When specified, consumers will read from the job in a strict round-robin 277 order, instead of the default first-come-first-served order. 278 num_consumers: (Optional.) The number of consumers which will consume from 279 the job. Must be specified alongside `consumer_index`. When specified, 280 consumers will read from the job in a strict round-robin order, instead 281 of the default first-come-first-served order. When `num_consumers` is 282 specified, the dataset must have infinite cardinality to prevent a 283 producer from running out of data early and causing consumers to go out 284 of sync. 285 max_outstanding_requests: (Optional.) A limit on how many elements may be 286 requested at the same time. You can use this option to control the 287 amount of memory used, since `distribute` won't use more than 288 `element_size` * `max_outstanding_requests` of memory. 289 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query 290 the dispatcher for task changes. 291 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 292 provided, dataset iteration will be shared across concurrently running 293 trainers. See 294 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 295 for details. 296 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, 297 tf.data runtime decides which workers to read from. If `"ANY"`, reads 298 from any tf.data service workers. If `"LOCAL"`, only reads from local 299 in-processs tf.data service workers. `"AUTO"` works well for most cases, 300 while users can specify other targets. For example, `"LOCAL"` helps 301 avoid RPCs and data copy if every TF worker colocates with a tf.data 302 service worker. Consumers of a shared job must use the same 303 `target_workers`. Defaults to `"AUTO"`. 304 """ 305 if consumer_index is None != num_consumers is None: 306 raise ValueError( 307 "Must either set both `consumer_index` and `num_consumers`, " 308 "or neither. ", 309 f"consumer_index={consumer_index}, num_consumers={num_consumers}") 310 if num_consumers is not None and job_name is None: 311 raise ValueError("`job_name` must be set when setting `num_consumers`. " 312 f"num_consumers was set to {num_consumers}.") 313 314 processing_mode_def = data_service_pb2.ProcessingModeDef( 315 sharding_policy=_get_validated_sharding_policy( 316 processing_mode)._to_proto()) 317 if job_name is None: 318 job_name = "" 319 if max_outstanding_requests is None: 320 max_outstanding_requests = dataset_ops.AUTOTUNE 321 if task_refresh_interval_hint_ms is None: 322 task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE 323 324 self._dataset_id = _to_tensor(dataset_id) 325 self._processing_mode = ops.convert_to_tensor( 326 processing_mode_def.SerializeToString(), 327 dtype=dtypes.string, 328 name="processing_mode") 329 self._address = ops.convert_to_tensor( 330 address, dtype=dtypes.string, name="address") 331 self._protocol = ops.convert_to_tensor( 332 protocol, dtype=dtypes.string, name="protocol") 333 self._job_name = ops.convert_to_tensor( 334 job_name, dtype=dtypes.string, name="job_name") 335 self._consumer_index = ops.convert_to_tensor( 336 -1 if consumer_index is None else consumer_index, 337 dtype=dtypes.int64, 338 name="consumer_index") 339 self._num_consumers = ops.convert_to_tensor( 340 -1 if num_consumers is None else num_consumers, 341 dtype=dtypes.int64, 342 name="num_consumers") 343 self._max_outstanding_requests = ops.convert_to_tensor( 344 max_outstanding_requests, 345 dtype=dtypes.int64, 346 name="max_outstanding_requests") 347 self._element_spec = element_spec 348 uncompress_func = structured_function.StructuredFunctionWrapper( 349 lambda x: compression_ops.uncompress(x, output_spec=element_spec), 350 transformation_name="DataServiceDataset.uncompress()", 351 input_structure=tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)) 352 cross_trainer_cache_options = ( 353 cross_trainer_cache._to_proto().SerializeToString() 354 if cross_trainer_cache else None) 355 356 compat_kwargs = {} 357 if data_transfer_protocol is not None: 358 compat_kwargs["data_transfer_protocol"] = data_transfer_protocol 359 360 if (compat.forward_compatible(2022, 8, 31) or 361 self._dataset_id.dtype == dtypes.string): 362 data_service_dataset = ( 363 gen_experimental_dataset_ops.data_service_dataset_v4) 364 else: 365 data_service_dataset = ( 366 gen_experimental_dataset_ops.data_service_dataset_v3) 367 368 # If `uncompress` is `True`, the dataset will query the servers to find 369 # out the actual compression used. It is always set to `True` the first 370 # time the graph is built, and set to false when serializing, so we will 371 # uncompress at most once. 372 uncompress = True 373 variant_tensor = data_service_dataset( 374 dataset_id=self._dataset_id, 375 processing_mode=self._processing_mode, 376 address=self._address, 377 protocol=self._protocol, 378 job_name=self._job_name, 379 consumer_index=self._consumer_index, 380 num_consumers=self._num_consumers, 381 max_outstanding_requests=self._max_outstanding_requests, 382 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 383 iteration_counter=( 384 gen_experimental_dataset_ops.dummy_iteration_counter()), 385 target_workers=target_workers, 386 uncompress=uncompress, 387 uncompress_fn=uncompress_func.function, 388 cross_trainer_cache_options=cross_trainer_cache_options, 389 **compat_kwargs, 390 **self._flat_structure) 391 super(_DataServiceDatasetV2, self).__init__(variant_tensor) 392 393 @property 394 def element_spec(self): 395 return self._element_spec 396 397 398class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): 399 """A `Dataset` that executes its input through the tf.data service.""" 400 401 @functools.wraps(_DataServiceDatasetV2.__init__) 402 def __init__(self, dataset_id, processing_mode, address, element_spec, 403 protocol, data_transfer_protocol, job_name, consumer_index, 404 num_consumers, max_outstanding_requests, 405 task_refresh_interval_hint_ms, cross_trainer_cache, 406 target_workers): 407 408 self._wrapped = _DataServiceDatasetV2( 409 dataset_id=dataset_id, 410 processing_mode=processing_mode, 411 address=address, 412 element_spec=element_spec, 413 protocol=protocol, 414 data_transfer_protocol=data_transfer_protocol, 415 job_name=job_name, 416 consumer_index=consumer_index, 417 num_consumers=num_consumers, 418 max_outstanding_requests=max_outstanding_requests, 419 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 420 cross_trainer_cache=cross_trainer_cache, 421 target_workers=target_workers) 422 super(_DataServiceDatasetV1, self).__init__(self._wrapped) 423 424 425if tf2.enabled(): 426 _DataServiceDataset = _DataServiceDatasetV2 427else: 428 _DataServiceDataset = _DataServiceDatasetV1 429 430 431def _parse_service(service): 432 """Converts a tf.data service string into a (protocol, address) tuple. 433 434 Args: 435 service: A string in the format "protocol://address" or just "address". If 436 the string is only an address, the default protocol will be used. 437 438 Returns: 439 The (protocol, address) tuple 440 """ 441 if not isinstance(service, str): 442 raise ValueError("`service` must be a string, but `service` was of type " 443 f"{type(service)}. service={service}") 444 if not service: 445 raise ValueError("`service` must not be empty") 446 parts = service.split("://") 447 if len(parts) == 2: 448 protocol, address = parts 449 elif len(parts) == 1: 450 address = parts[0] 451 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 452 else: 453 raise ValueError("Malformed `service` string has multiple '://': " 454 f"{service}.") 455 # TODO(aaudibert): Considering validating reachability of address here. 456 return (protocol, address) 457 458 459def _distribute(processing_mode, 460 service, 461 job_name=None, 462 consumer_index=None, 463 num_consumers=None, 464 max_outstanding_requests=None, 465 task_refresh_interval_hint_ms=None, 466 data_transfer_protocol=None, 467 compression="AUTO", 468 cross_trainer_cache=None, 469 target_workers="AUTO"): 470 """A transformation that moves dataset processing to the tf.data service. 471 472 This transformation is similar to `distribute`, but supports additional 473 parameters which we do not yet want to add to the public Python API. 474 475 Args: 476 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 477 how to shard the dataset among tf.data workers. See 478 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 479 compatibility, `processing_mode` may also be set to the strings 480 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 481 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 482 service: A string or a tuple indicating how to connect to the tf.data 483 service. If it's a string, it should be in the format 484 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 485 address and `<protocol>` can optionally be used to override the default 486 protocol to use. If it's a tuple, it should be (protocol, address). 487 job_name: (Optional.) The name of the job. If provided, it must be a 488 non-empty string. This argument makes it possible for multiple datasets to 489 share the same job. The default behavior is that the dataset creates 490 anonymous, exclusively owned jobs. 491 consumer_index: (Optional.) The index of the consumer in the range from `0` 492 to `num_consumers`. Must be specified alongside `num_consumers`. When 493 specified, consumers will read from the job in a strict round-robin order, 494 instead of the default first-come-first-served order. 495 num_consumers: (Optional.) The number of consumers which will consume from 496 the job. Must be specified alongside `consumer_index`. When specified, 497 consumers will read from the job in a strict round-robin order, instead of 498 the default first-come-first-served order. When `num_consumers` is 499 specified, the dataset must have infinite cardinality to prevent a 500 producer from running out of data early and causing consumers to go out of 501 sync. 502 max_outstanding_requests: (Optional.) A limit on how many elements may be 503 requested at the same time. You can use this option to control the amount 504 of memory used, since `distribute` won't use more than `element_size` * 505 `max_outstanding_requests` of memory. 506 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 507 dispatcher for task changes. 508 data_transfer_protocol: (Optional.) The protocol to use for transferring 509 data with the tf.data service. By default, data is transferred using gRPC. 510 compression: How to compress the dataset's elements before transferring them 511 over the network. "AUTO" leaves the decision of how to compress up to the 512 tf.data service runtime. `None` indicates not to compress. 513 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 514 provided, dataset iteration will be shared across concurrently running 515 trainers. See 516 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 517 for details. 518 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 519 runtime decides which workers to read from. If `"ANY"`, reads from any 520 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 521 tf.data service workers. `"AUTO"` works well for most cases, while users 522 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 523 data copy if every TF worker colocates with a tf.data service worker. 524 Consumers of a shared job must use the same `target_workers`. Defaults to 525 `"AUTO"`. 526 527 Returns: 528 Dataset: A `Dataset` of the elements produced by the data service. 529 """ 530 processing_mode = _get_validated_sharding_policy(processing_mode) 531 _validate_compression(compression) 532 compression = _decide_compression(compression, data_transfer_protocol) 533 534 def _apply_fn(dataset): # pylint: disable=missing-docstring 535 dataset_id = _register_dataset(service, dataset, compression=compression) 536 return _from_dataset_id( 537 processing_mode, 538 service, 539 dataset_id, 540 dataset.element_spec, 541 job_name=job_name, 542 consumer_index=consumer_index, 543 num_consumers=num_consumers, 544 max_outstanding_requests=max_outstanding_requests, 545 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 546 data_transfer_protocol=data_transfer_protocol, 547 compression=compression, 548 cross_trainer_cache=cross_trainer_cache, 549 target_workers=target_workers) 550 551 return _apply_fn 552 553 554@tf_export("data.experimental.service.distribute") 555def distribute(processing_mode, 556 service, 557 job_name=None, 558 consumer_index=None, 559 num_consumers=None, 560 max_outstanding_requests=None, 561 data_transfer_protocol=None, 562 compression="AUTO", 563 cross_trainer_cache=None, 564 target_workers="AUTO"): 565 """A transformation that moves dataset processing to the tf.data service. 566 567 When you iterate over a dataset containing the `distribute` transformation, 568 the tf.data service creates a "job" which produces data for the dataset 569 iteration. 570 571 The tf.data service uses a cluster of workers to prepare data for training 572 your model. 573 The `processing_mode` argument to `tf.data.experimental.service.distribute` 574 describes how to leverage multiple workers to process the input dataset. 575 Currently, there are two processing modes to choose from: "distributed_epoch" 576 and "parallel_epochs". 577 578 "distributed_epoch" means that the dataset will be split across all tf.data 579 service workers. 580 The dispatcher produces "splits" for the dataset and sends them to workers for 581 further processing. For example, if a dataset begins with a list of filenames, 582 the dispatcher will iterate through the filenames and send the filenames to 583 tf.data workers, which will perform the rest of the dataset transformations on 584 those files. "distributed_epoch" is useful when your model needs to see each 585 element of the dataset exactly once, or if it needs to see the data in a 586 generally-sequential order. "distributed_epoch" only works for datasets with 587 splittable sources, such as `Dataset.from_tensor_slices`, 588 `Dataset.list_files`, or `Dataset.range`. 589 590 "parallel_epochs" means that the entire input dataset will be processed 591 independently by each of the tf.data service workers. 592 For this reason, it is important to shuffle data (e.g. filenames) 593 non-deterministically, so that each worker will process the elements of the 594 dataset in a different order. "parallel_epochs" can be used to distribute 595 datasets that aren't splittable. 596 597 With two workers, "parallel_epochs" will produce every element of the dataset 598 twice: 599 600 >>> dispatcher = tf.data.experimental.service.DispatchServer() 601 >>> dispatcher_address = dispatcher.target.split("://")[1] 602 >>> # Start two workers 603 >>> workers = [ 604 ... tf.data.experimental.service.WorkerServer( 605 ... tf.data.experimental.service.WorkerConfig( 606 ... dispatcher_address=dispatcher_address)) for _ in range(2) 607 ... ] 608 >>> dataset = tf.data.Dataset.range(10) 609 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 610 ... processing_mode="parallel_epochs", service=dispatcher.target)) 611 >>> print(sorted(list(dataset.as_numpy_iterator()))) 612 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] 613 614 "distributed_epoch", on the other hand, will still produce each element once: 615 616 >>> dispatcher = tf.data.experimental.service.DispatchServer() 617 >>> dispatcher_address = dispatcher.target.split("://")[1] 618 >>> workers = [ 619 ... tf.data.experimental.service.WorkerServer( 620 ... tf.data.experimental.service.WorkerConfig( 621 ... dispatcher_address=dispatcher_address)) for _ in range(2) 622 ... ] 623 >>> dataset = tf.data.Dataset.range(10) 624 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 625 ... processing_mode="distributed_epoch", service=dispatcher.target)) 626 >>> print(sorted(list(dataset.as_numpy_iterator()))) 627 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 628 629 When using `apply(tf.data.experimental.service.distribute(...))`, the dataset 630 before the `apply` transformation executes within the tf.data service, while 631 the operations after `apply` happen within the local process. 632 633 >>> dispatcher = tf.data.experimental.service.DispatchServer() 634 >>> dispatcher_address = dispatcher.target.split("://")[1] 635 >>> workers = [ 636 ... tf.data.experimental.service.WorkerServer( 637 ... tf.data.experimental.service.WorkerConfig( 638 ... dispatcher_address=dispatcher_address)) for _ in range(2) 639 ... ] 640 >>> dataset = tf.data.Dataset.range(5) 641 >>> dataset = dataset.map(lambda x: x*x) 642 >>> dataset = dataset.apply( 643 ... tf.data.experimental.service.distribute("parallel_epochs", 644 ... dispatcher.target)) 645 >>> dataset = dataset.map(lambda x: x+1) 646 >>> print(sorted(list(dataset.as_numpy_iterator()))) 647 [1, 1, 2, 2, 5, 5, 10, 10, 17, 17] 648 649 In the above example, the dataset operations (before applying the `distribute` 650 function on the elements) will be executed on the tf.data workers, 651 and the elements are provided over RPC. The remaining transformations 652 (after the call to `distribute`) will be executed locally. The dispatcher 653 and the workers will bind to usused free ports (which are chosen at random), 654 in order to communicate with each other. However, to bind them to specific 655 ports, the `port` parameter can be passed. 656 657 The `job_name` argument allows jobs to be shared across multiple 658 datasets. Instead of each dataset creating its own job, all 659 datasets with the same `job_name` will consume from the same job. A new job 660 will be created for each iteration of the dataset (with each repetition of 661 `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer` 662 is serving on `localhost:5000` and two training workers (in either a single 663 client or multi-client setup) iterate over the below dataset, and there is a 664 single tf.data worker: 665 666 ``` 667 range5_dataset = tf.data.Dataset.range(5) 668 dataset = range5_dataset.apply(tf.data.experimental.service.distribute( 669 "parallel_epochs", "localhost:5000", job_name="my_job_name")) 670 for iteration in range(3): 671 print(list(dataset)) 672 ``` 673 674 The elements of each job will be split between the two processes, with 675 elements being consumed by the processes on a first-come first-served basis. 676 One possible result is that process 1 prints 677 678 ``` 679 [0, 2, 4] 680 [0, 1, 3] 681 [1] 682 ``` 683 684 and process 2 prints 685 686 ``` 687 [1, 3] 688 [2, 4] 689 [0, 2, 3, 4] 690 ``` 691 692 Job names must not be re-used across different training jobs within the 693 lifetime of the tf.data service. In general, the tf.data service is expected 694 to live for the duration of a single training job. 695 To use the tf.data service with multiple training jobs, make sure to use 696 different job names to avoid conflicts. For example, suppose a training job 697 calls `distribute` with `job_name="job"` and reads until end of input. If 698 another independent job connects to the same tf.data service and tries to read 699 from `job_name="job"`, it will immediately receive end of input, without 700 getting any data. 701 702 **Coordinated data read** 703 704 By default, when multiple consumers read from the same job, they receive data 705 on a first-come first-served basis. In some use cases, it is advantageous to 706 coordinate the consumers. At each step, consumers read data from the same 707 worker. 708 709 For example, the tf.data service can be used to coordinate example sizes 710 across a cluster during synchronous training, so that during each step all 711 replicas train on similar-sized elements. To achieve this, define a dataset 712 which generates rounds of `num_consumers` consecutive similar-sized batches, 713 then enable coordinated reads by setting `consumer_index` and `num_consumers`. 714 715 NOTE: To keep consumers in sync, round robin data consumption requires that 716 the dataset have infinite cardinality. You can get this by adding `.repeat()` 717 at the end of the dataset definition. 718 719 **Keras and Distribution Strategies** 720 721 The dataset produced by the `distribute` transformation can be passed to 722 Keras' `Model.fit` or Distribution Strategy's 723 `tf.distribute.Strategy.experimental_distribute_dataset` like any other 724 `tf.data.Dataset`. We recommend setting a `job_name` on the call to 725 `distribute` so that if there are multiple workers, they read data from the 726 same job. Note that the autosharding normally performed by 727 `experimental_distribute_dataset` will be disabled when setting a `job_name`, 728 since sharing the job already results in splitting data across the workers. 729 When using a shared job, data will be dynamically balanced across workers, so 730 that they reach end of input about the same time. This results in better 731 worker utilization than with autosharding, where each worker processes an 732 independent set of files, and some workers may run out of data earlier than 733 others. 734 735 Args: 736 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 737 how to shard the dataset among tf.data workers. See 738 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 739 compatibility, `processing_mode` may also be set to the strings 740 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 741 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 742 service: A string or a tuple indicating how to connect to the tf.data 743 service. If it's a string, it should be in the format 744 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 745 address and `<protocol>` can optionally be used to override the default 746 protocol to use. If it's a tuple, it should be (protocol, address). 747 job_name: (Optional.) The name of the job. If provided, it must be a 748 non-empty string. This argument makes it possible for multiple datasets to 749 share the same job. The default behavior is that the dataset creates 750 anonymous, exclusively owned jobs. 751 consumer_index: (Optional.) The index of the consumer in the range from `0` 752 to `num_consumers`. Must be specified alongside `num_consumers`. When 753 specified, consumers will read from the job in a strict round-robin order, 754 instead of the default first-come-first-served order. 755 num_consumers: (Optional.) The number of consumers which will consume from 756 the job. Must be specified alongside `consumer_index`. When specified, 757 consumers will read from the job in a strict round-robin order, instead of 758 the default first-come-first-served order. When `num_consumers` is 759 specified, the dataset must have infinite cardinality to prevent a 760 producer from running out of data early and causing consumers to go out of 761 sync. 762 max_outstanding_requests: (Optional.) A limit on how many elements may be 763 requested at the same time. You can use this option to control the amount 764 of memory used, since `distribute` won't use more than `element_size` * 765 `max_outstanding_requests` of memory. 766 data_transfer_protocol: (Optional.) The protocol to use for transferring 767 data with the tf.data service. By default, data is transferred using gRPC. 768 compression: How to compress the dataset's elements before transferring them 769 over the network. "AUTO" leaves the decision of how to compress up to the 770 tf.data service runtime. `None` indicates not to compress. 771 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 772 provided, dataset iteration will be shared across concurrently running 773 trainers. See 774 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 775 for details. 776 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 777 runtime decides which workers to read from. If `"ANY"`, reads from any 778 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 779 tf.data service workers. `"AUTO"` works well for most cases, while users 780 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 781 data copy if every TF worker colocates with a tf.data service worker. 782 Consumers of a shared job must use the same `target_workers`. Defaults to 783 `"AUTO"`. 784 785 Returns: 786 Dataset: A `Dataset` of the elements produced by the data service. 787 """ 788 _validate_job_name(job_name) 789 return _distribute( 790 processing_mode=processing_mode, 791 service=service, 792 job_name=job_name, 793 consumer_index=consumer_index, 794 num_consumers=num_consumers, 795 max_outstanding_requests=max_outstanding_requests, 796 data_transfer_protocol=data_transfer_protocol, 797 compression=compression, 798 cross_trainer_cache=cross_trainer_cache, 799 target_workers=target_workers) 800 801 802def _register_dataset(service, dataset, compression, dataset_id=None): 803 """Registers a dataset with the tf.data service. 804 805 This transformation is similar to `register_dataset`, but supports additional 806 parameters which we do not yet want to add to the public Python API. 807 808 Args: 809 service: A string or a tuple indicating how to connect to the tf.data 810 service. If it's a string, it should be in the format 811 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 812 address and `<protocol>` can optionally be used to override the default 813 protocol to use. If it's a tuple, it should be (protocol, address). 814 dataset: A `tf.data.Dataset` to register with the tf.data service. 815 compression: How to compress the dataset's elements before transferring them 816 over the network. "AUTO" leaves the decision of how to compress up to the 817 tf.data service runtime. `None` indicates not to compress. 818 dataset_id: (Optional.) By default, tf.data service generates a unique 819 (string) ID for each registered dataset. If a `dataset_id` is provided, it 820 will use the specified ID. If a dataset with a matching ID already exists, 821 no new dataset is registered. This is useful if multiple training jobs 822 want to (re)use the same dataset for training. In this case, they can 823 register the dataset with the same dataset ID. 824 825 Returns: 826 A scalar string tensor representing the dataset ID. 827 """ 828 _validate_compression(compression) 829 if isinstance(service, tuple): 830 protocol, address = service 831 else: 832 protocol, address = _parse_service(service) 833 external_state_policy = dataset.options().experimental_external_state_policy 834 if external_state_policy is None: 835 external_state_policy = ExternalStatePolicy.WARN 836 837 encoded_spec = None 838 if context.executing_eagerly(): 839 encoded_spec = nested_structure_coder.encode_structure( 840 dataset.element_spec).SerializeToString() 841 842 if compression == COMPRESSION_AUTO: 843 dataset = dataset.map( 844 lambda *x: compression_ops.compress(x), 845 num_parallel_calls=dataset_ops.AUTOTUNE) 846 dataset = dataset.prefetch(dataset_ops.AUTOTUNE) 847 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 848 849 metadata = data_service_pb2.DataServiceMetadata( 850 element_spec=encoded_spec, 851 compression=_get_compression_proto(compression)) 852 853 if compat.forward_compatible(2022, 8, 31) or dataset_id: 854 return gen_experimental_dataset_ops.register_dataset_v2( 855 dataset._variant_tensor, # pylint: disable=protected-access 856 address=address, 857 protocol=protocol, 858 external_state_policy=external_state_policy.value, 859 requested_dataset_id=dataset_id, 860 metadata=metadata.SerializeToString()) 861 else: 862 return gen_experimental_dataset_ops.register_dataset( 863 dataset._variant_tensor, # pylint: disable=protected-access 864 address=address, 865 protocol=protocol, 866 external_state_policy=external_state_policy.value, 867 metadata=metadata.SerializeToString()) 868 869 870@tf_export("data.experimental.service.register_dataset") 871def register_dataset(service, dataset, compression="AUTO", dataset_id=None): 872 """Registers a dataset with the tf.data service. 873 874 `register_dataset` registers a dataset with the tf.data service so that 875 datasets can be created later with 876 `tf.data.experimental.service.from_dataset_id`. This is useful when the 877 dataset 878 is registered by one process, then used in another process. When the same 879 process is both registering and reading from the dataset, it is simpler to use 880 `tf.data.experimental.service.distribute` instead. 881 882 If the dataset is already registered with the tf.data service, 883 `register_dataset` returns the already-registered dataset's id. 884 885 >>> dispatcher = tf.data.experimental.service.DispatchServer() 886 >>> dispatcher_address = dispatcher.target.split("://")[1] 887 >>> worker = tf.data.experimental.service.WorkerServer( 888 ... tf.data.experimental.service.WorkerConfig( 889 ... dispatcher_address=dispatcher_address)) 890 >>> dataset = tf.data.Dataset.range(10) 891 >>> dataset_id = tf.data.experimental.service.register_dataset( 892 ... dispatcher.target, dataset) 893 >>> dataset = tf.data.experimental.service.from_dataset_id( 894 ... processing_mode="parallel_epochs", 895 ... service=dispatcher.target, 896 ... dataset_id=dataset_id, 897 ... element_spec=dataset.element_spec) 898 >>> print(list(dataset.as_numpy_iterator())) 899 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 900 901 Args: 902 service: A string or a tuple indicating how to connect to the tf.data 903 service. If it's a string, it should be in the format 904 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 905 address and `<protocol>` can optionally be used to override the default 906 protocol to use. If it's a tuple, it should be (protocol, address). 907 dataset: A `tf.data.Dataset` to register with the tf.data service. 908 compression: (Optional.) How to compress the dataset's elements before 909 transferring them over the network. "AUTO" leaves the decision of how to 910 compress up to the tf.data service runtime. `None` indicates not to 911 compress. 912 dataset_id: (Optional.) By default, tf.data service generates a unique 913 (string) ID for each registered dataset. If a `dataset_id` is provided, it 914 will use the specified ID. If a dataset with a matching ID already exists, 915 no new dataset is registered. This is useful if multiple training jobs 916 want to (re)use the same dataset for training. In this case, they can 917 register the dataset with the same dataset ID. 918 919 Returns: 920 A scalar string tensor representing the dataset ID. 921 """ 922 return _register_dataset(service, dataset, compression, dataset_id) 923 924 925def _from_dataset_id(processing_mode, 926 service, 927 dataset_id, 928 element_spec, 929 job_name=None, 930 consumer_index=None, 931 num_consumers=None, 932 max_outstanding_requests=None, 933 task_refresh_interval_hint_ms=None, 934 data_transfer_protocol=None, 935 compression="AUTO", 936 cross_trainer_cache=None, 937 target_workers="AUTO"): 938 """Creates a dataset which reads data from the tf.data service. 939 940 This transformation is similar to `from_dataset_id`, but supports additional 941 parameters which we do not yet want to add to the public Python API. 942 943 Args: 944 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 945 how to shard the dataset among tf.data workers. See 946 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 947 compatibility, `processing_mode` may also be set to the strings 948 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 949 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 950 service: A string or a tuple indicating how to connect to the tf.data 951 service. If it's a string, it should be in the format 952 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 953 address and `<protocol>` can optionally be used to override the default 954 protocol to use. If it's a tuple, it should be (protocol, address). 955 dataset_id: The id of the dataset to read from. This id is returned by 956 `register_dataset` when the dataset is registered with the tf.data 957 service. 958 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 959 elements produced by the dataset. This argument is only required inside a 960 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec 961 for a given dataset. 962 job_name: (Optional.) The name of the job. If provided, it must be a 963 non-empty string or tensor. This argument makes it possible for multiple 964 datasets to share the same job. The default behavior is that the dataset 965 creates anonymous, exclusively owned jobs. 966 consumer_index: (Optional.) The index of the consumer in the range from `0` 967 to `num_consumers`. Must be specified alongside `num_consumers`. When 968 specified, consumers will read from the job in a strict round-robin order, 969 instead of the default first-come-first-served order. 970 num_consumers: (Optional.) The number of consumers which will consume from 971 the job. Must be specified alongside `consumer_index`. When specified, 972 consumers will read from the job in a strict round-robin order, instead of 973 the default first-come-first-served order. When `num_consumers` is 974 specified, the dataset must have infinite cardinality to prevent a 975 producer from running out of data early and causing consumers to go out of 976 sync. 977 max_outstanding_requests: (Optional.) A limit on how many elements may be 978 requested at the same time. You can use this option to control the amount 979 of memory used, since `distribute` won't use more than `element_size` * 980 `max_outstanding_requests` of memory. 981 task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the 982 dispatcher for task changes. 983 data_transfer_protocol: (Optional.) The protocol to use for transferring 984 data with the tf.data service. By default, data is transferred using gRPC. 985 compression: An indication of how the dataset's elements were compressed, so 986 that `from_dataset_id` can uncompress them if necessary. 987 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 988 provided, dataset iteration will be shared across concurrently running 989 trainers. See 990 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 991 for details. 992 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 993 runtime decides which workers to read from. If `"ANY"`, reads from any 994 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 995 tf.data service workers. `"AUTO"` works well for most cases, while users 996 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 997 data copy if every TF worker colocates with a tf.data service worker. 998 Consumers of a shared job must use the same `target_workers`. Defaults to 999 `"AUTO"`. 1000 1001 Returns: 1002 A `tf.data.Dataset` which reads from the tf.data service. 1003 """ 1004 def _get_element_spec(): 1005 """Fetches the element spec from the server.""" 1006 data_service_metadata = None 1007 dataset_id_val = tensor_util.constant_value(dataset_id) 1008 try: 1009 if isinstance(dataset_id_val, str) or isinstance(dataset_id_val, bytes): 1010 data_service_metadata = ( 1011 _pywrap_server_lib.TF_DATA_GetDataServiceMetadataByID( 1012 dataset_id_val, address, protocol)) 1013 else: 1014 # TODO(b/236725000): Remove this after the forward compatibility window 1015 # has passed. 1016 data_service_metadata = ( 1017 _pywrap_server_lib.TF_DATA_GetDataServiceMetadata( 1018 dataset_id_val, address, protocol)) 1019 except NotImplementedError as err: 1020 raise ValueError( 1021 "The tf.data service is running an earlier version of TensorFlow " 1022 "that requires specifying `element_spec` as an argument to " 1023 "`from_dataset_id`. Please either supply an element spec or update " 1024 "the tf.data service to the latest version.") from err 1025 except RuntimeError: 1026 # This error results from dataset ID not found. A more appropriate error 1027 # will be raised when the dataset is created. 1028 pass 1029 1030 if not data_service_metadata or not data_service_metadata.element_spec: 1031 dataset_id_val = tensor_util.constant_value(dataset_id) 1032 raise ValueError( 1033 f"Failed to fetch element spec for dataset id {dataset_id_val} from " 1034 "tf.data service. If the dataset was registered in graph mode or " 1035 "inside a tf.function, the `element_spec` must be specified as an " 1036 "argument to `from_dataset_id`.") 1037 1038 struct_pb = nested_structure_coder.struct_pb2.StructuredValue() 1039 struct_pb.ParseFromString(data_service_metadata.element_spec) 1040 return nested_structure_coder.decode_proto(struct_pb) 1041 1042 processing_mode = _get_validated_sharding_policy(processing_mode) 1043 if isinstance(service, tuple): 1044 protocol, address = service 1045 else: 1046 protocol, address = _parse_service(service) 1047 _validate_compression(compression) 1048 if job_name is not None: 1049 if not isinstance(job_name, str) and not isinstance(job_name, ops.Tensor): 1050 raise ValueError( 1051 "`job_name` must be a string or Tensor, but `job_name` was of type " 1052 f"{type(job_name)}. job_name={job_name}.") 1053 1054 if not element_spec: 1055 if not context.executing_eagerly(): 1056 raise ValueError( 1057 "In graph mode `element_spec` must be provided manually.") 1058 element_spec = _get_element_spec() 1059 1060 dataset = _DataServiceDataset( 1061 dataset_id=dataset_id, 1062 processing_mode=processing_mode, 1063 address=address, 1064 element_spec=element_spec, 1065 protocol=protocol, 1066 data_transfer_protocol=data_transfer_protocol, 1067 job_name=job_name, 1068 consumer_index=consumer_index, 1069 num_consumers=num_consumers, 1070 max_outstanding_requests=max_outstanding_requests, 1071 task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, 1072 cross_trainer_cache=cross_trainer_cache, 1073 target_workers=target_workers) 1074 1075 # Disable autosharding for shared jobs. 1076 if job_name is not None: 1077 options = options_lib.Options() 1078 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1079 dataset = dataset.with_options(options) 1080 return dataset 1081 1082 1083@tf_export("data.experimental.service.from_dataset_id") 1084def from_dataset_id(processing_mode, 1085 service, 1086 dataset_id, 1087 element_spec=None, 1088 job_name=None, 1089 consumer_index=None, 1090 num_consumers=None, 1091 max_outstanding_requests=None, 1092 data_transfer_protocol=None, 1093 cross_trainer_cache=None, 1094 target_workers="AUTO"): 1095 """Creates a dataset which reads data from the tf.data service. 1096 1097 This is useful when the dataset is registered by one process, then used in 1098 another process. When the same process is both registering and reading from 1099 the dataset, it is simpler to use `tf.data.experimental.service.distribute` 1100 instead. 1101 1102 Before using `from_dataset_id`, the dataset must have been registered with the 1103 tf.data service using `tf.data.experimental.service.register_dataset`. 1104 `register_dataset` returns a dataset id for the registered dataset. That is 1105 the `dataset_id` which should be passed to `from_dataset_id`. 1106 1107 The `element_spec` argument indicates the `tf.TypeSpec`s for the elements 1108 produced by the dataset. Currently `element_spec` must be explicitly 1109 specified, and match the dataset registered under `dataset_id`. `element_spec` 1110 defaults to `None` so that in the future we can support automatically 1111 discovering the `element_spec` by querying the tf.data service. 1112 1113 `tf.data.experimental.service.distribute` is a convenience method which 1114 combines `register_dataset` and `from_dataset_id` into a dataset 1115 transformation. 1116 See the documentation for `tf.data.experimental.service.distribute` for more 1117 detail about how `from_dataset_id` works. 1118 1119 >>> dispatcher = tf.data.experimental.service.DispatchServer() 1120 >>> dispatcher_address = dispatcher.target.split("://")[1] 1121 >>> worker = tf.data.experimental.service.WorkerServer( 1122 ... tf.data.experimental.service.WorkerConfig( 1123 ... dispatcher_address=dispatcher_address)) 1124 >>> dataset = tf.data.Dataset.range(10) 1125 >>> dataset_id = tf.data.experimental.service.register_dataset( 1126 ... dispatcher.target, dataset) 1127 >>> dataset = tf.data.experimental.service.from_dataset_id( 1128 ... processing_mode="parallel_epochs", 1129 ... service=dispatcher.target, 1130 ... dataset_id=dataset_id, 1131 ... element_spec=dataset.element_spec) 1132 >>> print(list(dataset.as_numpy_iterator())) 1133 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 1134 1135 Args: 1136 processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying 1137 how to shard the dataset among tf.data workers. See 1138 `tf.data.experimental.service.ShardingPolicy` for details. For backwards 1139 compatibility, `processing_mode` may also be set to the strings 1140 `"parallel_epochs"` or `"distributed_epoch"`, which are respectively 1141 equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`. 1142 service: A string or a tuple indicating how to connect to the tf.data 1143 service. If it's a string, it should be in the format 1144 `[<protocol>://]<address>`, where `<address>` identifies the dispatcher 1145 address and `<protocol>` can optionally be used to override the default 1146 protocol to use. If it's a tuple, it should be (protocol, address). 1147 dataset_id: The id of the dataset to read from. This id is returned by 1148 `register_dataset` when the dataset is registered with the tf.data 1149 service. 1150 element_spec: A nested structure of `tf.TypeSpec`s representing the type of 1151 elements produced by the dataset. This argument is only required inside a 1152 tf.function. Use `tf.data.Dataset.element_spec` to get the element spec 1153 for a given dataset. 1154 job_name: (Optional.) The name of the job. If provided, it must be a 1155 non-empty string. This argument makes it possible for multiple datasets to 1156 share the same job. The default behavior is that the dataset creates 1157 anonymous, exclusively owned jobs. 1158 consumer_index: (Optional.) The index of the consumer in the range from `0` 1159 to `num_consumers`. Must be specified alongside `num_consumers`. When 1160 specified, consumers will read from the job in a strict round-robin order, 1161 instead of the default first-come-first-served order. 1162 num_consumers: (Optional.) The number of consumers which will consume from 1163 the job. Must be specified alongside `consumer_index`. When specified, 1164 consumers will read from the job in a strict round-robin order, instead of 1165 the default first-come-first-served order. When `num_consumers` is 1166 specified, the dataset must have infinite cardinality to prevent a 1167 producer from running out of data early and causing consumers to go out of 1168 sync. 1169 max_outstanding_requests: (Optional.) A limit on how many elements may be 1170 requested at the same time. You can use this option to control the amount 1171 of memory used, since `distribute` won't use more than `element_size` * 1172 `max_outstanding_requests` of memory. 1173 data_transfer_protocol: (Optional.) The protocol to use for transferring 1174 data with the tf.data service. By default, data is transferred using gRPC. 1175 cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is 1176 provided, dataset iteration will be shared across concurrently running 1177 trainers. See 1178 https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers 1179 for details. 1180 target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data 1181 runtime decides which workers to read from. If `"ANY"`, reads from any 1182 tf.data service workers. If `"LOCAL"`, only reads from local in-processs 1183 tf.data service workers. `"AUTO"` works well for most cases, while users 1184 can specify other targets. For example, `"LOCAL"` helps avoid RPCs and 1185 data copy if every TF worker colocates with a tf.data service worker. 1186 Consumers of a shared job must use the same `target_workers`. Defaults to 1187 `"AUTO"`. 1188 1189 Returns: 1190 A `tf.data.Dataset` which reads from the tf.data service. 1191 """ 1192 _validate_job_name(job_name) 1193 if job_name is not None: 1194 job_name = string_ops.string_join( 1195 ["dataset_id=", _to_string(dataset_id), job_name], "/") 1196 1197 return _from_dataset_id( 1198 processing_mode=processing_mode, 1199 service=service, 1200 dataset_id=dataset_id, 1201 element_spec=element_spec, 1202 job_name=job_name, 1203 consumer_index=consumer_index, 1204 num_consumers=num_consumers, 1205 max_outstanding_requests=max_outstanding_requests, 1206 data_transfer_protocol=data_transfer_protocol, 1207 cross_trainer_cache=cross_trainer_cache, 1208 target_workers=target_workers) 1209