1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalContext.h> 4#import <ATen/native/metal/MetalTensorImpl.h> 5#import <ATen/native/metal/MetalTensorImplStorage.h> 6#import <ATen/native/metal/MetalTensorUtils.h> 7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 9#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 15 16static Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) { 17 MPSImage* X = imageFromTensor(input); 18 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 19 IntArrayRef outputSize = input.sizes(); 20 std::vector<int64_t> imageSize = computeImageSize(outputSize); 21 float negative_slope = negative_slope_val.toFloat(); 22 MPSImage* Y = createTemporaryImage(commandBuffer, imageSize); 23 id<MTLComputeCommandEncoder> encoder = 24 [commandBuffer.buffer computeCommandEncoder]; 25 id<MTLComputePipelineState> state = 26 [[MetalContext sharedInstance] specializedPipelineState:"leaky_relu" 27 Constants:@[ 28 @(X.numberOfImages), 29 @(X.featureChannels), 30 @(X.height), 31 @(X.width), 32 @(negative_slope) 33 ]]; 34 35 [encoder setComputePipelineState:state]; 36 [encoder setTexture:[X texture] atIndex:0]; 37 [encoder setTexture:[Y texture] atIndex:1]; 38 39 const auto& launchParams = 40 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 41 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 42 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 43 [encoder endEncoding]; 44 MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); 45 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 46 implStorage.texture()->setImage(Y); 47 return input; 48} 49 50static Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) { 51 MPSImage* X = imageFromTensor(input); 52 IntArrayRef outputSize = input.sizes(); 53 MetalTensorImplStorage mt{outputSize.vec()}; 54 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 55 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 56 float negative_slope = negative_slope_val.toFloat(); 57 MPSImage* Y = mt.texture()->image(); 58 id<MTLComputeCommandEncoder> encoder = 59 [commandBuffer.buffer computeCommandEncoder]; 60 id<MTLComputePipelineState> state = 61 [[MetalContext sharedInstance] specializedPipelineState:"leaky_relu" 62 Constants:@[ 63 @(X.numberOfImages), 64 @(X.featureChannels), 65 @(X.height), 66 @(X.width), 67 @(negative_slope) 68 ]]; 69 70 [encoder setComputePipelineState:state]; 71 [encoder setTexture:[X texture] atIndex:0]; 72 [encoder setTexture:[Y texture] atIndex:1]; 73 74 const auto& launchParams = 75 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 76 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 77 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 78 [encoder endEncoding]; 79 80 auto output = makeTensor(std::move(mt), input.options()); 81 return output; 82} 83 84TORCH_LIBRARY_IMPL(aten, Metal, m) { 85 m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), TORCH_FN(leaky_relu_)); 86 m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), TORCH_FN(leaky_relu)); 87} 88 89} // namespace at::native::metal 90