xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/source_range.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/util/Exception.h>
3 #include <optional>
4 
5 #include <algorithm>
6 #include <iterator>
7 #include <memory>
8 #include <ostream>
9 #include <sstream>
10 #include <unordered_map>
11 
12 namespace torch::jit {
13 
14 class SourceRangeUnpickler;
15 struct SourceRange;
16 
17 // A stringlike class backed by a vector of string_view
18 // the string represented are logically the concatenation of  the string_views
19 // This has advantage of not needing continues memory.
20 struct TORCH_API StringCordView {
21   StringCordView();
22   StringCordView(const StringCordView&) = default;
23   StringCordView(StringCordView&&) noexcept = default;
24   StringCordView(
25       std::vector<c10::string_view> inputs,
26       std::vector<std::shared_ptr<std::string>> ownerships);
27 
28   StringCordView& operator=(const StringCordView&) = default;
29   StringCordView& operator=(StringCordView&&) noexcept = default;
30 
sizeStringCordView31   size_t size() const {
32     return accumulated_sizes_.back();
33   }
34 
35   size_t find(const std::string& tok, size_t start) const;
36   size_t find_regex(const std::string& tok, size_t start) const;
37   StringCordView substr(size_t start, size_t size) const;
38 
atStringCordView39   char at(size_t index) const {
40     return *iter_for_pos(index);
41   }
42   char operator[](size_t index) const {
43     return at(index);
44   }
45 
strStringCordView46   std::string str() const {
47     std::stringstream ss;
48     for (auto s : pieces_) {
49       ss << std::string(s);
50     }
51     return ss.str();
52   }
53 
54   bool operator==(const std::string& rhs) const;
55 
56   bool operator==(const StringCordView& rhs) const;
57 
pieceStringCordView58   c10::string_view piece(size_t index) const {
59     return pieces_[index];
60   }
61 
62   struct Iterator {
IteratorStringCordView::Iterator63     Iterator(
64         const StringCordView* str,
65         size_t start_line,
66         size_t start_pos,
67         size_t size)
68         : line_(start_line), pos_(start_pos), str_(str), size_(size) {}
IteratorStringCordView::Iterator69     explicit Iterator(const StringCordView* str)
70         : Iterator(str, 0, 0, str->size()) {}
71 
IteratorStringCordView::Iterator72     Iterator() : Iterator(nullptr, 0, 0, 0) {}
73 
74     Iterator(const Iterator&) = default;
75     Iterator(Iterator&&) = default;
76     Iterator& operator=(const Iterator&) = default;
77     Iterator& operator=(Iterator&&) = default;
78 
79     Iterator operator++() {
80       if (size_ == 0) {
81         return *this;
82       }
83       if ((pos_ + 1) < str_->pieces_[line_].size()) {
84         pos_++;
85       } else {
86         line_++;
87         pos_ = 0;
88       }
89       return *this;
90     }
91 
92     Iterator operator++(int) {
93       Iterator prev(*this);
94       ++(*this);
95       return prev;
96     }
97 
next_iterStringCordView::Iterator98     Iterator next_iter() const {
99       Iterator next(*this);
100       ++next;
101       return next;
102     }
103 
104     Iterator& operator+=(size_t num) {
105       if (!has_next()) {
106         return *this;
107       }
108       size_t target_pos = pos_ + num;
109       if (target_pos >= str_->accumulated_sizes_[line_] &&
110           (line_ + 1) < str_->accumulated_sizes_.size() &&
111           target_pos < str_->accumulated_sizes_[line_ + 1]) {
112         pos_ = target_pos;
113         return *this;
114       }
115 
116       size_t target_abs_pos = pos() + num;
117       *this = str_->iter_for_pos(target_abs_pos);
118       return *this;
119     }
120 
121     bool operator==(const Iterator& rhs) const {
122       if (!has_next() && !rhs.has_next()) {
123         return true;
124       }
125       return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_);
126     }
127     bool operator!=(const Iterator& rhs) {
128       return !((*this) == rhs);
129     }
has_nextStringCordView::Iterator130     bool has_next() const {
131       return size_ > 0 && (line_ < str_->pieces_.size());
132     }
133 
134     char operator*() const {
135       TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size());
136       TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size());
137       return str_->pieces_[line_].at(pos_);
138     }
139 
140     // returns rest of the line of the current iterator
rest_lineStringCordView::Iterator141     c10::string_view rest_line() const {
142       if (line_ >= str_->pieces_.size()) {
143         return "";
144       }
145 
146       c10::string_view cur_line = str_->pieces_[line_];
147       return cur_line.substr(pos_, std::string::npos);
148     }
149 
posStringCordView::Iterator150     size_t pos() const {
151       if (size_ == 0) {
152         return 0;
153       }
154       return str_->accumulated_sizes_[line_] + pos_;
155     }
156 
157    private:
158     size_t line_;
159     size_t pos_;
160     const StringCordView* str_;
161     size_t size_;
162     friend struct StringCordView;
163   };
164 
beginStringCordView165   Iterator begin() const {
166     return Iterator(this, 0, 0, size());
167   }
endStringCordView168   Iterator end() const {
169     return Iterator(this, pieces_.size(), 0, 0);
170   }
171   Iterator iter_for_pos(size_t pos) const;
172 
173  private:
174   std::vector<c10::string_view> pieces_;
175   std::vector<size_t> accumulated_sizes_;
176   std::vector<std::shared_ptr<std::string>> owned_strings_;
177 };
178 
179 // Source represents a code segment. It keeps track of:
180 //  - text_view : the view into text of the code segment
181 //  - filename (optional) : if present, represents the name of the file from
182 //                          which the code segment originated.
183 //  - starting_line_no : represents the line in the original file where the
184 //                       code segment started.
185 struct TORCH_API Source {
186   // Whether or not Source should copy the string passed in the constructor.
187   enum CopiesString { COPIES_STRING, DONT_COPY };
188 
189   explicit Source(
190       c10::string_view text_view,
191       std::optional<std::string> filename = std::nullopt,
192       size_t starting_line_no = 0,
193       std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr,
194       CopiesString copies_str = COPIES_STRING)
filename_Source195       : filename_(std::move(filename)),
196         starting_line_no_(starting_line_no),
197         gen_ranges_(std::move(gen_ranges)) {
198     if (copies_str == COPIES_STRING) {
199       std::shared_ptr<std::string> allocated_str =
200           std::make_shared<std::string>(text_view.data(), text_view.size());
201       text_view_ = StringCordView({*allocated_str}, {allocated_str});
202     } else {
203       text_view_ = StringCordView({text_view}, {});
204     }
205 
206     calc_line_start_offsets();
207   }
208 
209   explicit Source(
210       StringCordView str,
211       std::optional<std::string> filename = std::nullopt,
212       size_t starting_line_no = 0,
213       std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
text_view_Source214       : text_view_(std::move(str)),
215         filename_(std::move(filename)),
216         starting_line_no_(starting_line_no),
217         gen_ranges_(std::move(gen_ranges)) {
218     calc_line_start_offsets();
219   }
220   // Given a line number (within source_), return the byte offset of the
221   // beginning of that line.
offset_for_lineSource222   size_t offset_for_line(size_t line) const {
223     return line_starting_offsets_.at(line);
224   }
225 
226   // Returns number of lines present.
num_linesSource227   size_t num_lines() const {
228     return line_starting_offsets_.size();
229   }
230 
231   // Calculate the line (within the code segment) on which `offset` resides.
lineno_for_offsetSource232   size_t lineno_for_offset(size_t offset) const {
233     auto iter = std::upper_bound(
234         line_starting_offsets_.begin(), line_starting_offsets_.end(), offset);
235     return iter - line_starting_offsets_.begin() - 1;
236   }
237 
238   // Calculate the line (within the original source file, if present) on which
239   // `lineno` resides.
lineno_to_source_linenoSource240   size_t lineno_to_source_lineno(size_t lineno) const {
241     if (filename_) {
242       return lineno + starting_line_no_;
243     } else {
244       return lineno;
245     }
246   }
247 
get_lineSource248   StringCordView get_line(size_t lineno) const {
249     auto start = offset_for_line(lineno);
250     auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start
251                                            : text_view_.size() - start;
252     return text_view_.substr(start, size);
253   }
254 
text_strSource255   const StringCordView& text_str() const {
256     return text_view_;
257   }
258 
char_atSource259   char char_at(size_t index) const {
260     return text_view_.at(index);
261   }
262 
sizeSource263   size_t size() const {
264     return text_view_.size();
265   }
266 
filenameSource267   std::optional<std::string>& filename() {
268     return filename_;
269   }
270 
starting_line_noSource271   size_t starting_line_no() const {
272     return starting_line_no_;
273   }
274 
275   std::optional<SourceRange> findSourceRangeThatGenerated(
276       const SourceRange& range);
277 
278   ~Source() = default;
279 
280  private:
calc_line_start_offsetsSource281   void calc_line_start_offsets() {
282     line_starting_offsets_.clear();
283     line_starting_offsets_.push_back(0);
284     size_t pos = 0;
285     while ((pos = text_view_.find("\n", pos)) != std::string::npos) {
286       line_starting_offsets_.push_back(++pos);
287     }
288   }
289 
290   StringCordView text_view_;
291 
292   std::optional<std::string> filename_;
293   // If filename_ is not present, starting_line_no_ is don't care
294   size_t starting_line_no_;
295   // Starting offsets for lines into the source. e.g. line 0 starts at
296   // line_starting_offsets_[0], etc.
297   std::vector<size_t> line_starting_offsets_;
298 
299   std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
300 };
301 
302 // A SourceRange is a reference to subset of a Source, specified by `start` and
303 // `end` byte offsets into the source text.
304 struct TORCH_API SourceRange {
SourceRangeSourceRange305   SourceRange(std::shared_ptr<Source> source_view, size_t start_, size_t end_)
306       : source_view_(std::move(source_view)), start_(start_), end_(end_) {
307     if (source_view_) {
308       start_iter_ = source_view_->text_str().iter_for_pos(start_);
309     }
310   }
311 
SourceRangeSourceRange312   SourceRange() : source_view_(nullptr), start_(0), end_(0) {}
313 
SourceRangeSourceRange314   SourceRange(
315       std::shared_ptr<Source> source_view_,
316       StringCordView::Iterator start_iter,
317       size_t end_)
318       : source_view_(std::move(source_view_)),
319         start_(start_iter.pos()),
320         end_(end_),
321         start_iter_(start_iter) {}
322 
token_textSourceRange323   const c10::string_view token_text() const {
324     size_t size = end() - start();
325     return start_iter_.rest_line().substr(0, size);
326   }
327 
textSourceRange328   const StringCordView text() const {
329     return source_view_->text_str().substr(start(), end() - start());
330   }
sizeSourceRange331   size_t size() const {
332     return end() - start();
333   }
334   static const size_t CONTEXT = 3;
335   void highlight(std::ostream& out) const;
336 
337   // Customizable version of 'highlight' method.
338   void print_with_context(
339       std::ostream& out,
340       size_t context,
341       bool highlight,
342       const std::string& funcname) const;
343 
sourceSourceRange344   const std::shared_ptr<Source>& source() const {
345     return source_view_;
346   }
startSourceRange347   size_t start() const {
348     return start_;
349   }
endSourceRange350   size_t end() const {
351     return end_;
352   }
strSourceRange353   std::string str() const {
354     std::stringstream ss;
355     highlight(ss);
356     return ss.str();
357   }
358 
file_line_colSourceRange359   std::optional<std::tuple<std::string, size_t, size_t>> file_line_col() const {
360     if (!source_view_ || !source()->filename()) {
361       return std::nullopt;
362     }
363 
364     auto lineno = source_view_->lineno_for_offset(start_);
365     auto col_offset = (int)start_ - (int)source_view_->offset_for_line(lineno);
366     // TODO: std::optional<>::value returns an rvalue ref so can't use it here??
367     return std::make_tuple<std::string, size_t, size_t>(
368         source_view_->filename().value_or(""),
369         source_view_->lineno_to_source_lineno(lineno),
370         (size_t)col_offset);
371   }
372 
373   bool operator==(const SourceRange& rhs) const {
374     return start() == rhs.start() && end() == rhs.end() &&
375         source() == rhs.source();
376   }
377 
378   bool operator!=(const SourceRange& rhs) const {
379     return !(*this == rhs);
380   }
381 
findSourceRangeThatGeneratedSourceRange382   std::optional<SourceRange> findSourceRangeThatGenerated() const {
383     if (!source_view_) {
384       return std::nullopt;
385     }
386     return source_view_->findSourceRangeThatGenerated(*this);
387   }
388 
389  protected:
390   std::shared_ptr<Source> source_view_;
391 
392  private:
393   size_t start_;
394   size_t end_;
395   StringCordView::Iterator start_iter_;
396 };
397 
398 // OwnedSourceRange is just like a SourceRange except that it owns a `Source`
399 // instead of `Source`. Thus OwnedSourceRange owns a copy of source text.
400 struct OwnedSourceRange : public SourceRange {
OwnedSourceRangeOwnedSourceRange401   explicit OwnedSourceRange(const SourceRange& source_range)
402       : SourceRange(source_range) {
403     const auto& source = source_range.source();
404     if (source) {
405       source_view_ = std::make_shared<Source>(
406           source->text_str().str(),
407           source->filename(),
408           source->starting_line_no());
409     }
410   }
411 };
412 
413 struct TORCH_API SourceRangeHasher {
414  public:
415   size_t operator()(const torch::jit::SourceRange& key) const;
416 };
417 
418 struct StackEntry {
419   std::string filename;
420   SourceRange range;
421 };
422 
423 TORCH_API void format_stack_trace(
424     std::ostream& out,
425     const std::vector<StackEntry>& entries);
426 
427 inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) {
428   range.highlight(out);
429   return out;
430 }
431 
432 // A pair of (byte offset, SourceRange) describing a specific segment
433 // of the output stream
434 struct TaggedRange {
TaggedRangeTaggedRange435   TaggedRange(size_t bytes, SourceRange range)
436       : bytes(bytes), range(std::move(range)) {}
437   size_t bytes;
438   SourceRange range;
439 };
440 using SourceRangeRecords = std::vector<TaggedRange>;
441 using SourceRangeTagMap =
442     std::unordered_map<SourceRange, int64_t, SourceRangeHasher>;
443 
444 } // namespace torch::jit
445 
446 namespace std {
447 template <>
448 struct iterator_traits<torch::jit::StringCordView::Iterator> {
449   using value_type = char;
450   using difference_type = ptrdiff_t;
451   using pointer = char*;
452   using reference = char&;
453   using iterator_category = std::forward_iterator_tag;
454 };
455 } // namespace std
456