1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AutomaticControlDependencies and related functionality.""" 16 17import collections 18import enum 19 20from tensorflow.core.framework import attr_value_pb2 21from tensorflow.python.eager import context 22from tensorflow.python.framework import auto_control_deps_utils as utils 23from tensorflow.python.framework import dtypes as dtypes_module 24from tensorflow.python.framework import indexed_slices 25from tensorflow.python.framework import op_def_registry 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import registry 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import control_flow_util 32from tensorflow.python.ops import tensor_array_ops 33from tensorflow.python.util import nest 34from tensorflow.python.util import object_identity 35from tensorflow.python.util import tf_decorator 36 37# LINT.IfChange 38# Op types that should not run in program order, e.g. because they need to run 39# asynchronously to avoid deadlock. 40 41ASYNC_STATEFUL_OPS = frozenset(( 42 "CollectiveGather", 43 "CollectiveReduce", 44 "CollectiveBcastSend", 45 "CollectiveBcastSendV2", 46 "CollectiveBcastRecv", 47 "CollectiveBcastRecvV2", 48 "NcclAllReduce", 49 # We do not add "Send" here since we want it to be added as a control output 50 # in order to avoid being pruned. 51 "Recv", 52 "CollectiveInitializeCommunicator", 53 "CollectiveAssignGroupV2", 54)) 55 56LEGACY_RANDOM_OPS = frozenset(( 57 # These may be used in variable initializers -- thus their execution should 58 # not be dependent on other stateful operations. This is because although 59 # according to program order, tf.Variables may be created in sequence, 60 # their initialization happens outside of the program order (specifically, 61 # in graph mode their initialization happens by calling a grouped 62 # initializer operation or in eager mode, where initialization is lifted 63 # out of the tf.function and executed the first time the function is 64 # executed). 65 # 66 # Unless there is a specific dependency between the initializers 67 # themselves (e.g. one initializer depends on a Variable whose value depends 68 # on another initializer), the initialization can happen in any order so 69 # long as it's before the associated Variable read operations. 70 # 71 # Note that in general the randomness of legacy random operations is only 72 # guaranteed by providing a graph-level and op-level seed (and ordering of 73 # the same op across multiple iterations of a while_loop is specifically not 74 # guaranteed; see the discussion below). 75 # 76 # There is a possible race condition inside while_loop where the same 77 # random OpKernel instantiation is reused across multiple steps 78 # of the loop. Since legacy Random OpKernels have an internal rng state, 79 # automatic dependency tracking across loop steps would likely 80 # fix this race; and for that case this denylist is problematic. 81 # However, since automatic dependency tracking inside while loops is not 82 # currently supported, and there are no other examples of OpKernel reuse 83 # (each OpKernel is associated with a unique op in graph mode), 84 # this denylist has no effect on the aforementioned behavior. 85 # 86 # TODO(ebrevdo,skyewm): Modify the check against this denylist to 87 # only occur when the op is inside a "variable initialization scope"; and 88 # add proper autodeps inside while_loops that respects this updated check. 89 "RandomUniform", 90 "RandomUniformInt", 91 "RandomStandardNormal", 92 "ParameterizedTruncatedNormal", 93 "TruncatedNormal", 94 "RandomShuffle", 95 "Multinomial", 96 "RandomGamma", 97 "RandomGammaGrad", 98 "RandomPoisson", 99 "RandomPoissonV2", 100)) 101 102MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset(( 103 "InfeedEnqueue", 104 "InfeedEnqueueTuple", 105)) 106 107# These ops are order-insensitive ans should in theory run, but at the moment 108# they either always have the necessary data dependencies, or have workarounds 109# in existing code that would break when adding new control deps. This 110# inconsistency should be eventually fixed, but it would be more effective to 111# retire the list instead. 112SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS = frozenset(( 113 "CudnnRNN", 114 "CudnnRNNBackprop", 115 "CudnnRNNV2", 116 "CudnnRNNV3", 117 "CudnnRNNBackpropV2", 118 "CudnnRNNBackpropV3", 119 "EnqueueTPUEmbeddingSparseBatch", 120 "EnqueueTPUEmbeddingIntegerBatch", 121 "EnqueueTPUEmbeddingSparseTensorBatch", 122 "EnqueueTPUEmbeddingRaggedTensorBatch", 123 "EnqueueTPUEmbeddingArbitraryTensorBatch", 124 "RestoreV2", 125 "SaveV2", 126)) 127# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc) 128 129# Op types that are marked as stateless, but should be allowlisted to add auto 130# control dependencies. 131_ALLOWLIST_STATELESS_OPS = [ 132 # As TPU collective ops are blocking, if there are more than one collective 133 # op in the function, we need to make sure different collectives ops are 134 # scheduled in certain orders. Otherwise if at the same time all the 135 # replicas are launching different collective ops/programs, it may cause 136 # deadlock. 137 "AllToAll", 138 "CrossReplicaSum", 139 "CollectivePermute", 140] 141 142 143def op_is_stateful(op): 144 # pylint: disable=protected-access 145 ret = ((op._is_stateful and 146 ((op.type not in ASYNC_STATEFUL_OPS) and 147 (op.type not in LEGACY_RANDOM_OPS) and 148 (op.type not in SKIPPED_ORDER_INSENSITIVE_STATEFUL_OPS))) or 149 (op.type in _ALLOWLIST_STATELESS_OPS)) 150 return ret 151 152 153class ResourceType(enum.Enum): 154 READ_ONLY = "read-only" 155 READ_WRITE = "read-write" 156 157 158def collective_manager_ids_from_op(op): 159 """Returns CollectiveManager ID from the op if one exists, else None. 160 161 CollectiveManager adds collective and no_op operations tagged with an ID, 162 unique to the manager object. This function extracts that ID, or None, if the 163 node was not generated by a CollectiveManager. 164 165 Args: 166 op: `Operation` to get the collective manager ID from. 167 168 Returns: 169 List of CollectiveManager IDs used by the op. 170 """ 171 if op.type == "CollectiveReduce": 172 try: 173 return [op.get_attr("_collective_manager_id")] 174 except ValueError: 175 pass 176 elif op.type == "StatefulPartitionedCall": 177 try: 178 return op.get_attr(utils.COLLECTIVE_MANAGER_IDS) 179 except ValueError: 180 pass 181 return [] 182 183 184class AutomaticControlDependencies(object): 185 """Context manager to automatically add control dependencies. 186 187 Code under this context manager will act as if a sensible set of control 188 dependencies were present. More specifically: 189 1. All stateful ops in the scope will execute (with the exception of ops in 190 ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS) 191 2. Stateful ops which modify the same resource will execute in program order 192 193 Note: creating variables in an automatic control dependencies context is not 194 supported (the value of the variables will never change as they will keep 195 getting reinitialized). 196 197 NOT THREAD SAFE 198 """ 199 200 def __init__(self, 201 record_initial_resource_uses=False, 202 record_uses_of_resource_ids=None): 203 self._returned_tensors = object_identity.ObjectIdentitySet() 204 self.ops_which_must_run = set() 205 self.record_initial_resource_uses = record_initial_resource_uses 206 self.record_uses_of_resource_ids = record_uses_of_resource_ids 207 self._independent_ops = [] 208 209 def mark_as_return(self, tensor): 210 """Acts like identity but marks the `Tensor` as a return value. 211 212 This will possibly return a copy of the `Tensor`. Usage: 213 214 ``` 215 with AutomaticControlDependencies() as a: 216 ... 217 t = a.mark_as_return(t) 218 _ = ...(t...) # i.e. it's safe to use t here 219 ``` 220 221 Args: 222 tensor: the `Tensor` to be marked 223 224 Returns: 225 a copy of the `Tensor`. 226 """ 227 if isinstance(tensor, indexed_slices.IndexedSlices): 228 values = array_ops.identity(tensor.values) 229 indices = array_ops.identity(tensor.indices) 230 self._returned_tensors.add(indices) 231 self._returned_tensors.add(values) 232 return indexed_slices.IndexedSlices( 233 values, indices, dense_shape=tensor.dense_shape) 234 elif isinstance(tensor, sparse_tensor.SparseTensor): 235 values = array_ops.identity(tensor.values) 236 indices = array_ops.identity(tensor.indices) 237 self._returned_tensors.add(indices) 238 self._returned_tensors.add(values) 239 return sparse_tensor.SparseTensor( 240 indices, values, dense_shape=tensor.dense_shape) 241 elif isinstance(tensor, tensor_array_ops.TensorArray): 242 flow = array_ops.identity(tensor.flow) 243 self._returned_tensors.add(flow) 244 return tensor_array_ops.build_ta_with_new_flow(tensor, flow) 245 # We want to make the return values depend on the stateful operations, but 246 # we don't want to introduce a cycle, so we make the return value the result 247 # of a new identity operation that the stateful operations definitely don't 248 # depend on. 249 tensor = array_ops.identity(tensor) 250 self._returned_tensors.add(tensor) 251 return tensor 252 253 def run_independently(self, op): 254 """Marks the given op as independent. 255 256 Overrides any other rule for the op. 257 258 Independent ops are guaranteed to execute before the return values, but 259 are allowed to run in parallel with everything else. Use in programs which 260 can guarantee that an op has side effects that don't affect any other op. 261 262 Args: 263 op: An operation 264 """ 265 self._independent_ops.append(op) 266 op._set_attr("_independent_side_effects", attr_value_pb2.AttrValue(b=True)) # pylint: disable=protected-access 267 268 def __enter__(self): 269 if context.executing_eagerly(): 270 return self 271 # This code assumes no other thread is adding ops to the graph while 272 # we're adding ops to the graph. 273 # TODO(apassos): Fix this by locking the graph or using a temporary 274 # graph (but that would mess up devices and collections at least, 275 # probably other things as well). 276 g = ops.get_default_graph() 277 self._graph = g 278 g._add_control_dependencies = True # pylint: disable=protected-access 279 g.experimental_acd_manager = self 280 self._n_operations = len(g.get_operations()) 281 return self 282 283 def _process_switch(self, switch_op, ops_which_must_run, 284 last_write_to_resource, merge_for_resource): 285 """Processes a switch node for a resource input. 286 287 When tensorflow creates a cond, it creates a control flow context for each 288 branch of the cond. Each external tensor accessed by that branch is routed 289 through a switch op, which gets created in the graph _after_ the op which 290 uses that tensor get created. 291 292 If the resource comes from another switch op we process that one first. 293 294 _process_switch creates a corresponding merge node for the switch node. This 295 merge node is added to the outer control flow context of the switch 296 node. We also ensure that: 297 298 1. The switch node executes after the previous op which used the resource 299 tensor 300 301 2. Any op which uses a resource output of the switch node executes before 302 the merge for the switch node. 303 304 3. The next op which uses the input resource to the switch node (which 305 might be another switch node for the other branch of the conditional) 306 will execute after the merge node is done. 307 308 4. The merge node is marked as must_run so it will run even if no 309 subsequent operation uses the resource. 310 311 Args: 312 switch_op: the switch op to be processed 313 ops_which_must_run: the set of ops which must run 314 last_write_to_resource: map from resource tensor to last op updating 315 it 316 merge_for_resource: map from resource tensor to merge which must follow 317 all usages of it. 318 """ 319 # pylint: disable=protected-access 320 inp = switch_op.inputs[0] 321 input_id = ops.tensor_id(inp) 322 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": 323 self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, 324 merge_for_resource) 325 output = switch_op.outputs[0] 326 output_id = ops.tensor_id(output) 327 if output_id in merge_for_resource: 328 return 329 new_merge = control_flow_ops.merge( 330 switch_op.outputs, name="artificial_merge") 331 new_merge[0].op._control_flow_context = ( 332 switch_op._control_flow_context.outer_context) 333 # Ensures the merge always runs 334 ops_which_must_run.add(new_merge[0].op) 335 if input_id in last_write_to_resource: 336 # Ensures the switch executes after the previous op using the resource. 337 switch_op._add_control_input(last_write_to_resource[input_id]) 338 # Ensure the next op outside the cond happens after the merge. 339 last_write_to_resource[input_id] = new_merge[0].op 340 if input_id in merge_for_resource: 341 merge_for_resource[input_id]._add_control_input(new_merge[0].op) 342 for o in switch_op.outputs: 343 # Ensures the merge will execute after all ops inside the cond 344 merge_for_resource[ops.tensor_id(o)] = new_merge[0].op 345 346 def __exit__(self, unused_type, unused_value, unused_traceback): 347 # pylint: disable=protected-access 348 if context.executing_eagerly(): 349 return 350 351 if self._graph is not ops.get_default_graph(): 352 raise RuntimeError( 353 "Within the automatic control dependency context, the default graph" 354 f" cannot change. Upon entry it was {self._graph}, but on exit it" 355 f" changed to {ops.get_default_graph()}") 356 357 outer_graph = getattr(self._graph, "outer_graph", None) 358 if outer_graph is not None: 359 self._graph._add_control_dependencies = outer_graph._add_control_dependencies 360 else: 361 self._graph._add_control_dependencies = False 362 self._graph.experimental_acd_manager = None 363 364 # map from resource tensor to the last op which wrote to it 365 last_write_to_resource = {} 366 # map from resource tensor to the list of reads from it since the last 367 # write or since the beginning of the function. 368 reads_since_last_write_to_resource = collections.defaultdict(list) 369 # CollectiveManager manager_ids within a particular function call should not 370 # be needed outside of that function call. So we keep them separate (though 371 # the general idea of the maps is the same, in the future, we'll need to 372 # correctly thread the control output outside). 373 # Map from collective manager scope to the last op which used it 374 collective_manager_scopes_opened = {} 375 collective_manager_scopes_used = {} 376 # set of conditional and loop exits 377 ops_which_must_run = set() 378 # merge which must depend on ops which use this resource 379 merge_for_resource = {} 380 381 new_operations = self._graph.get_operations()[self._n_operations:] 382 first_use_for_res = {} 383 resources_by_op = {} 384 385 # Ensures that uses of resource tensors get serialized properly and all 386 # execute. This is done by keeping a map from resource tensor to the last op 387 # in graph-construction order which used it (last_write_to_resource). 388 # 389 # Conditionals are written in TensorFlow such that every external tensor 390 # accessed in the conditional goes through a switch op and every return 391 # tensor (it's guaranteed that there will be at least one) goes through a 392 # merge op. 393 # 394 # To handle conditionals, switches are handled in a special way (see 395 # comments for _process_switch). Merge nodes created by TF's conditional 396 # logic (as opposed to by _process_switch) are forced to run and also get a 397 # control dependency added to them to ensure all stateful ops inside their 398 # control flow context run. 399 # 400 # We also ensure that if an op is using a resource output by a switch node 401 # (that is, a resource tensor for which there's a value in 402 # merge_for_resource) this op will run before the merge for that resource. 403 # 404 # We try to add control inputs to nodes respecting their control flow 405 # contexts to avoid dead nodes propagating everywhere and leading to 406 # "retval[0] doesn't have value" errors. If a node gets a control dependency 407 # on a dead node (i.e. a note from an untaken control flow branch) that node 408 # will be marked as dead unless it's a merge node. 409 # 410 # TODO(apassos): serialize non-resource-taking stateful ops as well, and 411 # test that it works. Support while loops. Support init_scope escaping from 412 # this. 413 for op in new_operations: 414 # TODO(apassos) make this code safely support while loops. 415 if control_flow_util.IsInWhileLoop(op): 416 continue 417 control_inputs = set() 418 419 if op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS: 420 # This will add it to self._independent_ops, but also mark it with an 421 # attribute. 422 self.run_independently(op) 423 424 if op in self._independent_ops: 425 ops_which_must_run.add(op) 426 continue 427 428 # Ensure stateful ops run. 429 # Read-only ops are added to control outputs if the read value is 430 # consumed. This covers the case when the read value is returned from 431 # the function since that goes through a tf.identity in mark_as_return. 432 if ((op_def_registry.get(op.type) is None) or 433 (op_is_stateful(op) and 434 (op.type not in utils.RESOURCE_READ_OPS or 435 any(output.consumers() for output in op.outputs)))): 436 ops_which_must_run.add(op) 437 438 # Make a note of all opened manager_ids. 439 if op.type == "NoOp": 440 try: 441 collective_manager_scopes_opened[op.get_attr( 442 "_collective_manager_id")] = op 443 except ValueError: 444 pass 445 # Ignore switches (they're handled separately) 446 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: 447 continue 448 # Make merges trigger all other computation which must run 449 # TODO(mdan): Don't do this. Write a transform to chains instead. 450 # See core/common_runtime/control_flow_deps_to_chains.cc. 451 if op.type == "Merge": 452 for o in ops_which_must_run: 453 op._add_control_input(o) 454 for inp in o.inputs: 455 input_id = ops.tensor_id(inp) 456 if input_id in last_write_to_resource: 457 last_write_to_resource[input_id] = op 458 ops_which_must_run = set([op]) 459 continue 460 461 resource_inputs = set() 462 # Check for any resource inputs. If we find any, we update control_inputs 463 # and last_write_to_resource. 464 for inp, resource_type in _get_resource_inputs(op): 465 is_read = resource_type == ResourceType.READ_ONLY 466 input_id = ops.tensor_id(inp) 467 468 # If the op receives the same resource tensor twice as an input, we skip 469 # to avoid the op getting a control dependency on itself. 470 if input_id in resource_inputs: 471 continue 472 473 resource_inputs.add(input_id) 474 # Deal with switches, finally. 475 if inp.op.type == "Switch": 476 self._process_switch(inp.op, ops_which_must_run, 477 last_write_to_resource, merge_for_resource) 478 is_building_function = op.graph.building_function 479 # Ensure uses of resources are serialized 480 if input_id in last_write_to_resource: 481 if is_building_function or ( 482 last_write_to_resource[input_id]._control_flow_context 483 is op._control_flow_context): 484 control_inputs.add(last_write_to_resource[input_id]) 485 # Ensure merges happen after the closing of a cond block 486 if input_id in merge_for_resource: 487 merge_for_resource[input_id]._add_control_input(op) 488 489 do_record = ( 490 self.record_initial_resource_uses and 491 input_id not in first_use_for_res) 492 493 if is_read: 494 reads_list = reads_since_last_write_to_resource[input_id] 495 reads_list.append(op) 496 497 if do_record: 498 # Note: this will track the entire list that 499 # reads_since_last_write_to_resource maintains. Updates to it will 500 # and should be tracked, until the first write is encountered. At 501 # that point, reads_since_last_write_to_resource will contain a new 502 # empty list. This logic relies on that behavior. 503 first_use_for_res[input_id] = reads_list 504 505 else: 506 control_inputs.update(reads_since_last_write_to_resource[input_id]) 507 reads_since_last_write_to_resource[input_id] = [] 508 last_write_to_resource[input_id] = op 509 510 if do_record: 511 first_use_for_res[input_id] = [op] 512 513 if self.record_initial_resource_uses and op_is_stateful(op): 514 if resource_inputs: 515 resources_by_op[op] = tuple(resource_inputs) 516 else: 517 if None not in first_use_for_res: 518 first_use_for_res[None] = [op] 519 resources_by_op[op] = (None,) 520 521 if (op_is_stateful(op) and not resource_inputs 522 and op._control_flow_context is None): 523 if None in last_write_to_resource: 524 op._add_control_input(last_write_to_resource[None]) 525 last_write_to_resource[None] = op 526 527 # Ensure ordering of collective ops 528 manager_ids = collective_manager_ids_from_op(op) 529 for manager_id in manager_ids: 530 if manager_id in collective_manager_scopes_opened: 531 # Chain this function call if the scope was opened. 532 op._add_control_input(collective_manager_scopes_opened[manager_id]) 533 collective_manager_scopes_opened[manager_id] = op 534 else: 535 # If this op is in a scope not created here, create a chain starting 536 # at this op. 537 if manager_id in collective_manager_scopes_used: 538 op._add_control_input(collective_manager_scopes_used[manager_id]) 539 collective_manager_scopes_used[manager_id] = op 540 541 if control_inputs and not is_building_function: 542 control_inputs = [ 543 c for c in control_inputs 544 if c._control_flow_context is op._control_flow_context 545 ] 546 547 op._add_control_inputs(control_inputs) 548 549 # Record the ops which first use resources touched by "ops which must run". 550 if self.record_initial_resource_uses: 551 first_uses_by_output_ops = {} 552 for op in ops_which_must_run: 553 if op not in resources_by_op: 554 # This may happen with Merge/Switch nodes which are special cased 555 # above. 556 continue 557 for r in resources_by_op[op]: 558 if op not in first_uses_by_output_ops: 559 first_uses_by_output_ops[op] = set() 560 first_uses_by_output_ops[op].update(first_use_for_res[r]) 561 # For each "op which must run", set a private attr indicating the ops that 562 # used the same resources it did. 563 for op in first_uses_by_output_ops: 564 others = [ 565 other.name.encode() for other in first_uses_by_output_ops[op] 566 ] 567 l = attr_value_pb2.AttrValue.ListValue(s=others) 568 # TODO(mdan): Is there a way which doesn't use anonymous attrs? 569 op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l)) 570 571 # Ensure all ops which must run do run 572 self.ops_which_must_run.update(ops_which_must_run) 573 control_output_op = None 574 for idx, r in enumerate( 575 nest.flatten(list(self._returned_tensors), expand_composites=True)): 576 if self.ops_which_must_run: 577 updated_ops_which_must_run = [] 578 if r.graph.building_function: 579 # There may be many stateful ops in the graph. Adding them as 580 # control inputs to each function output could create excessive 581 # control edges in the graph. Thus we create an intermediate No-op 582 # to chain the control dependencies between stateful ops and 583 # function outputs. 584 if idx == 0: 585 control_output_op = control_flow_ops.no_op() 586 control_output_op._set_attr("_acd_function_control_output", 587 attr_value_pb2.AttrValue(b=True)) 588 control_output_op._add_control_inputs(self.ops_which_must_run) 589 updated_ops_which_must_run = [control_output_op] 590 else: 591 updated_ops_which_must_run = [ 592 o for o in self.ops_which_must_run 593 if o._control_flow_context is r.op._control_flow_context 594 ] 595 r.op._add_control_inputs(updated_ops_which_must_run) 596 597 self.collective_manager_ids_used = collective_manager_scopes_used 598 599 600_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers") 601 602 603def register_acd_resource_resolver(f): 604 """Register a function for resolving resources touched by an op. 605 606 `f` is called for every Operation added in the ACD context with the op's 607 original resource reads and writes. `f` is expected to update the sets of 608 resource reads and writes in-place and return True if it updated either of the 609 sets, False otherwise. 610 611 Example: 612 @register_acd_resource_resolver 613 def identity_resolver(op, resource_reads, resource_writes): 614 # op: The `Operation` being processed by ACD currently. 615 # resource_reads: An `ObjectIdentitySet` of read-only resources. 616 # resource_writes: An `ObjectIdentitySet` of read-write resources. 617 def update(resource_inputs): 618 to_remove = [] 619 to_add = [] 620 for resource in resource_inputs: 621 if resource.op.type == "Identity": 622 to_remove.append(resource) 623 to_add.extend(resource.op.inputs) 624 for t in to_remove: 625 resource_inputs.discard(t) 626 resource_inputs.update(to_add) 627 return to_add or to_remove 628 return update(resource_reads) or update(resource_writes) 629 630 Args: 631 f: Python function with signature 632 (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool 633 634 Returns: 635 The function `f` after adding it to the registry. 636 """ 637 _acd_resource_resolvers_registry.register(f) 638 return f 639 640 641@register_acd_resource_resolver 642def _identity_resolver(op, resource_reads, resource_writes): 643 """Replaces Identity output with its input in resource_inputs.""" 644 del op 645 def update(resource_inputs): 646 to_remove = [] 647 to_add = [] 648 for resource in resource_inputs: 649 if resource.op.type == "Identity": 650 to_remove.append(resource) 651 to_add.extend(resource.op.inputs) 652 for t in to_remove: 653 resource_inputs.discard(t) 654 resource_inputs.update(to_add) 655 return to_add or to_remove 656 657 return update(resource_reads) or update(resource_writes) 658 659 660def _get_resource_inputs(op): 661 """Returns an iterable of resources touched by this `op`.""" 662 reads, writes = utils.get_read_write_resource_inputs(op) 663 saturated = False 664 while not saturated: 665 saturated = True 666 for key in _acd_resource_resolvers_registry.list(): 667 # Resolvers should return true if they are updating the list of 668 # resource_inputs. 669 # TODO(srbs): An alternate would be to just compare the old and new set 670 # but that may not be as fast. 671 updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes) 672 if updated: 673 # Conservatively remove any resources from `reads` that are also writes. 674 reads = reads.difference(writes) 675 saturated = saturated and not updated 676 677 # Note: A resource handle that is not written to is treated as read-only. We 678 # don't have a special way of denoting an unused resource. 679 for t in reads: 680 yield (t, ResourceType.READ_ONLY) 681 for t in writes: 682 yield (t, ResourceType.READ_WRITE) 683 684 685def automatic_control_dependencies(f): 686 """Wraps f to automatically insert control dependencies. 687 688 The inserted dependencies ensure that: 689 1. All stateful ops in f run when the result of f runs 690 2. Updates to the same resources happen in order. 691 692 Args: 693 f: the function to be wrapped. 694 695 Returns: 696 The wrapped function. 697 """ 698 699 def wrapper(*args, **kwargs): 700 with AutomaticControlDependencies() as a: 701 result = f(*args, **kwargs) 702 result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] 703 return nest.pack_sequence_as(result, result_flat) 704 705 return tf_decorator.make_decorator(f, wrapper) 706