xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/visualize.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1#!/usr/bin/env python
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""This tool creates an html visualization of a TensorFlow Lite graph.
17
18Example usage:
19
20python visualize.py foo.tflite foo.html
21"""
22
23import json
24import os
25import re
26import sys
27import numpy as np
28
29# pylint: disable=g-import-not-at-top
30if not os.path.splitext(__file__)[0].endswith(
31    os.path.join("tflite_runtime", "visualize")):
32  # This file is part of tensorflow package.
33  from tensorflow.lite.python import schema_py_generated as schema_fb
34else:
35  # This file is part of tflite_runtime package.
36  from tflite_runtime import schema_py_generated as schema_fb
37
38# A CSS description for making the visualizer
39_CSS = """
40<html>
41<head>
42<style>
43body {font-family: sans-serif; background-color: #fa0;}
44table {background-color: #eca;}
45th {background-color: black; color: white;}
46h1 {
47  background-color: ffaa00;
48  padding:5px;
49  color: black;
50}
51
52svg {
53  margin: 10px;
54  border: 2px;
55  border-style: solid;
56  border-color: black;
57  background: white;
58}
59
60div {
61  border-radius: 5px;
62  background-color: #fec;
63  padding:5px;
64  margin:5px;
65}
66
67.tooltip {color: blue;}
68.tooltip .tooltipcontent  {
69    visibility: hidden;
70    color: black;
71    background-color: yellow;
72    padding: 5px;
73    border-radius: 4px;
74    position: absolute;
75    z-index: 1;
76}
77.tooltip:hover .tooltipcontent {
78    visibility: visible;
79}
80
81.edges line {
82  stroke: #333;
83}
84
85text {
86  font-weight: bold;
87}
88
89.nodes text {
90  color: black;
91  pointer-events: none;
92  font-family: sans-serif;
93  font-size: 11px;
94}
95</style>
96
97<script src="https://d3js.org/d3.v4.min.js"></script>
98
99</head>
100<body>
101"""
102
103_D3_HTML_TEMPLATE = """
104  <script>
105    function buildGraph() {
106      // Build graph data
107      var graph = %s;
108
109      var svg = d3.select("#subgraph%d")
110      var width = svg.attr("width");
111      var height = svg.attr("height");
112      // Make the graph scrollable.
113      svg = svg.call(d3.zoom().on("zoom", function() {
114        svg.attr("transform", d3.event.transform);
115      })).append("g");
116
117
118      var color = d3.scaleOrdinal(d3.schemeDark2);
119
120      var simulation = d3.forceSimulation()
121          .force("link", d3.forceLink().id(function(d) {return d.id;}))
122          .force("charge", d3.forceManyBody())
123          .force("center", d3.forceCenter(0.5 * width, 0.5 * height));
124
125      var edge = svg.append("g").attr("class", "edges").selectAll("line")
126        .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
127
128      // Make the node group
129      var node = svg.selectAll(".nodes")
130        .data(graph.nodes)
131        .enter().append("g")
132        .attr("x", function(d){return d.x})
133        .attr("y", function(d){return d.y})
134        .attr("transform", function(d) {
135          return "translate( " + d.x + ", " + d.y + ")"
136        })
137        .attr("class", "nodes")
138          .call(d3.drag()
139              .on("start", function(d) {
140                if(!d3.event.active) simulation.alphaTarget(1.0).restart();
141                d.fx = d.x;d.fy = d.y;
142              })
143              .on("drag", function(d) {
144                d.fx = d3.event.x; d.fy = d3.event.y;
145              })
146              .on("end", function(d) {
147                if (!d3.event.active) simulation.alphaTarget(0);
148                d.fx = d.fy = null;
149              }));
150      // Within the group, draw a box for the node position and text
151      // on the side.
152
153      var node_width = 150;
154      var node_height = 30;
155
156      node.append("rect")
157          .attr("r", "5px")
158          .attr("width", node_width)
159          .attr("height", node_height)
160          .attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
161          .attr("stroke", "#000000")
162          .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
163      node.append("text")
164          .text(function(d) { return d.name; })
165          .attr("x", 5)
166          .attr("y", 20)
167          .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
168      // Setup force parameters and update position callback
169
170
171      var node = svg.selectAll(".nodes")
172        .data(graph.nodes);
173
174      // Bind the links
175      var name_to_g = {}
176      node.each(function(data, index, nodes) {
177        console.log(data.id)
178        name_to_g[data.id] = this;
179      });
180
181      function proc(w, t) {
182        return parseInt(w.getAttribute(t));
183      }
184      edge.attr("d", function(d) {
185        function lerp(t, a, b) {
186          return (1.0-t) * a + t * b;
187        }
188        var x1 = proc(name_to_g[d.source],"x") + node_width /2;
189        var y1 = proc(name_to_g[d.source],"y") + node_height;
190        var x2 = proc(name_to_g[d.target],"x") + node_width /2;
191        var y2 = proc(name_to_g[d.target],"y");
192        var s = "M " + x1 + " " + y1
193            + " C " + x1 + " " + lerp(.5, y1, y2)
194            + " " + x2 + " " + lerp(.5, y1, y2)
195            + " " + x2  + " " + y2
196      return s;
197    });
198
199  }
200  buildGraph()
201</script>
202"""
203
204
205def TensorTypeToName(tensor_type):
206  """Converts a numerical enum to a readable tensor type."""
207  for name, value in schema_fb.TensorType.__dict__.items():
208    if value == tensor_type:
209      return name
210  return None
211
212
213def BuiltinCodeToName(code):
214  """Converts a builtin op code enum to a readable name."""
215  for name, value in schema_fb.BuiltinOperator.__dict__.items():
216    if value == code:
217      return name
218  return None
219
220
221def NameListToString(name_list):
222  """Converts a list of integers to the equivalent ASCII string."""
223  if isinstance(name_list, str):
224    return name_list
225  else:
226    result = ""
227    if name_list is not None:
228      for val in name_list:
229        result = result + chr(int(val))
230    return result
231
232
233class OpCodeMapper:
234  """Maps an opcode index to an op name."""
235
236  def __init__(self, data):
237    self.code_to_name = {}
238    for idx, d in enumerate(data["operator_codes"]):
239      self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
240      if self.code_to_name[idx] == "CUSTOM":
241        self.code_to_name[idx] = NameListToString(d["custom_code"])
242
243  def __call__(self, x):
244    if x not in self.code_to_name:
245      s = "<UNKNOWN>"
246    else:
247      s = self.code_to_name[x]
248    return "%s (%d)" % (s, x)
249
250
251class DataSizeMapper:
252  """For buffers, report the number of bytes."""
253
254  def __call__(self, x):
255    if x is not None:
256      return "%d bytes" % len(x)
257    else:
258      return "--"
259
260
261class TensorMapper:
262  """Maps a list of tensor indices to a tooltip hoverable indicator of more."""
263
264  def __init__(self, subgraph_data):
265    self.data = subgraph_data
266
267  def __call__(self, x):
268    html = ""
269    if x is None:
270      return html
271
272    html += "<span class='tooltip'><span class='tooltipcontent'>"
273    for i in x:
274      tensor = self.data["tensors"][i]
275      html += str(i) + " "
276      html += NameListToString(tensor["name"]) + " "
277      html += TensorTypeToName(tensor["type"]) + " "
278      html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
279      html += (repr(tensor["shape_signature"])
280               if "shape_signature" in tensor else "[]") + "<br>"
281    html += "</span>"
282    html += repr(x)
283    html += "</span>"
284    return html
285
286
287def GenerateGraph(subgraph_idx, g, opcode_mapper):
288  """Produces the HTML required to have a d3 visualization of the dag."""
289
290  def TensorName(idx):
291    return "t%d" % idx
292
293  def OpName(idx):
294    return "o%d" % idx
295
296  edges = []
297  nodes = []
298  first = {}
299  second = {}
300  pixel_mult = 200  # TODO(aselle): multiplier for initial placement
301  width_mult = 170  # TODO(aselle): multiplier for initial placement
302  for op_index, op in enumerate(g["operators"] or []):
303    if op["inputs"] is not None:
304      for tensor_input_position, tensor_index in enumerate(op["inputs"]):
305        if tensor_index not in first:
306          first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
307                                 (tensor_input_position + 1) * width_mult)
308        edges.append({
309            "source": TensorName(tensor_index),
310            "target": OpName(op_index)
311        })
312    if op["outputs"] is not None:
313      for tensor_output_position, tensor_index in enumerate(op["outputs"]):
314        if tensor_index not in second:
315          second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
316                                  (tensor_output_position + 1) * width_mult)
317        edges.append({
318            "target": TensorName(tensor_index),
319            "source": OpName(op_index)
320        })
321
322    nodes.append({
323        "id": OpName(op_index),
324        "name": opcode_mapper(op["opcode_index"]),
325        "group": 2,
326        "x": pixel_mult,
327        "y": (op_index + 1) * pixel_mult
328    })
329  for tensor_index, tensor in enumerate(g["tensors"]):
330    initial_y = (
331        first[tensor_index] if tensor_index in first else
332        second[tensor_index] if tensor_index in second else (0, 0))
333
334    nodes.append({
335        "id": TensorName(tensor_index),
336        "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
337        "group": 1,
338        "x": initial_y[1],
339        "y": initial_y[0]
340    })
341  graph_str = json.dumps({"nodes": nodes, "edges": edges})
342
343  html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
344  return html
345
346
347def GenerateTableHtml(items, keys_to_print, display_index=True):
348  """Given a list of object values and keys to print, make an HTML table.
349
350  Args:
351    items: Items to print an array of dicts.
352    keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
353      items[0][key] should exist. display_fn is the mapping function on display.
354      i.e. the displayed html cell will have the string returned by
355      `mapping_fn(items[0][key])`.
356    display_index: add a column which is the index of each row in `items`.
357
358  Returns:
359    An html table.
360  """
361  html = ""
362  # Print the list of  items
363  html += "<table><tr>\n"
364  html += "<tr>\n"
365  if display_index:
366    html += "<th>index</th>"
367  for h, mapper in keys_to_print:
368    html += "<th>%s</th>" % h
369  html += "</tr>\n"
370  for idx, tensor in enumerate(items):
371    html += "<tr>\n"
372    if display_index:
373      html += "<td>%d</td>" % idx
374    # print tensor.keys()
375    for h, mapper in keys_to_print:
376      val = tensor[h] if h in tensor else None
377      val = val if mapper is None else mapper(val)
378      html += "<td>%s</td>\n" % val
379
380    html += "</tr>\n"
381  html += "</table>\n"
382  return html
383
384
385def CamelCaseToSnakeCase(camel_case_input):
386  """Converts an identifier in CamelCase to snake_case."""
387  s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
388  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
389
390
391def FlatbufferToDict(fb, preserve_as_numpy):
392  """Converts a hierarchy of FB objects into a nested dict.
393
394  We avoid transforming big parts of the flat buffer into python arrays. This
395  speeds conversion from ten minutes to a few seconds on big graphs.
396
397  Args:
398    fb: a flat buffer structure. (i.e. ModelT)
399    preserve_as_numpy: true if all downstream np.arrays should be preserved.
400      false if all downstream np.array should become python arrays
401  Returns:
402    A dictionary representing the flatbuffer rather than a flatbuffer object.
403  """
404  if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
405    return fb
406  elif hasattr(fb, "__dict__"):
407    result = {}
408    for attribute_name in dir(fb):
409      attribute = fb.__getattribute__(attribute_name)
410      if not callable(attribute) and attribute_name[0] != "_":
411        snake_name = CamelCaseToSnakeCase(attribute_name)
412        preserve = True if attribute_name == "buffers" else preserve_as_numpy
413        result[snake_name] = FlatbufferToDict(attribute, preserve)
414    return result
415  elif isinstance(fb, np.ndarray):
416    return fb if preserve_as_numpy else fb.tolist()
417  elif hasattr(fb, "__len__"):
418    return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
419  else:
420    return fb
421
422
423def CreateDictFromFlatbuffer(buffer_data):
424  model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
425  model = schema_fb.ModelT.InitFromObj(model_obj)
426  return FlatbufferToDict(model, preserve_as_numpy=False)
427
428
429def create_html(tflite_input, input_is_filepath=True):  # pylint: disable=invalid-name
430  """Returns html description with the given tflite model.
431
432  Args:
433    tflite_input: TFLite flatbuffer model path or model object.
434    input_is_filepath: Tells if tflite_input is a model path or a model object.
435
436  Returns:
437    Dump of the given tflite model in HTML format.
438
439  Raises:
440    RuntimeError: If the input is not valid.
441  """
442
443  # Convert the model into a JSON flatbuffer using flatc (build if doesn't
444  # exist.
445  if input_is_filepath:
446    if not os.path.exists(tflite_input):
447      raise RuntimeError("Invalid filename %r" % tflite_input)
448    if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
449      with open(tflite_input, "rb") as file_handle:
450        file_data = bytearray(file_handle.read())
451      data = CreateDictFromFlatbuffer(file_data)
452    elif tflite_input.endswith(".json"):
453      data = json.load(open(tflite_input))
454    else:
455      raise RuntimeError("Input file was not .tflite or .json")
456  else:
457    data = CreateDictFromFlatbuffer(tflite_input)
458  html = ""
459  html += _CSS
460  html += "<h1>TensorFlow Lite Model</h2>"
461
462  data["filename"] = tflite_input if input_is_filepath else (
463      "Null (used model object)")  # Avoid special case
464
465  toplevel_stuff = [("filename", None), ("version", None),
466                    ("description", None)]
467
468  html += "<table>\n"
469  for key, mapping in toplevel_stuff:
470    if not mapping:
471      mapping = lambda x: x
472    html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key)))
473  html += "</table>\n"
474
475  # Spec on what keys to display
476  buffer_keys_to_display = [("data", DataSizeMapper())]
477  operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
478                              ("custom_code", NameListToString),
479                              ("version", None)]
480
481  # Update builtin code fields.
482  for d in data["operator_codes"]:
483    d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
484
485  for subgraph_idx, g in enumerate(data["subgraphs"]):
486    # Subgraph local specs on what to display
487    html += "<div class='subgraph'>"
488    tensor_mapper = TensorMapper(g)
489    opcode_mapper = OpCodeMapper(data)
490    op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
491                          ("builtin_options", None),
492                          ("opcode_index", opcode_mapper)]
493    tensor_keys_to_display = [("name", NameListToString),
494                              ("type", TensorTypeToName), ("shape", None),
495                              ("shape_signature", None), ("buffer", None),
496                              ("quantization", None)]
497
498    html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
499
500    # Inputs and outputs.
501    html += "<h3>Inputs/Outputs</h3>\n"
502    html += GenerateTableHtml([{
503        "inputs": g["inputs"],
504        "outputs": g["outputs"]
505    }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
506                              display_index=False)
507
508    # Print the tensors.
509    html += "<h3>Tensors</h3>\n"
510    html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
511
512    # Print the ops.
513    if g["operators"]:
514      html += "<h3>Ops</h3>\n"
515      html += GenerateTableHtml(g["operators"], op_keys_to_display)
516
517    # Visual graph.
518    html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
519        subgraph_idx,)
520    html += GenerateGraph(subgraph_idx, g, opcode_mapper)
521    html += "</div>"
522
523  # Buffers have no data, but maybe in the future they will
524  html += "<h2>Buffers</h2>\n"
525  html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
526
527  # Operator codes
528  html += "<h2>Operator Codes</h2>\n"
529  html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
530
531  html += "</body></html>\n"
532
533  return html
534
535
536def main(argv):
537  try:
538    tflite_input = argv[1]
539    html_output = argv[2]
540  except IndexError:
541    print("Usage: %s <input tflite> <output html>" % (argv[0]))
542  else:
543    html = create_html(tflite_input)
544    with open(html_output, "w") as output_file:
545      output_file.write(html)
546
547
548if __name__ == "__main__":
549  main(sys.argv)
550