1#include <ATen/Tensor.h> 2#import <ATen/native/metal/MetalCommandBuffer.h> 3#import <ATen/native/metal/MetalTensorImpl.h> 4#import <ATen/native/metal/MetalTensorImplStorage.h> 5#import <ATen/native/metal/MetalTensorUtils.h> 6#import <ATen/native/metal/MetalContext.h> 7#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 8#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 9#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 10#include <torch/library.h> 11 12namespace at::native::metal { 13 14// Split the input tensor into two on channel dimension 15// TODO: [T87567124] Fully implement chunk in Metal shader 16static std::vector<Tensor> chunk(const Tensor& input, int64_t chunks, int64_t dim) { 17 TORCH_CHECK(chunks == 2 && dim == 1); 18 TORCH_CHECK(input.dim() == 4); 19 TORCH_CHECK(input.size(0) == 1); 20 int64_t dim_size = input.size(dim); 21 int64_t split_size = (dim_size + chunks - 1) / chunks; 22 int64_t num_splits = 1; 23 if (split_size != 0) { 24 num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1); 25 } 26 std::vector<Tensor> splits(num_splits); 27 int64_t last_split_size = split_size - (split_size * num_splits - dim_size); 28 MPSImage* X = imageFromTensor(input); 29 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 30 auto outputSize1 = {input.size(0), split_size, input.size(2), input.size(3)}; 31 auto outputSize2 = {input.size(0), last_split_size, input.size(2), input.size(3)}; 32 MetalTensorImplStorage mt1(outputSize1); 33 MetalTensorImplStorage mt2(outputSize2); 34 mt1.texture()->allocateTemporaryStorage(outputSize1, commandBuffer); 35 mt2.texture()->allocateTemporaryStorage(outputSize2, commandBuffer); 36 MPSImage* Y1 = mt1.texture()->image(); 37 MPSImage* Y2 = mt2.texture()->image(); 38 id<MTLComputePipelineState> state = [[MetalContext sharedInstance] 39 specializedPipelineState:"split_channels" 40 Constants:@[ 41 @(X.featureChannels), 42 @(Y1.featureChannels), 43 @(Y2.featureChannels)]]; 44 id<MTLComputeCommandEncoder> encoder = 45 [commandBuffer.buffer computeCommandEncoder]; 46 [encoder setComputePipelineState:state]; 47 [encoder setTexture:[X texture] atIndex:0]; 48 [encoder setTexture:[Y1 texture] atIndex:1]; 49 [encoder setTexture:[Y2 texture] atIndex:2]; 50 const auto& launchParams = 51 mpscnn::spatialPointwiseKernelLaunchParams(state, X); 52 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 53 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 54 [encoder endEncoding]; 55 auto output1 = makeTensor(std::move(mt1), input.options()); 56 auto output2 = makeTensor(std::move(mt2), input.options()); 57 return {output1, output2}; 58} 59 60TORCH_LIBRARY_IMPL(aten, Metal, m) { 61 m.impl(TORCH_SELECTIVE_NAME("aten::chunk"), TORCH_FN(chunk)); 62} 63 64} // namespace at::native::metal 65