1// Copyright © 2022 Apple Inc. 2 3#include <ATen/mps/MPSAllocatorInterface.h> 4#include <ATen/mps/MPSProfiler.h> 5#include <ATen/mps/MPSStream.h> 6 7@interface MPSGraphExecutionDescriptor () 8@property(readwrite, atomic) BOOL enableCommitAndContinue; 9@end 10 11namespace at::mps { 12 13//----------------------------------------------------------------- 14// MPSStream 15//----------------------------------------------------------------- 16 17MPSStream::MPSStream(Stream stream) : _stream(stream) { 18 _commandQueue = [MPSDevice::getInstance()->device() newCommandQueue]; 19 TORCH_CHECK(_stream.device_type() == DeviceType::MPS); 20 _serialQueue = dispatch_queue_create("metal gpu stream", nullptr); 21 _executionDescriptor = [MPSGraphExecutionDescriptor new]; 22 _compilationDescriptor = [MPSGraphCompilationDescriptor new]; 23 24 // disable commitAndContinue if Signpost tracing is enabled 25 if (getMPSProfiler().isSignpostTracingEnabled() || getMPSProfiler().isCaptureEnabled()) { 26 _enableCommitAndContinue = false; 27 } 28 _executionDescriptor.enableCommitAndContinue = _enableCommitAndContinue; 29 30 // Choose level which optimizes for GPU 31 _compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0; 32 _executionDescriptor.compilationDescriptor = _compilationDescriptor; 33} 34 35MPSStream::~MPSStream() { 36 [_commandQueue release]; 37 _commandQueue = nil; 38 [_executionDescriptor release]; 39 [_compilationDescriptor release]; 40 _executionDescriptor = nil; 41 _compilationDescriptor = nil; 42 43 assert(_commandBuffer == nil); 44} 45 46MPSCommandBuffer* MPSStream::commandBuffer() { 47 if (!_commandBuffer) { 48 _commandBuffer = [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain; 49 } 50 51 return _commandBuffer; 52} 53 54id<MTLComputeCommandEncoder> MPSStream::commandEncoder() { 55 if (!_commandEncoder) { 56 _commandEncoder = [commandBuffer() computeCommandEncoder].retain; 57 } 58 59 return _commandEncoder; 60} 61 62void MPSStream::synchronize(SyncType syncType) { 63 endKernelCoalescing(); 64 switch (syncType) { 65 case SyncType::NONE: 66 // typically in GPU to GPU copies we won't commit explicitly 67 break; 68 case SyncType::COMMIT: 69 commit(); 70 break; 71 case SyncType::COMMIT_ADAPTIVE: 72 // the adaptive commit only commits if we hit the low watermark memory threshold 73 if (getIMPSAllocator()->getLowWatermarkValue() <= 1) { 74 commit(); 75 } 76 break; 77 case SyncType::COMMIT_AND_WAIT: 78 commitAndWait(); 79 break; 80 case SyncType::COMMIT_AND_CONTINUE: 81 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_enableCommitAndContinue, 82 "CommitAndContinue is called but it is disabled globally!"); 83 commitAndContinue(); 84 break; 85 } 86} 87 88void MPSStream::commit() { 89 if (_enableCommitAndContinue) { 90 [commandBuffer() commitAndContinue]; 91 } else { 92 flush(); 93 } 94} 95 96void MPSStream::commitAndWait() { 97 if (_prevCommandBuffer) { 98 // the previous command buffer (if exists) has already been committed, 99 // so we just wait until it's completed and then dispose it. 100 [_prevCommandBuffer waitUntilCompleted]; 101 [_prevCommandBuffer release]; 102 _prevCommandBuffer = nil; 103 } 104 105 if (_commandBuffer) { 106 [_commandBuffer commit]; 107 [_commandBuffer waitUntilCompleted]; 108 [_commandBuffer release]; 109 _commandBuffer = nil; 110 } 111} 112 113void MPSStream::commitAndContinue() { 114 assert(_commandBuffer); 115 [_commandBuffer commitAndContinue]; 116} 117 118void MPSStream::endKernelCoalescing() { 119 if (_commandEncoder) { 120 [_commandEncoder endEncoding]; 121 [_commandEncoder release]; 122 _commandEncoder = nil; 123 } 124} 125 126void MPSStream::flush() { 127 if (_commandBuffer) { 128 [_commandBuffer commit]; 129 // if commitAndContinue is disabled (e.g., for Profiler), we keep the command 130 // buffer so we could wait on it later, if required. 131 if (!_enableCommitAndContinue) { 132 _prevCommandBuffer = _commandBuffer; 133 } else { 134 [_commandBuffer release]; 135 } 136 _commandBuffer = nil; 137 } 138} 139 140void MPSStream::addCompletedHandler(MTLCommandBufferHandler block) { 141 dispatch_sync(_serialQueue, ^() { 142 @autoreleasepool { 143 [commandBuffer() addCompletedHandler:block]; 144 } 145 }); 146} 147 148void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { 149 if (length == 0) { 150 return; 151 } 152 dispatch_sync(_serialQueue, ^() { 153 @autoreleasepool { 154 endKernelCoalescing(); 155 id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder]; 156 157 [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; 158 [blitEncoder endEncoding]; 159 synchronize(syncType); 160 } 161 }); 162} 163 164void MPSStream::copy(id<MTLBuffer> srcBuffer, 165 id<MTLBuffer> dstBuffer, 166 size_t length, 167 size_t srcOffset, 168 size_t dstOffset, 169 uint64_t profileId, 170 SyncType syncType) { 171 dispatch_sync(_serialQueue, ^() { 172 @autoreleasepool { 173 endKernelCoalescing(); 174 id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder]; 175 176 // For some reason copyFromBuffer for 4Gb fails without returning an error 177 // See https://github.com/pytorch/pytorch/issues/124335 178 // Workaround by batching copy commands into 2Gb chunks 179 constexpr size_t max_copy_size = 0x80000000; // 2GB 180 size_t bytes_copied = 0; 181 size_t bytes_remains = length; 182 while (bytes_remains > 0) { 183 NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remains); 184 [blitEncoder copyFromBuffer:srcBuffer 185 sourceOffset:(NSUInteger)srcOffset + bytes_copied 186 toBuffer:dstBuffer 187 destinationOffset:(NSUInteger)dstOffset + bytes_copied 188 size:bytes_to_copy]; 189 bytes_copied += bytes_to_copy; 190 bytes_remains -= bytes_to_copy; 191 } 192 [blitEncoder endEncoding]; 193 194 // profilerId has a value only if copy profiling is enabled 195 if (profileId) { 196 getMPSProfiler().endProfileCopy(profileId, syncType); 197 } else { 198 synchronize(syncType); 199 } 200 } 201 }); 202} 203 204void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer, 205 id<MTLBuffer> dstBuffer, 206 size_t length, 207 size_t srcOffset, 208 size_t dstOffset, 209 bool non_blocking, 210 uint64_t profileId) { 211 copy(srcBuffer, 212 dstBuffer, 213 length, 214 srcOffset, 215 dstOffset, 216 profileId, 217 !non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT); 218} 219 220void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { 221 auto& profiler = getMPSProfiler(); 222 const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled(); 223 224 dispatch_sync(_serialQueue, ^() { 225 endKernelCoalescing(); 226 if (isGraphProfilingEnabled) { 227 // this function call is only relevant for interval-based Signposts 228 // which exclude schedule time (only includes GPU run time) 229 profiler.beginProfileGPUInterval(mpsGraph); 230 } 231 // note: CommitAndContinue feature is enabled/disabled via "_executionDescriptor" 232 [mpsGraph encodeToCommandBuffer:commandBuffer() 233 feeds:feeds 234 targetOperations:nil 235 resultsDictionary:results 236 executionDescriptor:_executionDescriptor]; 237 238 SyncType _syncType = syncType; 239 // if commitAndContinue is disabled, we need to always commit manually after encoding 240 if (!_enableCommitAndContinue && syncType != SyncType::COMMIT_AND_WAIT) { 241 _syncType = SyncType::COMMIT; 242 } 243 244 // check if graph execution profiling is enabled 245 if (isGraphProfilingEnabled) { 246 // with profiler enabled, we commit after adding the completedHandler in MPSProfiler 247 profiler.endProfileKernel(mpsGraph, _syncType); 248 } else { 249 synchronize(_syncType); 250 } 251 }); 252} 253 254//----------------------------------------------------------------- 255// MPSStreamImpl 256//----------------------------------------------------------------- 257 258MPSStream* MPSStreamImpl::_stream = nullptr; 259 260MPSStream* MPSStreamImpl::getInstance() { 261 if (_stream == nullptr) { 262 _stream = new MPSStream(Stream(Stream::UNSAFE, c10::Device(DeviceType::MPS, 0), 0)); 263 } 264 return _stream; 265} 266 267MPSStreamImpl::MPSStreamImpl() {} 268 269MPSStream* getCurrentMPSStream() { 270 return getDefaultMPSStream(); 271} 272 273MPSStream* getDefaultMPSStream() { 274 return MPSStreamImpl::getInstance(); 275} 276 277} // namespace at::mps 278