1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16# pylint: disable=g-import-not-at-top 17"""Utilities related to model visualization.""" 18 19import os 20import sys 21from tensorflow.python.keras.utils.io_utils import path_to_string 22from tensorflow.python.util import nest 23from tensorflow.python.util.tf_export import keras_export 24 25 26try: 27 # pydot-ng is a fork of pydot that is better maintained. 28 import pydot_ng as pydot 29except ImportError: 30 # pydotplus is an improved version of pydot 31 try: 32 import pydotplus as pydot 33 except ImportError: 34 # Fall back on pydot if necessary. 35 try: 36 import pydot 37 except ImportError: 38 pydot = None 39 40 41def check_pydot(): 42 """Returns True if PyDot and Graphviz are available.""" 43 if pydot is None: 44 return False 45 try: 46 # Attempt to create an image of a blank graph 47 # to check the pydot/graphviz installation. 48 pydot.Dot.create(pydot.Dot()) 49 return True 50 except (OSError, pydot.InvocationException): 51 return False 52 53 54def is_wrapped_model(layer): 55 from tensorflow.python.keras.engine import functional 56 from tensorflow.python.keras.layers import wrappers 57 return (isinstance(layer, wrappers.Wrapper) and 58 isinstance(layer.layer, functional.Functional)) 59 60 61def add_edge(dot, src, dst): 62 if not dot.get_edge(src, dst): 63 dot.add_edge(pydot.Edge(src, dst)) 64 65 66@keras_export('keras.utils.model_to_dot') 67def model_to_dot(model, 68 show_shapes=False, 69 show_dtype=False, 70 show_layer_names=True, 71 rankdir='TB', 72 expand_nested=False, 73 dpi=96, 74 subgraph=False): 75 """Convert a Keras model to dot format. 76 77 Args: 78 model: A Keras model instance. 79 show_shapes: whether to display shape information. 80 show_dtype: whether to display layer dtypes. 81 show_layer_names: whether to display layer names. 82 rankdir: `rankdir` argument passed to PyDot, 83 a string specifying the format of the plot: 84 'TB' creates a vertical plot; 85 'LR' creates a horizontal plot. 86 expand_nested: whether to expand nested models into clusters. 87 dpi: Dots per inch. 88 subgraph: whether to return a `pydot.Cluster` instance. 89 90 Returns: 91 A `pydot.Dot` instance representing the Keras model or 92 a `pydot.Cluster` instance representing nested model if 93 `subgraph=True`. 94 95 Raises: 96 ImportError: if graphviz or pydot are not available. 97 """ 98 from tensorflow.python.keras.layers import wrappers 99 from tensorflow.python.keras.engine import sequential 100 from tensorflow.python.keras.engine import functional 101 102 if not check_pydot(): 103 message = ( 104 'You must install pydot (`pip install pydot`) ' 105 'and install graphviz ' 106 '(see instructions at https://graphviz.gitlab.io/download/) ', 107 'for plot_model/model_to_dot to work.') 108 if 'IPython.core.magics.namespace' in sys.modules: 109 # We don't raise an exception here in order to avoid crashing notebook 110 # tests where graphviz is not available. 111 print(message) 112 return 113 else: 114 raise ImportError(message) 115 116 if subgraph: 117 dot = pydot.Cluster(style='dashed', graph_name=model.name) 118 dot.set('label', model.name) 119 dot.set('labeljust', 'l') 120 else: 121 dot = pydot.Dot() 122 dot.set('rankdir', rankdir) 123 dot.set('concentrate', True) 124 dot.set('dpi', dpi) 125 dot.set_node_defaults(shape='record') 126 127 sub_n_first_node = {} 128 sub_n_last_node = {} 129 sub_w_first_node = {} 130 sub_w_last_node = {} 131 132 layers = model.layers 133 if not model._is_graph_network: 134 node = pydot.Node(str(id(model)), label=model.name) 135 dot.add_node(node) 136 return dot 137 elif isinstance(model, sequential.Sequential): 138 if not model.built: 139 model.build() 140 layers = super(sequential.Sequential, model).layers 141 142 # Create graph nodes. 143 for i, layer in enumerate(layers): 144 layer_id = str(id(layer)) 145 146 # Append a wrapped layer's label to node's label, if it exists. 147 layer_name = layer.name 148 class_name = layer.__class__.__name__ 149 150 if isinstance(layer, wrappers.Wrapper): 151 if expand_nested and isinstance(layer.layer, 152 functional.Functional): 153 submodel_wrapper = model_to_dot( 154 layer.layer, 155 show_shapes, 156 show_dtype, 157 show_layer_names, 158 rankdir, 159 expand_nested, 160 subgraph=True) 161 # sub_w : submodel_wrapper 162 sub_w_nodes = submodel_wrapper.get_nodes() 163 sub_w_first_node[layer.layer.name] = sub_w_nodes[0] 164 sub_w_last_node[layer.layer.name] = sub_w_nodes[-1] 165 dot.add_subgraph(submodel_wrapper) 166 else: 167 layer_name = '{}({})'.format(layer_name, layer.layer.name) 168 child_class_name = layer.layer.__class__.__name__ 169 class_name = '{}({})'.format(class_name, child_class_name) 170 171 if expand_nested and isinstance(layer, functional.Functional): 172 submodel_not_wrapper = model_to_dot( 173 layer, 174 show_shapes, 175 show_dtype, 176 show_layer_names, 177 rankdir, 178 expand_nested, 179 subgraph=True) 180 # sub_n : submodel_not_wrapper 181 sub_n_nodes = submodel_not_wrapper.get_nodes() 182 sub_n_first_node[layer.name] = sub_n_nodes[0] 183 sub_n_last_node[layer.name] = sub_n_nodes[-1] 184 dot.add_subgraph(submodel_not_wrapper) 185 186 # Create node's label. 187 if show_layer_names: 188 label = '{}: {}'.format(layer_name, class_name) 189 else: 190 label = class_name 191 192 # Rebuild the label as a table including the layer's dtype. 193 if show_dtype: 194 195 def format_dtype(dtype): 196 if dtype is None: 197 return '?' 198 else: 199 return str(dtype) 200 201 label = '%s|%s' % (label, format_dtype(layer.dtype)) 202 203 # Rebuild the label as a table including input/output shapes. 204 if show_shapes: 205 206 def format_shape(shape): 207 return str(shape).replace(str(None), 'None') 208 209 try: 210 outputlabels = format_shape(layer.output_shape) 211 except AttributeError: 212 outputlabels = '?' 213 if hasattr(layer, 'input_shape'): 214 inputlabels = format_shape(layer.input_shape) 215 elif hasattr(layer, 'input_shapes'): 216 inputlabels = ', '.join( 217 [format_shape(ishape) for ishape in layer.input_shapes]) 218 else: 219 inputlabels = '?' 220 label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, 221 inputlabels, 222 outputlabels) 223 224 if not expand_nested or not isinstance( 225 layer, functional.Functional): 226 node = pydot.Node(layer_id, label=label) 227 dot.add_node(node) 228 229 # Connect nodes with edges. 230 for layer in layers: 231 layer_id = str(id(layer)) 232 for i, node in enumerate(layer._inbound_nodes): 233 node_key = layer.name + '_ib-' + str(i) 234 if node_key in model._network_nodes: 235 for inbound_layer in nest.flatten(node.inbound_layers): 236 inbound_layer_id = str(id(inbound_layer)) 237 if not expand_nested: 238 assert dot.get_node(inbound_layer_id) 239 assert dot.get_node(layer_id) 240 add_edge(dot, inbound_layer_id, layer_id) 241 else: 242 # if inbound_layer is not Model or wrapped Model 243 if (not isinstance(inbound_layer, 244 functional.Functional) and 245 not is_wrapped_model(inbound_layer)): 246 # if current layer is not Model or wrapped Model 247 if (not isinstance(layer, functional.Functional) and 248 not is_wrapped_model(layer)): 249 assert dot.get_node(inbound_layer_id) 250 assert dot.get_node(layer_id) 251 add_edge(dot, inbound_layer_id, layer_id) 252 # if current layer is Model 253 elif isinstance(layer, functional.Functional): 254 add_edge(dot, inbound_layer_id, 255 sub_n_first_node[layer.name].get_name()) 256 # if current layer is wrapped Model 257 elif is_wrapped_model(layer): 258 add_edge(dot, inbound_layer_id, layer_id) 259 name = sub_w_first_node[layer.layer.name].get_name() 260 add_edge(dot, layer_id, name) 261 # if inbound_layer is Model 262 elif isinstance(inbound_layer, functional.Functional): 263 name = sub_n_last_node[inbound_layer.name].get_name() 264 if isinstance(layer, functional.Functional): 265 output_name = sub_n_first_node[layer.name].get_name() 266 add_edge(dot, name, output_name) 267 else: 268 add_edge(dot, name, layer_id) 269 # if inbound_layer is wrapped Model 270 elif is_wrapped_model(inbound_layer): 271 inbound_layer_name = inbound_layer.layer.name 272 add_edge(dot, 273 sub_w_last_node[inbound_layer_name].get_name(), 274 layer_id) 275 return dot 276 277 278@keras_export('keras.utils.plot_model') 279def plot_model(model, 280 to_file='model.png', 281 show_shapes=False, 282 show_dtype=False, 283 show_layer_names=True, 284 rankdir='TB', 285 expand_nested=False, 286 dpi=96): 287 """Converts a Keras model to dot format and save to a file. 288 289 Example: 290 291 ```python 292 input = tf.keras.Input(shape=(100,), dtype='int32', name='input') 293 x = tf.keras.layers.Embedding( 294 output_dim=512, input_dim=10000, input_length=100)(input) 295 x = tf.keras.layers.LSTM(32)(x) 296 x = tf.keras.layers.Dense(64, activation='relu')(x) 297 x = tf.keras.layers.Dense(64, activation='relu')(x) 298 x = tf.keras.layers.Dense(64, activation='relu')(x) 299 output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x) 300 model = tf.keras.Model(inputs=[input], outputs=[output]) 301 dot_img_file = '/tmp/model_1.png' 302 tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) 303 ``` 304 305 Args: 306 model: A Keras model instance 307 to_file: File name of the plot image. 308 show_shapes: whether to display shape information. 309 show_dtype: whether to display layer dtypes. 310 show_layer_names: whether to display layer names. 311 rankdir: `rankdir` argument passed to PyDot, 312 a string specifying the format of the plot: 313 'TB' creates a vertical plot; 314 'LR' creates a horizontal plot. 315 expand_nested: Whether to expand nested models into clusters. 316 dpi: Dots per inch. 317 318 Returns: 319 A Jupyter notebook Image object if Jupyter is installed. 320 This enables in-line display of the model plots in notebooks. 321 """ 322 dot = model_to_dot( 323 model, 324 show_shapes=show_shapes, 325 show_dtype=show_dtype, 326 show_layer_names=show_layer_names, 327 rankdir=rankdir, 328 expand_nested=expand_nested, 329 dpi=dpi) 330 to_file = path_to_string(to_file) 331 if dot is None: 332 return 333 _, extension = os.path.splitext(to_file) 334 if not extension: 335 extension = 'png' 336 else: 337 extension = extension[1:] 338 # Save image to disk. 339 dot.write(to_file, format=extension) 340 # Return the image as a Jupyter Image object, to be displayed in-line. 341 # Note that we cannot easily detect whether the code is running in a 342 # notebook, and thus we always return the Image if Jupyter is available. 343 if extension != 'pdf': 344 try: 345 from IPython import display 346 return display.Image(filename=to_file) 347 except ImportError: 348 pass 349