xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/coordinator/values.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Important value classes relevant to `ClusterCoordinator`.
16
17This is currently under development and the API is subject to change.
18"""
19
20import enum
21import threading
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops.options import ExternalStatePolicy
25from tensorflow.python.distribute import input_lib
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function as tf_function
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import type_spec as type_spec_lib
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import gen_dataset_ops
35from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.util import nest
38from tensorflow.python.util.tf_export import tf_export
39
40
41class RemoteValueStatus(enum.Enum):
42  """The status of a `RemoteValue` object.
43
44  A `RemoteValue` object can have three states:
45    1) not ready: no value, no non-retryable error and not aborted;
46    2) aborted: i.e. the execution of function was aborted because of task
47       failure, but can be retried;
48    3) ready: i.e. has value or has non-tryable error;
49
50  The initial state of a `RemoteValue` is "not ready". When its corresponding
51  closure has
52  been executed at least once, it will become aborted or ready. The state
53  transitions are:
54    1) not ready -> 2) aborted:
55      when the corresponding closure is aborted due to worker failure, and the
56      worker failure is not immediately handled.
57    1) not ready -> 3) ready:
58      when the corresponding closure has been executed successfully.
59    2) aborted -> 3) ready:
60      when the `RemoteValue` is rebuilt by rerunning the corresponding closure
61      and the closure has been executed successfully.
62    3) ready -> 2) aborted:
63      when the corresponding closure had been executed successfully but later
64      the corresponding remote worker failed. This is currently only implemented
65      for resource `RemoteValue` like iterators.
66  """
67  NOT_READY = "NOT_READY"
68  ABORTED = "ABORTED"
69  READY = "READY"
70
71
72@tf_export("distribute.experimental.coordinator.RemoteValue",
73           "distribute.coordinator.RemoteValue", v1=[])
74class RemoteValue(object):
75  """An asynchronously available value of a scheduled function.
76
77  This class is used as the return value of
78  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where
79  the underlying value becomes available at a later time once the function has
80  been executed.
81
82  Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to
83  a subsequent function scheduled with
84  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is
85  currently not supported.
86
87  Example:
88
89  ```python
90  strategy = tf.distribute.experimental.ParameterServerStrategy(
91      cluster_resolver=...)
92  coordinator = (
93      tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
94
95  with strategy.scope():
96    v1 = tf.Variable(initial_value=0.0)
97    v2 = tf.Variable(initial_value=1.0)
98
99  @tf.function
100  def worker_fn():
101    v1.assign_add(0.1)
102    v2.assign_sub(0.2)
103    return v1.read_value() / v2.read_value()
104
105  result = coordinator.schedule(worker_fn)
106  # Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
107  assert result.fetch() == 0.125
108
109  for _ in range(10):
110    # `worker_fn` will be run on arbitrary workers that are available. The
111    # `result` value will be available later.
112    result = coordinator.schedule(worker_fn)
113  ```
114  """
115
116  def fetch(self):
117    """Wait for the result of `RemoteValue` and return the numpy result.
118
119    This makes the value concrete by copying the remote value to local.
120
121    Returns:
122      The numpy array structure of the actual output of the `tf.function`
123      associated with this `RemoteValue`, previously returned by a
124      `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
125      This can be a single value, or a structure of values, depending on the
126      output of the `tf.function`.
127
128    Raises:
129      tf.errors.CancelledError: If the function that produces this `RemoteValue`
130        is aborted or cancelled due to failure.
131    """
132    raise NotImplementedError("Must be implemented in subclasses.")
133
134  def get(self):
135    """Wait for the result of `RemoteValue` and return the tensor result.
136
137    This makes the value concrete by copying the remote tensor to local.
138
139    Returns:
140      The actual output (in the form of `tf.Tensor`s) of the `tf.function`
141      associated with this `RemoteValue`, previously returned by a
142      `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
143      This can be a single Tensor, or a structure of Tensors, depending on the
144      output of the `tf.function`.
145
146    Raises:
147      tf.errors.CancelledError: If the function that produces this `RemoteValue`
148        is aborted or cancelled due to failure.
149    """
150    raise NotImplementedError("Must be implemented in subclasses.")
151
152
153# TODO(yuefengz): create an implementation for resource RemoteValue which needs
154# to remember the closure object while a normal RemoteValue doesn't.
155class RemoteValueImpl(RemoteValue):
156  """Implementation of `RemoteValue`."""
157
158  def __init__(self, closure, type_spec):  # pylint: disable=super-init-not-called
159    """Initializes a `RemoteValueImpl`.
160
161    Args:
162      closure: The closure from which the `RemoteValue` is created.
163      type_spec: The type spec for this `RemoteValue` which is used to trace
164        functions that take this `RemoteValue` as input.
165    """
166    self._closure = closure
167    self._type_spec = type_spec
168    self._values = None
169    self._has_fetched_to_local = False
170    self._has_fetched_to_local_lock = threading.Lock()
171    self._fetched_tensors = None
172    self._error = None
173    self._status_available_event = threading.Event()
174    self._status = RemoteValueStatus.NOT_READY
175
176  def _set_aborted(self, error):
177    self._status = RemoteValueStatus.ABORTED
178    self._values = None
179    self._error = error
180
181    # Wake up any waiting thread and clear the event.
182    self._status_available_event.set()
183
184  def _rebuild_on(self, worker):
185    self._status_available_event.clear()
186    # TODO(yuefengz): we may need to rebuild its inputs as well.
187    self._closure.execute_on(worker)
188
189  def _set_values(self, tensors):
190    self._status = RemoteValueStatus.READY
191    self._values = tensors
192    self._error = None
193    self._status_available_event.set()
194
195  def _set_error(self, error):
196    self._status = RemoteValueStatus.READY
197    self._values = None
198    self._error = error
199    self._status_available_event.set()
200
201  def _get_values(self):
202    self._status_available_event.wait()
203    return self._values
204
205  def _get_error(self):
206    self._status_available_event.wait()
207    return self._error
208
209  def _wait_and_maybe_error(self):
210    self._status_available_event.wait()
211    if self._status is RemoteValueStatus.ABORTED:
212      raise errors.CancelledError(
213          None, None,
214          "The corresponding function is aborted. Please reschedule the "
215          "function.")
216    if self._error is not None:
217      raise self._error
218
219  def fetch(self):
220    # TODO(rchao): Discuss the possibility of letting users perform `numpy`
221    # themselves at API graduation.
222    return nest.map_structure(
223        lambda x: x.numpy() if hasattr(x, "numpy") else x, self.get())
224
225  def get(self):
226    self._wait_and_maybe_error()
227
228    with self._has_fetched_to_local_lock:
229      if not self._has_fetched_to_local:
230
231        def copy_tensor(composite_tensor_obj):
232          """Copy a remote tensor to local (coordinator)."""
233          if isinstance(composite_tensor_obj, input_lib.DistributedIterator):
234            # A DistributedIterator cannot be copied to local; users should not
235            # access that anyway.
236            return composite_tensor_obj
237
238          with ops.device("/job:%s" % context.get_server_def().job_name):
239            # Copying to local (the coordinator) with `tf.device`.
240            return array_ops.identity(composite_tensor_obj)
241
242        if self._values is not None:
243          # When `self._values` is `None`, it indicates the associated function
244          # does not have a return value.
245          self._fetched_tensors = nest.map_structure(copy_tensor, self._values)
246        self._has_fetched_to_local = True
247
248    return self._fetched_tensors
249
250
251@tf_export("distribute.experimental.coordinator.PerWorkerValues",
252           "distribute.coordinator.PerWorkerValue", v1=[])
253class PerWorkerValues(composite_tensor.CompositeTensor):
254  """A container that holds a list of values, one value per worker.
255
256  `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
257  of values, where each of the values is located on its corresponding worker,
258  and upon being used as one of the `args` or `kwargs` of
259  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
260  value specific to a worker will be passed into the function being executed at
261  that corresponding worker.
262
263  Currently, the only supported path to create an object of
264  `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
265  `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
266  distributed dataset instance. The mechanism to create a custom
267  `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
268  """
269
270  def __init__(self, values):
271    for v in values:
272      if not isinstance(v, RemoteValue):
273        raise AssertionError(
274            "`PerWorkerValues` should only take `RemoteValue`s.")
275    self._values = tuple(values)
276
277  @property
278  def _type_spec(self):
279    return PerWorkerValuesTypeSpec(
280        self._values[0]._type_spec,  # pylint: disable=protected-access
281        type(self))
282
283
284class PerWorkerValuesTypeSpec(type_spec_lib.TypeSpec):
285  """TypeSpec for PerWorkerValues.
286
287  It only support tracing a function using a PerWorkerValues.
288  """
289
290  def __init__(self, value_spec, descendant_type):
291    assert value_spec
292    self._value_spec = value_spec
293    self._descendant_type = descendant_type
294
295  def _serialize(self):
296    return (self._value_spec,)
297
298  @property
299  def value_type(self):
300    return self._descendant_type
301
302  def most_specific_common_supertype(self, others):
303    raise NotImplementedError(
304        "most_specific_common_supertype is not implemented")
305
306  @property
307  def _component_specs(self):
308    return self._value_spec
309
310  def _to_components(self, value):
311    return self._value_spec
312
313  def _from_components(self, value):
314    return value
315
316
317class PerWorkerDatasetFromDatasetFunction(object):
318  """Represents worker-distributed datasets created from dataset function."""
319
320  def __init__(self, dataset_fn, coordinator):
321    """Makes an iterable from datasets created by the given function.
322
323    Args:
324      dataset_fn: A function that returns a `Dataset`.
325      coordinator: a `ClusterCoordinator` object, used to create dataset
326        resources.
327    """
328
329    def disallow_variable_creation(next_creator, **kwargs):
330      raise ValueError("Creating variables in `dataset_fn` is not allowed.")
331
332    if isinstance(dataset_fn, def_function.Function):
333      with variable_scope.variable_creator_scope(disallow_variable_creation):
334        dataset_fn = dataset_fn.get_concrete_function()
335    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
336      with variable_scope.variable_creator_scope(disallow_variable_creation):
337        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
338    self._dataset_fn = dataset_fn
339    self._coordinator = coordinator
340    self._element_spec = None
341
342  def __iter__(self):
343    # We would like users to create iterators outside `tf.function`s so that we
344    # can track them.
345    if (not context.executing_eagerly() or
346        ops.get_default_graph().building_function):
347      raise RuntimeError(
348          "__iter__() is not supported inside of tf.function or in graph mode.")
349
350    def _create_per_worker_iterator():
351      dataset = self._dataset_fn()
352      return iter(dataset)
353
354    # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
355    # times, for the same object it should only create and register resource
356    # once. Using object id to distinguish different iterator resources.
357    per_worker_iterator = self._coordinator._create_per_worker_resources(
358        _create_per_worker_iterator)
359
360    # Setting type_spec of each RemoteValue so that functions taking these
361    # RemoteValues as inputs can be traced.
362    for iterator_remote_value in per_worker_iterator._values:
363      iterator_remote_value._type_spec = (
364          input_lib.get_iterator_spec_from_dataset(
365              self._coordinator.strategy, self._dataset_fn.structured_outputs))
366
367    return PerWorkerDistributedIterator(per_worker_iterator._values)
368
369  @property
370  def element_spec(self):
371    """The type specification of an element of this dataset.
372
373    This property is subject to change without notice.
374    """
375    if not isinstance(self._dataset_fn, tf_function.ConcreteFunction):
376      raise NotImplementedError(
377          "`element_spec` is not supported when the `dataset_fn` is not "
378          "a `ConcreteFunction`.")
379    return self._dataset_fn.structured_outputs.element_spec
380
381
382def serialize_dataset_to_graph(dataset):
383  dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
384  graph_def = gen_dataset_ops.dataset_to_graph_v2(
385      dataset._variant_tensor,  # pylint: disable=protected-access
386      external_state_policy=ExternalStatePolicy.WARN.value,
387      strip_device_assignment=True)
388  return graph_def
389
390
391class _RemoteDataset(dataset_ops.DatasetSource):
392  """Creates a dataset given a graph def."""
393
394  def __init__(self, graph_def, element_spec):
395    self._elem_spec = element_spec
396    variant_tensor = ged_ops.dataset_from_graph(graph_def)
397    super(_RemoteDataset, self).__init__(variant_tensor)
398
399  @property
400  def element_spec(self):
401    return self._elem_spec
402
403
404def deserialize_dataset_from_graph(graph_def, element_spec):
405  return _RemoteDataset(graph_def, element_spec)
406
407
408class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
409  """Represents worker-distributed datasets created from a dataset."""
410
411  def __init__(self, dataset, coordinator):
412    """Makes an iterable from datasets created by the given dataset.
413
414    It creates a dataset_fn which deserializes a dataset from a graph under the
415    hood.
416
417    Args:
418      dataset: A tf.data.Dataset, a DistributedDataset or a
419        DistributedDatasetsFromFunction
420      coordinator: a `ClusterCoordinator` object, used to create dataset
421        resources.
422    """
423    if isinstance(dataset, input_lib.DistributedDataset):
424      original_dataset = dataset._original_dataset
425      serialized = serialize_dataset_to_graph(original_dataset)
426
427      def dataset_fn():
428        deserialized = deserialize_dataset_from_graph(
429            serialized, original_dataset.element_spec)
430        dataset.build(dataset_to_replace=deserialized)
431        return dataset
432    elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
433      def dataset_fn():
434        dataset.build()
435        return dataset
436    elif isinstance(dataset, dataset_ops.Dataset):
437      serialized = serialize_dataset_to_graph(dataset)
438
439      def dataset_fn():
440        return deserialize_dataset_from_graph(serialized, dataset.element_spec)
441    else:
442      raise ValueError("Unexpected dataset type!")
443
444    super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)
445
446
447def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
448  """Returns a per-worker dataset from a dataset or a dataset function."""
449  if callable(dataset_or_dataset_fn):
450    return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
451                                               coordinator)
452  else:
453    return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)
454
455
456class PerWorkerDistributedIterator(PerWorkerValues):
457  """Distributed iterator for `ClusterCoordinator`."""
458
459  def __next__(self):
460    return self.get_next()
461
462  def get_next(self, name=None):
463    """Returns the next input from the iterator for all replicas."""
464    raise NotImplementedError("Iterating over an `AsyncDistributedIterator` "
465                              "is not supported right now.")
466