xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/source_range.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/frontend/source_range.h>
3 #include <torch/csrc/jit/serialization/source_range_serialization.h>
4 #include <iostream>
5 #include <regex>
6 
7 namespace torch::jit {
8 
9 // A stringlike class backed by a vector of string_view
10 // the string represented are logically the concatenation of  the string_views
11 // This has advantage of not needing continues memory.
StringCordView()12 StringCordView::StringCordView() {
13   accumulated_sizes_.push_back(0);
14 }
15 
StringCordView(std::vector<c10::string_view> inputs,std::vector<std::shared_ptr<std::string>> ownerships)16 StringCordView::StringCordView(
17     std::vector<c10::string_view> inputs,
18     std::vector<std::shared_ptr<std::string>> ownerships)
19     : pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
20   accumulated_sizes_.push_back(0);
21   size_t running_sum = 0;
22   for (auto& s : pieces_) {
23     if (!s.empty()) {
24       running_sum += s.size();
25       accumulated_sizes_.push_back(running_sum);
26     }
27   }
28 }
29 
find(const std::string & tok,size_t start) const30 size_t StringCordView::find(const std::string& tok, size_t start) const {
31   if (tok.empty()) {
32     return 0;
33   }
34 
35   if ((size() - start) < tok.size()) {
36     return std::string::npos;
37   }
38 
39   Iterator begin = iter_for_pos(start);
40   Iterator end_iter = end();
41   size_t offset = start;
42   for (; begin != end_iter; ++begin, ++offset) {
43     if (*begin == tok[0]) {
44       auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
45       if (mis.second == tok.end()) {
46         // no mismatch, and second string (tok) is exhausted.
47         return offset;
48       }
49       if (mis.first == end_iter) {
50         // this str is exhausted but tok is not
51         return std::string::npos;
52       }
53     }
54   }
55   return std::string::npos;
56 }
57 
find_regex(const std::string & tok,size_t start) const58 size_t StringCordView::find_regex(const std::string& tok, size_t start) const {
59   if (tok.empty()) {
60     return 0;
61   }
62 
63   const std::string& target = this->substr(start, this->size()).str();
64   std::smatch sm;
65   const std::regex re(tok);
66 
67   auto regex_found = std::regex_search(target, sm, re);
68 
69   return regex_found ? sm.position(0) : std::string::npos;
70 }
71 
substr(size_t start,size_t size) const72 StringCordView StringCordView::substr(size_t start, size_t size) const {
73   std::vector<c10::string_view> pieces;
74   std::vector<std::shared_ptr<std::string>> ownerships;
75   if (start >= this->size()) {
76     // out of bounds
77     return StringCordView();
78   }
79   if (start + size >= this->size()) {
80     size = this->size() - start;
81   }
82   Iterator begin = iter_for_pos(start);
83   Iterator end = iter_for_pos(start + size);
84 
85   if (begin.line_ == end.line_) {
86     // same line
87     pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
88   } else {
89     pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
90 
91     size_t last_line = pieces_.size();
92     if (end != this->end() && end.line_ < last_line) {
93       // end is within the string
94       last_line = end.line_;
95     }
96     for (size_t i = begin.line_ + 1; i < last_line; i++) {
97       pieces.push_back(pieces_[i]);
98     }
99     if (end != this->end()) {
100       pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
101     }
102   }
103 
104   // share ownership
105   std::copy(
106       owned_strings_.begin(),
107       owned_strings_.end(),
108       std::back_inserter(ownerships));
109 
110   return StringCordView(std::move(pieces), std::move(ownerships));
111 }
112 
operator ==(const std::string & rhs) const113 bool StringCordView::operator==(const std::string& rhs) const {
114   if (size() != rhs.size()) {
115     return false;
116   }
117   auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
118   // both need to exhaust
119   return res.first == end() && res.second == rhs.end();
120 }
121 
operator ==(const StringCordView & rhs) const122 bool StringCordView::operator==(const StringCordView& rhs) const {
123   if (size() != rhs.size()) {
124     return false;
125   }
126   auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
127   // both need to exhaust
128   return res.first == end() && res.second == rhs.end();
129 }
130 
iter_for_pos(size_t pos) const131 StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
132   if (pos == 0) {
133     return begin();
134   }
135   if (pos >= size()) {
136     return end();
137   }
138   auto upper = std::upper_bound(
139       accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
140   if (upper == accumulated_sizes_.end()) {
141     return end();
142   }
143   size_t line = upper - accumulated_sizes_.begin() - 1;
144   assert(accumulated_sizes_[line] <= pos);
145   assert(accumulated_sizes_[line + 1] > pos);
146   return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
147 }
148 
operator ()(const torch::jit::SourceRange & key) const149 size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
150   return (
151       std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
152       std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
153 }
154 
findSourceRangeThatGenerated(const SourceRange & range)155 std::optional<SourceRange> Source::findSourceRangeThatGenerated(
156     const SourceRange& range) {
157   if (!gen_ranges_) {
158     return std::nullopt;
159   }
160   return gen_ranges_->findSourceRangeThatGenerated(range);
161 }
162 
highlight(std::ostream & out) const163 void SourceRange::highlight(std::ostream& out) const {
164   // Retrieve original SourceRange, if present.
165   if (auto orig_source_range = findSourceRangeThatGenerated()) {
166     orig_source_range->highlight(out);
167     out << "Serialized ";
168   }
169   print_with_context(out, CONTEXT, true, "");
170 }
171 
format_stack_trace(std::ostream & out,const std::vector<StackEntry> & entries)172 void format_stack_trace(
173     std::ostream& out,
174     const std::vector<StackEntry>& entries) {
175   bool has_orig_ranges = false;
176   std::vector<SourceRange> orig_ranges;
177   // gather original ranges. if we have a situation where we do not have orig
178   // ranges for some frames, we still want to report them for the frames we do
179   // have,
180   //  so substitute the current range for that frame
181   for (const StackEntry& entry : entries) {
182     if (auto orig_source_range = entry.range.findSourceRangeThatGenerated()) {
183       orig_ranges.emplace_back(std::move(orig_source_range.value()));
184       has_orig_ranges = true;
185     } else {
186       orig_ranges.emplace_back(entry.range);
187     }
188   }
189   out << "Traceback of TorchScript";
190   if (has_orig_ranges) {
191     out << ", serialized code";
192   }
193   out << " (most recent call last):\n";
194   for (const StackEntry& entry : entries) {
195     entry.range.print_with_context(
196         out, SourceRange::CONTEXT, true, entry.filename);
197   }
198   if (has_orig_ranges) {
199     out << "\nTraceback of TorchScript, original code (most recent call last):\n";
200     auto it = entries.begin();
201     for (const SourceRange& range : orig_ranges) {
202       range.print_with_context(
203           out, SourceRange::CONTEXT, true, (*it++).filename);
204     }
205   }
206 }
207 
print_with_context(std::ostream & out,size_t context,bool highlight,const std::string & funcname) const208 void SourceRange::print_with_context(
209     std::ostream& out,
210     size_t context,
211     bool highlight,
212     const std::string& funcname) const {
213   // This is an empty SourceRange, used as a sentinel value.
214   if (!source_view_) {
215     return;
216   }
217 
218   auto str = source_view_->text_str().str();
219   if (size() == str.size()) {
220     // this is just the entire file, not a subset, so print it out.
221     // primarily used to print out python stack traces
222     out << str;
223     return;
224   }
225 
226   size_t range_end =
227       (str.size() < end()
228            ? str.size()
229            : end()); // use instead of 'end()' because some ranges extend past
230                      // the length of the source
231 
232   // determine CONTEXT line range
233   size_t begin_line = start(); // beginning of lines to highlight
234   size_t end_line = range_end;
235   if (begin_line > str.size()) {
236     return;
237   }
238   while (begin_line > 0 && str[begin_line - 1] != '\n')
239     --begin_line;
240   while (end_line < str.size() && str[end_line] != '\n')
241     ++end_line;
242   AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
243   AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
244 
245   size_t begin_context = begin_line; // beginning of context, CONTEXT lines
246                                      // before the highlight lines
247   for (size_t i = 0; begin_context > 0; --begin_context) {
248     if (str[begin_context - 1] == '\n') {
249       ++i;
250     }
251     if (i >= context) {
252       break;
253     }
254   }
255   AT_ASSERT(begin_context == 0 || str[begin_context - 1] == '\n');
256 
257   size_t end_context =
258       end_line; // end of context, CONTEXT lines after the highlight lines
259   for (size_t i = 0; end_context < str.size(); ++end_context) {
260     if (str[end_context] == '\n') {
261       ++i;
262     }
263     if (i >= context) {
264       break;
265     }
266   }
267   AT_ASSERT(end_context == str.size() || str[end_context] == '\n');
268 
269   // print out location information
270   if (auto flc = file_line_col()) {
271     auto [filename, line, col] = *flc;
272     out << "  File \"" << filename << "\", line " << line;
273     if (!funcname.empty()) {
274       out << ", in " << funcname;
275     }
276     out << "\n";
277   }
278   // print out inital context
279   out << str.substr(begin_context, start() - begin_context);
280   size_t line_start = start();
281   size_t line_end = range_end;
282   if (highlight) {
283     line_end = start();
284     while (line_start < range_end) {
285       // move line_end to end of line
286       while (line_end < str.size() && str[line_end] != '\n') {
287         ++line_end;
288       }
289       // print line of code
290       auto actual_line = str.substr(line_start, (line_end - line_start) + 1);
291       out << actual_line;
292       if (actual_line.back() != '\n') {
293         out << "\n";
294       }
295 
296       size_t empty_space = 0;
297       size_t highlight_space = 0;
298       size_t hightlight_begin = line_start;
299       size_t highlight_end = line_start;
300       // determine length of line which is being highlighted
301       while (hightlight_begin > 0 && str[hightlight_begin - 1] != '\n') {
302         --hightlight_begin;
303       }
304       while (highlight_end < range_end && str[highlight_end] != '\n') {
305         ++highlight_end;
306       }
307       AT_ASSERT(hightlight_begin == 0 || str[hightlight_begin - 1] == '\n');
308       AT_ASSERT(highlight_end == range_end || str[highlight_end] == '\n');
309       // determine amount of empty space vs highlighted space
310       for (const auto i : c10::irange(hightlight_begin, highlight_end)) {
311         if (str[i] == ' ' || i < start()) {
312           empty_space++;
313         } else {
314           break;
315         }
316       }
317       highlight_space = highlight_end - hightlight_begin - empty_space;
318       if (highlight_space > 0) {
319         // some ranges are off and include empty white space on new lines which
320         // don't need to be printed
321         bool more_lines = false;
322         for (size_t i = line_end; i <= range_end; i++) {
323           if (str[i] != '\n' && str[i] != ' ') {
324             more_lines = true;
325           }
326         }
327         out << std::string(empty_space, ' ');
328         out << std::string(highlight_space, '~');
329         out << (more_lines && line_end != range_end ? "\n" : " <--- HERE\n");
330       }
331       ++line_end;
332       line_start = line_end;
333     }
334   } else {
335     // print out code with no highlight
336     out << str.substr(start(), range_end - start());
337   }
338   // print out ending context
339   if (line_end <= str.size()) {
340     auto line_substr = str.substr(line_end, end_context - line_end);
341     out << line_substr;
342     if (!line_substr.empty() && line_substr.back() != '\n') {
343       out << "\n";
344     }
345   }
346 }
347 
348 } // namespace torch::jit
349