xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSEvent.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2023 Apple Inc.
2
3#include <ATen/mps/MPSEvent.h>
4
5namespace at::mps {
6
7MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing)
8    : m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {}
9
10MPSEvent::~MPSEvent() {
11  if (m_event) {
12    [m_event release];
13    m_event = nil;
14  }
15  if (m_listener) {
16    [m_listener release];
17    m_listener = nil;
18  }
19}
20
21void MPSEvent::recordLocked(bool syncEvent) {
22  // active encoders must end before encoding or waiting
23  m_stream->endKernelCoalescing();
24  ++m_signalCounter;
25  if (m_enable_timing) {
26    notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
27      m_completion_time = getTime();
28      notifyCpuSync();
29    });
30  }
31  id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
32  [commandBuffer encodeSignalEvent:m_event value:m_signalCounter];
33  if (syncEvent) {
34    m_stream->synchronize(SyncType::COMMIT);
35  }
36}
37
38bool MPSEvent::waitLocked(bool syncEvent) {
39  // check if event is not recorded yet
40  if (m_event.signaledValue >= m_signalCounter) {
41    return false;
42  }
43  // active encoders must end before encoding or waiting
44  m_stream->endKernelCoalescing();
45  id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
46  [commandBuffer encodeWaitForEvent:m_event value:m_signalCounter];
47  if (syncEvent) {
48    m_stream->synchronize(SyncType::COMMIT);
49  }
50  return true;
51}
52
53bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) {
54  // check if event is not recorded yet
55  if (m_event.signaledValue >= m_signalCounter) {
56    return false;
57  }
58  if (!m_listener) {
59    m_listener = [[MTLSharedEventListener alloc] init];
60  }
61  [m_event notifyListener:m_listener atValue:m_signalCounter block:block];
62  return true;
63}
64
65void MPSEvent::record(bool needsLock, bool syncEvent) {
66  if (!needsLock) {
67    recordLocked(syncEvent);
68    return;
69  }
70  dispatch_sync(m_stream->queue(), ^() {
71    @autoreleasepool {
72      recordLocked(syncEvent);
73    }
74  });
75}
76
77bool MPSEvent::wait(bool needsLock, bool syncEvent) {
78  __block bool waited = false;
79  if (!needsLock) {
80    return waitLocked(syncEvent);
81  }
82  dispatch_sync(m_stream->queue(), ^() {
83    @autoreleasepool {
84      waited = waitLocked(syncEvent);
85    }
86  });
87  return waited;
88}
89
90bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) {
91  if (!needsLock) {
92    return notifyLocked(block);
93  }
94  __block bool scheduledNotify = false;
95  dispatch_sync(m_stream->queue(), ^() {
96    @autoreleasepool {
97      scheduledNotify = notifyLocked(block);
98    }
99  });
100  return scheduledNotify;
101}
102
103void MPSEvent::notifyCpuSync() {
104  std::lock_guard<std::mutex> lock(m_cpu_sync_mutex);
105  m_cpu_sync_completed = true;
106  m_cpu_sync_cv.notify_one();
107}
108
109void MPSEvent::waitForCpuSync() {
110  std::unique_lock<std::mutex> lock(m_cpu_sync_mutex);
111  m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; });
112  m_cpu_sync_completed = false;
113}
114
115bool MPSEvent::synchronize() {
116  bool scheduledNotify = notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
117    m_completion_time = getTime();
118    notifyCpuSync();
119  });
120
121  if (scheduledNotify) {
122    waitForCpuSync();
123    return true;
124  }
125  return false;
126}
127
128bool MPSEvent::query() const {
129  // return false if not recorded or signaled yet
130  return m_signalCounter && (m_event.signaledValue >= m_signalCounter);
131}
132
133void MPSEvent::reset(MPSStream* stream, bool enable_timing) {
134  if (stream != m_stream) {
135    m_signalCounter = 0;
136    m_event.signaledValue = 0;
137    m_stream = stream;
138  }
139  // reset record time
140  m_completion_time = 0;
141  m_enable_timing = enable_timing;
142  m_cpu_sync_completed = false;
143};
144
145//-----------------------------------------------------------------
146//  MPSEventPool
147//-----------------------------------------------------------------
148
149MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) {
150  // default deleter to return the event back to pool after it gets released
151  m_default_deleter = [&](MPSEvent* event) {
152    std::lock_guard<std::recursive_mutex> lock(m_mutex);
153    m_pool.push(std::unique_ptr<MPSEvent>(event));
154  };
155}
156
157MPSEventPool::~MPSEventPool() {
158  emptyCache();
159}
160
161MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) {
162  if (!stream) {
163    stream = m_default_stream;
164  }
165  {
166    std::lock_guard<std::recursive_mutex> lock(m_mutex);
167    if (!m_pool.empty()) {
168      auto event = m_pool.top().release();
169      m_pool.pop();
170      event->reset(stream, enable_timing);
171      return MPSEventPtr(event, m_default_deleter);
172    }
173  }
174  auto new_event = std::make_unique<MPSEvent>(++m_event_counter, stream, enable_timing);
175  return MPSEventPtr(new_event.release(), m_default_deleter);
176}
177
178void MPSEventPool::emptyCache() {
179  std::lock_guard<std::recursive_mutex> lock(m_mutex);
180  while (!m_pool.empty()) {
181    m_pool.pop();
182  }
183}
184
185id_t MPSEventPool::acquireEvent(bool enable_timing) {
186  std::lock_guard<std::recursive_mutex> lock(m_mutex);
187  MPSEventPtr event = acquireEvent(enable_timing, nullptr);
188  TORCH_INTERNAL_ASSERT(event);
189  id_t event_id = event->getID();
190  m_in_use_events.emplace(event_id, std::move(event));
191  return event_id;
192}
193
194void MPSEventPool::releaseEvent(id_t event_id) {
195  std::lock_guard<std::recursive_mutex> lock(m_mutex);
196  TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
197  // returns the event back to the MPSEventPool
198  m_in_use_events.erase(event_id);
199}
200
201void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) {
202  MPSEvent* event = getInUseEvent(event_id);
203  event->record(/*needsLock*/ true, syncEvent);
204}
205
206void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) {
207  MPSEvent* event = getInUseEvent(event_id);
208  event->wait(/*needsLock*/ true, syncEvent);
209}
210
211void MPSEventPool::synchronizeEvent(id_t event_id) {
212  MPSEvent* event = getInUseEvent(event_id);
213  event->synchronize();
214}
215
216bool MPSEventPool::queryEvent(id_t event_id) {
217  MPSEvent* event = getInUseEvent(event_id);
218  return event->query();
219}
220
221double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) {
222  // first make sure notifyListeners are called to capture events' completion times
223  dispatch_sync(m_default_stream->queue(), ^() {
224    m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT);
225  });
226  std::lock_guard<std::recursive_mutex> lock(m_mutex);
227  MPSEvent* start_event = getInUseEvent(start_event_id, false);
228  MPSEvent* end_event = getInUseEvent(end_event_id, false);
229  // the notify is called on a separate thread, so this waits for that
230  end_event->waitForCpuSync();
231  const uint64_t start_time = start_event->getCompletionTime();
232  const uint64_t end_time = end_event->getCompletionTime();
233
234  TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'");
235  TORCH_CHECK(
236      end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id);
237  return double(end_time - start_time) * 1e-6;
238}
239
240MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) {
241  if (locked) {
242    m_mutex.lock();
243  }
244  TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
245  MPSEvent* event = m_in_use_events[event_id].get();
246  if (locked) {
247    m_mutex.unlock();
248  }
249  return event;
250}
251
252std::shared_ptr<MPSEventPool> getMPSEventPool() {
253  static std::shared_ptr<MPSEventPool> event_pool = std::make_shared<MPSEventPool>(getDefaultMPSStream());
254  return event_pool;
255}
256
257} // namespace at::mps
258