xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
17 
18 #include <atomic>
19 
20 #include "absl/strings/str_split.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26 
27 namespace tensorflow {
28 
29 // Counter is used as a prefix for filenames.
30 static std::atomic<int> log_counter(0);
31 
BridgeLoggerConfig(bool print_module_scope,bool print_after_only_on_change)32 BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
33                                        bool print_after_only_on_change)
34     : mlir::PassManager::IRPrinterConfig(print_module_scope,
35                                          print_after_only_on_change),
36       pass_filter_(GetFilter("MLIR_BRIDGE_LOG_PASS_FILTER")),
37       string_filter_(GetFilter("MLIR_BRIDGE_LOG_STRING_FILTER")) {}
38 
39 // Logs op to file with name of format
40 // `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
Log(BridgeLoggerConfig::PrintCallbackFn print_callback,mlir::Pass * pass,mlir::Operation * op,llvm::StringRef file_suffix)41 inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
42                        mlir::Pass* pass, mlir::Operation* op,
43                        llvm::StringRef file_suffix) {
44   std::string pass_name = pass->getName().str();
45 
46   // Add 4-digit counter as prefix so the order of the passes is obvious.
47   std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
48                                    pass_name, file_suffix);
49 
50   std::unique_ptr<llvm::raw_ostream> os;
51   std::string filepath;
52   if (CreateFileForDumping(name, &os, &filepath).ok()) {
53     print_callback(*os);
54     LOG(INFO) << "Dumped MLIR module to " << filepath;
55   }
56 }
57 
printBeforeIfEnabled(mlir::Pass * pass,mlir::Operation * op,PrintCallbackFn print_callback)58 void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
59                                               mlir::Operation* op,
60                                               PrintCallbackFn print_callback) {
61   if (ShouldPrint(pass, op)) Log(print_callback, pass, op, "before");
62 }
63 
printAfterIfEnabled(mlir::Pass * pass,mlir::Operation * op,PrintCallbackFn print_callback)64 void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
65                                              mlir::Operation* op,
66                                              PrintCallbackFn print_callback) {
67   if (ShouldPrint(pass, op)) Log(print_callback, pass, op, "after");
68 }
69 
GetFilter(const std::string & env_var)70 std::vector<std::string> BridgeLoggerConfig::GetFilter(
71     const std::string& env_var) {
72   std::vector<std::string> filter;
73   const char* filter_str = getenv(env_var.c_str());
74   if (filter_str) {
75     filter = absl::StrSplit(filter_str, ';', absl::SkipWhitespace());
76   }
77   return filter;
78 }
79 
MatchesFilter(const std::string & str,const std::vector<std::string> & filter,bool exact_match)80 bool BridgeLoggerConfig::MatchesFilter(const std::string& str,
81                                        const std::vector<std::string>& filter,
82                                        bool exact_match) {
83   if (filter.empty()) return true;
84   for (const std::string& filter_str : filter) {
85     if (str == filter_str) return true;
86     if (!exact_match && str.find(filter_str) != std::string::npos) return true;
87   }
88   return false;
89 }
90 
ShouldPrint(mlir::Pass * pass,mlir::Operation * op)91 bool BridgeLoggerConfig::ShouldPrint(mlir::Pass* pass, mlir::Operation* op) {
92   // Check pass filter first since it's cheaper.
93   std::string pass_name = pass->getName().str();
94   if (!MatchesFilter(pass_name, pass_filter_, /*exact_match=*/true)) {
95     // No string in filter matches pass name.
96     VLOG(1) << "Not logging invocation of pass `" << pass_name
97             << "` because the pass name does not match any string in "
98                "`MLIR_BRIDGE_LOG_PASS_FILTER`";
99     return false;
100   }
101   if (!string_filter_.empty()) {
102     std::string serialized_op;
103     llvm::raw_string_ostream os(serialized_op);
104     op->print(os);
105     if (!MatchesFilter(serialized_op, string_filter_, /*exact_match=*/false)) {
106       // No string in filter was found in serialized `op`.
107       VLOG(1) << "Not logging invocation of pass `" << pass_name
108               << "` because the serialized operation on which the pass is "
109                  "invoked does not contain any of the strings specified by "
110                  "MLIR_BRIDGE_LOG_STRING_FILTER";
111       return false;
112     }
113   }
114   return true;
115 }
116 
117 }  // namespace tensorflow
118