xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSImageUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #import <ATen/Tensor.h>
2 #import <ATen/native/metal/MetalCommandBuffer.h>
3 #import <ATen/native/metal/MetalTensorImpl.h>
4 #import <ATen/native/metal/MetalTensorUtils.h>
5 
6 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
7 
8 namespace at {
9 namespace native {
10 namespace metal {
11 
12 MPSImage* createStaticImage(IntArrayRef sizes);
13 MPSImage* createStaticImage(const float* src, const IntArrayRef sizes);
14 MPSImage* createStaticImage(
15     MPSTemporaryImage* image,
16     MetalCommandBuffer* buffer,
17     bool waitUntilCompleted);
18 
19 MPSTemporaryImage* createTemporaryImage(
20     MetalCommandBuffer* buffer,
21     const IntArrayRef sizes);
22 MPSTemporaryImage* createTemporaryImage(
23     MetalCommandBuffer* buffer,
24     const IntArrayRef sizes,
25     const float* src);
26 MPSTemporaryImage* createTemporaryImage(
27     MetalCommandBuffer* buffer,
28     MPSImage* image);
29 
30 void copyImageToFloatBuffer(float* dst, MPSImage* image);
31 
32 void copyImageToMetalBuffer(
33     MetalCommandBuffer* buffer,
34     id<MTLBuffer> dst,
35     MPSImage* image);
36 
imageFromTensor(const Tensor & tensor)37 static inline MPSImage* imageFromTensor(const Tensor& tensor) {
38   TORCH_CHECK(tensor.is_metal());
39   using MetalTensorImplStorage = at::native::metal::MetalTensorImplStorage;
40   using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
41   MetalTensorImpl* impl = (MetalTensorImpl*)tensor.unsafeGetTensorImpl();
42   MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
43   return implStorage.texture()->image();
44 }
45 
46 /*
47 MPSImage carries a IntList shape which is identical to the shape of the CPU
48 tensor it’s converted from.
49 1) 1D tensors (W,) are always stored as MPSImage(N=1, C=1, H=1, W=W).
50 2) 2D tensors (H, W) are always stored as MPSImage(N=1, C=1, H=H, W=W).
51 3) 3D tensors (C, H, W) are always stored as MPSImage(N=1, C=C, H=H, W=W).
52 4) 4D tensors (N, C, H, W) are always stored as MPSImage(N=N, C=C, H=H, W=W).
53 5) 5D tensors (T, N, C, H, W) are always stored as MPSImage(N=T*N, C=C, H=H,
54 W=W). 6) ...
55  */
computeImageSize(IntArrayRef sizes)56 static inline std::vector<int64_t> computeImageSize(IntArrayRef sizes) {
57   std::vector<int64_t> imageSize(4, 1);
58   int64_t index = 3;
59   int64_t batch = 1;
60   for (int64_t i = sizes.size() - 1; i >= 0; i--) {
61     if (index != 0) {
62       imageSize[index] = sizes[i];
63       index--;
64       continue;
65     }
66     // For higher dimensional tensors,
67     // multiply rest of dims into imageSize[0]
68     batch *= sizes[i];
69   }
70   imageSize[0] = batch;
71   return imageSize;
72 }
73 
74 } // namespace metal
75 } // namespace native
76 } // namespace at
77