xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSHooks.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/mps/MPSAllocatorInterface.h>
4#include <ATen/mps/MPSDevice.h>
5#include <ATen/mps/MPSGeneratorImpl.h>
6#include <ATen/mps/MPSHooks.h>
7#include <ATen/mps/MPSProfiler.h>
8#include <ATen/mps/MPSStream.h>
9#include <c10/util/Logging.h>
10
11namespace at::mps {
12
13void MPSHooks::initMPS() const {
14  C10_LOG_API_USAGE_ONCE("aten.init.mps");
15  // TODO: initialize MPS devices and streams here
16}
17
18bool MPSHooks::hasMPS() const {
19  return at::mps::is_available();
20}
21
22bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
23  switch (major) {
24    case 15:
25      if (minor > 0)
26        TORCH_WARN("Can't check whether running on 15.", minor, "+ returning one for 15.0+");
27      return is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
28    case 14:
29      switch (minor) {
30        case 0:
31          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
32        case 4:
33          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
34        default:
35          TORCH_WARN("Can't check whether running on 14.", minor, "+ returning one for 14.4+");
36          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
37      }
38    case 13:
39      switch (minor) {
40        case 0:
41          return true;
42        case 1:
43          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
44        case 2:
45          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
46        case 3:
47          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
48        default:
49          TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
50          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
51      }
52    default:
53      TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false");
54      return false;
55  }
56}
57
58Allocator* MPSHooks::getMPSDeviceAllocator() const {
59  return at::mps::GetMPSAllocator();
60}
61
62const Generator& MPSHooks::getDefaultMPSGenerator() const {
63  return at::mps::detail::getDefaultMPSGenerator();
64}
65
66void MPSHooks::deviceSynchronize() const {
67  at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
68}
69
70void MPSHooks::commitStream() const {
71  at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT);
72}
73
74void* MPSHooks::getCommandBuffer() const {
75  return at::mps::getDefaultMPSStream()->commandBuffer();
76}
77
78void* MPSHooks::getDispatchQueue() const {
79  return at::mps::getDefaultMPSStream()->queue();
80}
81
82void MPSHooks::emptyCache() const {
83  at::mps::getIMPSAllocator()->emptyCache();
84}
85
86size_t MPSHooks::getCurrentAllocatedMemory() const {
87  return at::mps::getIMPSAllocator()->getCurrentAllocatedMemory();
88}
89
90size_t MPSHooks::getDriverAllocatedMemory() const {
91  return at::mps::getIMPSAllocator()->getDriverAllocatedMemory();
92}
93
94size_t MPSHooks::getRecommendedMaxMemory() const {
95  return at::mps::getIMPSAllocator()->getRecommendedMaxMemory();
96}
97
98void MPSHooks::setMemoryFraction(double ratio) const {
99  at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio);
100}
101
102void MPSHooks::profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
103  at::mps::getMPSProfiler().StartTrace(mode, waitUntilCompleted);
104}
105
106void MPSHooks::profilerStopTrace() const {
107  at::mps::getMPSProfiler().StopTrace();
108}
109
110uint32_t MPSHooks::acquireEvent(bool enable_timing) const {
111  return at::mps::getMPSEventPool()->acquireEvent(enable_timing);
112}
113
114void MPSHooks::releaseEvent(uint32_t event_id) const {
115  at::mps::getMPSEventPool()->releaseEvent(event_id);
116}
117
118void MPSHooks::recordEvent(uint32_t event_id) const {
119  at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true);
120}
121
122void MPSHooks::waitForEvent(uint32_t event_id) const {
123  at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true);
124}
125
126void MPSHooks::synchronizeEvent(uint32_t event_id) const {
127  at::mps::getMPSEventPool()->synchronizeEvent(event_id);
128}
129
130bool MPSHooks::queryEvent(uint32_t event_id) const {
131  return at::mps::getMPSEventPool()->queryEvent(event_id);
132}
133
134double MPSHooks::elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
135  return at::mps::getMPSEventPool()->elapsedTime(start_event_id, end_event_id);
136}
137
138bool MPSHooks::isPinnedPtr(const void* data) const {
139  return at::mps::isMPSPinnedPtr(data);
140}
141
142Allocator* MPSHooks::getPinnedMemoryAllocator() const {
143  return at::mps::getIMPSAllocator(true);
144}
145
146using at::MPSHooksRegistry;
147using at::RegistererMPSHooksRegistry;
148
149REGISTER_MPS_HOOKS(MPSHooks);
150
151} // namespace at::mps
152