xref: /aosp_15_r20/external/pytorch/aten/src/ATen/metal/Context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef MetalContext_h
2 #define MetalContext_h
3 
4 #include <atomic>
5 
6 #include <ATen/Tensor.h>
7 
8 namespace at::metal {
9 
10 struct MetalInterface {
11   virtual ~MetalInterface() = default;
12   virtual bool is_metal_available() const = 0;
13   virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src)
14       const = 0;
15 };
16 
17 extern std::atomic<const MetalInterface*> g_metal_impl_registry;
18 
19 class MetalImplRegistrar {
20  public:
21   explicit MetalImplRegistrar(MetalInterface*);
22 };
23 
24 at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src);
25 
26 } // namespace at::metal
27 
28 namespace at::native {
29 bool is_metal_available();
30 } // namespace at::native
31 
32 #endif /* MetalContext_h */
33