1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This program prints out a summary of a GraphDef file's contents, listing
17 // things that are useful for debugging and reusing the model it contains. For
18 // example it looks at the graph structure and op types to figure out likely
19 // input and output nodes, and shows which ops are used by the graph. To use it,
20 // run something like this:
21 //
22 // bazel build tensorflow/tools/graph_transforms:summarize_graph
23 // bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
24 // --in_graph=my_graph.pb
25
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/init_main.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/util/command_line_flags.h"
36 #include "tensorflow/tools/graph_transforms/file_utils.h"
37 #include "tensorflow/tools/graph_transforms/transform_utils.h"
38
39 namespace tensorflow {
40 namespace graph_transforms {
41 namespace {
42
PrintNodeInfo(const NodeDef * node)43 void PrintNodeInfo(const NodeDef* node) {
44 string shape_description = "None";
45 if (node->attr().count("shape")) {
46 TensorShapeProto shape_proto = node->attr().at("shape").shape();
47 Status shape_status = PartialTensorShape::IsValidShape(shape_proto);
48 if (shape_status.ok()) {
49 shape_description = PartialTensorShape(shape_proto).DebugString();
50 } else {
51 shape_description = shape_status.error_message();
52 }
53 }
54 DataType dtype = DT_INVALID;
55 if (node->attr().count("dtype")) {
56 dtype = node->attr().at("dtype").type();
57 }
58 std::cout << "(name=" << node->name();
59 std::cout << ", type=" << DataTypeString(dtype) << "(" << dtype << ")";
60 std::cout << ", shape=" << shape_description << ") ";
61 }
62
PrintBenchmarkUsage(const std::vector<const NodeDef * > & placeholders,const std::vector<const NodeDef * > & variables,const std::vector<const NodeDef * > outputs,const string & graph_path)63 void PrintBenchmarkUsage(const std::vector<const NodeDef*>& placeholders,
64 const std::vector<const NodeDef*>& variables,
65 const std::vector<const NodeDef*> outputs,
66 const string& graph_path) {
67 std::vector<const NodeDef*> all_inputs(placeholders);
68 all_inputs.insert(all_inputs.end(), variables.begin(), variables.end());
69
70 std::vector<string> input_layers;
71 std::vector<string> input_layer_types;
72 std::vector<string> input_layer_shapes;
73 for (const NodeDef* node : all_inputs) {
74 input_layers.push_back(node->name());
75 DataType dtype = DT_INVALID;
76 if (node->attr().count("dtype")) {
77 dtype = node->attr().at("dtype").type();
78 }
79 input_layer_types.push_back(DataTypeString(dtype));
80 std::vector<int64_t> sizes;
81 PartialTensorShape shape;
82 if (node->attr().count("shape")) {
83 TensorShapeProto shape_proto = node->attr().at("shape").shape();
84 if (PartialTensorShape::IsValid(shape_proto)) {
85 shape = PartialTensorShape(shape_proto);
86 }
87 }
88 string sizes_string;
89 if (shape.dims() == -1) {
90 // Unknown shapes can have -1 for dims, so leave these blank.
91 sizes_string = "";
92 } else {
93 sizes.reserve(shape.dims());
94 for (int i = 0; i < shape.dims(); ++i) {
95 sizes.push_back(shape.dim_size(i));
96 }
97 sizes_string = absl::StrJoin(sizes, ",");
98 }
99 input_layer_shapes.push_back(sizes_string);
100 }
101 std::vector<string> output_layers;
102 output_layers.reserve(outputs.size());
103 for (const NodeDef* node : outputs) {
104 output_layers.push_back(node->name());
105 }
106 string input_layer_value = absl::StrJoin(input_layers, ",");
107 string input_layer_type_value = absl::StrJoin(input_layer_types, ",");
108 string input_layer_shape_value = absl::StrJoin(input_layer_shapes, ":");
109 string output_layer_value = absl::StrJoin(output_layers, ",");
110
111 std::cout << "To use with tensorflow/tools/benchmark:benchmark_model try "
112 "these arguments:"
113 << std::endl;
114 std::cout << "bazel run tensorflow/tools/benchmark:benchmark_model --";
115 std::cout << " --graph=" << graph_path;
116 std::cout << " --show_flops";
117 std::cout << " --input_layer=" << input_layer_value;
118 std::cout << " --input_layer_type=" << input_layer_type_value;
119 std::cout << " --input_layer_shape=" << input_layer_shape_value;
120 std::cout << " --output_layer=" << output_layer_value;
121 std::cout << std::endl;
122 }
123
PrintStructure(const GraphDef & graph)124 Status PrintStructure(const GraphDef& graph) {
125 GraphDef sorted_graph;
126 TF_RETURN_IF_ERROR(SortByExecutionOrder(graph, &sorted_graph));
127 for (const NodeDef& node : sorted_graph.node()) {
128 std::cout << node.name() << " (" << node.op() << "): ["
129 << absl::StrJoin(node.input(), ", ") << "]";
130 if (node.op() == "Const") {
131 Tensor tensor;
132 if (node.attr().count("value") &&
133 tensor.FromProto(node.attr().at("value").tensor())) {
134 std::cout << ", value=" << tensor.DebugString();
135 } else {
136 LOG(WARNING) << "Decoding Tensor failed for node" << node.name();
137 }
138 }
139 std::cout << std::endl;
140 }
141 return OkStatus();
142 }
143
SummarizeGraph(const GraphDef & graph,const string & graph_path,bool print_structure)144 Status SummarizeGraph(const GraphDef& graph, const string& graph_path,
145 bool print_structure) {
146 std::vector<const NodeDef*> placeholders;
147 std::vector<const NodeDef*> variables;
148 for (const NodeDef& node : graph.node()) {
149 if (node.op() == "Placeholder") {
150 placeholders.push_back(&node);
151 }
152 if (node.op() == "Variable" || node.op() == "VariableV2") {
153 variables.push_back(&node);
154 }
155 }
156
157 if (placeholders.empty()) {
158 std::cout << "No inputs spotted." << std::endl;
159 } else {
160 std::cout << "Found " << placeholders.size() << " possible inputs: ";
161 for (const NodeDef* node : placeholders) {
162 PrintNodeInfo(node);
163 }
164 std::cout << std::endl;
165 }
166
167 if (variables.empty()) {
168 std::cout << "No variables spotted." << std::endl;
169 } else {
170 std::cout << "Found " << variables.size() << " variables: ";
171 for (const NodeDef* node : variables) {
172 PrintNodeInfo(node);
173 }
174 std::cout << std::endl;
175 }
176
177 std::map<string, std::vector<const NodeDef*>> output_map;
178 MapNodesToOutputs(graph, &output_map);
179 std::vector<const NodeDef*> outputs;
180 std::unordered_set<string> unlikely_output_types = {"Const", "Assign", "NoOp",
181 "Placeholder"};
182 for (const NodeDef& node : graph.node()) {
183 if ((output_map.count(node.name()) == 0) &&
184 (unlikely_output_types.count(node.op()) == 0)) {
185 outputs.push_back(&node);
186 }
187 }
188
189 if (outputs.empty()) {
190 std::cout << "No outputs spotted." << std::endl;
191 } else {
192 std::cout << "Found " << outputs.size() << " possible outputs: ";
193 for (const NodeDef* node : outputs) {
194 std::cout << "(name=" << node->name();
195 std::cout << ", op=" << node->op() << ") ";
196 }
197 std::cout << std::endl;
198 }
199
200 int64_t const_parameter_count = 0;
201 int64_t variable_parameter_count = 0;
202 int control_edge_count = 0;
203 std::map<string, int> device_counts;
204 for (const NodeDef& node : graph.node()) {
205 for (const string& input : node.input()) {
206 if (input.substr(0, 1) == "^") {
207 ++control_edge_count;
208 }
209 }
210 if (!node.device().empty()) {
211 ++device_counts[node.device()];
212 }
213 if ((node.op() == "Const") || (node.op() == "Variable") ||
214 (node.op() == "VariableV2")) {
215 Tensor tensor;
216 if (node.attr().count("value") &&
217 tensor.FromProto(node.attr().at("value").tensor())) {
218 const size_t num_elements = tensor.NumElements();
219 if (node.op() == "Const") {
220 const_parameter_count += num_elements;
221 } else {
222 variable_parameter_count += num_elements;
223 }
224 } else {
225 LOG(WARNING) << "Decoding Tensor failed for node" << node.name();
226 }
227 }
228 }
229
230 std::cout << "Found " << const_parameter_count << " ("
231 << strings::HumanReadableNum(const_parameter_count)
232 << ") const parameters, " << variable_parameter_count << " ("
233 << strings::HumanReadableNum(variable_parameter_count)
234 << ") variable parameters, and " << control_edge_count
235 << " control_edges" << std::endl;
236 if (!device_counts.empty()) {
237 for (const auto& device_info : device_counts) {
238 std::cout << device_info.second << " nodes assigned to device '"
239 << device_info.first << "'";
240 }
241 }
242
243 std::vector<std::pair<string, string>> invalid_inputs;
244 FindInvalidInputs(graph, &invalid_inputs);
245 if (!invalid_inputs.empty()) {
246 for (const std::pair<string, string>& invalid_input : invalid_inputs) {
247 std::cout << "Invalid input " << invalid_input.second << " for node "
248 << invalid_input.first << std::endl;
249 }
250 return errors::Internal(
251 "Invalid graph with inputs referring to nonexistent nodes");
252 }
253
254 std::map<string, int> op_counts;
255 for (const NodeDef& node : graph.node()) {
256 ++op_counts[node.op()];
257 }
258 for (const FunctionDef& function : graph.library().function()) {
259 for (const NodeDef& node : function.node_def()) {
260 ++op_counts[node.op()];
261 }
262 }
263 std::vector<std::pair<string, int>> op_counts_vec(op_counts.begin(),
264 op_counts.end());
265 std::sort(op_counts_vec.begin(), op_counts_vec.end(),
266 [](std::pair<string, int> a, std::pair<string, int> b) {
267 return (a.second > b.second);
268 });
269 std::cout << "Op types used: ";
270 bool is_first = true;
271 for (const std::pair<string, int>& op_count : op_counts_vec) {
272 if (!is_first) {
273 std::cout << ", ";
274 } else {
275 is_first = false;
276 }
277 std::cout << op_count.second << " " << op_count.first;
278 }
279 std::cout << std::endl;
280
281 PrintBenchmarkUsage(placeholders, variables, outputs, graph_path);
282
283 if (print_structure) {
284 TF_RETURN_IF_ERROR(PrintStructure(graph));
285 }
286
287 return OkStatus();
288 }
289
ParseFlagsAndSummarizeGraph(int argc,char * argv[])290 int ParseFlagsAndSummarizeGraph(int argc, char* argv[]) {
291 string in_graph = "";
292 bool print_structure = false;
293 std::vector<Flag> flag_list = {
294 Flag("in_graph", &in_graph, "input graph file name"),
295 Flag("print_structure", &print_structure,
296 "whether to print the network connections of the graph"),
297 };
298 string usage = Flags::Usage(argv[0], flag_list);
299
300 const bool parse_result = Flags::Parse(&argc, argv, flag_list);
301 // We need to call this to set up global state for TensorFlow.
302 port::InitMain(argv[0], &argc, &argv);
303
304 if (!parse_result) {
305 LOG(ERROR) << usage;
306 return -1;
307 }
308 if (argc > 1) {
309 LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
310 return -1;
311 }
312 if (in_graph.empty()) {
313 LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
314 return -1;
315 }
316
317 GraphDef graph_def;
318 Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
319 if (!load_status.ok()) {
320 LOG(ERROR) << "Loading graph '" << in_graph << "' failed with "
321 << load_status.error_message();
322 LOG(ERROR) << usage;
323 return -1;
324 }
325
326 Status summarize_result =
327 SummarizeGraph(graph_def, in_graph, print_structure);
328 if (!summarize_result.ok()) {
329 LOG(ERROR) << summarize_result.error_message() << "\n" << usage;
330 return -1;
331 }
332
333 return 0;
334 }
335
336 } // namespace
337 } // namespace graph_transforms
338 } // namespace tensorflow
339
main(int argc,char * argv[])340 int main(int argc, char* argv[]) {
341 return tensorflow::graph_transforms::ParseFlagsAndSummarizeGraph(argc, argv);
342 }
343