xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalHardshrink.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalContext.h>
4#import <ATen/native/metal/MetalTensorImpl.h>
5#import <ATen/native/metal/MetalTensorImplStorage.h>
6#import <ATen/native/metal/MetalTensorUtils.h>
7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
9#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
15
16// NB: this is currently unused, but I've left it because in principle
17// it's useful
18static Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) {
19  float l = lambda.toFloat();
20  MPSImage* X = imageFromTensor(input);
21  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
22  IntArrayRef outputSize = input.sizes();
23  std::vector<int64_t> imageSize = computeImageSize(outputSize);
24  MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
25  id<MTLComputeCommandEncoder> encoder =
26      [commandBuffer.buffer computeCommandEncoder];
27  id<MTLComputePipelineState> state =
28      [[MetalContext sharedInstance] specializedPipelineState:"hardshrink"
29                                                    Constants:@[
30                                                      @(X.numberOfImages),
31                                                      @(X.featureChannels),
32                                                      @(X.height),
33                                                      @(X.width),
34                                                      @(l)
35                                                    ]];
36
37  [encoder setComputePipelineState:state];
38  [encoder setTexture:[X texture] atIndex:0];
39  [encoder setTexture:[Y texture] atIndex:1];
40
41  const auto& launchParams =
42      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
43  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
44          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
45  [encoder endEncoding];
46  MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
47  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
48  implStorage.texture()->setImage(Y);
49  return input;
50}
51
52static Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) {
53  float l = lambda.toFloat();
54  MPSImage* X = imageFromTensor(input);
55  IntArrayRef outputSize = input.sizes();
56  MetalTensorImplStorage mt{outputSize.vec()};
57  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
58  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
59  MPSImage* Y = mt.texture()->image();
60  id<MTLComputeCommandEncoder> encoder =
61      [commandBuffer.buffer computeCommandEncoder];
62  id<MTLComputePipelineState> state =
63      [[MetalContext sharedInstance] specializedPipelineState:"hardshrink"
64                                                    Constants:@[
65                                                      @(X.numberOfImages),
66                                                      @(X.featureChannels),
67                                                      @(X.height),
68                                                      @(X.width),
69                                                      @(l)
70                                                    ]];
71
72  [encoder setComputePipelineState:state];
73  [encoder setTexture:[X texture] atIndex:0];
74  [encoder setTexture:[Y texture] atIndex:1];
75
76  const auto& launchParams =
77      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
78  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
79          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
80  [encoder endEncoding];
81
82  auto output = makeTensor(std::move(mt), input.options());
83  return output;
84}
85
86TORCH_LIBRARY_IMPL(aten, Metal, m) {
87  m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink));
88}
89
90} // namespace at::native::metal
91