1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6#include <executorch/backends/apple/mps/runtime/MPSStream.h> 7#include <executorch/runtime/platform/assert.h> 8#include <vector> 9 10@interface MPSGraphExecutionDescriptor () 11@property (readwrite, atomic) BOOL enableCommitAndContinue; 12@end 13 14namespace executorch { 15namespace backends { 16namespace mps { 17namespace delegate { 18 19using executorch::runtime::Error; 20 21//----------------------------------------------------------------- 22// MPSStream 23//----------------------------------------------------------------- 24 25MPSStream::MPSStream() { 26 _commandQueue = [MPSDevice::getInstance()->device() newCommandQueue]; 27 _serialQueue = dispatch_queue_create("metal gpu stream", nullptr); 28} 29 30MPSStream::~MPSStream() { 31 [_commandQueue release]; 32 _commandQueue = nil; 33 34 assert(_commandBuffer == nil); 35} 36 37bool MPSStream::hasLiveCommandBuffer() { 38 return _commandBuffer; 39} 40 41API_AVAILABLE(ios(13.0)) 42MPSCommandBuffer* MPSStream::commandBuffer() { 43 if (!_commandBuffer) { 44 _commandBuffer = [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain; 45 } 46 47 return _commandBuffer; 48} 49 50id<MTLComputeCommandEncoder> MPSStream::commandEncoder() { 51 if (!_commandEncoder) { 52 if (@available(iOS 13.0, *)) { 53 _commandEncoder = [commandBuffer() computeCommandEncoder].retain; 54 } 55 } 56 57 return _commandEncoder; 58} 59 60ET_NODISCARD 61Error MPSStream::synchronize(SyncType syncType) { 62 endKernelCoalescing(); 63 switch(syncType) { 64 case SyncType::COMMIT: 65 commit(); 66 break; 67 case SyncType::COMMIT_AND_WAIT: 68 commitAndWait(); 69 break; 70 case SyncType::COMMIT_ADAPTIVE: 71 break; 72 case SyncType::COMMIT_AND_CONTINUE: 73 ET_CHECK_OR_RETURN_ERROR( 74 _enableCommitAndContinue == true, 75 Internal, 76 "CommitAndContinue is called but it is disabled globally!"); 77 commitAndContinue(); 78 break; 79 default: 80 ET_CHECK_OR_RETURN_ERROR( 81 false, 82 Internal, 83 "Unhandled syncType type"); 84 } 85 86 return Error::Ok; 87} 88 89bool MPSStream::commitAndContinueEnabled() { 90 return _enableCommitAndContinue; 91} 92 93void MPSStream::commitAndContinue() { 94 assert(_commandBuffer); 95 [_commandBuffer commitAndContinue]; 96} 97 98void MPSStream::endKernelCoalescing() { 99 if (_commandEncoder) { 100 [_commandEncoder endEncoding]; 101 [_commandEncoder release]; 102 _commandEncoder = nil; 103 } 104} 105 106void MPSStream::commitAndWait() { 107 if (_prevCommandBuffer) { 108 // the previous command buffer (if exists) has already been committed, 109 // so we just wait until it's completed and then dispose it. 110 [_prevCommandBuffer waitUntilCompleted]; 111 [_prevCommandBuffer release]; 112 _prevCommandBuffer = nil; 113 } 114 115 if (_commandBuffer) { 116 [_commandBuffer commit]; 117 [_commandBuffer waitUntilCompleted]; 118 [_commandBuffer release]; 119 _commandBuffer = nil; 120 // reset the accumulated resource sizes for command buffer 121 _commandBufferResourceSize = 0; 122 } 123} 124 125void MPSStream::commit() { 126 if (_enableCommitAndContinue) { 127 if (@available(iOS 13.0, *)) { 128 [commandBuffer() commitAndContinue]; 129 } 130 } else { 131 flush(); 132 } 133 // reset the accumulated resource sizes for command buffer 134 _commandBufferResourceSize = 0; 135} 136 137void MPSStream::flush() { 138 if (_commandBuffer) { 139 [_commandBuffer commit]; 140 // if commitAndContinue is disabled (e.g., for Profiler), we keep the command 141 // buffer so we could wait on it later, if required. 142 if (!_enableCommitAndContinue) { 143 _prevCommandBuffer = _commandBuffer; 144 } else { 145 [_commandBuffer release]; 146 } 147 _commandBuffer = nil; 148 } 149} 150 151void MPSStream::copy(id<MTLBuffer> srcBuffer, 152 id<MTLBuffer> dstBuffer, 153 size_t length, 154 size_t srcOffset, 155 size_t dstOffset, 156 SyncType syncType) { 157 dispatch_sync(_serialQueue, ^() { 158 @autoreleasepool { 159 endKernelCoalescing(); 160 if (@available(iOS 13.0, *)) { 161 id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder]; 162 163 [blitEncoder copyFromBuffer:srcBuffer 164 sourceOffset:(NSUInteger)srcOffset 165 toBuffer:dstBuffer 166 destinationOffset:(NSUInteger)dstOffset 167 size:(NSUInteger)length]; 168 [blitEncoder endEncoding]; 169 } 170 ET_CHECK(synchronize(syncType) == Error::Ok); 171 } 172 }); 173} 174 175void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer, 176 id<MTLBuffer> dstBuffer, 177 size_t length, 178 size_t srcOffset, 179 size_t dstOffset, 180 bool non_blocking) { 181 copy(srcBuffer, 182 dstBuffer, 183 length, 184 srcOffset, 185 dstOffset, 186 !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE); 187} 188 189void MPSStream::copy(std::vector<CPUBufferWrapper>& dataBuffers, 190 SyncType syncType) { 191 dispatch_sync(_serialQueue, ^() { 192 @autoreleasepool { 193#if TARGET_OS_SIMULATOR 194 if (dataBuffers[0].dstCpu) { 195 // If the destination is a CPU buffer, 196 // wait for the GPU to finish executing 197 // before copying into the CPU buffers. 198 ET_CHECK(synchronize(SyncType::COMMIT_AND_WAIT) == Error::Ok); 199 } 200 for (int i = 0; i < dataBuffers.size(); i++) { 201 uint8_t* src = nil; 202 uint8_t* dst = nil; 203 if (dataBuffers[i].srcCpu) { 204 src = static_cast<uint8_t*>(dataBuffers[i].srcBuffer) + dataBuffers[i].srcOffset; 205 dst = (uint8_t*)([(id<MTLBuffer>)dataBuffers[i].dstBuffer contents]) + dataBuffers[i].dstOffset; 206 } else { 207 ET_CHECK(dataBuffers[i].dstCpu); 208 src = (uint8_t*)([(id<MTLBuffer>)dataBuffers[i].srcBuffer contents]) + dataBuffers[i].srcOffset; 209 dst = static_cast<uint8_t*>(dataBuffers[i].dstBuffer) + dataBuffers[i].dstOffset; 210 } 211 memcpy(dst, src, dataBuffers[i].length); 212 } 213#else 214 endKernelCoalescing(); 215 id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder]; 216 217 for (int i = 0; i < dataBuffers.size(); i++) { 218 [blitEncoder copyFromBuffer:(id<MTLBuffer>)dataBuffers[i].srcBuffer 219 sourceOffset:(NSUInteger)dataBuffers[i].srcOffset 220 toBuffer:(id<MTLBuffer>)dataBuffers[i].dstBuffer 221 destinationOffset:(NSUInteger)dataBuffers[i].dstOffset 222 size:(NSUInteger)dataBuffers[i].length]; 223 } 224 [blitEncoder endEncoding]; 225 ET_CHECK(synchronize(syncType) == Error::Ok); 226#endif 227 } 228 }); 229} 230 231void MPSStream::copy_and_sync(std::vector<CPUBufferWrapper>& dataBuffers, 232 bool non_blocking) { 233 copy(dataBuffers, 234 !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE); 235} 236 237//----------------------------------------------------------------- 238// MPSStreamImpl 239//----------------------------------------------------------------- 240 241MPSStream* MPSStreamImpl::_stream = nullptr; 242 243MPSStream* MPSStreamImpl::getInstance() { 244 if (_stream == nullptr) { 245 _stream = 246 new MPSStream(); 247 } 248 return _stream; 249} 250 251MPSStreamImpl::MPSStreamImpl() {} 252 253MPSStream* getCurrentMPSStream() { 254 return getDefaultMPSStream(); 255} 256 257MPSStream* getDefaultMPSStream() { 258 return MPSStreamImpl::getInstance(); 259} 260 261} // namespace delegate 262} // namespace mps 263} // namespace backends 264} // namespace executorch 265