1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Utility to lift subgraphs.""" 17 18import collections 19 20from tensorflow.python.framework import func_graph 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import op_selector 24from tensorflow.python.ops import resource_variable_ops 25from tensorflow.python.util import compat 26from tensorflow.python.util import object_identity 27from tensorflow.python.util.tf_export import tf_export 28 29 30UnliftableError = op_selector.UnliftableError 31 32 33def _as_operation(op_or_tensor): 34 if isinstance(op_or_tensor, ops.Tensor): 35 return op_or_tensor.op 36 return op_or_tensor 37 38 39def _constant_inputs(op_or_tensor): 40 return all(_as_operation(i).type == u"Const" 41 and not _as_operation(i).control_inputs 42 for i in op_selector.graph_inputs(_as_operation(op_or_tensor))) 43 44 45# Represents an input to `copied_op` which must be updated once 46# `old_graph_tensor` has been copied. 47_InputMutation = collections.namedtuple( 48 "_InputMutation", 49 ["copied_op", "input_index", "old_graph_tensor"]) 50 51 52# Represents a control input to `copied_op` which must be added once 53# `old_graph_op` has been copied. 54_ControlMutation = collections.namedtuple( 55 "_ControlMutation", 56 ["copied_op", "old_graph_op"]) 57 58 59def _copy_non_source(op, graph, op_map, base_graph): 60 """Copy an op directly to a given graph. 61 62 Generally `op`'s inputs should already have been copied. If this is not the 63 case, for example with v1 while_loops, then `_copy_non_source` inserts 64 placeholders for the unavailable Tensors and returns a list of required 65 mutations. 66 67 Args: 68 op: The op to be copied. 69 graph: The destination graph. 70 op_map: A dict mapping ops and tensors in the old graph to the new one. 71 base_graph: The graph we're copying from, for any necessary functions. 72 Returns: 73 A tuple of (required_inputs, required_control_inputs): 74 required_inputs: 75 A list of `_InputMutation` tuples containing inputs to `copied_op` which 76 must be updated once `old_graph_tensor` has been copied. 77 required_control_inputs: 78 A list of `_ControlMutation` tuples containing control inputs to 79 `copied_op` which must be added once `old_graph_op` has been copied. 80 """ 81 input_mutations = [] 82 control_mutations = [] 83 copied_inputs = [] 84 for input_index, original_input in enumerate(op.inputs): 85 copied_input = op_map.get(original_input, None) 86 if copied_input is None: 87 # An input for this op is missing due to a loop in the graph. We'll insert 88 # a placeholder for now and return information about the required post-hoc 89 # mutation. 90 copied_input = array_ops.placeholder( 91 name="unused_control_flow_input", 92 shape=original_input.shape, 93 dtype=original_input.dtype) 94 input_mutations.append( 95 # `copied_op` is filled in below, after we've created it. 96 _InputMutation(copied_op=None, 97 input_index=input_index, 98 old_graph_tensor=original_input)) 99 copied_inputs.append(copied_input) 100 101 copied_control_inputs = [] 102 for original_control_input in op.control_inputs: 103 copied_control_input = op_map.get(original_control_input, None) 104 if copied_control_input is None: 105 control_mutations.append( 106 _ControlMutation(copied_op=None, 107 old_graph_op=original_control_input)) 108 else: 109 copied_control_inputs.append(copied_control_input) 110 111 # Don't copy over nodes with _tpu_replicate attribute. This attributed is used 112 # to signal that the op was built inside a tpu_replicate context; if we're 113 # lifting it to another graph we're similarly lifting it into another context. 114 with ops.control_dependencies(copied_control_inputs), ops.device(op.device): 115 # pylint: disable=protected-access 116 f = base_graph._functions.get(op.type, None) 117 if f is not None and compat.as_str(f.name) not in graph._functions: 118 f.add_to_graph(graph) 119 # pylint: enable=protected-access 120 121 # Create a new op in the destination graph if it doesn't exist before. 122 copied_op = graph.create_op( 123 op_type=op.type, 124 inputs=copied_inputs, 125 dtypes=[x.dtype for x in op.outputs], 126 attrs={ 127 key: value for key, value in op.node_def.attr.items() 128 if not key.startswith("_class") and 129 not key.startswith("_tpu_replicate") 130 }, # b/128981532. 131 name=op.name) 132 op_map[op] = copied_op 133 for i, o in enumerate(op.outputs): 134 op_map[o] = copied_op.outputs[i] 135 136 return ([mutation._replace(copied_op=copied_op) 137 for mutation in input_mutations], 138 [mutation._replace(copied_op=copied_op) 139 for mutation in control_mutations]) 140 141 142def _copy_source(s, graph, op_map, handle_captures, inverse_captures, 143 base_graph): 144 """Create a source in a graph based on a Tensor from a different graph. 145 146 This function creates a placeholder analog of `s` in a graph with the 147 following behavior: 148 149 1) If s is a captured Tensor or Variable and handle_captures is set to True, 150 simply capture it in the new graph as well. 151 152 2) If s is a PlaceholderWithDefault whose default is a constant, preserve 153 said default in the new graph. 154 155 3) When applicable, copy resource variable metadata from `s` to the newly 156 created placeholder. 157 158 Args: 159 s: The source of interest. 160 graph: The destination graph. 161 op_map: A dict mapping ops and tensors in the old graph to the new one. 162 handle_captures: A boolean indicating whether to re-capture s in the new 163 graph or simply create a vanilla placeholder. 164 inverse_captures: A dict mapping s back to the Tensor or Variable that it 165 captures. 166 base_graph: The graph being copied from. 167 """ 168 if handle_captures and s in inverse_captures: 169 copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) 170 elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): 171 # Copy the default value to the graph. 172 default_value = s.op.inputs[0] 173 unavailable_inputs, unavailable_control_inputs = _copy_non_source( 174 op=default_value.op, graph=graph, op_map=op_map, 175 base_graph=base_graph) 176 if unavailable_inputs or unavailable_control_inputs: 177 raise AssertionError( 178 "Could not copy source node {} because it has inputs." 179 .format(default_value)) 180 181 with ops.device(s.op.device): 182 copied_placeholder = array_ops.placeholder_with_default( 183 input=op_map[default_value], shape=s.shape, name=s.op.name) 184 else: 185 with ops.device(s.op.device): 186 copied_placeholder = array_ops.placeholder( 187 dtype=s.dtype, shape=s.shape, name=s.op.name) 188 189 base_handle = resource_variable_ops.get_resource_handle_data(s) 190 if base_handle.shape_and_type: 191 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 192 copied_placeholder, 193 base_handle, 194 graph_mode=True) 195 196 op_map[s] = copied_placeholder 197 # Add an entry for the op of the source tensor so that if there are any nodes 198 # depending on that op via control dependencies it can work correctly. 199 op_map[s.op] = copied_placeholder.op 200 201 202@tf_export("__internal__.lift_to_graph", v1=[]) 203def lift_to_graph(tensors, 204 graph, 205 sources=None, 206 disallowed_placeholders=None, 207 add_sources=False, 208 handle_captures=False, 209 base_graph=None, 210 op_map=None): 211 """Copies the tensor and all its inputs recursively to the outer graph. 212 213 Args: 214 tensors: The Tensors to lift. 215 graph: The graph to lift to. 216 sources: Optional sequence of nodes to start from. If omitted the whole 217 subgraph which feeds into `init_tensor` is lifted. 218 disallowed_placeholders: An optional set of ops which may not appear in the 219 lifted graph. Defaults to all placeholders. 220 add_sources: A boolean indicating whether placeholders which are not in 221 sources should be allowed. 222 handle_captures: A boolean indicating whether to re-capture s in the new 223 graph or simply create a vanilla placeholder. 224 base_graph: The graph from which to lift ops. This will be inferred if not 225 specified. 226 op_map: A map contains all the existing nodes that have been lifted to the 227 destination graph, so they won't be lifted and copied again. 228 229 Returns: 230 A mapping from ops in the current default graph to ops in `graph`. 231 232 Raises: 233 UnliftableError: If a placeholder blocks lifting. 234 """ 235 variable_init_tensors = [] 236 init_tensors = [] 237 for tensor in tensors: 238 if isinstance(tensor, resource_variable_ops.ResourceVariable): 239 variable_init_tensors.append(tensor) 240 else: 241 init_tensors.append(tensor) 242 base_graph = base_graph or init_tensors[0].graph 243 op_map = op_map or object_identity.ObjectIdentityDictionary() 244 245 # Check that the initializer does not depend on any placeholders. 246 sources = object_identity.ObjectIdentitySet(sources or []) 247 visited_ops = set(x.op for x in sources) 248 op_outputs = collections.defaultdict(set) 249 250 # First we extract the subgraph between init_tensors and sources. 251 for init_tensor in init_tensors: 252 sources.update(op_selector.map_subgraph( 253 init_tensor=init_tensor, 254 sources=sources, 255 disallowed_placeholders=disallowed_placeholders, 256 visited_ops=visited_ops, 257 op_outputs=op_outputs, 258 add_sources=add_sources)) 259 260 # Try to topologically sort the nodes we've extracted. Now we know how many of 261 # their outputs are part of this subgraph. 262 ops_to_copy = [] 263 marked_ops = set([]) 264 ops_to_visit = [_as_operation(t) for t in init_tensors 265 if not op_outputs[_as_operation(t)]] 266 unvisited_ops = set(ops_to_visit) 267 while unvisited_ops: 268 while ops_to_visit: 269 op = ops_to_visit.pop() 270 if op in marked_ops: 271 continue 272 marked_ops.add(op) 273 ops_to_copy.append(op) 274 for inp in op_selector.graph_inputs(op): 275 # Don't lift the TPUReplicateMetadata nodes out of the function, because 276 # it has no registered kernels. 277 if inp.type == "TPUReplicateMetadata": 278 continue 279 unvisited_ops.add(inp) 280 if (all(x in marked_ops for x in op_outputs[inp]) and 281 inp not in sources): 282 ops_to_visit.append(inp) 283 unvisited_ops.difference_update(marked_ops) 284 if unvisited_ops: 285 # `unvisited_ops` should only have elements if the graph has a loop. In 286 # this case we want to keep copying and there's no topological ordering; 287 # we'll do ugly post-hoc mutations instead. 288 ops_to_visit.append(next(iter(unvisited_ops))) 289 290 # When the topological sort fails due to loops, it can result in exceptions 291 # later when copying a node which inputs haven't been copied yet. We can 292 # improve that pseudo-topological order slightly by putting the ops without 293 # inputs, such as constants, at the start of the topological order (i.e at 294 # the end of ops_to_copy). 295 ops_to_copy.sort(key=(lambda op: len(op_selector.graph_inputs(op)) == 0)) 296 297 # When lifting from one FuncGraph to another, we will need to capture the 298 # relevant tensors as well. 299 captures = [] 300 inverse_captures = object_identity.ObjectIdentityDictionary() 301 internal_captures = [] 302 if (isinstance(base_graph, func_graph.FuncGraph) and 303 isinstance(graph, func_graph.FuncGraph)): 304 captures = base_graph.captures 305 for external_capture, internal_capture in captures: 306 inverse_captures[internal_capture] = external_capture 307 internal_captures = base_graph.internal_captures 308 309 # ops_to_copy now holds a reverse topologically sorted list of ops which 310 # ends in the initializer. We copy those to the outermost graph and 311 # build the initialization op there. 312 with graph.as_default(): 313 for i in variable_init_tensors: 314 op_map[i] = i 315 source_ops = set() 316 # Add the sources in the same order as the original graph. 317 for s in internal_captures: 318 if s in sources: 319 sources.remove(s) 320 source_ops.add(s.op) 321 _copy_source( 322 s=s, 323 graph=graph, 324 op_map=op_map, 325 handle_captures=handle_captures, 326 inverse_captures=inverse_captures, 327 base_graph=base_graph) 328 for s in sources: 329 source_ops.add(s.op) 330 _copy_source( 331 s=s, 332 graph=graph, 333 op_map=op_map, 334 handle_captures=handle_captures, 335 inverse_captures=inverse_captures, 336 base_graph=base_graph) 337 338 input_mutations = [] 339 control_mutations = [] 340 for op in reversed(ops_to_copy): 341 if op in source_ops or op in op_map: 342 continue 343 new_input_mutations, new_control_mutations = _copy_non_source( 344 op=op, graph=graph, op_map=op_map, base_graph=base_graph) 345 input_mutations.extend(new_input_mutations) 346 control_mutations.extend(new_control_mutations) 347 348 # Mutate the new graph to insert any loops which existed in the source 349 # graph due to v1 while_loops. 350 # 351 # pylint: disable=protected-access 352 with graph._mutation_lock(): 353 for mutation in input_mutations: 354 mutation.copied_op._update_input( 355 mutation.input_index, op_map[mutation.old_graph_tensor]) 356 for mutation in control_mutations: 357 # Don't lift the TPUReplicateMetadata nodes out of the function, because 358 # it has no registered kernels. 359 if mutation.old_graph_op.type == "TPUReplicateMetadata": 360 continue 361 mutation.copied_op._add_control_input(op_map[mutation.old_graph_op]) 362 # pylint: enable=protected-access 363 364 return op_map 365