1// Copyright © 2023 Apple Inc. 2 3#include <ATen/mps/MPSEvent.h> 4 5namespace at::mps { 6 7MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing) 8 : m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {} 9 10MPSEvent::~MPSEvent() { 11 if (m_event) { 12 [m_event release]; 13 m_event = nil; 14 } 15 if (m_listener) { 16 [m_listener release]; 17 m_listener = nil; 18 } 19} 20 21void MPSEvent::recordLocked(bool syncEvent) { 22 // active encoders must end before encoding or waiting 23 m_stream->endKernelCoalescing(); 24 ++m_signalCounter; 25 if (m_enable_timing) { 26 notifyLocked(^(id<MTLSharedEvent>, uint64_t) { 27 m_completion_time = getTime(); 28 notifyCpuSync(); 29 }); 30 } 31 id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer(); 32 [commandBuffer encodeSignalEvent:m_event value:m_signalCounter]; 33 if (syncEvent) { 34 m_stream->synchronize(SyncType::COMMIT); 35 } 36} 37 38bool MPSEvent::waitLocked(bool syncEvent) { 39 // check if event is not recorded yet 40 if (m_event.signaledValue >= m_signalCounter) { 41 return false; 42 } 43 // active encoders must end before encoding or waiting 44 m_stream->endKernelCoalescing(); 45 id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer(); 46 [commandBuffer encodeWaitForEvent:m_event value:m_signalCounter]; 47 if (syncEvent) { 48 m_stream->synchronize(SyncType::COMMIT); 49 } 50 return true; 51} 52 53bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) { 54 // check if event is not recorded yet 55 if (m_event.signaledValue >= m_signalCounter) { 56 return false; 57 } 58 if (!m_listener) { 59 m_listener = [[MTLSharedEventListener alloc] init]; 60 } 61 [m_event notifyListener:m_listener atValue:m_signalCounter block:block]; 62 return true; 63} 64 65void MPSEvent::record(bool needsLock, bool syncEvent) { 66 if (!needsLock) { 67 recordLocked(syncEvent); 68 return; 69 } 70 dispatch_sync(m_stream->queue(), ^() { 71 @autoreleasepool { 72 recordLocked(syncEvent); 73 } 74 }); 75} 76 77bool MPSEvent::wait(bool needsLock, bool syncEvent) { 78 __block bool waited = false; 79 if (!needsLock) { 80 return waitLocked(syncEvent); 81 } 82 dispatch_sync(m_stream->queue(), ^() { 83 @autoreleasepool { 84 waited = waitLocked(syncEvent); 85 } 86 }); 87 return waited; 88} 89 90bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) { 91 if (!needsLock) { 92 return notifyLocked(block); 93 } 94 __block bool scheduledNotify = false; 95 dispatch_sync(m_stream->queue(), ^() { 96 @autoreleasepool { 97 scheduledNotify = notifyLocked(block); 98 } 99 }); 100 return scheduledNotify; 101} 102 103void MPSEvent::notifyCpuSync() { 104 std::lock_guard<std::mutex> lock(m_cpu_sync_mutex); 105 m_cpu_sync_completed = true; 106 m_cpu_sync_cv.notify_one(); 107} 108 109void MPSEvent::waitForCpuSync() { 110 std::unique_lock<std::mutex> lock(m_cpu_sync_mutex); 111 m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; }); 112 m_cpu_sync_completed = false; 113} 114 115bool MPSEvent::synchronize() { 116 bool scheduledNotify = notifyLocked(^(id<MTLSharedEvent>, uint64_t) { 117 m_completion_time = getTime(); 118 notifyCpuSync(); 119 }); 120 121 if (scheduledNotify) { 122 waitForCpuSync(); 123 return true; 124 } 125 return false; 126} 127 128bool MPSEvent::query() const { 129 // return false if not recorded or signaled yet 130 return m_signalCounter && (m_event.signaledValue >= m_signalCounter); 131} 132 133void MPSEvent::reset(MPSStream* stream, bool enable_timing) { 134 if (stream != m_stream) { 135 m_signalCounter = 0; 136 m_event.signaledValue = 0; 137 m_stream = stream; 138 } 139 // reset record time 140 m_completion_time = 0; 141 m_enable_timing = enable_timing; 142 m_cpu_sync_completed = false; 143}; 144 145//----------------------------------------------------------------- 146// MPSEventPool 147//----------------------------------------------------------------- 148 149MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) { 150 // default deleter to return the event back to pool after it gets released 151 m_default_deleter = [&](MPSEvent* event) { 152 std::lock_guard<std::recursive_mutex> lock(m_mutex); 153 m_pool.push(std::unique_ptr<MPSEvent>(event)); 154 }; 155} 156 157MPSEventPool::~MPSEventPool() { 158 emptyCache(); 159} 160 161MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) { 162 if (!stream) { 163 stream = m_default_stream; 164 } 165 { 166 std::lock_guard<std::recursive_mutex> lock(m_mutex); 167 if (!m_pool.empty()) { 168 auto event = m_pool.top().release(); 169 m_pool.pop(); 170 event->reset(stream, enable_timing); 171 return MPSEventPtr(event, m_default_deleter); 172 } 173 } 174 auto new_event = std::make_unique<MPSEvent>(++m_event_counter, stream, enable_timing); 175 return MPSEventPtr(new_event.release(), m_default_deleter); 176} 177 178void MPSEventPool::emptyCache() { 179 std::lock_guard<std::recursive_mutex> lock(m_mutex); 180 while (!m_pool.empty()) { 181 m_pool.pop(); 182 } 183} 184 185id_t MPSEventPool::acquireEvent(bool enable_timing) { 186 std::lock_guard<std::recursive_mutex> lock(m_mutex); 187 MPSEventPtr event = acquireEvent(enable_timing, nullptr); 188 TORCH_INTERNAL_ASSERT(event); 189 id_t event_id = event->getID(); 190 m_in_use_events.emplace(event_id, std::move(event)); 191 return event_id; 192} 193 194void MPSEventPool::releaseEvent(id_t event_id) { 195 std::lock_guard<std::recursive_mutex> lock(m_mutex); 196 TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id); 197 // returns the event back to the MPSEventPool 198 m_in_use_events.erase(event_id); 199} 200 201void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) { 202 MPSEvent* event = getInUseEvent(event_id); 203 event->record(/*needsLock*/ true, syncEvent); 204} 205 206void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) { 207 MPSEvent* event = getInUseEvent(event_id); 208 event->wait(/*needsLock*/ true, syncEvent); 209} 210 211void MPSEventPool::synchronizeEvent(id_t event_id) { 212 MPSEvent* event = getInUseEvent(event_id); 213 event->synchronize(); 214} 215 216bool MPSEventPool::queryEvent(id_t event_id) { 217 MPSEvent* event = getInUseEvent(event_id); 218 return event->query(); 219} 220 221double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) { 222 // first make sure notifyListeners are called to capture events' completion times 223 dispatch_sync(m_default_stream->queue(), ^() { 224 m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT); 225 }); 226 std::lock_guard<std::recursive_mutex> lock(m_mutex); 227 MPSEvent* start_event = getInUseEvent(start_event_id, false); 228 MPSEvent* end_event = getInUseEvent(end_event_id, false); 229 // the notify is called on a separate thread, so this waits for that 230 end_event->waitForCpuSync(); 231 const uint64_t start_time = start_event->getCompletionTime(); 232 const uint64_t end_time = end_event->getCompletionTime(); 233 234 TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'"); 235 TORCH_CHECK( 236 end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id); 237 return double(end_time - start_time) * 1e-6; 238} 239 240MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) { 241 if (locked) { 242 m_mutex.lock(); 243 } 244 TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id); 245 MPSEvent* event = m_in_use_events[event_id].get(); 246 if (locked) { 247 m_mutex.unlock(); 248 } 249 return event; 250} 251 252std::shared_ptr<MPSEventPool> getMPSEventPool() { 253 static std::shared_ptr<MPSEventPool> event_pool = std::make_shared<MPSEventPool>(getDefaultMPSStream()); 254 return event_pool; 255} 256 257} // namespace at::mps 258