xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSDevice.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 #include <c10/core/Allocator.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/Exception.h>
7 
8 
9 #ifdef __OBJC__
10 #include <Foundation/Foundation.h>
11 #include <Metal/Metal.h>
12 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
13 typedef id<MTLDevice> MTLDevice_t;
14 typedef id<MTLLibrary> MTLLibrary_t;
15 typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
16 typedef id<MTLLibrary> MTLLibrary_t;
17 #else
18 typedef void* MTLDevice;
19 typedef void* MTLDevice_t;
20 typedef void* MTLLibrary_t;
21 typedef void* MTLComputePipelineState_t;
22 typedef void* MTLLibrary_t;
23 #endif
24 
25 namespace at::mps {
26 
27 // Helper enum to check if a MPSGraph op is supported in a given macOS version
28 enum class MacOSVersion : uint32_t {
29   MACOS_VER_13_1_PLUS = 0,
30   MACOS_VER_13_2_PLUS,
31   MACOS_VER_13_3_PLUS,
32   MACOS_VER_14_0_PLUS,
33   MACOS_VER_14_4_PLUS,
34   MACOS_VER_15_0_PLUS,
35 };
36 
37 //-----------------------------------------------------------------
38 //  MPSDevice
39 //
40 // MPSDevice is a singleton class that returns the default device
41 //-----------------------------------------------------------------
42 
43 class TORCH_API MPSDevice {
44  public:
45   /**
46    * MPSDevice should not be cloneable.
47    */
48   MPSDevice(MPSDevice& other) = delete;
49   /**
50    * MPSDevice should not be assignable.
51    */
52   void operator=(const MPSDevice&) = delete;
53   /**
54    * Gets single instance of the Device.
55    */
56   static MPSDevice* getInstance();
57   /**
58    * Returns the single device.
59    */
device()60   MTLDevice_t device() {
61     return _mtl_device;
62   }
63   /**
64    * Returns whether running on Ventura or newer
65    */
66   bool isMacOS13Plus(MacOSVersion version) const;
67 
68   MTLComputePipelineState_t metalIndexingPSO(const std::string &kernel);
69   MTLLibrary_t getMetalIndexingLibrary();
70 
71   ~MPSDevice();
72 
73  private:
74   static MPSDevice* _device;
75   MTLDevice_t _mtl_device;
76   MTLLibrary_t _mtl_indexing_library;
77   MPSDevice();
78 };
79 
80 TORCH_API bool is_available();
81 TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
82 TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
83 
84 } // namespace at::mps
85