1 #pragma once 2 3 #include <c10/macros/Export.h> 4 5 #include <cstdint> 6 #include <memory> 7 8 namespace c10 { 9 10 enum class C10_API_ENUM DebugInfoKind : uint8_t { 11 PRODUCER_INFO = 0, 12 MOBILE_RUNTIME_INFO, 13 PROFILER_STATE, 14 INFERENCE_CONTEXT, // for inference usage 15 PARAM_COMMS_INFO, 16 17 TEST_INFO, // used only in tests 18 TEST_INFO_2, // used only in tests 19 }; 20 21 class C10_API DebugInfoBase { 22 public: 23 DebugInfoBase() = default; 24 virtual ~DebugInfoBase() = default; 25 }; 26 27 // Thread local debug information is propagated across the forward 28 // (including async fork tasks) and backward passes and is supposed 29 // to be utilized by the user's code to pass extra information from 30 // the higher layers (e.g. model id) down to the lower levels 31 // (e.g. to the operator observers used for debugging, logging, 32 // profiling, etc) 33 class C10_API ThreadLocalDebugInfo { 34 public: 35 static DebugInfoBase* get(DebugInfoKind kind); 36 37 // Get current ThreadLocalDebugInfo 38 static std::shared_ptr<ThreadLocalDebugInfo> current(); 39 40 // Internal, use DebugInfoGuard/ThreadLocalStateGuard 41 static void _forceCurrentDebugInfo( 42 std::shared_ptr<ThreadLocalDebugInfo> info); 43 44 // Push debug info struct of a given kind 45 static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info); 46 // Pop debug info, throws in case the last pushed 47 // debug info is not of a given kind 48 static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind); 49 // Peek debug info, throws in case the last pushed debug info is not of the 50 // given kind 51 static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind); 52 53 private: 54 std::shared_ptr<DebugInfoBase> info_; 55 DebugInfoKind kind_; 56 std::shared_ptr<ThreadLocalDebugInfo> parent_info_; 57 58 friend class DebugInfoGuard; 59 }; 60 61 // DebugInfoGuard is used to set debug information, 62 // ThreadLocalDebugInfo is semantically immutable, the values are set 63 // through the scope-based guard object. 64 // Nested DebugInfoGuard adds/overrides existing values in the scope, 65 // restoring the original values after exiting the scope. 66 // Users can access the values through the ThreadLocalDebugInfo::get() call; 67 class C10_API DebugInfoGuard { 68 public: 69 DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info); 70 71 explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info); 72 73 ~DebugInfoGuard(); 74 75 DebugInfoGuard(const DebugInfoGuard&) = delete; 76 DebugInfoGuard(DebugInfoGuard&&) = delete; 77 78 private: 79 bool active_ = false; 80 std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr; 81 }; 82 83 } // namespace c10 84