1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalContext.h> 4#import <ATen/native/metal/MetalTensorImpl.h> 5#import <ATen/native/metal/MetalTensorImplStorage.h> 6#import <ATen/native/metal/MetalTensorUtils.h> 7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 9#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 15 16static Tensor& hardswish_(Tensor& input) { 17 MPSImage* X = imageFromTensor(input); 18 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 19 IntArrayRef outputSize = input.sizes(); 20 std::vector<int64_t> imageSize = computeImageSize(outputSize); 21 MPSImage* Y = createTemporaryImage(commandBuffer, imageSize); 22 id<MTLComputeCommandEncoder> encoder = 23 [commandBuffer.buffer computeCommandEncoder]; 24 id<MTLComputePipelineState> state = 25 [[MetalContext sharedInstance] specializedPipelineState:"hardswish" 26 Constants:@[ 27 @(X.numberOfImages), 28 @(X.featureChannels), 29 @(X.height), 30 @(X.width) 31 ]]; 32 33 [encoder setComputePipelineState:state]; 34 [encoder setTexture:[X texture] atIndex:0]; 35 [encoder setTexture:[Y texture] atIndex:1]; 36 37 const auto& launchParams = 38 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 39 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 40 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 41 [encoder endEncoding]; 42 MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); 43 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 44 implStorage.texture()->setImage(Y); 45 return input; 46} 47 48static Tensor hardswish(const at::Tensor& input) { 49 MPSImage* X = imageFromTensor(input); 50 IntArrayRef outputSize = input.sizes(); 51 MetalTensorImplStorage mt{outputSize.vec()}; 52 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 53 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 54 MPSImage* Y = mt.texture()->image(); 55 id<MTLComputeCommandEncoder> encoder = 56 [commandBuffer.buffer computeCommandEncoder]; 57 id<MTLComputePipelineState> state = 58 [[MetalContext sharedInstance] specializedPipelineState:"hardswish" 59 Constants:@[ 60 @(X.numberOfImages), 61 @(X.featureChannels), 62 @(X.height), 63 @(X.width) 64 ]]; 65 66 [encoder setComputePipelineState:state]; 67 [encoder setTexture:[X texture] atIndex:0]; 68 [encoder setTexture:[Y texture] atIndex:1]; 69 70 const auto& launchParams = 71 metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X); 72 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 73 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 74 [encoder endEncoding]; 75 76 auto output = makeTensor(std::move(mt), input.options()); 77 return output; 78} 79 80TORCH_LIBRARY_IMPL(aten, Metal, m) { 81 m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), TORCH_FN(hardswish_)); 82 m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), TORCH_FN(hardswish)); 83} 84 85} // namespace at::native::metal 86