1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
17
18 #include <atomic>
19
20 #include "absl/strings/str_split.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
26
27 namespace tensorflow {
28
29 // Counter is used as a prefix for filenames.
30 static std::atomic<int> log_counter(0);
31
BridgeLoggerConfig(bool print_module_scope,bool print_after_only_on_change)32 BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope,
33 bool print_after_only_on_change)
34 : mlir::PassManager::IRPrinterConfig(print_module_scope,
35 print_after_only_on_change),
36 pass_filter_(GetFilter("MLIR_BRIDGE_LOG_PASS_FILTER")),
37 string_filter_(GetFilter("MLIR_BRIDGE_LOG_STRING_FILTER")) {}
38
39 // Logs op to file with name of format
40 // `<log_counter>_mlir_bridge_<pass_name>_<file_suffix>.mlir`.
Log(BridgeLoggerConfig::PrintCallbackFn print_callback,mlir::Pass * pass,mlir::Operation * op,llvm::StringRef file_suffix)41 inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback,
42 mlir::Pass* pass, mlir::Operation* op,
43 llvm::StringRef file_suffix) {
44 std::string pass_name = pass->getName().str();
45
46 // Add 4-digit counter as prefix so the order of the passes is obvious.
47 std::string name = llvm::formatv("{0,0+4}_mlir_bridge_{1}_{2}", log_counter++,
48 pass_name, file_suffix);
49
50 std::unique_ptr<llvm::raw_ostream> os;
51 std::string filepath;
52 if (CreateFileForDumping(name, &os, &filepath).ok()) {
53 print_callback(*os);
54 LOG(INFO) << "Dumped MLIR module to " << filepath;
55 }
56 }
57
printBeforeIfEnabled(mlir::Pass * pass,mlir::Operation * op,PrintCallbackFn print_callback)58 void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass,
59 mlir::Operation* op,
60 PrintCallbackFn print_callback) {
61 if (ShouldPrint(pass, op)) Log(print_callback, pass, op, "before");
62 }
63
printAfterIfEnabled(mlir::Pass * pass,mlir::Operation * op,PrintCallbackFn print_callback)64 void BridgeLoggerConfig::printAfterIfEnabled(mlir::Pass* pass,
65 mlir::Operation* op,
66 PrintCallbackFn print_callback) {
67 if (ShouldPrint(pass, op)) Log(print_callback, pass, op, "after");
68 }
69
GetFilter(const std::string & env_var)70 std::vector<std::string> BridgeLoggerConfig::GetFilter(
71 const std::string& env_var) {
72 std::vector<std::string> filter;
73 const char* filter_str = getenv(env_var.c_str());
74 if (filter_str) {
75 filter = absl::StrSplit(filter_str, ';', absl::SkipWhitespace());
76 }
77 return filter;
78 }
79
MatchesFilter(const std::string & str,const std::vector<std::string> & filter,bool exact_match)80 bool BridgeLoggerConfig::MatchesFilter(const std::string& str,
81 const std::vector<std::string>& filter,
82 bool exact_match) {
83 if (filter.empty()) return true;
84 for (const std::string& filter_str : filter) {
85 if (str == filter_str) return true;
86 if (!exact_match && str.find(filter_str) != std::string::npos) return true;
87 }
88 return false;
89 }
90
ShouldPrint(mlir::Pass * pass,mlir::Operation * op)91 bool BridgeLoggerConfig::ShouldPrint(mlir::Pass* pass, mlir::Operation* op) {
92 // Check pass filter first since it's cheaper.
93 std::string pass_name = pass->getName().str();
94 if (!MatchesFilter(pass_name, pass_filter_, /*exact_match=*/true)) {
95 // No string in filter matches pass name.
96 VLOG(1) << "Not logging invocation of pass `" << pass_name
97 << "` because the pass name does not match any string in "
98 "`MLIR_BRIDGE_LOG_PASS_FILTER`";
99 return false;
100 }
101 if (!string_filter_.empty()) {
102 std::string serialized_op;
103 llvm::raw_string_ostream os(serialized_op);
104 op->print(os);
105 if (!MatchesFilter(serialized_op, string_filter_, /*exact_match=*/false)) {
106 // No string in filter was found in serialized `op`.
107 VLOG(1) << "Not logging invocation of pass `" << pass_name
108 << "` because the serialized operation on which the pass is "
109 "invoked does not contain any of the strings specified by "
110 "MLIR_BRIDGE_LOG_STRING_FILTER";
111 return false;
112 }
113 }
114 return true;
115 }
116
117 } // namespace tensorflow
118