xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalContext.h>
4#import <ATen/native/metal/MetalTensorImpl.h>
5#import <ATen/native/metal/MetalTensorImplStorage.h>
6#import <ATen/native/metal/MetalTensorUtils.h>
7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
9#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
15
16static Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) {
17  MPSImage* X = imageFromTensor(input);
18  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
19  IntArrayRef outputSize = input.sizes();
20  std::vector<int64_t> imageSize = computeImageSize(outputSize);
21  float negative_slope = negative_slope_val.toFloat();
22  MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
23  id<MTLComputeCommandEncoder> encoder =
24      [commandBuffer.buffer computeCommandEncoder];
25  id<MTLComputePipelineState> state =
26      [[MetalContext sharedInstance] specializedPipelineState:"leaky_relu"
27                                                    Constants:@[
28                                                      @(X.numberOfImages),
29                                                      @(X.featureChannels),
30                                                      @(X.height),
31                                                      @(X.width),
32                                                      @(negative_slope)
33                                                    ]];
34
35  [encoder setComputePipelineState:state];
36  [encoder setTexture:[X texture] atIndex:0];
37  [encoder setTexture:[Y texture] atIndex:1];
38
39  const auto& launchParams =
40      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
41  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
42          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
43  [encoder endEncoding];
44  MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
45  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
46  implStorage.texture()->setImage(Y);
47  return input;
48}
49
50static Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) {
51  MPSImage* X = imageFromTensor(input);
52  IntArrayRef outputSize = input.sizes();
53  MetalTensorImplStorage mt{outputSize.vec()};
54  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
55  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
56  float negative_slope = negative_slope_val.toFloat();
57  MPSImage* Y = mt.texture()->image();
58  id<MTLComputeCommandEncoder> encoder =
59      [commandBuffer.buffer computeCommandEncoder];
60  id<MTLComputePipelineState> state =
61      [[MetalContext sharedInstance] specializedPipelineState:"leaky_relu"
62                                                    Constants:@[
63                                                      @(X.numberOfImages),
64                                                      @(X.featureChannels),
65                                                      @(X.height),
66                                                      @(X.width),
67                                                      @(negative_slope)
68                                                    ]];
69
70  [encoder setComputePipelineState:state];
71  [encoder setTexture:[X texture] atIndex:0];
72  [encoder setTexture:[Y texture] atIndex:1];
73
74  const auto& launchParams =
75      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
76  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
77          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
78  [encoder endEncoding];
79
80  auto output = makeTensor(std::move(mt), input.options());
81  return output;
82}
83
84TORCH_LIBRARY_IMPL(aten, Metal, m) {
85  m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), TORCH_FN(leaky_relu_));
86  m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), TORCH_FN(leaky_relu));
87}
88
89} // namespace at::native::metal
90