1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9#include <torch/library.h> 10 11namespace at::native::metal { 12 13static Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) { 14 TORCH_CHECK(input.is_metal()); 15 MPSImage* X = imageFromTensor(input); 16 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 17 MPSImage* Y = createTemporaryImage(commandBuffer, input.sizes().vec()); 18 float min = min_val.toFloat(); 19 float max = max_val.toFloat(); 20 MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ X, Y ] 21 Args:@[ @(min), @(max) ]]; 22 [clampOp encode:commandBuffer.buffer]; 23 using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>; 24 MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl(); 25 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle(); 26 implStorage.texture()->setImage(Y); 27 return input; 28} 29 30static Tensor hardtanh( 31 const Tensor& input, 32 const Scalar& min_val, 33 const Scalar& max_val) { 34 TORCH_CHECK(input.is_metal()); 35 IntArrayRef outputSize = input.sizes(); 36 if (input.numel() == 0) { 37 return makeTensor({outputSize.vec()}, input.options()); 38 } 39 MetalTensorImplStorage mt{outputSize.vec()}; 40 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 41 mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer); 42 MPSImage* Y = mt.texture()->image(); 43 float min = min_val.toFloat(); 44 float max = max_val.toFloat(); 45 MPSImage* X = imageFromTensor(input); 46 MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ X, Y ] 47 Args:@[ @(min), @(max) ]]; 48 [clampOp encode:commandBuffer.buffer]; 49 auto output = makeTensor(std::move(mt), input.options()); 50 return output; 51} 52 53static at::Tensor clamp( 54 const at::Tensor& input, 55 const std::optional<at::Scalar>& min, 56 const std::optional<at::Scalar>& max) { 57 TORCH_CHECK(min.has_value() && max.has_value()); 58 return hardtanh(input, min.value(), max.value()); 59} 60 61TORCH_LIBRARY_IMPL(aten, Metal, m) { 62 m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), TORCH_FN(hardtanh_)); 63 m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), TORCH_FN(hardtanh)); 64 m.impl(TORCH_SELECTIVE_NAME("aten::clamp"), TORCH_FN(clamp)); 65} 66 67} // namespace at::native::metal 68