xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/jit_log.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cstdlib>
2 #include <iomanip>
3 #include <iostream>
4 #include <sstream>
5 #include <string>
6 #include <unordered_map>
7 #include <vector>
8 
9 #include <ATen/core/function.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/StringUtil.h>
12 #include <torch/csrc/jit/api/function_impl.h>
13 #include <torch/csrc/jit/frontend/error_report.h>
14 #include <torch/csrc/jit/ir/ir.h>
15 #include <torch/csrc/jit/jit_log.h>
16 #include <torch/csrc/jit/serialization/python_print.h>
17 
18 namespace torch::jit {
19 
20 class JitLoggingConfig {
21  public:
getInstance()22   static JitLoggingConfig& getInstance() {
23     static JitLoggingConfig instance;
24     return instance;
25   }
26   JitLoggingConfig(JitLoggingConfig const&) = delete;
27   void operator=(JitLoggingConfig const&) = delete;
28 
29  private:
30   std::string logging_levels;
31   std::unordered_map<std::string, size_t> files_to_levels;
32   std::ostream* out;
33 
JitLoggingConfig()34   JitLoggingConfig() : out(&std::cerr) {
35     const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
36     logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
37 
38     parse();
39   }
40   void parse();
41 
42  public:
getLoggingLevels() const43   std::string getLoggingLevels() const {
44     return this->logging_levels;
45   }
setLoggingLevels(std::string levels)46   void setLoggingLevels(std::string levels) {
47     this->logging_levels = std::move(levels);
48     parse();
49   }
50 
getFilesToLevels() const51   const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
52     return this->files_to_levels;
53   }
54 
setOutputStream(std::ostream & out_stream)55   void setOutputStream(std::ostream& out_stream) {
56     this->out = &out_stream;
57   }
58 
getOutputStream()59   std::ostream& getOutputStream() {
60     return *(this->out);
61   }
62 };
63 
get_jit_logging_levels()64 std::string get_jit_logging_levels() {
65   return JitLoggingConfig::getInstance().getLoggingLevels();
66 }
67 
set_jit_logging_levels(std::string level)68 void set_jit_logging_levels(std::string level) {
69   JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
70 }
71 
set_jit_logging_output_stream(std::ostream & stream)72 void set_jit_logging_output_stream(std::ostream& stream) {
73   JitLoggingConfig::getInstance().setOutputStream(stream);
74 }
75 
get_jit_logging_output_stream()76 std::ostream& get_jit_logging_output_stream() {
77   return JitLoggingConfig::getInstance().getOutputStream();
78 }
79 
80 // gets a string representation of a node header
81 // (e.g. outputs, a node kind and outputs)
getHeader(const Node * node)82 std::string getHeader(const Node* node) {
83   std::stringstream ss;
84   node->print(ss, 0, {}, false, false, false, false);
85   return ss.str();
86 }
87 
parse()88 void JitLoggingConfig::parse() {
89   std::stringstream in_ss;
90   in_ss << "function:" << this->logging_levels;
91 
92   files_to_levels.clear();
93   std::string line;
94   while (std::getline(in_ss, line, ':')) {
95     if (line.empty()) {
96       continue;
97     }
98 
99     auto index_at = line.find_last_of('>');
100     auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
101     size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
102     auto end_index = line.find_last_of('.') == std::string::npos
103         ? line.size()
104         : line.find_last_of('.');
105     auto filename = line.substr(begin_index, end_index - begin_index);
106     files_to_levels.insert({filename, logging_level});
107   }
108 }
109 
is_enabled(const char * cfname,JitLoggingLevels level)110 bool is_enabled(const char* cfname, JitLoggingLevels level) {
111   const auto& files_to_levels =
112       JitLoggingConfig::getInstance().getFilesToLevels();
113   std::string fname{cfname};
114   fname = c10::detail::StripBasename(fname);
115   const auto end_index = fname.find_last_of('.') == std::string::npos
116       ? fname.size()
117       : fname.find_last_of('.');
118   const auto fname_no_ext = fname.substr(0, end_index);
119 
120   const auto it = files_to_levels.find(fname_no_ext);
121   if (it == files_to_levels.end()) {
122     return false;
123   }
124 
125   return level <= static_cast<JitLoggingLevels>(it->second);
126 }
127 
128 // Unfortunately, in `GraphExecutor` where `log_function` is invoked
129 // we won't have access to an original function, so we have to construct
130 // a dummy function to give to PythonPrint
log_function(const std::shared_ptr<torch::jit::Graph> & graph)131 std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
132   torch::jit::GraphFunction func("source_dump", graph, nullptr);
133   std::vector<at::IValue> constants;
134   PrintDepsTable deps;
135   PythonPrint pp(constants, deps);
136   pp.printFunction(func);
137   return pp.str();
138 }
139 
jit_log_prefix(const std::string & prefix,const std::string & in_str)140 std::string jit_log_prefix(
141     const std::string& prefix,
142     const std::string& in_str) {
143   std::stringstream in_ss(in_str);
144   std::stringstream out_ss;
145   std::string line;
146   while (std::getline(in_ss, line)) {
147     out_ss << prefix << line << '\n';
148   }
149 
150   return out_ss.str();
151 }
152 
jit_log_prefix(JitLoggingLevels level,const char * fn,int l,const std::string & in_str)153 std::string jit_log_prefix(
154     JitLoggingLevels level,
155     const char* fn,
156     int l,
157     const std::string& in_str) {
158   std::stringstream prefix_ss;
159   prefix_ss << "[";
160   prefix_ss << level << " ";
161   prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
162   prefix_ss << std::setfill('0') << std::setw(3) << l;
163   prefix_ss << "] ";
164 
165   return jit_log_prefix(prefix_ss.str(), in_str);
166 }
167 
operator <<(std::ostream & out,JitLoggingLevels level)168 std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
169   switch (level) {
170     case JitLoggingLevels::GRAPH_DUMP:
171       out << "DUMP";
172       break;
173     case JitLoggingLevels::GRAPH_UPDATE:
174       out << "UPDATE";
175       break;
176     case JitLoggingLevels::GRAPH_DEBUG:
177       out << "DEBUG";
178       break;
179     default:
180       TORCH_INTERNAL_ASSERT(false, "Invalid level");
181   }
182 
183   return out;
184 }
185 
186 } // namespace torch::jit
187