xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalNeurons.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/MetalContext.h>
7#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
9#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
15
16static Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
17  MPSImage* X = imageFromTensor(input);
18  IntArrayRef outputSize = input.sizes();
19  if(input.numel() == 0){
20    return makeTensor({outputSize.vec()}, input.options());
21  }
22  IntArrayRef textureSize = outputSize;
23  MetalTensorImplStorage mt{outputSize.vec()};
24  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
25  mt.texture()->allocateTemporaryStorage(textureSize, commandBuffer);
26  MPSImage* Y = mt.texture()->image();
27  [neuron encodeToCommandBuffer:commandBuffer.buffer
28                    sourceImage:X
29               destinationImage:Y];
30  auto output = makeTensor(std::move(mt), input.options());
31  return output;
32}
33
34static Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) {
35  MPSImage* X = imageFromTensor(input);
36  IntArrayRef outputSize = input.sizes();
37  if(input.numel() == 0){
38    return input;
39  }
40  IntArrayRef textureSize = outputSize;
41  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
42  MPSImage* Y = createTemporaryImage(commandBuffer, textureSize);
43  [neuron encodeToCommandBuffer:commandBuffer.buffer
44                    sourceImage:X
45               destinationImage:Y];
46  MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
47  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
48  implStorage.texture()->setImage(Y);
49  return input;
50}
51
52API_AVAILABLE(ios(11.0), macos(10.13))
53static Tensor relu(const Tensor& input) {
54  TORCH_CHECK(input.is_metal());
55  return neuronKernel(input, [MPSCNNNeuronOp relu]);
56}
57
58API_AVAILABLE(ios(11.0), macos(10.13))
59static Tensor& relu_(Tensor& input) {
60  TORCH_CHECK(input.is_metal());
61  return neuronKernel_(input, [MPSCNNNeuronOp relu]);
62}
63
64API_AVAILABLE(ios(11.0), macos(10.13))
65static Tensor sigmoid(const Tensor& input) {
66  return neuronKernel(input, [MPSCNNNeuronOp sigmoid]);
67}
68
69API_AVAILABLE(ios(11.0), macos(10.13))
70static Tensor& hardsigmoid_(Tensor& input) {
71  TORCH_CHECK(input.is_metal());
72  return neuronKernel_(input, [MPSCNNNeuronOp hardSigmoid]);
73}
74
75API_AVAILABLE(ios(11.0), macos(10.13))
76static Tensor tanh(const Tensor& input) {
77  TORCH_CHECK(input.is_metal());
78  return neuronKernel(input, [MPSCNNNeuronOp tanh]);
79}
80
81TORCH_LIBRARY_IMPL(aten, Metal, m) {
82  m.impl(TORCH_SELECTIVE_NAME("aten::tanh"), tanh);
83  m.impl(TORCH_SELECTIVE_NAME("aten::relu"), TORCH_FN(relu));
84  m.impl(TORCH_SELECTIVE_NAME("aten::relu_"), TORCH_FN(relu_));
85  m.impl(TORCH_SELECTIVE_NAME("aten::sigmoid"), TORCH_FN(sigmoid));
86  m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), TORCH_FN(hardsigmoid_));
87}
88
89} // namepsace at::native::metal
90