1 #pragma once 2 #include <c10/util/Exception.h> 3 #include <optional> 4 5 #include <algorithm> 6 #include <iterator> 7 #include <memory> 8 #include <ostream> 9 #include <sstream> 10 #include <unordered_map> 11 12 namespace torch::jit { 13 14 class SourceRangeUnpickler; 15 struct SourceRange; 16 17 // A stringlike class backed by a vector of string_view 18 // the string represented are logically the concatenation of the string_views 19 // This has advantage of not needing continues memory. 20 struct TORCH_API StringCordView { 21 StringCordView(); 22 StringCordView(const StringCordView&) = default; 23 StringCordView(StringCordView&&) noexcept = default; 24 StringCordView( 25 std::vector<c10::string_view> inputs, 26 std::vector<std::shared_ptr<std::string>> ownerships); 27 28 StringCordView& operator=(const StringCordView&) = default; 29 StringCordView& operator=(StringCordView&&) noexcept = default; 30 sizeStringCordView31 size_t size() const { 32 return accumulated_sizes_.back(); 33 } 34 35 size_t find(const std::string& tok, size_t start) const; 36 size_t find_regex(const std::string& tok, size_t start) const; 37 StringCordView substr(size_t start, size_t size) const; 38 atStringCordView39 char at(size_t index) const { 40 return *iter_for_pos(index); 41 } 42 char operator[](size_t index) const { 43 return at(index); 44 } 45 strStringCordView46 std::string str() const { 47 std::stringstream ss; 48 for (auto s : pieces_) { 49 ss << std::string(s); 50 } 51 return ss.str(); 52 } 53 54 bool operator==(const std::string& rhs) const; 55 56 bool operator==(const StringCordView& rhs) const; 57 pieceStringCordView58 c10::string_view piece(size_t index) const { 59 return pieces_[index]; 60 } 61 62 struct Iterator { IteratorStringCordView::Iterator63 Iterator( 64 const StringCordView* str, 65 size_t start_line, 66 size_t start_pos, 67 size_t size) 68 : line_(start_line), pos_(start_pos), str_(str), size_(size) {} IteratorStringCordView::Iterator69 explicit Iterator(const StringCordView* str) 70 : Iterator(str, 0, 0, str->size()) {} 71 IteratorStringCordView::Iterator72 Iterator() : Iterator(nullptr, 0, 0, 0) {} 73 74 Iterator(const Iterator&) = default; 75 Iterator(Iterator&&) = default; 76 Iterator& operator=(const Iterator&) = default; 77 Iterator& operator=(Iterator&&) = default; 78 79 Iterator operator++() { 80 if (size_ == 0) { 81 return *this; 82 } 83 if ((pos_ + 1) < str_->pieces_[line_].size()) { 84 pos_++; 85 } else { 86 line_++; 87 pos_ = 0; 88 } 89 return *this; 90 } 91 92 Iterator operator++(int) { 93 Iterator prev(*this); 94 ++(*this); 95 return prev; 96 } 97 next_iterStringCordView::Iterator98 Iterator next_iter() const { 99 Iterator next(*this); 100 ++next; 101 return next; 102 } 103 104 Iterator& operator+=(size_t num) { 105 if (!has_next()) { 106 return *this; 107 } 108 size_t target_pos = pos_ + num; 109 if (target_pos >= str_->accumulated_sizes_[line_] && 110 (line_ + 1) < str_->accumulated_sizes_.size() && 111 target_pos < str_->accumulated_sizes_[line_ + 1]) { 112 pos_ = target_pos; 113 return *this; 114 } 115 116 size_t target_abs_pos = pos() + num; 117 *this = str_->iter_for_pos(target_abs_pos); 118 return *this; 119 } 120 121 bool operator==(const Iterator& rhs) const { 122 if (!has_next() && !rhs.has_next()) { 123 return true; 124 } 125 return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_); 126 } 127 bool operator!=(const Iterator& rhs) { 128 return !((*this) == rhs); 129 } has_nextStringCordView::Iterator130 bool has_next() const { 131 return size_ > 0 && (line_ < str_->pieces_.size()); 132 } 133 134 char operator*() const { 135 TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size()); 136 TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size()); 137 return str_->pieces_[line_].at(pos_); 138 } 139 140 // returns rest of the line of the current iterator rest_lineStringCordView::Iterator141 c10::string_view rest_line() const { 142 if (line_ >= str_->pieces_.size()) { 143 return ""; 144 } 145 146 c10::string_view cur_line = str_->pieces_[line_]; 147 return cur_line.substr(pos_, std::string::npos); 148 } 149 posStringCordView::Iterator150 size_t pos() const { 151 if (size_ == 0) { 152 return 0; 153 } 154 return str_->accumulated_sizes_[line_] + pos_; 155 } 156 157 private: 158 size_t line_; 159 size_t pos_; 160 const StringCordView* str_; 161 size_t size_; 162 friend struct StringCordView; 163 }; 164 beginStringCordView165 Iterator begin() const { 166 return Iterator(this, 0, 0, size()); 167 } endStringCordView168 Iterator end() const { 169 return Iterator(this, pieces_.size(), 0, 0); 170 } 171 Iterator iter_for_pos(size_t pos) const; 172 173 private: 174 std::vector<c10::string_view> pieces_; 175 std::vector<size_t> accumulated_sizes_; 176 std::vector<std::shared_ptr<std::string>> owned_strings_; 177 }; 178 179 // Source represents a code segment. It keeps track of: 180 // - text_view : the view into text of the code segment 181 // - filename (optional) : if present, represents the name of the file from 182 // which the code segment originated. 183 // - starting_line_no : represents the line in the original file where the 184 // code segment started. 185 struct TORCH_API Source { 186 // Whether or not Source should copy the string passed in the constructor. 187 enum CopiesString { COPIES_STRING, DONT_COPY }; 188 189 explicit Source( 190 c10::string_view text_view, 191 std::optional<std::string> filename = std::nullopt, 192 size_t starting_line_no = 0, 193 std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr, 194 CopiesString copies_str = COPIES_STRING) filename_Source195 : filename_(std::move(filename)), 196 starting_line_no_(starting_line_no), 197 gen_ranges_(std::move(gen_ranges)) { 198 if (copies_str == COPIES_STRING) { 199 std::shared_ptr<std::string> allocated_str = 200 std::make_shared<std::string>(text_view.data(), text_view.size()); 201 text_view_ = StringCordView({*allocated_str}, {allocated_str}); 202 } else { 203 text_view_ = StringCordView({text_view}, {}); 204 } 205 206 calc_line_start_offsets(); 207 } 208 209 explicit Source( 210 StringCordView str, 211 std::optional<std::string> filename = std::nullopt, 212 size_t starting_line_no = 0, 213 std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr) text_view_Source214 : text_view_(std::move(str)), 215 filename_(std::move(filename)), 216 starting_line_no_(starting_line_no), 217 gen_ranges_(std::move(gen_ranges)) { 218 calc_line_start_offsets(); 219 } 220 // Given a line number (within source_), return the byte offset of the 221 // beginning of that line. offset_for_lineSource222 size_t offset_for_line(size_t line) const { 223 return line_starting_offsets_.at(line); 224 } 225 226 // Returns number of lines present. num_linesSource227 size_t num_lines() const { 228 return line_starting_offsets_.size(); 229 } 230 231 // Calculate the line (within the code segment) on which `offset` resides. lineno_for_offsetSource232 size_t lineno_for_offset(size_t offset) const { 233 auto iter = std::upper_bound( 234 line_starting_offsets_.begin(), line_starting_offsets_.end(), offset); 235 return iter - line_starting_offsets_.begin() - 1; 236 } 237 238 // Calculate the line (within the original source file, if present) on which 239 // `lineno` resides. lineno_to_source_linenoSource240 size_t lineno_to_source_lineno(size_t lineno) const { 241 if (filename_) { 242 return lineno + starting_line_no_; 243 } else { 244 return lineno; 245 } 246 } 247 get_lineSource248 StringCordView get_line(size_t lineno) const { 249 auto start = offset_for_line(lineno); 250 auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start 251 : text_view_.size() - start; 252 return text_view_.substr(start, size); 253 } 254 text_strSource255 const StringCordView& text_str() const { 256 return text_view_; 257 } 258 char_atSource259 char char_at(size_t index) const { 260 return text_view_.at(index); 261 } 262 sizeSource263 size_t size() const { 264 return text_view_.size(); 265 } 266 filenameSource267 std::optional<std::string>& filename() { 268 return filename_; 269 } 270 starting_line_noSource271 size_t starting_line_no() const { 272 return starting_line_no_; 273 } 274 275 std::optional<SourceRange> findSourceRangeThatGenerated( 276 const SourceRange& range); 277 278 ~Source() = default; 279 280 private: calc_line_start_offsetsSource281 void calc_line_start_offsets() { 282 line_starting_offsets_.clear(); 283 line_starting_offsets_.push_back(0); 284 size_t pos = 0; 285 while ((pos = text_view_.find("\n", pos)) != std::string::npos) { 286 line_starting_offsets_.push_back(++pos); 287 } 288 } 289 290 StringCordView text_view_; 291 292 std::optional<std::string> filename_; 293 // If filename_ is not present, starting_line_no_ is don't care 294 size_t starting_line_no_; 295 // Starting offsets for lines into the source. e.g. line 0 starts at 296 // line_starting_offsets_[0], etc. 297 std::vector<size_t> line_starting_offsets_; 298 299 std::shared_ptr<SourceRangeUnpickler> gen_ranges_; 300 }; 301 302 // A SourceRange is a reference to subset of a Source, specified by `start` and 303 // `end` byte offsets into the source text. 304 struct TORCH_API SourceRange { SourceRangeSourceRange305 SourceRange(std::shared_ptr<Source> source_view, size_t start_, size_t end_) 306 : source_view_(std::move(source_view)), start_(start_), end_(end_) { 307 if (source_view_) { 308 start_iter_ = source_view_->text_str().iter_for_pos(start_); 309 } 310 } 311 SourceRangeSourceRange312 SourceRange() : source_view_(nullptr), start_(0), end_(0) {} 313 SourceRangeSourceRange314 SourceRange( 315 std::shared_ptr<Source> source_view_, 316 StringCordView::Iterator start_iter, 317 size_t end_) 318 : source_view_(std::move(source_view_)), 319 start_(start_iter.pos()), 320 end_(end_), 321 start_iter_(start_iter) {} 322 token_textSourceRange323 const c10::string_view token_text() const { 324 size_t size = end() - start(); 325 return start_iter_.rest_line().substr(0, size); 326 } 327 textSourceRange328 const StringCordView text() const { 329 return source_view_->text_str().substr(start(), end() - start()); 330 } sizeSourceRange331 size_t size() const { 332 return end() - start(); 333 } 334 static const size_t CONTEXT = 3; 335 void highlight(std::ostream& out) const; 336 337 // Customizable version of 'highlight' method. 338 void print_with_context( 339 std::ostream& out, 340 size_t context, 341 bool highlight, 342 const std::string& funcname) const; 343 sourceSourceRange344 const std::shared_ptr<Source>& source() const { 345 return source_view_; 346 } startSourceRange347 size_t start() const { 348 return start_; 349 } endSourceRange350 size_t end() const { 351 return end_; 352 } strSourceRange353 std::string str() const { 354 std::stringstream ss; 355 highlight(ss); 356 return ss.str(); 357 } 358 file_line_colSourceRange359 std::optional<std::tuple<std::string, size_t, size_t>> file_line_col() const { 360 if (!source_view_ || !source()->filename()) { 361 return std::nullopt; 362 } 363 364 auto lineno = source_view_->lineno_for_offset(start_); 365 auto col_offset = (int)start_ - (int)source_view_->offset_for_line(lineno); 366 // TODO: std::optional<>::value returns an rvalue ref so can't use it here?? 367 return std::make_tuple<std::string, size_t, size_t>( 368 source_view_->filename().value_or(""), 369 source_view_->lineno_to_source_lineno(lineno), 370 (size_t)col_offset); 371 } 372 373 bool operator==(const SourceRange& rhs) const { 374 return start() == rhs.start() && end() == rhs.end() && 375 source() == rhs.source(); 376 } 377 378 bool operator!=(const SourceRange& rhs) const { 379 return !(*this == rhs); 380 } 381 findSourceRangeThatGeneratedSourceRange382 std::optional<SourceRange> findSourceRangeThatGenerated() const { 383 if (!source_view_) { 384 return std::nullopt; 385 } 386 return source_view_->findSourceRangeThatGenerated(*this); 387 } 388 389 protected: 390 std::shared_ptr<Source> source_view_; 391 392 private: 393 size_t start_; 394 size_t end_; 395 StringCordView::Iterator start_iter_; 396 }; 397 398 // OwnedSourceRange is just like a SourceRange except that it owns a `Source` 399 // instead of `Source`. Thus OwnedSourceRange owns a copy of source text. 400 struct OwnedSourceRange : public SourceRange { OwnedSourceRangeOwnedSourceRange401 explicit OwnedSourceRange(const SourceRange& source_range) 402 : SourceRange(source_range) { 403 const auto& source = source_range.source(); 404 if (source) { 405 source_view_ = std::make_shared<Source>( 406 source->text_str().str(), 407 source->filename(), 408 source->starting_line_no()); 409 } 410 } 411 }; 412 413 struct TORCH_API SourceRangeHasher { 414 public: 415 size_t operator()(const torch::jit::SourceRange& key) const; 416 }; 417 418 struct StackEntry { 419 std::string filename; 420 SourceRange range; 421 }; 422 423 TORCH_API void format_stack_trace( 424 std::ostream& out, 425 const std::vector<StackEntry>& entries); 426 427 inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) { 428 range.highlight(out); 429 return out; 430 } 431 432 // A pair of (byte offset, SourceRange) describing a specific segment 433 // of the output stream 434 struct TaggedRange { TaggedRangeTaggedRange435 TaggedRange(size_t bytes, SourceRange range) 436 : bytes(bytes), range(std::move(range)) {} 437 size_t bytes; 438 SourceRange range; 439 }; 440 using SourceRangeRecords = std::vector<TaggedRange>; 441 using SourceRangeTagMap = 442 std::unordered_map<SourceRange, int64_t, SourceRangeHasher>; 443 444 } // namespace torch::jit 445 446 namespace std { 447 template <> 448 struct iterator_traits<torch::jit::StringCordView::Iterator> { 449 using value_type = char; 450 using difference_type = ptrdiff_t; 451 using pointer = char*; 452 using reference = char&; 453 using iterator_category = std::forward_iterator_tag; 454 }; 455 } // namespace std 456