xref: /aosp_15_r20/external/pytorch/c10/util/ThreadLocalDebugInfo.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <c10/util/ThreadLocal.h>
3 #include <c10/util/ThreadLocalDebugInfo.h>
4 
5 #include <utility>
6 
7 namespace c10 {
8 
9 C10_DEFINE_TLS_static(std::shared_ptr<ThreadLocalDebugInfo>, tls_debug_info);
10 #define debug_info (tls_debug_info.get())
11 
12 /* static */
get(DebugInfoKind kind)13 DebugInfoBase* ThreadLocalDebugInfo::get(DebugInfoKind kind) {
14   ThreadLocalDebugInfo* cur = debug_info.get();
15   while (cur) {
16     if (cur->kind_ == kind) {
17       return cur->info_.get();
18     }
19     cur = cur->parent_info_.get();
20   }
21   return nullptr;
22 }
23 
24 /* static */
current()25 std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() {
26   return debug_info;
27 }
28 
29 /* static */
_forceCurrentDebugInfo(std::shared_ptr<ThreadLocalDebugInfo> info)30 void ThreadLocalDebugInfo::_forceCurrentDebugInfo(
31     std::shared_ptr<ThreadLocalDebugInfo> info) {
32   debug_info = std::move(info);
33 }
34 
35 /* static */
_push(DebugInfoKind kind,std::shared_ptr<DebugInfoBase> info)36 void ThreadLocalDebugInfo::_push(
37     DebugInfoKind kind,
38     std::shared_ptr<DebugInfoBase> info) {
39   auto prev_info = debug_info;
40   debug_info = std::make_shared<ThreadLocalDebugInfo>();
41   debug_info->parent_info_ = prev_info;
42   debug_info->kind_ = kind;
43   debug_info->info_ = std::move(info);
44 }
45 
46 /* static */
_pop(DebugInfoKind kind)47 std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_pop(DebugInfoKind kind) {
48   TORCH_CHECK(
49       debug_info && debug_info->kind_ == kind,
50       "Expected debug info of type ",
51       (size_t)kind);
52   auto res = debug_info;
53   debug_info = debug_info->parent_info_;
54   return res->info_;
55 }
56 
57 /* static */
_peek(DebugInfoKind kind)58 std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_peek(DebugInfoKind kind) {
59   TORCH_CHECK(
60       debug_info && debug_info->kind_ == kind,
61       "Expected debug info of type ",
62       (size_t)kind);
63   return debug_info->info_;
64 }
65 
DebugInfoGuard(DebugInfoKind kind,std::shared_ptr<DebugInfoBase> info)66 DebugInfoGuard::DebugInfoGuard(
67     DebugInfoKind kind,
68     std::shared_ptr<DebugInfoBase> info) {
69   if (!info) {
70     return;
71   }
72   prev_info_ = debug_info;
73   ThreadLocalDebugInfo::_push(kind, std::move(info));
74   active_ = true;
75 }
76 
~DebugInfoGuard()77 DebugInfoGuard::~DebugInfoGuard() {
78   if (active_) {
79     debug_info = prev_info_;
80   }
81 }
82 
83 // Used only for setting a debug info after crossing the thread boundary;
84 // in this case we assume that thread pool's thread does not have an
85 // active debug info
DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info)86 DebugInfoGuard::DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info) {
87   if (!info) {
88     return;
89   }
90   prev_info_ = std::move(debug_info);
91   debug_info = std::move(info);
92   active_ = true;
93 }
94 
95 } // namespace c10
96