xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/source_range_serialization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/source_range_serialization.h>
2 #include <torch/csrc/jit/serialization/source_range_serialization_impl.h>
3 
4 #include <c10/util/Exception.h>
5 #include <c10/util/Flags.h>
6 #include <torch/csrc/jit/mobile/type_parser.h>
7 #include <torch/csrc/jit/serialization/pickle.h>
8 #include <algorithm>
9 #include <memory>
10 
11 namespace torch::jit {
12 
13 // "Whether to emit compact debug_pkl when saving a model to .pt file."
14 // "Compact file is smaller but cannot be loaded by old torch binaries."
15 // TODO(qihan) remove when all binaries are using string table.
16 thread_local bool should_use_format_with_string_table_ = true;
17 
18 class SourceRangeSerializer {
19  public:
20   // Serialize SourceRange as Tuple[SourceType, int, int]
21   // where SourceType = Tuple[int, int, int, List[int]],
22   // The first 2 ints are positions into the vector returned by textSaved
23   // after all the Ranges are processed. textSaved() returns a vector of str
24   // the serialized form of Source
25   c10::IValue serialize(const SourceRange& sr);
26 
texts_saved()27   const std::vector<c10::IValue>& texts_saved() {
28     return texts_;
29   }
30 
SourceRangeSerializer()31   SourceRangeSerializer() {
32     texts_.emplace_back("");
33     text_to_idx_[texts_.back().toStringRef()] = 0;
34   }
35 
36  private:
37   // Serialize Source as Tuple[str, Optional[str], int, List[int]]
38   // This caches serialized sources, since many SourceRanges can
39   // refer to the same one.
40   c10::IValue serialize_source(const std::shared_ptr<Source>& s);
41   std::unordered_map<std::shared_ptr<Source>, c10::IValue> serialized_sources;
42 
43   int64_t store_text_and_get_index(const std::string& text_view);
44 
45   std::vector<c10::IValue> texts_;
46   std::unordered_map<c10::string_view, int64_t> text_to_idx_;
47 };
48 
deserialize(const c10::IValue & iv)49 SourceRange SourceRangeDeserializer::deserialize(const c10::IValue& iv) {
50   const auto& tup_elems = iv.toTupleRef().elements();
51   TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
52   std::shared_ptr<Source> source_ = deserialize_source(tup_elems[0]);
53   int64_t start_ = tup_elems[1].toInt();
54   int64_t end_ = tup_elems[2].toInt();
55   return SourceRange(source_, start_, end_);
56 }
57 
deserialize_source(const c10::IValue & iv)58 std::shared_ptr<Source> SourceRangeDeserializer::deserialize_source(
59     const c10::IValue& iv) {
60   auto tup = iv.toTuple();
61   auto it = cached_sources.find(tup);
62   if (it != cached_sources.end()) {
63     return it->second;
64   }
65   std::shared_ptr<Source> source;
66   const auto& tup_elems = tup->elements();
67   TORCH_INTERNAL_ASSERT(tup_elems.size() == 3);
68   if (!text_table_.empty()) {
69     const auto& textIndex = tup_elems[0].toIntList();
70     int64_t fnameIndex = tup_elems[1].toInt();
71     int64_t starting_line_no_ = tup_elems[2].toInt();
72     std::optional<std::string> filename = std::nullopt;
73 
74     TORCH_CHECK(
75         (uint64_t)fnameIndex < text_table_.size(),
76         "Text table index is out of range")
77     filename = *text_table_[fnameIndex];
78 
79     std::vector<c10::string_view> pieces;
80     std::vector<std::shared_ptr<std::string>> strs;
81 
82     for (int64_t i : textIndex) {
83       pieces.emplace_back(*text_table_[i]);
84       strs.emplace_back(text_table_[i]);
85     }
86 
87     StringCordView str_cord(std::move(pieces), std::move(strs));
88 
89     source = std::make_shared<Source>(str_cord, filename, starting_line_no_);
90   } else {
91     std::string text_ = tup_elems[0].toStringRef();
92     std::optional<std::string> filename_ =
93         tup_elems[1].toOptional<std::string>();
94     int64_t starting_line_no_ = tup_elems[2].toInt();
95     source = std::make_shared<Source>(
96         std::move(text_), std::move(filename_), starting_line_no_);
97   }
98   cached_sources[tup] = source;
99   return source;
100 }
101 
serialize(const SourceRange & sr)102 c10::IValue SourceRangeSerializer::serialize(const SourceRange& sr) {
103   return c10::ivalue::Tuple::create(
104       serialize_source(sr.source()), (int64_t)sr.start(), (int64_t)sr.end());
105 }
106 
store_text_and_get_index(const std::string & text_view)107 int64_t SourceRangeSerializer::store_text_and_get_index(
108     const std::string& text_view) {
109   auto text_iter = text_to_idx_.find(text_view);
110   if (text_iter == text_to_idx_.end()) {
111     int64_t text_pos = static_cast<int64_t>(texts_.size());
112     texts_.emplace_back(text_view);
113     text_to_idx_[texts_.back().toStringView()] = text_pos;
114     return text_pos;
115   } else {
116     return text_iter->second;
117   }
118 }
119 
serialize_source(const std::shared_ptr<Source> & s)120 c10::IValue SourceRangeSerializer::serialize_source(
121     const std::shared_ptr<Source>& s) {
122   if (serialized_sources.count(s)) {
123     return serialized_sources.at(s);
124   }
125   c10::intrusive_ptr<c10::ivalue::Tuple> serialized;
126   c10::List<int64_t> lines;
127   if (should_use_format_with_string_table_) {
128     if (s == nullptr) {
129       serialized = c10::ivalue::Tuple::create({lines, 0, 0});
130     } else {
131       for (size_t lineno = 0; lineno < s->num_lines(); lineno++) {
132         std::string line_content = s->get_line(lineno).str();
133         int64_t text_pos = store_text_and_get_index(line_content);
134         lines.push_back(text_pos);
135       }
136 
137       int64_t fname_pos = 0;
138       if (s->filename().has_value()) {
139         fname_pos = store_text_and_get_index(*s->filename());
140       }
141       serialized = c10::ivalue::Tuple::create(
142           {lines, fname_pos, (int64_t)s->starting_line_no()});
143     }
144   } else {
145     if (s == nullptr) {
146       serialized = c10::ivalue::Tuple::create({"", "", 0});
147     } else {
148       serialized = c10::ivalue::Tuple::create(
149           {s->text_str().str(), s->filename(), (int64_t)s->starting_line_no()});
150     }
151   }
152   serialized_sources[s] = serialized;
153   return serialized;
154 }
155 
SourceRangePickler()156 SourceRangePickler::SourceRangePickler() : srs(new SourceRangeSerializer()) {}
157 
pickle(const SourceRangeRecords & ranges,const SourceRangeTagMap & source_range_tags)158 std::vector<char> SourceRangePickler::pickle(
159     const SourceRangeRecords& ranges,
160     const SourceRangeTagMap& source_range_tags) {
161   std::vector<c10::IValue> ivalues;
162   for (const auto& range : ranges) {
163     int64_t source_range_tag{-1};
164     const auto& it = source_range_tags.find(range.range);
165     if (it != source_range_tags.end()) {
166       source_range_tag = it->second;
167     }
168 
169     ivalues.emplace_back(c10::ivalue::Tuple::create(
170         {(int64_t)range.bytes,
171          srs->serialize(range.range),
172          static_cast<int64_t>(source_range_tag)}));
173   }
174 
175   std::vector<at::Tensor> table;
176   auto textTable = c10::ivalue::Tuple::create(srs->texts_saved());
177   auto ivalue = c10::ivalue::Tuple::create(std::move(ivalues));
178   std::vector<char> result;
179   if (should_use_format_with_string_table_) {
180     result = jit::pickle(
181         c10::ivalue::Tuple::create({kFormatWithStringTable, textTable, ivalue}),
182         &table);
183   } else {
184     result = jit::pickle(ivalue, &table);
185   }
186   TORCH_CHECK(table.empty(), "Expected 0 tensors to be written");
187   return result;
188 }
189 
ConcreteSourceRangeUnpickler(at::DataPtr && data,size_t size)190 ConcreteSourceRangeUnpickler::ConcreteSourceRangeUnpickler(
191     at::DataPtr&& data,
192     size_t size)
193     : data(std::move(data)),
194       size(size),
195       deserializer(nullptr),
196       unpickled_records(nullptr) {}
197 
unpickle()198 void ConcreteSourceRangeUnpickler::unpickle() {
199   std::lock_guard<std::mutex> guard(mutex);
200   if (unpickled_records) {
201     return;
202   }
203 
204   auto ivaluesTuple = jit::unpickle(
205                           reinterpret_cast<const char*>(data.get()),
206                           size,
207                           nullptr,
208                           {},
209                           c10::parseType)
210                           .toTuple();
211 
212   const auto& ivalues = ivaluesTuple->elements();
213   TORCH_CHECK(
214       !ivalues.empty(), "Invalid unpickle operation: empty ivalues tuple");
215   unpickled_records = std::make_shared<SourceRangeRecords>();
216   IValue lines;
217   if (ivalues[0].isString() &&
218       kFormatWithStringTable == ivalues[0].toStringRef()) {
219     deserializer = std::make_shared<SourceRangeDeserializer>(ivalues[1]);
220     lines = ivalues[2];
221   } else {
222     deserializer = std::make_shared<SourceRangeDeserializer>();
223     lines = ivaluesTuple;
224   }
225   for (auto& val : lines.toTuple()->elements()) {
226     const auto& tup_elems = val.toTupleRef().elements();
227     int64_t offset = tup_elems[kByteOffsetIndex].toInt();
228     auto source_range = deserializer->deserialize(tup_elems[kSourceRangeIndex]);
229     unpickled_records->emplace_back(offset, std::move(source_range));
230   }
231 }
232 
233 std::optional<SourceRange> ConcreteSourceRangeUnpickler::
findSourceRangeThatGenerated(const SourceRange & range)234     findSourceRangeThatGenerated(const SourceRange& range) {
235   unpickle();
236 
237   auto query = TaggedRange(range.start(), SourceRange{});
238   auto entry = std::upper_bound(
239       unpickled_records->begin(),
240       unpickled_records->end(),
241       query,
242       [](const TaggedRange& a, const TaggedRange& b) -> bool {
243         return a.bytes < b.bytes;
244       });
245 
246   // NB: must decrement iterator since upper_bound finds the element
247   // *greater than* the query.
248   if (entry != unpickled_records->begin()) {
249     return (entry - 1)->range;
250   }
251 
252   return std::nullopt;
253 }
254 
setShouldUseFormatWithStringTable(bool should_use_format_with_string_table)255 TORCH_API void setShouldUseFormatWithStringTable(
256     bool should_use_format_with_string_table) {
257   should_use_format_with_string_table_ = should_use_format_with_string_table;
258 }
259 
260 } // namespace torch::jit
261