1#include <torch/extension.h> 2#include <ATen/native/mps/OperationUtils.h> 3 4// this sample custom kernel is taken from: 5// https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu 6static const char* CUSTOM_KERNEL = R"MPS_ADD_ARRAYS( 7#include <metal_stdlib> 8using namespace metal; 9kernel void add_arrays(device const float* inA, 10 device const float* inB, 11 device float* result, 12 uint index [[thread_position_in_grid]]) 13{ 14 result[index] = inA[index] + inB[index]; 15} 16)MPS_ADD_ARRAYS"; 17 18at::Tensor get_cpu_add_output(at::Tensor & cpu_input1, at::Tensor & cpu_input2) { 19 return cpu_input1 + cpu_input2; 20} 21 22at::Tensor get_mps_add_output(at::Tensor & mps_input1, at::Tensor & mps_input2) { 23 24 // smoke tests 25 TORCH_CHECK(mps_input1.is_mps()); 26 TORCH_CHECK(mps_input2.is_mps()); 27 TORCH_CHECK(mps_input1.sizes() == mps_input2.sizes()); 28 29 using namespace at::native::mps; 30 at::Tensor mps_output = at::empty_like(mps_input1); 31 32 @autoreleasepool { 33 id<MTLDevice> device = MPSDevice::getInstance()->device(); 34 NSError *error = nil; 35 size_t numThreads = mps_output.numel(); 36 id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: [NSString stringWithUTF8String:CUSTOM_KERNEL] 37 options: nil 38 error: &error]; 39 TORCH_CHECK(customKernelLibrary, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String); 40 41 id<MTLFunction> customFunction = [customKernelLibrary newFunctionWithName: @"add_arrays"]; 42 TORCH_CHECK(customFunction, "Failed to create function state object for the kernel"); 43 44 id<MTLComputePipelineState> kernelPSO = [device newComputePipelineStateWithFunction: customFunction error: &error]; 45 TORCH_CHECK(kernelPSO, error.localizedDescription.UTF8String); 46 47 MPSStream* mpsStream = getCurrentMPSStream(); 48 49 dispatch_sync(mpsStream->queue(), ^() { 50 // Start a compute pass. 51 id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder(); 52 TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); 53 54 // Encode the pipeline state object and its parameters. 55 [computeEncoder setComputePipelineState: kernelPSO]; 56 [computeEncoder setBuffer: getMTLBufferStorage(mps_input1) offset:0 atIndex:0]; 57 [computeEncoder setBuffer: getMTLBufferStorage(mps_input2) offset:0 atIndex:1]; 58 [computeEncoder setBuffer: getMTLBufferStorage(mps_output) offset:0 atIndex:2]; 59 MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); 60 61 // Calculate a thread group size. 62 NSUInteger threadsPerGroupSize = std::min(kernelPSO.maxTotalThreadsPerThreadgroup, numThreads); 63 MTLSize threadGroupSize = MTLSizeMake(threadsPerGroupSize, 1, 1); 64 65 // Encode the compute command. 66 [computeEncoder dispatchThreads: gridSize threadsPerThreadgroup: threadGroupSize]; 67 68 }); 69 } 70 return mps_output; 71} 72 73PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 74 m.def("get_cpu_add_output", &get_cpu_add_output); 75 m.def("get_mps_add_output", &get_mps_add_output); 76}