1 #import <ATen/Tensor.h>
2 #import <ATen/native/metal/MetalCommandBuffer.h>
3 #import <ATen/native/metal/MetalTensorImpl.h>
4 #import <ATen/native/metal/MetalTensorUtils.h>
5
6 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
7
8 namespace at {
9 namespace native {
10 namespace metal {
11
12 MPSImage* createStaticImage(IntArrayRef sizes);
13 MPSImage* createStaticImage(const float* src, const IntArrayRef sizes);
14 MPSImage* createStaticImage(
15 MPSTemporaryImage* image,
16 MetalCommandBuffer* buffer,
17 bool waitUntilCompleted);
18
19 MPSTemporaryImage* createTemporaryImage(
20 MetalCommandBuffer* buffer,
21 const IntArrayRef sizes);
22 MPSTemporaryImage* createTemporaryImage(
23 MetalCommandBuffer* buffer,
24 const IntArrayRef sizes,
25 const float* src);
26 MPSTemporaryImage* createTemporaryImage(
27 MetalCommandBuffer* buffer,
28 MPSImage* image);
29
30 void copyImageToFloatBuffer(float* dst, MPSImage* image);
31
32 void copyImageToMetalBuffer(
33 MetalCommandBuffer* buffer,
34 id<MTLBuffer> dst,
35 MPSImage* image);
36
imageFromTensor(const Tensor & tensor)37 static inline MPSImage* imageFromTensor(const Tensor& tensor) {
38 TORCH_CHECK(tensor.is_metal());
39 using MetalTensorImplStorage = at::native::metal::MetalTensorImplStorage;
40 using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
41 MetalTensorImpl* impl = (MetalTensorImpl*)tensor.unsafeGetTensorImpl();
42 MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
43 return implStorage.texture()->image();
44 }
45
46 /*
47 MPSImage carries a IntList shape which is identical to the shape of the CPU
48 tensor it’s converted from.
49 1) 1D tensors (W,) are always stored as MPSImage(N=1, C=1, H=1, W=W).
50 2) 2D tensors (H, W) are always stored as MPSImage(N=1, C=1, H=H, W=W).
51 3) 3D tensors (C, H, W) are always stored as MPSImage(N=1, C=C, H=H, W=W).
52 4) 4D tensors (N, C, H, W) are always stored as MPSImage(N=N, C=C, H=H, W=W).
53 5) 5D tensors (T, N, C, H, W) are always stored as MPSImage(N=T*N, C=C, H=H,
54 W=W). 6) ...
55 */
computeImageSize(IntArrayRef sizes)56 static inline std::vector<int64_t> computeImageSize(IntArrayRef sizes) {
57 std::vector<int64_t> imageSize(4, 1);
58 int64_t index = 3;
59 int64_t batch = 1;
60 for (int64_t i = sizes.size() - 1; i >= 0; i--) {
61 if (index != 0) {
62 imageSize[index] = sizes[i];
63 index--;
64 continue;
65 }
66 // For higher dimensional tensors,
67 // multiply rest of dims into imageSize[0]
68 batch *= sizes[i];
69 }
70 imageSize[0] = batch;
71 return imageSize;
72 }
73
74 } // namespace metal
75 } // namespace native
76 } // namespace at
77