xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/ops/MetalReshape.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalCommandBuffer.h>
2#import <ATen/native/metal/MetalTensorImpl.h>
3#import <ATen/native/metal/MetalTensorImplStorage.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#import <ATen/native/metal/MetalContext.h>
6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
8#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
9
10#include <ATen/InferSize.h>
11#include <ATen/TensorUtils.h>
12#include <torch/library.h>
13
14namespace at::native::metal {
15
16API_AVAILABLE(ios(11.0), macos(10.13))
17static Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) {
18  auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
19  TORCH_CHECK(input.is_metal());
20  auto inferred_size = at::infer_size(size, input.numel());
21  auto stride =
22      at::detail::computeStride(input.sizes(), input.strides(), inferred_size);
23  TORCH_CHECK(
24      stride.has_value(),
25      "view size is "
26      "not compatible with input tensor's size and stride (at least one dimension"
27      " spans across two contiguous subspaces). Use .reshape(...) instead.");
28  auto stride_value = *stride;
29  if(input.numel() == 0) {
30    return makeTensor({inferred_size, stride_value}, input.options());
31  }
32  MPSImage* X = imageFromTensor(input);
33  MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
34  MetalTensorImplStorage mt{inferred_size, stride_value};
35  mt.texture()->allocateTemporaryStorage(inferred_size, commandBuffer);
36  MPSImage* Y = mt.texture()->image();
37  id<MTLComputePipelineState> state =
38      [[MetalContext sharedInstance] specializedPipelineState:"reshape"
39                                                     Constants:@[
40                                                       @(Y.height),
41                                                       @(Y.width),
42                                                       @(Y.featureChannels),
43                                                       @(Y.numberOfImages),
44                                                       @(X.height),
45                                                       @(X.width),
46                                                       @(X.featureChannels),
47                                                       @(X.numberOfImages),
48                                                     ]];
49  id<MTLComputeCommandEncoder> encoder =
50      [commandBuffer.buffer computeCommandEncoder];
51  [encoder setComputePipelineState:state];
52  [encoder setTexture:[X texture] atIndex:0];
53  [encoder setTexture:[Y texture] atIndex:1];
54  const auto& launchParams =
55      mpscnn::spatialPointwiseKernelLaunchParams(state, Y);
56  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
57          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
58  [encoder endEncoding];
59  auto output = makeTensor(std::move(mt), input.options());
60  return output;
61}
62
63static Tensor reshape(const Tensor& input, IntArrayRef shape) {
64  TORCH_CHECK(input.is_metal());
65  return view(input, c10::fromIntArrayRefSlow(shape));
66}
67
68static Tensor flatten_using_ints(
69    const Tensor& input,
70    int64_t start_dim,
71    int64_t end_dim) {
72  TORCH_CHECK(input.is_metal());
73  start_dim = maybe_wrap_dim(start_dim, input.dim());
74  end_dim = maybe_wrap_dim(end_dim, input.dim());
75  TORCH_CHECK(
76      start_dim <= end_dim,
77      "flatten() has invalid args: start_dim cannot come after end_dim");
78  std::vector<int64_t> shape;
79  if (input.dim() == 0) {
80    return input.reshape({1});
81  }
82  if (start_dim == end_dim) {
83    return input;
84  }
85  const auto slice_numel = c10::multiply_integers(
86      input.sizes().slice(start_dim, end_dim - start_dim + 1));
87  shape.reserve(input.dim() - end_dim + start_dim);
88  for (int64_t i = 0; i < start_dim; i++) {
89    shape.push_back(input.size(i));
90  }
91  shape.push_back(slice_numel);
92  for (int64_t i = end_dim + 1; i < input.dim(); i++) {
93    shape.push_back(input.size(i));
94  }
95  return input.reshape(shape);
96}
97
98static Tensor detach(const Tensor& input) {
99  TORCH_CHECK(input.is_metal());
100  return input;
101}
102
103TORCH_LIBRARY_IMPL(aten, Metal, m) {
104  m.impl(TORCH_SELECTIVE_NAME("aten::detach"), TORCH_FN(detach));
105  m.impl(TORCH_SELECTIVE_NAME("aten::view"), TORCH_FN(view));
106  m.impl(TORCH_SELECTIVE_NAME("aten::reshape"), TORCH_FN(reshape));
107  m.impl(TORCH_SELECTIVE_NAME("aten::flatten.using_ints"), TORCH_FN(flatten_using_ints));
108}
109
110} // namespace at::native::metal
111