1#import <ATen/native/metal/MetalTensorImpl.h> 2#import <ATen/native/metal/MetalTensorImplStorage.h> 3#import <ATen/native/metal/MetalContext.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#include <ATen/metal/Context.h> 6#include <torch/script.h> 7 8namespace at { 9namespace native::metal { 10 11static Tensor& copy_from_metal_(Tensor& dst, const Tensor& src) { 12 TORCH_INTERNAL_ASSERT( 13 src.device().type() == DeviceType::Metal, 14 "copy_from_metal input tensor's device is not metal"); 15 TORCH_INTERNAL_ASSERT( 16 dst.device().is_cpu(), 17 "copy_from_metal is implemented only for CPU device output"); 18 TORCH_INTERNAL_ASSERT( 19 dst.layout() == Layout::Strided, 20 "copy_from_metal is implemented only for Strided layout output"); 21 TORCH_INTERNAL_ASSERT( 22 dst.scalar_type() == ScalarType::Float, 23 "copy_from_metal is implemented only for float dtype output, got:", 24 dst.scalar_type()); 25 TORCH_INTERNAL_ASSERT( 26 dst.is_contiguous(), 27 "copy_from_metal is implemented only for contiguous output tensor"); 28 if(dst.numel() == 0){ 29 return dst; 30 } 31 MetalTensorImplStorage& tensorImplStorage = getTensorImplStorage(src); 32 tensorImplStorage.copy_data_to_host(dst.data_ptr<float>()); 33 return dst; 34} 35 36static Tensor& copy_to_metal_(Tensor& dst, const Tensor& src) { 37 TORCH_INTERNAL_ASSERT( 38 dst.device().type() == DeviceType::Metal, 39 "copy_to_metal_ output tensor's device is not metal"); 40 TORCH_INTERNAL_ASSERT( 41 src.device().is_cpu(), 42 "copy_to_metal_ is implemented only for CPU device input"); 43 TORCH_INTERNAL_ASSERT( 44 src.layout() == Layout::Strided, 45 "copy_to_metal_ is implemented only for Strided layout input"); 46 TORCH_INTERNAL_ASSERT( 47 src.scalar_type() == ScalarType::Float, 48 "copy_to_metal_ is implemented only for float dtype"); 49 50 auto cpu_tensor_contiguous = src.contiguous(); 51 MetalTensorImplStorage& tensorImplStorage = getTensorImplStorage(dst); 52 tensorImplStorage.set_data_from_host(cpu_tensor_contiguous.data_ptr<float>()); 53 return dst; 54} 55 56static Tensor& metal_copy_impl_(Tensor& dst, const Tensor& src) { 57 if (src.device().type() == at::kMetal && dst.device().type() == at::kCPU) { 58 return copy_from_metal_(dst, src); 59 } 60 if (src.device().type() == at::kCPU && dst.device().type() == at::kMetal) { 61 return copy_to_metal_(dst, src); 62 } 63 TORCH_INTERNAL_ASSERT( 64 src.device().type() == DeviceType::Metal, 65 "metal_copy_ is implemented only for CPU,Strided,float->Metal; Metal->CPU,Strided,float"); 66 return dst; 67} 68 69#pragma mark - ATen Ops 70 71static Tensor empty( 72 c10::SymIntArrayRef sym_size, 73 std::optional<ScalarType> dtype, 74 std::optional<Layout> layout, 75 std::optional<Device> device, 76 std::optional<bool> pin_memory, 77 std::optional<MemoryFormat> memory_format) { 78 auto size = C10_AS_INTARRAYREF_SLOW(sym_size); 79 TORCH_CHECK( 80 !pin_memory.has_value(), 81 "'pin_memory' argument is incompatible with Metal tensor"); 82 TORCH_CHECK( 83 !memory_format.has_value(), 84 "'memory_format' argument is incompatible with Metal tensor"); 85 MetalTensorImplStorage mt{size.vec()}; 86 return makeTensor( 87 std::move(mt), at::device(at::kMetal).dtype(dtype)); 88}; 89 90static Tensor empty_strided( 91 IntArrayRef size, 92 IntArrayRef stride, 93 std::optional<ScalarType> dtype, 94 std::optional<Layout> layout, 95 std::optional<Device> device, 96 std::optional<bool> pin_memory) { 97 TORCH_CHECK( 98 !pin_memory.has_value() || !pin_memory.value(), 99 "'pin_memory' argument is incompatible with Metal tensor"); 100 MetalTensorImplStorage mt{size.vec(), stride.vec()}; 101 return makeTensor( 102 std::move(mt), at::device(at::kMetal).dtype(dtype)); 103} 104 105 106TORCH_LIBRARY_IMPL(aten, Metal, m) { 107 m.impl(TORCH_SELECTIVE_NAME("aten::empty.memory_format"), empty); 108 m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(empty_strided)); 109} 110 111} // namespace native::metal 112 113struct MetalImpl : public at::metal::MetalInterface { 114 bool is_metal_available() const override { 115#if defined(USE_PYTORCH_METAL) 116 return [[MetalContext sharedInstance] available]; 117#else 118 return false; 119#endif 120 } 121 at::Tensor& metal_copy_(at::Tensor& input, const at::Tensor& src) 122 const override { 123 TORCH_CHECK( 124 is_metal_available(), "Metal is not available on the current device"); 125 return native::metal::metal_copy_impl_(input, src); 126 } 127}; 128#if defined(USE_PYTORCH_METAL) 129static at::metal::MetalImplRegistrar g_metal_impl(new MetalImpl()); 130#endif 131 132} // namespace at 133