xref: /aosp_15_r20/external/armnn/delegate/common/src/test/DelegateTestInterpreterUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/core/c/c_api.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/custom_ops_register.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <type_traits>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker namespace delegateTestInterpreter
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
GetInputTensorFromInterpreter(TfLiteInterpreter * interpreter,int index)19*89c4ff92SAndroid Build Coastguard Worker inline TfLiteTensor* GetInputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker     TfLiteTensor* inputTensor = TfLiteInterpreterGetInputTensor(interpreter, index);
22*89c4ff92SAndroid Build Coastguard Worker     if(inputTensor == nullptr)
23*89c4ff92SAndroid Build Coastguard Worker     {
24*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Input tensor was not found at the given index: " + std::to_string(index));
25*89c4ff92SAndroid Build Coastguard Worker     }
26*89c4ff92SAndroid Build Coastguard Worker     return inputTensor;
27*89c4ff92SAndroid Build Coastguard Worker }
28*89c4ff92SAndroid Build Coastguard Worker 
GetOutputTensorFromInterpreter(TfLiteInterpreter * interpreter,int index)29*89c4ff92SAndroid Build Coastguard Worker inline const TfLiteTensor* GetOutputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* outputTensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
32*89c4ff92SAndroid Build Coastguard Worker     if(outputTensor == nullptr)
33*89c4ff92SAndroid Build Coastguard Worker     {
34*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Output tensor was not found at the given index: " + std::to_string(index));
35*89c4ff92SAndroid Build Coastguard Worker     }
36*89c4ff92SAndroid Build Coastguard Worker     return outputTensor;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker 
CreateTfLiteModel(std::vector<char> & data)39*89c4ff92SAndroid Build Coastguard Worker inline TfLiteModel* CreateTfLiteModel(std::vector<char>& data)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     TfLiteModel* tfLiteModel = TfLiteModelCreate(data.data(), data.size());
42*89c4ff92SAndroid Build Coastguard Worker     if(tfLiteModel == nullptr)
43*89c4ff92SAndroid Build Coastguard Worker     {
44*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("An error has occurred when creating the TfLiteModel.");
45*89c4ff92SAndroid Build Coastguard Worker     }
46*89c4ff92SAndroid Build Coastguard Worker     return tfLiteModel;
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker 
CreateTfLiteInterpreterOptions()49*89c4ff92SAndroid Build Coastguard Worker inline TfLiteInterpreterOptions* CreateTfLiteInterpreterOptions()
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker     TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
52*89c4ff92SAndroid Build Coastguard Worker     if(options == nullptr)
53*89c4ff92SAndroid Build Coastguard Worker     {
54*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("An error has occurred when creating the TfLiteInterpreterOptions.");
55*89c4ff92SAndroid Build Coastguard Worker     }
56*89c4ff92SAndroid Build Coastguard Worker     return options;
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker 
GenerateCustomOpResolver(const std::string & opName)59*89c4ff92SAndroid Build Coastguard Worker inline tflite::ops::builtin::BuiltinOpResolver GenerateCustomOpResolver(const std::string& opName)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker     tflite::ops::builtin::BuiltinOpResolver opResolver;
62*89c4ff92SAndroid Build Coastguard Worker     if (opName == "MaxPool3D")
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         opResolver.AddCustom("MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
65*89c4ff92SAndroid Build Coastguard Worker     }
66*89c4ff92SAndroid Build Coastguard Worker     else if (opName == "AveragePool3D")
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         opResolver.AddCustom("AveragePool3D", tflite::ops::custom::Register_AVG_POOL_3D());
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker     else
71*89c4ff92SAndroid Build Coastguard Worker     {
72*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("The custom op isn't supported by the DelegateTestInterpreter.");
73*89c4ff92SAndroid Build Coastguard Worker     }
74*89c4ff92SAndroid Build Coastguard Worker     return opResolver;
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CopyFromBufferToTensor(TfLiteTensor * tensor,std::vector<T> & values)78*89c4ff92SAndroid Build Coastguard Worker inline TfLiteStatus CopyFromBufferToTensor(TfLiteTensor* tensor, std::vector<T>& values)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker     // Make sure there is enough bytes allocated to copy into for uint8_t and int16_t case.
81*89c4ff92SAndroid Build Coastguard Worker     if(tensor->bytes < values.size() * sizeof(T))
82*89c4ff92SAndroid Build Coastguard Worker     {
83*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("Tensor has not been allocated to match number of values.");
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     // Requires uint8_t and int16_t specific case as the number of bytes is larger than values passed when creating
87*89c4ff92SAndroid Build Coastguard Worker     // TFLite tensors of these types. Otherwise, use generic TfLiteTensorCopyFromBuffer function.
88*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus status = kTfLiteOk;
89*89c4ff92SAndroid Build Coastguard Worker     if (std::is_same<T, uint8_t>::value)
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < values.size(); ++i)
92*89c4ff92SAndroid Build Coastguard Worker         {
93*89c4ff92SAndroid Build Coastguard Worker             tensor->data.uint8[i] = values[i];
94*89c4ff92SAndroid Build Coastguard Worker         }
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker     else if (std::is_same<T, int16_t>::value)
97*89c4ff92SAndroid Build Coastguard Worker     {
98*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < values.size(); ++i)
99*89c4ff92SAndroid Build Coastguard Worker         {
100*89c4ff92SAndroid Build Coastguard Worker             tensor->data.i16[i] = values[i];
101*89c4ff92SAndroid Build Coastguard Worker         }
102*89c4ff92SAndroid Build Coastguard Worker     }
103*89c4ff92SAndroid Build Coastguard Worker     else
104*89c4ff92SAndroid Build Coastguard Worker     {
105*89c4ff92SAndroid Build Coastguard Worker         status = TfLiteTensorCopyFromBuffer(tensor, values.data(), values.size() * sizeof(T));
106*89c4ff92SAndroid Build Coastguard Worker     }
107*89c4ff92SAndroid Build Coastguard Worker     return status;
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace