xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
16 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_split.h"
23 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 
26 namespace mlir {
27 namespace TFL {
28 namespace {
29 
30 // The signature contains namespaces (Ex: mlir::TFL::(anonymous namespace)::).
31 // So only extract the function name as the pass name.
extract_pass_name(const std::string & signature)32 inline std::string extract_pass_name(const std::string &signature) {
33   const std::vector<std::string> &v = absl::StrSplit(signature, "::");
34   return v.back();
35 }
36 
37 // Errors raised by emitOpError start with "'<dialect>.<op>' op". Returns an
38 // empty string if the pattern is not found or the operator is not in tf or tfl
39 // dialect.
extract_op_name_from_error_message(const std::string & error_message)40 inline std::string extract_op_name_from_error_message(
41     const std::string &error_message) {
42   int end_pos = error_message.find("' op");
43   if ((absl::StartsWith(error_message, "'tf.") ||
44        absl::StartsWith(error_message, "'tfl.")) &&
45       end_pos != std::string::npos) {
46     return error_message.substr(1, end_pos - 1);
47   }
48   return "";
49 }
50 
51 // Only notes with character count smaller than kMaxAcceptedNoteSize will be
52 // appended to the error message.
53 const int kMaxAcceptedNoteSize = 1024;
54 }  // namespace
55 
ErrorCollectorInstrumentation(MLIRContext * context)56 ErrorCollectorInstrumentation::ErrorCollectorInstrumentation(
57     MLIRContext *context)
58     : error_collector_(ErrorCollector::GetErrorCollector()) {
59   handler_ = std::make_unique<ScopedDiagnosticHandler>(
60       context, [this](Diagnostic &diag) {
61         if (diag.getSeverity() == DiagnosticSeverity::Error) {
62           Location loc = diag.getLocation();
63           std::string error_message = diag.str();
64           std::string op_name, error_code;
65           if (loc_to_name_.count(loc)) {
66             op_name = loc_to_name_[loc];
67           } else {
68             op_name = extract_op_name_from_error_message(diag.str());
69           }
70 
71           for (const auto &note : diag.getNotes()) {
72             const std::string note_str = note.str();
73             if (note_str.rfind(kErrorCodePrefix, 0) == 0) {
74               error_code = note_str.substr(sizeof(kErrorCodePrefix) - 1);
75             }
76 
77             error_message += "\n";
78             if (note_str.size() <= kMaxAcceptedNoteSize) {
79               error_message += note_str;
80             } else {
81               error_message += note_str.substr(0, kMaxAcceptedNoteSize);
82               error_message += "...";
83             }
84           }
85 
86           ErrorCode error_code_enum = ConverterErrorData::UNKNOWN;
87           bool has_valid_error_code =
88               ConverterErrorData::ErrorCode_Parse(error_code, &error_code_enum);
89           if (!op_name.empty() || has_valid_error_code) {
90             error_collector_->ReportError(NewConverterErrorData(
91                 pass_name_, error_message, error_code_enum, op_name, loc));
92           } else {
93             common_error_message_ += diag.str();
94             common_error_message_ += "\n";
95           }
96         }
97         return failure();
98       });
99 }
100 
runBeforePass(Pass * pass,Operation * module)101 void ErrorCollectorInstrumentation::runBeforePass(Pass *pass,
102                                                   Operation *module) {
103   // Find the op names with tf or tfl dialect prefix, Ex: "tf.Abs" or "tfl.Abs".
104   auto collectOps = [this](Operation *op) {
105     const auto &op_name = op->getName().getStringRef().str();
106     if (absl::StartsWith(op_name, "tf.") || absl::StartsWith(op_name, "tfl.")) {
107       loc_to_name_.emplace(op->getLoc(), op_name);
108     }
109   };
110 
111   for (auto &region : module->getRegions()) {
112     region.walk(collectOps);
113   }
114 
115   pass_name_ = extract_pass_name(pass->getName().str());
116   error_collector_->Clear();
117 }
118 
runAfterPass(Pass * pass,Operation * module)119 void ErrorCollectorInstrumentation::runAfterPass(Pass *pass,
120                                                  Operation *module) {
121   loc_to_name_.clear();
122   pass_name_.clear();
123   common_error_message_.clear();
124   error_collector_->Clear();
125 }
126 
runAfterPassFailed(Pass * pass,Operation * module)127 void ErrorCollectorInstrumentation::runAfterPassFailed(Pass *pass,
128                                                        Operation *module) {
129   // Create a new error if no errors collected yet.
130   if (error_collector_->CollectedErrors().empty() &&
131       !common_error_message_.empty()) {
132     error_collector_->ReportError(NewConverterErrorData(
133         pass_name_, common_error_message_, ConverterErrorData::UNKNOWN,
134         /*op_name=*/"", module->getLoc()));
135   }
136 
137   loc_to_name_.clear();
138   pass_name_.clear();
139   common_error_message_.clear();
140 }
141 
142 }  // namespace TFL
143 }  // namespace mlir
144