1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Important value classes relevant to `ClusterCoordinator`. 16 17This is currently under development and the API is subject to change. 18""" 19 20import enum 21import threading 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops.options import ExternalStatePolicy 25from tensorflow.python.distribute import input_lib 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function as tf_function 29from tensorflow.python.framework import composite_tensor 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import type_spec as type_spec_lib 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import gen_dataset_ops 35from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.util import nest 38from tensorflow.python.util.tf_export import tf_export 39 40 41class RemoteValueStatus(enum.Enum): 42 """The status of a `RemoteValue` object. 43 44 A `RemoteValue` object can have three states: 45 1) not ready: no value, no non-retryable error and not aborted; 46 2) aborted: i.e. the execution of function was aborted because of task 47 failure, but can be retried; 48 3) ready: i.e. has value or has non-tryable error; 49 50 The initial state of a `RemoteValue` is "not ready". When its corresponding 51 closure has 52 been executed at least once, it will become aborted or ready. The state 53 transitions are: 54 1) not ready -> 2) aborted: 55 when the corresponding closure is aborted due to worker failure, and the 56 worker failure is not immediately handled. 57 1) not ready -> 3) ready: 58 when the corresponding closure has been executed successfully. 59 2) aborted -> 3) ready: 60 when the `RemoteValue` is rebuilt by rerunning the corresponding closure 61 and the closure has been executed successfully. 62 3) ready -> 2) aborted: 63 when the corresponding closure had been executed successfully but later 64 the corresponding remote worker failed. This is currently only implemented 65 for resource `RemoteValue` like iterators. 66 """ 67 NOT_READY = "NOT_READY" 68 ABORTED = "ABORTED" 69 READY = "READY" 70 71 72@tf_export("distribute.experimental.coordinator.RemoteValue", 73 "distribute.coordinator.RemoteValue", v1=[]) 74class RemoteValue(object): 75 """An asynchronously available value of a scheduled function. 76 77 This class is used as the return value of 78 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where 79 the underlying value becomes available at a later time once the function has 80 been executed. 81 82 Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to 83 a subsequent function scheduled with 84 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is 85 currently not supported. 86 87 Example: 88 89 ```python 90 strategy = tf.distribute.experimental.ParameterServerStrategy( 91 cluster_resolver=...) 92 coordinator = ( 93 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)) 94 95 with strategy.scope(): 96 v1 = tf.Variable(initial_value=0.0) 97 v2 = tf.Variable(initial_value=1.0) 98 99 @tf.function 100 def worker_fn(): 101 v1.assign_add(0.1) 102 v2.assign_sub(0.2) 103 return v1.read_value() / v2.read_value() 104 105 result = coordinator.schedule(worker_fn) 106 # Note that `fetch()` gives the actual result instead of a `tf.Tensor`. 107 assert result.fetch() == 0.125 108 109 for _ in range(10): 110 # `worker_fn` will be run on arbitrary workers that are available. The 111 # `result` value will be available later. 112 result = coordinator.schedule(worker_fn) 113 ``` 114 """ 115 116 def fetch(self): 117 """Wait for the result of `RemoteValue` and return the numpy result. 118 119 This makes the value concrete by copying the remote value to local. 120 121 Returns: 122 The numpy array structure of the actual output of the `tf.function` 123 associated with this `RemoteValue`, previously returned by a 124 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call. 125 This can be a single value, or a structure of values, depending on the 126 output of the `tf.function`. 127 128 Raises: 129 tf.errors.CancelledError: If the function that produces this `RemoteValue` 130 is aborted or cancelled due to failure. 131 """ 132 raise NotImplementedError("Must be implemented in subclasses.") 133 134 def get(self): 135 """Wait for the result of `RemoteValue` and return the tensor result. 136 137 This makes the value concrete by copying the remote tensor to local. 138 139 Returns: 140 The actual output (in the form of `tf.Tensor`s) of the `tf.function` 141 associated with this `RemoteValue`, previously returned by a 142 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call. 143 This can be a single Tensor, or a structure of Tensors, depending on the 144 output of the `tf.function`. 145 146 Raises: 147 tf.errors.CancelledError: If the function that produces this `RemoteValue` 148 is aborted or cancelled due to failure. 149 """ 150 raise NotImplementedError("Must be implemented in subclasses.") 151 152 153# TODO(yuefengz): create an implementation for resource RemoteValue which needs 154# to remember the closure object while a normal RemoteValue doesn't. 155class RemoteValueImpl(RemoteValue): 156 """Implementation of `RemoteValue`.""" 157 158 def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called 159 """Initializes a `RemoteValueImpl`. 160 161 Args: 162 closure: The closure from which the `RemoteValue` is created. 163 type_spec: The type spec for this `RemoteValue` which is used to trace 164 functions that take this `RemoteValue` as input. 165 """ 166 self._closure = closure 167 self._type_spec = type_spec 168 self._values = None 169 self._has_fetched_to_local = False 170 self._has_fetched_to_local_lock = threading.Lock() 171 self._fetched_tensors = None 172 self._error = None 173 self._status_available_event = threading.Event() 174 self._status = RemoteValueStatus.NOT_READY 175 176 def _set_aborted(self, error): 177 self._status = RemoteValueStatus.ABORTED 178 self._values = None 179 self._error = error 180 181 # Wake up any waiting thread and clear the event. 182 self._status_available_event.set() 183 184 def _rebuild_on(self, worker): 185 self._status_available_event.clear() 186 # TODO(yuefengz): we may need to rebuild its inputs as well. 187 self._closure.execute_on(worker) 188 189 def _set_values(self, tensors): 190 self._status = RemoteValueStatus.READY 191 self._values = tensors 192 self._error = None 193 self._status_available_event.set() 194 195 def _set_error(self, error): 196 self._status = RemoteValueStatus.READY 197 self._values = None 198 self._error = error 199 self._status_available_event.set() 200 201 def _get_values(self): 202 self._status_available_event.wait() 203 return self._values 204 205 def _get_error(self): 206 self._status_available_event.wait() 207 return self._error 208 209 def _wait_and_maybe_error(self): 210 self._status_available_event.wait() 211 if self._status is RemoteValueStatus.ABORTED: 212 raise errors.CancelledError( 213 None, None, 214 "The corresponding function is aborted. Please reschedule the " 215 "function.") 216 if self._error is not None: 217 raise self._error 218 219 def fetch(self): 220 # TODO(rchao): Discuss the possibility of letting users perform `numpy` 221 # themselves at API graduation. 222 return nest.map_structure( 223 lambda x: x.numpy() if hasattr(x, "numpy") else x, self.get()) 224 225 def get(self): 226 self._wait_and_maybe_error() 227 228 with self._has_fetched_to_local_lock: 229 if not self._has_fetched_to_local: 230 231 def copy_tensor(composite_tensor_obj): 232 """Copy a remote tensor to local (coordinator).""" 233 if isinstance(composite_tensor_obj, input_lib.DistributedIterator): 234 # A DistributedIterator cannot be copied to local; users should not 235 # access that anyway. 236 return composite_tensor_obj 237 238 with ops.device("/job:%s" % context.get_server_def().job_name): 239 # Copying to local (the coordinator) with `tf.device`. 240 return array_ops.identity(composite_tensor_obj) 241 242 if self._values is not None: 243 # When `self._values` is `None`, it indicates the associated function 244 # does not have a return value. 245 self._fetched_tensors = nest.map_structure(copy_tensor, self._values) 246 self._has_fetched_to_local = True 247 248 return self._fetched_tensors 249 250 251@tf_export("distribute.experimental.coordinator.PerWorkerValues", 252 "distribute.coordinator.PerWorkerValue", v1=[]) 253class PerWorkerValues(composite_tensor.CompositeTensor): 254 """A container that holds a list of values, one value per worker. 255 256 `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection 257 of values, where each of the values is located on its corresponding worker, 258 and upon being used as one of the `args` or `kwargs` of 259 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the 260 value specific to a worker will be passed into the function being executed at 261 that corresponding worker. 262 263 Currently, the only supported path to create an object of 264 `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling 265 `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned 266 distributed dataset instance. The mechanism to create a custom 267 `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported. 268 """ 269 270 def __init__(self, values): 271 for v in values: 272 if not isinstance(v, RemoteValue): 273 raise AssertionError( 274 "`PerWorkerValues` should only take `RemoteValue`s.") 275 self._values = tuple(values) 276 277 @property 278 def _type_spec(self): 279 return PerWorkerValuesTypeSpec( 280 self._values[0]._type_spec, # pylint: disable=protected-access 281 type(self)) 282 283 284class PerWorkerValuesTypeSpec(type_spec_lib.TypeSpec): 285 """TypeSpec for PerWorkerValues. 286 287 It only support tracing a function using a PerWorkerValues. 288 """ 289 290 def __init__(self, value_spec, descendant_type): 291 assert value_spec 292 self._value_spec = value_spec 293 self._descendant_type = descendant_type 294 295 def _serialize(self): 296 return (self._value_spec,) 297 298 @property 299 def value_type(self): 300 return self._descendant_type 301 302 def most_specific_common_supertype(self, others): 303 raise NotImplementedError( 304 "most_specific_common_supertype is not implemented") 305 306 @property 307 def _component_specs(self): 308 return self._value_spec 309 310 def _to_components(self, value): 311 return self._value_spec 312 313 def _from_components(self, value): 314 return value 315 316 317class PerWorkerDatasetFromDatasetFunction(object): 318 """Represents worker-distributed datasets created from dataset function.""" 319 320 def __init__(self, dataset_fn, coordinator): 321 """Makes an iterable from datasets created by the given function. 322 323 Args: 324 dataset_fn: A function that returns a `Dataset`. 325 coordinator: a `ClusterCoordinator` object, used to create dataset 326 resources. 327 """ 328 329 def disallow_variable_creation(next_creator, **kwargs): 330 raise ValueError("Creating variables in `dataset_fn` is not allowed.") 331 332 if isinstance(dataset_fn, def_function.Function): 333 with variable_scope.variable_creator_scope(disallow_variable_creation): 334 dataset_fn = dataset_fn.get_concrete_function() 335 elif not isinstance(dataset_fn, tf_function.ConcreteFunction): 336 with variable_scope.variable_creator_scope(disallow_variable_creation): 337 dataset_fn = def_function.function(dataset_fn).get_concrete_function() 338 self._dataset_fn = dataset_fn 339 self._coordinator = coordinator 340 self._element_spec = None 341 342 def __iter__(self): 343 # We would like users to create iterators outside `tf.function`s so that we 344 # can track them. 345 if (not context.executing_eagerly() or 346 ops.get_default_graph().building_function): 347 raise RuntimeError( 348 "__iter__() is not supported inside of tf.function or in graph mode.") 349 350 def _create_per_worker_iterator(): 351 dataset = self._dataset_fn() 352 return iter(dataset) 353 354 # If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple 355 # times, for the same object it should only create and register resource 356 # once. Using object id to distinguish different iterator resources. 357 per_worker_iterator = self._coordinator._create_per_worker_resources( 358 _create_per_worker_iterator) 359 360 # Setting type_spec of each RemoteValue so that functions taking these 361 # RemoteValues as inputs can be traced. 362 for iterator_remote_value in per_worker_iterator._values: 363 iterator_remote_value._type_spec = ( 364 input_lib.get_iterator_spec_from_dataset( 365 self._coordinator.strategy, self._dataset_fn.structured_outputs)) 366 367 return PerWorkerDistributedIterator(per_worker_iterator._values) 368 369 @property 370 def element_spec(self): 371 """The type specification of an element of this dataset. 372 373 This property is subject to change without notice. 374 """ 375 if not isinstance(self._dataset_fn, tf_function.ConcreteFunction): 376 raise NotImplementedError( 377 "`element_spec` is not supported when the `dataset_fn` is not " 378 "a `ConcreteFunction`.") 379 return self._dataset_fn.structured_outputs.element_spec 380 381 382def serialize_dataset_to_graph(dataset): 383 dataset = dataset._apply_debug_options() # pylint: disable=protected-access 384 graph_def = gen_dataset_ops.dataset_to_graph_v2( 385 dataset._variant_tensor, # pylint: disable=protected-access 386 external_state_policy=ExternalStatePolicy.WARN.value, 387 strip_device_assignment=True) 388 return graph_def 389 390 391class _RemoteDataset(dataset_ops.DatasetSource): 392 """Creates a dataset given a graph def.""" 393 394 def __init__(self, graph_def, element_spec): 395 self._elem_spec = element_spec 396 variant_tensor = ged_ops.dataset_from_graph(graph_def) 397 super(_RemoteDataset, self).__init__(variant_tensor) 398 399 @property 400 def element_spec(self): 401 return self._elem_spec 402 403 404def deserialize_dataset_from_graph(graph_def, element_spec): 405 return _RemoteDataset(graph_def, element_spec) 406 407 408class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction): 409 """Represents worker-distributed datasets created from a dataset.""" 410 411 def __init__(self, dataset, coordinator): 412 """Makes an iterable from datasets created by the given dataset. 413 414 It creates a dataset_fn which deserializes a dataset from a graph under the 415 hood. 416 417 Args: 418 dataset: A tf.data.Dataset, a DistributedDataset or a 419 DistributedDatasetsFromFunction 420 coordinator: a `ClusterCoordinator` object, used to create dataset 421 resources. 422 """ 423 if isinstance(dataset, input_lib.DistributedDataset): 424 original_dataset = dataset._original_dataset 425 serialized = serialize_dataset_to_graph(original_dataset) 426 427 def dataset_fn(): 428 deserialized = deserialize_dataset_from_graph( 429 serialized, original_dataset.element_spec) 430 dataset.build(dataset_to_replace=deserialized) 431 return dataset 432 elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction): 433 def dataset_fn(): 434 dataset.build() 435 return dataset 436 elif isinstance(dataset, dataset_ops.Dataset): 437 serialized = serialize_dataset_to_graph(dataset) 438 439 def dataset_fn(): 440 return deserialize_dataset_from_graph(serialized, dataset.element_spec) 441 else: 442 raise ValueError("Unexpected dataset type!") 443 444 super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator) 445 446 447def get_per_worker_dataset(dataset_or_dataset_fn, coordinator): 448 """Returns a per-worker dataset from a dataset or a dataset function.""" 449 if callable(dataset_or_dataset_fn): 450 return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn, 451 coordinator) 452 else: 453 return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator) 454 455 456class PerWorkerDistributedIterator(PerWorkerValues): 457 """Distributed iterator for `ClusterCoordinator`.""" 458 459 def __next__(self): 460 return self.get_next() 461 462 def get_next(self, name=None): 463 """Returns the next input from the iterator for all replicas.""" 464 raise NotImplementedError("Iterating over an `AsyncDistributedIterator` " 465 "is not supported right now.") 466