1 // Copyright © 2022 Apple Inc. 2 3 #pragma once 4 #include <c10/core/Allocator.h> 5 #include <c10/macros/Macros.h> 6 #include <c10/util/Exception.h> 7 8 9 #ifdef __OBJC__ 10 #include <Foundation/Foundation.h> 11 #include <Metal/Metal.h> 12 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 13 typedef id<MTLDevice> MTLDevice_t; 14 typedef id<MTLLibrary> MTLLibrary_t; 15 typedef id<MTLComputePipelineState> MTLComputePipelineState_t; 16 typedef id<MTLLibrary> MTLLibrary_t; 17 #else 18 typedef void* MTLDevice; 19 typedef void* MTLDevice_t; 20 typedef void* MTLLibrary_t; 21 typedef void* MTLComputePipelineState_t; 22 typedef void* MTLLibrary_t; 23 #endif 24 25 namespace at::mps { 26 27 // Helper enum to check if a MPSGraph op is supported in a given macOS version 28 enum class MacOSVersion : uint32_t { 29 MACOS_VER_13_1_PLUS = 0, 30 MACOS_VER_13_2_PLUS, 31 MACOS_VER_13_3_PLUS, 32 MACOS_VER_14_0_PLUS, 33 MACOS_VER_14_4_PLUS, 34 MACOS_VER_15_0_PLUS, 35 }; 36 37 //----------------------------------------------------------------- 38 // MPSDevice 39 // 40 // MPSDevice is a singleton class that returns the default device 41 //----------------------------------------------------------------- 42 43 class TORCH_API MPSDevice { 44 public: 45 /** 46 * MPSDevice should not be cloneable. 47 */ 48 MPSDevice(MPSDevice& other) = delete; 49 /** 50 * MPSDevice should not be assignable. 51 */ 52 void operator=(const MPSDevice&) = delete; 53 /** 54 * Gets single instance of the Device. 55 */ 56 static MPSDevice* getInstance(); 57 /** 58 * Returns the single device. 59 */ device()60 MTLDevice_t device() { 61 return _mtl_device; 62 } 63 /** 64 * Returns whether running on Ventura or newer 65 */ 66 bool isMacOS13Plus(MacOSVersion version) const; 67 68 MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel); 69 MTLLibrary_t getMetalIndexingLibrary(); 70 71 ~MPSDevice(); 72 73 private: 74 static MPSDevice* _device; 75 MTLDevice_t _mtl_device; 76 MTLLibrary_t _mtl_indexing_library; 77 MPSDevice(); 78 }; 79 80 TORCH_API bool is_available(); 81 TORCH_API bool is_macos_13_or_newer(MacOSVersion version); 82 TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); 83 84 } // namespace at::mps 85