xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/dump_graphviz.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/dump_graphviz.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/strings/strip.h"
28 #include "re2/re2.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/lite/toco/model_flags.pb.h"
31 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
32 #include "tensorflow/lite/toco/toco_port.h"
33 #include "tensorflow/lite/toco/toco_types.h"
34 #include "tensorflow/lite/toco/tooling_util.h"
35 
36 using toco::port::AppendF;
37 using toco::port::StringF;
38 
39 namespace toco {
40 namespace {
41 
42 // 'nslimit' is a graphviz (dot) parameter that limits the iterations during
43 // the layout phase. Omitting it allows infinite iterations, causing some
44 // complex graphs to never finish. A value of 125 produces good graphs
45 // while allowing complex graphs to finish.
46 constexpr char kGraphFmt[] = R"CODE(digraph Computegraph { tooltip = "/"
47     nslimit=125 margin=36 ranksep = 2 labelloc="t" label=%s
48 )CODE";
49 // Note: tooltip's are only supported on SVGs in Chrome.
50 constexpr char kSubgraphFmt[] =
51     R"CODE(    subgraph "cluster_%s" { style=rounded bgcolor="%s" penwidth=0.0 label=%s
52 )CODE";
53 constexpr char kArrayNodeFmt[] =
54     R"CODE(        "%s" [label=%s tooltip="%s" shape=%s style=filled fillcolor="%s" fontcolor="%sDD"];
55 )CODE";
56 constexpr char kOpNodeFmt[] =
57     R"CODE(        %s [label=%s tooltip=" " shape=box margin=0 style=filled fillcolor="%s" fontcolor="%sDD"];
58 )CODE";
59 constexpr char kInputEdgeFmt[] =
60     R"CODE(        "%s"%s -> %s:i%d:n [penwidth=%f weight=%f];
61 )CODE";
62 constexpr char kOutputEdgeFmt[] =
63     R"CODE(        %s:o%d:s -> "%s"%s [penwidth=%f weight=%f];
64 )CODE";
65 constexpr char kRNNBackEdgeFmt[] =
66     R"CODE(        "%s":s -> "%s":n [color="#0F9D58" constraint=false];
67 )CODE";
68 constexpr char kUnicodeMult[] = "\u00D7";
69 constexpr char kUnicodeEllipsis[] = " \u2026 ";
70 
71 class Color {
72  public:
Color()73   Color() {}
Color(uint8 r,uint8 g,uint8 b)74   Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {}
Color(uint32 word)75   explicit Color(uint32 word)
76       : r_((word & 0x00FF0000) >> 16),
77         g_((word & 0x0000FF00) >> 8),
78         b_((word & 0x000000FF) >> 0) {}
79 
80   // Returns the string serialization of this color in graphviz format,
81   // for use as 'fillcolor' in boxes.
AsHexString() const82   std::string AsHexString() const {
83     return StringF("#%.2X%.2X%.2X", r_, g_, b_);
84   }
85   // The color to use for this node; will be used as 'fillcolor'
86   // for its box. See Color::AsHexString. A suitable, different
87   // color will be chosen for the 'fontcolor' for the inside text
88   // label, see Color::TextColorString.
89   // Returns the serialization in graphviz format of a suitable color to use
90   // 'fontcolor' in the same boxes. It should black or white, whichever offers
91   // the better contrast from AsHexString().
TextColorString() const92   std::string TextColorString() const {
93     // https://en.wikipedia.org/wiki/Relative_luminance
94     const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_;
95     const uint8 l = luminance > 128.f ? 0 : 255;
96     return StringF("#%.2X%.2X%.2X", l, l, l);
97   }
98 
99  private:
100   uint8 r_ = 0, g_ = 0, b_ = 0;
101 };
102 
HashStringToColor(std::string s)103 Color HashStringToColor(std::string s) {
104   // Return a unique color for a name.
105   //
106   // This function removes Tensorflow anti-collision suffixes (eg "_2"), hashes
107   // the string to a uint_32, then twiddles some bits to get a light and subtle
108   // color. This seems to be a good heuristic for keeping enough of the name to
109   // hash to a unique color while still revealing structure through naming
110   // similarities.
111   //
112   // The regular expression "_\d+" matches any underscore followed by numbers,
113   // which we strip out. Examples:
114   //
115   //     "Conv"      -> "Conv"
116   //     "Conv_2"    -> "Conv"
117   //     "Conv_72"   -> "Conv"
118   //     "Pad_1_bias -> "Pad_bias"
119   //     "Conv_abc"  -> "Conv_abc"
120 
121   RE2::GlobalReplace(&s, R"CODE(_\d+)CODE", "");
122   uint32 color_word = std::hash<std::string>{}(s);
123   color_word |= 0x00E0E0E0;
124   return Color(color_word);
125 }
126 
GetArrayColorAndShape(const Model & model,const std::string & array_name,Color * color,std::string * shape)127 void GetArrayColorAndShape(const Model& model, const std::string& array_name,
128                            Color* color, std::string* shape) {
129   // All colors in this file are from:
130   // https://material.io/guidelines/style/color.html
131   // Arrays involved in RNN back-edges have a different color
132   for (const auto& rnn_state : model.flags.rnn_states()) {
133     // RNN state, fed by a back-edge. Bold color.
134     if (array_name == rnn_state.state_array()) {
135       *color = Color(0x0F, 0x9D, 0x58);
136       *shape = "invhouse";
137       return;
138     }
139     // RNN back-edge source, feeding a RNN state.
140     // Light tone of the same color as RNN states.
141     if (array_name == rnn_state.back_edge_source_array()) {
142       *color = Color(0xB7, 0xE1, 0xCD);
143       *shape = "house";
144       return;
145     }
146   }
147   // Constant parameter arrays have their own bold color
148   if (model.GetArray(array_name).buffer) {
149     *color = Color(0x42, 0x85, 0xF4);
150     *shape = "cylinder";
151     return;
152   }
153   // Remaining arrays are activations.
154   // We use gray colors for them because they are the majority
155   // of arrays so we want to highlight other arrays instead of them.
156   // First, we use a bolder gray for input/output arrays:
157   if (IsInputArray(model, array_name)) {
158     *color = Color(0x9E, 0x9E, 0x9E);
159     *shape = "invhouse";
160     return;
161   }
162   if (IsOutputArray(model, array_name)) {
163     *color = Color(0x9E, 0x9E, 0x9E);
164     *shape = "house";
165     return;
166   }
167   // Remaining arrays are intermediate activation arrays.
168   // Lighter tone of the same grey as for input/output arrays:
169   // We want these to be very discrete.
170   *color = Color(0xF5, 0xF5, 0xF5);
171   *shape = "box";
172 }
173 
GetArrayCompassPt(const Model & model,const std::string & array_name)174 std::string GetArrayCompassPt(const Model& model,
175                               const std::string& array_name) {
176   // The "compass point" is the point on the node where edge connections are
177   // made. For most arrays we don't care, but input's and outputs look better
178   // connected at the tip of the "house" and "invhouse" shapes used. So we
179   // append ":n" and ":s" respectively for those.
180   for (const auto& rnn_state : model.flags.rnn_states()) {
181     // RNN state is essentially an input
182     if (array_name == rnn_state.state_array()) {
183       return ":s";
184     }
185     // RNN back-edge source is essentially an output
186     if (array_name == rnn_state.back_edge_source_array()) {
187       return ":n";
188     }
189   }
190   if (IsInputArray(model, array_name)) {
191     return ":s";
192   }
193   if (IsOutputArray(model, array_name)) {
194     return ":n";
195   }
196   return "";
197 }
198 
AppendArrayVal(std::string * string,Array const & array,int index)199 void AppendArrayVal(std::string* string, Array const& array, int index) {
200   if (array.buffer->type == ArrayDataType::kFloat) {
201     const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
202     if (index >= data.size()) {
203       return;
204     }
205     AppendF(string, "%.3f", data[index]);
206   } else if (array.buffer->type == ArrayDataType::kUint8) {
207     const auto& data = array.GetBuffer<ArrayDataType::kUint8>().data;
208     if (index >= data.size()) {
209       return;
210     }
211     AppendF(string, "%d", data[index]);
212   } else if (array.buffer->type == ArrayDataType::kInt16) {
213     const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
214     if (index >= data.size()) {
215       return;
216     }
217     AppendF(string, "%d", data[index]);
218   } else if (array.buffer->type == ArrayDataType::kInt32) {
219     const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
220     if (index >= data.size()) {
221       return;
222     }
223     AppendF(string, "%d", data[index]);
224   } else if (array.buffer->type == ArrayDataType::kInt64) {
225     const auto& data = array.GetBuffer<ArrayDataType::kInt64>().data;
226     if (index >= data.size()) {
227       return;
228     }
229     AppendF(string, "%d", data[index]);
230   } else if (array.buffer->type == ArrayDataType::kBool) {
231     const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
232     if (index >= data.size()) {
233       return;
234     }
235     AppendF(string, "%d", data[index]);
236   }
237 }
238 
239 typedef std::map<std::string, std::string> Attributes;
240 
AttributesToHtml(Attributes attributes)241 std::string AttributesToHtml(Attributes attributes) {
242   std::string html;
243   for (const auto& attr : attributes) {
244     html += R"CODE(<TR><TD CELLPADDING="1" ALIGN="RIGHT">)CODE";
245     html += attr.first;
246     html += R"CODE(:</TD><TD CELLPADDING="1" ALIGN="LEFT">)CODE";
247     html += attr.second;
248     html += "</TD></TR>";
249   }
250   return html;
251 }
252 
GetArrayLabel(const Model & model,const std::string & array_id)253 std::string GetArrayLabel(const Model& model, const std::string& array_id) {
254   std::string html;
255 
256   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
257   html += "<";
258 
259   // Begin Table
260   html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
261   html += R"CODE(<TABLE BORDER="0" CELLSPACING="2" CELLPADDING="0">)CODE";
262 
263   auto& array = model.GetArray(array_id);
264   if (array.buffer) {
265     // "cylinder" shapes require some extra head room.
266     html += R"CODE(<TR><TD COLSPAN="2"> </TD></TR>)CODE";
267   }
268 
269   // "Primary" name of array (last non-slash delimited group of characters).
270   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
271   html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><I>)CODE";
272   AppendF(&html, R"CODE(%s)CODE",
273           std::vector<std::string>(absl::StrSplit(array_id, '/')).back());
274   html += R"CODE(</I></FONT>)CODE";
275   html += "</TD></TR>";
276 
277   // Array data type and dimensions
278   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
279   html += R"CODE(<FONT POINT-SIZE="14" FACE="Courier"><B>)CODE";
280   // Type
281   html += ArrayDataTypeName(array.data_type);
282   // Shape
283   if (array.has_shape()) {
284     auto& array_shape = array.shape();
285     html += "[";
286     for (int dim = 0; dim < array_shape.dimensions_count(); dim++) {
287       AppendF(&html, "%d", array_shape.dims(dim));
288       if (dim + 1 < array_shape.dimensions_count()) {
289         html += kUnicodeMult;
290       }
291     }
292     html += "]";
293   }
294 
295   // Small buffer sample
296   int buffer_size = 0;
297   if (array.buffer) {
298     buffer_size = RequiredBufferSizeForShape(array.shape());
299   }
300   if ((buffer_size > 0) && (buffer_size <= 4)) {
301     html += " = ";
302     if (array.shape().dimensions_count() > 0) {
303       html += "{";
304     }
305     for (int i = 0; i < buffer_size; i++) {
306       AppendArrayVal(&html, array, i);
307       if (i + 1 < buffer_size) {
308         html += ", ";
309       }
310     }
311     if (array.shape().dimensions_count() > 0) {
312       html += "}";
313     }
314   }
315   html += R"CODE(</B></FONT>)CODE";
316   html += "</TD></TR>";
317 
318   // Large buffer samples get their own line
319   if (buffer_size > 4) {
320     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER"> = {)CODE";
321     AppendArrayVal(&html, array, 0);
322     html += ", ";
323     AppendArrayVal(&html, array, 1);
324     html += kUnicodeEllipsis;
325     AppendArrayVal(&html, array, buffer_size - 2);
326     html += ", ";
327     AppendArrayVal(&html, array, buffer_size - 1);
328     html += "}</TD></TR>";
329   }
330 
331   // Other array properties
332   Attributes attrs;
333   if (array.minmax) {
334     attrs["minmax"] =
335         StringF("[%.7g, %.7g]", array.minmax->min, array.minmax->max);
336   }
337   if (array.quantization_params) {
338     attrs["quant"] = StringF("%7g\u00B7(x-%d)",  // Unicode "cdot"
339                              array.quantization_params->scale,
340                              array.quantization_params->zero_point);
341   }
342   if (array.alloc) {
343     attrs["alloc"] = StringF("[%d, %d)", array.alloc->start, array.alloc->end);
344   }
345   html += AttributesToHtml(attrs);
346 
347   // output array_id in ultra-small font so it can be searched and copied.
348   html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
349   html += R"CODE(<FONT POINT-SIZE="3" FACE="">)CODE";
350   AppendF(&html, R"CODE("%s")CODE", array_id);
351   html += R"CODE(</FONT>)CODE";
352   html += "</TD></TR>";
353 
354   // End Table and HTML-like label
355   html += R"CODE(</TABLE></FONT>)CODE";
356   html += ">";
357   return html;
358 }
359 
GetOpAttributes(const Model & model,const Operator & op)360 Attributes GetOpAttributes(const Model& model, const Operator& op) {
361   Attributes attrs;
362   switch (op.fused_activation_function) {
363     case FusedActivationFunctionType::kRelu:
364       attrs["func"] = "ReLU";
365       break;
366     case FusedActivationFunctionType::kRelu6:
367       attrs["func"] = "ReLU6";
368       break;
369     case FusedActivationFunctionType::kRelu1:
370       attrs["func"] = "ReLU1";
371       break;
372     default:
373       break;
374   }
375   // Output state of member vars on derived operators.
376   switch (op.type) {
377     case OperatorType::kConv: {
378       const auto& conv_op = static_cast<const ConvOperator&>(op);
379       std::string stride;
380       AppendF(&stride, "%d", conv_op.stride_width);
381       stride += kUnicodeMult;
382       AppendF(&stride, "%d", conv_op.stride_height);
383       attrs["stride"] = stride;
384       attrs["padding"] =
385           (conv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
386       break;
387     }
388     case OperatorType::kDepthwiseConv: {
389       const auto& depthconv_op = static_cast<const ConvOperator&>(op);
390       std::string stride;
391       AppendF(&stride, "%d", depthconv_op.stride_width);
392       stride += kUnicodeMult;
393       AppendF(&stride, "%d", depthconv_op.stride_height);
394       attrs["stride"] = stride;
395       attrs["padding"] =
396           (depthconv_op.padding.type == PaddingType::kSame) ? "same" : "valid";
397       break;
398     }
399     case OperatorType::kFakeQuant: {
400       const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
401       attrs["bits"] = StringF("%d", fakequant_op.num_bits);
402       if (fakequant_op.minmax) {
403         attrs["range"] = StringF("[%g,%g]", fakequant_op.minmax->min,
404                                  fakequant_op.minmax->max);
405       } else {
406         attrs["range"] = "[?,?]";
407       }
408       break;
409     }
410     default:
411       break;
412   }
413   int64_t math_ops_count;
414   if (EstimateArithmeticOpsCount(model, op, &math_ops_count) &&
415       (math_ops_count != 0)) {
416     attrs["math"] = FormattedNumber(math_ops_count) + "ops";
417   }
418 
419   return attrs;
420 }
421 
GetOpColor(const Operator & op)422 Color GetOpColor(const Operator& op) {
423   if ((op.type == OperatorType::kDepthwiseConv) ||
424       (op.type == OperatorType::kConv) ||
425       (op.type == OperatorType::kFullyConnected) ||
426       (op.type == OperatorType::kFakeQuant)) {
427     // Give some ops a bolder red
428     return Color(0xC5, 0x39, 0x29);
429   } else {
430     return Color(0xDB, 0x44, 0x37);
431   }
432 }
433 
GetOpLabel(const Model & model,const Operator & op)434 std::string GetOpLabel(const Model& model, const Operator& op) {
435   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
436   std::string html;
437   html += "<";
438 
439   // Begin Table
440   html += R"CODE(<FONT POINT-SIZE="10" FACE="Courier">)CODE";
441   html +=
442       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
443 
444   // Input Ports
445   if (!op.inputs.empty()) {
446     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
447     // Distribute evenly using a sub-table
448     html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
449     html += R"CODE(<TR>)CODE";
450     for (int i = 0; i < op.inputs.size(); i++) {
451       html += R"CODE(<TD PORT=")CODE";
452       AppendF(&html, "i%d", i);
453       html += R"CODE(">)CODE";
454       if (op.inputs.size() > 1) {
455         // Only number inputs when op has two or more inputs
456         AppendF(&html, "%d", i);
457       }
458       html += "</TD>";
459     }
460     html += "</TR>";
461     html += R"CODE(</TABLE></TD></TR>)CODE";
462   }
463 
464   // Name
465   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
466   html += R"CODE(<FONT POINT-SIZE="16" FACE="Helvetica"><B>)CODE";
467   if (op.type == OperatorType::kUnsupported) {
468     html += static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
469   } else {
470     html +=
471         std::string(absl::StripPrefix(OperatorTypeName(op.type), "TensorFlow"));
472   }
473   html += R"CODE(</B></FONT>)CODE";
474   html += "</TD></TR>";
475 
476   // Attributes
477   Attributes attrs = GetOpAttributes(model, op);
478   html += AttributesToHtml(attrs);
479 
480   // Output Ports
481   if (!op.outputs.empty()) {
482     html += R"CODE(<TR><TD COLSPAN="2" ALIGN="CENTER">)CODE";
483     // Distribute evenly using a sub-table
484     html += R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">)CODE";
485     html += R"CODE(<TR>)CODE";
486     for (int i = 0; i < op.outputs.size(); i++) {
487       html += R"CODE(<TD PORT=")CODE";
488       AppendF(&html, "o%d", i);
489       html += R"CODE(">)CODE";
490       if (op.outputs.size() > 1) {
491         // Only number outputs when op has two or more outputs
492         AppendF(&html, "%d", i);
493       }
494       html += "</TD>";
495     }
496     html += "</TR>";
497     html += R"CODE(</TABLE></TD></TR>)CODE";
498   }
499 
500   // End Table and HTML-like label
501   html += R"CODE(</TABLE></FONT>)CODE";
502   html += ">";
503 
504   return html;
505 }
506 
GetLog2BufferSize(const Model & model,const std::string & array_id)507 float GetLog2BufferSize(const Model& model, const std::string& array_id) {
508   auto& array = model.GetArray(array_id);
509   if (array.has_shape()) {
510     int buffer_size = 0;
511     if (IsNonEmpty(array.shape())) {
512       buffer_size = RequiredBufferSizeForShape(array.shape());
513       return std::log2(static_cast<float>(buffer_size));
514     }
515   }
516   return 0.0f;
517 }
518 
GetOpId(int op_index)519 std::string GetOpId(int op_index) { return StringF("op%05d", op_index); }
520 
DumpOperator(const Model & model,std::string * output_file,int op_index)521 void DumpOperator(const Model& model, std::string* output_file, int op_index) {
522   // Dump node for operator.
523   const Operator& op = *model.operators[op_index];
524   Color color = GetOpColor(op);
525   std::string label = GetOpLabel(model, op);
526   std::string op_id = GetOpId(op_index);
527   AppendF(output_file, kOpNodeFmt, op_id, label, color.AsHexString(),
528           color.TextColorString());
529 }
530 
DumpOperatorEdges(const Model & model,std::string * output_file,int op_index)531 void DumpOperatorEdges(const Model& model, std::string* output_file,
532                        int op_index) {
533   // Inputs
534   const Operator& op = *model.operators[op_index];
535   std::string op_id = GetOpId(op_index);
536   for (int i = 0; i < op.inputs.size(); i++) {
537     const auto& input = op.inputs[i];
538     if (!model.HasArray(input)) {
539       // Connected arrays should _always_ exist. Except, perhaps, during
540       // development.
541       continue;
542     }
543     float log2_buffer_size = GetLog2BufferSize(model, input);
544     // Draw lines that transport more data thicker (Otherwise, where would the
545     // data fit? right?).
546     float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
547     // Keep edges that transport more data shorter than those with less.
548     float weight = std::max(1.0f, log2_buffer_size);
549     if (!IsInputArray(model, input) &&
550         GetOpWithOutput(model, input) == nullptr) {
551       // Give the main line of data flow a straighter path by penalizing edges
552       // to standalone buffers. Weights are generally very large buffers that
553       // would otherwise skew the layout.
554       weight = 1.0f;
555     }
556     std::string compass_pt = GetArrayCompassPt(model, input);
557     AppendF(output_file, kInputEdgeFmt, input, compass_pt, op_id, i, line_width,
558             weight);
559   }
560   // Outputs
561   for (int i = 0; i < op.outputs.size(); i++) {
562     const auto& output = op.outputs[i];
563     if (!model.HasArray(output)) {
564       continue;
565     }
566     float log2_buffer_size = GetLog2BufferSize(model, output);
567     // See comments above regarding weight and line_width calculations.
568     float line_width = std::max(0.5f, log2_buffer_size / 3.0f);
569     float weight = std::max(1.0f, log2_buffer_size);
570     if (!IsArrayConsumed(model, output)) {
571       weight = 1.0f;
572     }
573     std::string compass_pt = GetArrayCompassPt(model, output);
574     AppendF(output_file, kOutputEdgeFmt, op_id, i, output, compass_pt,
575             line_width, weight);
576   }
577 }
578 
579 struct Node {
Nodetoco::__anone48474c00111::Node580   Node() : math_ops(0) {}
581   // Name used as a key in the model's array map
582   std::string array_id;
583 
584   // Estimated number of math ops incurred by this node (the sum of the op
585   // with this array as 1st output, plus all children nodes).
586   int64_t math_ops;
587 
588   // A map of child nodes keyed by name.
589   std::map<const std::string, std::unique_ptr<Node>> children;
590 };
591 
GetSubgraphLabel(Node const & node,const std::string & subgraph)592 std::string GetSubgraphLabel(Node const& node, const std::string& subgraph) {
593   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
594   std::string html;
595   html += "<";
596 
597   // Begin Table
598   html += R"CODE(<FONT POINT-SIZE="12" FACE="Courier">)CODE";
599   html +=
600       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
601 
602   // Name
603   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
604   html += R"CODE(<FONT POINT-SIZE="18" FACE="Helvetica"><I>)CODE";
605   html += subgraph;
606   html += R"CODE(</I></FONT>)CODE";
607   html += "</TD></TR>";
608 
609   // Attributes
610   Attributes attrs;
611   if (node.math_ops > 0) {
612     attrs["math"] = FormattedNumber(node.math_ops) + "ops";
613   }
614   html += AttributesToHtml(attrs);
615 
616   // End Table and HTML-like label
617   html += R"CODE(</TABLE></FONT>)CODE";
618   html += ">";
619 
620   return html;
621 }
622 
DumpSubgraphHeader(std::string * output_file,Node const & node,const std::string & node_name)623 void DumpSubgraphHeader(std::string* output_file, Node const& node,
624                         const std::string& node_name) {
625   Color color = HashStringToColor(node_name);
626   std::string label = GetSubgraphLabel(node, node_name);
627   AppendF(output_file, kSubgraphFmt, node_name, color.AsHexString(), label);
628 }
629 
DumpArray(const Model & model,std::string * output_file,const std::string & array_id)630 void DumpArray(const Model& model, std::string* output_file,
631                const std::string& array_id) {
632   Color color;
633   std::string shape;
634   GetArrayColorAndShape(model, array_id, &color, &shape);
635   std::string label = GetArrayLabel(model, array_id);
636   AppendF(output_file, kArrayNodeFmt, array_id, label, array_id, shape,
637           color.AsHexString(), color.TextColorString());
638 
639   // Ops are placed in the same subgraph as their first output.
640   for (int op_index = 0; op_index < model.operators.size(); op_index++) {
641     const Operator& op = *model.operators[op_index];
642     if (!op.outputs.empty() && (op.outputs[0] == array_id)) {
643       DumpOperator(model, output_file, op_index);
644     }
645   }
646 }
647 
DumpNode(const Model & model,std::string * output_file,const std::string & node_name,Node const & node)648 void DumpNode(const Model& model, std::string* output_file,
649               const std::string& node_name, Node const& node) {
650   bool not_root = !node_name.empty();
651   if (not_root) {
652     DumpSubgraphHeader(output_file, node, node_name);
653   }
654 
655   for (const auto& child : node.children) {
656     if (!child.second->array_id.empty()) {
657       // Dump array if this node possesses one.
658       DumpArray(model, output_file, child.second->array_id);
659     }
660     // Note that it is always possible to have children. Unlike a filesystem,
661     // the existence of array "foo/bar" does _not_ prevent other arrays, such as
662     // and "foo/bar/baz", from being nested beneath it.
663     DumpNode(model, output_file, child.first, *child.second);
664   }
665 
666   if (not_root) {
667     // End subgraph
668     AppendF(output_file, "    }\n");
669   }
670 }
671 
GetArithmeticOpsCount(const Model & model,const std::string & array_id)672 int64_t GetArithmeticOpsCount(const Model& model, const std::string& array_id) {
673   for (const auto& op : model.operators) {
674     if (!op->outputs.empty() && op->outputs[0] == array_id) {
675       int64_t count;
676       if (EstimateArithmeticOpsCount(model, *op, &count)) {
677         return count;
678       } else {
679         return 0;
680       }
681     }
682   }
683   return 0;
684 }
685 
InsertNode(const Model & model,const std::string & array_id,Node * node,std::vector<std::string> prefixes,int64_t * math_ops)686 void InsertNode(const Model& model, const std::string& array_id, Node* node,
687                 std::vector<std::string> prefixes, int64_t* math_ops) {
688   if (prefixes.empty()) {
689     // Base case: store array in this node.
690     node->array_id = array_id;
691     *math_ops = GetArithmeticOpsCount(model, array_id);
692   } else {
693     // Insert into the sub-tree for that prefix.
694     std::string prefix = prefixes.back();
695     prefixes.pop_back();
696     if (node->children.count(prefix) == 0) {
697       // Create a new node if this prefix is unseen.
698       node->children[prefix] = std::make_unique<Node>();
699     }
700     InsertNode(model, array_id, node->children[prefix].get(), prefixes,
701                math_ops);
702   }
703   // Sum estimated math ops into all nodes.
704   node->math_ops += *math_ops;
705 }
706 
BuildArrayTree(const Model & model,Node * tree)707 void BuildArrayTree(const Model& model, Node* tree) {
708   // Delimit array names by path "/", then place into a tree based on this path.
709   for (const auto& array_id : model.GetArrayMap()) {
710     std::vector<std::string> prefixes = absl::StrSplit(array_id.first, '/');
711     std::reverse(prefixes.begin(), prefixes.end());
712     int64_t math_ops;  // Temporary storage for math ops used during recursion.
713     InsertNode(model, array_id.first, tree, prefixes, &math_ops);
714   }
715 }
716 
GetGraphLabel(const Model & model,const std::string & graph_name)717 std::string GetGraphLabel(const Model& model, const std::string& graph_name) {
718   // Use HTML-like labels (http://www.graphviz.org/doc/info/shapes.html#html)
719   std::string html;
720   html += "<";
721 
722   // Begin Table
723   html += R"CODE(<FONT POINT-SIZE="36" FACE="Courier">)CODE";
724   html +=
725       R"CODE(<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">)CODE";
726 
727   // Name
728   html += R"CODE(<TR><TD COLSPAN="2" CELLPADDING="3" ALIGN="CENTER">)CODE";
729   html += R"CODE(<FONT POINT-SIZE="64" FACE="Helvetica"><B><I>)CODE";
730   html += graph_name;
731   html += R"CODE(</I></B></FONT>)CODE";
732   html += "</TD></TR>";
733 
734   // Attributes
735   Attributes attrs;
736   attrs["arrays"] = StringF("%d", model.GetArrayMap().size());
737   if (!model.optional_arrays.empty()) {
738     attrs["optional arrays"] = StringF("%d", model.optional_arrays.size());
739   }
740   attrs["operators"] = StringF("%d", model.operators.size());
741   int64_t ops_count;
742   if (EstimateArithmeticOpsCount(model, &ops_count) && (ops_count > 0)) {
743     attrs["math"] = FormattedNumber(ops_count) + "ops";
744   }
745   if (model.transient_data_size > 0) {
746     attrs["transient data size"] =
747         StringF("%d KiB", model.transient_data_size / 1024);
748   }
749   if (model.transient_data_alignment > 0) {
750     attrs["transient data alignment"] =
751         StringF("%d bytes", model.transient_data_alignment);
752   }
753   html += AttributesToHtml(attrs);
754 
755   // End Table and HTML-like label
756   html += R"CODE(</TABLE></FONT>)CODE";
757   html += ">";
758 
759   return html;
760 }
761 }  // namespace
762 
DumpGraphviz(const Model & model,std::string * output_file,const std::string & graph_name)763 void DumpGraphviz(const Model& model, std::string* output_file,
764                   const std::string& graph_name) {
765   // Start graphviz format
766   AppendF(output_file, kGraphFmt, GetGraphLabel(model, graph_name));
767 
768   // Organize arrays into a tree for subgraphing
769   Node tree;
770   BuildArrayTree(model, &tree);
771   DumpNode(model, output_file, "", tree);
772 
773   // Dump edges outside all subgraphs (otherwise the referred-to nodes are
774   // implicitly included in that subgraph).
775   for (int op_index = 0; op_index < model.operators.size(); op_index++) {
776     DumpOperatorEdges(model, output_file, op_index);
777   }
778 
779   // Dump RNN Backedges
780   for (const auto& rnn_state : model.flags.rnn_states()) {
781     AppendF(output_file, kRNNBackEdgeFmt, rnn_state.back_edge_source_array(),
782             rnn_state.state_array());
783   }
784   // End graphviz format
785   AppendF(output_file, "}\n");
786 }
787 }  // namespace toco
788