xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/backend_debug_handler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/ivalue.h>
3 
4 #include <torch/csrc/jit/backends/backend_detail.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/scope.h>
7 
8 #include <atomic>
9 
10 namespace torch {
11 namespace jit {
12 
13 /*
14  *  BackendDebugHandleManager is responsible for issuing debug handles to
15  *  backends. Debug handles are associated with nodes of a graph.
16  *  BackendDebugHandleManager also maintains a map
17  *  [debug-handle, DebugInfoTuple = {source range, inlined callstack ptr]} that
18  *  will help generate a callstack for exception raised using debug handles.
19  *  Effectively debug handles are something that is given to backend and later
20  *  when an exception occurs in the backend, backend can tell, using debug
21  *  handle, that an exception occurred here. Then the runtime can generate
22  *  callstack correspoding to the exception.
23  *  There are two parts to BackendDebugHandleManager:
24  *  1. static std::atomic debug_handle
25  *  2. Map of [debug-handle, DebugInfoTuple]
26  *
27  *  About 1:
28  *  Why do they have to be unique. The reason is that by ensuring
29  *  uniqueness of debug handles, we remove the burden of another layer of
30  *  mapping where we need to say this set of debug handles were generated for
31  *  this lowered module or this bytecode function. This simplifies the API for
32  *  serialization since debug handles can uniquely identify DebugInfoTuple.
33  *  Thus simplifies the runtime API for throwing exception. Exception throwing
34  *  only needs to know debug_handle and not which module or method threw it.
35  *  There are 2 issues to keep in mind, though,for static std::atomic
36  *  debug_handle: A. Performance implications of using atomic variable. However
37  *  this is only used for compilation so we assume to absorb some of that
38  *  penalty. Plus if there is no contention then we should have less to worry
39  *  about. B. If repeated compilation is part of a long running process then we
40  *  may overflow int64_t. We may detect and fail on this. For now this is not
41  *  done.
42  *
43  *  Now about 2:
44  *  There are two usecases for [debug-handle, DebugInfoTuple]
45  *  A. During bytecode generation the DebugInfoTuple corresponding to the nodes
46  *  of the inlined graph being serialized, are stored in this object and a
47  *  unique debug handle is returned. This unique debug handle is stored in
48  *  mobile_debug info for pytorch lite models. It will be used for raising
49  *  exceptions as well as profiling. B. During backend lowering, each backend's
50  *  preprocess/compile method can compile method's graph and serialize those
51  *  methods. Once the method is lowered to backend, graph is essentially lost.
52  *  Without access to graph it is hard to generate model level debug info. Thus
53  *  the debug handles provide a way to map nodes of the graph to the model level
54  *  debug info.
55  *
56  *  During byte-code model serialization, [debug-handle, DebugInfoTuple] is
57  *  serialized. Now we know a. debug handles and b. how to map debug handles to
58  *  model source code. Thus we can either do eager symbolication by converting
59  *  debug handles to corresponding source code at runtime, or do lazy
60  *  symbolicattion offline.
61  *
62  *  Note that it is not necessary to serialize [debug-handle, DebugInfoTuple]
63  *  corresponding to lowered backend if the lowering process, that is
64  *  preprocess/compile, and execution happens in the same session, then eager
65  *  symbolication can be employed.
66  *
67  *  Now how does BackendDebugHandleManager capture all of the above?
68  *  By providing two API.
69  *  1. getNextDebugHandle which given a Node* returns a unique debug handle,
70  *     that will uniquely identify DebugInfoTuple.
71  *     and
72  *  2. getCallStackPtrMap which returns the map
73  *     [debug-handle, DebugInfoTuple]
74  *
75  *  1 provides debug handles to backends and 2 provides runtime a way to map
76  *  debug handles to source level debug info.
77  *
78  *  So why does debug handle map to DebugInfoTuple = {source range and inlined
79  *  cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
80  *  example: class L(nn.Module): def __init__(self) -> None:
81  *      ...
82  *    def forward(self, x):
83  *      return x * 5
84  *  class M(nn.Module):
85  *    def __init__(self) -> None:
86  *      ...
87  *    def forward(self, x):
88  *      return x - 2
89  *  class N(nn.Module):
90  *    def __init__(self) -> None:
91  *      self.m = M()
92  *    def forward(self, x):
93  *      return self.m(x) + 3
94  *  m = torch.jit.script(N())
95  *  Once you inline m's forward method, m.forward.graph will look something
96  *  like this
97  *  graph(%self...):
98  *   %x = aten::mul(..)
99  *   %x = aten::sub(x, ..)
100  *   %y = aten::add(x, ..)
101  *   ..
102  *  Inlined callstack ptr for these two nodes will look like:
103  *  aten::mul's inlined CS (callstack): [N.forward, source range] -> [M.forward,
104  *  source range] aten::sub's inlined CS (callstack): [N.forward, source range]
105  *  aten::add's inlined CS: null
106  *  mul node's inlined CS contains only information about the callsites' source
107  *  range The information about mul node's source range ('return x * 5') is not
108  *  available in its inlined CS. It is rather part of node's source range
109  *  instead of inlined CS. Thus to get full stack: [N.forward, source range] ->
110  *  [M.forward, source range] -> [aten::mul's source range] We need to track
111  *  mul's source range and inlined CS both.
112  */
113 
114 using BackendDebugInfoMapType =
115     std::unordered_map<torch::jit::DebugHandleType, DebugInfoTuple>;
116 
117 /*
118  * This class is used to generate debug info map.
119  * backend's preprocess will call generate_debug_handles (see
120  * backend_detail.cpp), which uses debug_handle_manager to generate debug
121  * handles. When lowering process finishes, calling stopRecording will
122  * return debug info map from debug_handle_manager
123  */
124 class TORCH_API BackendDebugInfoRecorder {
125  public:
126   BackendDebugInfoRecorder() = default;
127   int64_t getNextDebugHandle(const Node* node);
128   // Reason this is not done as RAII is that work done in stopRecording
129   // can throw, and throwing with dtor will call terminate and thus voids any
130   // exception catching at a higher level.
131   BackendDebugInfoMapType stopRecording();
132   NodeToDebugHandle generate_debug_handles(const std::shared_ptr<Graph>& graph);
133 
134  private:
135   static std::atomic<DebugHandleType> unique_debug_handle_;
136   BackendDebugInfoMapType handles_to_inlined_callstack_ptrs_;
137 };
138 
139 } // namespace jit
140 } // namespace torch
141