xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSStream.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/mps/MPSAllocatorInterface.h>
4#include <ATen/mps/MPSProfiler.h>
5#include <ATen/mps/MPSStream.h>
6
7@interface MPSGraphExecutionDescriptor ()
8@property(readwrite, atomic) BOOL enableCommitAndContinue;
9@end
10
11namespace at::mps {
12
13//-----------------------------------------------------------------
14//  MPSStream
15//-----------------------------------------------------------------
16
17MPSStream::MPSStream(Stream stream) : _stream(stream) {
18  _commandQueue = [MPSDevice::getInstance()->device() newCommandQueue];
19  TORCH_CHECK(_stream.device_type() == DeviceType::MPS);
20  _serialQueue = dispatch_queue_create("metal gpu stream", nullptr);
21  _executionDescriptor = [MPSGraphExecutionDescriptor new];
22  _compilationDescriptor = [MPSGraphCompilationDescriptor new];
23
24  // disable commitAndContinue if Signpost tracing is enabled
25  if (getMPSProfiler().isSignpostTracingEnabled() || getMPSProfiler().isCaptureEnabled()) {
26    _enableCommitAndContinue = false;
27  }
28  _executionDescriptor.enableCommitAndContinue = _enableCommitAndContinue;
29
30  // Choose level which optimizes for GPU
31  _compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
32  _executionDescriptor.compilationDescriptor = _compilationDescriptor;
33}
34
35MPSStream::~MPSStream() {
36  [_commandQueue release];
37  _commandQueue = nil;
38  [_executionDescriptor release];
39  [_compilationDescriptor release];
40  _executionDescriptor = nil;
41  _compilationDescriptor = nil;
42
43  assert(_commandBuffer == nil);
44}
45
46MPSCommandBuffer* MPSStream::commandBuffer() {
47  if (!_commandBuffer) {
48    _commandBuffer = [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain;
49  }
50
51  return _commandBuffer;
52}
53
54id<MTLComputeCommandEncoder> MPSStream::commandEncoder() {
55  if (!_commandEncoder) {
56    _commandEncoder = [commandBuffer() computeCommandEncoder].retain;
57  }
58
59  return _commandEncoder;
60}
61
62void MPSStream::synchronize(SyncType syncType) {
63  endKernelCoalescing();
64  switch (syncType) {
65    case SyncType::NONE:
66      // typically in GPU to GPU copies we won't commit explicitly
67      break;
68    case SyncType::COMMIT:
69      commit();
70      break;
71    case SyncType::COMMIT_ADAPTIVE:
72      // the adaptive commit only commits if we hit the low watermark memory threshold
73      if (getIMPSAllocator()->getLowWatermarkValue() <= 1) {
74        commit();
75      }
76      break;
77    case SyncType::COMMIT_AND_WAIT:
78      commitAndWait();
79      break;
80    case SyncType::COMMIT_AND_CONTINUE:
81      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_enableCommitAndContinue,
82                                       "CommitAndContinue is called but it is disabled globally!");
83      commitAndContinue();
84      break;
85  }
86}
87
88void MPSStream::commit() {
89  if (_enableCommitAndContinue) {
90    [commandBuffer() commitAndContinue];
91  } else {
92    flush();
93  }
94}
95
96void MPSStream::commitAndWait() {
97  if (_prevCommandBuffer) {
98    // the previous command buffer (if exists) has already been committed,
99    // so we just wait until it's completed and then dispose it.
100    [_prevCommandBuffer waitUntilCompleted];
101    [_prevCommandBuffer release];
102    _prevCommandBuffer = nil;
103  }
104
105  if (_commandBuffer) {
106    [_commandBuffer commit];
107    [_commandBuffer waitUntilCompleted];
108    [_commandBuffer release];
109    _commandBuffer = nil;
110  }
111}
112
113void MPSStream::commitAndContinue() {
114  assert(_commandBuffer);
115  [_commandBuffer commitAndContinue];
116}
117
118void MPSStream::endKernelCoalescing() {
119  if (_commandEncoder) {
120    [_commandEncoder endEncoding];
121    [_commandEncoder release];
122    _commandEncoder = nil;
123  }
124}
125
126void MPSStream::flush() {
127  if (_commandBuffer) {
128    [_commandBuffer commit];
129    // if commitAndContinue is disabled (e.g., for Profiler), we keep the command
130    // buffer so we could wait on it later, if required.
131    if (!_enableCommitAndContinue) {
132      _prevCommandBuffer = _commandBuffer;
133    } else {
134      [_commandBuffer release];
135    }
136    _commandBuffer = nil;
137  }
138}
139
140void MPSStream::addCompletedHandler(MTLCommandBufferHandler block) {
141  dispatch_sync(_serialQueue, ^() {
142    @autoreleasepool {
143      [commandBuffer() addCompletedHandler:block];
144    }
145  });
146}
147
148void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) {
149  if (length == 0) {
150    return;
151  }
152  dispatch_sync(_serialQueue, ^() {
153    @autoreleasepool {
154      endKernelCoalescing();
155      id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
156
157      [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value];
158      [blitEncoder endEncoding];
159      synchronize(syncType);
160    }
161  });
162}
163
164void MPSStream::copy(id<MTLBuffer> srcBuffer,
165                     id<MTLBuffer> dstBuffer,
166                     size_t length,
167                     size_t srcOffset,
168                     size_t dstOffset,
169                     uint64_t profileId,
170                     SyncType syncType) {
171  dispatch_sync(_serialQueue, ^() {
172    @autoreleasepool {
173      endKernelCoalescing();
174      id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
175
176      // For some reason copyFromBuffer for 4Gb fails without returning an error
177      // See https://github.com/pytorch/pytorch/issues/124335
178      // Workaround by batching copy commands into 2Gb chunks
179      constexpr size_t max_copy_size = 0x80000000; // 2GB
180      size_t bytes_copied = 0;
181      size_t bytes_remains = length;
182      while (bytes_remains > 0) {
183        NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remains);
184        [blitEncoder copyFromBuffer:srcBuffer
185                       sourceOffset:(NSUInteger)srcOffset + bytes_copied
186                           toBuffer:dstBuffer
187                  destinationOffset:(NSUInteger)dstOffset + bytes_copied
188                               size:bytes_to_copy];
189        bytes_copied += bytes_to_copy;
190        bytes_remains -= bytes_to_copy;
191      }
192      [blitEncoder endEncoding];
193
194      // profilerId has a value only if copy profiling is enabled
195      if (profileId) {
196        getMPSProfiler().endProfileCopy(profileId, syncType);
197      } else {
198        synchronize(syncType);
199      }
200    }
201  });
202}
203
204void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer,
205                              id<MTLBuffer> dstBuffer,
206                              size_t length,
207                              size_t srcOffset,
208                              size_t dstOffset,
209                              bool non_blocking,
210                              uint64_t profileId) {
211  copy(srcBuffer,
212       dstBuffer,
213       length,
214       srcOffset,
215       dstOffset,
216       profileId,
217       !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
218}
219
220void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) {
221  auto& profiler = getMPSProfiler();
222  const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
223
224  dispatch_sync(_serialQueue, ^() {
225    endKernelCoalescing();
226    if (isGraphProfilingEnabled) {
227      // this function call is only relevant for interval-based Signposts
228      // which exclude schedule time (only includes GPU run time)
229      profiler.beginProfileGPUInterval(mpsGraph);
230    }
231    // note: CommitAndContinue feature is enabled/disabled via "_executionDescriptor"
232    [mpsGraph encodeToCommandBuffer:commandBuffer()
233                              feeds:feeds
234                   targetOperations:nil
235                  resultsDictionary:results
236                executionDescriptor:_executionDescriptor];
237
238    SyncType _syncType = syncType;
239    // if commitAndContinue is disabled, we need to always commit manually after encoding
240    if (!_enableCommitAndContinue && syncType != SyncType::COMMIT_AND_WAIT) {
241      _syncType = SyncType::COMMIT;
242    }
243
244    // check if graph execution profiling is enabled
245    if (isGraphProfilingEnabled) {
246      // with profiler enabled, we commit after adding the completedHandler in MPSProfiler
247      profiler.endProfileKernel(mpsGraph, _syncType);
248    } else {
249      synchronize(_syncType);
250    }
251  });
252}
253
254//-----------------------------------------------------------------
255//  MPSStreamImpl
256//-----------------------------------------------------------------
257
258MPSStream* MPSStreamImpl::_stream = nullptr;
259
260MPSStream* MPSStreamImpl::getInstance() {
261  if (_stream == nullptr) {
262    _stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS, 0), 0));
263  }
264  return _stream;
265}
266
267MPSStreamImpl::MPSStreamImpl() {}
268
269MPSStream* getCurrentMPSStream() {
270  return getDefaultMPSStream();
271}
272
273MPSStream* getDefaultMPSStream() {
274  return MPSStreamImpl::getInstance();
275}
276
277} // namespace at::mps
278