xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/LinearAlgebra.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8#include <iostream>
9
10namespace executorch {
11namespace backends {
12namespace mps {
13namespace delegate {
14
15Error
16MPSGraphBuilder::mpsMatMulOp(NodePtr nodePtr) {
17  auto graphNode = nodePtr->mpsnode_union_as_MPSMatMul();
18  ET_LOG(
19    Debug, "%s: (%d, %d) -> %d",
20    __FUNCTION__,
21    graphNode->input1_id(),
22    graphNode->input2_id(),
23    graphNode->output_id()
24  );
25
26  _idToMPSGraphTensor[graphNode->output_id()] =
27    [_mpsGraph matrixMultiplicationWithPrimaryTensor:getMPSGraphTensor(graphNode->input1_id())
28                                     secondaryTensor:getMPSGraphTensor(graphNode->input2_id())
29                                                name:nil];
30
31  return Error::Ok;
32}
33
34Error
35MPSGraphBuilder::mpsAddmmOp(NodePtr nodePtr) {
36  auto graphNode = nodePtr->mpsnode_union_as_MPSAddmm();
37  ET_LOG(
38    Debug, "%s: (%d, %d, %d) -> %d",
39    __FUNCTION__,
40    graphNode->input1_id(),
41    graphNode->input2_id(),
42    graphNode->input3_id(),
43    graphNode->output_id()
44  );
45
46  MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input1_id());
47  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input2_id());
48  MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input3_id());
49  float beta = graphNode->beta();
50  float alpha = graphNode->alpha();
51
52  MPSGraphTensor* multiplyTensor = [_mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor
53                                                                    secondaryTensor:weightTensor
54                                                                               name:@"addmm/matmul"];
55  MPSGraphTensor* alphaTimesMultiply = multiplyTensor;
56  if (alpha != 1.0) {
57    // assert
58    MPSGraphTensor* alphaTensor = [_mpsGraph constantWithScalar:alpha
59                                                       dataType:inputTensor.dataType];
60
61    alphaTimesMultiply = [_mpsGraph multiplicationWithPrimaryTensor:multiplyTensor
62                                                    secondaryTensor:alphaTensor
63                                                              name:@"addmm/alpha*matmul"];
64  }
65
66  MPSGraphTensor* betaBiasTensor = biasTensor;
67  if (beta != 1.0) {
68    MPSGraphTensor* betaTensor = [_mpsGraph constantWithScalar:beta
69                                                      dataType:inputTensor.dataType];
70
71    betaBiasTensor = [_mpsGraph multiplicationWithPrimaryTensor:biasTensor
72                                                  secondaryTensor:betaTensor
73                                                  name:@"addmm/beta*bias"];
74  }
75
76  _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph additionWithPrimaryTensor:alphaTimesMultiply
77                                                                    secondaryTensor:betaBiasTensor
78                                                                               name:@"addmm/beta*bias*alpha*matmul"];
79
80  return Error::Ok;
81}
82
83} // namespace delegate
84} // namespace mps
85} // namespace backends
86} // namespace executorch
87