1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8#include <iostream> 9 10namespace executorch { 11namespace backends { 12namespace mps { 13namespace delegate { 14 15Error 16MPSGraphBuilder::mpsMatMulOp(NodePtr nodePtr) { 17 auto graphNode = nodePtr->mpsnode_union_as_MPSMatMul(); 18 ET_LOG( 19 Debug, "%s: (%d, %d) -> %d", 20 __FUNCTION__, 21 graphNode->input1_id(), 22 graphNode->input2_id(), 23 graphNode->output_id() 24 ); 25 26 _idToMPSGraphTensor[graphNode->output_id()] = 27 [_mpsGraph matrixMultiplicationWithPrimaryTensor:getMPSGraphTensor(graphNode->input1_id()) 28 secondaryTensor:getMPSGraphTensor(graphNode->input2_id()) 29 name:nil]; 30 31 return Error::Ok; 32} 33 34Error 35MPSGraphBuilder::mpsAddmmOp(NodePtr nodePtr) { 36 auto graphNode = nodePtr->mpsnode_union_as_MPSAddmm(); 37 ET_LOG( 38 Debug, "%s: (%d, %d, %d) -> %d", 39 __FUNCTION__, 40 graphNode->input1_id(), 41 graphNode->input2_id(), 42 graphNode->input3_id(), 43 graphNode->output_id() 44 ); 45 46 MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input1_id()); 47 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input2_id()); 48 MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input3_id()); 49 float beta = graphNode->beta(); 50 float alpha = graphNode->alpha(); 51 52 MPSGraphTensor* multiplyTensor = [_mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor 53 secondaryTensor:weightTensor 54 name:@"addmm/matmul"]; 55 MPSGraphTensor* alphaTimesMultiply = multiplyTensor; 56 if (alpha != 1.0) { 57 // assert 58 MPSGraphTensor* alphaTensor = [_mpsGraph constantWithScalar:alpha 59 dataType:inputTensor.dataType]; 60 61 alphaTimesMultiply = [_mpsGraph multiplicationWithPrimaryTensor:multiplyTensor 62 secondaryTensor:alphaTensor 63 name:@"addmm/alpha*matmul"]; 64 } 65 66 MPSGraphTensor* betaBiasTensor = biasTensor; 67 if (beta != 1.0) { 68 MPSGraphTensor* betaTensor = [_mpsGraph constantWithScalar:beta 69 dataType:inputTensor.dataType]; 70 71 betaBiasTensor = [_mpsGraph multiplicationWithPrimaryTensor:biasTensor 72 secondaryTensor:betaTensor 73 name:@"addmm/beta*bias"]; 74 } 75 76 _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph additionWithPrimaryTensor:alphaTimesMultiply 77 secondaryTensor:betaBiasTensor 78 name:@"addmm/beta*bias*alpha*matmul"]; 79 80 return Error::Ok; 81} 82 83} // namespace delegate 84} // namespace mps 85} // namespace backends 86} // namespace executorch 87