1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
16
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_split.h"
23 #include "mlir/IR/Diagnostics.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25
26 namespace mlir {
27 namespace TFL {
28 namespace {
29
30 // The signature contains namespaces (Ex: mlir::TFL::(anonymous namespace)::).
31 // So only extract the function name as the pass name.
extract_pass_name(const std::string & signature)32 inline std::string extract_pass_name(const std::string &signature) {
33 const std::vector<std::string> &v = absl::StrSplit(signature, "::");
34 return v.back();
35 }
36
37 // Errors raised by emitOpError start with "'<dialect>.<op>' op". Returns an
38 // empty string if the pattern is not found or the operator is not in tf or tfl
39 // dialect.
extract_op_name_from_error_message(const std::string & error_message)40 inline std::string extract_op_name_from_error_message(
41 const std::string &error_message) {
42 int end_pos = error_message.find("' op");
43 if ((absl::StartsWith(error_message, "'tf.") ||
44 absl::StartsWith(error_message, "'tfl.")) &&
45 end_pos != std::string::npos) {
46 return error_message.substr(1, end_pos - 1);
47 }
48 return "";
49 }
50
51 // Only notes with character count smaller than kMaxAcceptedNoteSize will be
52 // appended to the error message.
53 const int kMaxAcceptedNoteSize = 1024;
54 } // namespace
55
ErrorCollectorInstrumentation(MLIRContext * context)56 ErrorCollectorInstrumentation::ErrorCollectorInstrumentation(
57 MLIRContext *context)
58 : error_collector_(ErrorCollector::GetErrorCollector()) {
59 handler_ = std::make_unique<ScopedDiagnosticHandler>(
60 context, [this](Diagnostic &diag) {
61 if (diag.getSeverity() == DiagnosticSeverity::Error) {
62 Location loc = diag.getLocation();
63 std::string error_message = diag.str();
64 std::string op_name, error_code;
65 if (loc_to_name_.count(loc)) {
66 op_name = loc_to_name_[loc];
67 } else {
68 op_name = extract_op_name_from_error_message(diag.str());
69 }
70
71 for (const auto ¬e : diag.getNotes()) {
72 const std::string note_str = note.str();
73 if (note_str.rfind(kErrorCodePrefix, 0) == 0) {
74 error_code = note_str.substr(sizeof(kErrorCodePrefix) - 1);
75 }
76
77 error_message += "\n";
78 if (note_str.size() <= kMaxAcceptedNoteSize) {
79 error_message += note_str;
80 } else {
81 error_message += note_str.substr(0, kMaxAcceptedNoteSize);
82 error_message += "...";
83 }
84 }
85
86 ErrorCode error_code_enum = ConverterErrorData::UNKNOWN;
87 bool has_valid_error_code =
88 ConverterErrorData::ErrorCode_Parse(error_code, &error_code_enum);
89 if (!op_name.empty() || has_valid_error_code) {
90 error_collector_->ReportError(NewConverterErrorData(
91 pass_name_, error_message, error_code_enum, op_name, loc));
92 } else {
93 common_error_message_ += diag.str();
94 common_error_message_ += "\n";
95 }
96 }
97 return failure();
98 });
99 }
100
runBeforePass(Pass * pass,Operation * module)101 void ErrorCollectorInstrumentation::runBeforePass(Pass *pass,
102 Operation *module) {
103 // Find the op names with tf or tfl dialect prefix, Ex: "tf.Abs" or "tfl.Abs".
104 auto collectOps = [this](Operation *op) {
105 const auto &op_name = op->getName().getStringRef().str();
106 if (absl::StartsWith(op_name, "tf.") || absl::StartsWith(op_name, "tfl.")) {
107 loc_to_name_.emplace(op->getLoc(), op_name);
108 }
109 };
110
111 for (auto ®ion : module->getRegions()) {
112 region.walk(collectOps);
113 }
114
115 pass_name_ = extract_pass_name(pass->getName().str());
116 error_collector_->Clear();
117 }
118
runAfterPass(Pass * pass,Operation * module)119 void ErrorCollectorInstrumentation::runAfterPass(Pass *pass,
120 Operation *module) {
121 loc_to_name_.clear();
122 pass_name_.clear();
123 common_error_message_.clear();
124 error_collector_->Clear();
125 }
126
runAfterPassFailed(Pass * pass,Operation * module)127 void ErrorCollectorInstrumentation::runAfterPassFailed(Pass *pass,
128 Operation *module) {
129 // Create a new error if no errors collected yet.
130 if (error_collector_->CollectedErrors().empty() &&
131 !common_error_message_.empty()) {
132 error_collector_->ReportError(NewConverterErrorData(
133 pass_name_, common_error_message_, ConverterErrorData::UNKNOWN,
134 /*op_name=*/"", module->getLoc()));
135 }
136
137 loc_to_name_.clear();
138 pass_name_.clear();
139 common_error_message_.clear();
140 }
141
142 } // namespace TFL
143 } // namespace mlir
144