xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/vis_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-import-not-at-top
17"""Utilities related to model visualization."""
18
19import os
20import sys
21from tensorflow.python.keras.utils.io_utils import path_to_string
22from tensorflow.python.util import nest
23from tensorflow.python.util.tf_export import keras_export
24
25
26try:
27  # pydot-ng is a fork of pydot that is better maintained.
28  import pydot_ng as pydot
29except ImportError:
30  # pydotplus is an improved version of pydot
31  try:
32    import pydotplus as pydot
33  except ImportError:
34    # Fall back on pydot if necessary.
35    try:
36      import pydot
37    except ImportError:
38      pydot = None
39
40
41def check_pydot():
42  """Returns True if PyDot and Graphviz are available."""
43  if pydot is None:
44    return False
45  try:
46    # Attempt to create an image of a blank graph
47    # to check the pydot/graphviz installation.
48    pydot.Dot.create(pydot.Dot())
49    return True
50  except (OSError, pydot.InvocationException):
51    return False
52
53
54def is_wrapped_model(layer):
55  from tensorflow.python.keras.engine import functional
56  from tensorflow.python.keras.layers import wrappers
57  return (isinstance(layer, wrappers.Wrapper) and
58          isinstance(layer.layer, functional.Functional))
59
60
61def add_edge(dot, src, dst):
62  if not dot.get_edge(src, dst):
63    dot.add_edge(pydot.Edge(src, dst))
64
65
66@keras_export('keras.utils.model_to_dot')
67def model_to_dot(model,
68                 show_shapes=False,
69                 show_dtype=False,
70                 show_layer_names=True,
71                 rankdir='TB',
72                 expand_nested=False,
73                 dpi=96,
74                 subgraph=False):
75  """Convert a Keras model to dot format.
76
77  Args:
78    model: A Keras model instance.
79    show_shapes: whether to display shape information.
80    show_dtype: whether to display layer dtypes.
81    show_layer_names: whether to display layer names.
82    rankdir: `rankdir` argument passed to PyDot,
83        a string specifying the format of the plot:
84        'TB' creates a vertical plot;
85        'LR' creates a horizontal plot.
86    expand_nested: whether to expand nested models into clusters.
87    dpi: Dots per inch.
88    subgraph: whether to return a `pydot.Cluster` instance.
89
90  Returns:
91    A `pydot.Dot` instance representing the Keras model or
92    a `pydot.Cluster` instance representing nested model if
93    `subgraph=True`.
94
95  Raises:
96    ImportError: if graphviz or pydot are not available.
97  """
98  from tensorflow.python.keras.layers import wrappers
99  from tensorflow.python.keras.engine import sequential
100  from tensorflow.python.keras.engine import functional
101
102  if not check_pydot():
103    message = (
104        'You must install pydot (`pip install pydot`) '
105        'and install graphviz '
106        '(see instructions at https://graphviz.gitlab.io/download/) ',
107        'for plot_model/model_to_dot to work.')
108    if 'IPython.core.magics.namespace' in sys.modules:
109      # We don't raise an exception here in order to avoid crashing notebook
110      # tests where graphviz is not available.
111      print(message)
112      return
113    else:
114      raise ImportError(message)
115
116  if subgraph:
117    dot = pydot.Cluster(style='dashed', graph_name=model.name)
118    dot.set('label', model.name)
119    dot.set('labeljust', 'l')
120  else:
121    dot = pydot.Dot()
122    dot.set('rankdir', rankdir)
123    dot.set('concentrate', True)
124    dot.set('dpi', dpi)
125    dot.set_node_defaults(shape='record')
126
127  sub_n_first_node = {}
128  sub_n_last_node = {}
129  sub_w_first_node = {}
130  sub_w_last_node = {}
131
132  layers = model.layers
133  if not model._is_graph_network:
134    node = pydot.Node(str(id(model)), label=model.name)
135    dot.add_node(node)
136    return dot
137  elif isinstance(model, sequential.Sequential):
138    if not model.built:
139      model.build()
140    layers = super(sequential.Sequential, model).layers
141
142  # Create graph nodes.
143  for i, layer in enumerate(layers):
144    layer_id = str(id(layer))
145
146    # Append a wrapped layer's label to node's label, if it exists.
147    layer_name = layer.name
148    class_name = layer.__class__.__name__
149
150    if isinstance(layer, wrappers.Wrapper):
151      if expand_nested and isinstance(layer.layer,
152                                      functional.Functional):
153        submodel_wrapper = model_to_dot(
154            layer.layer,
155            show_shapes,
156            show_dtype,
157            show_layer_names,
158            rankdir,
159            expand_nested,
160            subgraph=True)
161        # sub_w : submodel_wrapper
162        sub_w_nodes = submodel_wrapper.get_nodes()
163        sub_w_first_node[layer.layer.name] = sub_w_nodes[0]
164        sub_w_last_node[layer.layer.name] = sub_w_nodes[-1]
165        dot.add_subgraph(submodel_wrapper)
166      else:
167        layer_name = '{}({})'.format(layer_name, layer.layer.name)
168        child_class_name = layer.layer.__class__.__name__
169        class_name = '{}({})'.format(class_name, child_class_name)
170
171    if expand_nested and isinstance(layer, functional.Functional):
172      submodel_not_wrapper = model_to_dot(
173          layer,
174          show_shapes,
175          show_dtype,
176          show_layer_names,
177          rankdir,
178          expand_nested,
179          subgraph=True)
180      # sub_n : submodel_not_wrapper
181      sub_n_nodes = submodel_not_wrapper.get_nodes()
182      sub_n_first_node[layer.name] = sub_n_nodes[0]
183      sub_n_last_node[layer.name] = sub_n_nodes[-1]
184      dot.add_subgraph(submodel_not_wrapper)
185
186    # Create node's label.
187    if show_layer_names:
188      label = '{}: {}'.format(layer_name, class_name)
189    else:
190      label = class_name
191
192    # Rebuild the label as a table including the layer's dtype.
193    if show_dtype:
194
195      def format_dtype(dtype):
196        if dtype is None:
197          return '?'
198        else:
199          return str(dtype)
200
201      label = '%s|%s' % (label, format_dtype(layer.dtype))
202
203    # Rebuild the label as a table including input/output shapes.
204    if show_shapes:
205
206      def format_shape(shape):
207        return str(shape).replace(str(None), 'None')
208
209      try:
210        outputlabels = format_shape(layer.output_shape)
211      except AttributeError:
212        outputlabels = '?'
213      if hasattr(layer, 'input_shape'):
214        inputlabels = format_shape(layer.input_shape)
215      elif hasattr(layer, 'input_shapes'):
216        inputlabels = ', '.join(
217            [format_shape(ishape) for ishape in layer.input_shapes])
218      else:
219        inputlabels = '?'
220      label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
221                                                     inputlabels,
222                                                     outputlabels)
223
224    if not expand_nested or not isinstance(
225        layer, functional.Functional):
226      node = pydot.Node(layer_id, label=label)
227      dot.add_node(node)
228
229  # Connect nodes with edges.
230  for layer in layers:
231    layer_id = str(id(layer))
232    for i, node in enumerate(layer._inbound_nodes):
233      node_key = layer.name + '_ib-' + str(i)
234      if node_key in model._network_nodes:
235        for inbound_layer in nest.flatten(node.inbound_layers):
236          inbound_layer_id = str(id(inbound_layer))
237          if not expand_nested:
238            assert dot.get_node(inbound_layer_id)
239            assert dot.get_node(layer_id)
240            add_edge(dot, inbound_layer_id, layer_id)
241          else:
242            # if inbound_layer is not Model or wrapped Model
243            if (not isinstance(inbound_layer,
244                               functional.Functional) and
245                not is_wrapped_model(inbound_layer)):
246              # if current layer is not Model or wrapped Model
247              if (not isinstance(layer, functional.Functional) and
248                  not is_wrapped_model(layer)):
249                assert dot.get_node(inbound_layer_id)
250                assert dot.get_node(layer_id)
251                add_edge(dot, inbound_layer_id, layer_id)
252              # if current layer is Model
253              elif isinstance(layer, functional.Functional):
254                add_edge(dot, inbound_layer_id,
255                         sub_n_first_node[layer.name].get_name())
256              # if current layer is wrapped Model
257              elif is_wrapped_model(layer):
258                add_edge(dot, inbound_layer_id, layer_id)
259                name = sub_w_first_node[layer.layer.name].get_name()
260                add_edge(dot, layer_id, name)
261            # if inbound_layer is Model
262            elif isinstance(inbound_layer, functional.Functional):
263              name = sub_n_last_node[inbound_layer.name].get_name()
264              if isinstance(layer, functional.Functional):
265                output_name = sub_n_first_node[layer.name].get_name()
266                add_edge(dot, name, output_name)
267              else:
268                add_edge(dot, name, layer_id)
269            # if inbound_layer is wrapped Model
270            elif is_wrapped_model(inbound_layer):
271              inbound_layer_name = inbound_layer.layer.name
272              add_edge(dot,
273                       sub_w_last_node[inbound_layer_name].get_name(),
274                       layer_id)
275  return dot
276
277
278@keras_export('keras.utils.plot_model')
279def plot_model(model,
280               to_file='model.png',
281               show_shapes=False,
282               show_dtype=False,
283               show_layer_names=True,
284               rankdir='TB',
285               expand_nested=False,
286               dpi=96):
287  """Converts a Keras model to dot format and save to a file.
288
289  Example:
290
291  ```python
292  input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
293  x = tf.keras.layers.Embedding(
294      output_dim=512, input_dim=10000, input_length=100)(input)
295  x = tf.keras.layers.LSTM(32)(x)
296  x = tf.keras.layers.Dense(64, activation='relu')(x)
297  x = tf.keras.layers.Dense(64, activation='relu')(x)
298  x = tf.keras.layers.Dense(64, activation='relu')(x)
299  output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
300  model = tf.keras.Model(inputs=[input], outputs=[output])
301  dot_img_file = '/tmp/model_1.png'
302  tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)
303  ```
304
305  Args:
306    model: A Keras model instance
307    to_file: File name of the plot image.
308    show_shapes: whether to display shape information.
309    show_dtype: whether to display layer dtypes.
310    show_layer_names: whether to display layer names.
311    rankdir: `rankdir` argument passed to PyDot,
312        a string specifying the format of the plot:
313        'TB' creates a vertical plot;
314        'LR' creates a horizontal plot.
315    expand_nested: Whether to expand nested models into clusters.
316    dpi: Dots per inch.
317
318  Returns:
319    A Jupyter notebook Image object if Jupyter is installed.
320    This enables in-line display of the model plots in notebooks.
321  """
322  dot = model_to_dot(
323      model,
324      show_shapes=show_shapes,
325      show_dtype=show_dtype,
326      show_layer_names=show_layer_names,
327      rankdir=rankdir,
328      expand_nested=expand_nested,
329      dpi=dpi)
330  to_file = path_to_string(to_file)
331  if dot is None:
332    return
333  _, extension = os.path.splitext(to_file)
334  if not extension:
335    extension = 'png'
336  else:
337    extension = extension[1:]
338  # Save image to disk.
339  dot.write(to_file, format=extension)
340  # Return the image as a Jupyter Image object, to be displayed in-line.
341  # Note that we cannot easily detect whether the code is running in a
342  # notebook, and thus we always return the Image if Jupyter is available.
343  if extension != 'pdf':
344    try:
345      from IPython import display
346      return display.Image(filename=to_file)
347    except ImportError:
348      pass
349