1// Copyright © 2022 Apple Inc. 2 3#include <ATen/mps/MPSAllocatorInterface.h> 4#include <ATen/mps/MPSDevice.h> 5#include <ATen/mps/MPSGeneratorImpl.h> 6#include <ATen/mps/MPSHooks.h> 7#include <ATen/mps/MPSProfiler.h> 8#include <ATen/mps/MPSStream.h> 9#include <c10/util/Logging.h> 10 11namespace at::mps { 12 13void MPSHooks::initMPS() const { 14 C10_LOG_API_USAGE_ONCE("aten.init.mps"); 15 // TODO: initialize MPS devices and streams here 16} 17 18bool MPSHooks::hasMPS() const { 19 return at::mps::is_available(); 20} 21 22bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const { 23 switch (major) { 24 case 15: 25 if (minor > 0) 26 TORCH_WARN("Can't check whether running on 15.", minor, "+ returning one for 15.0+"); 27 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); 28 case 14: 29 switch (minor) { 30 case 0: 31 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS); 32 case 4: 33 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); 34 default: 35 TORCH_WARN("Can't check whether running on 14.", minor, "+ returning one for 14.4+"); 36 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); 37 } 38 case 13: 39 switch (minor) { 40 case 0: 41 return true; 42 case 1: 43 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS); 44 case 2: 45 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); 46 case 3: 47 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); 48 default: 49 TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+"); 50 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); 51 } 52 default: 53 TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false"); 54 return false; 55 } 56} 57 58Allocator* MPSHooks::getMPSDeviceAllocator() const { 59 return at::mps::GetMPSAllocator(); 60} 61 62const Generator& MPSHooks::getDefaultMPSGenerator() const { 63 return at::mps::detail::getDefaultMPSGenerator(); 64} 65 66void MPSHooks::deviceSynchronize() const { 67 at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT); 68} 69 70void MPSHooks::commitStream() const { 71 at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT); 72} 73 74void* MPSHooks::getCommandBuffer() const { 75 return at::mps::getDefaultMPSStream()->commandBuffer(); 76} 77 78void* MPSHooks::getDispatchQueue() const { 79 return at::mps::getDefaultMPSStream()->queue(); 80} 81 82void MPSHooks::emptyCache() const { 83 at::mps::getIMPSAllocator()->emptyCache(); 84} 85 86size_t MPSHooks::getCurrentAllocatedMemory() const { 87 return at::mps::getIMPSAllocator()->getCurrentAllocatedMemory(); 88} 89 90size_t MPSHooks::getDriverAllocatedMemory() const { 91 return at::mps::getIMPSAllocator()->getDriverAllocatedMemory(); 92} 93 94size_t MPSHooks::getRecommendedMaxMemory() const { 95 return at::mps::getIMPSAllocator()->getRecommendedMaxMemory(); 96} 97 98void MPSHooks::setMemoryFraction(double ratio) const { 99 at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio); 100} 101 102void MPSHooks::profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const { 103 at::mps::getMPSProfiler().StartTrace(mode, waitUntilCompleted); 104} 105 106void MPSHooks::profilerStopTrace() const { 107 at::mps::getMPSProfiler().StopTrace(); 108} 109 110uint32_t MPSHooks::acquireEvent(bool enable_timing) const { 111 return at::mps::getMPSEventPool()->acquireEvent(enable_timing); 112} 113 114void MPSHooks::releaseEvent(uint32_t event_id) const { 115 at::mps::getMPSEventPool()->releaseEvent(event_id); 116} 117 118void MPSHooks::recordEvent(uint32_t event_id) const { 119 at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true); 120} 121 122void MPSHooks::waitForEvent(uint32_t event_id) const { 123 at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true); 124} 125 126void MPSHooks::synchronizeEvent(uint32_t event_id) const { 127 at::mps::getMPSEventPool()->synchronizeEvent(event_id); 128} 129 130bool MPSHooks::queryEvent(uint32_t event_id) const { 131 return at::mps::getMPSEventPool()->queryEvent(event_id); 132} 133 134double MPSHooks::elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const { 135 return at::mps::getMPSEventPool()->elapsedTime(start_event_id, end_event_id); 136} 137 138bool MPSHooks::isPinnedPtr(const void* data) const { 139 return at::mps::isMPSPinnedPtr(data); 140} 141 142Allocator* MPSHooks::getPinnedMemoryAllocator() const { 143 return at::mps::getIMPSAllocator(true); 144} 145 146using at::MPSHooksRegistry; 147using at::RegistererMPSHooksRegistry; 148 149REGISTER_MPS_HOOKS(MPSHooks); 150 151} // namespace at::mps 152