1 // 2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <ClassicDelegateUtils.hpp> 9 10 #include <algorithm> 11 #include <iterator> 12 #include <string> 13 #include <vector> 14 15 namespace armnnDelegate 16 { VisitBatchMatMulOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)17 TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData, 18 TfLiteContext* tfLiteContext, 19 TfLiteNode* tfLiteNode, 20 int nodeIndex, 21 int32_t operatorCode) 22 { 23 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex)); 24 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex)); 25 26 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors; 27 const TfLiteTensor& kTfLiteLHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]]; 28 const TfLiteTensor& kTfLiteRHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]]; 29 30 if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex)) 31 { 32 return kTfLiteError; 33 } 34 if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex)) 35 { 36 return kTfLiteError; 37 } 38 39 if (IsDynamicTensor(kTfLiteLHSInputTensor) || IsDynamicTensor(kTfLiteRHSInputTensor)) 40 { 41 TF_LITE_MAYBE_KERNEL_LOG( 42 tfLiteContext, 43 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ", 44 operatorCode, nodeIndex); 45 return kTfLiteError; 46 } 47 48 const TfLiteTensor& kTfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]]; 49 if (IsDynamicTensor(kTfLiteOutputTensor)) 50 { 51 TF_LITE_MAYBE_KERNEL_LOG( 52 tfLiteContext, 53 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ", 54 operatorCode, nodeIndex); 55 return kTfLiteError; 56 } 57 58 const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteLHSInputTensor); 59 const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteRHSInputTensor); 60 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteOutputTensor, true); 61 62 armnn::BatchMatMulDescriptor descriptor; 63 auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(tfLiteNode->builtin_data); 64 65 // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do 66 // not perform ajoint. 67 descriptor.m_TransposeX = params->adj_x; 68 descriptor.m_TransposeY = params->adj_y; 69 70 // Check if supported 71 bool isSupported = false; 72 armnn::BackendId setBackend; 73 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) 74 { 75 FORWARD_LAYER_SUPPORT_FUNC("BATCH_MATMUL", 76 tfLiteContext, 77 IsBatchMatMulSupported, 78 delegateData.m_Backends, 79 isSupported, 80 setBackend, 81 armnnLHSInputTensorInfo, 82 armnnRHSInputTensorInfo, 83 outputTensorInfo, 84 descriptor); 85 }; 86 87 if (!delegateData.m_Network) 88 { 89 validateFunc(outputTensorInfo, isSupported); 90 return isSupported ? kTfLiteOk : kTfLiteError; 91 } 92 93 armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor); 94 layer->SetBackendId(setBackend); 95 ARMNN_ASSERT(layer != nullptr); 96 97 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); 98 outputSlot.SetTensorInfo(outputTensorInfo); 99 100 // try to connect the Constant Inputs if there are any 101 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk ) 102 { 103 return kTfLiteError; 104 } 105 106 return Connect(layer, tfLiteNode, delegateData); 107 } 108 } // namespace armnnDelegate 109