xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSStream.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6#include <executorch/backends/apple/mps/runtime/MPSStream.h>
7#include <executorch/runtime/platform/assert.h>
8#include <vector>
9
10@interface MPSGraphExecutionDescriptor ()
11@property (readwrite, atomic) BOOL enableCommitAndContinue;
12@end
13
14namespace executorch {
15namespace backends {
16namespace mps {
17namespace delegate {
18
19using executorch::runtime::Error;
20
21//-----------------------------------------------------------------
22//  MPSStream
23//-----------------------------------------------------------------
24
25MPSStream::MPSStream() {
26  _commandQueue = [MPSDevice::getInstance()->device() newCommandQueue];
27  _serialQueue = dispatch_queue_create("metal gpu stream", nullptr);
28}
29
30MPSStream::~MPSStream() {
31  [_commandQueue release];
32  _commandQueue = nil;
33
34  assert(_commandBuffer == nil);
35}
36
37bool MPSStream::hasLiveCommandBuffer() {
38  return _commandBuffer;
39}
40
41API_AVAILABLE(ios(13.0))
42MPSCommandBuffer* MPSStream::commandBuffer() {
43  if (!_commandBuffer) {
44    _commandBuffer = [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain;
45  }
46
47  return _commandBuffer;
48}
49
50id<MTLComputeCommandEncoder> MPSStream::commandEncoder() {
51  if (!_commandEncoder) {
52    if (@available(iOS 13.0, *)) {
53      _commandEncoder = [commandBuffer() computeCommandEncoder].retain;
54    }
55  }
56
57  return _commandEncoder;
58}
59
60ET_NODISCARD
61Error MPSStream::synchronize(SyncType syncType) {
62  endKernelCoalescing();
63  switch(syncType) {
64    case SyncType::COMMIT:
65      commit();
66      break;
67    case SyncType::COMMIT_AND_WAIT:
68      commitAndWait();
69      break;
70    case SyncType::COMMIT_ADAPTIVE:
71      break;
72    case SyncType::COMMIT_AND_CONTINUE:
73      ET_CHECK_OR_RETURN_ERROR(
74        _enableCommitAndContinue == true,
75        Internal,
76        "CommitAndContinue is called but it is disabled globally!");
77      commitAndContinue();
78      break;
79    default:
80      ET_CHECK_OR_RETURN_ERROR(
81        false,
82        Internal,
83        "Unhandled syncType type");
84  }
85
86  return Error::Ok;
87}
88
89bool MPSStream::commitAndContinueEnabled() {
90  return _enableCommitAndContinue;
91}
92
93void MPSStream::commitAndContinue() {
94  assert(_commandBuffer);
95  [_commandBuffer commitAndContinue];
96}
97
98void MPSStream::endKernelCoalescing() {
99  if (_commandEncoder) {
100    [_commandEncoder endEncoding];
101    [_commandEncoder release];
102    _commandEncoder = nil;
103  }
104}
105
106void MPSStream::commitAndWait() {
107  if (_prevCommandBuffer) {
108    // the previous command buffer (if exists) has already been committed,
109    // so we just wait until it's completed and then dispose it.
110    [_prevCommandBuffer waitUntilCompleted];
111    [_prevCommandBuffer release];
112    _prevCommandBuffer = nil;
113  }
114
115  if (_commandBuffer) {
116    [_commandBuffer commit];
117    [_commandBuffer waitUntilCompleted];
118    [_commandBuffer release];
119    _commandBuffer = nil;
120    // reset the accumulated resource sizes for command buffer
121    _commandBufferResourceSize = 0;
122  }
123}
124
125void MPSStream::commit() {
126  if (_enableCommitAndContinue) {
127    if (@available(iOS 13.0, *)) {
128      [commandBuffer() commitAndContinue];
129    }
130  } else {
131    flush();
132  }
133  // reset the accumulated resource sizes for command buffer
134  _commandBufferResourceSize = 0;
135}
136
137void MPSStream::flush() {
138  if (_commandBuffer) {
139    [_commandBuffer commit];
140    // if commitAndContinue is disabled (e.g., for Profiler), we keep the command
141    // buffer so we could wait on it later, if required.
142    if (!_enableCommitAndContinue) {
143      _prevCommandBuffer = _commandBuffer;
144    } else {
145      [_commandBuffer release];
146    }
147    _commandBuffer = nil;
148  }
149}
150
151void MPSStream::copy(id<MTLBuffer> srcBuffer,
152                     id<MTLBuffer> dstBuffer,
153                     size_t length,
154                     size_t srcOffset,
155                     size_t dstOffset,
156                     SyncType syncType) {
157  dispatch_sync(_serialQueue, ^() {
158    @autoreleasepool {
159      endKernelCoalescing();
160      if (@available(iOS 13.0, *)) {
161        id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
162
163        [blitEncoder copyFromBuffer:srcBuffer
164                       sourceOffset:(NSUInteger)srcOffset
165                           toBuffer:dstBuffer
166                  destinationOffset:(NSUInteger)dstOffset
167                               size:(NSUInteger)length];
168        [blitEncoder endEncoding];
169      }
170      ET_CHECK(synchronize(syncType) == Error::Ok);
171    }
172  });
173}
174
175void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer,
176                              id<MTLBuffer> dstBuffer,
177                              size_t length,
178                              size_t srcOffset,
179                              size_t dstOffset,
180                              bool non_blocking) {
181  copy(srcBuffer,
182       dstBuffer,
183       length,
184       srcOffset,
185       dstOffset,
186       !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
187}
188
189void MPSStream::copy(std::vector<CPUBufferWrapper>& dataBuffers,
190                     SyncType syncType) {
191  dispatch_sync(_serialQueue, ^() {
192    @autoreleasepool {
193#if TARGET_OS_SIMULATOR
194      if (dataBuffers[0].dstCpu) {
195        // If the destination is a CPU buffer,
196        // wait for the GPU to finish executing
197        // before copying into the CPU buffers.
198        ET_CHECK(synchronize(SyncType::COMMIT_AND_WAIT) == Error::Ok);
199      }
200      for (int i = 0; i < dataBuffers.size(); i++) {
201        uint8_t* src = nil;
202        uint8_t* dst = nil;
203        if (dataBuffers[i].srcCpu) {
204          src = static_cast<uint8_t*>(dataBuffers[i].srcBuffer) + dataBuffers[i].srcOffset;
205          dst = (uint8_t*)([(id<MTLBuffer>)dataBuffers[i].dstBuffer contents]) + dataBuffers[i].dstOffset;
206        } else {
207          ET_CHECK(dataBuffers[i].dstCpu);
208          src = (uint8_t*)([(id<MTLBuffer>)dataBuffers[i].srcBuffer contents]) + dataBuffers[i].srcOffset;
209          dst = static_cast<uint8_t*>(dataBuffers[i].dstBuffer) + dataBuffers[i].dstOffset;
210        }
211        memcpy(dst, src, dataBuffers[i].length);
212      }
213#else
214      endKernelCoalescing();
215      id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
216
217      for (int i = 0; i < dataBuffers.size(); i++) {
218        [blitEncoder copyFromBuffer:(id<MTLBuffer>)dataBuffers[i].srcBuffer
219                       sourceOffset:(NSUInteger)dataBuffers[i].srcOffset
220                           toBuffer:(id<MTLBuffer>)dataBuffers[i].dstBuffer
221                  destinationOffset:(NSUInteger)dataBuffers[i].dstOffset
222                               size:(NSUInteger)dataBuffers[i].length];
223      }
224      [blitEncoder endEncoding];
225      ET_CHECK(synchronize(syncType) == Error::Ok);
226#endif
227    }
228  });
229}
230
231void MPSStream::copy_and_sync(std::vector<CPUBufferWrapper>& dataBuffers,
232                              bool non_blocking) {
233  copy(dataBuffers,
234       !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT_ADAPTIVE);
235}
236
237//-----------------------------------------------------------------
238//  MPSStreamImpl
239//-----------------------------------------------------------------
240
241MPSStream* MPSStreamImpl::_stream = nullptr;
242
243MPSStream* MPSStreamImpl::getInstance() {
244  if (_stream == nullptr) {
245    _stream =
246        new MPSStream();
247  }
248  return _stream;
249}
250
251MPSStreamImpl::MPSStreamImpl() {}
252
253MPSStream* getCurrentMPSStream() {
254  return getDefaultMPSStream();
255}
256
257MPSStream* getDefaultMPSStream() {
258  return MPSStreamImpl::getInstance();
259}
260
261} // namespace delegate
262} // namespace mps
263} // namespace backends
264} // namespace executorch
265