xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/function_def_to_graph.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility to convert FunctionDef to GraphDef and Graph."""
16
17import itertools
18
19
20from tensorflow.core.framework import function_pb2
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import tensor_shape_pb2
23from tensorflow.core.framework import types_pb2
24from tensorflow.core.framework import versions_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.framework import cpp_shape_inference_pb2
27from tensorflow.python.framework import importer
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import versions
30from tensorflow.python.framework.func_graph import FuncGraph
31from tensorflow.python.ops import resource_variable_ops
32
33
34def function_def_to_graph(fdef,
35                          structured_input_signature=None,
36                          structured_outputs=None,
37                          input_shapes=None):
38  """Converts a FunctionDef to a FuncGraph (sub-class Graph).
39
40  The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
41  The input tensors are represented as placeholders.
42
43  Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
44  by the caller.
45
46  Args:
47    fdef: FunctionDef.
48    structured_input_signature: Optional. The structured input signature to
49      use for initializing the FuncGraph. See the docstring for FuncGraph for
50      more information.
51    structured_outputs: Optional. The structured outputs to use for
52      initializing the FuncGraph. See the docstring for FuncGraph for more
53      information.
54    input_shapes: Optional. A list of TensorShape objects of the shapes of
55      function inputs. Defaults to the function's "_input_shapes" attribute. If
56      specified, its length must match length of `fdef.signature.input_arg`. If
57      a shape is None, the corresponding input placeholder will have unknown
58      shape.
59
60  Returns:
61    A FuncGraph.
62  """
63  func_graph = FuncGraph(fdef.signature.name,
64                         structured_input_signature=structured_input_signature,
65                         structured_outputs=structured_outputs)
66  if input_shapes is None:
67    input_shapes_attr = fdef.attr.get("_input_shapes", None)
68    if input_shapes_attr is not None:
69      raw_input_shapes = input_shapes_attr.list.shape
70
71      # Replace resource handle shapes in the inputs to disable shape inference.
72      # Setting the shape to either the variable handle shape (which is always
73      # `[]`) or the variable shape can cause shape inference issues.
74      input_shapes = []
75      for input_shape, arg_def in zip(raw_input_shapes,
76                                      fdef.signature.input_arg):
77        if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data:
78          input_shapes.append(None)
79        else:
80          input_shapes.append(input_shape)
81
82  graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
83      fdef, input_shapes)
84
85  with func_graph.as_default():
86    # Add all function nodes to the graph.
87    importer.import_graph_def_for_function(graph_def, name="")
88
89    # Initialize fields specific to FuncGraph.
90
91    # inputs
92    input_tensor_names = [
93        nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
94    ]
95    func_graph.inputs = [
96        func_graph.get_tensor_by_name(name) for name in input_tensor_names
97    ]
98
99    # outputs
100    output_tensor_names = [
101        nested_to_flat_tensor_name[fdef.ret[arg.name]]
102        for arg in fdef.signature.output_arg
103    ]
104    func_graph.outputs = [
105        func_graph.get_tensor_by_name(name) for name in output_tensor_names
106    ]
107    func_graph.control_outputs = [
108        func_graph.get_operation_by_name(fdef.control_ret[ret_name])
109        for ret_name in fdef.signature.control_output
110    ]
111
112    _set_handle_data(func_graph, fdef)
113
114    for node in graph_def.node:
115      output_shapes = node.attr.get("_output_shapes", None)
116      if output_shapes is not None:
117        op = func_graph.get_operation_by_name(node.name)
118        # _output_shapes for functions can sometimes be too long because the
119        # output-intermediates-for-gradients version of the function was
120        # substituted before saving. We'll accept that here. (See b/133666530).
121        for output_index, shape in enumerate(
122            output_shapes.list.shape[:len(op.outputs)]):
123          op.outputs[output_index].set_shape(shape)
124    output_names = {}
125    for ret_arg_def, tensor_name in zip(
126        fdef.signature.output_arg, output_tensor_names):
127      output_names[ops.tensor_id(
128          func_graph.get_tensor_by_name(tensor_name))] = (
129              ret_arg_def.name)
130    func_graph._output_names = output_names  # pylint: disable=protected-access
131  return func_graph
132
133
134def is_function(fname):
135  """Checks for a function definition with `fname` in the current context."""
136  if context.executing_eagerly():
137    return context.context().has_function(fname)
138  else:
139    graph = ops.get_default_graph()
140    while graph is not None:
141      if graph._is_function(fname):  # pylint: disable=protected-access
142        return True
143      if hasattr(graph, "outer_graph"):
144        graph = graph.outer_graph
145      else:
146        return False
147
148
149def function_def_to_graph_def(fdef, input_shapes=None):
150  """Convert a FunctionDef to a GraphDef.
151
152  Steps:
153  1. Creates placeholder nodes corresponding to inputs in
154     `FunctionDef.signature.input_arg`.
155  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
156  3. Renames inputs of all nodes to use the convention of GraphDef instead of
157     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
158     in FunctionDefs is different from GraphDefs.
159
160  Args:
161    fdef: FunctionDef.
162    input_shapes: Optional. A list of TensorShape objects of the shapes of
163      function inputs. If specified, its length must match length of
164      `fdef.signature.input_arg`. If a shape is None, the corresponding input
165      placeholder will have unknown shape.
166
167  Returns:
168    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
169    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
170
171  Raises:
172    ValueError: If the length of input_shapes does not match the number of
173      input_args or if the FunctionDef is invalid.
174  """
175  graph_def = graph_pb2.GraphDef()
176  graph_def.versions.CopyFrom(
177      versions_pb2.VersionDef(
178          producer=versions.GRAPH_DEF_VERSION,
179          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
180
181  default_graph = ops.get_default_graph()
182
183  copied_functions = set()
184
185  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
186    raise ValueError("Length of `input_shapes` must match the number "
187                     f"of `input_arg`s in `fdef`. Got "
188                     f"{len(input_shapes)} `input_shapes` and "
189                     f"{len(fdef.signature.input_arg)} `input_arg`s.")
190
191  # 1. Create placeholders for input nodes.
192  for i, arg_def in enumerate(fdef.signature.input_arg):
193    node_def = graph_def.node.add()
194    node_def.name = arg_def.name
195    node_def.op = "Placeholder"
196    node_def.attr["dtype"].type = arg_def.type
197    if input_shapes and input_shapes[i] is not None:
198      input_shape = input_shapes[i]
199      if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
200        input_shape = input_shape.as_proto()
201      node_def.attr["shape"].shape.CopyFrom(input_shape)
202    arg_attrs = fdef.arg_attr[i].attr
203    for k in arg_attrs:
204      # Only copy internal attributes. Normal attributes for nodes cannot be
205      # applied to these Placeholder nodes.
206      if k == "_output_shapes":
207        if arg_attrs[k].WhichOneof("value") == "list":
208          node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0])
209        elif arg_attrs[k].WhichOneof("value") == "shape":
210          node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].shape)
211      elif k.startswith("_"):
212        node_def.attr[k].CopyFrom(arg_attrs[k])
213
214  # 2. Copy all body NodeDefs to the GraphDef.
215  graph_def.node.extend(fdef.node_def)
216
217  # 3. Perform the renaming.
218
219  # Build the tensor name mapping then flatten the tensor names.
220  # See comment on `FunctionDef.node_def` on how the tensor naming in
221  # FunctionDefs is different from GraphDefs.
222  nested_to_flat_tensor_name = {}
223
224  for arg_def in fdef.signature.input_arg:
225    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
226    control_name = "^" + arg_def.name
227    nested_to_flat_tensor_name[control_name] = control_name
228
229  for node_def in fdef.node_def:
230    graph = default_graph
231    while True:
232      f = graph._functions.get(node_def.op, None)  # pylint: disable=protected-access
233      if f is not None or not hasattr(graph, "outer_graph"):
234        break
235      graph = graph.outer_graph
236
237    if f is not None:
238      fdef = f.definition
239      op_def = fdef.signature
240      if node_def.op not in copied_functions:
241        # Since this function is referenced as an op type, we have no choice but
242        # to copy it into the GraphDef if we want downstream tools to process
243        # it.
244        graph_def.library.function.add().CopyFrom(fdef)
245        copied_functions.add(node_def.op)
246        if f.grad_func_name:
247          grad_def = function_pb2.GradientDef()
248          grad_def.function_name = f.name
249          grad_def.gradient_func = f.grad_func_name
250          graph_def.library.gradient.extend([grad_def])
251    else:
252      op_def = default_graph._get_op_def(node_def.op)  # pylint: disable=protected-access
253
254    for attr in op_def.attr:
255      if attr.type == "func":
256        fname = node_def.attr[attr.name].func.name
257        if not is_function(fname):
258          raise ValueError(f"Function {fname} was not found. Please make sure "
259                           "the FunctionDef `fdef` is correct.")
260      elif attr.type == "list(func)":
261        for fn in node_def.attr[attr.name].list.func:
262          fname = fn.name
263          if not is_function(fname):
264            raise ValueError(f"Function {fname} was not found. Please make "
265                             "sure the FunctionDef `fdef` is correct.")
266
267    # Iterate over output_args in op_def to build the map.
268    # Index of the output tensor in the flattened list of *all* output
269    # tensors of the op.
270    flattened_index = 0
271    for arg_def in op_def.output_arg:
272      num_args = _get_num_args(arg_def, node_def)
273      for i in range(num_args):
274        # Map tensor names from "node_name:output_arg_name:index" to
275        # "node_name:flattened_index".
276        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
277        flat_name = "{}:{}".format(node_def.name, flattened_index)
278        nested_to_flat_tensor_name[nested_name] = flat_name
279        flattened_index += 1
280    control_name = "^" + node_def.name
281    nested_to_flat_tensor_name[control_name] = control_name
282
283  # Update inputs of all nodes in graph.
284  for node_def in graph_def.node:
285    for i in range(len(node_def.input)):
286      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
287
288  return graph_def, nested_to_flat_tensor_name
289
290
291# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
292def _get_num_args(arg_def, node_def):
293  if arg_def.number_attr:
294    return node_def.attr[arg_def.number_attr].i
295  elif arg_def.type_list_attr:
296    return len(node_def.attr[arg_def.type_list_attr].list.type)
297  elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
298    return 1
299  else:
300    raise ValueError(f"Invalid arg_def:\n\n{arg_def}. Please make sure the "
301                     "FunctionDef `fdef` is correct.")
302
303
304def _set_handle_data(func_graph, fdef):
305  """Adds handle data for resource type inputs and outputs."""
306  # The shape of the handle itself is [], while the variable shape is
307  # saved in `handle_data`. Previously, the shape of the resource handle
308  # was set to `None`. Correct both shapes here.
309  for tensor, arg_def in itertools.chain(
310      zip(func_graph.inputs, fdef.signature.input_arg),
311      zip(func_graph.outputs, fdef.signature.output_arg)):
312    if arg_def.handle_data:
313      tensor.set_shape([])
314
315      shape_and_dtype = arg_def.handle_data[0]
316      handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
317      handle_data.is_set = True
318      handle_data.shape_and_type.append(
319          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
320              shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
321      resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
322          tensor, handle_data, True)
323