xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/cuda/cuda.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAEvent.h>
2 #include <c10/core/Device.h>
3 #include <c10/cuda/CUDAStream.h>
4 #include <torch/custom_class.h>
5 
6 namespace torch::jit {
7 
8 class CUDAEvent;
9 // This class is a wrapper around c10::cuda::CUDAStream.
10 // It is needed because TorchBind does not support all of the argument types
11 // for c10::cuda::CUDAStream. For more details, please refer to
12 // c10/cuda/CUDAStream.h.
13 class CUDAStream final : public CustomClassHolder {
14  public:
15   CUDAStream(
16       std::optional<c10::Device> device = std::nullopt,
17       int64_t priority = 0) {
18     c10::DeviceIndex device_index =
19         device.has_value() ? device->index() : c10::cuda::current_device();
20     stream_ = std::make_unique<c10::cuda::CUDAStream>(
21         c10::cuda::getStreamFromPool(static_cast<int>(priority), device_index));
22   }
23 
CUDAStream(c10::cuda::CUDAStream s)24   CUDAStream(c10::cuda::CUDAStream s) {
25     stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
26   }
27 
query()28   bool query() {
29     return stream_->query();
30   }
31 
32   c10::intrusive_ptr<CUDAEvent> recordEvent(
33       c10::intrusive_ptr<CUDAEvent> event);
34 
synchronize()35   void synchronize() {
36     stream_->synchronize();
37   }
38 
39   void waitEvent(const c10::intrusive_ptr<CUDAEvent>& event);
40 
41   void waitStream(const c10::intrusive_ptr<CUDAStream>& stream);
42 
43   /// Get the CUDA device index that this stream is associated with.
device_index()44   int64_t device_index() const {
45     return stream_->device_index();
46   }
47 
48   /// Get the full Device that this stream is associated with.  The Device
49   /// is guaranteed to be a CUDA device.
device()50   c10::Device device() const {
51     return stream_->device();
52   }
53 
54   /// Return the stream ID corresponding to this particular stream.
id()55   int64_t id() const {
56     return stream_->id();
57   }
58 
59  private:
60   std::unique_ptr<c10::cuda::CUDAStream> stream_;
61   friend class CUDAEvent;
62 };
63 
64 // This class is a wrapper around at::cuda::CUDAStream.
65 // It is needed because TorchBind does not support all of the argument types
66 // for at::cuda::CUDAEvent. For more details, please refer to
67 // aten/src/ATen/cuda/CUDAEvent.h.
68 class CUDAEvent final : public CustomClassHolder {
69  public:
70   CUDAEvent(
71       bool enable_timing = false,
72       bool blocking = false,
73       bool interprocess = false) {
74     int flags = cudaEventDisableTiming;
75     if (enable_timing) {
76       flags = cudaEventDefault;
77     }
78     if (blocking) {
79       flags |= cudaEventBlockingSync;
80     }
81     if (interprocess) {
82       TORCH_CHECK(!enable_timing);
83       flags |= cudaEventInterprocess;
84     }
85 
86     event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
87   }
88 
elapsedTime(const c10::intrusive_ptr<CUDAEvent> & end)89   double elapsedTime(const c10::intrusive_ptr<CUDAEvent>& end) {
90     return event_->elapsed_time(*end->event_);
91   }
92 
ipcHandle()93   std::string ipcHandle() {
94     cudaIpcEventHandle_t handle{};
95     event_->ipc_handle(&handle);
96     std::string str_handle((const char*)&handle, sizeof(handle));
97     return str_handle;
98   }
99 
query()100   bool query() {
101     return event_->query();
102   }
103 
104   void record(const c10::intrusive_ptr<CUDAStream>& stream);
105 
synchronize()106   void synchronize() {
107     event_->synchronize();
108   }
109   void wait(const c10::intrusive_ptr<CUDAStream>& stream);
110 
111  private:
112   void recordInternal(CUDAStream* stream);
113   std::unique_ptr<at::cuda::CUDAEvent> event_;
114 
115   friend class CUDAStream;
116 };
117 
recordEvent(c10::intrusive_ptr<CUDAEvent> event)118 inline c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
119     c10::intrusive_ptr<CUDAEvent> event) {
120   if (!event) {
121     event = c10::make_intrusive<CUDAEvent>();
122   }
123 
124   event->recordInternal(this);
125   return event;
126 }
127 
waitEvent(const c10::intrusive_ptr<CUDAEvent> & event)128 inline void CUDAStream::waitEvent(const c10::intrusive_ptr<CUDAEvent>& event) {
129   event->event_->block(*stream_);
130 }
131 
waitStream(const c10::intrusive_ptr<CUDAStream> & stream)132 inline void CUDAStream::waitStream(
133     const c10::intrusive_ptr<CUDAStream>& stream) {
134   auto ev = c10::make_intrusive<CUDAEvent>();
135   stream->recordEvent(ev);
136   waitEvent(ev);
137 }
138 
record(const c10::intrusive_ptr<CUDAStream> & stream)139 inline void CUDAEvent::record(const c10::intrusive_ptr<CUDAStream>& stream) {
140   event_->record(*stream->stream_);
141 }
142 
recordInternal(CUDAStream * stream)143 inline void CUDAEvent::recordInternal(CUDAStream* stream) {
144   event_->record(*stream->stream_);
145 }
146 
wait(const c10::intrusive_ptr<CUDAStream> & stream)147 inline void CUDAEvent::wait(const c10::intrusive_ptr<CUDAStream>& stream) {
148   event_->block(*stream->stream_);
149 }
150 
TORCH_LIBRARY(cuda,m)151 TORCH_LIBRARY(cuda, m) {
152   auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
153       torch::init<std::optional<c10::Device>, int64_t>(),
154       "",
155       {torch::arg("device") = std::nullopt, torch::arg("priority") = 0});
156   auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
157       torch::init<bool, bool, bool>(),
158       "",
159       {torch::arg("enable_timing") = false,
160        torch::arg("blocking") = false,
161        torch::arg("interprocess") = false});
162 
163   stream_class.def("query", &CUDAStream::query)
164       .def("record_event", &CUDAStream::recordEvent)
165       .def("synchronize", &CUDAStream::synchronize)
166       .def("wait_event", &CUDAStream::waitEvent)
167       .def("wait_stream", &CUDAStream::waitStream)
168       .def("device_index", &CUDAStream::device_index)
169       .def_property("device", &CUDAStream::device)
170       .def("id", &CUDAStream::id);
171 
172   event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
173       .def("query", &CUDAEvent::query)
174       .def("record", &CUDAEvent::record)
175       .def("synchronize", &CUDAEvent::synchronize)
176       .def("wait", &CUDAEvent::wait);
177 };
178 
179 } // namespace torch::jit
180