xref: /aosp_15_r20/external/armnn/delegate/classic/src/BatchMatMul.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <algorithm>
11 #include <iterator>
12 #include <string>
13 #include <vector>
14 
15 namespace armnnDelegate
16 {
VisitBatchMatMulOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)17     TfLiteStatus VisitBatchMatMulOperator(DelegateData& delegateData,
18                                           TfLiteContext* tfLiteContext,
19                                           TfLiteNode* tfLiteNode,
20                                           int nodeIndex,
21                                           int32_t operatorCode)
22     {
23         TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24         TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 
26         const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27         const TfLiteTensor& kTfLiteLHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28         const TfLiteTensor& kTfLiteRHSInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
29 
30         if (!IsValid(tfLiteContext, kTfLiteLHSInputTensor, operatorCode, nodeIndex))
31         {
32             return kTfLiteError;
33         }
34         if (!IsValid(tfLiteContext, kTfLiteRHSInputTensor, operatorCode, nodeIndex))
35         {
36             return kTfLiteError;
37         }
38 
39         if (IsDynamicTensor(kTfLiteLHSInputTensor) || IsDynamicTensor(kTfLiteRHSInputTensor))
40         {
41             TF_LITE_MAYBE_KERNEL_LOG(
42                     tfLiteContext,
43                     "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
44                     operatorCode, nodeIndex);
45             return kTfLiteError;
46         }
47 
48         const TfLiteTensor& kTfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
49         if (IsDynamicTensor(kTfLiteOutputTensor))
50         {
51             TF_LITE_MAYBE_KERNEL_LOG(
52                     tfLiteContext,
53                     "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
54                     operatorCode, nodeIndex);
55             return kTfLiteError;
56         }
57 
58         const armnn::TensorInfo& armnnLHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteLHSInputTensor);
59         const armnn::TensorInfo& armnnRHSInputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteRHSInputTensor);
60         const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(kTfLiteOutputTensor, true);
61 
62         armnn::BatchMatMulDescriptor descriptor;
63         auto* params = reinterpret_cast<TfLiteBatchMatMulParams *>(tfLiteNode->builtin_data);
64 
65         // Tensorflow params are called adjoint, however they are actually just transposes behind the scene. They do
66         // not perform ajoint.
67         descriptor.m_TransposeX = params->adj_x;
68         descriptor.m_TransposeY = params->adj_y;
69 
70         // Check if supported
71         bool isSupported = false;
72         armnn::BackendId setBackend;
73         auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
74         {
75             FORWARD_LAYER_SUPPORT_FUNC("BATCH_MATMUL",
76                                        tfLiteContext,
77                                        IsBatchMatMulSupported,
78                                        delegateData.m_Backends,
79                                        isSupported,
80                                        setBackend,
81                                        armnnLHSInputTensorInfo,
82                                        armnnRHSInputTensorInfo,
83                                        outputTensorInfo,
84                                        descriptor);
85         };
86 
87         if (!delegateData.m_Network)
88         {
89             validateFunc(outputTensorInfo, isSupported);
90             return isSupported ? kTfLiteOk : kTfLiteError;
91         }
92 
93         armnn::IConnectableLayer* layer = delegateData.m_Network->AddBatchMatMulLayer(descriptor);
94         layer->SetBackendId(setBackend);
95         ARMNN_ASSERT(layer != nullptr);
96 
97         armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
98         outputSlot.SetTensorInfo(outputTensorInfo);
99 
100         // try to connect the Constant Inputs if there are any
101         if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
102         {
103             return kTfLiteError;
104         }
105 
106        return Connect(layer, tfLiteNode, delegateData);
107     }
108 } // namespace armnnDelegate
109