1 // Copyright © 2022 Apple Inc. 2 3 #pragma once 4 5 #include <ATen/detail/MPSHooksInterface.h> 6 #include <ATen/Generator.h> 7 #include <ATen/mps/MPSEvent.h> 8 #include <optional> 9 10 namespace at::mps { 11 12 // The real implementation of MPSHooksInterface 13 struct MPSHooks : public at::MPSHooksInterface { MPSHooksMPSHooks14 MPSHooks(at::MPSHooksArgs) {} 15 void initMPS() const override; 16 17 // MPSDevice interface 18 bool hasMPS() const override; 19 bool isOnMacOSorNewer(unsigned major, unsigned minor) const override; 20 21 // MPSGeneratorImpl interface 22 const Generator& getDefaultMPSGenerator() const override; 23 24 // MPSStream interface 25 void deviceSynchronize() const override; 26 void commitStream() const override; 27 void* getCommandBuffer() const override; 28 void* getDispatchQueue() const override; 29 30 // MPSAllocator interface 31 Allocator* getMPSDeviceAllocator() const override; 32 void emptyCache() const override; 33 size_t getCurrentAllocatedMemory() const override; 34 size_t getDriverAllocatedMemory() const override; 35 size_t getRecommendedMaxMemory() const override; 36 void setMemoryFraction(double ratio) const override; 37 bool isPinnedPtr(const void* data) const override; 38 Allocator* getPinnedMemoryAllocator() const override; 39 40 // MPSProfiler interface 41 void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override; 42 void profilerStopTrace() const override; 43 44 // MPSEvent interface 45 uint32_t acquireEvent(bool enable_timing) const override; 46 void releaseEvent(uint32_t event_id) const override; 47 void recordEvent(uint32_t event_id) const override; 48 void waitForEvent(uint32_t event_id) const override; 49 void synchronizeEvent(uint32_t event_id) const override; 50 bool queryEvent(uint32_t event_id) const override; 51 double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override; 52 53 // Compatibility with Accelerator API hasPrimaryContextMPSHooks54 bool hasPrimaryContext(DeviceIndex device_index) const override { 55 // When MPS is available, it is always in use for the one device. 56 return true; 57 } 58 }; 59 60 } // namespace at::mps 61