1 #include <c10/util/Exception.h>
2 #include <c10/util/ThreadLocal.h>
3 #include <c10/util/ThreadLocalDebugInfo.h>
4
5 #include <utility>
6
7 namespace c10 {
8
9 C10_DEFINE_TLS_static(std::shared_ptr<ThreadLocalDebugInfo>, tls_debug_info);
10 #define debug_info (tls_debug_info.get())
11
12 /* static */
get(DebugInfoKind kind)13 DebugInfoBase* ThreadLocalDebugInfo::get(DebugInfoKind kind) {
14 ThreadLocalDebugInfo* cur = debug_info.get();
15 while (cur) {
16 if (cur->kind_ == kind) {
17 return cur->info_.get();
18 }
19 cur = cur->parent_info_.get();
20 }
21 return nullptr;
22 }
23
24 /* static */
current()25 std::shared_ptr<ThreadLocalDebugInfo> ThreadLocalDebugInfo::current() {
26 return debug_info;
27 }
28
29 /* static */
_forceCurrentDebugInfo(std::shared_ptr<ThreadLocalDebugInfo> info)30 void ThreadLocalDebugInfo::_forceCurrentDebugInfo(
31 std::shared_ptr<ThreadLocalDebugInfo> info) {
32 debug_info = std::move(info);
33 }
34
35 /* static */
_push(DebugInfoKind kind,std::shared_ptr<DebugInfoBase> info)36 void ThreadLocalDebugInfo::_push(
37 DebugInfoKind kind,
38 std::shared_ptr<DebugInfoBase> info) {
39 auto prev_info = debug_info;
40 debug_info = std::make_shared<ThreadLocalDebugInfo>();
41 debug_info->parent_info_ = prev_info;
42 debug_info->kind_ = kind;
43 debug_info->info_ = std::move(info);
44 }
45
46 /* static */
_pop(DebugInfoKind kind)47 std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_pop(DebugInfoKind kind) {
48 TORCH_CHECK(
49 debug_info && debug_info->kind_ == kind,
50 "Expected debug info of type ",
51 (size_t)kind);
52 auto res = debug_info;
53 debug_info = debug_info->parent_info_;
54 return res->info_;
55 }
56
57 /* static */
_peek(DebugInfoKind kind)58 std::shared_ptr<DebugInfoBase> ThreadLocalDebugInfo::_peek(DebugInfoKind kind) {
59 TORCH_CHECK(
60 debug_info && debug_info->kind_ == kind,
61 "Expected debug info of type ",
62 (size_t)kind);
63 return debug_info->info_;
64 }
65
DebugInfoGuard(DebugInfoKind kind,std::shared_ptr<DebugInfoBase> info)66 DebugInfoGuard::DebugInfoGuard(
67 DebugInfoKind kind,
68 std::shared_ptr<DebugInfoBase> info) {
69 if (!info) {
70 return;
71 }
72 prev_info_ = debug_info;
73 ThreadLocalDebugInfo::_push(kind, std::move(info));
74 active_ = true;
75 }
76
~DebugInfoGuard()77 DebugInfoGuard::~DebugInfoGuard() {
78 if (active_) {
79 debug_info = prev_info_;
80 }
81 }
82
83 // Used only for setting a debug info after crossing the thread boundary;
84 // in this case we assume that thread pool's thread does not have an
85 // active debug info
DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info)86 DebugInfoGuard::DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info) {
87 if (!info) {
88 return;
89 }
90 prev_info_ = std::move(debug_info);
91 debug_info = std::move(info);
92 active_ = true;
93 }
94
95 } // namespace c10
96