1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/logging/conversion_log_util.h"
16
17 #include <string>
18
19 #ifdef __linux__
20 #include <sys/utsname.h>
21 #endif
22
23 #include <vector>
24
25 #include "absl/strings/str_cat.h"
26 #include "absl/time/clock.h"
27 #include "absl/time/time.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/lite/toco/model.h"
30 #include "tensorflow/lite/toco/tflite/export.h"
31 #include "tensorflow/lite/toco/tflite/operator.h"
32 #include "tensorflow/lite/toco/tooling_util.h"
33 #include "tensorflow/lite/version.h"
34
35 namespace toco {
36
37 namespace {
38
TryGetOperatorName(const Operator & op)39 std::string TryGetOperatorName(const Operator& op) {
40 std::string op_name;
41 if (!op.tensorflow_node_def.empty()) {
42 // Parse op name from serialized NodeDef.
43 tensorflow::NodeDef node_def;
44 if (!node_def.ParseFromString(op.tensorflow_node_def)) {
45 LOG(ERROR) << "Failed to parse Tensorflow NodeDef";
46 } else {
47 op_name = node_def.op();
48 if (!op_name.empty()) return op_name;
49 }
50 }
51 if (op.type == OperatorType::kUnsupported) {
52 // If we failed to get op name from serialized NodeDef (either because
53 // the tensorflow_node_def is an empty string, or we failed to parse
54 // from it), fall back to use 'tensorflow_op' field if this op is a
55 // TensorflowUnsupportedOperator.
56 const TensorFlowUnsupportedOperator& unsupported_op =
57 static_cast<const TensorFlowUnsupportedOperator&>(op);
58 if (!unsupported_op.tensorflow_op.empty()) {
59 op_name = unsupported_op.tensorflow_op;
60 return op_name;
61 }
62 }
63 // If this is a built-in op.
64 op_name = OperatorTypeName(op.type);
65 return op_name;
66 }
67
GetOSVersion()68 std::string GetOSVersion() {
69 std::string os_info;
70 #ifdef __linux__
71 utsname info;
72 if (uname(&info)) {
73 // Failed
74 LOG(ERROR) << "Cannot get OS info.";
75 return "";
76 }
77 os_info =
78 std::string(info.sysname) + ";OSVer=" + std::string(info.release) + ";";
79 #endif
80 return os_info;
81 }
82
ShapeToStringNoSpace(const Shape & shape)83 std::string ShapeToStringNoSpace(const Shape& shape) {
84 if (shape.dimensions_count() == 0) {
85 return "[]";
86 }
87
88 return absl::StrCat("[", absl::StrJoin(shape.dims(), ","), "]");
89 }
90
GetOperatorSignature(const Model & model,const Operator & op,const std::map<OperatorType,std::unique_ptr<tflite::BaseOperator>> & op_types_map)91 std::string GetOperatorSignature(
92 const Model& model, const Operator& op,
93 const std::map<OperatorType, std::unique_ptr<tflite::BaseOperator>>&
94 op_types_map) {
95 // The signature of an op has the following schema:
96 // INPUT:SHAPE::TYPE::OUTPUT:SHAPE::TYPE::NAME:VERSION:
97 std::string op_signature;
98 constexpr char delimiter[] = "::";
99
100 // Get input shapes and types.
101 op_signature.append("INPUT:");
102 for (const auto& input : op.inputs) {
103 const auto& array = model.GetArray(input);
104 if (array.has_shape()) {
105 op_signature.append(ShapeToStringNoSpace(array.shape()));
106 } else {
107 op_signature.append("None");
108 }
109 op_signature.append(delimiter);
110 op_signature.append(ArrayDataTypeName(array.data_type) + delimiter);
111 }
112 // Get output shapes and types.
113 op_signature.append("OUTPUT:");
114 for (const auto& output : op.outputs) {
115 const auto& array = model.GetArray(output);
116 if (array.has_shape()) {
117 op_signature.append(ShapeToStringNoSpace(array.shape()));
118 } else {
119 op_signature.append("None");
120 }
121 op_signature.append(delimiter);
122 op_signature.append(ArrayDataTypeName(array.data_type) + delimiter);
123 }
124 // Append Op name.
125 op_signature.append("NAME:");
126 op_signature.append(TryGetOperatorName(op) + delimiter);
127 // Append Op version.
128 op_signature.append("VERSION:");
129 OperatorSignature toco_op_signature;
130 toco_op_signature.op = &op;
131 toco_op_signature.model = &model;
132 if (op_types_map.find(op.type) != op_types_map.end()) {
133 const int version = op_types_map.at(op.type)->GetVersion(toco_op_signature);
134 op_signature.append(std::to_string(version));
135 } else {
136 op_signature.append("None");
137 }
138 return op_signature;
139 }
140
141 } // namespace
142
GetOperatorNames(const Model & model)143 std::vector<std::string> GetOperatorNames(const Model& model) {
144 std::vector<std::string> op_names;
145 op_names.reserve(model.operators.size());
146 for (const auto& op : model.operators) {
147 op_names.push_back(TryGetOperatorName(*op));
148 }
149 return op_names;
150 }
151
CountOperatorsByType(const Model & model,std::map<std::string,int> * built_in_ops,std::map<std::string,int> * custom_ops,std::map<std::string,int> * select_ops)152 void CountOperatorsByType(const Model& model,
153 std::map<std::string, int>* built_in_ops,
154 std::map<std::string, int>* custom_ops,
155 std::map<std::string, int>* select_ops) {
156 for (const auto& op : model.operators) {
157 OperatorSignature op_signature = {op.get(), &model};
158 const auto ops_by_type =
159 tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/);
160 tflite::details::OperatorKey op_key(op_signature, ops_by_type,
161 true /*enable_select_tf_ops*/);
162
163 const std::string op_name = TryGetOperatorName(*op);
164 if (op_key.is_custom_op()) {
165 (*custom_ops)[op_name]++;
166 } else if (op_key.is_flex_op()) {
167 (*select_ops)[op_name]++;
168 } else {
169 (*built_in_ops)[op_name]++;
170 }
171 }
172 }
173
GetInputAndOutputTypes(const Model & model,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * input_types,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * output_types)174 void GetInputAndOutputTypes(
175 const Model& model,
176 TFLITE_PROTO_NS::RepeatedPtrField<std::string>* input_types,
177 TFLITE_PROTO_NS::RepeatedPtrField<std::string>* output_types) {
178 for (const auto& input_array : model.flags.input_arrays()) {
179 const Array& array = model.GetArray(input_array.name());
180 input_types->Add(ArrayDataTypeName(array.data_type));
181 }
182 for (const auto& output_array : model.flags.output_arrays()) {
183 const Array& array = model.GetArray(output_array);
184 output_types->Add(ArrayDataTypeName(array.data_type));
185 }
186 }
187
GetTfLiteVersion()188 std::string GetTfLiteVersion() { return TFLITE_VERSION_STRING; }
189
GetCachedOSVersion()190 std::string GetCachedOSVersion() {
191 static std::string* version = new std::string(GetOSVersion());
192 return *version;
193 }
194
GetOpSignatures(const Model & model,TFLITE_PROTO_NS::RepeatedPtrField<std::string> * op_signatures)195 void GetOpSignatures(
196 const Model& model,
197 TFLITE_PROTO_NS::RepeatedPtrField<std::string>* op_signatures) {
198 const auto& op_types_map =
199 tflite::BuildOperatorByTypeMap(true /*enable_select_tf_ops*/);
200 for (const auto& op : model.operators) {
201 op_signatures->Add(GetOperatorSignature(model, *op, op_types_map));
202 }
203 }
204
GetModelHash(const Model & model)205 std::string GetModelHash(const Model& model) {
206 // TODO(b/123519920): Implement the hash function for Model.
207 // Need to consider different implementations for public/private models.
208 return "";
209 }
210
211 // This function scans through the error message string, extracts the part about
212 // missing ops and prunes away all other information in the error info.
SanitizeErrorMessage(const std::string & error_message)213 std::string SanitizeErrorMessage(const std::string& error_message) {
214 const std::string s1 = "Ops that can be supported by the flex runtime";
215 const std::string s2 = "Ops that need custom implementation";
216 std::string pruned_message;
217 size_t pos = error_message.find(s1);
218 if (pos != std::string::npos) {
219 // Find the terminate point for flex op list.
220 auto end = error_message.find('.', pos);
221 pruned_message.append(error_message.substr(pos, end - pos + 1));
222 }
223 pos = error_message.find(s2);
224 if (pos != std::string::npos) {
225 // Find the terminate point for custom op list.
226 auto end = error_message.find('.', pos);
227 pruned_message.append(error_message.substr(pos, end - pos + 1));
228 }
229 return pruned_message;
230 }
231
PopulateConversionLog(const Model & model,TocoConversionLog * log)232 void PopulateConversionLog(const Model& model, TocoConversionLog* log) {
233 // Get the list of ops after conversion.
234 const std::vector<std::string> op_names = GetOperatorNames(model);
235 for (const auto& op_name : op_names) {
236 log->add_op_list(op_name);
237 }
238
239 // Get op signatures.
240 TFLITE_PROTO_NS::RepeatedPtrField<std::string> op_signatures;
241 GetOpSignatures(model, &op_signatures);
242 log->mutable_op_signatures()->CopyFrom(op_signatures);
243
244 // Get op counts by category: custom, built-in or select.
245 std::map<std::string, int> custom_ops, select_ops, built_in_ops;
246 CountOperatorsByType(model, &built_in_ops, &custom_ops, &select_ops);
247 log->mutable_custom_ops()->insert(custom_ops.cbegin(), custom_ops.cend());
248 log->mutable_built_in_ops()->insert(built_in_ops.cbegin(),
249 built_in_ops.cend());
250 log->mutable_select_ops()->insert(select_ops.cbegin(), select_ops.cend());
251
252 // Get the model's input and output types.
253 TFLITE_PROTO_NS::RepeatedPtrField<std::string> input_types, output_types;
254 GetInputAndOutputTypes(model, &input_types, &output_types);
255 log->mutable_input_tensor_types()->CopyFrom(input_types);
256 log->mutable_output_tensor_types()->CopyFrom(output_types);
257
258 log->set_log_generation_ts(absl::ToUnixMicros(absl::Now()));
259
260 log->set_model_size(model.operators.size());
261 log->set_tf_lite_version(GetTfLiteVersion());
262 log->set_os_version(GetCachedOSVersion());
263 log->set_model_hash(GetModelHash(model));
264 // TODO(b/123519920): Populate TOCO error logs.
265 // Currently we will focus on external installation of TOCO via pip, where
266 // the C++ TOCO binary is invoked via subprocess command, this will make our
267 // life easier collecting the error logs emitted by TOCO. However, note that
268 // if a user directly invokes the C++ TOCO binary, this log might not be
269 // available.
270 }
271
272 } // namespace toco
273