1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Utility to convert FunctionDef to GraphDef and Graph.""" 16 17import itertools 18 19 20from tensorflow.core.framework import function_pb2 21from tensorflow.core.framework import graph_pb2 22from tensorflow.core.framework import tensor_shape_pb2 23from tensorflow.core.framework import types_pb2 24from tensorflow.core.framework import versions_pb2 25from tensorflow.python.eager import context 26from tensorflow.python.framework import cpp_shape_inference_pb2 27from tensorflow.python.framework import importer 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import versions 30from tensorflow.python.framework.func_graph import FuncGraph 31from tensorflow.python.ops import resource_variable_ops 32 33 34def function_def_to_graph(fdef, 35 structured_input_signature=None, 36 structured_outputs=None, 37 input_shapes=None): 38 """Converts a FunctionDef to a FuncGraph (sub-class Graph). 39 40 The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. 41 The input tensors are represented as placeholders. 42 43 Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set 44 by the caller. 45 46 Args: 47 fdef: FunctionDef. 48 structured_input_signature: Optional. The structured input signature to 49 use for initializing the FuncGraph. See the docstring for FuncGraph for 50 more information. 51 structured_outputs: Optional. The structured outputs to use for 52 initializing the FuncGraph. See the docstring for FuncGraph for more 53 information. 54 input_shapes: Optional. A list of TensorShape objects of the shapes of 55 function inputs. Defaults to the function's "_input_shapes" attribute. If 56 specified, its length must match length of `fdef.signature.input_arg`. If 57 a shape is None, the corresponding input placeholder will have unknown 58 shape. 59 60 Returns: 61 A FuncGraph. 62 """ 63 func_graph = FuncGraph(fdef.signature.name, 64 structured_input_signature=structured_input_signature, 65 structured_outputs=structured_outputs) 66 if input_shapes is None: 67 input_shapes_attr = fdef.attr.get("_input_shapes", None) 68 if input_shapes_attr is not None: 69 raw_input_shapes = input_shapes_attr.list.shape 70 71 # Replace resource handle shapes in the inputs to disable shape inference. 72 # Setting the shape to either the variable handle shape (which is always 73 # `[]`) or the variable shape can cause shape inference issues. 74 input_shapes = [] 75 for input_shape, arg_def in zip(raw_input_shapes, 76 fdef.signature.input_arg): 77 if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data: 78 input_shapes.append(None) 79 else: 80 input_shapes.append(input_shape) 81 82 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( 83 fdef, input_shapes) 84 85 with func_graph.as_default(): 86 # Add all function nodes to the graph. 87 importer.import_graph_def_for_function(graph_def, name="") 88 89 # Initialize fields specific to FuncGraph. 90 91 # inputs 92 input_tensor_names = [ 93 nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg 94 ] 95 func_graph.inputs = [ 96 func_graph.get_tensor_by_name(name) for name in input_tensor_names 97 ] 98 99 # outputs 100 output_tensor_names = [ 101 nested_to_flat_tensor_name[fdef.ret[arg.name]] 102 for arg in fdef.signature.output_arg 103 ] 104 func_graph.outputs = [ 105 func_graph.get_tensor_by_name(name) for name in output_tensor_names 106 ] 107 func_graph.control_outputs = [ 108 func_graph.get_operation_by_name(fdef.control_ret[ret_name]) 109 for ret_name in fdef.signature.control_output 110 ] 111 112 _set_handle_data(func_graph, fdef) 113 114 for node in graph_def.node: 115 output_shapes = node.attr.get("_output_shapes", None) 116 if output_shapes is not None: 117 op = func_graph.get_operation_by_name(node.name) 118 # _output_shapes for functions can sometimes be too long because the 119 # output-intermediates-for-gradients version of the function was 120 # substituted before saving. We'll accept that here. (See b/133666530). 121 for output_index, shape in enumerate( 122 output_shapes.list.shape[:len(op.outputs)]): 123 op.outputs[output_index].set_shape(shape) 124 output_names = {} 125 for ret_arg_def, tensor_name in zip( 126 fdef.signature.output_arg, output_tensor_names): 127 output_names[ops.tensor_id( 128 func_graph.get_tensor_by_name(tensor_name))] = ( 129 ret_arg_def.name) 130 func_graph._output_names = output_names # pylint: disable=protected-access 131 return func_graph 132 133 134def is_function(fname): 135 """Checks for a function definition with `fname` in the current context.""" 136 if context.executing_eagerly(): 137 return context.context().has_function(fname) 138 else: 139 graph = ops.get_default_graph() 140 while graph is not None: 141 if graph._is_function(fname): # pylint: disable=protected-access 142 return True 143 if hasattr(graph, "outer_graph"): 144 graph = graph.outer_graph 145 else: 146 return False 147 148 149def function_def_to_graph_def(fdef, input_shapes=None): 150 """Convert a FunctionDef to a GraphDef. 151 152 Steps: 153 1. Creates placeholder nodes corresponding to inputs in 154 `FunctionDef.signature.input_arg`. 155 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 156 3. Renames inputs of all nodes to use the convention of GraphDef instead of 157 FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming 158 in FunctionDefs is different from GraphDefs. 159 160 Args: 161 fdef: FunctionDef. 162 input_shapes: Optional. A list of TensorShape objects of the shapes of 163 function inputs. If specified, its length must match length of 164 `fdef.signature.input_arg`. If a shape is None, the corresponding input 165 placeholder will have unknown shape. 166 167 Returns: 168 A tuple of (GraphDef, dict<string, string>). The dict contains a mapping 169 from nested tensor names (in FunctionDef) to flattened names (in GraphDef). 170 171 Raises: 172 ValueError: If the length of input_shapes does not match the number of 173 input_args or if the FunctionDef is invalid. 174 """ 175 graph_def = graph_pb2.GraphDef() 176 graph_def.versions.CopyFrom( 177 versions_pb2.VersionDef( 178 producer=versions.GRAPH_DEF_VERSION, 179 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) 180 181 default_graph = ops.get_default_graph() 182 183 copied_functions = set() 184 185 if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): 186 raise ValueError("Length of `input_shapes` must match the number " 187 f"of `input_arg`s in `fdef`. Got " 188 f"{len(input_shapes)} `input_shapes` and " 189 f"{len(fdef.signature.input_arg)} `input_arg`s.") 190 191 # 1. Create placeholders for input nodes. 192 for i, arg_def in enumerate(fdef.signature.input_arg): 193 node_def = graph_def.node.add() 194 node_def.name = arg_def.name 195 node_def.op = "Placeholder" 196 node_def.attr["dtype"].type = arg_def.type 197 if input_shapes and input_shapes[i] is not None: 198 input_shape = input_shapes[i] 199 if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto): 200 input_shape = input_shape.as_proto() 201 node_def.attr["shape"].shape.CopyFrom(input_shape) 202 arg_attrs = fdef.arg_attr[i].attr 203 for k in arg_attrs: 204 # Only copy internal attributes. Normal attributes for nodes cannot be 205 # applied to these Placeholder nodes. 206 if k == "_output_shapes": 207 if arg_attrs[k].WhichOneof("value") == "list": 208 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0]) 209 elif arg_attrs[k].WhichOneof("value") == "shape": 210 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].shape) 211 elif k.startswith("_"): 212 node_def.attr[k].CopyFrom(arg_attrs[k]) 213 214 # 2. Copy all body NodeDefs to the GraphDef. 215 graph_def.node.extend(fdef.node_def) 216 217 # 3. Perform the renaming. 218 219 # Build the tensor name mapping then flatten the tensor names. 220 # See comment on `FunctionDef.node_def` on how the tensor naming in 221 # FunctionDefs is different from GraphDefs. 222 nested_to_flat_tensor_name = {} 223 224 for arg_def in fdef.signature.input_arg: 225 nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) 226 control_name = "^" + arg_def.name 227 nested_to_flat_tensor_name[control_name] = control_name 228 229 for node_def in fdef.node_def: 230 graph = default_graph 231 while True: 232 f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access 233 if f is not None or not hasattr(graph, "outer_graph"): 234 break 235 graph = graph.outer_graph 236 237 if f is not None: 238 fdef = f.definition 239 op_def = fdef.signature 240 if node_def.op not in copied_functions: 241 # Since this function is referenced as an op type, we have no choice but 242 # to copy it into the GraphDef if we want downstream tools to process 243 # it. 244 graph_def.library.function.add().CopyFrom(fdef) 245 copied_functions.add(node_def.op) 246 if f.grad_func_name: 247 grad_def = function_pb2.GradientDef() 248 grad_def.function_name = f.name 249 grad_def.gradient_func = f.grad_func_name 250 graph_def.library.gradient.extend([grad_def]) 251 else: 252 op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access 253 254 for attr in op_def.attr: 255 if attr.type == "func": 256 fname = node_def.attr[attr.name].func.name 257 if not is_function(fname): 258 raise ValueError(f"Function {fname} was not found. Please make sure " 259 "the FunctionDef `fdef` is correct.") 260 elif attr.type == "list(func)": 261 for fn in node_def.attr[attr.name].list.func: 262 fname = fn.name 263 if not is_function(fname): 264 raise ValueError(f"Function {fname} was not found. Please make " 265 "sure the FunctionDef `fdef` is correct.") 266 267 # Iterate over output_args in op_def to build the map. 268 # Index of the output tensor in the flattened list of *all* output 269 # tensors of the op. 270 flattened_index = 0 271 for arg_def in op_def.output_arg: 272 num_args = _get_num_args(arg_def, node_def) 273 for i in range(num_args): 274 # Map tensor names from "node_name:output_arg_name:index" to 275 # "node_name:flattened_index". 276 nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) 277 flat_name = "{}:{}".format(node_def.name, flattened_index) 278 nested_to_flat_tensor_name[nested_name] = flat_name 279 flattened_index += 1 280 control_name = "^" + node_def.name 281 nested_to_flat_tensor_name[control_name] = control_name 282 283 # Update inputs of all nodes in graph. 284 for node_def in graph_def.node: 285 for i in range(len(node_def.input)): 286 node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] 287 288 return graph_def, nested_to_flat_tensor_name 289 290 291# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. 292def _get_num_args(arg_def, node_def): 293 if arg_def.number_attr: 294 return node_def.attr[arg_def.number_attr].i 295 elif arg_def.type_list_attr: 296 return len(node_def.attr[arg_def.type_list_attr].list.type) 297 elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: 298 return 1 299 else: 300 raise ValueError(f"Invalid arg_def:\n\n{arg_def}. Please make sure the " 301 "FunctionDef `fdef` is correct.") 302 303 304def _set_handle_data(func_graph, fdef): 305 """Adds handle data for resource type inputs and outputs.""" 306 # The shape of the handle itself is [], while the variable shape is 307 # saved in `handle_data`. Previously, the shape of the resource handle 308 # was set to `None`. Correct both shapes here. 309 for tensor, arg_def in itertools.chain( 310 zip(func_graph.inputs, fdef.signature.input_arg), 311 zip(func_graph.outputs, fdef.signature.output_arg)): 312 if arg_def.handle_data: 313 tensor.set_shape([]) 314 315 shape_and_dtype = arg_def.handle_data[0] 316 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 317 handle_data.is_set = True 318 handle_data.shape_and_type.append( 319 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 320 shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype)) 321 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 322 tensor, handle_data, True) 323