xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/orchestration/observer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/record_function.h>
4 #include <torch/csrc/Export.h>
5 
6 #include <utility>
7 
8 namespace torch {
9 namespace profiler {
10 namespace impl {
11 
12 // ----------------------------------------------------------------------------
13 // -- Profiler Config ---------------------------------------------------------
14 // ----------------------------------------------------------------------------
15 enum class C10_API_ENUM ActivityType {
16   CPU = 0,
17   XPU, // XPU kernels, runtime
18   CUDA, // CUDA kernels, runtime
19   MTIA, // MTIA kernels, runtime
20   PrivateUse1, // PrivateUse1 kernels, runtime
21   NUM_KINETO_ACTIVITIES, // must be the last one
22 };
23 
actToString(ActivityType t)24 inline std::string actToString(ActivityType t) {
25   const std::string ActivityTypeNames[] = {
26       "CPU", "XPU", "CUDA", "MTIA", "PrivateUse1"};
27   return ActivityTypeNames[static_cast<int>(t)];
28 }
29 
30 enum class C10_API_ENUM ProfilerState {
31   Disabled = 0,
32   CPU, // CPU-only profiling
33   CUDA, // CPU + CUDA events
34   NVTX, // only emit NVTX markers
35   ITT, // only emit ITT markers
36   PRIVATEUSE1, // only emit PRIVATEUSE1 markers
37   KINETO, // use libkineto
38   KINETO_GPU_FALLBACK, // use CUDA events when CUPTI is not available
39   KINETO_PRIVATEUSE1_FALLBACK, // use PrivateUse1 events
40   KINETO_ONDEMAND, // run the profiler in on-demand mode
41   NUM_PROFILER_STATES, // must be the last one
42 };
43 
44 enum class C10_API_ENUM ActiveProfilerType {
45   NONE = 0,
46   LEGACY,
47   KINETO,
48   NVTX,
49   ITT,
50   PRIVATEUSE1
51 };
52 
53 struct TORCH_API ExperimentalConfig {
54   ExperimentalConfig(
55       std::vector<std::string> profiler_metrics = {},
56       bool profiler_measure_per_kernel = false,
57       bool verbose = false,
58       std::vector<std::string> performance_events = {},
59       bool enable_cuda_sync_events = false,
60       bool adjust_timestamps = false);
61   explicit operator bool() const;
62 
63   std::vector<std::string> profiler_metrics;
64   bool profiler_measure_per_kernel;
65   bool verbose;
66   /*
67    * List of performance events to be profiled.
68    * An empty list will disable performance event based profiling altogether.
69    */
70   std::vector<std::string> performance_events;
71   /*
72    * For CUDA profiling mode, enable adding CUDA synchronization events
73    * that expose CUDA device, stream and event synchronization activities.
74    * This feature is new and currently disabled by default.
75    */
76   bool enable_cuda_sync_events;
77   /*
78    * Controls whether or not timestamp adjustment occurs after profiling.
79    * The purpose of this is to adjust Vulkan event timelines to align with those
80    * of their parent CPU events.
81    * This sometimes requires increasing CPU event durations (to fully contain
82    * their child events) and delaying CPU event start times (to
83    * prevent overlaps), so this should not be used unless Vulkan events are
84    * being profiled and it is ok to use this modified timestamp/duration
85    * information instead of the original information.
86    */
87   bool adjust_timestamps;
88 };
89 
90 struct TORCH_API ProfilerConfig {
91   ProfilerConfig(
92       ProfilerState state,
93       bool report_input_shapes = false,
94       bool profile_memory = false,
95       bool with_stack = false,
96       bool with_flops = false,
97       bool with_modules = false,
98       ExperimentalConfig experimental_config = ExperimentalConfig());
99 
100   bool disabled() const;
101   bool global() const;
102 
103   ProfilerState state;
104   ExperimentalConfig experimental_config;
105   bool report_input_shapes;
106   bool profile_memory;
107   bool with_stack;
108   bool with_flops;
109   bool with_modules;
110 
111   // For serialization
112   at::IValue toIValue() const;
113   static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue);
114 };
115 
116 // ----------------------------------------------------------------------------
117 // -- Profiler base class -----------------------------------------------------
118 // ----------------------------------------------------------------------------
119 struct TORCH_API ProfilerStateBase : public c10::MemoryReportingInfoBase {
120   explicit ProfilerStateBase(ProfilerConfig config);
121   ~ProfilerStateBase() override;
122 
123   static ProfilerStateBase* get(bool global);
getProfilerStateBase124   static ProfilerStateBase* get() {
125     auto* out = get(/*global=*/true);
126     return out ? out : get(/*global=*/false);
127   }
128 
129   static void push(std::shared_ptr<ProfilerStateBase>&& state);
130 
131   static std::shared_ptr<ProfilerStateBase> pop(bool global);
popProfilerStateBase132   static std::shared_ptr<ProfilerStateBase> pop() {
133     auto out = pop(/*global=*/true);
134     return out ? std::move(out) : pop(/*global=*/false);
135   }
136 
configProfilerStateBase137   const ProfilerConfig& config() const {
138     return config_;
139   }
140 
141   void setCallbackHandle(at::CallbackHandle handle);
142   void removeCallback();
143 
memoryProfilingEnabledProfilerStateBase144   bool memoryProfilingEnabled() const override {
145     return config_.profile_memory;
146   }
147 
148   virtual ActiveProfilerType profilerType() = 0;
149 
150  protected:
151   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
152   std::mutex state_mutex_;
153   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
154   ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled);
155   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
156   at::CallbackHandle handle_ = 0;
157 };
158 
159 // Note: The following are only for the active *thread local* profiler.
160 TORCH_API bool profilerEnabled();
161 TORCH_API ActiveProfilerType profilerType();
162 TORCH_API ProfilerConfig getProfilerConfig();
163 
164 } // namespace impl
165 } // namespace profiler
166 } // namespace torch
167