1 #include <torch/csrc/jit/serialization/source_range_serialization.h>
2 #include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
3
4 #include <c10/util/Exception.h>
5 #include <c10/util/Flags.h>
6 #include <torch/csrc/jit/mobile/type_parser.h>
7 #include <torch/csrc/jit/serialization/pickle.h>
8 #include <algorithm>
9 #include <memory>
10
11 namespace torch::jit {
12
13 // "Whether to emit compact debug_pkl when saving a model to .pt file."
14 // "Compact file is smaller but cannot be loaded by old torch binaries."
15 // TODO(qihan) remove when all binaries are using string table.
16 thread_local bool should_use_format_with_string_table_ = true;
17
18 class SourceRangeSerializer {
19 public:
20 // Serialize SourceRange as Tuple[SourceType, int, int]
21 // where SourceType = Tuple[int, int, int, List[int]],
22 // The first 2 ints are positions into the vector returned by textSaved
23 // after all the Ranges are processed. textSaved() returns a vector of str
24 // the serialized form of Source
25 c10::IValue serialize(const SourceRange& sr);
26
texts_saved()27 const std::vector<c10::IValue>& texts_saved() {
28 return texts_;
29 }
30
SourceRangeSerializer()31 SourceRangeSerializer() {
32 texts_.emplace_back("");
33 text_to_idx_[texts_.back().toStringRef()] = 0;
34 }
35
36 private:
37 // Serialize Source as Tuple[str, Optional[str], int, List[int]]
38 // This caches serialized sources, since many SourceRanges can
39 // refer to the same one.
40 c10::IValue serialize_source(const std::shared_ptr<Source>& s);
41 std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
42
43 int64_t store_text_and_get_index(const std::string& text_view);
44
45 std::vector<c10::IValue> texts_;
46 std::unordered_map<c10::string_view, int64_t> text_to_idx_;
47 };
48
deserialize(const c10::IValue & iv)49 SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
50 const auto& tup_elems = iv.toTupleRef().elements();
51 TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
52 std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
53 int64_t start_ = tup_elems[1].toInt();
54 int64_t end_ = tup_elems[2].toInt();
55 return SourceRange(source_, start_, end_);
56 }
57
deserialize_source(const c10::IValue & iv)58 std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
59 const c10::IValue& iv) {
60 auto tup = iv.toTuple();
61 auto it = cached_sources.find(tup);
62 if (it != cached_sources.end()) {
63 return it->second;
64 }
65 std::shared_ptr<Source> source;
66 const auto& tup_elems = tup->elements();
67 TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
68 if (!text_table_.empty()) {
69 const auto& textIndex = tup_elems[0].toIntList();
70 int64_t fnameIndex = tup_elems[1].toInt();
71 int64_t starting_line_no_ = tup_elems[2].toInt();
72 std::optional<std::string> filename = std::nullopt;
73
74 TORCH_CHECK(
75 (uint64_t)fnameIndex < text_table_.size(),
76 "Text table index is out of range")
77 filename = *text_table_[fnameIndex];
78
79 std::vector<c10::string_view> pieces;
80 std::vector<std::shared_ptr<std::string>> strs;
81
82 for (int64_t i : textIndex) {
83 pieces.emplace_back(*text_table_[i]);
84 strs.emplace_back(text_table_[i]);
85 }
86
87 StringCordView str_cord(std::move(pieces), std::move(strs));
88
89 source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
90 } else {
91 std::string text_ = tup_elems[0].toStringRef();
92 std::optional<std::string> filename_ =
93 tup_elems[1].toOptional<std::string>();
94 int64_t starting_line_no_ = tup_elems[2].toInt();
95 source = std::make_shared<Source>(
96 std::move(text_), std::move(filename_), starting_line_no_);
97 }
98 cached_sources[tup] = source;
99 return source;
100 }
101
serialize(const SourceRange & sr)102 c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
103 return c10::ivalue::Tuple::create(
104 serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
105 }
106
store_text_and_get_index(const std::string & text_view)107 int64_t SourceRangeSerializer::store_text_and_get_index(
108 const std::string& text_view) {
109 auto text_iter = text_to_idx_.find(text_view);
110 if (text_iter == text_to_idx_.end()) {
111 int64_t text_pos = static_cast<int64_t>(texts_.size());
112 texts_.emplace_back(text_view);
113 text_to_idx_[texts_.back().toStringView()] = text_pos;
114 return text_pos;
115 } else {
116 return text_iter->second;
117 }
118 }
119
serialize_source(const std::shared_ptr<Source> & s)120 c10::IValue SourceRangeSerializer::serialize_source(
121 const std::shared_ptr<Source>& s) {
122 if (serialized_sources.count(s)) {
123 return serialized_sources.at(s);
124 }
125 c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
126 c10::List<int64_t> lines;
127 if (should_use_format_with_string_table_) {
128 if (s == nullptr) {
129 serialized = c10::ivalue::Tuple::create({lines, 0, 0});
130 } else {
131 for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
132 std::string line_content = s->get_line(lineno).str();
133 int64_t text_pos = store_text_and_get_index(line_content);
134 lines.push_back(text_pos);
135 }
136
137 int64_t fname_pos = 0;
138 if (s->filename().has_value()) {
139 fname_pos = store_text_and_get_index(*s->filename());
140 }
141 serialized = c10::ivalue::Tuple::create(
142 {lines, fname_pos, (int64_t)s->starting_line_no()});
143 }
144 } else {
145 if (s == nullptr) {
146 serialized = c10::ivalue::Tuple::create({"", "", 0});
147 } else {
148 serialized = c10::ivalue::Tuple::create(
149 {s->text_str().str(), s->filename(), (int64_t)s->starting_line_no()});
150 }
151 }
152 serialized_sources[s] = serialized;
153 return serialized;
154 }
155
SourceRangePickler()156 SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {}
157
pickle(const SourceRangeRecords & ranges,const SourceRangeTagMap & source_range_tags)158 std::vector<char> SourceRangePickler::pickle(
159 const SourceRangeRecords& ranges,
160 const SourceRangeTagMap& source_range_tags) {
161 std::vector<c10::IValue> ivalues;
162 for (const auto& range : ranges) {
163 int64_t source_range_tag{-1};
164 const auto& it = source_range_tags.find(range.range);
165 if (it != source_range_tags.end()) {
166 source_range_tag = it->second;
167 }
168
169 ivalues.emplace_back(c10::ivalue::Tuple::create(
170 {(int64_t)range.bytes,
171 srs->serialize(range.range),
172 static_cast<int64_t>(source_range_tag)}));
173 }
174
175 std::vector<at::Tensor> table;
176 auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
177 auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
178 std::vector<char> result;
179 if (should_use_format_with_string_table_) {
180 result = jit::pickle(
181 c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
182 &table);
183 } else {
184 result = jit::pickle(ivalue, &table);
185 }
186 TORCH_CHECK(table.empty(), "Expected 0 tensors to be written");
187 return result;
188 }
189
ConcreteSourceRangeUnpickler(at::DataPtr && data,size_t size)190 ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
191 at::DataPtr&& data,
192 size_t size)
193 : data(std::move(data)),
194 size(size),
195 deserializer(nullptr),
196 unpickled_records(nullptr) {}
197
unpickle()198 void ConcreteSourceRangeUnpickler::unpickle() {
199 std::lock_guard<std::mutex> guard(mutex);
200 if (unpickled_records) {
201 return;
202 }
203
204 auto ivaluesTuple = jit::unpickle(
205 reinterpret_cast<const char*>(data.get()),
206 size,
207 nullptr,
208 {},
209 c10::parseType)
210 .toTuple();
211
212 const auto& ivalues = ivaluesTuple->elements();
213 TORCH_CHECK(
214 !ivalues.empty(), "Invalid unpickle operation: empty ivalues tuple");
215 unpickled_records = std::make_shared<SourceRangeRecords>();
216 IValue lines;
217 if (ivalues[0].isString() &&
218 kFormatWithStringTable == ivalues[0].toStringRef()) {
219 deserializer = std::make_shared<SourceRangeDeserializer>(ivalues[1]);
220 lines = ivalues[2];
221 } else {
222 deserializer = std::make_shared<SourceRangeDeserializer>();
223 lines = ivaluesTuple;
224 }
225 for (auto& val : lines.toTuple()->elements()) {
226 const auto& tup_elems = val.toTupleRef().elements();
227 int64_t offset = tup_elems[kByteOffsetIndex].toInt();
228 auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
229 unpickled_records->emplace_back(offset, std::move(source_range));
230 }
231 }
232
233 std::optional<SourceRange> ConcreteSourceRangeUnpickler::
findSourceRangeThatGenerated(const SourceRange & range)234 findSourceRangeThatGenerated(const SourceRange& range) {
235 unpickle();
236
237 auto query = TaggedRange(range.start(), SourceRange{});
238 auto entry = std::upper_bound(
239 unpickled_records->begin(),
240 unpickled_records->end(),
241 query,
242 [](const TaggedRange& a, const TaggedRange& b) -> bool {
243 return a.bytes < b.bytes;
244 });
245
246 // NB: must decrement iterator since upper_bound finds the element
247 // *greater than* the query.
248 if (entry != unpickled_records->begin()) {
249 return (entry - 1)->range;
250 }
251
252 return std::nullopt;
253 }
254
setShouldUseFormatWithStringTable(bool should_use_format_with_string_table)255 TORCH_API void setShouldUseFormatWithStringTable(
256 bool should_use_format_with_string_table) {
257 should_use_format_with_string_table_ = should_use_format_with_string_table;
258 }
259
260 } // namespace torch::jit
261