1 //
2 // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/Exceptions.hpp>
9
10 #include <tensorflow/lite/core/c/c_api.h>
11 #include <tensorflow/lite/kernels/custom_ops_register.h>
12 #include <tensorflow/lite/kernels/register.h>
13
14 #include <type_traits>
15
16 namespace delegateTestInterpreter
17 {
18
GetInputTensorFromInterpreter(TfLiteInterpreter * interpreter,int index)19 inline TfLiteTensor* GetInputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
20 {
21 TfLiteTensor* inputTensor = TfLiteInterpreterGetInputTensor(interpreter, index);
22 if(inputTensor == nullptr)
23 {
24 throw armnn::Exception("Input tensor was not found at the given index: " + std::to_string(index));
25 }
26 return inputTensor;
27 }
28
GetOutputTensorFromInterpreter(TfLiteInterpreter * interpreter,int index)29 inline const TfLiteTensor* GetOutputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
30 {
31 const TfLiteTensor* outputTensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
32 if(outputTensor == nullptr)
33 {
34 throw armnn::Exception("Output tensor was not found at the given index: " + std::to_string(index));
35 }
36 return outputTensor;
37 }
38
CreateTfLiteModel(std::vector<char> & data)39 inline TfLiteModel* CreateTfLiteModel(std::vector<char>& data)
40 {
41 TfLiteModel* tfLiteModel = TfLiteModelCreate(data.data(), data.size());
42 if(tfLiteModel == nullptr)
43 {
44 throw armnn::Exception("An error has occurred when creating the TfLiteModel.");
45 }
46 return tfLiteModel;
47 }
48
CreateTfLiteInterpreterOptions()49 inline TfLiteInterpreterOptions* CreateTfLiteInterpreterOptions()
50 {
51 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
52 if(options == nullptr)
53 {
54 throw armnn::Exception("An error has occurred when creating the TfLiteInterpreterOptions.");
55 }
56 return options;
57 }
58
GenerateCustomOpResolver(const std::string & opName)59 inline tflite::ops::builtin::BuiltinOpResolver GenerateCustomOpResolver(const std::string& opName)
60 {
61 tflite::ops::builtin::BuiltinOpResolver opResolver;
62 if (opName == "MaxPool3D")
63 {
64 opResolver.AddCustom("MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
65 }
66 else if (opName == "AveragePool3D")
67 {
68 opResolver.AddCustom("AveragePool3D", tflite::ops::custom::Register_AVG_POOL_3D());
69 }
70 else
71 {
72 throw armnn::Exception("The custom op isn't supported by the DelegateTestInterpreter.");
73 }
74 return opResolver;
75 }
76
77 template<typename T>
CopyFromBufferToTensor(TfLiteTensor * tensor,std::vector<T> & values)78 inline TfLiteStatus CopyFromBufferToTensor(TfLiteTensor* tensor, std::vector<T>& values)
79 {
80 // Make sure there is enough bytes allocated to copy into for uint8_t and int16_t case.
81 if(tensor->bytes < values.size() * sizeof(T))
82 {
83 throw armnn::Exception("Tensor has not been allocated to match number of values.");
84 }
85
86 // Requires uint8_t and int16_t specific case as the number of bytes is larger than values passed when creating
87 // TFLite tensors of these types. Otherwise, use generic TfLiteTensorCopyFromBuffer function.
88 TfLiteStatus status = kTfLiteOk;
89 if (std::is_same<T, uint8_t>::value)
90 {
91 for (unsigned int i = 0; i < values.size(); ++i)
92 {
93 tensor->data.uint8[i] = values[i];
94 }
95 }
96 else if (std::is_same<T, int16_t>::value)
97 {
98 for (unsigned int i = 0; i < values.size(); ++i)
99 {
100 tensor->data.i16[i] = values[i];
101 }
102 }
103 else
104 {
105 status = TfLiteTensorCopyFromBuffer(tensor, values.data(), values.size() * sizeof(T));
106 }
107 return status;
108 }
109
110 } // anonymous namespace