xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/OperationUtils.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 //  Copyright (c) 2023 Apple Inc. All rights reserved.
3 //  Provided subject to the LICENSE file in the top level directory.
4 //
5 
6 #pragma once
7 
8 #import <Foundation/Foundation.h>
9 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
10 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
11 #include <executorch/backends/apple/mps/runtime/MPSDevice.h>
12 #include <executorch/backends/apple/mps/schema_generated.h>
13 
14 namespace executorch {
15 namespace backends {
16 namespace mps {
17 namespace delegate {
18 
19 #define INF std::numeric_limits<float>::infinity()
20 
21 MPSDataType getMPSScalarType(executorch::aten::ScalarType scalar_type);
22 executorch::aten::ScalarType getScalarType(MPSDataType mpsDataType);
23 MPSGraphTensor *castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor *tensor, executorch::aten::ScalarType toType);
24 MPSGraphTensor *castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor *tensor, MPSDataType toType);
25 std::vector<int64_t> getMPSShapeVec(const MPSShape *shape);
26 
flatbufferDimsToVector(const flatbuffers::Vector<int32_t> * dims)27 template <typename T = size_t> std::vector<T> flatbufferDimsToVector(const flatbuffers::Vector<int32_t> *dims) {
28   std::vector<T> dimsData;
29   dimsData.reserve(dims->size());
30   for (auto dim : *dims) {
31     dimsData.push_back(static_cast<T>(dim));
32   }
33   return dimsData;
34 }
35 
36 id<MTLBuffer> getMTLBufferStorage(const executorch::aten::Tensor &tensor);
37 void *pageAlignedBlockPtr(const void *ptr, NSUInteger size, NSUInteger *alignedBlockSize);
38 
39 MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSArray *permuteOrder);
40 
41 } // namespace delegate
42 } // namespace mps
43 } // namespace backends
44 } // namespace executorch
45