xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/profiler_legacy.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <iostream>
5 #include <memory>
6 #include <mutex>
7 #include <string>
8 #include <vector>
9 
10 #include <torch/csrc/Export.h>
11 #include <torch/csrc/profiler/api.h>
12 #include <torch/csrc/profiler/stubs/base.h>
13 #include <torch/csrc/profiler/util.h>
14 
15 namespace torch::autograd::profiler {
16 
17 enum class C10_API_ENUM EventKind : uint16_t {
18   Mark,
19   PushRange,
20   PopRange,
21   MemoryAlloc,
22 };
23 
24 // To be deprecated, once we switch to Kineto profiling
25 struct TORCH_API LegacyEvent {
26   LegacyEvent(
27       EventKind kind,
28       at::StringView name,
29       uint16_t thread_id,
30       bool record_cuda,
31       at::RecordFunctionHandle handle = 0,
32       std::vector<std::vector<int64_t>>&& shapes = {},
33       int64_t node_id = -1,
34       bool is_async = false)
name_LegacyEvent35       : name_(std::move(name)),
36         kind_(kind),
37         thread_id_(thread_id),
38         handle_(handle),
39         shapes_(std::move(shapes)),
40         node_id_(node_id),
41         is_async_(is_async) {
42     record(record_cuda);
43   }
44 
45   // Constructor to be used in conjunction with LegacyEvent::fromIValue.
46   LegacyEvent(
47       EventKind kind,
48       at::StringView name,
49       uint16_t thread_id,
50       at::RecordFunctionHandle handle,
51       std::vector<std::vector<int64_t>>&& shapes,
52       int64_t node_id,
53       bool is_remote,
54       int64_t cpu_memory_usage,
55       int64_t cpu_ns,
56       bool cuda_recorded,
57       int64_t cuda_memory_usage = 0,
58       c10::DeviceIndex device = -1,
59       double cuda_us = -1)
cpu_ns_LegacyEvent60       : cpu_ns_(cpu_ns),
61         name_(std::move(name)),
62         kind_(kind),
63         thread_id_(thread_id),
64         handle_(handle),
65         shapes_(std::move(shapes)),
66         cpu_memory_usage_(cpu_memory_usage),
67         cuda_memory_usage_(cuda_memory_usage),
68         device_(device),
69         node_id_(node_id),
70         is_remote_(is_remote),
71         cuda_us_(static_cast<int64_t>(cuda_us)) {
72     // Sanity check values that were deserialized
73     TORCH_INTERNAL_ASSERT(cpu_ns_ > 0);
74     if (cuda_recorded) {
75       TORCH_INTERNAL_ASSERT(device_ >= 0);
76       TORCH_INTERNAL_ASSERT(cuda_us_ >= 0);
77     }
78   }
79 
80   // Returns IValues corresponding to event structure, to be used for
81   // serialization.
82   at::IValue toIValue() const;
83 
84   // Reconstructs an event from IValues given by toIValue.
85   static LegacyEvent fromIValue(const at::IValue& eventIValue);
86 
87   void record(bool record_cuda);
88 
kindStrLegacyEvent89   std::string kindStr() const {
90     switch (kind_) {
91       case EventKind::Mark:
92         return "mark";
93       case EventKind::PushRange:
94         return "push";
95       case EventKind::PopRange:
96         return "pop";
97       case EventKind::MemoryAlloc:
98         return "memory_alloc";
99     }
100     throw std::runtime_error("unknown event kind");
101   }
102 
kindLegacyEvent103   EventKind kind() const {
104     return kind_;
105   }
106 
nameLegacyEvent107   const char* name() const {
108     return name_.str();
109   }
110 
threadIdLegacyEvent111   uint64_t threadId() const {
112     return thread_id_;
113   }
114 
shapesLegacyEvent115   std::vector<std::vector<int64_t>> shapes() const {
116     return shapes_;
117   }
118 
cpuElapsedUsLegacyEvent119   double cpuElapsedUs(const LegacyEvent& e) const {
120     return static_cast<double>(e.cpu_ns_ - cpu_ns_) / (1000.0);
121   }
122 
setCpuUsLegacyEvent123   void setCpuUs(int64_t cpu_us) {
124     cpu_ns_ = cpu_us * 1000;
125   }
126 
cpuUsLegacyEvent127   double cpuUs() const {
128     return static_cast<double>(cpu_ns_) / (1000.0);
129   }
130 
131   double cudaElapsedUs(const LegacyEvent& e) const;
132 
hasCudaLegacyEvent133   bool hasCuda() const {
134     return cuda_event != nullptr || (isRemote() && device_ != -1);
135   }
136 
deviceLegacyEvent137   c10::DeviceIndex device() const {
138     return device_;
139   }
140 
updateMemoryStatsLegacyEvent141   void updateMemoryStats(int64_t alloc_size, c10::Device device) {
142     if (device.is_cuda() || device.type() == c10::DeviceType::HIP) {
143       cuda_memory_usage_ = alloc_size;
144     } else if (
145         device.is_cpu() || device.type() == c10::DeviceType::MKLDNN ||
146         device.type() == c10::DeviceType::IDEEP) {
147       cpu_memory_usage_ = alloc_size;
148     } else {
149       LOG(WARNING) << "Unsupported memory profiling device: " << device;
150     }
151   }
152 
cpuMemoryUsageLegacyEvent153   int64_t cpuMemoryUsage() const {
154     return cpu_memory_usage_;
155   }
156 
cudaMemoryUsageLegacyEvent157   int64_t cudaMemoryUsage() const {
158     return cuda_memory_usage_;
159   }
160 
handleLegacyEvent161   at::RecordFunctionHandle handle() const {
162     return handle_;
163   }
164 
165   // Node ID corresponding to this event.
nodeIdLegacyEvent166   int64_t nodeId() const {
167     return node_id_;
168   }
169 
170   // Set Node ID on this event.
setNodeIdLegacyEvent171   void setNodeId(int64_t node_id) {
172     node_id_ = node_id;
173   }
174 
setNameLegacyEvent175   void setName(at::StringView newName_) {
176     name_ = std::move(newName_);
177   }
178 
isRemoteLegacyEvent179   bool isRemote() const {
180     return is_remote_;
181   }
182 
setCudaUsLegacyEvent183   void setCudaUs(int64_t cuda_us) {
184     cuda_us_ = cuda_us;
185   }
186 
setSequenceNrLegacyEvent187   void setSequenceNr(int64_t sequence_nr) {
188     sequence_nr_ = sequence_nr;
189   }
190 
sequenceNrLegacyEvent191   int64_t sequenceNr() const {
192     return sequence_nr_;
193   }
194 
setCorrelationIdLegacyEvent195   void setCorrelationId(uint64_t correlation_id) {
196     correlation_id_ = correlation_id;
197   }
198 
correlationIdLegacyEvent199   uint64_t correlationId() const {
200     return correlation_id_;
201   }
202 
stackLegacyEvent203   const std::vector<std::string>& stack() const {
204     return stack_;
205   }
206 
setStackLegacyEvent207   void setStack(const std::vector<std::string>& stack) {
208     stack_ = stack;
209   }
210 
fwdThreadIdLegacyEvent211   uint64_t fwdThreadId() const {
212     return fwd_thread_id_;
213   }
214 
setFwdThreadIdLegacyEvent215   void setFwdThreadId(uint64_t fwd_thread_id) {
216     fwd_thread_id_ = fwd_thread_id;
217   }
218 
scopeLegacyEvent219   uint8_t scope() const {
220     return scope_;
221   }
222 
setScopeLegacyEvent223   void setScope(uint8_t scope) {
224     scope_ = scope;
225   }
226 
extraArgsLegacyEvent227   const std::unordered_map<std::string, c10::IValue>& extraArgs() const {
228     return extra_args_;
229   }
230 
setExtraArgsLegacyEvent231   void setExtraArgs(std::unordered_map<std::string, c10::IValue>&& save_args) {
232     extra_args_ = std::move(save_args);
233   }
234 
flopsLegacyEvent235   uint64_t flops() {
236     return flops_;
237   }
238 
isAsyncLegacyEvent239   bool isAsync() {
240     return is_async_;
241   }
242 
setFlopsLegacyEvent243   void setFlops(uint64_t flops) {
244     flops_ = flops;
245   }
246 
247  private:
248   // signed to allow for negative intervals, initialized for safety.
249   int64_t cpu_ns_ = 0;
250   at::StringView name_;
251   EventKind kind_;
252   uint64_t thread_id_;
253   uint64_t fwd_thread_id_{0};
254   at::RecordFunctionHandle handle_{0};
255   std::vector<std::vector<int64_t>> shapes_;
256   int64_t cpu_memory_usage_ = 0;
257   int64_t cuda_memory_usage_ = 0;
258   c10::DeviceIndex device_ = -1;
259   torch::profiler::impl::ProfilerVoidEventStub cuda_event = nullptr;
260   int64_t node_id_ = 0;
261   bool is_remote_ = false;
262   int64_t cuda_us_ = -1;
263   int64_t sequence_nr_ = -1;
264   bool is_async_ = false;
265 
266   std::vector<std::string> stack_;
267   uint8_t scope_{0};
268   uint64_t correlation_id_{0};
269   // Extra arguments for computing op flops
270   std::unordered_map<std::string, c10::IValue> extra_args_;
271   uint64_t flops_ = 0;
272 };
273 
274 // a linked-list of fixed sized vectors, to avoid
275 // a std::vector resize from taking a large amount of time inside
276 // a profiling  event
277 struct RangeEventList {
RangeEventListRangeEventList278   RangeEventList() {
279     events_.reserve(kReservedCapacity);
280   }
281 
282   template <typename... Args>
recordRangeEventList283   void record(Args&&... args) {
284     std::lock_guard<std::mutex> guard(mutex_);
285     events_.emplace_back(std::forward<Args>(args)...);
286   }
287 
consolidateRangeEventList288   std::vector<LegacyEvent> consolidate() {
289     std::lock_guard<std::mutex> lock(mutex_);
290     std::vector<LegacyEvent> result;
291     result.insert(
292         result.begin(),
293         std::make_move_iterator(events_.begin()),
294         std::make_move_iterator(events_.end()));
295     events_.erase(events_.begin(), events_.end());
296     return result;
297   }
298 
sizeRangeEventList299   size_t size() {
300     std::lock_guard<std::mutex> lock(mutex_);
301     return events_.size();
302   }
303 
304  private:
305   // This mutex is used to serialize access when different threads are writing
306   // to the same instance of RangeEventList.
307   std::mutex mutex_;
308   std::vector<LegacyEvent> events_;
309 
310   static const size_t kReservedCapacity = 1024;
311 };
312 
313 // A struct to control settings of disableProfiler options.
314 struct TORCH_API ProfilerDisableOptions {
315   ProfilerDisableOptions() = default;
ProfilerDisableOptionsProfilerDisableOptions316   ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate)
317       : cleanupTLSState(shouldCleanupTLSState),
318         consolidate(shouldConsolidate) {}
319   // Whether we should clean up profiler states that are thread local, such as
320   // ThreadLocalDebugInfo and thread local RecordFunction callbacks.
321   bool cleanupTLSState = true;
322   // Whether we should consolidate all currently recorded profiled events. If
323   // false, will not consolidate and other threads can continue to write to the
324   // event lists.
325   bool consolidate = true;
326 };
327 
328 // NOTE: profiler mode is thread local, with automatic propagation
329 // across thread boundary (e.g. at::launch tasks)
330 TORCH_API void enableProfilerLegacy(
331     const torch::profiler::impl::ProfilerConfig&);
332 using thread_event_lists = std::vector<std::vector<LegacyEvent>>;
333 TORCH_API thread_event_lists disableProfilerLegacy(
334     std::optional<ProfilerDisableOptions> profilerDisableOptions =
335         std::nullopt);
336 
337 // adds profiledEvents to the current thread local recorded events. Each event
338 // will be marked with node ID given by fromNodeId.
339 TORCH_API void addEventList(std::vector<LegacyEvent>&& profiledEvents);
340 // Writes profiled events to a stream.
341 TORCH_API void writeProfilerEventsToStream(
342     std::ostream& out,
343     const std::vector<LegacyEvent*>& events);
344 
345 // Usage:
346 //   {
347 //     RecordProfile guard("filename.trace");
348 //     // code you want to profile
349 //   }
350 // Then open filename.trace in chrome://tracing
351 struct TORCH_API RecordProfile {
352   RecordProfile(std::ostream& out);
353   RecordProfile(const std::string& filename);
354 
355   ~RecordProfile();
356 
357  private:
358   void init();
359   std::unique_ptr<std::ofstream> file_;
360   std::ostream& out_;
361   void processEvents(const std::vector<LegacyEvent*>& events);
362 };
363 
364 // A guard that enables the legacy profiler, taking in an optional callback to
365 // process the results Usage:
366 // {
367 //   TLSLegacyProfilerGuard g([](thread_event_lists profilerResults) {
368 //     // process profilerResults
369 //   });
370 //   Code to profile
371 // }
372 struct TORCH_API TLSLegacyProfilerGuard {
373   explicit TLSLegacyProfilerGuard(
374       const torch::profiler::impl::ProfilerConfig& cfg,
375       std::optional<std::function<void(const thread_event_lists&)>>
376           resultCallback = std::nullopt,
377       std::optional<ProfilerDisableOptions> profilerDisableOptions =
378           std::nullopt)
cb_TLSLegacyProfilerGuard379       : cb_(std::move(resultCallback)),
380         profilerDisableOptions_(profilerDisableOptions) {
381     enableProfilerLegacy(cfg);
382   }
~TLSLegacyProfilerGuardTLSLegacyProfilerGuard383   ~TLSLegacyProfilerGuard() {
384     thread_event_lists event_lists =
385         disableProfilerLegacy(profilerDisableOptions_);
386     if (cb_) {
387       try {
388         (*cb_)(event_lists);
389       } catch (const std::exception& e) {
390         LOG(ERROR) << "Got error processing profiler events: " << e.what();
391       }
392     }
393   }
394 
395  private:
396   std::optional<std::function<void(const thread_event_lists&)>> cb_;
397   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
398   const std::optional<ProfilerDisableOptions> profilerDisableOptions_;
399 };
400 
401 } // namespace torch::autograd::profiler
402