1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "TestUtils.hpp"
9
10 #include <armnn_delegate.hpp>
11 #include <DelegateTestInterpreter.hpp>
12
13 #include <flatbuffers/flatbuffers.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/version.h>
16
17 #include <schema_generated.h>
18
19 #include <doctest/doctest.h>
20
21 namespace
22 {
23
CreateStridedSliceTfLiteModel(tflite::TensorType tensorType,const std::vector<int32_t> & inputTensorShape,const std::vector<int32_t> & beginTensorData,const std::vector<int32_t> & endTensorData,const std::vector<int32_t> & strideTensorData,const std::vector<int32_t> & beginTensorShape,const std::vector<int32_t> & endTensorShape,const std::vector<int32_t> & strideTensorShape,const std::vector<int32_t> & outputTensorShape,const int32_t beginMask,const int32_t endMask,const int32_t ellipsisMask,const int32_t newAxisMask,const int32_t ShrinkAxisMask,const armnn::DataLayout & dataLayout)24 std::vector<char> CreateStridedSliceTfLiteModel(tflite::TensorType tensorType,
25 const std::vector<int32_t>& inputTensorShape,
26 const std::vector<int32_t>& beginTensorData,
27 const std::vector<int32_t>& endTensorData,
28 const std::vector<int32_t>& strideTensorData,
29 const std::vector<int32_t>& beginTensorShape,
30 const std::vector<int32_t>& endTensorShape,
31 const std::vector<int32_t>& strideTensorShape,
32 const std::vector<int32_t>& outputTensorShape,
33 const int32_t beginMask,
34 const int32_t endMask,
35 const int32_t ellipsisMask,
36 const int32_t newAxisMask,
37 const int32_t ShrinkAxisMask,
38 const armnn::DataLayout& dataLayout)
39 {
40 using namespace tflite;
41 flatbuffers::FlatBufferBuilder flatBufferBuilder;
42
43 flatbuffers::Offset<tflite::Buffer> buffers[6] = {
44 CreateBuffer(flatBufferBuilder),
45 CreateBuffer(flatBufferBuilder),
46 CreateBuffer(flatBufferBuilder,
47 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
48 sizeof(int32_t) * beginTensorData.size())),
49 CreateBuffer(flatBufferBuilder,
50 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(endTensorData.data()),
51 sizeof(int32_t) * endTensorData.size())),
52 CreateBuffer(flatBufferBuilder,
53 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(strideTensorData.data()),
54 sizeof(int32_t) * strideTensorData.size())),
55 CreateBuffer(flatBufferBuilder)
56 };
57
58 std::array<flatbuffers::Offset<Tensor>, 5> tensors;
59 tensors[0] = CreateTensor(flatBufferBuilder,
60 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
61 inputTensorShape.size()),
62 tensorType,
63 1,
64 flatBufferBuilder.CreateString("input"));
65 tensors[1] = CreateTensor(flatBufferBuilder,
66 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
67 beginTensorShape.size()),
68 ::tflite::TensorType_INT32,
69 2,
70 flatBufferBuilder.CreateString("begin_tensor"));
71 tensors[2] = CreateTensor(flatBufferBuilder,
72 flatBufferBuilder.CreateVector<int32_t>(endTensorShape.data(),
73 endTensorShape.size()),
74 ::tflite::TensorType_INT32,
75 3,
76 flatBufferBuilder.CreateString("end_tensor"));
77 tensors[3] = CreateTensor(flatBufferBuilder,
78 flatBufferBuilder.CreateVector<int32_t>(strideTensorShape.data(),
79 strideTensorShape.size()),
80 ::tflite::TensorType_INT32,
81 4,
82 flatBufferBuilder.CreateString("stride_tensor"));
83 tensors[4] = CreateTensor(flatBufferBuilder,
84 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
85 outputTensorShape.size()),
86 tensorType,
87 5,
88 flatBufferBuilder.CreateString("output"));
89
90
91 // create operator
92 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_StridedSliceOptions;
93 flatbuffers::Offset<void> operatorBuiltinOptions = CreateStridedSliceOptions(flatBufferBuilder,
94 beginMask,
95 endMask,
96 ellipsisMask,
97 newAxisMask,
98 ShrinkAxisMask).Union();
99
100 const std::vector<int> operatorInputs{ 0, 1, 2, 3 };
101 const std::vector<int> operatorOutputs{ 4 };
102 flatbuffers::Offset <Operator> sliceOperator =
103 CreateOperator(flatBufferBuilder,
104 0,
105 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
106 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
107 operatorBuiltinOptionsType,
108 operatorBuiltinOptions);
109
110 const std::vector<int> subgraphInputs{ 0, 1, 2, 3 };
111 const std::vector<int> subgraphOutputs{ 4 };
112 flatbuffers::Offset <SubGraph> subgraph =
113 CreateSubGraph(flatBufferBuilder,
114 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
115 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
116 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
117 flatBufferBuilder.CreateVector(&sliceOperator, 1));
118
119 flatbuffers::Offset <flatbuffers::String> modelDescription =
120 flatBufferBuilder.CreateString("ArmnnDelegate: StridedSlice Operator Model");
121 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
122 BuiltinOperator_STRIDED_SLICE);
123
124 flatbuffers::Offset <Model> flatbufferModel =
125 CreateModel(flatBufferBuilder,
126 TFLITE_SCHEMA_VERSION,
127 flatBufferBuilder.CreateVector(&operatorCode, 1),
128 flatBufferBuilder.CreateVector(&subgraph, 1),
129 modelDescription,
130 flatBufferBuilder.CreateVector(buffers, 6));
131
132 flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
133
134 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
135 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
136 }
137
138 template <typename T>
StridedSliceTestImpl(std::vector<armnn::BackendId> & backends,std::vector<T> & inputValues,std::vector<T> & expectedOutputValues,std::vector<int32_t> & beginTensorData,std::vector<int32_t> & endTensorData,std::vector<int32_t> & strideTensorData,std::vector<int32_t> & inputTensorShape,std::vector<int32_t> & beginTensorShape,std::vector<int32_t> & endTensorShape,std::vector<int32_t> & strideTensorShape,std::vector<int32_t> & outputTensorShape,const int32_t beginMask=0,const int32_t endMask=0,const int32_t ellipsisMask=0,const int32_t newAxisMask=0,const int32_t ShrinkAxisMask=0,const armnn::DataLayout & dataLayout=armnn::DataLayout::NHWC)139 void StridedSliceTestImpl(std::vector<armnn::BackendId>& backends,
140 std::vector<T>& inputValues,
141 std::vector<T>& expectedOutputValues,
142 std::vector<int32_t>& beginTensorData,
143 std::vector<int32_t>& endTensorData,
144 std::vector<int32_t>& strideTensorData,
145 std::vector<int32_t>& inputTensorShape,
146 std::vector<int32_t>& beginTensorShape,
147 std::vector<int32_t>& endTensorShape,
148 std::vector<int32_t>& strideTensorShape,
149 std::vector<int32_t>& outputTensorShape,
150 const int32_t beginMask = 0,
151 const int32_t endMask = 0,
152 const int32_t ellipsisMask = 0,
153 const int32_t newAxisMask = 0,
154 const int32_t ShrinkAxisMask = 0,
155 const armnn::DataLayout& dataLayout = armnn::DataLayout::NHWC)
156 {
157 using namespace delegateTestInterpreter;
158 std::vector<char> modelBuffer = CreateStridedSliceTfLiteModel(
159 ::tflite::TensorType_FLOAT32,
160 inputTensorShape,
161 beginTensorData,
162 endTensorData,
163 strideTensorData,
164 beginTensorShape,
165 endTensorShape,
166 strideTensorShape,
167 outputTensorShape,
168 beginMask,
169 endMask,
170 ellipsisMask,
171 newAxisMask,
172 ShrinkAxisMask,
173 dataLayout);
174
175 // Setup interpreter with just TFLite Runtime.
176 auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
177 CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
178 CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
179 CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
180 std::vector<T> tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(0);
181 std::vector<int32_t> tfLiteOutputShape = tfLiteInterpreter.GetOutputShape(0);
182
183 // Setup interpreter with Arm NN Delegate applied.
184 auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
185 CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
186 CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
187 CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
188 std::vector<T> armnnOutputValues = armnnInterpreter.GetOutputResult<T>(0);
189 std::vector<int32_t> armnnOutputShape = armnnInterpreter.GetOutputShape(0);
190
191 armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues);
192 armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShape);
193
194 tfLiteInterpreter.Cleanup();
195 armnnInterpreter.Cleanup();
196 } // End of StridedSlice Test
197
198 } // anonymous namespace