xref: /aosp_15_r20/external/pytorch/c10/util/ThreadLocalDebugInfo.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Export.h>
4 
5 #include <cstdint>
6 #include <memory>
7 
8 namespace c10 {
9 
10 enum class C10_API_ENUM DebugInfoKind : uint8_t {
11   PRODUCER_INFO = 0,
12   MOBILE_RUNTIME_INFO,
13   PROFILER_STATE,
14   INFERENCE_CONTEXT, // for inference usage
15   PARAM_COMMS_INFO,
16 
17   TEST_INFO, // used only in tests
18   TEST_INFO_2, // used only in tests
19 };
20 
21 class C10_API DebugInfoBase {
22  public:
23   DebugInfoBase() = default;
24   virtual ~DebugInfoBase() = default;
25 };
26 
27 // Thread local debug information is propagated across the forward
28 // (including async fork tasks) and backward passes and is supposed
29 // to be utilized by the user's code to pass extra information from
30 // the higher layers (e.g. model id) down to the lower levels
31 // (e.g. to the operator observers used for debugging, logging,
32 // profiling, etc)
33 class C10_API ThreadLocalDebugInfo {
34  public:
35   static DebugInfoBase* get(DebugInfoKind kind);
36 
37   // Get current ThreadLocalDebugInfo
38   static std::shared_ptr<ThreadLocalDebugInfo> current();
39 
40   // Internal, use DebugInfoGuard/ThreadLocalStateGuard
41   static void _forceCurrentDebugInfo(
42       std::shared_ptr<ThreadLocalDebugInfo> info);
43 
44   // Push debug info struct of a given kind
45   static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
46   // Pop debug info, throws in case the last pushed
47   // debug info is not of a given kind
48   static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
49   // Peek debug info, throws in case the last pushed debug info is not of the
50   // given kind
51   static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
52 
53  private:
54   std::shared_ptr<DebugInfoBase> info_;
55   DebugInfoKind kind_;
56   std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
57 
58   friend class DebugInfoGuard;
59 };
60 
61 // DebugInfoGuard is used to set debug information,
62 // ThreadLocalDebugInfo is semantically immutable, the values are set
63 // through the scope-based guard object.
64 // Nested DebugInfoGuard adds/overrides existing values in the scope,
65 // restoring the original values after exiting the scope.
66 // Users can access the values through the ThreadLocalDebugInfo::get() call;
67 class C10_API DebugInfoGuard {
68  public:
69   DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
70 
71   explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
72 
73   ~DebugInfoGuard();
74 
75   DebugInfoGuard(const DebugInfoGuard&) = delete;
76   DebugInfoGuard(DebugInfoGuard&&) = delete;
77 
78  private:
79   bool active_ = false;
80   std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
81 };
82 
83 } // namespace c10
84