1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/frontend/source_range.h>
3 #include <torch/csrc/jit/serialization/source_range_serialization.h>
4 #include <iostream>
5 #include <regex>
6
7 namespace torch::jit {
8
9 // A stringlike class backed by a vector of string_view
10 // the string represented are logically the concatenation of the string_views
11 // This has advantage of not needing continues memory.
StringCordView()12 StringCordView::StringCordView() {
13 accumulated_sizes_.push_back(0);
14 }
15
StringCordView(std::vector<c10::string_view> inputs,std::vector<std::shared_ptr<std::string>> ownerships)16 StringCordView::StringCordView(
17 std::vector<c10::string_view> inputs,
18 std::vector<std::shared_ptr<std::string>> ownerships)
19 : pieces_(std::move(inputs)), owned_strings_(std::move(ownerships)) {
20 accumulated_sizes_.push_back(0);
21 size_t running_sum = 0;
22 for (auto& s : pieces_) {
23 if (!s.empty()) {
24 running_sum += s.size();
25 accumulated_sizes_.push_back(running_sum);
26 }
27 }
28 }
29
find(const std::string & tok,size_t start) const30 size_t StringCordView::find(const std::string& tok, size_t start) const {
31 if (tok.empty()) {
32 return 0;
33 }
34
35 if ((size() - start) < tok.size()) {
36 return std::string::npos;
37 }
38
39 Iterator begin = iter_for_pos(start);
40 Iterator end_iter = end();
41 size_t offset = start;
42 for (; begin != end_iter; ++begin, ++offset) {
43 if (*begin == tok[0]) {
44 auto mis = std::mismatch(begin, end_iter, tok.begin(), tok.end());
45 if (mis.second == tok.end()) {
46 // no mismatch, and second string (tok) is exhausted.
47 return offset;
48 }
49 if (mis.first == end_iter) {
50 // this str is exhausted but tok is not
51 return std::string::npos;
52 }
53 }
54 }
55 return std::string::npos;
56 }
57
find_regex(const std::string & tok,size_t start) const58 size_t StringCordView::find_regex(const std::string& tok, size_t start) const {
59 if (tok.empty()) {
60 return 0;
61 }
62
63 const std::string& target = this->substr(start, this->size()).str();
64 std::smatch sm;
65 const std::regex re(tok);
66
67 auto regex_found = std::regex_search(target, sm, re);
68
69 return regex_found ? sm.position(0) : std::string::npos;
70 }
71
substr(size_t start,size_t size) const72 StringCordView StringCordView::substr(size_t start, size_t size) const {
73 std::vector<c10::string_view> pieces;
74 std::vector<std::shared_ptr<std::string>> ownerships;
75 if (start >= this->size()) {
76 // out of bounds
77 return StringCordView();
78 }
79 if (start + size >= this->size()) {
80 size = this->size() - start;
81 }
82 Iterator begin = iter_for_pos(start);
83 Iterator end = iter_for_pos(start + size);
84
85 if (begin.line_ == end.line_) {
86 // same line
87 pieces.push_back(pieces_[begin.line_].substr(begin.pos_, size));
88 } else {
89 pieces.push_back(pieces_[begin.line_].substr(begin.pos_));
90
91 size_t last_line = pieces_.size();
92 if (end != this->end() && end.line_ < last_line) {
93 // end is within the string
94 last_line = end.line_;
95 }
96 for (size_t i = begin.line_ + 1; i < last_line; i++) {
97 pieces.push_back(pieces_[i]);
98 }
99 if (end != this->end()) {
100 pieces.push_back(pieces_[end.line_].substr(0, end.pos_));
101 }
102 }
103
104 // share ownership
105 std::copy(
106 owned_strings_.begin(),
107 owned_strings_.end(),
108 std::back_inserter(ownerships));
109
110 return StringCordView(std::move(pieces), std::move(ownerships));
111 }
112
operator ==(const std::string & rhs) const113 bool StringCordView::operator==(const std::string& rhs) const {
114 if (size() != rhs.size()) {
115 return false;
116 }
117 auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
118 // both need to exhaust
119 return res.first == end() && res.second == rhs.end();
120 }
121
operator ==(const StringCordView & rhs) const122 bool StringCordView::operator==(const StringCordView& rhs) const {
123 if (size() != rhs.size()) {
124 return false;
125 }
126 auto res = std::mismatch(begin(), end(), rhs.begin(), rhs.end());
127 // both need to exhaust
128 return res.first == end() && res.second == rhs.end();
129 }
130
iter_for_pos(size_t pos) const131 StringCordView::Iterator StringCordView::iter_for_pos(size_t pos) const {
132 if (pos == 0) {
133 return begin();
134 }
135 if (pos >= size()) {
136 return end();
137 }
138 auto upper = std::upper_bound(
139 accumulated_sizes_.begin(), accumulated_sizes_.end(), pos);
140 if (upper == accumulated_sizes_.end()) {
141 return end();
142 }
143 size_t line = upper - accumulated_sizes_.begin() - 1;
144 assert(accumulated_sizes_[line] <= pos);
145 assert(accumulated_sizes_[line + 1] > pos);
146 return Iterator(this, line, pos - accumulated_sizes_[line], size() - pos);
147 }
148
operator ()(const torch::jit::SourceRange & key) const149 size_t SourceRangeHasher::operator()(const torch::jit::SourceRange& key) const {
150 return (
151 std::hash<uintptr_t>()(reinterpret_cast<uintptr_t>(key.source().get())) ^
152 std::hash<size_t>()(key.start()) ^ std::hash<size_t>()(key.end()));
153 }
154
findSourceRangeThatGenerated(const SourceRange & range)155 std::optional<SourceRange> Source::findSourceRangeThatGenerated(
156 const SourceRange& range) {
157 if (!gen_ranges_) {
158 return std::nullopt;
159 }
160 return gen_ranges_->findSourceRangeThatGenerated(range);
161 }
162
highlight(std::ostream & out) const163 void SourceRange::highlight(std::ostream& out) const {
164 // Retrieve original SourceRange, if present.
165 if (auto orig_source_range = findSourceRangeThatGenerated()) {
166 orig_source_range->highlight(out);
167 out << "Serialized ";
168 }
169 print_with_context(out, CONTEXT, true, "");
170 }
171
format_stack_trace(std::ostream & out,const std::vector<StackEntry> & entries)172 void format_stack_trace(
173 std::ostream& out,
174 const std::vector<StackEntry>& entries) {
175 bool has_orig_ranges = false;
176 std::vector<SourceRange> orig_ranges;
177 // gather original ranges. if we have a situation where we do not have orig
178 // ranges for some frames, we still want to report them for the frames we do
179 // have,
180 // so substitute the current range for that frame
181 for (const StackEntry& entry : entries) {
182 if (auto orig_source_range = entry.range.findSourceRangeThatGenerated()) {
183 orig_ranges.emplace_back(std::move(orig_source_range.value()));
184 has_orig_ranges = true;
185 } else {
186 orig_ranges.emplace_back(entry.range);
187 }
188 }
189 out << "Traceback of TorchScript";
190 if (has_orig_ranges) {
191 out << ", serialized code";
192 }
193 out << " (most recent call last):\n";
194 for (const StackEntry& entry : entries) {
195 entry.range.print_with_context(
196 out, SourceRange::CONTEXT, true, entry.filename);
197 }
198 if (has_orig_ranges) {
199 out << "\nTraceback of TorchScript, original code (most recent call last):\n";
200 auto it = entries.begin();
201 for (const SourceRange& range : orig_ranges) {
202 range.print_with_context(
203 out, SourceRange::CONTEXT, true, (*it++).filename);
204 }
205 }
206 }
207
print_with_context(std::ostream & out,size_t context,bool highlight,const std::string & funcname) const208 void SourceRange::print_with_context(
209 std::ostream& out,
210 size_t context,
211 bool highlight,
212 const std::string& funcname) const {
213 // This is an empty SourceRange, used as a sentinel value.
214 if (!source_view_) {
215 return;
216 }
217
218 auto str = source_view_->text_str().str();
219 if (size() == str.size()) {
220 // this is just the entire file, not a subset, so print it out.
221 // primarily used to print out python stack traces
222 out << str;
223 return;
224 }
225
226 size_t range_end =
227 (str.size() < end()
228 ? str.size()
229 : end()); // use instead of 'end()' because some ranges extend past
230 // the length of the source
231
232 // determine CONTEXT line range
233 size_t begin_line = start(); // beginning of lines to highlight
234 size_t end_line = range_end;
235 if (begin_line > str.size()) {
236 return;
237 }
238 while (begin_line > 0 && str[begin_line - 1] != '\n')
239 --begin_line;
240 while (end_line < str.size() && str[end_line] != '\n')
241 ++end_line;
242 AT_ASSERT(begin_line == 0 || str[begin_line - 1] == '\n');
243 AT_ASSERT(end_line == str.size() || str[end_line] == '\n');
244
245 size_t begin_context = begin_line; // beginning of context, CONTEXT lines
246 // before the highlight lines
247 for (size_t i = 0; begin_context > 0; --begin_context) {
248 if (str[begin_context - 1] == '\n') {
249 ++i;
250 }
251 if (i >= context) {
252 break;
253 }
254 }
255 AT_ASSERT(begin_context == 0 || str[begin_context - 1] == '\n');
256
257 size_t end_context =
258 end_line; // end of context, CONTEXT lines after the highlight lines
259 for (size_t i = 0; end_context < str.size(); ++end_context) {
260 if (str[end_context] == '\n') {
261 ++i;
262 }
263 if (i >= context) {
264 break;
265 }
266 }
267 AT_ASSERT(end_context == str.size() || str[end_context] == '\n');
268
269 // print out location information
270 if (auto flc = file_line_col()) {
271 auto [filename, line, col] = *flc;
272 out << " File \"" << filename << "\", line " << line;
273 if (!funcname.empty()) {
274 out << ", in " << funcname;
275 }
276 out << "\n";
277 }
278 // print out inital context
279 out << str.substr(begin_context, start() - begin_context);
280 size_t line_start = start();
281 size_t line_end = range_end;
282 if (highlight) {
283 line_end = start();
284 while (line_start < range_end) {
285 // move line_end to end of line
286 while (line_end < str.size() && str[line_end] != '\n') {
287 ++line_end;
288 }
289 // print line of code
290 auto actual_line = str.substr(line_start, (line_end - line_start) + 1);
291 out << actual_line;
292 if (actual_line.back() != '\n') {
293 out << "\n";
294 }
295
296 size_t empty_space = 0;
297 size_t highlight_space = 0;
298 size_t hightlight_begin = line_start;
299 size_t highlight_end = line_start;
300 // determine length of line which is being highlighted
301 while (hightlight_begin > 0 && str[hightlight_begin - 1] != '\n') {
302 --hightlight_begin;
303 }
304 while (highlight_end < range_end && str[highlight_end] != '\n') {
305 ++highlight_end;
306 }
307 AT_ASSERT(hightlight_begin == 0 || str[hightlight_begin - 1] == '\n');
308 AT_ASSERT(highlight_end == range_end || str[highlight_end] == '\n');
309 // determine amount of empty space vs highlighted space
310 for (const auto i : c10::irange(hightlight_begin, highlight_end)) {
311 if (str[i] == ' ' || i < start()) {
312 empty_space++;
313 } else {
314 break;
315 }
316 }
317 highlight_space = highlight_end - hightlight_begin - empty_space;
318 if (highlight_space > 0) {
319 // some ranges are off and include empty white space on new lines which
320 // don't need to be printed
321 bool more_lines = false;
322 for (size_t i = line_end; i <= range_end; i++) {
323 if (str[i] != '\n' && str[i] != ' ') {
324 more_lines = true;
325 }
326 }
327 out << std::string(empty_space, ' ');
328 out << std::string(highlight_space, '~');
329 out << (more_lines && line_end != range_end ? "\n" : " <--- HERE\n");
330 }
331 ++line_end;
332 line_start = line_end;
333 }
334 } else {
335 // print out code with no highlight
336 out << str.substr(start(), range_end - start());
337 }
338 // print out ending context
339 if (line_end <= str.size()) {
340 auto line_substr = str.substr(line_end, end_context - line_end);
341 out << line_substr;
342 if (!line_substr.empty() && line_substr.back() != '\n') {
343 out << "\n";
344 }
345 }
346 }
347
348 } // namespace torch::jit
349