xref: /aosp_15_r20/external/armnn/delegate/common/src/test/DelegateTestInterpreterUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Exceptions.hpp>
9 
10 #include <tensorflow/lite/core/c/c_api.h>
11 #include <tensorflow/lite/kernels/custom_ops_register.h>
12 #include <tensorflow/lite/kernels/register.h>
13 
14 #include <type_traits>
15 
16 namespace delegateTestInterpreter
17 {
18 
GetInputTensorFromInterpreter(TfLiteInterpreter * interpreter,int index)19 inline TfLiteTensor* GetInputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
20 {
21     TfLiteTensor* inputTensor = TfLiteInterpreterGetInputTensor(interpreter, index);
22     if(inputTensor == nullptr)
23     {
24         throw armnn::Exception("Input tensor was not found at the given index: " + std::to_string(index));
25     }
26     return inputTensor;
27 }
28 
GetOutputTensorFromInterpreter(TfLiteInterpreter * interpreter,int index)29 inline const TfLiteTensor* GetOutputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
30 {
31     const TfLiteTensor* outputTensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
32     if(outputTensor == nullptr)
33     {
34         throw armnn::Exception("Output tensor was not found at the given index: " + std::to_string(index));
35     }
36     return outputTensor;
37 }
38 
CreateTfLiteModel(std::vector<char> & data)39 inline TfLiteModel* CreateTfLiteModel(std::vector<char>& data)
40 {
41     TfLiteModel* tfLiteModel = TfLiteModelCreate(data.data(), data.size());
42     if(tfLiteModel == nullptr)
43     {
44         throw armnn::Exception("An error has occurred when creating the TfLiteModel.");
45     }
46     return tfLiteModel;
47 }
48 
CreateTfLiteInterpreterOptions()49 inline TfLiteInterpreterOptions* CreateTfLiteInterpreterOptions()
50 {
51     TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
52     if(options == nullptr)
53     {
54         throw armnn::Exception("An error has occurred when creating the TfLiteInterpreterOptions.");
55     }
56     return options;
57 }
58 
GenerateCustomOpResolver(const std::string & opName)59 inline tflite::ops::builtin::BuiltinOpResolver GenerateCustomOpResolver(const std::string& opName)
60 {
61     tflite::ops::builtin::BuiltinOpResolver opResolver;
62     if (opName == "MaxPool3D")
63     {
64         opResolver.AddCustom("MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
65     }
66     else if (opName == "AveragePool3D")
67     {
68         opResolver.AddCustom("AveragePool3D", tflite::ops::custom::Register_AVG_POOL_3D());
69     }
70     else
71     {
72         throw armnn::Exception("The custom op isn't supported by the DelegateTestInterpreter.");
73     }
74     return opResolver;
75 }
76 
77 template<typename T>
CopyFromBufferToTensor(TfLiteTensor * tensor,std::vector<T> & values)78 inline TfLiteStatus CopyFromBufferToTensor(TfLiteTensor* tensor, std::vector<T>& values)
79 {
80     // Make sure there is enough bytes allocated to copy into for uint8_t and int16_t case.
81     if(tensor->bytes < values.size() * sizeof(T))
82     {
83         throw armnn::Exception("Tensor has not been allocated to match number of values.");
84     }
85 
86     // Requires uint8_t and int16_t specific case as the number of bytes is larger than values passed when creating
87     // TFLite tensors of these types. Otherwise, use generic TfLiteTensorCopyFromBuffer function.
88     TfLiteStatus status = kTfLiteOk;
89     if (std::is_same<T, uint8_t>::value)
90     {
91         for (unsigned int i = 0; i < values.size(); ++i)
92         {
93             tensor->data.uint8[i] = values[i];
94         }
95     }
96     else if (std::is_same<T, int16_t>::value)
97     {
98         for (unsigned int i = 0; i < values.size(); ++i)
99         {
100             tensor->data.i16[i] = values[i];
101         }
102     }
103     else
104     {
105         status = TfLiteTensorCopyFromBuffer(tensor, values.data(), values.size() * sizeof(T));
106     }
107     return status;
108 }
109 
110 } // anonymous namespace