1#import <ATen/native/metal/MetalCommandBuffer.h> 2#import <ATen/native/metal/MetalTensorImpl.h> 3#import <ATen/native/metal/MetalTensorImplStorage.h> 4#import <ATen/native/metal/MetalTensorUtils.h> 5#import <ATen/native/metal/MetalContext.h> 6#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 7#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 8#import <ATen/native/metal/mpscnn/MPSImageUtils.h> 9 10#include <ATen/InferSize.h> 11#include <ATen/TensorUtils.h> 12#include <torch/library.h> 13 14namespace at::native::metal { 15 16API_AVAILABLE(ios(11.0), macos(10.13)) 17static Tensor view(const Tensor& input, c10::SymIntArrayRef sym_size) { 18 auto size = C10_AS_INTARRAYREF_SLOW(sym_size); 19 TORCH_CHECK(input.is_metal()); 20 auto inferred_size = at::infer_size(size, input.numel()); 21 auto stride = 22 at::detail::computeStride(input.sizes(), input.strides(), inferred_size); 23 TORCH_CHECK( 24 stride.has_value(), 25 "view size is " 26 "not compatible with input tensor's size and stride (at least one dimension" 27 " spans across two contiguous subspaces). Use .reshape(...) instead."); 28 auto stride_value = *stride; 29 if(input.numel() == 0) { 30 return makeTensor({inferred_size, stride_value}, input.options()); 31 } 32 MPSImage* X = imageFromTensor(input); 33 MetalCommandBuffer* commandBuffer = getCommandBuffer(input); 34 MetalTensorImplStorage mt{inferred_size, stride_value}; 35 mt.texture()->allocateTemporaryStorage(inferred_size, commandBuffer); 36 MPSImage* Y = mt.texture()->image(); 37 id<MTLComputePipelineState> state = 38 [[MetalContext sharedInstance] specializedPipelineState:"reshape" 39 Constants:@[ 40 @(Y.height), 41 @(Y.width), 42 @(Y.featureChannels), 43 @(Y.numberOfImages), 44 @(X.height), 45 @(X.width), 46 @(X.featureChannels), 47 @(X.numberOfImages), 48 ]]; 49 id<MTLComputeCommandEncoder> encoder = 50 [commandBuffer.buffer computeCommandEncoder]; 51 [encoder setComputePipelineState:state]; 52 [encoder setTexture:[X texture] atIndex:0]; 53 [encoder setTexture:[Y texture] atIndex:1]; 54 const auto& launchParams = 55 mpscnn::spatialPointwiseKernelLaunchParams(state, Y); 56 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 57 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 58 [encoder endEncoding]; 59 auto output = makeTensor(std::move(mt), input.options()); 60 return output; 61} 62 63static Tensor reshape(const Tensor& input, IntArrayRef shape) { 64 TORCH_CHECK(input.is_metal()); 65 return view(input, c10::fromIntArrayRefSlow(shape)); 66} 67 68static Tensor flatten_using_ints( 69 const Tensor& input, 70 int64_t start_dim, 71 int64_t end_dim) { 72 TORCH_CHECK(input.is_metal()); 73 start_dim = maybe_wrap_dim(start_dim, input.dim()); 74 end_dim = maybe_wrap_dim(end_dim, input.dim()); 75 TORCH_CHECK( 76 start_dim <= end_dim, 77 "flatten() has invalid args: start_dim cannot come after end_dim"); 78 std::vector<int64_t> shape; 79 if (input.dim() == 0) { 80 return input.reshape({1}); 81 } 82 if (start_dim == end_dim) { 83 return input; 84 } 85 const auto slice_numel = c10::multiply_integers( 86 input.sizes().slice(start_dim, end_dim - start_dim + 1)); 87 shape.reserve(input.dim() - end_dim + start_dim); 88 for (int64_t i = 0; i < start_dim; i++) { 89 shape.push_back(input.size(i)); 90 } 91 shape.push_back(slice_numel); 92 for (int64_t i = end_dim + 1; i < input.dim(); i++) { 93 shape.push_back(input.size(i)); 94 } 95 return input.reshape(shape); 96} 97 98static Tensor detach(const Tensor& input) { 99 TORCH_CHECK(input.is_metal()); 100 return input; 101} 102 103TORCH_LIBRARY_IMPL(aten, Metal, m) { 104 m.impl(TORCH_SELECTIVE_NAME("aten::detach"), TORCH_FN(detach)); 105 m.impl(TORCH_SELECTIVE_NAME("aten::view"), TORCH_FN(view)); 106 m.impl(TORCH_SELECTIVE_NAME("aten::reshape"), TORCH_FN(reshape)); 107 m.impl(TORCH_SELECTIVE_NAME("aten::flatten.using_ints"), TORCH_FN(flatten_using_ints)); 108} 109 110} // namespace at::native::metal 111