xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/kineto_shim.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 #include <string>
5 
6 // Skip Kineto dependency on mobile unless explicitly asked for.
7 // When is it explicitly asked for?
8 //   KinetoEdgeCPUProfiler uses KinetoProfiler for cpu
9 //   event profiling. This has a dependency on cpu only libkineto
10 #if defined(USE_KINETO) && defined(C10_MOBILE) && \
11     !defined(EDGE_PROFILER_USE_KINETO)
12 #undef USE_KINETO
13 #endif
14 
15 #include <ActivityType.h>
16 
17 #include <torch/csrc/Export.h>
18 #include <torch/csrc/profiler/api.h>
19 
20 #ifdef USE_KINETO
21 // Forward declarations so we don't have to include `libkineto.h` in a header.
22 namespace libkineto {
23 class GenericTraceActivity;
24 struct CpuTraceBuffer;
25 class ActivityTraceInterface;
26 } // namespace libkineto
27 #endif
28 
29 namespace torch {
30 namespace profiler {
31 
32 #ifdef USE_KINETO
33 constexpr bool kKinetoAvailable{true};
34 #else
35 constexpr bool kKinetoAvailable{false};
36 #endif
37 
38 namespace impl::kineto {
39 
40 // ----------------------------------------------------------------------------
41 // -- Interface (Does not require Kineto) -------------------------------------
42 // ----------------------------------------------------------------------------
43 struct DeviceAndResource {
44   int32_t device;
45   int32_t resource;
46 };
47 const DeviceAndResource kineto_ids();
48 
49 #ifdef USE_KINETO
50 using trace_t = libkineto::CpuTraceBuffer;
51 using interface_trace_t = libkineto::ActivityTraceInterface;
52 using activity_t = libkineto::GenericTraceActivity;
53 #else
54 struct DummyTraceBuffer {};
55 struct DummyTraceInterface {};
56 
57 using trace_t = DummyTraceBuffer;
58 using interface_trace_t = DummyTraceBuffer;
59 struct activity_t;
60 #endif // USE_KINETO
61 
62 void addMetadata(
63     activity_t* activity,
64     const std::string& key,
65     const std::string& value);
66 
67 // Wraps: libkineto::CpuTraceBuffer
68 struct TraceWrapper {
69   TraceWrapper(const int64_t start_time, const std::string& name);
70   TraceWrapper(TraceWrapper&&) = default;
71   TraceWrapper(const TraceWrapper&) = delete;
72   ~TraceWrapper();
73 
74   // The caller is expected to hold a mutex when calling `addCPUActivity`.
75   activity_t* addCPUActivity(
76       const std::string& name,
77       const libkineto::ActivityType type,
78       const DeviceAndResource device_and_resource,
79       const uint64_t correlation_id,
80       const int64_t start_time,
81       const int64_t end_time);
82 
83   void transferCpuTrace(int64_t end_time);
84 
85   explicit operator bool() const;
86 
getTraceWrapper87   std::unique_ptr<trace_t>& get() {
88     return cpu_trace_;
89   }
90 
91  private:
92   std::unique_ptr<trace_t> cpu_trace_;
93 };
94 
95 // Wraps libkineto::ActivityTraceInterface
96 struct ActivityTraceWrapper {
97   explicit ActivityTraceWrapper(std::unique_ptr<interface_trace_t>&& trace);
98   ActivityTraceWrapper() = default;
99   ActivityTraceWrapper(ActivityTraceWrapper&&) = default;
100   ActivityTraceWrapper(const ActivityTraceWrapper&) = delete;
101   explicit operator bool() const;
102   void save(const std::string& path);
103 
getActivityTraceWrapper104   const std::unique_ptr<interface_trace_t>& get() {
105     return trace_;
106   }
107 
108  private:
109   std::unique_ptr<interface_trace_t> trace_;
110 #ifdef USE_KINETO
111   bool saved_ = false; // Kineto's save is destructive
112 #endif
113 };
114 
115 using ActivitySet = std::set<torch::autograd::profiler::ActivityType>;
116 void prepareTrace(
117     const bool cpuOnly,
118     const ActivitySet& activities,
119     const torch::profiler::impl::ExperimentalConfig& config);
120 
121 void toggleCollectionDynamic(const bool enable);
122 void startTrace();
123 ActivityTraceWrapper stopTrace();
124 void pushCorrelationId(uint64_t correlation_id);
125 void pushUserCorrelationId(uint64_t correlation_id);
126 void popCorrelationId();
127 void popUserCorrelationId();
128 void recordThreadInfo();
129 bool collectivesProfilerExists();
130 
131 void logInvariantViolation(
132     const std::string& assertion,
133     const std::string& error,
134     const std::string& profile_id,
135     const std::string& group_profile_id);
136 
137 } // namespace impl::kineto
138 
139 } // namespace profiler
140 
141 namespace autograd::profiler {
142 c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type);
143 
144 TORCH_API void addMetadataJson(
145     const std::string& key,
146     const std::string& value);
147 
148 TORCH_API void profilerStep();
149 
150 } // namespace autograd::profiler
151 
152 } // namespace torch
153