xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalHardswish.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalContext.h>
4#import <ATen/native/metal/MetalTensorImpl.h>
5#import <ATen/native/metal/MetalTensorImplStorage.h>
6#import <ATen/native/metal/MetalTensorUtils.h>
7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
9#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
15
16static Tensor& hardswish_(Tensor& input) {
17  MPSImage* X = imageFromTensor(input);
18  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
19  IntArrayRef outputSize = input.sizes();
20  std::vector<int64_t> imageSize = computeImageSize(outputSize);
21  MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
22  id<MTLComputeCommandEncoder> encoder =
23      [commandBuffer.buffer computeCommandEncoder];
24  id<MTLComputePipelineState> state =
25      [[MetalContext sharedInstance] specializedPipelineState:"hardswish"
26                                                    Constants:@[
27                                                      @(X.numberOfImages),
28                                                      @(X.featureChannels),
29                                                      @(X.height),
30                                                      @(X.width)
31                                                    ]];
32
33  [encoder setComputePipelineState:state];
34  [encoder setTexture:[X texture] atIndex:0];
35  [encoder setTexture:[Y texture] atIndex:1];
36
37  const auto& launchParams =
38      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
39  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
40          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
41  [encoder endEncoding];
42  MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
43  MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
44  implStorage.texture()->setImage(Y);
45  return input;
46}
47
48static Tensor hardswish(const at::Tensor& input) {
49  MPSImage* X = imageFromTensor(input);
50  IntArrayRef outputSize = input.sizes();
51  MetalTensorImplStorage mt{outputSize.vec()};
52  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
53  mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
54  MPSImage* Y = mt.texture()->image();
55  id<MTLComputeCommandEncoder> encoder =
56      [commandBuffer.buffer computeCommandEncoder];
57  id<MTLComputePipelineState> state =
58      [[MetalContext sharedInstance] specializedPipelineState:"hardswish"
59                                                    Constants:@[
60                                                      @(X.numberOfImages),
61                                                      @(X.featureChannels),
62                                                      @(X.height),
63                                                      @(X.width)
64                                                    ]];
65
66  [encoder setComputePipelineState:state];
67  [encoder setTexture:[X texture] atIndex:0];
68  [encoder setTexture:[Y texture] atIndex:1];
69
70  const auto& launchParams =
71      metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
72  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
73          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
74  [encoder endEncoding];
75
76  auto output = makeTensor(std::move(mt), input.options());
77  return output;
78}
79
80TORCH_LIBRARY_IMPL(aten, Metal, m) {
81  m.impl(TORCH_SELECTIVE_NAME("aten::hardswish_"), TORCH_FN(hardswish_));
82  m.impl(TORCH_SELECTIVE_NAME("aten::hardswish"), TORCH_FN(hardswish));
83}
84
85} // namespace at::native::metal
86