1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalContext.h> 4#import <ATen/native/metal/MetalTensorImpl.h> 5#import <ATen/native/metal/MetalTensorImplStorage.h> 6#import <ATen/native/metal/MetalTensorUtils.h> 7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 9#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 15 16// NB: this is currently unused, but I've left it because in principle 17// it's useful 18static Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) { 19 float l = lambda.toFloat(); 20 MPSImage* X = imageFromTensor(input); 21 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 22 IntArrayRef outputSize = input.sizes(); 23 std::vector<int64_t> imageSize = computeImageSize(outputSize); 24 MPSImage* Y = createTemporaryImage(commandBuffer, imageSize); 25 id<MTLComputeCommandEncoder> encoder = 26 [commandBuffer.buffer computeCommandEncoder]; 27 id<MTLComputePipelineState> state = 28 [[MetalContext sharedInstance] specializedPipelineState:"hardshrink" 29 Constants:@[ 30 @(X.numberOfImages), 31 @(X.featureChannels), 32 @(X.height), 33 @(X.width), 34 @(l) 35 ]]; 36 37 [encoder setComputePipelineState:state]; 38 [encoder setTexture:[X texture] atIndex:0]; 39 [encoder setTexture:[Y texture] atIndex:1]; 40 41 const auto& launchParams = 42 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 43 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 44 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 45 [encoder endEncoding]; 46 MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); 47 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 48 implStorage.texture()->setImage(Y); 49 return input; 50} 51 52static Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) { 53 float l = lambda.toFloat(); 54 MPSImage* X = imageFromTensor(input); 55 IntArrayRef outputSize = input.sizes(); 56 MetalTensorImplStorage mt{outputSize.vec()}; 57 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 58 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 59 MPSImage* Y = mt.texture()->image(); 60 id<MTLComputeCommandEncoder> encoder = 61 [commandBuffer.buffer computeCommandEncoder]; 62 id<MTLComputePipelineState> state = 63 [[MetalContext sharedInstance] specializedPipelineState:"hardshrink" 64 Constants:@[ 65 @(X.numberOfImages), 66 @(X.featureChannels), 67 @(X.height), 68 @(X.width), 69 @(l) 70 ]]; 71 72 [encoder setComputePipelineState:state]; 73 [encoder setTexture:[X texture] atIndex:0]; 74 [encoder setTexture:[Y texture] atIndex:1]; 75 76 const auto& launchParams = 77 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 78 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 79 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 80 [encoder endEncoding]; 81 82 auto output = makeTensor(std::move(mt), input.options()); 83 return output; 84} 85 86TORCH_LIBRARY_IMPL(aten, Metal, m) { 87 m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink)); 88} 89 90} // namespace at::native::metal 91