xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/OperationUtils.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8#include <numeric>
9
10#ifndef PAGE_SIZE
11#define PAGE_SIZE 4096
12#endif
13
14namespace executorch {
15namespace backends {
16namespace mps {
17namespace delegate {
18
19MPSDataType
20MPSGraphBuilder::getMPSDataType(int32_t id) {
21  return getMPSDataType(_flatBufferGraph->mps_values()->Get(id)->datatype());
22}
23
24MPSDataType
25MPSGraphBuilder::getMPSDataType(DataType serializedDataType) {
26  switch (serializedDataType) {
27    case DataType::mps_data_type_float16:
28      return MPSDataTypeFloat16;
29    case DataType::mps_data_type_float32:
30    case DataType::mps_data_type_float64:
31      return MPSDataTypeFloat32;
32    case DataType::mps_data_type_int8:
33      return MPSDataTypeInt8;
34    case DataType::mps_data_type_int4: {
35      if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, *)) {
36        return MPSDataTypeInt4;
37      } else {
38        return ((MPSDataType)(MPSDataTypeSignedBit | 4));
39      }
40    }
41    case DataType::mps_data_type_int16:
42      return MPSDataTypeInt16;
43    case DataType::mps_data_type_int32:
44      return MPSDataTypeInt32;
45    case DataType::mps_data_type_int64:
46      return MPSDataTypeInt64;
47    case DataType::mps_data_type_bool:
48      return MPSDataTypeBool;
49    default:
50      ET_CHECK_MSG(false, "[ERROR] Invalid MPS data type: %d!", (int32_t)serializedDataType);
51      return MPSDataTypeInvalid;
52  }
53}
54
55MPSShape*
56MPSGraphBuilder::getMPSShape(int32_t id) {
57  TensorPtr mpsTensor = _flatBufferGraph->mps_values()->Get(id);
58  auto sizes = mpsTensor->dims();
59  const int sz =  mpsTensor->num_dims();
60  const int sz_ = (sz > 0) ? sz : 1;
61
62  std::vector<NSNumber*> numbers(sz_);
63
64  for (int i = 0; i < sz_; i++) {
65    NSInteger sz_i = (i < sz) ? sizes->Get(i) : 1;
66    NSNumber* number = [NSNumber numberWithInteger:sz_i];
67    numbers[i] = number;
68  }
69  return [NSArray arrayWithObjects:numbers.data() count:numbers.size()];
70}
71
72MPSShape*
73MPSGraphBuilder::getMPSShape(const flatbuffers::Vector<int32_t>* shape) {
74  const int sz =  shape->size();
75  const int sz_ = (sz > 0) ? sz : 1;
76
77  std::vector<NSNumber*> numbers(sz_);
78
79  for (int i = 0; i < sz_; i++) {
80    NSInteger sz_i = (i < sz) ? shape->Get(i) : 1;
81    NSNumber* number = [NSNumber numberWithInteger:sz_i];
82    numbers[i] = number;
83  }
84  return [NSArray arrayWithObjects:numbers.data() count:numbers.size()];
85}
86
87int64_t
88MPSGraphBuilder::numel(const flatbuffers::Vector<int32_t>* shape) {
89  int64_t numel = 1;
90  for (auto dim : *shape) {
91    numel = numel * dim;
92  }
93  return numel;
94}
95
96NSData*
97MPSGraphBuilder::getConstantData(int32_t id) {
98  TensorPtr mpsTensor = _flatBufferGraph->mps_values()->Get(id);
99  uint64_t constantBufferSize = mpsTensor->constant_buffer_size();
100  uint64_t segmentOffset = mpsTensor->segment_offset();
101  const unsigned char* constantBuffer = _constant_data_ptr + segmentOffset;
102  ET_CHECK_MSG(constantBufferSize > 0 && constantBuffer != nullptr, "[ERROR] Invalid constant buffer");
103  return [[NSData alloc] initWithBytesNoCopy:(void*)constantBuffer
104                                length:constantBufferSize];
105}
106
107std::pair<float, float>
108MPSGraphBuilder::getMinMaxValues(NodePtr nodePtr) {
109  float minValue = -INF;
110  float maxValue = INF;
111  auto minMaxValues = nodePtr->min_max();
112  if (minMaxValues != nullptr) {
113    minValue = minMaxValues->min_value();
114    maxValue = minMaxValues->max_value();
115  }
116
117  return {minValue, maxValue};
118}
119
120#define _DEFINE_MPS_NODE(node)                 \
121  case mpsgraph::MPSNodeUnion::MPS##node:      \
122    return mps##node##Op(nodePtr);
123
124Error
125MPSGraphBuilder::addNodeToMPSGraph(NodePtr nodePtr) {
126  switch (nodePtr->mpsnode_union_type()) {
127    // Activation ops
128    _DEFINE_MPS_NODE(HardTanh);
129    _DEFINE_MPS_NODE(ReLU);
130    _DEFINE_MPS_NODE(GELU);
131    _DEFINE_MPS_NODE(LeakyReLU);
132    _DEFINE_MPS_NODE(Softmax);
133    _DEFINE_MPS_NODE(LogSoftmax);
134    // Binary ops
135    _DEFINE_MPS_NODE(Add);
136    _DEFINE_MPS_NODE(Sub);
137    _DEFINE_MPS_NODE(Mul);
138    _DEFINE_MPS_NODE(Div);
139    _DEFINE_MPS_NODE(Pow);
140    _DEFINE_MPS_NODE(Fmod);
141    _DEFINE_MPS_NODE(Remainder);
142    _DEFINE_MPS_NODE(BitwiseAnd);
143    _DEFINE_MPS_NODE(BitwiseOr);
144    _DEFINE_MPS_NODE(BitwiseXor);
145    _DEFINE_MPS_NODE(Minimum);
146    // Unary ops
147    _DEFINE_MPS_NODE(Exp);
148    _DEFINE_MPS_NODE(Exp2);
149    _DEFINE_MPS_NODE(Reciprocal);
150    _DEFINE_MPS_NODE(Sqrt);
151    _DEFINE_MPS_NODE(Neg);
152    _DEFINE_MPS_NODE(Log);
153    _DEFINE_MPS_NODE(Log10);
154    _DEFINE_MPS_NODE(Log2);
155    _DEFINE_MPS_NODE(Erf);
156    _DEFINE_MPS_NODE(Floor);
157    _DEFINE_MPS_NODE(Ceil);
158    _DEFINE_MPS_NODE(Rsqrt);
159    _DEFINE_MPS_NODE(Sigmoid);
160    _DEFINE_MPS_NODE(Sin);
161    _DEFINE_MPS_NODE(Sign);
162    _DEFINE_MPS_NODE(Cos);
163    _DEFINE_MPS_NODE(Tan);
164    _DEFINE_MPS_NODE(Abs);
165    _DEFINE_MPS_NODE(Asin);
166    _DEFINE_MPS_NODE(Acos);
167    _DEFINE_MPS_NODE(Atan);
168    _DEFINE_MPS_NODE(Sinh);
169    _DEFINE_MPS_NODE(Cosh);
170    _DEFINE_MPS_NODE(Tanh);
171    _DEFINE_MPS_NODE(Asinh);
172    _DEFINE_MPS_NODE(Acosh);
173    _DEFINE_MPS_NODE(Atanh);
174    _DEFINE_MPS_NODE(BitwiseNot);
175    _DEFINE_MPS_NODE(Isnan);
176    _DEFINE_MPS_NODE(Isinf);
177    _DEFINE_MPS_NODE(Round);
178    _DEFINE_MPS_NODE(LogicalNot);
179    // Clamp ops
180    _DEFINE_MPS_NODE(Clamp);
181    _DEFINE_MPS_NODE(Where);
182    // Linear algebra ops
183    _DEFINE_MPS_NODE(MatMul);
184    _DEFINE_MPS_NODE(Addmm);
185    // Constant ops
186    _DEFINE_MPS_NODE(Full);
187    _DEFINE_MPS_NODE(FullLike);
188    //Indexing ops
189    _DEFINE_MPS_NODE(IndexSelect);
190    _DEFINE_MPS_NODE(Embedding);
191    _DEFINE_MPS_NODE(IndexTensor);
192    _DEFINE_MPS_NODE(IndexPut);
193    _DEFINE_MPS_NODE(Scatter);
194    // Reduce ops
195    _DEFINE_MPS_NODE(Mean);
196    // Shape ops
197    _DEFINE_MPS_NODE(Permute);
198    _DEFINE_MPS_NODE(View);
199    _DEFINE_MPS_NODE(Expand);
200    _DEFINE_MPS_NODE(Cat);
201    _DEFINE_MPS_NODE(Squeeze);
202    _DEFINE_MPS_NODE(Unsqueeze);
203    _DEFINE_MPS_NODE(Select);
204    _DEFINE_MPS_NODE(Slice);
205    _DEFINE_MPS_NODE(PixelShuffle);
206    _DEFINE_MPS_NODE(SplitWithSizes);
207    _DEFINE_MPS_NODE(Cast);
208    // Convolution ops
209    _DEFINE_MPS_NODE(Conv2D);
210    _DEFINE_MPS_NODE(DepthwiseConv2D);
211    // Comparison ops
212    _DEFINE_MPS_NODE(Eq);
213    _DEFINE_MPS_NODE(Ne);
214    _DEFINE_MPS_NODE(Ge);
215    _DEFINE_MPS_NODE(Gt);
216    _DEFINE_MPS_NODE(Le);
217    _DEFINE_MPS_NODE(Lt);
218    // Normalization ops
219    _DEFINE_MPS_NODE(BatchNorm);
220    _DEFINE_MPS_NODE(LayerNorm);
221    // Pooling ops
222    _DEFINE_MPS_NODE(MaxPool2DWithIndices);
223    _DEFINE_MPS_NODE(AvgPool2D);
224    // Pad ops
225    _DEFINE_MPS_NODE(ConstantPadND);
226    // Range ops
227    _DEFINE_MPS_NODE(Arange);
228    // Quant-Dequant ops
229    _DEFINE_MPS_NODE(DequantizePerChannelGroup);
230
231    case mpsgraph::MPSNodeUnion::NONE:
232    default:
233      ET_CHECK_OR_RETURN_ERROR(
234        false,
235        NotImplemented,
236        "[ERROR] Unhandled node type: %s!",
237        mpsgraph::EnumNameMPSNodeUnion(nodePtr->mpsnode_union_type()));
238  }
239}
240
241Error
242MPSGraphBuilder::compileMetalKernel(NodePtr nodePtr) {
243  return addNodeToMPSGraph(nodePtr);
244}
245
246#undef _DEFINE_MPS_NODE
247
248MPSGraphTensor*
249MPSGraphBuilder::getMPSGraphTensor(int32_t id) {
250  return _idToMPSGraphTensor[id];
251}
252
253MPSDataType getMPSScalarType(executorch::aten::ScalarType scalar_type) {
254  switch (scalar_type) {
255    // This is an intentional fallthrough supporting Double for Scalar
256    // types as they are casted to Float32 currently.
257    case executorch::aten::ScalarType::Float:
258      return MPSDataTypeFloat32;
259    case executorch::aten::ScalarType::Half:
260      return MPSDataTypeFloat16;
261    default:
262      ET_CHECK_MSG(false, "Unhandled ExecuTorch scalar type!");
263  }
264}
265
266executorch::aten::ScalarType getScalarType(MPSDataType mpsDataType) {
267  switch (mpsDataType) {
268    case MPSDataTypeFloat16:
269      return executorch::aten::ScalarType::Half;
270    case MPSDataTypeFloat32:
271      return executorch::aten::ScalarType::Float;
272    case MPSDataTypeInt8:
273      return executorch::aten::ScalarType::Char;
274    case MPSDataTypeInt16:
275      return executorch::aten::ScalarType::Short;
276    case MPSDataTypeInt32:
277      return executorch::aten::ScalarType::Int;
278    case MPSDataTypeInt64:
279      return executorch::aten::ScalarType::Long;
280    case MPSDataTypeBool:
281      return executorch::aten::ScalarType::Bool;
282    default:
283      ET_CHECK_MSG(false, "Unhandled MPS data type!");
284  }
285}
286
287MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, executorch::aten::ScalarType toType) {
288  return castMPSTensor(mpsGraph, tensor, getMPSScalarType(toType));
289}
290
291MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
292  return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];
293}
294
295std::vector<int64_t> getMPSShapeVec(const MPSShape* shape) {
296  __block std::vector<int64_t> shapeVec;
297  shapeVec.reserve([shape count]);
298  [shape enumerateObjectsUsingBlock:^(NSNumber * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) {
299      shapeVec.push_back(obj.intValue);
300  }];
301  return shapeVec;
302}
303
304id<MTLBuffer> getMTLBufferStorage(const executorch::aten::Tensor &tensor) {
305  uint8_t *data = tensor.mutable_data_ptr<uint8_t>();
306  return [MPSDevice::getInstance()->device() newBufferWithBytesNoCopy:data
307                                                               length:tensor.nbytes()
308                                                              options:0
309                                                          deallocator:nil];
310}
311
312void* pageAlignedBlockPtr(const void* ptr, NSUInteger size, NSUInteger* alignedBlockSize) {
313  uintptr_t address = (uintptr_t)ptr;
314  uintptr_t alignedAddress = address & ~(PAGE_SIZE - 1);
315  uintptr_t alignedEnd = ((address + size) + PAGE_SIZE - 1) & ~(PAGE_SIZE - 1);
316  uint64_t alignedLength = alignedEnd - alignedAddress;
317
318  assert(address >= alignedAddress);
319  assert(address + size <= alignedAddress + alignedLength);
320
321  *alignedBlockSize = alignedLength;
322  return (void*)alignedAddress;
323}
324
325
326MPSGraphTensor* permuteTensor(MPSGraph* graph, MPSGraphTensor* inputTensor, NSArray* permuteOrder) {
327  NSUInteger srcRank = [[inputTensor shape] count];
328  if (srcRank != [permuteOrder count]) {
329    return nil;
330  }
331
332  MPSGraphTensor* outputTensor = inputTensor;
333  std::vector<NSUInteger> dimensionOrder(srcRank);
334  std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0);
335
336  for (int32_t i = 0; i < srcRank; i++) {
337    NSUInteger axis = [permuteOrder[i] integerValue];
338    auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis);
339    NSUInteger axis1 = i;
340    NSUInteger axis2 = axisIter - dimensionOrder.begin();
341    iter_swap(dimensionOrder.begin() + i, axisIter);
342
343    outputTensor = [graph transposeTensor:outputTensor dimension:axis1 withDimension:axis2 name:nil];
344  }
345
346  return outputTensor;
347}
348
349
350} // namespace delegate
351} // namespace mps
352} // namespace backends
353} // namespace executorch
354