1 #include <cstdlib>
2 #include <iomanip>
3 #include <iostream>
4 #include <sstream>
5 #include <string>
6 #include <unordered_map>
7 #include <vector>
8
9 #include <ATen/core/function.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/StringUtil.h>
12 #include <torch/csrc/jit/api/function_impl.h>
13 #include <torch/csrc/jit/frontend/error_report.h>
14 #include <torch/csrc/jit/ir/ir.h>
15 #include <torch/csrc/jit/jit_log.h>
16 #include <torch/csrc/jit/serialization/python_print.h>
17
18 namespace torch::jit {
19
20 class JitLoggingConfig {
21 public:
getInstance()22 static JitLoggingConfig& getInstance() {
23 static JitLoggingConfig instance;
24 return instance;
25 }
26 JitLoggingConfig(JitLoggingConfig const&) = delete;
27 void operator=(JitLoggingConfig const&) = delete;
28
29 private:
30 std::string logging_levels;
31 std::unordered_map<std::string, size_t> files_to_levels;
32 std::ostream* out;
33
JitLoggingConfig()34 JitLoggingConfig() : out(&std::cerr) {
35 const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
36 logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
37
38 parse();
39 }
40 void parse();
41
42 public:
getLoggingLevels() const43 std::string getLoggingLevels() const {
44 return this->logging_levels;
45 }
setLoggingLevels(std::string levels)46 void setLoggingLevels(std::string levels) {
47 this->logging_levels = std::move(levels);
48 parse();
49 }
50
getFilesToLevels() const51 const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
52 return this->files_to_levels;
53 }
54
setOutputStream(std::ostream & out_stream)55 void setOutputStream(std::ostream& out_stream) {
56 this->out = &out_stream;
57 }
58
getOutputStream()59 std::ostream& getOutputStream() {
60 return *(this->out);
61 }
62 };
63
get_jit_logging_levels()64 std::string get_jit_logging_levels() {
65 return JitLoggingConfig::getInstance().getLoggingLevels();
66 }
67
set_jit_logging_levels(std::string level)68 void set_jit_logging_levels(std::string level) {
69 JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
70 }
71
set_jit_logging_output_stream(std::ostream & stream)72 void set_jit_logging_output_stream(std::ostream& stream) {
73 JitLoggingConfig::getInstance().setOutputStream(stream);
74 }
75
get_jit_logging_output_stream()76 std::ostream& get_jit_logging_output_stream() {
77 return JitLoggingConfig::getInstance().getOutputStream();
78 }
79
80 // gets a string representation of a node header
81 // (e.g. outputs, a node kind and outputs)
getHeader(const Node * node)82 std::string getHeader(const Node* node) {
83 std::stringstream ss;
84 node->print(ss, 0, {}, false, false, false, false);
85 return ss.str();
86 }
87
parse()88 void JitLoggingConfig::parse() {
89 std::stringstream in_ss;
90 in_ss << "function:" << this->logging_levels;
91
92 files_to_levels.clear();
93 std::string line;
94 while (std::getline(in_ss, line, ':')) {
95 if (line.empty()) {
96 continue;
97 }
98
99 auto index_at = line.find_last_of('>');
100 auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
101 size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
102 auto end_index = line.find_last_of('.') == std::string::npos
103 ? line.size()
104 : line.find_last_of('.');
105 auto filename = line.substr(begin_index, end_index - begin_index);
106 files_to_levels.insert({filename, logging_level});
107 }
108 }
109
is_enabled(const char * cfname,JitLoggingLevels level)110 bool is_enabled(const char* cfname, JitLoggingLevels level) {
111 const auto& files_to_levels =
112 JitLoggingConfig::getInstance().getFilesToLevels();
113 std::string fname{cfname};
114 fname = c10::detail::StripBasename(fname);
115 const auto end_index = fname.find_last_of('.') == std::string::npos
116 ? fname.size()
117 : fname.find_last_of('.');
118 const auto fname_no_ext = fname.substr(0, end_index);
119
120 const auto it = files_to_levels.find(fname_no_ext);
121 if (it == files_to_levels.end()) {
122 return false;
123 }
124
125 return level <= static_cast<JitLoggingLevels>(it->second);
126 }
127
128 // Unfortunately, in `GraphExecutor` where `log_function` is invoked
129 // we won't have access to an original function, so we have to construct
130 // a dummy function to give to PythonPrint
log_function(const std::shared_ptr<torch::jit::Graph> & graph)131 std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
132 torch::jit::GraphFunction func("source_dump", graph, nullptr);
133 std::vector<at::IValue> constants;
134 PrintDepsTable deps;
135 PythonPrint pp(constants, deps);
136 pp.printFunction(func);
137 return pp.str();
138 }
139
jit_log_prefix(const std::string & prefix,const std::string & in_str)140 std::string jit_log_prefix(
141 const std::string& prefix,
142 const std::string& in_str) {
143 std::stringstream in_ss(in_str);
144 std::stringstream out_ss;
145 std::string line;
146 while (std::getline(in_ss, line)) {
147 out_ss << prefix << line << '\n';
148 }
149
150 return out_ss.str();
151 }
152
jit_log_prefix(JitLoggingLevels level,const char * fn,int l,const std::string & in_str)153 std::string jit_log_prefix(
154 JitLoggingLevels level,
155 const char* fn,
156 int l,
157 const std::string& in_str) {
158 std::stringstream prefix_ss;
159 prefix_ss << "[";
160 prefix_ss << level << " ";
161 prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
162 prefix_ss << std::setfill('0') << std::setw(3) << l;
163 prefix_ss << "] ";
164
165 return jit_log_prefix(prefix_ss.str(), in_str);
166 }
167
operator <<(std::ostream & out,JitLoggingLevels level)168 std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
169 switch (level) {
170 case JitLoggingLevels::GRAPH_DUMP:
171 out << "DUMP";
172 break;
173 case JitLoggingLevels::GRAPH_UPDATE:
174 out << "UPDATE";
175 break;
176 case JitLoggingLevels::GRAPH_DEBUG:
177 out << "DEBUG";
178 break;
179 default:
180 TORCH_INTERNAL_ASSERT(false, "Invalid level");
181 }
182
183 return out;
184 }
185
186 } // namespace torch::jit
187