1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/MetalContext.h> 7#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h> 8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 9#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 15 16static Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) { 17 MPSImage* X = imageFromTensor(input); 18 IntArrayRef outputSize = input.sizes(); 19 if(input.numel() == 0){ 20 return makeTensor({outputSize.vec()}, input.options()); 21 } 22 IntArrayRef textureSize = outputSize; 23 MetalTensorImplStorage mt{outputSize.vec()}; 24 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 25 mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer); 26 MPSImage* Y = mt.texture()->image(); 27 [neuron encodeToCommandBuffer:commandBuffer.buffer 28 sourceImage:X 29 destinationImage:Y]; 30 auto output = makeTensor(std::move(mt), input.options()); 31 return output; 32} 33 34static Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) { 35 MPSImage* X = imageFromTensor(input); 36 IntArrayRef outputSize = input.sizes(); 37 if(input.numel() == 0){ 38 return input; 39 } 40 IntArrayRef textureSize = outputSize; 41 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 42 MPSImage* Y = createTemporaryImage(commandBuffer, textureSize); 43 [neuron encodeToCommandBuffer:commandBuffer.buffer 44 sourceImage:X 45 destinationImage:Y]; 46 MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); 47 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 48 implStorage.texture()->setImage(Y); 49 return input; 50} 51 52API_AVAILABLE(ios(11.0), macos(10.13)) 53static Tensor relu(const Tensor& input) { 54 TORCH_CHECK(input.is_metal()); 55 return neuronKernel(input, [MPSCNNNeuronOp relu]); 56} 57 58API_AVAILABLE(ios(11.0), macos(10.13)) 59static Tensor& relu_(Tensor& input) { 60 TORCH_CHECK(input.is_metal()); 61 return neuronKernel_(input, [MPSCNNNeuronOp relu]); 62} 63 64API_AVAILABLE(ios(11.0), macos(10.13)) 65static Tensor sigmoid(const Tensor& input) { 66 return neuronKernel(input, [MPSCNNNeuronOp sigmoid]); 67} 68 69API_AVAILABLE(ios(11.0), macos(10.13)) 70static Tensor& hardsigmoid_(Tensor& input) { 71 TORCH_CHECK(input.is_metal()); 72 return neuronKernel_(input, [MPSCNNNeuronOp hardSigmoid]); 73} 74 75API_AVAILABLE(ios(11.0), macos(10.13)) 76static Tensor tanh(const Tensor& input) { 77 TORCH_CHECK(input.is_metal()); 78 return neuronKernel(input, [MPSCNNNeuronOp tanh]); 79} 80 81TORCH_LIBRARY_IMPL(aten, Metal, m) { 82 m.impl(TORCH_SELECTIVE_NAME("aten::tanh"), tanh); 83 m.impl(TORCH_SELECTIVE_NAME("aten::relu"), TORCH_FN(relu)); 84 m.impl(TORCH_SELECTIVE_NAME("aten::relu_"), TORCH_FN(relu_)); 85 m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), TORCH_FN(sigmoid)); 86 m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), TORCH_FN(hardsigmoid_)); 87} 88 89} // namepsace at::native::metal 90