xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/debug_info.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/source_range.h>
2 #include <torch/csrc/jit/mobile/debug_info.h>
3 #include <torch/csrc/jit/mobile/type_parser.h>
4 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
5 #include <torch/csrc/jit/serialization/source_range_serialization.h>
6 
7 #include <ATen/core/ivalue.h>
8 #include <torch/csrc/jit/serialization/pickle.h>
9 
10 #include <c10/util/string_view.h>
11 
12 namespace torch::jit {
13 
14 namespace {
15 
debugHandlesNotFoundMessage(const std::string & debug_handles_string)16 C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
17     const std::string& debug_handles_string) {
18   return "Debug info for handle(s): " + debug_handles_string +
19       ", was not found.";
20 }
21 
getStackTraceWithModuleHierarchy(const DebugInfoTuple & source_callstack,const std::string & caller_name)22 std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
23     const DebugInfoTuple& source_callstack,
24     const std::string& caller_name) {
25   std::vector<StackEntry> entries;
26 
27   const SourceRange& range =
28       std::get<kDebugInfoTupleSourceRangeIndex>(source_callstack);
29   InlinedCallStackPtr callstack_ptr =
30       std::get<kDebugInfoTupleInlinedCSIndex>(source_callstack);
31   std::string prev_function_name = caller_name;
32   std::string module_info;
33   if (!callstack_ptr) {
34     // If not cs then top level node
35     entries.emplace_back(StackEntry{prev_function_name, range});
36     return {std::move(entries), std::move(module_info)};
37   } else {
38     while (callstack_ptr) {
39       const auto& opt_module_instance_info = callstack_ptr->module_instance();
40       if (opt_module_instance_info.has_value()) {
41         const auto& module_instance_info = opt_module_instance_info.value();
42         // Sometimes (e.g., in lowered backends) we augment instance name with
43         // type name instead of losing type name. In those cases instance_name
44         // includes both instance name and type name. See
45         // callstack_debug_info_serialization.cpp
46         if (module_instance_info.class_type()) {
47           module_info.append(".").append(
48               utils::get_module_info(module_instance_info));
49         } else {
50           module_info.append(".").append(module_instance_info.instance_name());
51         }
52       } else {
53         module_info.append(".UNKNOWN_INSTANCE(UNKNOWN_TYPE)");
54       }
55       // Now add source range info to stack
56       entries.emplace_back(
57           StackEntry{prev_function_name, callstack_ptr->source_range()});
58       prev_function_name = callstack_ptr->function_name();
59       // Function name appended here
60       // It is renamed to prev_function_name because for StackEntry
61       // it will be appended in the next iteration. This is the format
62       // in which format_stack_trace expects function names.
63       module_info.append("::").append(prev_function_name);
64 
65       if (callstack_ptr->callee()) {
66         callstack_ptr = callstack_ptr->callee().value();
67       } else {
68         callstack_ptr = c10::intrusive_ptr<InlinedCallStack>();
69       }
70     }
71     entries.emplace_back(StackEntry{prev_function_name, range});
72     return {std::move(entries), std::move(module_info)};
73   }
74 }
75 
76 // This function construct stacktrace with module hierarchy
77 // Module hierarchy will contain information about where in the
78 // module hierarchy this source is. For example if conv2d op
79 // exist in hierarcy A->B->C->Conv2d with type annotations of
80 // A -> TopM, B->MyModule, C->SomeModule, then module hierarchy
81 // will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv)
82 // Source level stack information will be from model source code.
getStackTraceWithModuleHierarchy(const std::vector<DebugInfoTuple> & source_callstacks,const std::string & root_scope_string,const std::string & top_module_type_name)83 std::pair<std::string, std::string> getStackTraceWithModuleHierarchy(
84     const std::vector<DebugInfoTuple>& source_callstacks,
85     const std::string& root_scope_string,
86     const std::string& top_module_type_name) {
87   std::vector<StackEntry> stack_entries;
88   std::string module_info =
89       root_scope_string + "(" + top_module_type_name + ")";
90   std::string caller_fn_name = "<unknown>";
91   module_info.append("::").append(caller_fn_name);
92   for (const auto& debug_info : source_callstacks) {
93     auto debug_info_pair =
94         getStackTraceWithModuleHierarchy(debug_info, caller_fn_name);
95     auto entries = std::move(debug_info_pair.first);
96     stack_entries.insert(stack_entries.end(), entries.begin(), entries.end());
97     module_info.append(debug_info_pair.second);
98   }
99   // Only last entry in the callstack will have a node name of interest.
100   // Rest are likely CallMethod/CallFunction nodes
101   auto last_entry = source_callstacks.back();
102   const std::string& node_name =
103       std::get<kDebugInfoTupleNodeNameIndex>(last_entry);
104   module_info.append(".").append(node_name);
105   std::ostringstream ss;
106   ss << "Module hierarchy:" << module_info << "\n";
107   format_stack_trace(ss, stack_entries);
108   return {ss.str(), std::move(module_info)};
109 }
110 
111 } // namespace
112 
MobileDebugTable(std::unique_ptr<caffe2::serialize::PyTorchStreamReader> & reader,const std::shared_ptr<CompilationUnit> & cu)113 MobileDebugTable::MobileDebugTable(
114     std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
115     const std::shared_ptr<CompilationUnit>& cu) {
116   ska::flat_hash_map<int64_t, SourceRange> source_range_map;
117   const std::vector<std::string>& record_names = reader->getAllRecords();
118   const c10::string_view suffix(".debug_pkl");
119   for (const auto& record_name : record_names) {
120     if (c10::string_view(record_name).ends_with(suffix)) {
121       auto [debug_data, debug_size] = reader->getRecord(record_name);
122       auto ivalueTuple = jit::unpickle(
123           reinterpret_cast<const char*>(debug_data.get()),
124           debug_size,
125           nullptr,
126           {},
127           c10::parseType);
128       const auto& ivalues = ivalueTuple.toTuple()->elements();
129       IValue lines;
130       std::unique_ptr<SourceRangeDeserializer> deserializer;
131       if (ivalues.size() == 3 && ivalues[0].isString() &&
132           kFormatWithStringTable == ivalues[0].toStringRef()) {
133         // new format
134         deserializer = std::make_unique<SourceRangeDeserializer>(ivalues[1]);
135         lines = ivalues[2];
136       } else {
137         deserializer = std::make_unique<SourceRangeDeserializer>();
138         lines = ivalueTuple;
139       }
140 
141       for (auto& val : lines.toTuple()->elements()) {
142         auto tup_elems = std::move(*val.toTuple()).elements();
143         // For BC we decode only tuples with 3 elements
144         // assuming it contains
145         // byte_offset, debug_handle (=source range tag), source range
146         if (tup_elems.size() == 3) {
147           int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
148           auto source_range =
149               deserializer->deserialize(tup_elems[kSourceRangeIndex]);
150           source_range_map.emplace(debug_handle, std::move(source_range));
151         }
152       }
153     }
154   }
155   const std::string callstack_debug_file("callstack_debug_map.pkl");
156   if (reader->hasRecord("callstack_debug_map.pkl")) {
157     auto [callstack_data, callstack_data_size] =
158         reader->getRecord(callstack_debug_file);
159     CallStackDebugInfoUnpickler unpickler;
160     callstack_ptr_map_ = unpickler.unpickle(
161         callstack_data, callstack_data_size, source_range_map, cu);
162   }
163 }
164 
getModuleHierarchyInfo(const int64_t debug_handle,const std::string & top_module_type_name) const165 std::string MobileDebugTable::getModuleHierarchyInfo(
166     const int64_t debug_handle,
167     const std::string& top_module_type_name) const {
168   const auto it = callstack_ptr_map_.find(debug_handle);
169   if (it == callstack_ptr_map_.end()) {
170     return debugHandlesNotFoundMessage(std::to_string(debug_handle));
171   }
172   return (getStackTraceWithModuleHierarchy(
173               {it->second}, "top", top_module_type_name))
174       .second;
175 }
176 
getModuleHierarchyInfo(const std::vector<int64_t> & debug_handles,const std::string & top_module_type_name) const177 std::string MobileDebugTable::getModuleHierarchyInfo(
178     const std::vector<int64_t>& debug_handles,
179     const std::string& top_module_type_name) const {
180   return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name)
181       .second;
182 }
183 
getSourceDebugString(const int64_t debug_handle,const std::string & top_module_type_name) const184 std::string MobileDebugTable::getSourceDebugString(
185     const int64_t debug_handle,
186     const std::string& top_module_type_name) const {
187   const auto it = callstack_ptr_map_.find(debug_handle);
188   if (it == callstack_ptr_map_.end()) {
189     return debugHandlesNotFoundMessage(std::to_string(debug_handle));
190   }
191   return (getStackTraceWithModuleHierarchy(
192               {it->second}, "top", top_module_type_name))
193       .first;
194 }
195 
getSourceDebugString(const std::vector<int64_t> & debug_handles,const std::string & top_module_type_name) const196 std::string MobileDebugTable::getSourceDebugString(
197     const std::vector<int64_t>& debug_handles,
198     const std::string& top_module_type_name) const {
199   return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name)
200       .first;
201 }
202 
203 std::pair<std::string, std::string> MobileDebugTable::
getSourceDebugModuleHierarchyInfo(const std::vector<int64_t> & debug_handles,const std::string & top_module_type_name) const204     getSourceDebugModuleHierarchyInfo(
205         const std::vector<int64_t>& debug_handles,
206         const std::string& top_module_type_name) const {
207   std::vector<DebugInfoTuple> debug_infos;
208   bool debug_handle_not_found{false};
209   for (auto it = debug_handles.rbegin(); it != debug_handles.rend(); ++it) {
210     auto debug_handle = *it;
211     const auto cs_it = callstack_ptr_map_.find(debug_handle);
212     if (cs_it == callstack_ptr_map_.end()) {
213       debug_handle_not_found = true;
214       break;
215     }
216     debug_infos.emplace_back(cs_it->second);
217   }
218   if (debug_handle_not_found) {
219     std::string debug_handles_string = "debug_handles:{";
220     for (const auto debug_handle : debug_handles) {
221       debug_handles_string += std::to_string(debug_handle);
222     }
223     debug_handles_string += "}";
224     debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
225     return {debug_handles_string, debug_handles_string};
226   }
227   return (getStackTraceWithModuleHierarchy(
228       debug_infos, "top", top_module_type_name));
229 }
230 
231 } // namespace torch::jit
232