xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalClamp.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9#include <torch/library.h>
10
11namespace at::native::metal {
12
13static Tensor& hardtanh_(Tensor& input, const Scalar& min_val, const Scalar& max_val) {
14  TORCH_CHECK(input.is_metal());
15  MPSImage* X = imageFromTensor(input);
16  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
17  MPSImage* Y = createTemporaryImage(commandBuffer, input.sizes().vec());
18  float min = min_val.toFloat();
19  float max = max_val.toFloat();
20  MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ X, Y ]
21                                                     Args:@[ @(min), @(max) ]];
22  [clampOp encode:commandBuffer.buffer];
23  using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
24  MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
25  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
26  implStorage.texture()->setImage(Y);
27  return input;
28}
29
30static Tensor hardtanh(
31    const Tensor& input,
32    const Scalar& min_val,
33    const Scalar& max_val) {
34  TORCH_CHECK(input.is_metal());
35  IntArrayRef outputSize = input.sizes();
36  if (input.numel() == 0) {
37    return makeTensor({outputSize.vec()}, input.options());
38  }
39  MetalTensorImplStorage mt{outputSize.vec()};
40  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
41  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
42  MPSImage* Y = mt.texture()->image();
43  float min = min_val.toFloat();
44  float max = max_val.toFloat();
45  MPSImage* X = imageFromTensor(input);
46  MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ X, Y ]
47                                                     Args:@[ @(min), @(max) ]];
48  [clampOp encode:commandBuffer.buffer];
49  auto output = makeTensor(std::move(mt), input.options());
50  return output;
51}
52
53static at::Tensor clamp(
54    const at::Tensor& input,
55    const std::optional<at::Scalar>& min,
56    const std::optional<at::Scalar>& max) {
57  TORCH_CHECK(min.has_value() && max.has_value());
58  return hardtanh(input, min.value(), max.value());
59}
60
61TORCH_LIBRARY_IMPL(aten, Metal, m) {
62  m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh_"), TORCH_FN(hardtanh_));
63  m.impl(TORCH_SELECTIVE_NAME("aten::hardtanh"), TORCH_FN(hardtanh));
64  m.impl(TORCH_SELECTIVE_NAME("aten::clamp"), TORCH_FN(clamp));
65}
66
67} // namespace at::native::metal
68