1#!/usr/bin/env python 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""This tool creates an html visualization of a TensorFlow Lite graph. 17 18Example usage: 19 20python visualize.py foo.tflite foo.html 21""" 22 23import json 24import os 25import re 26import sys 27import numpy as np 28 29# pylint: disable=g-import-not-at-top 30if not os.path.splitext(__file__)[0].endswith( 31 os.path.join("tflite_runtime", "visualize")): 32 # This file is part of tensorflow package. 33 from tensorflow.lite.python import schema_py_generated as schema_fb 34else: 35 # This file is part of tflite_runtime package. 36 from tflite_runtime import schema_py_generated as schema_fb 37 38# A CSS description for making the visualizer 39_CSS = """ 40<html> 41<head> 42<style> 43body {font-family: sans-serif; background-color: #fa0;} 44table {background-color: #eca;} 45th {background-color: black; color: white;} 46h1 { 47 background-color: ffaa00; 48 padding:5px; 49 color: black; 50} 51 52svg { 53 margin: 10px; 54 border: 2px; 55 border-style: solid; 56 border-color: black; 57 background: white; 58} 59 60div { 61 border-radius: 5px; 62 background-color: #fec; 63 padding:5px; 64 margin:5px; 65} 66 67.tooltip {color: blue;} 68.tooltip .tooltipcontent { 69 visibility: hidden; 70 color: black; 71 background-color: yellow; 72 padding: 5px; 73 border-radius: 4px; 74 position: absolute; 75 z-index: 1; 76} 77.tooltip:hover .tooltipcontent { 78 visibility: visible; 79} 80 81.edges line { 82 stroke: #333; 83} 84 85text { 86 font-weight: bold; 87} 88 89.nodes text { 90 color: black; 91 pointer-events: none; 92 font-family: sans-serif; 93 font-size: 11px; 94} 95</style> 96 97<script src="https://d3js.org/d3.v4.min.js"></script> 98 99</head> 100<body> 101""" 102 103_D3_HTML_TEMPLATE = """ 104 <script> 105 function buildGraph() { 106 // Build graph data 107 var graph = %s; 108 109 var svg = d3.select("#subgraph%d") 110 var width = svg.attr("width"); 111 var height = svg.attr("height"); 112 // Make the graph scrollable. 113 svg = svg.call(d3.zoom().on("zoom", function() { 114 svg.attr("transform", d3.event.transform); 115 })).append("g"); 116 117 118 var color = d3.scaleOrdinal(d3.schemeDark2); 119 120 var simulation = d3.forceSimulation() 121 .force("link", d3.forceLink().id(function(d) {return d.id;})) 122 .force("charge", d3.forceManyBody()) 123 .force("center", d3.forceCenter(0.5 * width, 0.5 * height)); 124 125 var edge = svg.append("g").attr("class", "edges").selectAll("line") 126 .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none") 127 128 // Make the node group 129 var node = svg.selectAll(".nodes") 130 .data(graph.nodes) 131 .enter().append("g") 132 .attr("x", function(d){return d.x}) 133 .attr("y", function(d){return d.y}) 134 .attr("transform", function(d) { 135 return "translate( " + d.x + ", " + d.y + ")" 136 }) 137 .attr("class", "nodes") 138 .call(d3.drag() 139 .on("start", function(d) { 140 if(!d3.event.active) simulation.alphaTarget(1.0).restart(); 141 d.fx = d.x;d.fy = d.y; 142 }) 143 .on("drag", function(d) { 144 d.fx = d3.event.x; d.fy = d3.event.y; 145 }) 146 .on("end", function(d) { 147 if (!d3.event.active) simulation.alphaTarget(0); 148 d.fx = d.fy = null; 149 })); 150 // Within the group, draw a box for the node position and text 151 // on the side. 152 153 var node_width = 150; 154 var node_height = 30; 155 156 node.append("rect") 157 .attr("r", "5px") 158 .attr("width", node_width) 159 .attr("height", node_height) 160 .attr("rx", function(d) { return d.group == 1 ? 1 : 10; }) 161 .attr("stroke", "#000000") 162 .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; }) 163 node.append("text") 164 .text(function(d) { return d.name; }) 165 .attr("x", 5) 166 .attr("y", 20) 167 .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; }) 168 // Setup force parameters and update position callback 169 170 171 var node = svg.selectAll(".nodes") 172 .data(graph.nodes); 173 174 // Bind the links 175 var name_to_g = {} 176 node.each(function(data, index, nodes) { 177 console.log(data.id) 178 name_to_g[data.id] = this; 179 }); 180 181 function proc(w, t) { 182 return parseInt(w.getAttribute(t)); 183 } 184 edge.attr("d", function(d) { 185 function lerp(t, a, b) { 186 return (1.0-t) * a + t * b; 187 } 188 var x1 = proc(name_to_g[d.source],"x") + node_width /2; 189 var y1 = proc(name_to_g[d.source],"y") + node_height; 190 var x2 = proc(name_to_g[d.target],"x") + node_width /2; 191 var y2 = proc(name_to_g[d.target],"y"); 192 var s = "M " + x1 + " " + y1 193 + " C " + x1 + " " + lerp(.5, y1, y2) 194 + " " + x2 + " " + lerp(.5, y1, y2) 195 + " " + x2 + " " + y2 196 return s; 197 }); 198 199 } 200 buildGraph() 201</script> 202""" 203 204 205def TensorTypeToName(tensor_type): 206 """Converts a numerical enum to a readable tensor type.""" 207 for name, value in schema_fb.TensorType.__dict__.items(): 208 if value == tensor_type: 209 return name 210 return None 211 212 213def BuiltinCodeToName(code): 214 """Converts a builtin op code enum to a readable name.""" 215 for name, value in schema_fb.BuiltinOperator.__dict__.items(): 216 if value == code: 217 return name 218 return None 219 220 221def NameListToString(name_list): 222 """Converts a list of integers to the equivalent ASCII string.""" 223 if isinstance(name_list, str): 224 return name_list 225 else: 226 result = "" 227 if name_list is not None: 228 for val in name_list: 229 result = result + chr(int(val)) 230 return result 231 232 233class OpCodeMapper: 234 """Maps an opcode index to an op name.""" 235 236 def __init__(self, data): 237 self.code_to_name = {} 238 for idx, d in enumerate(data["operator_codes"]): 239 self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"]) 240 if self.code_to_name[idx] == "CUSTOM": 241 self.code_to_name[idx] = NameListToString(d["custom_code"]) 242 243 def __call__(self, x): 244 if x not in self.code_to_name: 245 s = "<UNKNOWN>" 246 else: 247 s = self.code_to_name[x] 248 return "%s (%d)" % (s, x) 249 250 251class DataSizeMapper: 252 """For buffers, report the number of bytes.""" 253 254 def __call__(self, x): 255 if x is not None: 256 return "%d bytes" % len(x) 257 else: 258 return "--" 259 260 261class TensorMapper: 262 """Maps a list of tensor indices to a tooltip hoverable indicator of more.""" 263 264 def __init__(self, subgraph_data): 265 self.data = subgraph_data 266 267 def __call__(self, x): 268 html = "" 269 if x is None: 270 return html 271 272 html += "<span class='tooltip'><span class='tooltipcontent'>" 273 for i in x: 274 tensor = self.data["tensors"][i] 275 html += str(i) + " " 276 html += NameListToString(tensor["name"]) + " " 277 html += TensorTypeToName(tensor["type"]) + " " 278 html += (repr(tensor["shape"]) if "shape" in tensor else "[]") 279 html += (repr(tensor["shape_signature"]) 280 if "shape_signature" in tensor else "[]") + "<br>" 281 html += "</span>" 282 html += repr(x) 283 html += "</span>" 284 return html 285 286 287def GenerateGraph(subgraph_idx, g, opcode_mapper): 288 """Produces the HTML required to have a d3 visualization of the dag.""" 289 290 def TensorName(idx): 291 return "t%d" % idx 292 293 def OpName(idx): 294 return "o%d" % idx 295 296 edges = [] 297 nodes = [] 298 first = {} 299 second = {} 300 pixel_mult = 200 # TODO(aselle): multiplier for initial placement 301 width_mult = 170 # TODO(aselle): multiplier for initial placement 302 for op_index, op in enumerate(g["operators"] or []): 303 if op["inputs"] is not None: 304 for tensor_input_position, tensor_index in enumerate(op["inputs"]): 305 if tensor_index not in first: 306 first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult, 307 (tensor_input_position + 1) * width_mult) 308 edges.append({ 309 "source": TensorName(tensor_index), 310 "target": OpName(op_index) 311 }) 312 if op["outputs"] is not None: 313 for tensor_output_position, tensor_index in enumerate(op["outputs"]): 314 if tensor_index not in second: 315 second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult, 316 (tensor_output_position + 1) * width_mult) 317 edges.append({ 318 "target": TensorName(tensor_index), 319 "source": OpName(op_index) 320 }) 321 322 nodes.append({ 323 "id": OpName(op_index), 324 "name": opcode_mapper(op["opcode_index"]), 325 "group": 2, 326 "x": pixel_mult, 327 "y": (op_index + 1) * pixel_mult 328 }) 329 for tensor_index, tensor in enumerate(g["tensors"]): 330 initial_y = ( 331 first[tensor_index] if tensor_index in first else 332 second[tensor_index] if tensor_index in second else (0, 0)) 333 334 nodes.append({ 335 "id": TensorName(tensor_index), 336 "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index), 337 "group": 1, 338 "x": initial_y[1], 339 "y": initial_y[0] 340 }) 341 graph_str = json.dumps({"nodes": nodes, "edges": edges}) 342 343 html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) 344 return html 345 346 347def GenerateTableHtml(items, keys_to_print, display_index=True): 348 """Given a list of object values and keys to print, make an HTML table. 349 350 Args: 351 items: Items to print an array of dicts. 352 keys_to_print: (key, display_fn). `key` is a key in the object. i.e. 353 items[0][key] should exist. display_fn is the mapping function on display. 354 i.e. the displayed html cell will have the string returned by 355 `mapping_fn(items[0][key])`. 356 display_index: add a column which is the index of each row in `items`. 357 358 Returns: 359 An html table. 360 """ 361 html = "" 362 # Print the list of items 363 html += "<table><tr>\n" 364 html += "<tr>\n" 365 if display_index: 366 html += "<th>index</th>" 367 for h, mapper in keys_to_print: 368 html += "<th>%s</th>" % h 369 html += "</tr>\n" 370 for idx, tensor in enumerate(items): 371 html += "<tr>\n" 372 if display_index: 373 html += "<td>%d</td>" % idx 374 # print tensor.keys() 375 for h, mapper in keys_to_print: 376 val = tensor[h] if h in tensor else None 377 val = val if mapper is None else mapper(val) 378 html += "<td>%s</td>\n" % val 379 380 html += "</tr>\n" 381 html += "</table>\n" 382 return html 383 384 385def CamelCaseToSnakeCase(camel_case_input): 386 """Converts an identifier in CamelCase to snake_case.""" 387 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input) 388 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 389 390 391def FlatbufferToDict(fb, preserve_as_numpy): 392 """Converts a hierarchy of FB objects into a nested dict. 393 394 We avoid transforming big parts of the flat buffer into python arrays. This 395 speeds conversion from ten minutes to a few seconds on big graphs. 396 397 Args: 398 fb: a flat buffer structure. (i.e. ModelT) 399 preserve_as_numpy: true if all downstream np.arrays should be preserved. 400 false if all downstream np.array should become python arrays 401 Returns: 402 A dictionary representing the flatbuffer rather than a flatbuffer object. 403 """ 404 if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): 405 return fb 406 elif hasattr(fb, "__dict__"): 407 result = {} 408 for attribute_name in dir(fb): 409 attribute = fb.__getattribute__(attribute_name) 410 if not callable(attribute) and attribute_name[0] != "_": 411 snake_name = CamelCaseToSnakeCase(attribute_name) 412 preserve = True if attribute_name == "buffers" else preserve_as_numpy 413 result[snake_name] = FlatbufferToDict(attribute, preserve) 414 return result 415 elif isinstance(fb, np.ndarray): 416 return fb if preserve_as_numpy else fb.tolist() 417 elif hasattr(fb, "__len__"): 418 return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] 419 else: 420 return fb 421 422 423def CreateDictFromFlatbuffer(buffer_data): 424 model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) 425 model = schema_fb.ModelT.InitFromObj(model_obj) 426 return FlatbufferToDict(model, preserve_as_numpy=False) 427 428 429def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name 430 """Returns html description with the given tflite model. 431 432 Args: 433 tflite_input: TFLite flatbuffer model path or model object. 434 input_is_filepath: Tells if tflite_input is a model path or a model object. 435 436 Returns: 437 Dump of the given tflite model in HTML format. 438 439 Raises: 440 RuntimeError: If the input is not valid. 441 """ 442 443 # Convert the model into a JSON flatbuffer using flatc (build if doesn't 444 # exist. 445 if input_is_filepath: 446 if not os.path.exists(tflite_input): 447 raise RuntimeError("Invalid filename %r" % tflite_input) 448 if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"): 449 with open(tflite_input, "rb") as file_handle: 450 file_data = bytearray(file_handle.read()) 451 data = CreateDictFromFlatbuffer(file_data) 452 elif tflite_input.endswith(".json"): 453 data = json.load(open(tflite_input)) 454 else: 455 raise RuntimeError("Input file was not .tflite or .json") 456 else: 457 data = CreateDictFromFlatbuffer(tflite_input) 458 html = "" 459 html += _CSS 460 html += "<h1>TensorFlow Lite Model</h2>" 461 462 data["filename"] = tflite_input if input_is_filepath else ( 463 "Null (used model object)") # Avoid special case 464 465 toplevel_stuff = [("filename", None), ("version", None), 466 ("description", None)] 467 468 html += "<table>\n" 469 for key, mapping in toplevel_stuff: 470 if not mapping: 471 mapping = lambda x: x 472 html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key))) 473 html += "</table>\n" 474 475 # Spec on what keys to display 476 buffer_keys_to_display = [("data", DataSizeMapper())] 477 operator_keys_to_display = [("builtin_code", BuiltinCodeToName), 478 ("custom_code", NameListToString), 479 ("version", None)] 480 481 # Update builtin code fields. 482 for d in data["operator_codes"]: 483 d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"]) 484 485 for subgraph_idx, g in enumerate(data["subgraphs"]): 486 # Subgraph local specs on what to display 487 html += "<div class='subgraph'>" 488 tensor_mapper = TensorMapper(g) 489 opcode_mapper = OpCodeMapper(data) 490 op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), 491 ("builtin_options", None), 492 ("opcode_index", opcode_mapper)] 493 tensor_keys_to_display = [("name", NameListToString), 494 ("type", TensorTypeToName), ("shape", None), 495 ("shape_signature", None), ("buffer", None), 496 ("quantization", None)] 497 498 html += "<h2>Subgraph %d</h2>\n" % subgraph_idx 499 500 # Inputs and outputs. 501 html += "<h3>Inputs/Outputs</h3>\n" 502 html += GenerateTableHtml([{ 503 "inputs": g["inputs"], 504 "outputs": g["outputs"] 505 }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], 506 display_index=False) 507 508 # Print the tensors. 509 html += "<h3>Tensors</h3>\n" 510 html += GenerateTableHtml(g["tensors"], tensor_keys_to_display) 511 512 # Print the ops. 513 if g["operators"]: 514 html += "<h3>Ops</h3>\n" 515 html += GenerateTableHtml(g["operators"], op_keys_to_display) 516 517 # Visual graph. 518 html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % ( 519 subgraph_idx,) 520 html += GenerateGraph(subgraph_idx, g, opcode_mapper) 521 html += "</div>" 522 523 # Buffers have no data, but maybe in the future they will 524 html += "<h2>Buffers</h2>\n" 525 html += GenerateTableHtml(data["buffers"], buffer_keys_to_display) 526 527 # Operator codes 528 html += "<h2>Operator Codes</h2>\n" 529 html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) 530 531 html += "</body></html>\n" 532 533 return html 534 535 536def main(argv): 537 try: 538 tflite_input = argv[1] 539 html_output = argv[2] 540 except IndexError: 541 print("Usage: %s <input tflite> <output html>" % (argv[0])) 542 else: 543 html = create_html(tflite_input) 544 with open(html_output, "w") as output_file: 545 output_file.write(html) 546 547 548if __name__ == "__main__": 549 main(sys.argv) 550