1 #include <ATen/cuda/CUDAEvent.h>
2 #include <c10/core/Device.h>
3 #include <c10/cuda/CUDAStream.h>
4 #include <torch/custom_class.h>
5
6 namespace torch::jit {
7
8 class CUDAEvent;
9 // This class is a wrapper around c10::cuda::CUDAStream.
10 // It is needed because TorchBind does not support all of the argument types
11 // for c10::cuda::CUDAStream. For more details, please refer to
12 // c10/cuda/CUDAStream.h.
13 class CUDAStream final : public CustomClassHolder {
14 public:
15 CUDAStream(
16 std::optional<c10::Device> device = std::nullopt,
17 int64_t priority = 0) {
18 c10::DeviceIndex device_index =
19 device.has_value() ? device->index() : c10::cuda::current_device();
20 stream_ = std::make_unique<c10::cuda::CUDAStream>(
21 c10::cuda::getStreamFromPool(static_cast<int>(priority), device_index));
22 }
23
CUDAStream(c10::cuda::CUDAStream s)24 CUDAStream(c10::cuda::CUDAStream s) {
25 stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
26 }
27
query()28 bool query() {
29 return stream_->query();
30 }
31
32 c10::intrusive_ptr<CUDAEvent> recordEvent(
33 c10::intrusive_ptr<CUDAEvent> event);
34
synchronize()35 void synchronize() {
36 stream_->synchronize();
37 }
38
39 void waitEvent(const c10::intrusive_ptr<CUDAEvent>& event);
40
41 void waitStream(const c10::intrusive_ptr<CUDAStream>& stream);
42
43 /// Get the CUDA device index that this stream is associated with.
device_index()44 int64_t device_index() const {
45 return stream_->device_index();
46 }
47
48 /// Get the full Device that this stream is associated with. The Device
49 /// is guaranteed to be a CUDA device.
device()50 c10::Device device() const {
51 return stream_->device();
52 }
53
54 /// Return the stream ID corresponding to this particular stream.
id()55 int64_t id() const {
56 return stream_->id();
57 }
58
59 private:
60 std::unique_ptr<c10::cuda::CUDAStream> stream_;
61 friend class CUDAEvent;
62 };
63
64 // This class is a wrapper around at::cuda::CUDAStream.
65 // It is needed because TorchBind does not support all of the argument types
66 // for at::cuda::CUDAEvent. For more details, please refer to
67 // aten/src/ATen/cuda/CUDAEvent.h.
68 class CUDAEvent final : public CustomClassHolder {
69 public:
70 CUDAEvent(
71 bool enable_timing = false,
72 bool blocking = false,
73 bool interprocess = false) {
74 int flags = cudaEventDisableTiming;
75 if (enable_timing) {
76 flags = cudaEventDefault;
77 }
78 if (blocking) {
79 flags |= cudaEventBlockingSync;
80 }
81 if (interprocess) {
82 TORCH_CHECK(!enable_timing);
83 flags |= cudaEventInterprocess;
84 }
85
86 event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
87 }
88
elapsedTime(const c10::intrusive_ptr<CUDAEvent> & end)89 double elapsedTime(const c10::intrusive_ptr<CUDAEvent>& end) {
90 return event_->elapsed_time(*end->event_);
91 }
92
ipcHandle()93 std::string ipcHandle() {
94 cudaIpcEventHandle_t handle{};
95 event_->ipc_handle(&handle);
96 std::string str_handle((const char*)&handle, sizeof(handle));
97 return str_handle;
98 }
99
query()100 bool query() {
101 return event_->query();
102 }
103
104 void record(const c10::intrusive_ptr<CUDAStream>& stream);
105
synchronize()106 void synchronize() {
107 event_->synchronize();
108 }
109 void wait(const c10::intrusive_ptr<CUDAStream>& stream);
110
111 private:
112 void recordInternal(CUDAStream* stream);
113 std::unique_ptr<at::cuda::CUDAEvent> event_;
114
115 friend class CUDAStream;
116 };
117
recordEvent(c10::intrusive_ptr<CUDAEvent> event)118 inline c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
119 c10::intrusive_ptr<CUDAEvent> event) {
120 if (!event) {
121 event = c10::make_intrusive<CUDAEvent>();
122 }
123
124 event->recordInternal(this);
125 return event;
126 }
127
waitEvent(const c10::intrusive_ptr<CUDAEvent> & event)128 inline void CUDAStream::waitEvent(const c10::intrusive_ptr<CUDAEvent>& event) {
129 event->event_->block(*stream_);
130 }
131
waitStream(const c10::intrusive_ptr<CUDAStream> & stream)132 inline void CUDAStream::waitStream(
133 const c10::intrusive_ptr<CUDAStream>& stream) {
134 auto ev = c10::make_intrusive<CUDAEvent>();
135 stream->recordEvent(ev);
136 waitEvent(ev);
137 }
138
record(const c10::intrusive_ptr<CUDAStream> & stream)139 inline void CUDAEvent::record(const c10::intrusive_ptr<CUDAStream>& stream) {
140 event_->record(*stream->stream_);
141 }
142
recordInternal(CUDAStream * stream)143 inline void CUDAEvent::recordInternal(CUDAStream* stream) {
144 event_->record(*stream->stream_);
145 }
146
wait(const c10::intrusive_ptr<CUDAStream> & stream)147 inline void CUDAEvent::wait(const c10::intrusive_ptr<CUDAStream>& stream) {
148 event_->block(*stream->stream_);
149 }
150
TORCH_LIBRARY(cuda,m)151 TORCH_LIBRARY(cuda, m) {
152 auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
153 torch::init<std::optional<c10::Device>, int64_t>(),
154 "",
155 {torch::arg("device") = std::nullopt, torch::arg("priority") = 0});
156 auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
157 torch::init<bool, bool, bool>(),
158 "",
159 {torch::arg("enable_timing") = false,
160 torch::arg("blocking") = false,
161 torch::arg("interprocess") = false});
162
163 stream_class.def("query", &CUDAStream::query)
164 .def("record_event", &CUDAStream::recordEvent)
165 .def("synchronize", &CUDAStream::synchronize)
166 .def("wait_event", &CUDAStream::waitEvent)
167 .def("wait_stream", &CUDAStream::waitStream)
168 .def("device_index", &CUDAStream::device_index)
169 .def_property("device", &CUDAStream::device)
170 .def("id", &CUDAStream::id);
171
172 event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
173 .def("query", &CUDAEvent::query)
174 .def("record", &CUDAEvent::record)
175 .def("synchronize", &CUDAEvent::synchronize)
176 .def("wait", &CUDAEvent::wait);
177 };
178
179 } // namespace torch::jit
180