1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8#include <numeric> 9 10#ifndef PAGE_SIZE 11#define PAGE_SIZE 4096 12#endif 13 14namespace executorch { 15namespace backends { 16namespace mps { 17namespace delegate { 18 19MPSDataType 20MPSGraphBuilder::getMPSDataType(int32_t id) { 21 return getMPSDataType(_flatBufferGraph->mps_values()->Get(id)->datatype()); 22} 23 24MPSDataType 25MPSGraphBuilder::getMPSDataType(DataType serializedDataType) { 26 switch (serializedDataType) { 27 case DataType::mps_data_type_float16: 28 return MPSDataTypeFloat16; 29 case DataType::mps_data_type_float32: 30 case DataType::mps_data_type_float64: 31 return MPSDataTypeFloat32; 32 case DataType::mps_data_type_int8: 33 return MPSDataTypeInt8; 34 case DataType::mps_data_type_int4: { 35 if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, *)) { 36 return MPSDataTypeInt4; 37 } else { 38 return ((MPSDataType)(MPSDataTypeSignedBit | 4)); 39 } 40 } 41 case DataType::mps_data_type_int16: 42 return MPSDataTypeInt16; 43 case DataType::mps_data_type_int32: 44 return MPSDataTypeInt32; 45 case DataType::mps_data_type_int64: 46 return MPSDataTypeInt64; 47 case DataType::mps_data_type_bool: 48 return MPSDataTypeBool; 49 default: 50 ET_CHECK_MSG(false, "[ERROR] Invalid MPS data type: %d!", (int32_t)serializedDataType); 51 return MPSDataTypeInvalid; 52 } 53} 54 55MPSShape* 56MPSGraphBuilder::getMPSShape(int32_t id) { 57 TensorPtr mpsTensor = _flatBufferGraph->mps_values()->Get(id); 58 auto sizes = mpsTensor->dims(); 59 const int sz = mpsTensor->num_dims(); 60 const int sz_ = (sz > 0) ? sz : 1; 61 62 std::vector<NSNumber*> numbers(sz_); 63 64 for (int i = 0; i < sz_; i++) { 65 NSInteger sz_i = (i < sz) ? sizes->Get(i) : 1; 66 NSNumber* number = [NSNumber numberWithInteger:sz_i]; 67 numbers[i] = number; 68 } 69 return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; 70} 71 72MPSShape* 73MPSGraphBuilder::getMPSShape(const flatbuffers::Vector<int32_t>* shape) { 74 const int sz = shape->size(); 75 const int sz_ = (sz > 0) ? sz : 1; 76 77 std::vector<NSNumber*> numbers(sz_); 78 79 for (int i = 0; i < sz_; i++) { 80 NSInteger sz_i = (i < sz) ? shape->Get(i) : 1; 81 NSNumber* number = [NSNumber numberWithInteger:sz_i]; 82 numbers[i] = number; 83 } 84 return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; 85} 86 87int64_t 88MPSGraphBuilder::numel(const flatbuffers::Vector<int32_t>* shape) { 89 int64_t numel = 1; 90 for (auto dim : *shape) { 91 numel = numel * dim; 92 } 93 return numel; 94} 95 96NSData* 97MPSGraphBuilder::getConstantData(int32_t id) { 98 TensorPtr mpsTensor = _flatBufferGraph->mps_values()->Get(id); 99 uint64_t constantBufferSize = mpsTensor->constant_buffer_size(); 100 uint64_t segmentOffset = mpsTensor->segment_offset(); 101 const unsigned char* constantBuffer = _constant_data_ptr + segmentOffset; 102 ET_CHECK_MSG(constantBufferSize > 0 && constantBuffer != nullptr, "[ERROR] Invalid constant buffer"); 103 return [[NSData alloc] initWithBytesNoCopy:(void*)constantBuffer 104 length:constantBufferSize]; 105} 106 107std::pair<float, float> 108MPSGraphBuilder::getMinMaxValues(NodePtr nodePtr) { 109 float minValue = -INF; 110 float maxValue = INF; 111 auto minMaxValues = nodePtr->min_max(); 112 if (minMaxValues != nullptr) { 113 minValue = minMaxValues->min_value(); 114 maxValue = minMaxValues->max_value(); 115 } 116 117 return {minValue, maxValue}; 118} 119 120#define _DEFINE_MPS_NODE(node) \ 121 case mpsgraph::MPSNodeUnion::MPS##node: \ 122 return mps##node##Op(nodePtr); 123 124Error 125MPSGraphBuilder::addNodeToMPSGraph(NodePtr nodePtr) { 126 switch (nodePtr->mpsnode_union_type()) { 127 // Activation ops 128 _DEFINE_MPS_NODE(HardTanh); 129 _DEFINE_MPS_NODE(ReLU); 130 _DEFINE_MPS_NODE(GELU); 131 _DEFINE_MPS_NODE(LeakyReLU); 132 _DEFINE_MPS_NODE(Softmax); 133 _DEFINE_MPS_NODE(LogSoftmax); 134 // Binary ops 135 _DEFINE_MPS_NODE(Add); 136 _DEFINE_MPS_NODE(Sub); 137 _DEFINE_MPS_NODE(Mul); 138 _DEFINE_MPS_NODE(Div); 139 _DEFINE_MPS_NODE(Pow); 140 _DEFINE_MPS_NODE(Fmod); 141 _DEFINE_MPS_NODE(Remainder); 142 _DEFINE_MPS_NODE(BitwiseAnd); 143 _DEFINE_MPS_NODE(BitwiseOr); 144 _DEFINE_MPS_NODE(BitwiseXor); 145 _DEFINE_MPS_NODE(Minimum); 146 // Unary ops 147 _DEFINE_MPS_NODE(Exp); 148 _DEFINE_MPS_NODE(Exp2); 149 _DEFINE_MPS_NODE(Reciprocal); 150 _DEFINE_MPS_NODE(Sqrt); 151 _DEFINE_MPS_NODE(Neg); 152 _DEFINE_MPS_NODE(Log); 153 _DEFINE_MPS_NODE(Log10); 154 _DEFINE_MPS_NODE(Log2); 155 _DEFINE_MPS_NODE(Erf); 156 _DEFINE_MPS_NODE(Floor); 157 _DEFINE_MPS_NODE(Ceil); 158 _DEFINE_MPS_NODE(Rsqrt); 159 _DEFINE_MPS_NODE(Sigmoid); 160 _DEFINE_MPS_NODE(Sin); 161 _DEFINE_MPS_NODE(Sign); 162 _DEFINE_MPS_NODE(Cos); 163 _DEFINE_MPS_NODE(Tan); 164 _DEFINE_MPS_NODE(Abs); 165 _DEFINE_MPS_NODE(Asin); 166 _DEFINE_MPS_NODE(Acos); 167 _DEFINE_MPS_NODE(Atan); 168 _DEFINE_MPS_NODE(Sinh); 169 _DEFINE_MPS_NODE(Cosh); 170 _DEFINE_MPS_NODE(Tanh); 171 _DEFINE_MPS_NODE(Asinh); 172 _DEFINE_MPS_NODE(Acosh); 173 _DEFINE_MPS_NODE(Atanh); 174 _DEFINE_MPS_NODE(BitwiseNot); 175 _DEFINE_MPS_NODE(Isnan); 176 _DEFINE_MPS_NODE(Isinf); 177 _DEFINE_MPS_NODE(Round); 178 _DEFINE_MPS_NODE(LogicalNot); 179 // Clamp ops 180 _DEFINE_MPS_NODE(Clamp); 181 _DEFINE_MPS_NODE(Where); 182 // Linear algebra ops 183 _DEFINE_MPS_NODE(MatMul); 184 _DEFINE_MPS_NODE(Addmm); 185 // Constant ops 186 _DEFINE_MPS_NODE(Full); 187 _DEFINE_MPS_NODE(FullLike); 188 //Indexing ops 189 _DEFINE_MPS_NODE(IndexSelect); 190 _DEFINE_MPS_NODE(Embedding); 191 _DEFINE_MPS_NODE(IndexTensor); 192 _DEFINE_MPS_NODE(IndexPut); 193 _DEFINE_MPS_NODE(Scatter); 194 // Reduce ops 195 _DEFINE_MPS_NODE(Mean); 196 // Shape ops 197 _DEFINE_MPS_NODE(Permute); 198 _DEFINE_MPS_NODE(View); 199 _DEFINE_MPS_NODE(Expand); 200 _DEFINE_MPS_NODE(Cat); 201 _DEFINE_MPS_NODE(Squeeze); 202 _DEFINE_MPS_NODE(Unsqueeze); 203 _DEFINE_MPS_NODE(Select); 204 _DEFINE_MPS_NODE(Slice); 205 _DEFINE_MPS_NODE(PixelShuffle); 206 _DEFINE_MPS_NODE(SplitWithSizes); 207 _DEFINE_MPS_NODE(Cast); 208 // Convolution ops 209 _DEFINE_MPS_NODE(Conv2D); 210 _DEFINE_MPS_NODE(DepthwiseConv2D); 211 // Comparison ops 212 _DEFINE_MPS_NODE(Eq); 213 _DEFINE_MPS_NODE(Ne); 214 _DEFINE_MPS_NODE(Ge); 215 _DEFINE_MPS_NODE(Gt); 216 _DEFINE_MPS_NODE(Le); 217 _DEFINE_MPS_NODE(Lt); 218 // Normalization ops 219 _DEFINE_MPS_NODE(BatchNorm); 220 _DEFINE_MPS_NODE(LayerNorm); 221 // Pooling ops 222 _DEFINE_MPS_NODE(MaxPool2DWithIndices); 223 _DEFINE_MPS_NODE(AvgPool2D); 224 // Pad ops 225 _DEFINE_MPS_NODE(ConstantPadND); 226 // Range ops 227 _DEFINE_MPS_NODE(Arange); 228 // Quant-Dequant ops 229 _DEFINE_MPS_NODE(DequantizePerChannelGroup); 230 231 case mpsgraph::MPSNodeUnion::NONE: 232 default: 233 ET_CHECK_OR_RETURN_ERROR( 234 false, 235 NotImplemented, 236 "[ERROR] Unhandled node type: %s!", 237 mpsgraph::EnumNameMPSNodeUnion(nodePtr->mpsnode_union_type())); 238 } 239} 240 241Error 242MPSGraphBuilder::compileMetalKernel(NodePtr nodePtr) { 243 return addNodeToMPSGraph(nodePtr); 244} 245 246#undef _DEFINE_MPS_NODE 247 248MPSGraphTensor* 249MPSGraphBuilder::getMPSGraphTensor(int32_t id) { 250 return _idToMPSGraphTensor[id]; 251} 252 253MPSDataType getMPSScalarType(executorch::aten::ScalarType scalar_type) { 254 switch (scalar_type) { 255 // This is an intentional fallthrough supporting Double for Scalar 256 // types as they are casted to Float32 currently. 257 case executorch::aten::ScalarType::Float: 258 return MPSDataTypeFloat32; 259 case executorch::aten::ScalarType::Half: 260 return MPSDataTypeFloat16; 261 default: 262 ET_CHECK_MSG(false, "Unhandled ExecuTorch scalar type!"); 263 } 264} 265 266executorch::aten::ScalarType getScalarType(MPSDataType mpsDataType) { 267 switch (mpsDataType) { 268 case MPSDataTypeFloat16: 269 return executorch::aten::ScalarType::Half; 270 case MPSDataTypeFloat32: 271 return executorch::aten::ScalarType::Float; 272 case MPSDataTypeInt8: 273 return executorch::aten::ScalarType::Char; 274 case MPSDataTypeInt16: 275 return executorch::aten::ScalarType::Short; 276 case MPSDataTypeInt32: 277 return executorch::aten::ScalarType::Int; 278 case MPSDataTypeInt64: 279 return executorch::aten::ScalarType::Long; 280 case MPSDataTypeBool: 281 return executorch::aten::ScalarType::Bool; 282 default: 283 ET_CHECK_MSG(false, "Unhandled MPS data type!"); 284 } 285} 286 287MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, executorch::aten::ScalarType toType) { 288 return castMPSTensor(mpsGraph, tensor, getMPSScalarType(toType)); 289} 290 291MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { 292 return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"]; 293} 294 295std::vector<int64_t> getMPSShapeVec(const MPSShape* shape) { 296 __block std::vector<int64_t> shapeVec; 297 shapeVec.reserve([shape count]); 298 [shape enumerateObjectsUsingBlock:^(NSNumber * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { 299 shapeVec.push_back(obj.intValue); 300 }]; 301 return shapeVec; 302} 303 304id<MTLBuffer> getMTLBufferStorage(const executorch::aten::Tensor &tensor) { 305 uint8_t *data = tensor.mutable_data_ptr<uint8_t>(); 306 return [MPSDevice::getInstance()->device() newBufferWithBytesNoCopy:data 307 length:tensor.nbytes() 308 options:0 309 deallocator:nil]; 310} 311 312void* pageAlignedBlockPtr(const void* ptr, NSUInteger size, NSUInteger* alignedBlockSize) { 313 uintptr_t address = (uintptr_t)ptr; 314 uintptr_t alignedAddress = address & ~(PAGE_SIZE - 1); 315 uintptr_t alignedEnd = ((address + size) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1); 316 uint64_t alignedLength = alignedEnd - alignedAddress; 317 318 assert(address >= alignedAddress); 319 assert(address + size <= alignedAddress + alignedLength); 320 321 *alignedBlockSize = alignedLength; 322 return (void*)alignedAddress; 323} 324 325 326MPSGraphTensor* permuteTensor(MPSGraph* graph, MPSGraphTensor* inputTensor, NSArray* permuteOrder) { 327 NSUInteger srcRank = [[inputTensor shape] count]; 328 if (srcRank != [permuteOrder count]) { 329 return nil; 330 } 331 332 MPSGraphTensor* outputTensor = inputTensor; 333 std::vector<NSUInteger> dimensionOrder(srcRank); 334 std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0); 335 336 for (int32_t i = 0; i < srcRank; i++) { 337 NSUInteger axis = [permuteOrder[i] integerValue]; 338 auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis); 339 NSUInteger axis1 = i; 340 NSUInteger axis2 = axisIter - dimensionOrder.begin(); 341 iter_swap(dimensionOrder.begin() + i, axisIter); 342 343 outputTensor = [graph transposeTensor:outputTensor dimension:axis1 withDimension:axis2 name:nil]; 344 } 345 346 return outputTensor; 347} 348 349 350} // namespace delegate 351} // namespace mps 352} // namespace backends 353} // namespace executorch 354