xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/data_service_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python API for executing a tf.data.Dataset using a tf.data service."""
16
17import enum
18import functools
19
20from tensorflow.core.protobuf import data_service_pb2
21from tensorflow.python import tf2
22from tensorflow.python.compat import compat
23from tensorflow.python.data.experimental.ops import compression_ops
24from tensorflow.python.data.experimental.service import _pywrap_server_lib
25from tensorflow.python.data.experimental.service import _pywrap_utils
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import options as options_lib
28from tensorflow.python.data.ops import structured_function
29from tensorflow.python.data.ops.options import AutoShardPolicy
30from tensorflow.python.data.ops.options import ExternalStatePolicy
31from tensorflow.python.eager import context
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.ops import gen_experimental_dataset_ops
37from tensorflow.python.ops import string_ops
38from tensorflow.python.util import lazy_loader
39from tensorflow.python.util.tf_export import tf_export
40
41COMPRESSION_AUTO = "AUTO"
42COMPRESSION_NONE = None
43_PARALLEL_EPOCHS = "parallel_epochs"
44_DISTRIBUTED_EPOCH = "distributed_epoch"
45
46# TODO(b/176933539): Use the regular import.
47# TODO(b/238903802): Use TypeSpec serialization methods directly.
48nested_structure_coder = lazy_loader.LazyLoader(
49    "nested_structure_coder", globals(),
50    "tensorflow.python.saved_model.nested_structure_coder")
51
52
53@tf_export("data.experimental.service.ShardingPolicy")
54class ShardingPolicy(enum.IntEnum):
55  """Specifies how to shard data among tf.data service workers.
56
57  OFF: No sharding will be performed. Each worker produces the entire dataset
58  without any sharding. With this mode, the best practice is to shuffle the
59  dataset nondeterministically so that workers process the dataset in different
60  orders. If workers are restarted or join the cluster mid-job, they will begin
61  processing the dataset from the beginning.
62
63  DYNAMIC: The input dataset is dynamically split among workers at runtime. Each
64  worker gets the next split when it reads data from the dispatcher. Data is
65  produced non-deterministically in this mode. Dynamic sharding works well with
66  varying-sized tf.data service clusters, e.g., when you need to auto-scale your
67  workers. Dynamic sharding provides at-most once visitation guarantees. No
68  examples will be repeated, but some may be missed if a tf.data service worker
69  gets restarted while processing a file.
70
71  The following are static sharding policies. The semantics are similar to
72  `tf.data.experimental.AutoShardPolicy`. These policies require:
73  * The tf.data service cluster is configured with a fixed list of workers
74    in DispatcherConfig.
75  * Each client only reads from the local tf.data service worker.
76
77  If a worker is restarted while performing static sharding, the worker will
78  begin processing its shard again from the beginning.
79
80  FILE: Shards by input files (i.e. each worker will get a fixed set of files to
81  process). When this option is selected, make sure that there is at least as
82  many files as workers. If there are fewer input files than workers, a runtime
83  error will be raised.
84
85  DATA: Shards by elements produced by the dataset. Each worker will process the
86  whole dataset and discard the portion that is not for itself. Note that for
87  this mode to correctly partition the dataset elements, the dataset needs to
88  produce elements in a deterministic order.
89
90  FILE_OR_DATA: Attempts FILE-based sharding, falling back to DATA-based
91  sharding on failure.
92
93  HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a
94  placeholder to replace with `shard(num_workers, worker_index)`.
95  """
96
97  # LINT.IfChange(tf_data_service_sharding_policy)
98  OFF = 0
99  DYNAMIC = 1
100  FILE = 2
101  DATA = 3
102  FILE_OR_DATA = 4
103  HINT = 5
104  # LINT.ThenChange()
105
106  def _to_proto(self):
107    """Converts the policy to ProcessingModeDef proto enum."""
108
109    if self == ShardingPolicy.OFF:
110      return data_service_pb2.ProcessingModeDef.OFF
111    if self == ShardingPolicy.DYNAMIC:
112      return data_service_pb2.ProcessingModeDef.DYNAMIC
113    if self == ShardingPolicy.FILE:
114      return data_service_pb2.ProcessingModeDef.FILE
115    if self == ShardingPolicy.DATA:
116      return data_service_pb2.ProcessingModeDef.DATA
117    if self == ShardingPolicy.FILE_OR_DATA:
118      return data_service_pb2.ProcessingModeDef.FILE_OR_DATA
119    if self == ShardingPolicy.HINT:
120      return data_service_pb2.ProcessingModeDef.HINT
121    raise ValueError(f"Unable to convert sharding policy {self!r} to proto.")
122
123
124@tf_export("data.experimental.service.CrossTrainerCache")
125class CrossTrainerCache:
126  """Options related to the tf.data service cross trainer cache.
127
128  This is used to enable cross-trainer cache when distributing a dataset. For
129  example:
130
131  ```
132  dataset = dataset.apply(tf.data.experimental.service.distribute(
133      processing_mode=tf.data.experimental.service.ShardingPolicy.OFF,
134      service=FLAGS.tf_data_service_address,
135      job_name="job",
136      cross_trainer_cache=data_service_ops.CrossTrainerCache(
137          trainer_id=trainer_id())))
138  ```
139
140  For more details, refer to
141  https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
142  """
143
144  def __init__(self, trainer_id):
145    """Constructs a CrossTrainerCache.
146
147    Args:
148      trainer_id: Each training job has a unique ID. Once a job has consumed
149      data, the data remains in the cache and is re-used by jobs with different
150      `trainer_id`s. Requests with the same `trainer_id` do not re-use data.
151
152    Raises:
153      ValueError if `trainer_id` is empty.
154    """
155    if not trainer_id:
156      raise ValueError(
157          "tf.data service cross-trainer cache requires a non-empty trainer ID."
158      )
159    self.trainer_id = trainer_id
160
161  def _to_proto(self):
162    return data_service_pb2.CrossTrainerCacheOptions(trainer_id=self.trainer_id)
163
164
165def _get_validated_sharding_policy(processing_mode):
166  """Validates `processing_mode` and converts it to ShardingPolicy."""
167
168  if isinstance(processing_mode, ShardingPolicy):
169    return processing_mode
170  if processing_mode == _PARALLEL_EPOCHS:
171    return ShardingPolicy.OFF
172  if processing_mode == _DISTRIBUTED_EPOCH:
173    return ShardingPolicy.DYNAMIC
174
175  raise ValueError("tf.data service processing mode should be a "
176                   "`tf.data.experimental.service.ShardingPolicy`, "
177                   "`\"parallel_epochs\"`, or `\"distributed_epoch\"`. Got "
178                   f"{processing_mode!r}.")
179
180
181def _validate_job_name(job_name):
182  if job_name is None:
183    return
184  if not isinstance(job_name, str):
185    raise ValueError("`job_name` must be a string, but `job_name` was of type "
186                     f"{type(job_name)}. job_name={job_name}")
187  if not job_name:
188    raise ValueError("`job_name` must not be empty")
189
190
191def _validate_compression(compression):
192  valid_compressions = [COMPRESSION_AUTO, COMPRESSION_NONE]
193  if compression not in valid_compressions:
194    raise ValueError(f"Invalid `compression` argument: {compression}. "
195                     f"Must be one of {valid_compressions}.")
196
197
198def _get_compression_proto(compression):
199  if compression == COMPRESSION_AUTO:
200    return data_service_pb2.DataServiceMetadata.COMPRESSION_SNAPPY
201  if compression == COMPRESSION_NONE:
202    return data_service_pb2.DataServiceMetadata.COMPRESSION_OFF
203  raise ValueError(f"Invalid `compression` argument: {compression}. "
204                   f"Must be one of {[COMPRESSION_AUTO, COMPRESSION_NONE]}.")
205
206
207def _decide_compression(compression, data_transfer_protocol):
208  if (compression == COMPRESSION_AUTO and data_transfer_protocol != "grpc" and
209      data_transfer_protocol is not None):
210    return COMPRESSION_NONE
211  return compression
212
213
214def _to_tensor(dataset_id):
215  """Converts `dataset_id` to Tensor."""
216
217  if isinstance(dataset_id, ops.Tensor):
218    return dataset_id
219  if isinstance(dataset_id, str) or isinstance(dataset_id, bytes):
220    return ops.convert_to_tensor(
221        dataset_id, dtype=dtypes.string, name="dataset_id")
222  return ops.convert_to_tensor(
223      dataset_id, dtype=dtypes.int64, name="dataset_id")
224
225
226def _to_string(dataset_id):
227  """Converts `dataset_id` to string."""
228
229  if isinstance(dataset_id, ops.Tensor):
230    return (dataset_id if dataset_id.dtype == dtypes.string else
231            string_ops.as_string(dataset_id))
232  return (dataset_id.decode()
233          if isinstance(dataset_id, bytes) else str(dataset_id))
234
235
236class _DataServiceDatasetV2(dataset_ops.DatasetSource):
237  """A `Dataset` that reads elements from the tf.data service."""
238
239  def __init__(self,
240               dataset_id,
241               processing_mode,
242               address,
243               element_spec,
244               protocol,
245               data_transfer_protocol,
246               job_name=None,
247               consumer_index=None,
248               num_consumers=None,
249               max_outstanding_requests=None,
250               task_refresh_interval_hint_ms=None,
251               cross_trainer_cache=None,
252               target_workers="AUTO"):
253    """Constructs a _DataServiceDatasetV2.
254
255    Args:
256      dataset_id: The dataset id for the dataset to read from.
257      processing_mode: A `tf.data.experimental.service.ShardingPolicy`
258        specifying how to shard the dataset among tf.data workers. See
259        `tf.data.experimental.service.ShardingPolicy` for details. For backwards
260        compatibility, `processing_mode` may also be set to the strings
261        `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
262        equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
263      address: The tf.data service address, e.g. "localhost:5000".
264      element_spec: The dataset element spec for the dataset to read from.
265      protocol: The protocol to use for communicating with the tf.data service,
266        e.g. "grpc".
267      data_transfer_protocol: (Optional.) The protocol to use for transferring
268        data with the tf.data service. By default, data is transferred using
269        gRPC.
270      job_name: (Optional.) The name of the job. If provided, it must be a
271        non-empty string or Tensor. This argument makes it possible for multiple
272        datasets to share the same job. The default behavior is that the dataset
273        creates anonymous, exclusively owned jobs.
274      consumer_index: (Optional.) The index of the consumer in the range from
275        `0` to `num_consumers`. Must be specified alongside `num_consumers`.
276        When specified, consumers will read from the job in a strict round-robin
277        order, instead of the default first-come-first-served order.
278      num_consumers: (Optional.) The number of consumers which will consume from
279        the job. Must be specified alongside `consumer_index`. When specified,
280        consumers will read from the job in a strict round-robin order, instead
281        of the default first-come-first-served order. When `num_consumers` is
282        specified, the dataset must have infinite cardinality to prevent a
283        producer from running out of data early and causing consumers to go out
284        of sync.
285      max_outstanding_requests: (Optional.) A limit on how many elements may be
286        requested at the same time. You can use this option to control the
287        amount of memory used, since `distribute` won't use more than
288        `element_size` * `max_outstanding_requests` of memory.
289      task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
290        the dispatcher for task changes.
291      cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
292        provided, dataset iteration will be shared across concurrently running
293        trainers. See
294        https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
295        for details.
296      target_workers: (Optional.) Which workers to read from. If `"AUTO"`,
297        tf.data runtime decides which workers to read from. If `"ANY"`, reads
298        from any tf.data service workers. If `"LOCAL"`, only reads from local
299        in-processs tf.data service workers. `"AUTO"` works well for most cases,
300        while users can specify other targets. For example, `"LOCAL"` helps
301        avoid RPCs and data copy if every TF worker colocates with a tf.data
302        service worker. Consumers of a shared job must use the same
303        `target_workers`. Defaults to `"AUTO"`.
304    """
305    if consumer_index is None != num_consumers is None:
306      raise ValueError(
307          "Must either set both `consumer_index` and `num_consumers`, "
308          "or neither. ",
309          f"consumer_index={consumer_index}, num_consumers={num_consumers}")
310    if num_consumers is not None and job_name is None:
311      raise ValueError("`job_name` must be set when setting `num_consumers`. "
312                       f"num_consumers was set to {num_consumers}.")
313
314    processing_mode_def = data_service_pb2.ProcessingModeDef(
315        sharding_policy=_get_validated_sharding_policy(
316            processing_mode)._to_proto())
317    if job_name is None:
318      job_name = ""
319    if max_outstanding_requests is None:
320      max_outstanding_requests = dataset_ops.AUTOTUNE
321    if task_refresh_interval_hint_ms is None:
322      task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
323
324    self._dataset_id = _to_tensor(dataset_id)
325    self._processing_mode = ops.convert_to_tensor(
326        processing_mode_def.SerializeToString(),
327        dtype=dtypes.string,
328        name="processing_mode")
329    self._address = ops.convert_to_tensor(
330        address, dtype=dtypes.string, name="address")
331    self._protocol = ops.convert_to_tensor(
332        protocol, dtype=dtypes.string, name="protocol")
333    self._job_name = ops.convert_to_tensor(
334        job_name, dtype=dtypes.string, name="job_name")
335    self._consumer_index = ops.convert_to_tensor(
336        -1 if consumer_index is None else consumer_index,
337        dtype=dtypes.int64,
338        name="consumer_index")
339    self._num_consumers = ops.convert_to_tensor(
340        -1 if num_consumers is None else num_consumers,
341        dtype=dtypes.int64,
342        name="num_consumers")
343    self._max_outstanding_requests = ops.convert_to_tensor(
344        max_outstanding_requests,
345        dtype=dtypes.int64,
346        name="max_outstanding_requests")
347    self._element_spec = element_spec
348    uncompress_func = structured_function.StructuredFunctionWrapper(
349        lambda x: compression_ops.uncompress(x, output_spec=element_spec),
350        transformation_name="DataServiceDataset.uncompress()",
351        input_structure=tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant))
352    cross_trainer_cache_options = (
353        cross_trainer_cache._to_proto().SerializeToString()
354        if cross_trainer_cache else None)
355
356    compat_kwargs = {}
357    if data_transfer_protocol is not None:
358      compat_kwargs["data_transfer_protocol"] = data_transfer_protocol
359
360    if (compat.forward_compatible(2022, 8, 31) or
361        self._dataset_id.dtype == dtypes.string):
362      data_service_dataset = (
363          gen_experimental_dataset_ops.data_service_dataset_v4)
364    else:
365      data_service_dataset = (
366          gen_experimental_dataset_ops.data_service_dataset_v3)
367
368    # If `uncompress` is `True`, the dataset will query the servers to find
369    # out the actual compression used. It is always set to `True` the first
370    # time the graph is built, and set to false when serializing, so we will
371    # uncompress at most once.
372    uncompress = True
373    variant_tensor = data_service_dataset(
374        dataset_id=self._dataset_id,
375        processing_mode=self._processing_mode,
376        address=self._address,
377        protocol=self._protocol,
378        job_name=self._job_name,
379        consumer_index=self._consumer_index,
380        num_consumers=self._num_consumers,
381        max_outstanding_requests=self._max_outstanding_requests,
382        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
383        iteration_counter=(
384            gen_experimental_dataset_ops.dummy_iteration_counter()),
385        target_workers=target_workers,
386        uncompress=uncompress,
387        uncompress_fn=uncompress_func.function,
388        cross_trainer_cache_options=cross_trainer_cache_options,
389        **compat_kwargs,
390        **self._flat_structure)
391    super(_DataServiceDatasetV2, self).__init__(variant_tensor)
392
393  @property
394  def element_spec(self):
395    return self._element_spec
396
397
398class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
399  """A `Dataset` that executes its input through the tf.data service."""
400
401  @functools.wraps(_DataServiceDatasetV2.__init__)
402  def __init__(self, dataset_id, processing_mode, address, element_spec,
403               protocol, data_transfer_protocol, job_name, consumer_index,
404               num_consumers, max_outstanding_requests,
405               task_refresh_interval_hint_ms, cross_trainer_cache,
406               target_workers):
407
408    self._wrapped = _DataServiceDatasetV2(
409        dataset_id=dataset_id,
410        processing_mode=processing_mode,
411        address=address,
412        element_spec=element_spec,
413        protocol=protocol,
414        data_transfer_protocol=data_transfer_protocol,
415        job_name=job_name,
416        consumer_index=consumer_index,
417        num_consumers=num_consumers,
418        max_outstanding_requests=max_outstanding_requests,
419        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
420        cross_trainer_cache=cross_trainer_cache,
421        target_workers=target_workers)
422    super(_DataServiceDatasetV1, self).__init__(self._wrapped)
423
424
425if tf2.enabled():
426  _DataServiceDataset = _DataServiceDatasetV2
427else:
428  _DataServiceDataset = _DataServiceDatasetV1
429
430
431def _parse_service(service):
432  """Converts a tf.data service string into a (protocol, address) tuple.
433
434  Args:
435    service: A string in the format "protocol://address" or just "address". If
436      the string is only an address, the default protocol will be used.
437
438  Returns:
439    The (protocol, address) tuple
440  """
441  if not isinstance(service, str):
442    raise ValueError("`service` must be a string, but `service` was of type "
443                     f"{type(service)}. service={service}")
444  if not service:
445    raise ValueError("`service` must not be empty")
446  parts = service.split("://")
447  if len(parts) == 2:
448    protocol, address = parts
449  elif len(parts) == 1:
450    address = parts[0]
451    protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
452  else:
453    raise ValueError("Malformed `service` string has multiple '://': "
454                     f"{service}.")
455  # TODO(aaudibert): Considering validating reachability of address here.
456  return (protocol, address)
457
458
459def _distribute(processing_mode,
460                service,
461                job_name=None,
462                consumer_index=None,
463                num_consumers=None,
464                max_outstanding_requests=None,
465                task_refresh_interval_hint_ms=None,
466                data_transfer_protocol=None,
467                compression="AUTO",
468                cross_trainer_cache=None,
469                target_workers="AUTO"):
470  """A transformation that moves dataset processing to the tf.data service.
471
472  This transformation is similar to `distribute`, but supports additional
473  parameters which we do not yet want to add to the public Python API.
474
475  Args:
476    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
477      how to shard the dataset among tf.data workers. See
478      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
479      compatibility, `processing_mode` may also be set to the strings
480      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
481      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
482    service: A string or a tuple indicating how to connect to the tf.data
483      service. If it's a string, it should be in the format
484      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
485        address and `<protocol>` can optionally be used to override the default
486        protocol to use. If it's a tuple, it should be (protocol, address).
487    job_name: (Optional.) The name of the job. If provided, it must be a
488      non-empty string. This argument makes it possible for multiple datasets to
489      share the same job. The default behavior is that the dataset creates
490      anonymous, exclusively owned jobs.
491    consumer_index: (Optional.) The index of the consumer in the range from `0`
492      to `num_consumers`. Must be specified alongside `num_consumers`. When
493      specified, consumers will read from the job in a strict round-robin order,
494      instead of the default first-come-first-served order.
495    num_consumers: (Optional.) The number of consumers which will consume from
496      the job. Must be specified alongside `consumer_index`. When specified,
497      consumers will read from the job in a strict round-robin order, instead of
498      the default first-come-first-served order. When `num_consumers` is
499      specified, the dataset must have infinite cardinality to prevent a
500      producer from running out of data early and causing consumers to go out of
501      sync.
502    max_outstanding_requests: (Optional.) A limit on how many elements may be
503      requested at the same time. You can use this option to control the amount
504      of memory used, since `distribute` won't use more than `element_size` *
505      `max_outstanding_requests` of memory.
506    task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
507      dispatcher for task changes.
508    data_transfer_protocol: (Optional.) The protocol to use for transferring
509      data with the tf.data service. By default, data is transferred using gRPC.
510    compression: How to compress the dataset's elements before transferring them
511      over the network. "AUTO" leaves the decision of how to compress up to the
512      tf.data service runtime. `None` indicates not to compress.
513    cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
514      provided, dataset iteration will be shared across concurrently running
515      trainers. See
516      https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
517      for details.
518    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
519      runtime decides which workers to read from. If `"ANY"`, reads from any
520      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
521      tf.data service workers. `"AUTO"` works well for most cases, while users
522      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
523      data copy if every TF worker colocates with a tf.data service worker.
524      Consumers of a shared job must use the same `target_workers`. Defaults to
525      `"AUTO"`.
526
527  Returns:
528    Dataset: A `Dataset` of the elements produced by the data service.
529  """
530  processing_mode = _get_validated_sharding_policy(processing_mode)
531  _validate_compression(compression)
532  compression = _decide_compression(compression, data_transfer_protocol)
533
534  def _apply_fn(dataset):  # pylint: disable=missing-docstring
535    dataset_id = _register_dataset(service, dataset, compression=compression)
536    return _from_dataset_id(
537        processing_mode,
538        service,
539        dataset_id,
540        dataset.element_spec,
541        job_name=job_name,
542        consumer_index=consumer_index,
543        num_consumers=num_consumers,
544        max_outstanding_requests=max_outstanding_requests,
545        task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
546        data_transfer_protocol=data_transfer_protocol,
547        compression=compression,
548        cross_trainer_cache=cross_trainer_cache,
549        target_workers=target_workers)
550
551  return _apply_fn
552
553
554@tf_export("data.experimental.service.distribute")
555def distribute(processing_mode,
556               service,
557               job_name=None,
558               consumer_index=None,
559               num_consumers=None,
560               max_outstanding_requests=None,
561               data_transfer_protocol=None,
562               compression="AUTO",
563               cross_trainer_cache=None,
564               target_workers="AUTO"):
565  """A transformation that moves dataset processing to the tf.data service.
566
567  When you iterate over a dataset containing the `distribute` transformation,
568  the tf.data service creates a "job" which produces data for the dataset
569  iteration.
570
571  The tf.data service uses a cluster of workers to prepare data for training
572  your model.
573  The `processing_mode` argument to `tf.data.experimental.service.distribute`
574  describes how to leverage multiple workers to process the input dataset.
575  Currently, there are two processing modes to choose from: "distributed_epoch"
576  and "parallel_epochs".
577
578  "distributed_epoch" means that the dataset will be split across all tf.data
579  service workers.
580  The dispatcher produces "splits" for the dataset and sends them to workers for
581  further processing. For example, if a dataset begins with a list of filenames,
582  the dispatcher will iterate through the filenames and send the filenames to
583  tf.data workers, which will perform the rest of the dataset transformations on
584  those files. "distributed_epoch" is useful when your model needs to see each
585  element of the dataset exactly once, or if it needs to see the data in a
586  generally-sequential order. "distributed_epoch" only works for datasets with
587  splittable sources, such as `Dataset.from_tensor_slices`,
588  `Dataset.list_files`, or `Dataset.range`.
589
590  "parallel_epochs" means that the entire input dataset will be processed
591  independently by each of the tf.data service workers.
592  For this reason, it is important to shuffle data (e.g. filenames)
593  non-deterministically, so that each worker will process the elements of the
594  dataset in a different order. "parallel_epochs" can be used to distribute
595  datasets that aren't splittable.
596
597  With two workers, "parallel_epochs" will produce every element of the dataset
598  twice:
599
600  >>> dispatcher = tf.data.experimental.service.DispatchServer()
601  >>> dispatcher_address = dispatcher.target.split("://")[1]
602  >>> # Start two workers
603  >>> workers = [
604  ...     tf.data.experimental.service.WorkerServer(
605  ...         tf.data.experimental.service.WorkerConfig(
606  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
607  ... ]
608  >>> dataset = tf.data.Dataset.range(10)
609  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
610  ...     processing_mode="parallel_epochs", service=dispatcher.target))
611  >>> print(sorted(list(dataset.as_numpy_iterator())))
612  [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]
613
614  "distributed_epoch", on the other hand, will still produce each element once:
615
616  >>> dispatcher = tf.data.experimental.service.DispatchServer()
617  >>> dispatcher_address = dispatcher.target.split("://")[1]
618  >>> workers = [
619  ...     tf.data.experimental.service.WorkerServer(
620  ...         tf.data.experimental.service.WorkerConfig(
621  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
622  ... ]
623  >>> dataset = tf.data.Dataset.range(10)
624  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
625  ...     processing_mode="distributed_epoch", service=dispatcher.target))
626  >>> print(sorted(list(dataset.as_numpy_iterator())))
627  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
628
629  When using `apply(tf.data.experimental.service.distribute(...))`, the dataset
630  before the `apply` transformation executes within the tf.data service, while
631  the operations after `apply` happen within the local process.
632
633  >>> dispatcher = tf.data.experimental.service.DispatchServer()
634  >>> dispatcher_address = dispatcher.target.split("://")[1]
635  >>> workers = [
636  ...     tf.data.experimental.service.WorkerServer(
637  ...         tf.data.experimental.service.WorkerConfig(
638  ...             dispatcher_address=dispatcher_address)) for _ in range(2)
639  ... ]
640  >>> dataset = tf.data.Dataset.range(5)
641  >>> dataset = dataset.map(lambda x: x*x)
642  >>> dataset = dataset.apply(
643  ...    tf.data.experimental.service.distribute("parallel_epochs",
644  ...                                            dispatcher.target))
645  >>> dataset = dataset.map(lambda x: x+1)
646  >>> print(sorted(list(dataset.as_numpy_iterator())))
647  [1, 1, 2, 2, 5, 5, 10, 10, 17, 17]
648
649  In the above example, the dataset operations (before applying the `distribute`
650  function on the elements) will be executed on the tf.data workers,
651  and the elements are provided over RPC. The remaining transformations
652  (after the call to `distribute`) will be executed locally. The dispatcher
653  and the workers will bind to usused free ports (which are chosen at random),
654  in order to communicate with each other. However, to bind them to specific
655  ports, the `port` parameter can be passed.
656
657  The `job_name` argument allows jobs to be shared across multiple
658  datasets. Instead of each dataset creating its own job, all
659  datasets with the same `job_name` will consume from the same job. A new job
660  will be created for each iteration of the dataset (with each repetition of
661  `Dataset.repeat` counting as a new iteration). Suppose the `DispatchServer`
662  is serving on `localhost:5000` and two training workers (in either a single
663  client or multi-client setup) iterate over the below dataset, and there is a
664  single tf.data worker:
665
666  ```
667  range5_dataset = tf.data.Dataset.range(5)
668  dataset = range5_dataset.apply(tf.data.experimental.service.distribute(
669      "parallel_epochs", "localhost:5000", job_name="my_job_name"))
670  for iteration in range(3):
671    print(list(dataset))
672  ```
673
674  The elements of each job will be split between the two processes, with
675  elements being consumed by the processes on a first-come first-served basis.
676  One possible result is that process 1 prints
677
678  ```
679  [0, 2, 4]
680  [0, 1, 3]
681  [1]
682  ```
683
684  and process 2 prints
685
686  ```
687  [1, 3]
688  [2, 4]
689  [0, 2, 3, 4]
690  ```
691
692  Job names must not be re-used across different training jobs within the
693  lifetime of the tf.data service. In general, the tf.data service is expected
694  to live for the duration of a single training job.
695  To use the tf.data service with multiple training jobs, make sure to use
696  different job names to avoid conflicts. For example, suppose a training job
697  calls `distribute` with `job_name="job"` and reads until end of input. If
698  another independent job connects to the same tf.data service and tries to read
699  from `job_name="job"`, it will immediately receive end of input, without
700  getting any data.
701
702  **Coordinated data read**
703
704  By default, when multiple consumers read from the same job, they receive data
705  on a first-come first-served basis. In some use cases, it is advantageous to
706  coordinate the consumers. At each step, consumers read data from the same
707  worker.
708
709  For example, the tf.data service can be used to coordinate example sizes
710  across a cluster during synchronous training, so that during each step all
711  replicas train on similar-sized elements. To achieve this, define a dataset
712  which generates rounds of `num_consumers` consecutive similar-sized batches,
713  then enable coordinated reads by setting `consumer_index` and `num_consumers`.
714
715  NOTE: To keep consumers in sync, round robin data consumption requires that
716  the dataset have infinite cardinality. You can get this by adding `.repeat()`
717  at the end of the dataset definition.
718
719  **Keras and Distribution Strategies**
720
721  The dataset produced by the `distribute` transformation can be passed to
722  Keras' `Model.fit` or Distribution Strategy's
723  `tf.distribute.Strategy.experimental_distribute_dataset` like any other
724  `tf.data.Dataset`. We recommend setting a `job_name` on the call to
725  `distribute` so that if there are multiple workers, they read data from the
726  same job. Note that the autosharding normally performed by
727  `experimental_distribute_dataset` will be disabled when setting a `job_name`,
728  since sharing the job already results in splitting data across the workers.
729  When using a shared job, data will be dynamically balanced across workers, so
730  that they reach end of input about the same time. This results in better
731  worker utilization than with autosharding, where each worker processes an
732  independent set of files, and some workers may run out of data earlier than
733  others.
734
735  Args:
736    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
737      how to shard the dataset among tf.data workers. See
738      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
739      compatibility, `processing_mode` may also be set to the strings
740      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
741      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
742    service: A string or a tuple indicating how to connect to the tf.data
743      service. If it's a string, it should be in the format
744      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
745        address and `<protocol>` can optionally be used to override the default
746        protocol to use. If it's a tuple, it should be (protocol, address).
747    job_name: (Optional.) The name of the job. If provided, it must be a
748      non-empty string. This argument makes it possible for multiple datasets to
749      share the same job. The default behavior is that the dataset creates
750      anonymous, exclusively owned jobs.
751    consumer_index: (Optional.) The index of the consumer in the range from `0`
752      to `num_consumers`. Must be specified alongside `num_consumers`. When
753      specified, consumers will read from the job in a strict round-robin order,
754      instead of the default first-come-first-served order.
755    num_consumers: (Optional.) The number of consumers which will consume from
756      the job. Must be specified alongside `consumer_index`. When specified,
757      consumers will read from the job in a strict round-robin order, instead of
758      the default first-come-first-served order. When `num_consumers` is
759      specified, the dataset must have infinite cardinality to prevent a
760      producer from running out of data early and causing consumers to go out of
761      sync.
762    max_outstanding_requests: (Optional.) A limit on how many elements may be
763      requested at the same time. You can use this option to control the amount
764      of memory used, since `distribute` won't use more than `element_size` *
765      `max_outstanding_requests` of memory.
766    data_transfer_protocol: (Optional.) The protocol to use for transferring
767      data with the tf.data service. By default, data is transferred using gRPC.
768    compression: How to compress the dataset's elements before transferring them
769      over the network. "AUTO" leaves the decision of how to compress up to the
770      tf.data service runtime. `None` indicates not to compress.
771    cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
772      provided, dataset iteration will be shared across concurrently running
773      trainers. See
774      https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
775      for details.
776    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
777      runtime decides which workers to read from. If `"ANY"`, reads from any
778      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
779      tf.data service workers. `"AUTO"` works well for most cases, while users
780      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
781      data copy if every TF worker colocates with a tf.data service worker.
782      Consumers of a shared job must use the same `target_workers`. Defaults to
783      `"AUTO"`.
784
785  Returns:
786    Dataset: A `Dataset` of the elements produced by the data service.
787  """
788  _validate_job_name(job_name)
789  return _distribute(
790      processing_mode=processing_mode,
791      service=service,
792      job_name=job_name,
793      consumer_index=consumer_index,
794      num_consumers=num_consumers,
795      max_outstanding_requests=max_outstanding_requests,
796      data_transfer_protocol=data_transfer_protocol,
797      compression=compression,
798      cross_trainer_cache=cross_trainer_cache,
799      target_workers=target_workers)
800
801
802def _register_dataset(service, dataset, compression, dataset_id=None):
803  """Registers a dataset with the tf.data service.
804
805  This transformation is similar to `register_dataset`, but supports additional
806  parameters which we do not yet want to add to the public Python API.
807
808  Args:
809    service: A string or a tuple indicating how to connect to the tf.data
810      service. If it's a string, it should be in the format
811      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
812        address and `<protocol>` can optionally be used to override the default
813        protocol to use. If it's a tuple, it should be (protocol, address).
814    dataset: A `tf.data.Dataset` to register with the tf.data service.
815    compression: How to compress the dataset's elements before transferring them
816      over the network. "AUTO" leaves the decision of how to compress up to the
817      tf.data service runtime. `None` indicates not to compress.
818    dataset_id: (Optional.) By default, tf.data service generates a unique
819      (string) ID for each registered dataset. If a `dataset_id` is provided, it
820      will use the specified ID. If a dataset with a matching ID already exists,
821      no new dataset is registered. This is useful if multiple training jobs
822      want to (re)use the same dataset for training. In this case, they can
823      register the dataset with the same dataset ID.
824
825  Returns:
826    A scalar string tensor representing the dataset ID.
827  """
828  _validate_compression(compression)
829  if isinstance(service, tuple):
830    protocol, address = service
831  else:
832    protocol, address = _parse_service(service)
833  external_state_policy = dataset.options().experimental_external_state_policy
834  if external_state_policy is None:
835    external_state_policy = ExternalStatePolicy.WARN
836
837  encoded_spec = None
838  if context.executing_eagerly():
839    encoded_spec = nested_structure_coder.encode_structure(
840        dataset.element_spec).SerializeToString()
841
842  if compression == COMPRESSION_AUTO:
843    dataset = dataset.map(
844        lambda *x: compression_ops.compress(x),
845        num_parallel_calls=dataset_ops.AUTOTUNE)
846  dataset = dataset.prefetch(dataset_ops.AUTOTUNE)
847  dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
848
849  metadata = data_service_pb2.DataServiceMetadata(
850      element_spec=encoded_spec,
851      compression=_get_compression_proto(compression))
852
853  if compat.forward_compatible(2022, 8, 31) or dataset_id:
854    return gen_experimental_dataset_ops.register_dataset_v2(
855        dataset._variant_tensor,  # pylint: disable=protected-access
856        address=address,
857        protocol=protocol,
858        external_state_policy=external_state_policy.value,
859        requested_dataset_id=dataset_id,
860        metadata=metadata.SerializeToString())
861  else:
862    return gen_experimental_dataset_ops.register_dataset(
863        dataset._variant_tensor,  # pylint: disable=protected-access
864        address=address,
865        protocol=protocol,
866        external_state_policy=external_state_policy.value,
867        metadata=metadata.SerializeToString())
868
869
870@tf_export("data.experimental.service.register_dataset")
871def register_dataset(service, dataset, compression="AUTO", dataset_id=None):
872  """Registers a dataset with the tf.data service.
873
874  `register_dataset` registers a dataset with the tf.data service so that
875  datasets can be created later with
876  `tf.data.experimental.service.from_dataset_id`. This is useful when the
877  dataset
878  is registered by one process, then used in another process. When the same
879  process is both registering and reading from the dataset, it is simpler to use
880  `tf.data.experimental.service.distribute` instead.
881
882  If the dataset is already registered with the tf.data service,
883  `register_dataset` returns the already-registered dataset's id.
884
885  >>> dispatcher = tf.data.experimental.service.DispatchServer()
886  >>> dispatcher_address = dispatcher.target.split("://")[1]
887  >>> worker = tf.data.experimental.service.WorkerServer(
888  ...     tf.data.experimental.service.WorkerConfig(
889  ...         dispatcher_address=dispatcher_address))
890  >>> dataset = tf.data.Dataset.range(10)
891  >>> dataset_id = tf.data.experimental.service.register_dataset(
892  ...     dispatcher.target, dataset)
893  >>> dataset = tf.data.experimental.service.from_dataset_id(
894  ...     processing_mode="parallel_epochs",
895  ...     service=dispatcher.target,
896  ...     dataset_id=dataset_id,
897  ...     element_spec=dataset.element_spec)
898  >>> print(list(dataset.as_numpy_iterator()))
899  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
900
901  Args:
902    service: A string or a tuple indicating how to connect to the tf.data
903      service. If it's a string, it should be in the format
904      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
905        address and `<protocol>` can optionally be used to override the default
906        protocol to use. If it's a tuple, it should be (protocol, address).
907    dataset: A `tf.data.Dataset` to register with the tf.data service.
908    compression: (Optional.) How to compress the dataset's elements before
909      transferring them over the network. "AUTO" leaves the decision of how to
910      compress up to the tf.data service runtime. `None` indicates not to
911      compress.
912    dataset_id: (Optional.) By default, tf.data service generates a unique
913      (string) ID for each registered dataset. If a `dataset_id` is provided, it
914      will use the specified ID. If a dataset with a matching ID already exists,
915      no new dataset is registered. This is useful if multiple training jobs
916      want to (re)use the same dataset for training. In this case, they can
917      register the dataset with the same dataset ID.
918
919  Returns:
920    A scalar string tensor representing the dataset ID.
921  """
922  return _register_dataset(service, dataset, compression, dataset_id)
923
924
925def _from_dataset_id(processing_mode,
926                     service,
927                     dataset_id,
928                     element_spec,
929                     job_name=None,
930                     consumer_index=None,
931                     num_consumers=None,
932                     max_outstanding_requests=None,
933                     task_refresh_interval_hint_ms=None,
934                     data_transfer_protocol=None,
935                     compression="AUTO",
936                     cross_trainer_cache=None,
937                     target_workers="AUTO"):
938  """Creates a dataset which reads data from the tf.data service.
939
940  This transformation is similar to `from_dataset_id`, but supports additional
941  parameters which we do not yet want to add to the public Python API.
942
943  Args:
944    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
945      how to shard the dataset among tf.data workers. See
946      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
947      compatibility, `processing_mode` may also be set to the strings
948      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
949      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
950    service: A string or a tuple indicating how to connect to the tf.data
951      service. If it's a string, it should be in the format
952      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
953        address and `<protocol>` can optionally be used to override the default
954        protocol to use. If it's a tuple, it should be (protocol, address).
955    dataset_id: The id of the dataset to read from. This id is returned by
956      `register_dataset` when the dataset is registered with the tf.data
957      service.
958    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
959      elements produced by the dataset. This argument is only required inside a
960      tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
961      for a given dataset.
962    job_name: (Optional.) The name of the job. If provided, it must be a
963      non-empty string or tensor. This argument makes it possible for multiple
964      datasets to share the same job. The default behavior is that the dataset
965      creates anonymous, exclusively owned jobs.
966    consumer_index: (Optional.) The index of the consumer in the range from `0`
967      to `num_consumers`. Must be specified alongside `num_consumers`. When
968      specified, consumers will read from the job in a strict round-robin order,
969      instead of the default first-come-first-served order.
970    num_consumers: (Optional.) The number of consumers which will consume from
971      the job. Must be specified alongside `consumer_index`. When specified,
972      consumers will read from the job in a strict round-robin order, instead of
973      the default first-come-first-served order. When `num_consumers` is
974      specified, the dataset must have infinite cardinality to prevent a
975      producer from running out of data early and causing consumers to go out of
976      sync.
977    max_outstanding_requests: (Optional.) A limit on how many elements may be
978      requested at the same time. You can use this option to control the amount
979      of memory used, since `distribute` won't use more than `element_size` *
980      `max_outstanding_requests` of memory.
981    task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
982      dispatcher for task changes.
983    data_transfer_protocol: (Optional.) The protocol to use for transferring
984      data with the tf.data service. By default, data is transferred using gRPC.
985    compression: An indication of how the dataset's elements were compressed, so
986      that `from_dataset_id` can uncompress them if necessary.
987    cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
988      provided, dataset iteration will be shared across concurrently running
989      trainers. See
990      https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
991      for details.
992    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
993      runtime decides which workers to read from. If `"ANY"`, reads from any
994      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
995      tf.data service workers. `"AUTO"` works well for most cases, while users
996      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
997      data copy if every TF worker colocates with a tf.data service worker.
998      Consumers of a shared job must use the same `target_workers`. Defaults to
999      `"AUTO"`.
1000
1001  Returns:
1002    A `tf.data.Dataset` which reads from the tf.data service.
1003  """
1004  def _get_element_spec():
1005    """Fetches the element spec from the server."""
1006    data_service_metadata = None
1007    dataset_id_val = tensor_util.constant_value(dataset_id)
1008    try:
1009      if isinstance(dataset_id_val, str) or isinstance(dataset_id_val, bytes):
1010        data_service_metadata = (
1011            _pywrap_server_lib.TF_DATA_GetDataServiceMetadataByID(
1012                dataset_id_val, address, protocol))
1013      else:
1014        # TODO(b/236725000): Remove this after the forward compatibility window
1015        # has passed.
1016        data_service_metadata = (
1017            _pywrap_server_lib.TF_DATA_GetDataServiceMetadata(
1018                dataset_id_val, address, protocol))
1019    except NotImplementedError as err:
1020      raise ValueError(
1021          "The tf.data service is running an earlier version of TensorFlow "
1022          "that requires specifying `element_spec` as an argument to "
1023          "`from_dataset_id`. Please either supply an element spec or update "
1024          "the tf.data service to the latest version.") from err
1025    except RuntimeError:
1026      # This error results from dataset ID not found. A more appropriate error
1027      # will be raised when the dataset is created.
1028      pass
1029
1030    if not data_service_metadata or not data_service_metadata.element_spec:
1031      dataset_id_val = tensor_util.constant_value(dataset_id)
1032      raise ValueError(
1033          f"Failed to fetch element spec for dataset id {dataset_id_val} from "
1034          "tf.data service. If the dataset was registered in graph mode or "
1035          "inside a tf.function, the `element_spec` must be specified as an "
1036          "argument to `from_dataset_id`.")
1037
1038    struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
1039    struct_pb.ParseFromString(data_service_metadata.element_spec)
1040    return nested_structure_coder.decode_proto(struct_pb)
1041
1042  processing_mode = _get_validated_sharding_policy(processing_mode)
1043  if isinstance(service, tuple):
1044    protocol, address = service
1045  else:
1046    protocol, address = _parse_service(service)
1047  _validate_compression(compression)
1048  if job_name is not None:
1049    if not isinstance(job_name, str) and not isinstance(job_name, ops.Tensor):
1050      raise ValueError(
1051          "`job_name` must be a string or Tensor, but `job_name` was of type "
1052          f"{type(job_name)}. job_name={job_name}.")
1053
1054  if not element_spec:
1055    if not context.executing_eagerly():
1056      raise ValueError(
1057          "In graph mode `element_spec` must be provided manually.")
1058    element_spec = _get_element_spec()
1059
1060  dataset = _DataServiceDataset(
1061      dataset_id=dataset_id,
1062      processing_mode=processing_mode,
1063      address=address,
1064      element_spec=element_spec,
1065      protocol=protocol,
1066      data_transfer_protocol=data_transfer_protocol,
1067      job_name=job_name,
1068      consumer_index=consumer_index,
1069      num_consumers=num_consumers,
1070      max_outstanding_requests=max_outstanding_requests,
1071      task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
1072      cross_trainer_cache=cross_trainer_cache,
1073      target_workers=target_workers)
1074
1075  # Disable autosharding for shared jobs.
1076  if job_name is not None:
1077    options = options_lib.Options()
1078    options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1079    dataset = dataset.with_options(options)
1080  return dataset
1081
1082
1083@tf_export("data.experimental.service.from_dataset_id")
1084def from_dataset_id(processing_mode,
1085                    service,
1086                    dataset_id,
1087                    element_spec=None,
1088                    job_name=None,
1089                    consumer_index=None,
1090                    num_consumers=None,
1091                    max_outstanding_requests=None,
1092                    data_transfer_protocol=None,
1093                    cross_trainer_cache=None,
1094                    target_workers="AUTO"):
1095  """Creates a dataset which reads data from the tf.data service.
1096
1097  This is useful when the dataset is registered by one process, then used in
1098  another process. When the same process is both registering and reading from
1099  the dataset, it is simpler to use `tf.data.experimental.service.distribute`
1100  instead.
1101
1102  Before using `from_dataset_id`, the dataset must have been registered with the
1103  tf.data service using `tf.data.experimental.service.register_dataset`.
1104  `register_dataset` returns a dataset id for the registered dataset. That is
1105  the `dataset_id` which should be passed to `from_dataset_id`.
1106
1107  The `element_spec` argument indicates the `tf.TypeSpec`s for the elements
1108  produced by the dataset. Currently `element_spec` must be explicitly
1109  specified, and match the dataset registered under `dataset_id`. `element_spec`
1110  defaults to `None` so that in the future we can support automatically
1111  discovering the `element_spec` by querying the tf.data service.
1112
1113  `tf.data.experimental.service.distribute` is a convenience method which
1114  combines `register_dataset` and `from_dataset_id` into a dataset
1115  transformation.
1116  See the documentation for `tf.data.experimental.service.distribute` for more
1117  detail about how `from_dataset_id` works.
1118
1119  >>> dispatcher = tf.data.experimental.service.DispatchServer()
1120  >>> dispatcher_address = dispatcher.target.split("://")[1]
1121  >>> worker = tf.data.experimental.service.WorkerServer(
1122  ...     tf.data.experimental.service.WorkerConfig(
1123  ...         dispatcher_address=dispatcher_address))
1124  >>> dataset = tf.data.Dataset.range(10)
1125  >>> dataset_id = tf.data.experimental.service.register_dataset(
1126  ...     dispatcher.target, dataset)
1127  >>> dataset = tf.data.experimental.service.from_dataset_id(
1128  ...     processing_mode="parallel_epochs",
1129  ...     service=dispatcher.target,
1130  ...     dataset_id=dataset_id,
1131  ...     element_spec=dataset.element_spec)
1132  >>> print(list(dataset.as_numpy_iterator()))
1133  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1134
1135  Args:
1136    processing_mode: A `tf.data.experimental.service.ShardingPolicy` specifying
1137      how to shard the dataset among tf.data workers. See
1138      `tf.data.experimental.service.ShardingPolicy` for details. For backwards
1139      compatibility, `processing_mode` may also be set to the strings
1140      `"parallel_epochs"` or `"distributed_epoch"`, which are respectively
1141      equivalent to `ShardingPolicy.OFF` and `ShardingPolicy.DYNAMIC`.
1142    service: A string or a tuple indicating how to connect to the tf.data
1143      service. If it's a string, it should be in the format
1144      `[<protocol>://]<address>`, where `<address>` identifies the dispatcher
1145        address and `<protocol>` can optionally be used to override the default
1146        protocol to use. If it's a tuple, it should be (protocol, address).
1147    dataset_id: The id of the dataset to read from. This id is returned by
1148      `register_dataset` when the dataset is registered with the tf.data
1149      service.
1150    element_spec: A nested structure of `tf.TypeSpec`s representing the type of
1151      elements produced by the dataset. This argument is only required inside a
1152      tf.function. Use `tf.data.Dataset.element_spec` to get the element spec
1153      for a given dataset.
1154    job_name: (Optional.) The name of the job. If provided, it must be a
1155      non-empty string. This argument makes it possible for multiple datasets to
1156      share the same job. The default behavior is that the dataset creates
1157      anonymous, exclusively owned jobs.
1158    consumer_index: (Optional.) The index of the consumer in the range from `0`
1159      to `num_consumers`. Must be specified alongside `num_consumers`. When
1160      specified, consumers will read from the job in a strict round-robin order,
1161      instead of the default first-come-first-served order.
1162    num_consumers: (Optional.) The number of consumers which will consume from
1163      the job. Must be specified alongside `consumer_index`. When specified,
1164      consumers will read from the job in a strict round-robin order, instead of
1165      the default first-come-first-served order. When `num_consumers` is
1166      specified, the dataset must have infinite cardinality to prevent a
1167      producer from running out of data early and causing consumers to go out of
1168      sync.
1169    max_outstanding_requests: (Optional.) A limit on how many elements may be
1170      requested at the same time. You can use this option to control the amount
1171      of memory used, since `distribute` won't use more than `element_size` *
1172      `max_outstanding_requests` of memory.
1173    data_transfer_protocol: (Optional.) The protocol to use for transferring
1174      data with the tf.data service. By default, data is transferred using gRPC.
1175    cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is
1176      provided, dataset iteration will be shared across concurrently running
1177      trainers. See
1178      https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
1179      for details.
1180    target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data
1181      runtime decides which workers to read from. If `"ANY"`, reads from any
1182      tf.data service workers. If `"LOCAL"`, only reads from local in-processs
1183      tf.data service workers. `"AUTO"` works well for most cases, while users
1184      can specify other targets. For example, `"LOCAL"` helps avoid RPCs and
1185      data copy if every TF worker colocates with a tf.data service worker.
1186      Consumers of a shared job must use the same `target_workers`. Defaults to
1187      `"AUTO"`.
1188
1189  Returns:
1190    A `tf.data.Dataset` which reads from the tf.data service.
1191  """
1192  _validate_job_name(job_name)
1193  if job_name is not None:
1194    job_name = string_ops.string_join(
1195        ["dataset_id=", _to_string(dataset_id), job_name], "/")
1196
1197  return _from_dataset_id(
1198      processing_mode=processing_mode,
1199      service=service,
1200      dataset_id=dataset_id,
1201      element_spec=element_spec,
1202      job_name=job_name,
1203      consumer_index=consumer_index,
1204      num_consumers=num_consumers,
1205      max_outstanding_requests=max_outstanding_requests,
1206      data_transfer_protocol=data_transfer_protocol,
1207      cross_trainer_cache=cross_trainer_cache,
1208      target_workers=target_workers)
1209