xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/auto_control_deps.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AutomaticControlDependencies and related functionality."""
16
17import collections
18import enum
19
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.python.eager import context
22from tensorflow.python.framework import auto_control_deps_utils as utils
23from tensorflow.python.framework import dtypes as dtypes_module
24from tensorflow.python.framework import indexed_slices
25from tensorflow.python.framework import op_def_registry
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import registry
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import control_flow_util
32from tensorflow.python.ops import tensor_array_ops
33from tensorflow.python.util import nest
34from tensorflow.python.util import object_identity
35from tensorflow.python.util import tf_decorator
36
37# LINT.IfChange
38# Op types that should not run in program order, e.g. because they need to run
39# asynchronously to avoid deadlock.
40
41ASYNC_STATEFUL_OPS = frozenset((
42    "CollectiveGather",
43    "CollectiveReduce",
44    "CollectiveBcastSend",
45    "CollectiveBcastSendV2",
46    "CollectiveBcastRecv",
47    "CollectiveBcastRecvV2",
48    "NcclAllReduce",
49    # We do not add "Send" here since we want it to be added as a control output
50    # in order to avoid being pruned.
51    "Recv",
52    "CollectiveInitializeCommunicator",
53    "CollectiveAssignGroupV2",
54))
55
56LEGACY_RANDOM_OPS = frozenset((
57    # These may be used in variable initializers -- thus their execution should
58    # not be dependent on other stateful operations.  This is because although
59    # according to program order, tf.Variables may be created in sequence,
60    # their initialization happens outside of the program order (specifically,
61    # in graph mode their initialization happens by calling a grouped
62    # initializer operation or in eager mode, where initialization is lifted
63    # out of the tf.function and executed the first time the function is
64    # executed).
65    #
66    # Unless there is a specific dependency between the initializers
67    # themselves (e.g. one initializer depends on a Variable whose value depends
68    # on another initializer), the initialization can happen in any order so
69    # long as it's before the associated Variable read operations.
70    #
71    # Note that in general the randomness of legacy random operations is only
72    # guaranteed by providing a graph-level and op-level seed (and ordering of
73    # the same op across multiple iterations of a while_loop is specifically not
74    # guaranteed; see the discussion below).
75    #
76    # There is a possible race condition inside while_loop where the same
77    # random OpKernel instantiation is reused across multiple steps
78    # of the loop.  Since legacy Random OpKernels have an internal rng state,
79    # automatic dependency tracking across loop steps would likely
80    # fix this race; and for that case this denylist is problematic.
81    # However, since automatic dependency tracking inside while loops is not
82    # currently supported, and there are no other examples of OpKernel reuse
83    # (each OpKernel is associated with a unique op in graph mode),
84    # this denylist has no effect on the aforementioned behavior.
85    #
86    # TODO(ebrevdo,skyewm): Modify the check against this denylist to
87    # only occur when the op is inside a "variable initialization scope"; and
88    # add proper autodeps inside while_loops that respects this updated check.
89    "RandomUniform",
90    "RandomUniformInt",
91    "RandomStandardNormal",
92    "ParameterizedTruncatedNormal",
93    "TruncatedNormal",
94    "RandomShuffle",
95    "Multinomial",
96    "RandomGamma",
97    "RandomGammaGrad",
98    "RandomPoisson",
99    "RandomPoissonV2",
100))
101
102MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset((
103    "InfeedEnqueue",
104    "InfeedEnqueueTuple",
105))
106
107# These ops are order-insensitive ans should in theory run, but at the moment
108# they either always have the necessary data dependencies, or have workarounds
109# in existing code that would break when adding new control deps. This
110# inconsistency should be eventually fixed, but it would be more effective to
111# retire the list instead.
112SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset((
113    "CudnnRNN",
114    "CudnnRNNBackprop",
115    "CudnnRNNV2",
116    "CudnnRNNV3",
117    "CudnnRNNBackpropV2",
118    "CudnnRNNBackpropV3",
119    "EnqueueTPUEmbeddingSparseBatch",
120    "EnqueueTPUEmbeddingIntegerBatch",
121    "EnqueueTPUEmbeddingSparseTensorBatch",
122    "EnqueueTPUEmbeddingRaggedTensorBatch",
123    "EnqueueTPUEmbeddingArbitraryTensorBatch",
124    "RestoreV2",
125    "SaveV2",
126))
127# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc)
128
129# Op types that are marked as stateless, but should be allowlisted to add auto
130# control dependencies.
131_ALLOWLIST_STATELESS_OPS = [
132    # As TPU collective ops are blocking, if there are more than one collective
133    # op in the function, we need to make sure different collectives ops are
134    # scheduled in certain orders. Otherwise if at the same time all the
135    # replicas are launching different collective ops/programs, it may cause
136    # deadlock.
137    "AllToAll",
138    "CrossReplicaSum",
139    "CollectivePermute",
140]
141
142
143def op_is_stateful(op):
144  # pylint: disable=protected-access
145  ret = ((op._is_stateful and
146          ((op.type not in ASYNC_STATEFUL_OPS) and
147           (op.type not in LEGACY_RANDOM_OPS) and
148           (op.type not in SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS))) or
149         (op.type in _ALLOWLIST_STATELESS_OPS))
150  return ret
151
152
153class ResourceType(enum.Enum):
154  READ_ONLY = "read-only"
155  READ_WRITE = "read-write"
156
157
158def collective_manager_ids_from_op(op):
159  """Returns CollectiveManager ID from the op if one exists, else None.
160
161  CollectiveManager adds collective and no_op operations tagged with an ID,
162  unique to the manager object. This function extracts that ID, or None, if the
163  node was not generated by a CollectiveManager.
164
165  Args:
166    op: `Operation` to get the collective manager ID from.
167
168  Returns:
169    List of CollectiveManager IDs used by the op.
170  """
171  if op.type == "CollectiveReduce":
172    try:
173      return [op.get_attr("_collective_manager_id")]
174    except ValueError:
175      pass
176  elif op.type == "StatefulPartitionedCall":
177    try:
178      return op.get_attr(utils.COLLECTIVE_MANAGER_IDS)
179    except ValueError:
180      pass
181  return []
182
183
184class AutomaticControlDependencies(object):
185  """Context manager to automatically add control dependencies.
186
187  Code under this context manager will act as if a sensible set of control
188  dependencies were present. More specifically:
189    1. All stateful ops in the scope will execute (with the exception of ops in
190       ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS)
191    2. Stateful ops which modify the same resource will execute in program order
192
193  Note: creating variables in an automatic control dependencies context is not
194  supported (the value of the variables will never change as they will keep
195  getting reinitialized).
196
197  NOT THREAD SAFE
198  """
199
200  def __init__(self,
201               record_initial_resource_uses=False,
202               record_uses_of_resource_ids=None):
203    self._returned_tensors = object_identity.ObjectIdentitySet()
204    self.ops_which_must_run = set()
205    self.record_initial_resource_uses = record_initial_resource_uses
206    self.record_uses_of_resource_ids = record_uses_of_resource_ids
207    self._independent_ops = []
208
209  def mark_as_return(self, tensor):
210    """Acts like identity but marks the `Tensor` as a return value.
211
212    This will possibly return a copy of the `Tensor`. Usage:
213
214    ```
215      with AutomaticControlDependencies() as a:
216       ...
217       t = a.mark_as_return(t)
218      _ = ...(t...)  # i.e. it's safe to use t here
219    ```
220
221    Args:
222      tensor: the `Tensor` to be marked
223
224    Returns:
225      a copy of the `Tensor`.
226    """
227    if isinstance(tensor, indexed_slices.IndexedSlices):
228      values = array_ops.identity(tensor.values)
229      indices = array_ops.identity(tensor.indices)
230      self._returned_tensors.add(indices)
231      self._returned_tensors.add(values)
232      return indexed_slices.IndexedSlices(
233          values, indices, dense_shape=tensor.dense_shape)
234    elif isinstance(tensor, sparse_tensor.SparseTensor):
235      values = array_ops.identity(tensor.values)
236      indices = array_ops.identity(tensor.indices)
237      self._returned_tensors.add(indices)
238      self._returned_tensors.add(values)
239      return sparse_tensor.SparseTensor(
240          indices, values, dense_shape=tensor.dense_shape)
241    elif isinstance(tensor, tensor_array_ops.TensorArray):
242      flow = array_ops.identity(tensor.flow)
243      self._returned_tensors.add(flow)
244      return tensor_array_ops.build_ta_with_new_flow(tensor, flow)
245    # We want to make the return values depend on the stateful operations, but
246    # we don't want to introduce a cycle, so we make the return value the result
247    # of a new identity operation that the stateful operations definitely don't
248    # depend on.
249    tensor = array_ops.identity(tensor)
250    self._returned_tensors.add(tensor)
251    return tensor
252
253  def run_independently(self, op):
254    """Marks the given op as independent.
255
256    Overrides any other rule for the op.
257
258    Independent ops are guaranteed to execute before the return values, but
259    are allowed to run in parallel with everything else. Use in programs which
260    can guarantee that an op has side effects that don't affect any other op.
261
262    Args:
263      op: An operation
264    """
265    self._independent_ops.append(op)
266    op._set_attr("_independent_side_effects", attr_value_pb2.AttrValue(b=True))  # pylint: disable=protected-access
267
268  def __enter__(self):
269    if context.executing_eagerly():
270      return self
271    # This code assumes no other thread is adding ops to the graph while
272    # we're adding ops to the graph.
273    # TODO(apassos): Fix this by locking the graph or using a temporary
274    # graph (but that would mess up devices and collections at least,
275    # probably other things as well).
276    g = ops.get_default_graph()
277    self._graph = g
278    g._add_control_dependencies = True  # pylint: disable=protected-access
279    g.experimental_acd_manager = self
280    self._n_operations = len(g.get_operations())
281    return self
282
283  def _process_switch(self, switch_op, ops_which_must_run,
284                      last_write_to_resource, merge_for_resource):
285    """Processes a switch node for a resource input.
286
287    When tensorflow creates a cond, it creates a control flow context for each
288    branch of the cond. Each external tensor accessed by that branch is routed
289    through a switch op, which gets created in the graph _after_ the op which
290    uses that tensor get created.
291
292    If the resource comes from another switch op we process that one first.
293
294    _process_switch creates a corresponding merge node for the switch node. This
295    merge node is added to the outer control flow context of the switch
296    node. We also ensure that:
297
298      1. The switch node executes after the previous op which used the resource
299         tensor
300
301      2. Any op which uses a resource output of the switch node executes before
302         the merge for the switch node.
303
304      3. The next op which uses the input resource to the switch node (which
305         might be another switch node for the other branch of the conditional)
306         will execute after the merge node is done.
307
308      4. The merge node is marked as must_run so it will run even if no
309         subsequent operation uses the resource.
310
311    Args:
312      switch_op: the switch op to be processed
313      ops_which_must_run: the set of ops which must run
314      last_write_to_resource: map from resource tensor to last op updating
315        it
316      merge_for_resource: map from resource tensor to merge which must follow
317        all usages of it.
318    """
319    # pylint: disable=protected-access
320    inp = switch_op.inputs[0]
321    input_id = ops.tensor_id(inp)
322    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
323      self._process_switch(inp.op, ops_which_must_run, last_write_to_resource,
324                           merge_for_resource)
325    output = switch_op.outputs[0]
326    output_id = ops.tensor_id(output)
327    if output_id in merge_for_resource:
328      return
329    new_merge = control_flow_ops.merge(
330        switch_op.outputs, name="artificial_merge")
331    new_merge[0].op._control_flow_context = (
332        switch_op._control_flow_context.outer_context)
333    # Ensures the merge always runs
334    ops_which_must_run.add(new_merge[0].op)
335    if input_id in last_write_to_resource:
336      # Ensures the switch executes after the previous op using the resource.
337      switch_op._add_control_input(last_write_to_resource[input_id])
338    # Ensure the next op outside the cond happens after the merge.
339    last_write_to_resource[input_id] = new_merge[0].op
340    if input_id in merge_for_resource:
341      merge_for_resource[input_id]._add_control_input(new_merge[0].op)
342    for o in switch_op.outputs:
343      # Ensures the merge will execute after all ops inside the cond
344      merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
345
346  def __exit__(self, unused_type, unused_value, unused_traceback):
347    # pylint: disable=protected-access
348    if context.executing_eagerly():
349      return
350
351    if self._graph is not ops.get_default_graph():
352      raise RuntimeError(
353          "Within the automatic control dependency context, the default graph"
354          f" cannot change. Upon entry it was {self._graph}, but on exit it"
355          f" changed to {ops.get_default_graph()}")
356
357    outer_graph = getattr(self._graph, "outer_graph", None)
358    if outer_graph is not None:
359      self._graph._add_control_dependencies = outer_graph._add_control_dependencies
360    else:
361      self._graph._add_control_dependencies = False
362    self._graph.experimental_acd_manager = None
363
364    # map from resource tensor to the last op which wrote to it
365    last_write_to_resource = {}
366    # map from resource tensor to the list of reads from it since the last
367    # write or since the beginning of the function.
368    reads_since_last_write_to_resource = collections.defaultdict(list)
369    # CollectiveManager manager_ids within a particular function call should not
370    # be needed outside of that function call. So we keep them separate (though
371    # the general idea of the maps is the same, in the future, we'll need to
372    # correctly thread the control output outside).
373    # Map from collective manager scope to the last op which used it
374    collective_manager_scopes_opened = {}
375    collective_manager_scopes_used = {}
376    # set of conditional and loop exits
377    ops_which_must_run = set()
378    # merge which must depend on ops which use this resource
379    merge_for_resource = {}
380
381    new_operations = self._graph.get_operations()[self._n_operations:]
382    first_use_for_res = {}
383    resources_by_op = {}
384
385    # Ensures that uses of resource tensors get serialized properly and all
386    # execute. This is done by keeping a map from resource tensor to the last op
387    # in graph-construction order which used it (last_write_to_resource).
388    #
389    # Conditionals are written in TensorFlow such that every external tensor
390    # accessed in the conditional goes through a switch op and every return
391    # tensor (it's guaranteed that there will be at least one) goes through a
392    # merge op.
393    #
394    # To handle conditionals, switches are handled in a special way (see
395    # comments for _process_switch). Merge nodes created by TF's conditional
396    # logic (as opposed to by _process_switch) are forced to run and also get a
397    # control dependency added to them to ensure all stateful ops inside their
398    # control flow context run.
399    #
400    # We also ensure that if an op is using a resource output by a switch node
401    # (that is, a resource tensor for which there's a value in
402    # merge_for_resource) this op will run before the merge for that resource.
403    #
404    # We try to add control inputs to nodes respecting their control flow
405    # contexts to avoid dead nodes propagating everywhere and leading to
406    # "retval[0] doesn't have value" errors. If a node gets a control dependency
407    # on a dead node (i.e. a note from an untaken control flow branch) that node
408    # will be marked as dead unless it's a merge node.
409    #
410    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
411    # test that it works. Support while loops. Support init_scope escaping from
412    # this.
413    for op in new_operations:
414      # TODO(apassos) make this code safely support while loops.
415      if control_flow_util.IsInWhileLoop(op):
416        continue
417      control_inputs = set()
418
419      if op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS:
420        # This will add it to self._independent_ops, but also mark it with an
421        # attribute.
422        self.run_independently(op)
423
424      if op in self._independent_ops:
425        ops_which_must_run.add(op)
426        continue
427
428      # Ensure stateful ops run.
429      # Read-only ops are added to control outputs if the read value is
430      # consumed. This covers the case when the read value is returned from
431      # the function since that goes through a tf.identity in mark_as_return.
432      if ((op_def_registry.get(op.type) is None) or
433          (op_is_stateful(op) and
434           (op.type not in utils.RESOURCE_READ_OPS or
435            any(output.consumers() for output in op.outputs)))):
436        ops_which_must_run.add(op)
437
438      # Make a note of all opened manager_ids.
439      if op.type == "NoOp":
440        try:
441          collective_manager_scopes_opened[op.get_attr(
442              "_collective_manager_id")] = op
443        except ValueError:
444          pass
445      # Ignore switches (they're handled separately)
446      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
447        continue
448      # Make merges trigger all other computation which must run
449      # TODO(mdan): Don't do this. Write a transform to chains instead.
450      # See core/common_runtime/control_flow_deps_to_chains.cc.
451      if op.type == "Merge":
452        for o in ops_which_must_run:
453          op._add_control_input(o)
454          for inp in o.inputs:
455            input_id = ops.tensor_id(inp)
456            if input_id in last_write_to_resource:
457              last_write_to_resource[input_id] = op
458        ops_which_must_run = set([op])
459        continue
460
461      resource_inputs = set()
462      # Check for any resource inputs. If we find any, we update control_inputs
463      # and last_write_to_resource.
464      for inp, resource_type in _get_resource_inputs(op):
465        is_read = resource_type == ResourceType.READ_ONLY
466        input_id = ops.tensor_id(inp)
467
468        # If the op receives the same resource tensor twice as an input, we skip
469        # to avoid the op getting a control dependency on itself.
470        if input_id in resource_inputs:
471          continue
472
473        resource_inputs.add(input_id)
474        # Deal with switches, finally.
475        if inp.op.type == "Switch":
476          self._process_switch(inp.op, ops_which_must_run,
477                               last_write_to_resource, merge_for_resource)
478        is_building_function = op.graph.building_function
479        # Ensure uses of resources are serialized
480        if input_id in last_write_to_resource:
481          if is_building_function or (
482              last_write_to_resource[input_id]._control_flow_context
483              is op._control_flow_context):
484            control_inputs.add(last_write_to_resource[input_id])
485        # Ensure merges happen after the closing of a cond block
486        if input_id in merge_for_resource:
487          merge_for_resource[input_id]._add_control_input(op)
488
489        do_record = (
490            self.record_initial_resource_uses and
491            input_id not in first_use_for_res)
492
493        if is_read:
494          reads_list = reads_since_last_write_to_resource[input_id]
495          reads_list.append(op)
496
497          if do_record:
498            # Note: this will track the entire list that
499            # reads_since_last_write_to_resource maintains. Updates to it will
500            # and should be tracked, until the first write is encountered. At
501            # that point, reads_since_last_write_to_resource will contain a new
502            # empty list. This logic relies on that behavior.
503            first_use_for_res[input_id] = reads_list
504
505        else:
506          control_inputs.update(reads_since_last_write_to_resource[input_id])
507          reads_since_last_write_to_resource[input_id] = []
508          last_write_to_resource[input_id] = op
509
510          if do_record:
511            first_use_for_res[input_id] = [op]
512
513      if self.record_initial_resource_uses and op_is_stateful(op):
514        if resource_inputs:
515          resources_by_op[op] = tuple(resource_inputs)
516        else:
517          if None not in first_use_for_res:
518            first_use_for_res[None] = [op]
519          resources_by_op[op] = (None,)
520
521      if (op_is_stateful(op) and not resource_inputs
522          and op._control_flow_context is None):
523        if None in last_write_to_resource:
524          op._add_control_input(last_write_to_resource[None])
525        last_write_to_resource[None] = op
526
527      # Ensure ordering of collective ops
528      manager_ids = collective_manager_ids_from_op(op)
529      for manager_id in manager_ids:
530        if manager_id in collective_manager_scopes_opened:
531          # Chain this function call if the scope was opened.
532          op._add_control_input(collective_manager_scopes_opened[manager_id])
533          collective_manager_scopes_opened[manager_id] = op
534        else:
535          # If this op is in a scope not created here, create a chain starting
536          # at this op.
537          if manager_id in collective_manager_scopes_used:
538            op._add_control_input(collective_manager_scopes_used[manager_id])
539          collective_manager_scopes_used[manager_id] = op
540
541      if control_inputs and not is_building_function:
542        control_inputs = [
543            c for c in control_inputs
544            if c._control_flow_context is op._control_flow_context
545        ]
546
547      op._add_control_inputs(control_inputs)
548
549    # Record the ops which first use resources touched by "ops which must run".
550    if self.record_initial_resource_uses:
551      first_uses_by_output_ops = {}
552      for op in ops_which_must_run:
553        if op not in resources_by_op:
554          # This may happen with Merge/Switch nodes which are special cased
555          # above.
556          continue
557        for r in resources_by_op[op]:
558          if op not in first_uses_by_output_ops:
559            first_uses_by_output_ops[op] = set()
560          first_uses_by_output_ops[op].update(first_use_for_res[r])
561      # For each "op which must run", set a private attr indicating the ops that
562      # used the same resources it did.
563      for op in first_uses_by_output_ops:
564        others = [
565            other.name.encode() for other in first_uses_by_output_ops[op]
566        ]
567        l = attr_value_pb2.AttrValue.ListValue(s=others)
568        # TODO(mdan): Is there a way which doesn't use anonymous attrs?
569        op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l))
570
571    # Ensure all ops which must run do run
572    self.ops_which_must_run.update(ops_which_must_run)
573    control_output_op = None
574    for idx, r in enumerate(
575        nest.flatten(list(self._returned_tensors), expand_composites=True)):
576      if self.ops_which_must_run:
577        updated_ops_which_must_run = []
578        if r.graph.building_function:
579          # There may be many stateful ops in the graph. Adding them as
580          # control inputs to each function output could create excessive
581          # control edges in the graph. Thus we create an intermediate No-op
582          # to chain the control dependencies between stateful ops and
583          # function outputs.
584          if idx == 0:
585            control_output_op = control_flow_ops.no_op()
586            control_output_op._set_attr("_acd_function_control_output",
587                                        attr_value_pb2.AttrValue(b=True))
588            control_output_op._add_control_inputs(self.ops_which_must_run)
589          updated_ops_which_must_run = [control_output_op]
590        else:
591          updated_ops_which_must_run = [
592              o for o in self.ops_which_must_run
593              if o._control_flow_context is r.op._control_flow_context
594          ]
595        r.op._add_control_inputs(updated_ops_which_must_run)
596
597    self.collective_manager_ids_used = collective_manager_scopes_used
598
599
600_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")
601
602
603def register_acd_resource_resolver(f):
604  """Register a function for resolving resources touched by an op.
605
606  `f` is called for every Operation added in the ACD context with the op's
607  original resource reads and writes. `f` is expected to update the sets of
608  resource reads and writes in-place and return True if it updated either of the
609  sets, False otherwise.
610
611  Example:
612  @register_acd_resource_resolver
613  def identity_resolver(op, resource_reads, resource_writes):
614    # op: The `Operation` being processed by ACD currently.
615    # resource_reads: An `ObjectIdentitySet` of read-only resources.
616    # resource_writes: An `ObjectIdentitySet` of read-write resources.
617    def update(resource_inputs):
618      to_remove = []
619      to_add = []
620      for resource in resource_inputs:
621        if resource.op.type == "Identity":
622          to_remove.append(resource)
623          to_add.extend(resource.op.inputs)
624      for t in to_remove:
625        resource_inputs.discard(t)
626      resource_inputs.update(to_add)
627      return to_add or to_remove
628    return update(resource_reads) or update(resource_writes)
629
630  Args:
631    f: Python function with signature
632    (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool
633
634  Returns:
635    The function `f` after adding it to the registry.
636  """
637  _acd_resource_resolvers_registry.register(f)
638  return f
639
640
641@register_acd_resource_resolver
642def _identity_resolver(op, resource_reads, resource_writes):
643  """Replaces Identity output with its input in resource_inputs."""
644  del op
645  def update(resource_inputs):
646    to_remove = []
647    to_add = []
648    for resource in resource_inputs:
649      if resource.op.type == "Identity":
650        to_remove.append(resource)
651        to_add.extend(resource.op.inputs)
652    for t in to_remove:
653      resource_inputs.discard(t)
654    resource_inputs.update(to_add)
655    return to_add or to_remove
656
657  return update(resource_reads) or update(resource_writes)
658
659
660def _get_resource_inputs(op):
661  """Returns an iterable of resources touched by this `op`."""
662  reads, writes = utils.get_read_write_resource_inputs(op)
663  saturated = False
664  while not saturated:
665    saturated = True
666    for key in _acd_resource_resolvers_registry.list():
667      # Resolvers should return true if they are updating the list of
668      # resource_inputs.
669      # TODO(srbs): An alternate would be to just compare the old and new set
670      # but that may not be as fast.
671      updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes)
672      if updated:
673        # Conservatively remove any resources from `reads` that are also writes.
674        reads = reads.difference(writes)
675      saturated = saturated and not updated
676
677  # Note: A resource handle that is not written to is treated as read-only. We
678  # don't have a special way of denoting an unused resource.
679  for t in reads:
680    yield (t, ResourceType.READ_ONLY)
681  for t in writes:
682    yield (t, ResourceType.READ_WRITE)
683
684
685def automatic_control_dependencies(f):
686  """Wraps f to automatically insert control dependencies.
687
688  The inserted dependencies ensure that:
689    1. All stateful ops in f run when the result of f runs
690    2. Updates to the same resources happen in order.
691
692  Args:
693    f: the function to be wrapped.
694
695  Returns:
696    The wrapped function.
697  """
698
699  def wrapper(*args, **kwargs):
700    with AutomaticControlDependencies() as a:
701      result = f(*args, **kwargs)
702      result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
703      return nest.pack_sequence_as(result, result_flat)
704
705  return tf_decorator.make_decorator(f, wrapper)
706