xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalChunk.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#include <ATen/Tensor.h>
2#import <ATen/native/metal/MetalCommandBuffer.h>
3#import <ATen/native/metal/MetalTensorImpl.h>
4#import <ATen/native/metal/MetalTensorImplStorage.h>
5#import <ATen/native/metal/MetalTensorUtils.h>
6#import <ATen/native/metal/MetalContext.h>
7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
9#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
10#include <torch/library.h>
11
12namespace at::native::metal {
13
14// Split the input tensor into two on channel dimension
15// TODO: [T87567124] Fully implement chunk in Metal shader
16static std::vector<Tensor> chunk(const Tensor& input, int64_t chunks, int64_t dim) {
17  TORCH_CHECK(chunks == 2 && dim == 1);
18  TORCH_CHECK(input.dim() == 4);
19  TORCH_CHECK(input.size(0) == 1);
20  int64_t dim_size = input.size(dim);
21  int64_t split_size = (dim_size + chunks - 1) / chunks;
22  int64_t num_splits = 1;
23  if (split_size != 0) {
24    num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
25  }
26  std::vector<Tensor> splits(num_splits);
27  int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
28  MPSImage* X = imageFromTensor(input);
29  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
30  auto outputSize1 = {input.size(0), split_size, input.size(2), input.size(3)};
31  auto outputSize2 = {input.size(0), last_split_size, input.size(2), input.size(3)};
32  MetalTensorImplStorage mt1(outputSize1);
33  MetalTensorImplStorage mt2(outputSize2);
34  mt1.texture()->allocateTemporaryStorage(outputSize1, commandBuffer);
35  mt2.texture()->allocateTemporaryStorage(outputSize2, commandBuffer);
36  MPSImage* Y1 = mt1.texture()->image();
37  MPSImage* Y2 = mt2.texture()->image();
38  id<MTLComputePipelineState> state = [[MetalContext sharedInstance]
39      specializedPipelineState:"split_channels"
40                     Constants:@[
41                         @(X.featureChannels),
42                         @(Y1.featureChannels),
43                         @(Y2.featureChannels)]];
44  id<MTLComputeCommandEncoder> encoder =
45      [commandBuffer.buffer computeCommandEncoder];
46  [encoder setComputePipelineState:state];
47  [encoder setTexture:[X texture] atIndex:0];
48  [encoder setTexture:[Y1 texture] atIndex:1];
49  [encoder setTexture:[Y2 texture] atIndex:2];
50  const auto& launchParams =
51      mpscnn::spatialPointwiseKernelLaunchParams(state, X);
52  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
53          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
54  [encoder endEncoding];
55  auto output1 = makeTensor(std::move(mt1), input.options());
56  auto output2 = makeTensor(std::move(mt2), input.options());
57  return {output1, output2};
58}
59
60TORCH_LIBRARY_IMPL(aten, Metal, m) {
61  m.impl(TORCH_SELECTIVE_NAME("aten::chunk"), TORCH_FN(chunk));
62}
63
64} // namespace at::native::metal
65