1 // 2 // Copyright (c) 2023 Apple Inc. All rights reserved. 3 // Provided subject to the LICENSE file in the top level directory. 4 // 5 6 #pragma once 7 8 #import <Foundation/Foundation.h> 9 #include <MetalPerformanceShaders/MetalPerformanceShaders.h> 10 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> 11 #include <executorch/backends/apple/mps/runtime/MPSDevice.h> 12 #include <executorch/backends/apple/mps/schema_generated.h> 13 14 namespace executorch { 15 namespace backends { 16 namespace mps { 17 namespace delegate { 18 19 #define INF std::numeric_limits<float>::infinity() 20 21 MPSDataType getMPSScalarType(executorch::aten::ScalarType scalar_type); 22 executorch::aten::ScalarType getScalarType(MPSDataType mpsDataType); 23 MPSGraphTensor *castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor *tensor, executorch::aten::ScalarType toType); 24 MPSGraphTensor *castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor *tensor, MPSDataType toType); 25 std::vector<int64_t> getMPSShapeVec(const MPSShape *shape); 26 flatbufferDimsToVector(const flatbuffers::Vector<int32_t> * dims)27template <typename T = size_t> std::vector<T> flatbufferDimsToVector(const flatbuffers::Vector<int32_t> *dims) { 28 std::vector<T> dimsData; 29 dimsData.reserve(dims->size()); 30 for (auto dim : *dims) { 31 dimsData.push_back(static_cast<T>(dim)); 32 } 33 return dimsData; 34 } 35 36 id<MTLBuffer> getMTLBufferStorage(const executorch::aten::Tensor &tensor); 37 void *pageAlignedBlockPtr(const void *ptr, NSUInteger size, NSUInteger *alignedBlockSize); 38 39 MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSArray *permuteOrder); 40 41 } // namespace delegate 42 } // namespace mps 43 } // namespace backends 44 } // namespace executorch 45